use std::sync::Arc;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use super::cache::{
CacheKey, EXPLICIT_TRUST_INSTRUCTION, PromptCacheTelemetry, build_shared_prefix,
};
#[cfg(test)]
use super::helpers::HelperParams;
use super::helpers::{HelperContext, HelperOutput, MemoryHandle, run_helper_with};
use super::pipeline::{HelperOutputRef, Pipeline, Stage};
pub const DEFAULT_MULTISTEP_MAX_CONTENT_CHARS: usize = 1500;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "stage_type", rename_all = "snake_case")]
pub enum StageOutcome {
Helper {
index: usize,
helper: String,
summary: String,
payload: Value,
content_bytes: usize,
},
LlmCall {
index: usize,
label: String,
prompt: String,
cache_key: String,
response: Value,
content_bytes: usize,
content_truncated: bool,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionTrace {
pub variant: String,
pub stages: Vec<StageOutcome>,
pub distinct_cache_keys: Vec<String>,
pub prompt_cache_consistent: bool,
pub final_output: Value,
pub bytes_per_stage: Vec<usize>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExecutorError {
InvalidTrustSlot {
stage_index: usize,
label: String,
},
LlmDispatch(String),
EmptyPipeline,
}
impl std::fmt::Display for ExecutorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidTrustSlot { stage_index, label } => write!(
f,
"invalid trust slot: stage_index={stage_index} (label={label})"
),
Self::LlmDispatch(msg) => write!(f, "llm dispatch failed: {msg}"),
Self::EmptyPipeline => write!(f, "pipeline has no stages"),
}
}
}
impl std::error::Error for ExecutorError {}
pub trait LlmDispatch: Send + Sync {
fn dispatch(&self, prompt: &str) -> Result<String, String>;
}
pub struct OllamaDispatch {
client: Arc<crate::llm::OllamaClient>,
}
impl OllamaDispatch {
#[must_use]
pub fn new(client: Arc<crate::llm::OllamaClient>) -> Self {
Self { client }
}
}
impl LlmDispatch for OllamaDispatch {
fn dispatch(&self, prompt: &str) -> Result<String, String> {
self.client
.generate(prompt, None)
.map_err(|e| e.to_string())
}
}
pub struct MockLlmDispatch {
responses: std::sync::Mutex<Vec<Result<String, String>>>,
}
impl MockLlmDispatch {
#[must_use]
pub fn new(responses: Vec<Result<String, String>>) -> Self {
Self {
responses: std::sync::Mutex::new(responses),
}
}
}
impl LlmDispatch for MockLlmDispatch {
fn dispatch(&self, _prompt: &str) -> Result<String, String> {
let mut q = self.responses.lock().expect("mutex not poisoned in tests");
if q.is_empty() {
return Err("mock: queue exhausted".to_string());
}
q.remove(0)
}
}
pub struct IngestExecutor<D: LlmDispatch + ?Sized> {
dispatch: Arc<D>,
telemetry: Arc<PromptCacheTelemetry>,
max_content_chars: Option<usize>,
helper_content_ptrs: Arc<std::sync::Mutex<Vec<usize>>>,
}
impl<D: LlmDispatch + ?Sized> IngestExecutor<D> {
#[must_use]
pub fn new(dispatch: Arc<D>) -> Self {
Self {
dispatch,
telemetry: Arc::new(PromptCacheTelemetry::new()),
max_content_chars: None,
helper_content_ptrs: Arc::new(std::sync::Mutex::new(Vec::new())),
}
}
#[must_use]
pub fn with_max_content_chars(mut self, cap: usize) -> Self {
self.max_content_chars = Some(cap);
self
}
#[must_use]
pub fn telemetry(&self) -> Arc<PromptCacheTelemetry> {
Arc::clone(&self.telemetry)
}
#[doc(hidden)]
#[must_use]
pub fn helper_content_ptrs(&self) -> Vec<usize> {
self.helper_content_ptrs
.lock()
.map(|g| g.clone())
.unwrap_or_default()
}
pub fn run(
&self,
pipeline: &Pipeline,
content: &str,
candidates: &[MemoryHandle],
content_embedding: Option<&[f32]>,
namespace: Option<&str>,
) -> Result<ExecutionTrace, ExecutorError> {
if pipeline.stages.is_empty() {
return Err(ExecutorError::EmptyPipeline);
}
let mut helper_outputs: Vec<Option<HelperOutput>> = vec![None; pipeline.stages.len()];
let mut stage_outcomes: Vec<StageOutcome> = Vec::with_capacity(pipeline.stages.len());
let mut bytes_per_stage: Vec<usize> = Vec::with_capacity(pipeline.stages.len());
let helper_ctx = HelperContext::new(content, candidates, content_embedding, namespace);
#[cfg(debug_assertions)]
let content_ptr_for_test = content.as_ptr() as usize;
for (idx, stage) in pipeline.stages.iter().enumerate() {
if let Stage::Helper { kind, params } = stage {
#[cfg(debug_assertions)]
{
let effective_ptr = helper_ctx.effective_content(params).as_ptr() as usize;
if let Ok(mut g) = self.helper_content_ptrs.lock() {
g.push(effective_ptr);
}
if params.content.is_empty() {
debug_assert_eq!(effective_ptr, content_ptr_for_test);
}
}
let out = run_helper_with(*kind, params, &helper_ctx);
bytes_per_stage.push(content.len());
stage_outcomes.push(StageOutcome::Helper {
index: idx,
helper: out.kind.as_str().to_string(),
summary: out.summary.clone(),
payload: out.payload.clone(),
content_bytes: content.len(),
});
helper_outputs[idx] = Some(out);
}
}
let prefix = build_shared_prefix(pipeline.variant_tag(), &pipeline.system_prompt);
let cache_key = CacheKey::from_prefix(&prefix);
let llm_cap = self
.max_content_chars
.unwrap_or(DEFAULT_MULTISTEP_MAX_CONTENT_CHARS);
let mut last_llm_response: Option<Value> = None;
for (idx, stage) in pipeline.stages.iter().enumerate() {
let Stage::LlmCall {
prompt_template,
trust_inputs,
output_schema,
label,
} = stage
else {
continue;
};
let trust_block = render_trust_inputs(trust_inputs, &helper_outputs)?;
let (content_view, truncated) = truncate_content_for_llm(content, llm_cap);
let stage_tail = format!(
"\n[STAGE label={label} index={idx}]\n\
[INCOMING CONTENT]\n{content_view}\n\
[TRUST INPUTS]\n{trust_block}\n\
[TASK]\n{prompt_template}\n\
[OUTPUT SCHEMA]\n{schema}\n",
schema = serde_json::to_string(output_schema).unwrap_or_else(|_| "{}".to_string()),
);
let prompt = format!("{prefix}{stage_tail}");
self.telemetry.record(cache_key.clone());
let response_text = self
.dispatch
.dispatch(&prompt)
.map_err(ExecutorError::LlmDispatch)?;
let response_value = match serde_json::from_str::<Value>(&response_text) {
Ok(v) => v,
Err(_) => json!({ "raw": response_text }),
};
let content_bytes = content_view.len();
bytes_per_stage.push(content_bytes);
stage_outcomes.push(StageOutcome::LlmCall {
index: idx,
label: label.clone(),
prompt,
cache_key: cache_key.as_hex().to_string(),
response: response_value.clone(),
content_bytes,
content_truncated: truncated,
});
last_llm_response = Some(response_value);
}
let distinct_cache_keys: Vec<String> = {
let mut seen: Vec<String> =
self.telemetry.snapshot().into_iter().map(|k| k.0).collect();
seen.sort();
seen.dedup();
seen
};
let prompt_cache_consistent = self.telemetry.all_keys_match();
let final_output = last_llm_response.unwrap_or_else(|| {
helper_outputs
.iter()
.rev()
.find_map(|o| o.as_ref().map(|h| h.payload.clone()))
.unwrap_or_else(|| json!({}))
});
Ok(ExecutionTrace {
variant: pipeline.variant_tag().to_string(),
stages: stage_outcomes,
distinct_cache_keys,
prompt_cache_consistent,
final_output,
bytes_per_stage,
})
}
}
fn truncate_content_for_llm(content: &str, cap: usize) -> (std::borrow::Cow<'_, str>, bool) {
use std::fmt::Write as _;
if cap == 0 {
return (std::borrow::Cow::Borrowed(content), false);
}
let total_chars = content.chars().count();
if total_chars <= cap {
return (std::borrow::Cow::Borrowed(content), false);
}
let mut truncated: String = content.chars().take(cap).collect();
let _ = write!(
truncated,
" [...truncated {} chars]",
total_chars.saturating_sub(cap)
);
(std::borrow::Cow::Owned(truncated), true)
}
fn render_trust_inputs(
inputs: &[HelperOutputRef],
helper_outputs: &[Option<HelperOutput>],
) -> Result<String, ExecutorError> {
if inputs.is_empty() {
return Ok(format!("(none — but: {EXPLICIT_TRUST_INSTRUCTION})"));
}
let mut out = String::new();
out.push_str(EXPLICIT_TRUST_INSTRUCTION);
out.push_str("\n\n");
for input in inputs {
let payload = helper_outputs
.get(input.stage_index)
.and_then(|o| o.as_ref())
.ok_or_else(|| ExecutorError::InvalidTrustSlot {
stage_index: input.stage_index,
label: input.label.clone(),
})?;
out.push_str(&format!(
"<<TRUST label={} helper={}>>\n{}\n<<END TRUST>>\n\n",
input.label,
payload.kind.as_str(),
serde_json::to_string_pretty(&payload.payload).unwrap_or_default()
));
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::multistep_ingest::pipeline::{four_step_default, two_phase_default};
fn mh(id: &str, body: &str) -> MemoryHandle {
MemoryHandle {
id: id.to_string(),
body: body.to_string(),
embedding: None,
namespace: None,
}
}
#[test]
fn helper_then_llm_runs_in_order_and_renders_trust_slot() {
let mock = MockLlmDispatch::new(vec![Ok(
r#"{"title":"T","summary":"S","tags":[],"atoms":[]}"#.to_string(),
)]);
let exec = IngestExecutor::new(Arc::new(mock));
let pipeline = two_phase_default();
let trace = exec
.run(
&pipeline,
"the quick brown fox",
&[mh("c1", "a quick fox")],
None,
Some("global"),
)
.expect("pipeline runs");
assert!(matches!(trace.stages[0], StageOutcome::Helper { .. }));
assert!(matches!(trace.stages[1], StageOutcome::Helper { .. }));
assert!(matches!(trace.stages[2], StageOutcome::LlmCall { .. }));
if let StageOutcome::LlmCall { prompt, .. } = &trace.stages[2] {
assert!(
prompt.contains(EXPLICIT_TRUST_INSTRUCTION),
"LLM prompt must carry the explicit-trust instruction verbatim"
);
assert!(
prompt.contains("jaccard_overlap") || prompt.contains("fts_classifier"),
"LLM prompt must cite a helper kind from the trust slots"
);
} else {
panic!("stage 2 must be an LLM call");
}
}
#[test]
fn two_phase_pipeline_produces_structured_output() {
let mock = MockLlmDispatch::new(vec![Ok(
r#"{"title":"T","summary":"S","tags":["a"],"atoms":["one","two"]}"#.to_string(),
)]);
let exec = IngestExecutor::new(Arc::new(mock));
let pipeline = two_phase_default();
let trace = exec
.run(&pipeline, "anything", &[], None, None)
.expect("ok");
assert_eq!(trace.variant, "two_phase");
assert_eq!(trace.final_output["title"], "T");
assert_eq!(trace.final_output["atoms"].as_array().unwrap().len(), 2);
}
#[test]
fn four_step_pipeline_produces_structured_output() {
let mock = MockLlmDispatch::new(vec![
Ok(r#"{"fact_kind":"declarative","confidence":0.9}"#.to_string()),
Ok(r#"{"entities":["a"],"claims":["c"],"relations":[]}"#.to_string()),
Ok(r#"{"title":"X","summary":"Y","tags":[],"proposed_links":[]}"#.to_string()),
]);
let exec = IngestExecutor::new(Arc::new(mock));
let pipeline = four_step_default();
let trace = exec
.run(
&pipeline,
"Paris is the capital of France.",
&[],
None,
None,
)
.expect("ok");
assert_eq!(trace.variant, "four_step");
let llm_count = trace
.stages
.iter()
.filter(|s| matches!(s, StageOutcome::LlmCall { .. }))
.count();
assert_eq!(llm_count, 3);
assert_eq!(trace.final_output["title"], "X");
}
#[test]
fn prompt_cache_key_is_consistent_across_stages_within_a_run() {
let mock = MockLlmDispatch::new(vec![
Ok(r#"{"fact_kind":"declarative","confidence":0.5}"#.to_string()),
Ok(r#"{"entities":[],"claims":[],"relations":[]}"#.to_string()),
Ok(r#"{"title":"T","summary":"S","tags":[],"proposed_links":[]}"#.to_string()),
]);
let exec = IngestExecutor::new(Arc::new(mock));
let pipeline = four_step_default();
let trace = exec.run(&pipeline, "content", &[], None, None).expect("ok");
assert!(
trace.prompt_cache_consistent,
"every LLM stage within a run must share the cache key"
);
assert_eq!(
trace.distinct_cache_keys.len(),
1,
"exactly one distinct cache key for a single-variant run"
);
}
#[test]
fn explicit_trust_instruction_appears_in_every_llm_prompt() {
let mock = MockLlmDispatch::new(vec![
Ok("{}".to_string()),
Ok("{}".to_string()),
Ok("{}".to_string()),
]);
let exec = IngestExecutor::new(Arc::new(mock));
let pipeline = four_step_default();
let trace = exec.run(&pipeline, "content", &[], None, None).expect("ok");
for stage in &trace.stages {
if let StageOutcome::LlmCall { prompt, .. } = stage {
assert!(
prompt.contains(EXPLICIT_TRUST_INSTRUCTION),
"every LLM prompt must carry the explicit-trust phrase"
);
}
}
}
#[test]
fn empty_pipeline_returns_structured_error() {
let mock = MockLlmDispatch::new(vec![]);
let exec = IngestExecutor::new(Arc::new(mock));
let pipeline = Pipeline {
variant: super::super::pipeline::PipelineVariant::TwoPhase,
stages: vec![],
system_prompt: String::new(),
};
let err = exec
.run(&pipeline, "x", &[], None, None)
.expect_err("empty pipeline should error");
assert!(matches!(err, ExecutorError::EmptyPipeline));
}
#[test]
fn helper_only_pipeline_uses_last_helper_payload_as_final_output() {
let mock = MockLlmDispatch::new(vec![]);
let exec = IngestExecutor::new(Arc::new(mock));
let pipeline = Pipeline {
variant: super::super::pipeline::PipelineVariant::TwoPhase,
stages: vec![Stage::Helper {
kind: super::super::helpers::HelperKind::FtsClassifier,
params: HelperParams::default(),
}],
system_prompt: String::new(),
};
let trace = exec
.run(&pipeline, "first, do X. then do Y.", &[], None, None)
.expect("ok");
assert_eq!(trace.final_output["helper"], "fts_classifier");
assert_eq!(trace.final_output["fact_kind"], "procedural");
}
#[test]
fn invalid_trust_slot_index_returns_structured_error() {
let mock = MockLlmDispatch::new(vec![Ok("{}".to_string())]);
let exec = IngestExecutor::new(Arc::new(mock));
let pipeline = Pipeline {
variant: super::super::pipeline::PipelineVariant::TwoPhase,
stages: vec![Stage::LlmCall {
prompt_template: "anything".to_string(),
trust_inputs: vec![HelperOutputRef {
stage_index: 99,
label: "missing".to_string(),
}],
output_schema: json!({}),
label: "broken".to_string(),
}],
system_prompt: "x".to_string(),
};
let err = exec
.run(&pipeline, "y", &[], None, None)
.expect_err("invalid trust slot must error");
assert!(matches!(err, ExecutorError::InvalidTrustSlot { .. }));
}
#[test]
fn telemetry_records_one_key_per_llm_stage() {
let mock = MockLlmDispatch::new(vec![
Ok("{}".to_string()),
Ok("{}".to_string()),
Ok("{}".to_string()),
]);
let exec = IngestExecutor::new(Arc::new(mock));
let telemetry = exec.telemetry();
let pipeline = four_step_default();
exec.run(&pipeline, "content", &[], None, None).unwrap();
assert_eq!(telemetry.len(), 3, "four-step has 3 LLM stages");
assert!(telemetry.all_keys_match());
}
}