Skip to main content

mentedb_extraction/
pipeline.rs

1use mentedb_cognitive::write_inference::{InferredAction, WriteInferenceEngine};
2use mentedb_core::MemoryNode;
3use mentedb_core::types::{AgentId, MemoryId};
4use mentedb_embedding::provider::EmbeddingProvider;
5
6use crate::config::ExtractionConfig;
7use crate::error::ExtractionError;
8use crate::prompts::extraction_system_prompt;
9use crate::provider::ExtractionProvider;
10use crate::schema::{ExtractedMemory, ExtractionResult};
11
12/// Findings from cognitive checks (contradiction detection).
13#[derive(Debug, Clone)]
14pub struct CognitiveFinding {
15    /// What type of issue was found.
16    pub finding_type: CognitiveFindingType,
17    /// Human-readable description of the finding.
18    pub description: String,
19    /// ID of the existing memory involved, if any.
20    pub related_memory_id: Option<MemoryId>,
21}
22
23/// Types of cognitive findings.
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum CognitiveFindingType {
26    Contradiction,
27    Obsolescence,
28    Related,
29    ConfidenceUpdate,
30}
31
32/// Statistics from a full extraction pipeline run.
33#[derive(Debug, Clone, Default)]
34pub struct ExtractionStats {
35    pub total_extracted: usize,
36    pub accepted: usize,
37    pub rejected_quality: usize,
38    pub rejected_duplicate: usize,
39    pub contradictions_found: usize,
40}
41
42/// The complete result of running the extraction pipeline.
43#[derive(Debug)]
44pub struct ProcessedExtractionResult {
45    /// Memories that passed all checks and should be stored.
46    pub to_store: Vec<ExtractedMemory>,
47    /// Memories rejected for low confidence scores.
48    pub rejected_low_quality: Vec<ExtractedMemory>,
49    /// Memories rejected as duplicates of existing memories.
50    pub rejected_duplicate: Vec<ExtractedMemory>,
51    /// Memories that contradict existing ones (stored anyway, with findings).
52    pub contradictions: Vec<(ExtractedMemory, Vec<CognitiveFinding>)>,
53    /// Summary statistics.
54    pub stats: ExtractionStats,
55}
56
57/// The main extraction engine. Takes raw conversations, extracts structured
58/// memories via an LLM, then filters and validates them before storage.
59pub struct ExtractionPipeline<P: ExtractionProvider> {
60    provider: P,
61    config: ExtractionConfig,
62}
63
64impl<P: ExtractionProvider> ExtractionPipeline<P> {
65    pub fn new(provider: P, config: ExtractionConfig) -> Self {
66        Self { provider, config }
67    }
68
69    /// Call the LLM to extract memories from a conversation, then parse the
70    /// response and filter by quality threshold.
71    pub async fn extract_from_conversation(
72        &self,
73        conversation: &str,
74    ) -> Result<Vec<ExtractedMemory>, ExtractionError> {
75        let system_prompt = extraction_system_prompt();
76        let raw_response = self.provider.extract(conversation, system_prompt).await?;
77
78        let result = self.parse_extraction_response(&raw_response)?;
79
80        let mut memories = result.memories;
81        if memories.len() > self.config.max_extractions_per_conversation {
82            tracing::warn!(
83                extracted = memories.len(),
84                max = self.config.max_extractions_per_conversation,
85                "truncating extractions to configured maximum"
86            );
87            memories.truncate(self.config.max_extractions_per_conversation);
88        }
89
90        Ok(memories)
91    }
92
93    /// Parse the raw JSON response from the LLM into an ExtractionResult.
94    /// Handles edge cases like markdown fences around JSON.
95    fn parse_extraction_response(&self, raw: &str) -> Result<ExtractionResult, ExtractionError> {
96        let trimmed = raw.trim();
97
98        // Strip markdown code fences if present
99        let json_str = if trimmed.starts_with("```") {
100            let without_prefix = trimmed
101                .trim_start_matches("```json")
102                .trim_start_matches("```");
103            without_prefix.trim_end_matches("```").trim()
104        } else {
105            trimmed
106        };
107
108        serde_json::from_str::<ExtractionResult>(json_str).map_err(|e| {
109            tracing::error!(
110                error = %e,
111                response_preview = &json_str[..json_str.len().min(200)],
112                "failed to parse LLM extraction response"
113            );
114            ExtractionError::ParseError(format!("Failed to parse extraction JSON: {e}"))
115        })
116    }
117
118    /// Remove memories below the configured confidence threshold.
119    pub fn filter_quality(&self, memories: &[ExtractedMemory]) -> Vec<ExtractedMemory> {
120        memories
121            .iter()
122            .filter(|m| m.confidence >= self.config.quality_threshold)
123            .cloned()
124            .collect()
125    }
126
127    /// Check a new extracted memory against existing memories for contradictions
128    /// using the WriteInferenceEngine.
129    pub fn check_contradictions(
130        &self,
131        new_memory: &ExtractedMemory,
132        existing: &[MemoryNode],
133        embedding_provider: &dyn EmbeddingProvider,
134    ) -> Vec<CognitiveFinding> {
135        if !self.config.enable_contradiction_check || existing.is_empty() {
136            return Vec::new();
137        }
138
139        let embedding = match embedding_provider.embed(&new_memory.content) {
140            Ok(e) => e,
141            Err(err) => {
142                tracing::warn!(error = %err, "failed to embed memory for contradiction check");
143                return Vec::new();
144            }
145        };
146
147        let memory_type = map_extraction_type_to_memory_type(&new_memory.memory_type);
148        let temp_node = MemoryNode::new(
149            AgentId::nil(),
150            memory_type,
151            new_memory.content.clone(),
152            embedding,
153        );
154
155        let engine = WriteInferenceEngine::new();
156        let actions = engine.infer_on_write(&temp_node, existing, &[]);
157
158        let mut findings = Vec::new();
159        for action in actions {
160            match action {
161                InferredAction::FlagContradiction {
162                    existing: existing_id,
163                    reason,
164                    ..
165                } => {
166                    findings.push(CognitiveFinding {
167                        finding_type: CognitiveFindingType::Contradiction,
168                        description: reason,
169                        related_memory_id: Some(existing_id),
170                    });
171                }
172                InferredAction::MarkObsolete {
173                    memory,
174                    superseded_by: _,
175                } => {
176                    findings.push(CognitiveFinding {
177                        finding_type: CognitiveFindingType::Obsolescence,
178                        description: format!("Memory {memory} may be obsolete"),
179                        related_memory_id: Some(memory),
180                    });
181                }
182                InferredAction::UpdateConfidence {
183                    memory,
184                    new_confidence,
185                } => {
186                    findings.push(CognitiveFinding {
187                        finding_type: CognitiveFindingType::ConfidenceUpdate,
188                        description: format!(
189                            "Confidence for {memory} should be updated to {new_confidence:.2}"
190                        ),
191                        related_memory_id: Some(memory),
192                    });
193                }
194                InferredAction::CreateEdge { target, .. } => {
195                    findings.push(CognitiveFinding {
196                        finding_type: CognitiveFindingType::Related,
197                        description: format!("Related to existing memory {target}"),
198                        related_memory_id: Some(target),
199                    });
200                }
201                _ => {}
202            }
203        }
204
205        findings
206    }
207
208    /// Check if a new memory is too similar to any existing memory
209    /// (above deduplication_threshold).
210    pub fn check_duplicates(
211        &self,
212        new_memory: &ExtractedMemory,
213        existing: &[MemoryNode],
214        embedding_provider: &dyn EmbeddingProvider,
215    ) -> bool {
216        if !self.config.enable_deduplication || existing.is_empty() {
217            return false;
218        }
219
220        let new_embedding = match embedding_provider.embed(&new_memory.content) {
221            Ok(e) => e,
222            Err(err) => {
223                tracing::warn!(error = %err, "failed to embed memory for dedup check");
224                return false;
225            }
226        };
227
228        for mem in existing {
229            let sim = cosine_similarity(&new_embedding, &mem.embedding);
230            if sim >= self.config.deduplication_threshold {
231                tracing::debug!(
232                    similarity = sim,
233                    threshold = self.config.deduplication_threshold,
234                    existing_id = %mem.id,
235                    "duplicate detected"
236                );
237                return true;
238            }
239        }
240
241        false
242    }
243
244    /// Run the full extraction pipeline: extract -> filter quality ->
245    /// check duplicates -> check contradictions.
246    pub async fn process(
247        &self,
248        conversation: &str,
249        existing_memories: &[MemoryNode],
250        embedding_provider: &dyn EmbeddingProvider,
251    ) -> Result<ProcessedExtractionResult, ExtractionError> {
252        let all_memories = self.extract_from_conversation(conversation).await?;
253        let total_extracted = all_memories.len();
254
255        let quality_passed = self.filter_quality(&all_memories);
256        let rejected_low_quality: Vec<ExtractedMemory> = all_memories
257            .iter()
258            .filter(|m| m.confidence < self.config.quality_threshold)
259            .cloned()
260            .collect();
261
262        let mut to_store = Vec::new();
263        let mut rejected_duplicate = Vec::new();
264        let mut contradictions = Vec::new();
265
266        for memory in quality_passed {
267            if self.check_duplicates(&memory, existing_memories, embedding_provider) {
268                rejected_duplicate.push(memory);
269                continue;
270            }
271
272            let findings =
273                self.check_contradictions(&memory, existing_memories, embedding_provider);
274            let has_contradiction = findings
275                .iter()
276                .any(|f| f.finding_type == CognitiveFindingType::Contradiction);
277
278            if has_contradiction {
279                contradictions.push((memory, findings));
280            } else {
281                to_store.push(memory);
282            }
283        }
284
285        let stats = ExtractionStats {
286            total_extracted,
287            accepted: to_store.len(),
288            rejected_quality: rejected_low_quality.len(),
289            rejected_duplicate: rejected_duplicate.len(),
290            contradictions_found: contradictions.len(),
291        };
292
293        tracing::info!(
294            total = stats.total_extracted,
295            accepted = stats.accepted,
296            rejected_quality = stats.rejected_quality,
297            rejected_duplicate = stats.rejected_duplicate,
298            contradictions = stats.contradictions_found,
299            "extraction pipeline complete"
300        );
301
302        Ok(ProcessedExtractionResult {
303            to_store,
304            rejected_low_quality,
305            rejected_duplicate,
306            contradictions,
307            stats,
308        })
309    }
310}
311
312/// Map extraction type strings to MemoryType enum variants.
313pub fn map_extraction_type_to_memory_type(
314    extraction_type: &str,
315) -> mentedb_core::memory::MemoryType {
316    use mentedb_core::memory::MemoryType;
317    match extraction_type.to_lowercase().as_str() {
318        "decision" | "preference" | "fact" | "entity" => MemoryType::Semantic,
319        "correction" => MemoryType::Correction,
320        "anti_pattern" => MemoryType::AntiPattern,
321        _ => MemoryType::Episodic,
322    }
323}
324
325fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
326    if a.len() != b.len() || a.is_empty() {
327        return 0.0;
328    }
329    let mut dot = 0.0f32;
330    let mut norm_a = 0.0f32;
331    let mut norm_b = 0.0f32;
332    for i in 0..a.len() {
333        dot += a[i] * b[i];
334        norm_a += a[i] * a[i];
335        norm_b += b[i] * b[i];
336    }
337    let denom = norm_a.sqrt() * norm_b.sqrt();
338    if denom == 0.0 { 0.0 } else { dot / denom }
339}