Skip to main content

mentedb_extraction/
pipeline.rs

1use mentedb_cognitive::write_inference::{InferredAction, WriteInferenceEngine};
2use mentedb_core::MemoryNode;
3use mentedb_core::types::{AgentId, MemoryId};
4use mentedb_embedding::provider::EmbeddingProvider;
5
6use crate::config::ExtractionConfig;
7use crate::error::ExtractionError;
8use crate::provider::ExtractionProvider;
9use crate::schema::{ExtractedMemory, ExtractionResult};
10
11/// Findings from cognitive checks (contradiction detection).
12#[derive(Debug, Clone)]
13pub struct CognitiveFinding {
14    /// What type of issue was found.
15    pub finding_type: CognitiveFindingType,
16    /// Human-readable description of the finding.
17    pub description: String,
18    /// ID of the existing memory involved, if any.
19    pub related_memory_id: Option<MemoryId>,
20}
21
22/// Types of cognitive findings.
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum CognitiveFindingType {
25    Contradiction,
26    Obsolescence,
27    Related,
28    ConfidenceUpdate,
29}
30
31/// Statistics from a full extraction pipeline run.
32#[derive(Debug, Clone, Default)]
33pub struct ExtractionStats {
34    pub total_extracted: usize,
35    pub accepted: usize,
36    pub rejected_quality: usize,
37    pub rejected_duplicate: usize,
38    pub contradictions_found: usize,
39}
40
41/// The complete result of running the extraction pipeline.
42#[derive(Debug)]
43pub struct ProcessedExtractionResult {
44    /// Memories that passed all checks and should be stored.
45    pub to_store: Vec<ExtractedMemory>,
46    /// Memories rejected for low confidence scores.
47    pub rejected_low_quality: Vec<ExtractedMemory>,
48    /// Memories rejected as duplicates of existing memories.
49    pub rejected_duplicate: Vec<ExtractedMemory>,
50    /// Memories that contradict existing ones (stored anyway, with findings).
51    pub contradictions: Vec<(ExtractedMemory, Vec<CognitiveFinding>)>,
52    /// Summary statistics.
53    pub stats: ExtractionStats,
54}
55
56/// The main extraction engine. Takes raw conversations, extracts structured
57/// memories via an LLM, then filters and validates them before storage.
58pub struct ExtractionPipeline<P: ExtractionProvider> {
59    provider: P,
60    config: ExtractionConfig,
61}
62
63impl<P: ExtractionProvider> ExtractionPipeline<P> {
64    pub fn new(provider: P, config: ExtractionConfig) -> Self {
65        Self { provider, config }
66    }
67
68    /// Call the LLM to extract memories from a conversation, then parse the
69    /// response and filter by quality threshold.
70    pub async fn extract_from_conversation(
71        &self,
72        conversation: &str,
73    ) -> Result<Vec<ExtractedMemory>, ExtractionError> {
74        let result = self.extract_full(conversation).await?;
75        Ok(result.memories)
76    }
77
78    /// Extract memories AND entities from a conversation.
79    /// Returns the full ExtractionResult including structured entities.
80    pub async fn extract_full(
81        &self,
82        conversation: &str,
83    ) -> Result<ExtractionResult, ExtractionError> {
84        use crate::prompts::{extraction_system_prompt, extraction_verification_prompt};
85
86        let system_prompt = extraction_system_prompt();
87        let raw_response = self.provider.extract(conversation, system_prompt).await?;
88
89        let mut result = self.parse_extraction_response(&raw_response)?;
90
91        // Verification pass: re-read conversation to find what the first pass missed
92        if self.config.extraction_passes >= 2 && !result.memories.is_empty() {
93            let first_pass_facts: String = result
94                .memories
95                .iter()
96                .map(|m| format!("- {}", m.content))
97                .collect::<Vec<_>>()
98                .join("\n");
99            let verify_prompt = extraction_verification_prompt(&first_pass_facts);
100            match self.provider.extract(conversation, &verify_prompt).await {
101                Ok(verify_response) => {
102                    if let Ok(verify_result) = self.parse_extraction_response(&verify_response) {
103                        let new_memories = verify_result.memories.len();
104                        let new_entities = verify_result.entities.len();
105                        result.memories.extend(verify_result.memories);
106                        result.entities.extend(verify_result.entities);
107                        if new_memories > 0 || new_entities > 0 {
108                            tracing::info!(
109                                new_memories,
110                                new_entities,
111                                "verification pass found additional extractions"
112                            );
113                        }
114                    }
115                }
116                Err(e) => {
117                    tracing::warn!("verification pass failed, using first pass only: {}", e);
118                }
119            }
120        }
121
122        if result.memories.len() > self.config.max_extractions_per_conversation {
123            tracing::warn!(
124                extracted = result.memories.len(),
125                max = self.config.max_extractions_per_conversation,
126                "truncating extractions to configured maximum"
127            );
128            result
129                .memories
130                .truncate(self.config.max_extractions_per_conversation);
131        }
132
133        Ok(result)
134    }
135
136    /// Parse the raw JSON response from the LLM into an ExtractionResult.
137    /// Handles edge cases like markdown fences around JSON and preamble text.
138    fn parse_extraction_response(&self, raw: &str) -> Result<ExtractionResult, ExtractionError> {
139        let trimmed = raw.trim();
140
141        // Empty response = no memories to extract
142        if trimmed.is_empty() {
143            return Ok(ExtractionResult {
144                memories: vec![],
145                entities: vec![],
146            });
147        }
148
149        // Strip markdown code fences if present
150        let stripped = if trimmed.starts_with("```") {
151            let without_prefix = trimmed
152                .trim_start_matches("```json")
153                .trim_start_matches("```");
154            without_prefix.trim_end_matches("```").trim()
155        } else {
156            trimmed
157        };
158
159        // Find the outermost JSON object using brace-depth matching
160        // that respects quoted strings (handles braces inside string values)
161        let json_str = if let Some(start) = stripped.find('{') {
162            let candidate = &stripped[start..];
163            let mut depth = 0i32;
164            let mut in_string = false;
165            let mut escape_next = false;
166            let mut end = candidate.len();
167            for (i, ch) in candidate.char_indices() {
168                if escape_next {
169                    escape_next = false;
170                    continue;
171                }
172                if in_string {
173                    match ch {
174                        '\\' => escape_next = true,
175                        '"' => in_string = false,
176                        _ => {}
177                    }
178                    continue;
179                }
180                match ch {
181                    '"' => in_string = true,
182                    '{' => depth += 1,
183                    '}' => {
184                        depth -= 1;
185                        if depth == 0 {
186                            end = i + 1;
187                            break;
188                        }
189                    }
190                    _ => {}
191                }
192            }
193            &candidate[..end]
194        } else {
195            // No JSON object found — LLM returned plain text (e.g. "No memories to extract")
196            return Ok(ExtractionResult {
197                memories: vec![],
198                entities: vec![],
199            });
200        };
201
202        // Parse with serde_json::Value first (tolerates duplicate keys — last one wins)
203        // then convert to ExtractionResult. LLMs sometimes emit duplicate fields.
204        let value: serde_json::Value = serde_json::from_str(json_str).map_err(|e| {
205            tracing::error!(
206                error = %e,
207                response_preview = &json_str[..json_str.len().min(200)],
208                "failed to parse LLM extraction response as JSON"
209            );
210            ExtractionError::ParseError(format!("Failed to parse extraction JSON: {e}"))
211        })?;
212
213        serde_json::from_value::<ExtractionResult>(value).map_err(|e| {
214            tracing::error!(
215                error = %e,
216                "failed to deserialize extraction JSON into ExtractionResult"
217            );
218            ExtractionError::ParseError(format!("Failed to parse extraction JSON: {e}"))
219        })
220    }
221
222    /// Remove memories below the configured confidence threshold.
223    pub fn filter_quality(&self, memories: &[ExtractedMemory]) -> Vec<ExtractedMemory> {
224        memories
225            .iter()
226            .filter(|m| m.confidence >= self.config.quality_threshold)
227            .cloned()
228            .collect()
229    }
230
231    /// Check a new extracted memory against existing memories for contradictions
232    /// using the WriteInferenceEngine.
233    pub fn check_contradictions(
234        &self,
235        new_memory: &ExtractedMemory,
236        existing: &[MemoryNode],
237        embedding_provider: &dyn EmbeddingProvider,
238    ) -> Vec<CognitiveFinding> {
239        if !self.config.enable_contradiction_check || existing.is_empty() {
240            return Vec::new();
241        }
242
243        let embedding = match embedding_provider.embed(&new_memory.content) {
244            Ok(e) => e,
245            Err(err) => {
246                tracing::warn!(error = %err, "failed to embed memory for contradiction check");
247                return Vec::new();
248            }
249        };
250
251        let memory_type = map_extraction_type_to_memory_type(&new_memory.memory_type);
252        let temp_node = MemoryNode::new(
253            AgentId::nil(),
254            memory_type,
255            new_memory.content.clone(),
256            embedding,
257        );
258
259        let engine = WriteInferenceEngine::new();
260        let actions = engine.infer_on_write(&temp_node, existing, &[]);
261
262        let mut findings = Vec::new();
263        for action in actions {
264            match action {
265                InferredAction::FlagContradiction {
266                    existing: existing_id,
267                    reason,
268                    ..
269                } => {
270                    findings.push(CognitiveFinding {
271                        finding_type: CognitiveFindingType::Contradiction,
272                        description: reason,
273                        related_memory_id: Some(existing_id),
274                    });
275                }
276                InferredAction::MarkObsolete {
277                    memory,
278                    superseded_by: _,
279                } => {
280                    findings.push(CognitiveFinding {
281                        finding_type: CognitiveFindingType::Obsolescence,
282                        description: format!("Memory {memory} may be obsolete"),
283                        related_memory_id: Some(memory),
284                    });
285                }
286                InferredAction::UpdateConfidence {
287                    memory,
288                    new_confidence,
289                } => {
290                    findings.push(CognitiveFinding {
291                        finding_type: CognitiveFindingType::ConfidenceUpdate,
292                        description: format!(
293                            "Confidence for {memory} should be updated to {new_confidence:.2}"
294                        ),
295                        related_memory_id: Some(memory),
296                    });
297                }
298                InferredAction::CreateEdge { target, .. } => {
299                    findings.push(CognitiveFinding {
300                        finding_type: CognitiveFindingType::Related,
301                        description: format!("Related to existing memory {target}"),
302                        related_memory_id: Some(target),
303                    });
304                }
305                _ => {}
306            }
307        }
308
309        findings
310    }
311
312    /// Check if a new memory is too similar to any existing memory
313    /// (above deduplication_threshold).
314    pub fn check_duplicates(
315        &self,
316        new_memory: &ExtractedMemory,
317        existing: &[MemoryNode],
318        embedding_provider: &dyn EmbeddingProvider,
319    ) -> bool {
320        if !self.config.enable_deduplication || existing.is_empty() {
321            return false;
322        }
323
324        let new_embedding = match embedding_provider.embed(&new_memory.content) {
325            Ok(e) => e,
326            Err(err) => {
327                tracing::warn!(error = %err, "failed to embed memory for dedup check");
328                return false;
329            }
330        };
331
332        for mem in existing {
333            let sim = cosine_similarity(&new_embedding, &mem.embedding);
334            if sim >= self.config.deduplication_threshold {
335                tracing::debug!(
336                    similarity = sim,
337                    threshold = self.config.deduplication_threshold,
338                    existing_id = %mem.id,
339                    "duplicate detected"
340                );
341                return true;
342            }
343        }
344
345        false
346    }
347
348    /// Run the full extraction pipeline: extract -> filter quality ->
349    /// check duplicates -> check contradictions.
350    pub async fn process(
351        &self,
352        conversation: &str,
353        existing_memories: &[MemoryNode],
354        embedding_provider: &dyn EmbeddingProvider,
355    ) -> Result<ProcessedExtractionResult, ExtractionError> {
356        let all_memories = self.extract_from_conversation(conversation).await?;
357        let total_extracted = all_memories.len();
358
359        let quality_passed = self.filter_quality(&all_memories);
360        let rejected_low_quality: Vec<ExtractedMemory> = all_memories
361            .iter()
362            .filter(|m| m.confidence < self.config.quality_threshold)
363            .cloned()
364            .collect();
365
366        let mut to_store = Vec::new();
367        let mut rejected_duplicate = Vec::new();
368        let mut contradictions = Vec::new();
369
370        for memory in quality_passed {
371            if self.check_duplicates(&memory, existing_memories, embedding_provider) {
372                rejected_duplicate.push(memory);
373                continue;
374            }
375
376            let findings =
377                self.check_contradictions(&memory, existing_memories, embedding_provider);
378            let has_contradiction = findings
379                .iter()
380                .any(|f| f.finding_type == CognitiveFindingType::Contradiction);
381
382            if has_contradiction {
383                contradictions.push((memory, findings));
384            } else {
385                to_store.push(memory);
386            }
387        }
388
389        let stats = ExtractionStats {
390            total_extracted,
391            accepted: to_store.len(),
392            rejected_quality: rejected_low_quality.len(),
393            rejected_duplicate: rejected_duplicate.len(),
394            contradictions_found: contradictions.len(),
395        };
396
397        tracing::info!(
398            total = stats.total_extracted,
399            accepted = stats.accepted,
400            rejected_quality = stats.rejected_quality,
401            rejected_duplicate = stats.rejected_duplicate,
402            contradictions = stats.contradictions_found,
403            "extraction pipeline complete"
404        );
405
406        Ok(ProcessedExtractionResult {
407            to_store,
408            rejected_low_quality,
409            rejected_duplicate,
410            contradictions,
411            stats,
412        })
413    }
414}
415
416/// Map extraction type strings to MemoryType enum variants.
417pub fn map_extraction_type_to_memory_type(
418    extraction_type: &str,
419) -> mentedb_core::memory::MemoryType {
420    use mentedb_core::memory::MemoryType;
421    match extraction_type.to_lowercase().as_str() {
422        "decision" | "preference" | "fact" | "entity" => MemoryType::Semantic,
423        "correction" => MemoryType::Correction,
424        "anti_pattern" => MemoryType::AntiPattern,
425        _ => MemoryType::Episodic,
426    }
427}
428
429fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
430    if a.len() != b.len() || a.is_empty() {
431        return 0.0;
432    }
433    let mut dot = 0.0f32;
434    let mut norm_a = 0.0f32;
435    let mut norm_b = 0.0f32;
436    for i in 0..a.len() {
437        dot += a[i] * b[i];
438        norm_a += a[i] * a[i];
439        norm_b += b[i] * b[i];
440    }
441    let denom = norm_a.sqrt() * norm_b.sqrt();
442    if denom == 0.0 { 0.0 } else { dot / denom }
443}