use crate::models::field_names;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use super::helpers::{HelperKind, HelperParams};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PipelineVariant {
TwoPhase,
FourStep,
}
impl PipelineVariant {
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::TwoPhase => "two_phase",
Self::FourStep => "four_step",
}
}
#[must_use]
pub fn from_str(s: &str) -> Option<Self> {
match s {
"two_phase" => Some(Self::TwoPhase),
"four_step" => Some(Self::FourStep),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct HelperOutputRef {
pub stage_index: usize,
pub label: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Stage {
Helper {
kind: HelperKind,
#[serde(default)]
params: HelperParams,
},
LlmCall {
prompt_template: String,
#[serde(default)]
trust_inputs: Vec<HelperOutputRef>,
#[serde(default)]
output_schema: Value,
#[serde(default)]
label: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Pipeline {
pub variant: PipelineVariant,
pub stages: Vec<Stage>,
#[serde(default)]
pub system_prompt: String,
}
impl Pipeline {
#[must_use]
pub fn variant_tag(&self) -> &'static str {
self.variant.as_str()
}
}
#[must_use]
pub fn two_phase_default() -> Pipeline {
Pipeline {
variant: PipelineVariant::TwoPhase,
system_prompt: "Synthesise the incoming content into a structured \
memory envelope with title, summary, tags, and atoms. \
Cite the helper output verbatim when it carries \
candidate overlaps or classifications."
.to_string(),
stages: vec![
Stage::Helper {
kind: HelperKind::FtsClassifier,
params: HelperParams::default(),
},
Stage::Helper {
kind: HelperKind::JaccardOverlap,
params: HelperParams::default(),
},
Stage::LlmCall {
label: "synthesise".to_string(),
prompt_template: "Produce a JSON object {title, summary, \
tags[], atoms[]} where each atom is a \
standalone fact distilled from the content. \
The trust slots below carry the \
pre-computed classifier label and the top \
candidate overlaps."
.to_string(),
trust_inputs: vec![
HelperOutputRef {
stage_index: 0,
label: "classification".to_string(),
},
HelperOutputRef {
stage_index: 1,
label: "overlap".to_string(),
},
],
output_schema: json!({
"type": "object",
"required": ["title", "summary", "tags", "atoms"],
(field_names::PROPERTIES): {
"title": {"type": "string"},
"summary": {"type": "string"},
"tags": {"type": "array", "items": {"type": "string"}},
"atoms": {"type": "array", "items": {"type": "string"}}
}
}),
},
],
}
}
#[must_use]
pub fn four_step_default() -> Pipeline {
Pipeline {
variant: PipelineVariant::FourStep,
system_prompt: "Run the OpenKB four-step ingest pipeline. Each \
stage produces a JSON object that feeds the next \
stage. Trust the helper output verbatim — do not \
re-derive classifications or overlap scores."
.to_string(),
stages: vec![
Stage::Helper {
kind: HelperKind::FtsClassifier,
params: HelperParams::default(),
},
Stage::Helper {
kind: HelperKind::JaccardOverlap,
params: HelperParams::default(),
},
Stage::LlmCall {
label: "classify".to_string(),
prompt_template: "Classify this content. Return JSON \
{fact_kind, confidence}."
.to_string(),
trust_inputs: vec![HelperOutputRef {
stage_index: 0,
label: HelperKind::FtsClassifier.as_str().to_string(),
}],
output_schema: json!({
"type": "object",
"required": ["fact_kind", field_names::CONFIDENCE],
(field_names::PROPERTIES): {
"fact_kind": {
"type": "string",
"enum": ["procedural", "declarative", "episodic"]
},
(field_names::CONFIDENCE): {
"type": "number",
"minimum": 0.0,
"maximum": 1.0
}
}
}),
},
Stage::LlmCall {
label: "enrich".to_string(),
prompt_template: "Extract entities, claims, and relations \
from the content. Return JSON {entities[], \
claims[], relations[]}."
.to_string(),
trust_inputs: vec![HelperOutputRef {
stage_index: 1,
label: "overlap".to_string(),
}],
output_schema: json!({
"type": "object",
"required": ["entities", "claims", "relations"],
(field_names::PROPERTIES): {
"entities": {"type": "array", "items": {"type": "string"}},
"claims": {"type": "array", "items": {"type": "string"}},
"relations": {"type": "array", "items": {"type": "object"}}
}
}),
},
Stage::LlmCall {
label: "emit".to_string(),
prompt_template: "Emit the final memory envelope. Return \
JSON {title, summary, tags[], \
proposed_links[]}."
.to_string(),
trust_inputs: vec![
HelperOutputRef {
stage_index: 0,
label: HelperKind::FtsClassifier.as_str().to_string(),
},
HelperOutputRef {
stage_index: 1,
label: "overlap".to_string(),
},
],
output_schema: json!({
"type": "object",
"required": ["title", "summary", "tags", "proposed_links"],
(field_names::PROPERTIES): {
"title": {"type": "string"},
"summary": {"type": "string"},
"tags": {"type": "array", "items": {"type": "string"}},
"proposed_links": {"type": "array", "items": {"type": "object"}}
}
}),
},
],
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pipeline_variant_round_trip_via_str() {
assert_eq!(
PipelineVariant::from_str("two_phase"),
Some(PipelineVariant::TwoPhase)
);
assert_eq!(
PipelineVariant::from_str("four_step"),
Some(PipelineVariant::FourStep)
);
assert_eq!(PipelineVariant::from_str("nonsense"), None);
}
#[test]
fn two_phase_default_has_two_phases() {
let p = two_phase_default();
assert_eq!(p.variant, PipelineVariant::TwoPhase);
let helpers = p
.stages
.iter()
.filter(|s| matches!(s, Stage::Helper { .. }))
.count();
let llms = p
.stages
.iter()
.filter(|s| matches!(s, Stage::LlmCall { .. }))
.count();
assert_eq!(helpers, 2);
assert_eq!(llms, 1);
}
#[test]
fn four_step_default_has_four_logical_stages() {
let p = four_step_default();
assert_eq!(p.variant, PipelineVariant::FourStep);
let llms = p
.stages
.iter()
.filter(|s| matches!(s, Stage::LlmCall { .. }))
.count();
assert_eq!(llms, 3);
}
#[test]
fn two_phase_llm_stage_references_both_helpers() {
let p = two_phase_default();
let Stage::LlmCall { trust_inputs, .. } = p.stages.last().unwrap() else {
panic!("last stage should be LLM call");
};
assert_eq!(trust_inputs.len(), 2);
assert_eq!(trust_inputs[0].stage_index, 0);
assert_eq!(trust_inputs[1].stage_index, 1);
}
#[test]
fn four_step_llm_stages_each_have_trust_inputs() {
let p = four_step_default();
for stage in &p.stages {
if let Stage::LlmCall { trust_inputs, .. } = stage {
assert!(
!trust_inputs.is_empty(),
"every LLM stage must have at least one trust input"
);
}
}
}
#[test]
fn pipeline_descriptor_round_trips_through_serde() {
let p = four_step_default();
let s = serde_json::to_string(&p).expect("serialises");
let back: Pipeline = serde_json::from_str(&s).expect("deserialises");
assert_eq!(back.variant, p.variant);
assert_eq!(back.stages.len(), p.stages.len());
}
#[test]
fn variant_tag_matches_as_str() {
assert_eq!(two_phase_default().variant_tag(), "two_phase");
assert_eq!(four_step_default().variant_tag(), "four_step");
}
}