1use crate::models::field_names;
11use serde::{Deserialize, Serialize};
12use serde_json::{Value, json};
13
14use super::helpers::{HelperKind, HelperParams};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
21#[serde(rename_all = "snake_case")]
22pub enum PipelineVariant {
23 TwoPhase,
25 FourStep,
27}
28
29impl PipelineVariant {
30 #[must_use]
33 pub const fn as_str(self) -> &'static str {
34 match self {
35 Self::TwoPhase => "two_phase",
36 Self::FourStep => "four_step",
37 }
38 }
39
40 #[must_use]
43 pub fn from_str(s: &str) -> Option<Self> {
44 match s {
45 "two_phase" => Some(Self::TwoPhase),
46 "four_step" => Some(Self::FourStep),
47 _ => None,
48 }
49 }
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
57pub struct HelperOutputRef {
58 pub stage_index: usize,
60 pub label: String,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
69#[serde(tag = "type", rename_all = "snake_case")]
70pub enum Stage {
71 Helper {
74 kind: HelperKind,
76 #[serde(default)]
79 params: HelperParams,
80 },
81 LlmCall {
85 prompt_template: String,
88 #[serde(default)]
91 trust_inputs: Vec<HelperOutputRef>,
92 #[serde(default)]
95 output_schema: Value,
96 #[serde(default)]
98 label: String,
99 },
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct Pipeline {
107 pub variant: PipelineVariant,
109 pub stages: Vec<Stage>,
111 #[serde(default)]
115 pub system_prompt: String,
116}
117
118impl Pipeline {
119 #[must_use]
122 pub fn variant_tag(&self) -> &'static str {
123 self.variant.as_str()
124 }
125}
126
127#[must_use]
136pub fn two_phase_default() -> Pipeline {
137 Pipeline {
138 variant: PipelineVariant::TwoPhase,
139 system_prompt: "Synthesise the incoming content into a structured \
140 memory envelope with title, summary, tags, and atoms. \
141 Cite the helper output verbatim when it carries \
142 candidate overlaps or classifications."
143 .to_string(),
144 stages: vec![
145 Stage::Helper {
146 kind: HelperKind::FtsClassifier,
147 params: HelperParams::default(),
148 },
149 Stage::Helper {
150 kind: HelperKind::JaccardOverlap,
151 params: HelperParams::default(),
152 },
153 Stage::LlmCall {
154 label: "synthesise".to_string(),
155 prompt_template: "Produce a JSON object {title, summary, \
156 tags[], atoms[]} where each atom is a \
157 standalone fact distilled from the content. \
158 The trust slots below carry the \
159 pre-computed classifier label and the top \
160 candidate overlaps."
161 .to_string(),
162 trust_inputs: vec![
163 HelperOutputRef {
164 stage_index: 0,
165 label: "classification".to_string(),
166 },
167 HelperOutputRef {
168 stage_index: 1,
169 label: "overlap".to_string(),
170 },
171 ],
172 output_schema: json!({
173 "type": "object",
174 "required": ["title", "summary", "tags", "atoms"],
175 (field_names::PROPERTIES): {
176 "title": {"type": "string"},
177 "summary": {"type": "string"},
178 "tags": {"type": "array", "items": {"type": "string"}},
179 "atoms": {"type": "array", "items": {"type": "string"}}
180 }
181 }),
182 },
183 ],
184 }
185}
186
187#[must_use]
204pub fn four_step_default() -> Pipeline {
205 Pipeline {
206 variant: PipelineVariant::FourStep,
207 system_prompt: "Run the OpenKB four-step ingest pipeline. Each \
208 stage produces a JSON object that feeds the next \
209 stage. Trust the helper output verbatim — do not \
210 re-derive classifications or overlap scores."
211 .to_string(),
212 stages: vec![
213 Stage::Helper {
214 kind: HelperKind::FtsClassifier,
215 params: HelperParams::default(),
216 },
217 Stage::Helper {
218 kind: HelperKind::JaccardOverlap,
219 params: HelperParams::default(),
220 },
221 Stage::LlmCall {
222 label: "classify".to_string(),
223 prompt_template: "Classify this content. Return JSON \
224 {fact_kind, confidence}."
225 .to_string(),
226 trust_inputs: vec![HelperOutputRef {
227 stage_index: 0,
228 label: HelperKind::FtsClassifier.as_str().to_string(),
229 }],
230 output_schema: json!({
231 "type": "object",
232 "required": ["fact_kind", field_names::CONFIDENCE],
233 (field_names::PROPERTIES): {
234 "fact_kind": {
235 "type": "string",
236 "enum": ["procedural", "declarative", "episodic"]
237 },
238 (field_names::CONFIDENCE): {
239 "type": "number",
240 "minimum": 0.0,
241 "maximum": 1.0
242 }
243 }
244 }),
245 },
246 Stage::LlmCall {
247 label: "enrich".to_string(),
248 prompt_template: "Extract entities, claims, and relations \
249 from the content. Return JSON {entities[], \
250 claims[], relations[]}."
251 .to_string(),
252 trust_inputs: vec![HelperOutputRef {
253 stage_index: 1,
254 label: "overlap".to_string(),
255 }],
256 output_schema: json!({
257 "type": "object",
258 "required": ["entities", "claims", "relations"],
259 (field_names::PROPERTIES): {
260 "entities": {"type": "array", "items": {"type": "string"}},
261 "claims": {"type": "array", "items": {"type": "string"}},
262 "relations": {"type": "array", "items": {"type": "object"}}
263 }
264 }),
265 },
266 Stage::LlmCall {
267 label: "emit".to_string(),
268 prompt_template: "Emit the final memory envelope. Return \
269 JSON {title, summary, tags[], \
270 proposed_links[]}."
271 .to_string(),
272 trust_inputs: vec![
273 HelperOutputRef {
274 stage_index: 0,
275 label: HelperKind::FtsClassifier.as_str().to_string(),
276 },
277 HelperOutputRef {
278 stage_index: 1,
279 label: "overlap".to_string(),
280 },
281 ],
282 output_schema: json!({
283 "type": "object",
284 "required": ["title", "summary", "tags", "proposed_links"],
285 (field_names::PROPERTIES): {
286 "title": {"type": "string"},
287 "summary": {"type": "string"},
288 "tags": {"type": "array", "items": {"type": "string"}},
289 "proposed_links": {"type": "array", "items": {"type": "object"}}
290 }
291 }),
292 },
293 ],
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[test]
302 fn pipeline_variant_round_trip_via_str() {
303 assert_eq!(
304 PipelineVariant::from_str("two_phase"),
305 Some(PipelineVariant::TwoPhase)
306 );
307 assert_eq!(
308 PipelineVariant::from_str("four_step"),
309 Some(PipelineVariant::FourStep)
310 );
311 assert_eq!(PipelineVariant::from_str("nonsense"), None);
312 }
313
314 #[test]
315 fn two_phase_default_has_two_phases() {
316 let p = two_phase_default();
317 assert_eq!(p.variant, PipelineVariant::TwoPhase);
318 let helpers = p
319 .stages
320 .iter()
321 .filter(|s| matches!(s, Stage::Helper { .. }))
322 .count();
323 let llms = p
324 .stages
325 .iter()
326 .filter(|s| matches!(s, Stage::LlmCall { .. }))
327 .count();
328 assert_eq!(helpers, 2);
330 assert_eq!(llms, 1);
331 }
332
333 #[test]
334 fn four_step_default_has_four_logical_stages() {
335 let p = four_step_default();
336 assert_eq!(p.variant, PipelineVariant::FourStep);
337 let llms = p
338 .stages
339 .iter()
340 .filter(|s| matches!(s, Stage::LlmCall { .. }))
341 .count();
342 assert_eq!(llms, 3);
345 }
346
347 #[test]
348 fn two_phase_llm_stage_references_both_helpers() {
349 let p = two_phase_default();
350 let Stage::LlmCall { trust_inputs, .. } = p.stages.last().unwrap() else {
351 panic!("last stage should be LLM call");
352 };
353 assert_eq!(trust_inputs.len(), 2);
354 assert_eq!(trust_inputs[0].stage_index, 0);
355 assert_eq!(trust_inputs[1].stage_index, 1);
356 }
357
358 #[test]
359 fn four_step_llm_stages_each_have_trust_inputs() {
360 let p = four_step_default();
361 for stage in &p.stages {
362 if let Stage::LlmCall { trust_inputs, .. } = stage {
363 assert!(
364 !trust_inputs.is_empty(),
365 "every LLM stage must have at least one trust input"
366 );
367 }
368 }
369 }
370
371 #[test]
372 fn pipeline_descriptor_round_trips_through_serde() {
373 let p = four_step_default();
374 let s = serde_json::to_string(&p).expect("serialises");
375 let back: Pipeline = serde_json::from_str(&s).expect("deserialises");
376 assert_eq!(back.variant, p.variant);
377 assert_eq!(back.stages.len(), p.stages.len());
378 }
379
380 #[test]
381 fn variant_tag_matches_as_str() {
382 assert_eq!(two_phase_default().variant_tag(), "two_phase");
383 assert_eq!(four_step_default().variant_tag(), "four_step");
384 }
385}