1use mentedb_cognitive::write_inference::{InferredAction, WriteInferenceEngine};
2use mentedb_core::MemoryNode;
3use mentedb_core::types::{AgentId, MemoryId};
4use mentedb_embedding::provider::EmbeddingProvider;
5
6use crate::config::ExtractionConfig;
7use crate::error::ExtractionError;
8use crate::prompts::extraction_system_prompt;
9use crate::provider::ExtractionProvider;
10use crate::schema::{ExtractedMemory, ExtractionResult};
11
12#[derive(Debug, Clone)]
14pub struct CognitiveFinding {
15 pub finding_type: CognitiveFindingType,
17 pub description: String,
19 pub related_memory_id: Option<MemoryId>,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum CognitiveFindingType {
26 Contradiction,
27 Obsolescence,
28 Related,
29 ConfidenceUpdate,
30}
31
32#[derive(Debug, Clone, Default)]
34pub struct ExtractionStats {
35 pub total_extracted: usize,
36 pub accepted: usize,
37 pub rejected_quality: usize,
38 pub rejected_duplicate: usize,
39 pub contradictions_found: usize,
40}
41
42#[derive(Debug)]
44pub struct ProcessedExtractionResult {
45 pub to_store: Vec<ExtractedMemory>,
47 pub rejected_low_quality: Vec<ExtractedMemory>,
49 pub rejected_duplicate: Vec<ExtractedMemory>,
51 pub contradictions: Vec<(ExtractedMemory, Vec<CognitiveFinding>)>,
53 pub stats: ExtractionStats,
55}
56
57pub struct ExtractionPipeline<P: ExtractionProvider> {
60 provider: P,
61 config: ExtractionConfig,
62}
63
64impl<P: ExtractionProvider> ExtractionPipeline<P> {
65 pub fn new(provider: P, config: ExtractionConfig) -> Self {
66 Self { provider, config }
67 }
68
69 pub async fn extract_from_conversation(
72 &self,
73 conversation: &str,
74 ) -> Result<Vec<ExtractedMemory>, ExtractionError> {
75 let system_prompt = extraction_system_prompt();
76 let raw_response = self.provider.extract(conversation, system_prompt).await?;
77
78 let result = self.parse_extraction_response(&raw_response)?;
79
80 let mut memories = result.memories;
81 if memories.len() > self.config.max_extractions_per_conversation {
82 tracing::warn!(
83 extracted = memories.len(),
84 max = self.config.max_extractions_per_conversation,
85 "truncating extractions to configured maximum"
86 );
87 memories.truncate(self.config.max_extractions_per_conversation);
88 }
89
90 Ok(memories)
91 }
92
93 fn parse_extraction_response(&self, raw: &str) -> Result<ExtractionResult, ExtractionError> {
96 let trimmed = raw.trim();
97
98 let json_str = if trimmed.starts_with("```") {
100 let without_prefix = trimmed
101 .trim_start_matches("```json")
102 .trim_start_matches("```");
103 without_prefix.trim_end_matches("```").trim()
104 } else {
105 trimmed
106 };
107
108 serde_json::from_str::<ExtractionResult>(json_str).map_err(|e| {
109 tracing::error!(
110 error = %e,
111 response_preview = &json_str[..json_str.len().min(200)],
112 "failed to parse LLM extraction response"
113 );
114 ExtractionError::ParseError(format!("Failed to parse extraction JSON: {e}"))
115 })
116 }
117
118 pub fn filter_quality(&self, memories: &[ExtractedMemory]) -> Vec<ExtractedMemory> {
120 memories
121 .iter()
122 .filter(|m| m.confidence >= self.config.quality_threshold)
123 .cloned()
124 .collect()
125 }
126
127 pub fn check_contradictions(
130 &self,
131 new_memory: &ExtractedMemory,
132 existing: &[MemoryNode],
133 embedding_provider: &dyn EmbeddingProvider,
134 ) -> Vec<CognitiveFinding> {
135 if !self.config.enable_contradiction_check || existing.is_empty() {
136 return Vec::new();
137 }
138
139 let embedding = match embedding_provider.embed(&new_memory.content) {
140 Ok(e) => e,
141 Err(err) => {
142 tracing::warn!(error = %err, "failed to embed memory for contradiction check");
143 return Vec::new();
144 }
145 };
146
147 let memory_type = map_extraction_type_to_memory_type(&new_memory.memory_type);
148 let temp_node = MemoryNode::new(
149 AgentId::nil(),
150 memory_type,
151 new_memory.content.clone(),
152 embedding,
153 );
154
155 let engine = WriteInferenceEngine::new();
156 let actions = engine.infer_on_write(&temp_node, existing, &[]);
157
158 let mut findings = Vec::new();
159 for action in actions {
160 match action {
161 InferredAction::FlagContradiction {
162 existing: existing_id,
163 reason,
164 ..
165 } => {
166 findings.push(CognitiveFinding {
167 finding_type: CognitiveFindingType::Contradiction,
168 description: reason,
169 related_memory_id: Some(existing_id),
170 });
171 }
172 InferredAction::MarkObsolete {
173 memory,
174 superseded_by: _,
175 } => {
176 findings.push(CognitiveFinding {
177 finding_type: CognitiveFindingType::Obsolescence,
178 description: format!("Memory {memory} may be obsolete"),
179 related_memory_id: Some(memory),
180 });
181 }
182 InferredAction::UpdateConfidence {
183 memory,
184 new_confidence,
185 } => {
186 findings.push(CognitiveFinding {
187 finding_type: CognitiveFindingType::ConfidenceUpdate,
188 description: format!(
189 "Confidence for {memory} should be updated to {new_confidence:.2}"
190 ),
191 related_memory_id: Some(memory),
192 });
193 }
194 InferredAction::CreateEdge { target, .. } => {
195 findings.push(CognitiveFinding {
196 finding_type: CognitiveFindingType::Related,
197 description: format!("Related to existing memory {target}"),
198 related_memory_id: Some(target),
199 });
200 }
201 _ => {}
202 }
203 }
204
205 findings
206 }
207
208 pub fn check_duplicates(
211 &self,
212 new_memory: &ExtractedMemory,
213 existing: &[MemoryNode],
214 embedding_provider: &dyn EmbeddingProvider,
215 ) -> bool {
216 if !self.config.enable_deduplication || existing.is_empty() {
217 return false;
218 }
219
220 let new_embedding = match embedding_provider.embed(&new_memory.content) {
221 Ok(e) => e,
222 Err(err) => {
223 tracing::warn!(error = %err, "failed to embed memory for dedup check");
224 return false;
225 }
226 };
227
228 for mem in existing {
229 let sim = cosine_similarity(&new_embedding, &mem.embedding);
230 if sim >= self.config.deduplication_threshold {
231 tracing::debug!(
232 similarity = sim,
233 threshold = self.config.deduplication_threshold,
234 existing_id = %mem.id,
235 "duplicate detected"
236 );
237 return true;
238 }
239 }
240
241 false
242 }
243
244 pub async fn process(
247 &self,
248 conversation: &str,
249 existing_memories: &[MemoryNode],
250 embedding_provider: &dyn EmbeddingProvider,
251 ) -> Result<ProcessedExtractionResult, ExtractionError> {
252 let all_memories = self.extract_from_conversation(conversation).await?;
253 let total_extracted = all_memories.len();
254
255 let quality_passed = self.filter_quality(&all_memories);
256 let rejected_low_quality: Vec<ExtractedMemory> = all_memories
257 .iter()
258 .filter(|m| m.confidence < self.config.quality_threshold)
259 .cloned()
260 .collect();
261
262 let mut to_store = Vec::new();
263 let mut rejected_duplicate = Vec::new();
264 let mut contradictions = Vec::new();
265
266 for memory in quality_passed {
267 if self.check_duplicates(&memory, existing_memories, embedding_provider) {
268 rejected_duplicate.push(memory);
269 continue;
270 }
271
272 let findings =
273 self.check_contradictions(&memory, existing_memories, embedding_provider);
274 let has_contradiction = findings
275 .iter()
276 .any(|f| f.finding_type == CognitiveFindingType::Contradiction);
277
278 if has_contradiction {
279 contradictions.push((memory, findings));
280 } else {
281 to_store.push(memory);
282 }
283 }
284
285 let stats = ExtractionStats {
286 total_extracted,
287 accepted: to_store.len(),
288 rejected_quality: rejected_low_quality.len(),
289 rejected_duplicate: rejected_duplicate.len(),
290 contradictions_found: contradictions.len(),
291 };
292
293 tracing::info!(
294 total = stats.total_extracted,
295 accepted = stats.accepted,
296 rejected_quality = stats.rejected_quality,
297 rejected_duplicate = stats.rejected_duplicate,
298 contradictions = stats.contradictions_found,
299 "extraction pipeline complete"
300 );
301
302 Ok(ProcessedExtractionResult {
303 to_store,
304 rejected_low_quality,
305 rejected_duplicate,
306 contradictions,
307 stats,
308 })
309 }
310}
311
312pub fn map_extraction_type_to_memory_type(
314 extraction_type: &str,
315) -> mentedb_core::memory::MemoryType {
316 use mentedb_core::memory::MemoryType;
317 match extraction_type.to_lowercase().as_str() {
318 "decision" | "preference" | "fact" | "entity" => MemoryType::Semantic,
319 "correction" => MemoryType::Correction,
320 "anti_pattern" => MemoryType::AntiPattern,
321 _ => MemoryType::Episodic,
322 }
323}
324
325fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
326 if a.len() != b.len() || a.is_empty() {
327 return 0.0;
328 }
329 let mut dot = 0.0f32;
330 let mut norm_a = 0.0f32;
331 let mut norm_b = 0.0f32;
332 for i in 0..a.len() {
333 dot += a[i] * b[i];
334 norm_a += a[i] * a[i];
335 norm_b += b[i] * b[i];
336 }
337 let denom = norm_a.sqrt() * norm_b.sqrt();
338 if denom == 0.0 { 0.0 } else { dot / denom }
339}