1use mentedb_cognitive::write_inference::{InferredAction, WriteInferenceEngine};
2use mentedb_core::MemoryNode;
3use mentedb_core::types::{AgentId, MemoryId};
4use mentedb_embedding::provider::EmbeddingProvider;
5
6use crate::config::ExtractionConfig;
7use crate::error::ExtractionError;
8use crate::provider::ExtractionProvider;
9use crate::schema::{ExtractedMemory, ExtractionResult};
10
11#[derive(Debug, Clone)]
13pub struct CognitiveFinding {
14 pub finding_type: CognitiveFindingType,
16 pub description: String,
18 pub related_memory_id: Option<MemoryId>,
20}
21
22#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum CognitiveFindingType {
25 Contradiction,
26 Obsolescence,
27 Related,
28 ConfidenceUpdate,
29}
30
31#[derive(Debug, Clone, Default)]
33pub struct ExtractionStats {
34 pub total_extracted: usize,
35 pub accepted: usize,
36 pub rejected_quality: usize,
37 pub rejected_duplicate: usize,
38 pub contradictions_found: usize,
39}
40
41#[derive(Debug)]
43pub struct ProcessedExtractionResult {
44 pub to_store: Vec<ExtractedMemory>,
46 pub rejected_low_quality: Vec<ExtractedMemory>,
48 pub rejected_duplicate: Vec<ExtractedMemory>,
50 pub contradictions: Vec<(ExtractedMemory, Vec<CognitiveFinding>)>,
52 pub stats: ExtractionStats,
54}
55
56pub struct ExtractionPipeline<P: ExtractionProvider> {
59 provider: P,
60 config: ExtractionConfig,
61}
62
63impl<P: ExtractionProvider> ExtractionPipeline<P> {
64 pub fn new(provider: P, config: ExtractionConfig) -> Self {
65 Self { provider, config }
66 }
67
68 pub async fn extract_from_conversation(
71 &self,
72 conversation: &str,
73 ) -> Result<Vec<ExtractedMemory>, ExtractionError> {
74 let result = self.extract_full(conversation).await?;
75 Ok(result.memories)
76 }
77
78 pub async fn extract_full(
81 &self,
82 conversation: &str,
83 ) -> Result<ExtractionResult, ExtractionError> {
84 use crate::prompts::{extraction_system_prompt, extraction_verification_prompt};
85
86 let system_prompt = extraction_system_prompt();
87 let raw_response = self.provider.extract(conversation, system_prompt).await?;
88
89 let mut result = self.parse_extraction_response(&raw_response)?;
90
91 if self.config.extraction_passes >= 2 && !result.memories.is_empty() {
93 let first_pass_facts: String = result
94 .memories
95 .iter()
96 .map(|m| format!("- {}", m.content))
97 .collect::<Vec<_>>()
98 .join("\n");
99 let verify_prompt = extraction_verification_prompt(&first_pass_facts);
100 match self.provider.extract(conversation, &verify_prompt).await {
101 Ok(verify_response) => {
102 if let Ok(verify_result) = self.parse_extraction_response(&verify_response) {
103 let new_memories = verify_result.memories.len();
104 let new_entities = verify_result.entities.len();
105 result.memories.extend(verify_result.memories);
106 result.entities.extend(verify_result.entities);
107 if new_memories > 0 || new_entities > 0 {
108 tracing::info!(
109 new_memories,
110 new_entities,
111 "verification pass found additional extractions"
112 );
113 }
114 }
115 }
116 Err(e) => {
117 tracing::warn!("verification pass failed, using first pass only: {}", e);
118 }
119 }
120 }
121
122 if result.memories.len() > self.config.max_extractions_per_conversation {
123 tracing::warn!(
124 extracted = result.memories.len(),
125 max = self.config.max_extractions_per_conversation,
126 "truncating extractions to configured maximum"
127 );
128 result
129 .memories
130 .truncate(self.config.max_extractions_per_conversation);
131 }
132
133 Ok(result)
134 }
135
136 fn parse_extraction_response(&self, raw: &str) -> Result<ExtractionResult, ExtractionError> {
139 let trimmed = raw.trim();
140
141 if trimmed.is_empty() {
143 return Ok(ExtractionResult {
144 memories: vec![],
145 entities: vec![],
146 });
147 }
148
149 let stripped = if trimmed.starts_with("```") {
151 let without_prefix = trimmed
152 .trim_start_matches("```json")
153 .trim_start_matches("```");
154 without_prefix.trim_end_matches("```").trim()
155 } else {
156 trimmed
157 };
158
159 let json_str = if let Some(start) = stripped.find('{') {
162 let candidate = &stripped[start..];
163 let mut depth = 0i32;
164 let mut in_string = false;
165 let mut escape_next = false;
166 let mut end = candidate.len();
167 for (i, ch) in candidate.char_indices() {
168 if escape_next {
169 escape_next = false;
170 continue;
171 }
172 if in_string {
173 match ch {
174 '\\' => escape_next = true,
175 '"' => in_string = false,
176 _ => {}
177 }
178 continue;
179 }
180 match ch {
181 '"' => in_string = true,
182 '{' => depth += 1,
183 '}' => {
184 depth -= 1;
185 if depth == 0 {
186 end = i + 1;
187 break;
188 }
189 }
190 _ => {}
191 }
192 }
193 &candidate[..end]
194 } else {
195 return Ok(ExtractionResult {
197 memories: vec![],
198 entities: vec![],
199 });
200 };
201
202 let value: serde_json::Value = serde_json::from_str(json_str).map_err(|e| {
205 tracing::error!(
206 error = %e,
207 response_preview = &json_str[..json_str.len().min(200)],
208 "failed to parse LLM extraction response as JSON"
209 );
210 ExtractionError::ParseError(format!("Failed to parse extraction JSON: {e}"))
211 })?;
212
213 serde_json::from_value::<ExtractionResult>(value).map_err(|e| {
214 tracing::error!(
215 error = %e,
216 "failed to deserialize extraction JSON into ExtractionResult"
217 );
218 ExtractionError::ParseError(format!("Failed to parse extraction JSON: {e}"))
219 })
220 }
221
222 pub fn filter_quality(&self, memories: &[ExtractedMemory]) -> Vec<ExtractedMemory> {
224 memories
225 .iter()
226 .filter(|m| m.confidence >= self.config.quality_threshold)
227 .cloned()
228 .collect()
229 }
230
231 pub fn check_contradictions(
234 &self,
235 new_memory: &ExtractedMemory,
236 existing: &[MemoryNode],
237 embedding_provider: &dyn EmbeddingProvider,
238 ) -> Vec<CognitiveFinding> {
239 if !self.config.enable_contradiction_check || existing.is_empty() {
240 return Vec::new();
241 }
242
243 let embedding = match embedding_provider.embed(&new_memory.content) {
244 Ok(e) => e,
245 Err(err) => {
246 tracing::warn!(error = %err, "failed to embed memory for contradiction check");
247 return Vec::new();
248 }
249 };
250
251 let memory_type = map_extraction_type_to_memory_type(&new_memory.memory_type);
252 let temp_node = MemoryNode::new(
253 AgentId::nil(),
254 memory_type,
255 new_memory.content.clone(),
256 embedding,
257 );
258
259 let engine = WriteInferenceEngine::new();
260 let actions = engine.infer_on_write(&temp_node, existing, &[]);
261
262 let mut findings = Vec::new();
263 for action in actions {
264 match action {
265 InferredAction::FlagContradiction {
266 existing: existing_id,
267 reason,
268 ..
269 } => {
270 findings.push(CognitiveFinding {
271 finding_type: CognitiveFindingType::Contradiction,
272 description: reason,
273 related_memory_id: Some(existing_id),
274 });
275 }
276 InferredAction::MarkObsolete {
277 memory,
278 superseded_by: _,
279 } => {
280 findings.push(CognitiveFinding {
281 finding_type: CognitiveFindingType::Obsolescence,
282 description: format!("Memory {memory} may be obsolete"),
283 related_memory_id: Some(memory),
284 });
285 }
286 InferredAction::UpdateConfidence {
287 memory,
288 new_confidence,
289 } => {
290 findings.push(CognitiveFinding {
291 finding_type: CognitiveFindingType::ConfidenceUpdate,
292 description: format!(
293 "Confidence for {memory} should be updated to {new_confidence:.2}"
294 ),
295 related_memory_id: Some(memory),
296 });
297 }
298 InferredAction::CreateEdge { target, .. } => {
299 findings.push(CognitiveFinding {
300 finding_type: CognitiveFindingType::Related,
301 description: format!("Related to existing memory {target}"),
302 related_memory_id: Some(target),
303 });
304 }
305 _ => {}
306 }
307 }
308
309 findings
310 }
311
312 pub fn check_duplicates(
315 &self,
316 new_memory: &ExtractedMemory,
317 existing: &[MemoryNode],
318 embedding_provider: &dyn EmbeddingProvider,
319 ) -> bool {
320 if !self.config.enable_deduplication || existing.is_empty() {
321 return false;
322 }
323
324 let new_embedding = match embedding_provider.embed(&new_memory.content) {
325 Ok(e) => e,
326 Err(err) => {
327 tracing::warn!(error = %err, "failed to embed memory for dedup check");
328 return false;
329 }
330 };
331
332 for mem in existing {
333 let sim = cosine_similarity(&new_embedding, &mem.embedding);
334 if sim >= self.config.deduplication_threshold {
335 tracing::debug!(
336 similarity = sim,
337 threshold = self.config.deduplication_threshold,
338 existing_id = %mem.id,
339 "duplicate detected"
340 );
341 return true;
342 }
343 }
344
345 false
346 }
347
348 pub async fn process(
351 &self,
352 conversation: &str,
353 existing_memories: &[MemoryNode],
354 embedding_provider: &dyn EmbeddingProvider,
355 ) -> Result<ProcessedExtractionResult, ExtractionError> {
356 let all_memories = self.extract_from_conversation(conversation).await?;
357 let total_extracted = all_memories.len();
358
359 let quality_passed = self.filter_quality(&all_memories);
360 let rejected_low_quality: Vec<ExtractedMemory> = all_memories
361 .iter()
362 .filter(|m| m.confidence < self.config.quality_threshold)
363 .cloned()
364 .collect();
365
366 let mut to_store = Vec::new();
367 let mut rejected_duplicate = Vec::new();
368 let mut contradictions = Vec::new();
369
370 for memory in quality_passed {
371 if self.check_duplicates(&memory, existing_memories, embedding_provider) {
372 rejected_duplicate.push(memory);
373 continue;
374 }
375
376 let findings =
377 self.check_contradictions(&memory, existing_memories, embedding_provider);
378 let has_contradiction = findings
379 .iter()
380 .any(|f| f.finding_type == CognitiveFindingType::Contradiction);
381
382 if has_contradiction {
383 contradictions.push((memory, findings));
384 } else {
385 to_store.push(memory);
386 }
387 }
388
389 let stats = ExtractionStats {
390 total_extracted,
391 accepted: to_store.len(),
392 rejected_quality: rejected_low_quality.len(),
393 rejected_duplicate: rejected_duplicate.len(),
394 contradictions_found: contradictions.len(),
395 };
396
397 tracing::info!(
398 total = stats.total_extracted,
399 accepted = stats.accepted,
400 rejected_quality = stats.rejected_quality,
401 rejected_duplicate = stats.rejected_duplicate,
402 contradictions = stats.contradictions_found,
403 "extraction pipeline complete"
404 );
405
406 Ok(ProcessedExtractionResult {
407 to_store,
408 rejected_low_quality,
409 rejected_duplicate,
410 contradictions,
411 stats,
412 })
413 }
414}
415
416pub fn map_extraction_type_to_memory_type(
418 extraction_type: &str,
419) -> mentedb_core::memory::MemoryType {
420 use mentedb_core::memory::MemoryType;
421 match extraction_type.to_lowercase().as_str() {
422 "decision" | "preference" | "fact" | "entity" => MemoryType::Semantic,
423 "correction" => MemoryType::Correction,
424 "anti_pattern" => MemoryType::AntiPattern,
425 _ => MemoryType::Episodic,
426 }
427}
428
429fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
430 if a.len() != b.len() || a.is_empty() {
431 return 0.0;
432 }
433 let mut dot = 0.0f32;
434 let mut norm_a = 0.0f32;
435 let mut norm_b = 0.0f32;
436 for i in 0..a.len() {
437 dot += a[i] * b[i];
438 norm_a += a[i] * a[i];
439 norm_b += b[i] * b[i];
440 }
441 let denom = norm_a.sqrt() * norm_b.sqrt();
442 if denom == 0.0 { 0.0 } else { dot / denom }
443}