Skip to main content

ai_memory/multistep_ingest/
pipeline.rs

1// Copyright 2026 AlphaOne LLC
2// SPDX-License-Identifier: Apache-2.0
3
4//! Form 3 — pipeline descriptor + the two default Batman exemplars.
5//!
6//! A [`Pipeline`] is an ordered list of [`Stage`]s. Helpers go first
7//! (deterministic, parallel-where-independent); LLM stages follow with
8//! explicit trust slots pointing back at the helper outputs.
9
10use crate::models::field_names;
11use serde::{Deserialize, Serialize};
12use serde_json::{Value, json};
13
14use super::helpers::{HelperKind, HelperParams};
15
16/// Named pipeline variant exposed at the MCP tool surface. Operators
17/// pick a variant via `pipeline_variant: "two_phase" | "four_step"` and
18/// can override the descriptor entirely via `pipeline_override` (a
19/// JSON-encoded [`Pipeline`]).
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
21#[serde(rename_all = "snake_case")]
22pub enum PipelineVariant {
23    /// Understand-Anything two-phase exemplar.
24    TwoPhase,
25    /// OpenKB four-step exemplar.
26    FourStep,
27}
28
29impl PipelineVariant {
30    /// Snake-case discriminator used in the shared-prefix builder and
31    /// the JSON trace.
32    #[must_use]
33    pub const fn as_str(self) -> &'static str {
34        match self {
35            Self::TwoPhase => "two_phase",
36            Self::FourStep => "four_step",
37        }
38    }
39
40    /// Parse a variant tag (snake_case). Returns `None` for unknown
41    /// inputs so the caller can surface a structured validation error.
42    #[must_use]
43    pub fn from_str(s: &str) -> Option<Self> {
44        match s {
45            "two_phase" => Some(Self::TwoPhase),
46            "four_step" => Some(Self::FourStep),
47            _ => None,
48        }
49    }
50}
51
52/// Reference to a prior helper output, surfaced to an LLM stage via its
53/// explicit-trust slot. `stage_index` is the zero-based position of the
54/// helper stage that produced the output; the executor resolves these
55/// against its in-flight context.
56#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
57pub struct HelperOutputRef {
58    /// Zero-based index of the producing helper stage.
59    pub stage_index: usize,
60    /// Label for the slot — appears in the LLM prompt so the model can
61    /// distinguish multiple trust slots (`"overlap"`, `"classification"`,
62    /// etc.).
63    pub label: String,
64}
65
66/// A pipeline stage. Helpers run first; LLM stages follow with trust
67/// slots resolved against the helper outputs.
68#[derive(Debug, Clone, Serialize, Deserialize)]
69#[serde(tag = "type", rename_all = "snake_case")]
70pub enum Stage {
71    /// Deterministic helper stage. Runs synchronously, produces a JSON
72    /// payload, no LLM involvement.
73    Helper {
74        /// Which deterministic helper to run.
75        kind: HelperKind,
76        /// Helper parameters. The executor merges in the run-time
77        /// `content` / `candidates` if the descriptor omitted them.
78        #[serde(default)]
79        params: HelperParams,
80    },
81    /// LLM call stage. The prompt template is appended to the SHARED
82    /// PREFIX from [`super::cache::build_shared_prefix`]; trust slots
83    /// are rendered verbatim into the prompt.
84    LlmCall {
85        /// Free-form prompt body (the stage-specific tail of the
86        /// shared-prefix sandwich).
87        prompt_template: String,
88        /// Trust slots — references to prior helper outputs that get
89        /// rendered into the prompt under the explicit-trust banner.
90        #[serde(default)]
91        trust_inputs: Vec<HelperOutputRef>,
92        /// Output schema hint forwarded to the LLM and echoed in the
93        /// trace so callers can route the parsed JSON.
94        #[serde(default)]
95        output_schema: Value,
96        /// Stage label — surfaces in the trace and the LLM prompt.
97        #[serde(default)]
98        label: String,
99    },
100}
101
102/// A pipeline descriptor. Each stage runs in declaration order; the
103/// executor enforces "helpers before LLM stages" so the trust slots
104/// are always resolvable.
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct Pipeline {
107    /// Variant tag (drives the shared-prefix builder).
108    pub variant: PipelineVariant,
109    /// Stages in execution order.
110    pub stages: Vec<Stage>,
111    /// System prompt shared across every LLM stage in this pipeline.
112    /// Goes into the prompt-cache-friendly shared prefix; do NOT put
113    /// per-stage variation here.
114    #[serde(default)]
115    pub system_prompt: String,
116}
117
118impl Pipeline {
119    /// Variant-tag accessor used by the executor when assembling the
120    /// shared prefix.
121    #[must_use]
122    pub fn variant_tag(&self) -> &'static str {
123        self.variant.as_str()
124    }
125}
126
127/// Understand-Anything-style two-phase pipeline.
128///
129/// Phase 1 (Helper): FTS overlap + Jaccard pre-filter against existing
130/// memories. Both helpers run in the same stage chain; the executor
131/// parallelises them because they have no inter-dependency.
132///
133/// Phase 2 (LLM): synthesise summary + tags + atoms with explicit trust
134/// citing the helper output.
135#[must_use]
136pub fn two_phase_default() -> Pipeline {
137    Pipeline {
138        variant: PipelineVariant::TwoPhase,
139        system_prompt: "Synthesise the incoming content into a structured \
140                        memory envelope with title, summary, tags, and atoms. \
141                        Cite the helper output verbatim when it carries \
142                        candidate overlaps or classifications."
143            .to_string(),
144        stages: vec![
145            Stage::Helper {
146                kind: HelperKind::FtsClassifier,
147                params: HelperParams::default(),
148            },
149            Stage::Helper {
150                kind: HelperKind::JaccardOverlap,
151                params: HelperParams::default(),
152            },
153            Stage::LlmCall {
154                label: "synthesise".to_string(),
155                prompt_template: "Produce a JSON object {title, summary, \
156                                  tags[], atoms[]} where each atom is a \
157                                  standalone fact distilled from the content. \
158                                  The trust slots below carry the \
159                                  pre-computed classifier label and the top \
160                                  candidate overlaps."
161                    .to_string(),
162                trust_inputs: vec![
163                    HelperOutputRef {
164                        stage_index: 0,
165                        label: "classification".to_string(),
166                    },
167                    HelperOutputRef {
168                        stage_index: 1,
169                        label: "overlap".to_string(),
170                    },
171                ],
172                output_schema: json!({
173                    "type": "object",
174                    "required": ["title", "summary", "tags", "atoms"],
175                    (field_names::PROPERTIES): {
176                        "title": {"type": "string"},
177                        "summary": {"type": "string"},
178                        "tags": {"type": "array", "items": {"type": "string"}},
179                        "atoms": {"type": "array", "items": {"type": "string"}}
180                    }
181                }),
182            },
183        ],
184    }
185}
186
187/// OpenKB-style four-step pipeline.
188///
189/// Stage 1 (Helper): `load_context` — assemble candidate set via FTS
190/// classifier + Jaccard overlap. (Two helpers under one logical "load"
191/// step.)
192///
193/// Stage 2 (LLM): `classify` — what kind of fact is this. Trust slot
194/// carries the FTS classifier label.
195///
196/// Stage 3 (LLM): `enrich` — extract entities, claims, relations.
197/// Trust slot carries the overlap output so the LLM doesn't re-rank.
198///
199/// Stage 4 (LLM): `emit` — final structured memory output.
200///
201/// All LLM stages share the SAME system prompt prefix so the
202/// prompt-cache key stays stable across stages within a run.
203#[must_use]
204pub fn four_step_default() -> Pipeline {
205    Pipeline {
206        variant: PipelineVariant::FourStep,
207        system_prompt: "Run the OpenKB four-step ingest pipeline. Each \
208                        stage produces a JSON object that feeds the next \
209                        stage. Trust the helper output verbatim — do not \
210                        re-derive classifications or overlap scores."
211            .to_string(),
212        stages: vec![
213            Stage::Helper {
214                kind: HelperKind::FtsClassifier,
215                params: HelperParams::default(),
216            },
217            Stage::Helper {
218                kind: HelperKind::JaccardOverlap,
219                params: HelperParams::default(),
220            },
221            Stage::LlmCall {
222                label: "classify".to_string(),
223                prompt_template: "Classify this content. Return JSON \
224                                  {fact_kind, confidence}."
225                    .to_string(),
226                trust_inputs: vec![HelperOutputRef {
227                    stage_index: 0,
228                    label: HelperKind::FtsClassifier.as_str().to_string(),
229                }],
230                output_schema: json!({
231                    "type": "object",
232                    "required": ["fact_kind", field_names::CONFIDENCE],
233                    (field_names::PROPERTIES): {
234                        "fact_kind": {
235                            "type": "string",
236                            "enum": ["procedural", "declarative", "episodic"]
237                        },
238                        (field_names::CONFIDENCE): {
239                            "type": "number",
240                            "minimum": 0.0,
241                            "maximum": 1.0
242                        }
243                    }
244                }),
245            },
246            Stage::LlmCall {
247                label: "enrich".to_string(),
248                prompt_template: "Extract entities, claims, and relations \
249                                  from the content. Return JSON {entities[], \
250                                  claims[], relations[]}."
251                    .to_string(),
252                trust_inputs: vec![HelperOutputRef {
253                    stage_index: 1,
254                    label: "overlap".to_string(),
255                }],
256                output_schema: json!({
257                    "type": "object",
258                    "required": ["entities", "claims", "relations"],
259                    (field_names::PROPERTIES): {
260                        "entities": {"type": "array", "items": {"type": "string"}},
261                        "claims": {"type": "array", "items": {"type": "string"}},
262                        "relations": {"type": "array", "items": {"type": "object"}}
263                    }
264                }),
265            },
266            Stage::LlmCall {
267                label: "emit".to_string(),
268                prompt_template: "Emit the final memory envelope. Return \
269                                  JSON {title, summary, tags[], \
270                                  proposed_links[]}."
271                    .to_string(),
272                trust_inputs: vec![
273                    HelperOutputRef {
274                        stage_index: 0,
275                        label: HelperKind::FtsClassifier.as_str().to_string(),
276                    },
277                    HelperOutputRef {
278                        stage_index: 1,
279                        label: "overlap".to_string(),
280                    },
281                ],
282                output_schema: json!({
283                    "type": "object",
284                    "required": ["title", "summary", "tags", "proposed_links"],
285                    (field_names::PROPERTIES): {
286                        "title": {"type": "string"},
287                        "summary": {"type": "string"},
288                        "tags": {"type": "array", "items": {"type": "string"}},
289                        "proposed_links": {"type": "array", "items": {"type": "object"}}
290                    }
291                }),
292            },
293        ],
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    #[test]
302    fn pipeline_variant_round_trip_via_str() {
303        assert_eq!(
304            PipelineVariant::from_str("two_phase"),
305            Some(PipelineVariant::TwoPhase)
306        );
307        assert_eq!(
308            PipelineVariant::from_str("four_step"),
309            Some(PipelineVariant::FourStep)
310        );
311        assert_eq!(PipelineVariant::from_str("nonsense"), None);
312    }
313
314    #[test]
315    fn two_phase_default_has_two_phases() {
316        let p = two_phase_default();
317        assert_eq!(p.variant, PipelineVariant::TwoPhase);
318        let helpers = p
319            .stages
320            .iter()
321            .filter(|s| matches!(s, Stage::Helper { .. }))
322            .count();
323        let llms = p
324            .stages
325            .iter()
326            .filter(|s| matches!(s, Stage::LlmCall { .. }))
327            .count();
328        // Two helpers (FTS + Jaccard) feed a single LLM synthesise stage.
329        assert_eq!(helpers, 2);
330        assert_eq!(llms, 1);
331    }
332
333    #[test]
334    fn four_step_default_has_four_logical_stages() {
335        let p = four_step_default();
336        assert_eq!(p.variant, PipelineVariant::FourStep);
337        let llms = p
338            .stages
339            .iter()
340            .filter(|s| matches!(s, Stage::LlmCall { .. }))
341            .count();
342        // Stage 1 (Helper) decomposes into 2 helpers; stages 2/3/4 are
343        // three LLM calls.
344        assert_eq!(llms, 3);
345    }
346
347    #[test]
348    fn two_phase_llm_stage_references_both_helpers() {
349        let p = two_phase_default();
350        let Stage::LlmCall { trust_inputs, .. } = p.stages.last().unwrap() else {
351            panic!("last stage should be LLM call");
352        };
353        assert_eq!(trust_inputs.len(), 2);
354        assert_eq!(trust_inputs[0].stage_index, 0);
355        assert_eq!(trust_inputs[1].stage_index, 1);
356    }
357
358    #[test]
359    fn four_step_llm_stages_each_have_trust_inputs() {
360        let p = four_step_default();
361        for stage in &p.stages {
362            if let Stage::LlmCall { trust_inputs, .. } = stage {
363                assert!(
364                    !trust_inputs.is_empty(),
365                    "every LLM stage must have at least one trust input"
366                );
367            }
368        }
369    }
370
371    #[test]
372    fn pipeline_descriptor_round_trips_through_serde() {
373        let p = four_step_default();
374        let s = serde_json::to_string(&p).expect("serialises");
375        let back: Pipeline = serde_json::from_str(&s).expect("deserialises");
376        assert_eq!(back.variant, p.variant);
377        assert_eq!(back.stages.len(), p.stages.len());
378    }
379
380    #[test]
381    fn variant_tag_matches_as_str() {
382        assert_eq!(two_phase_default().variant_tag(), "two_phase");
383        assert_eq!(four_step_default().variant_tag(), "four_step");
384    }
385}