Skip to main content

codemem_engine/
memory_ops.rs

1use crate::scoring;
2use crate::CodememEngine;
3use crate::SplitPart;
4use codemem_core::{
5    CodememError, Edge, GraphBackend, MemoryNode, MemoryType, RelationshipType, VectorBackend,
6};
7use std::collections::HashMap;
8use std::sync::atomic::Ordering;
9
10impl CodememEngine {
11    // ── Persistence ─────────────────────────────────────────────────────
12
13    /// Persist a memory through the full pipeline: storage → BM25 → graph → embedding → vector.
14    pub fn persist_memory(&self, memory: &MemoryNode) -> Result<(), CodememError> {
15        self.persist_memory_inner(memory, true)
16    }
17
18    /// Persist a memory without saving the vector index to disk.
19    /// Use this in batch operations, then call `save_index()` once at the end.
20    pub(crate) fn persist_memory_no_save(&self, memory: &MemoryNode) -> Result<(), CodememError> {
21        self.persist_memory_inner(memory, false)
22    }
23
24    /// Inner persist implementation with optional index save.
25    ///
26    /// H3: Lock ordering is enforced to prevent deadlocks:
27    /// 1. Embeddings lock (acquire, embed, drop)
28    /// 2. BM25 lock
29    /// 3. Graph lock
30    /// 4. Vector lock
31    fn persist_memory_inner(&self, memory: &MemoryNode, save: bool) -> Result<(), CodememError> {
32        // Auto-populate session_id from the engine's active session if not already set
33        let memory = if memory.session_id.is_none() {
34            if let Some(active_sid) = self.active_session_id() {
35                let mut m = memory.clone();
36                m.session_id = Some(active_sid);
37                std::borrow::Cow::Owned(m)
38            } else {
39                std::borrow::Cow::Borrowed(memory)
40            }
41        } else {
42            std::borrow::Cow::Borrowed(memory)
43        };
44        let memory = memory.as_ref();
45
46        // H3: Step 1 — Acquire embeddings lock first, embed, save result, drop lock.
47        // This prevents holding the embeddings lock while acquiring vector/graph locks.
48        // Embedding is computed before the transaction since it doesn't touch SQLite.
49        let embedding_result = match self.lock_embeddings() {
50            Ok(Some(emb)) => {
51                let enriched = self.enrich_memory_text(
52                    &memory.content,
53                    memory.memory_type,
54                    &memory.tags,
55                    memory.namespace.as_deref(),
56                    Some(&memory.id),
57                );
58                let result = emb.embed(&enriched).ok();
59                drop(emb);
60                result
61            }
62            Ok(None) => None,
63            Err(e) => {
64                tracing::warn!("Embeddings lock failed during persist: {e}");
65                None
66            }
67        };
68
69        // 2F: Wrap all SQLite mutations in a single transaction so that the
70        // database cannot be left in an inconsistent state if one step fails.
71        // The HNSW vector index is NOT in SQLite, so vector insertion happens
72        // after commit — if it fails, the memory is still persisted without
73        // its embedding, which is recoverable.
74        self.storage.begin_transaction()?;
75
76        let result = self.persist_memory_sqlite(memory, &embedding_result);
77
78        match result {
79            Ok(()) => {
80                self.storage.commit_transaction()?;
81            }
82            Err(e) => {
83                if let Err(rb_err) = self.storage.rollback_transaction() {
84                    tracing::error!("Failed to rollback transaction after persist error: {rb_err}");
85                }
86                return Err(e);
87            }
88        }
89
90        // 2. Update BM25 index (in-memory, not SQLite)
91        match self.lock_bm25() {
92            Ok(mut bm25) => {
93                bm25.add_document(&memory.id, &memory.content);
94            }
95            Err(e) => tracing::warn!("BM25 lock failed during persist: {e}"),
96        }
97
98        // 3. Add memory node to in-memory graph (already persisted to SQLite above)
99        match self.lock_graph() {
100            Ok(mut graph) => {
101                let node = codemem_core::GraphNode {
102                    id: memory.id.clone(),
103                    kind: codemem_core::NodeKind::Memory,
104                    label: scoring::truncate_content(&memory.content, 80),
105                    payload: std::collections::HashMap::new(),
106                    centrality: 0.0,
107                    memory_id: Some(memory.id.clone()),
108                    namespace: memory.namespace.clone(),
109                };
110                if let Err(e) = graph.add_node(node) {
111                    tracing::warn!(
112                        "Failed to add graph node in-memory for memory {}: {e}",
113                        memory.id
114                    );
115                }
116            }
117            Err(e) => tracing::warn!("Graph lock failed during persist: {e}"),
118        }
119
120        // 3b. Auto-link to memories with shared tags (session co-membership, topic overlap)
121        self.auto_link_by_tags(memory);
122
123        // H3: Step 4 — Insert embedding into HNSW vector index (separate from SQLite).
124        if let Some(vec) = &embedding_result {
125            if let Ok(mut vi) = self.lock_vector() {
126                if let Err(e) = vi.insert(&memory.id, vec) {
127                    tracing::warn!("Failed to insert into vector index for {}: {e}", memory.id);
128                }
129            }
130        }
131
132        // C5: Set dirty flag instead of calling save_index() after each persist.
133        // Callers should use flush_if_dirty() to batch save the index.
134        if save {
135            self.save_index(); // save_index() clears dirty flag
136        } else {
137            self.dirty.store(true, Ordering::Release);
138        }
139
140        Ok(())
141    }
142
143    /// Execute all SQLite mutations for a memory persist.
144    ///
145    /// Called within the transaction opened by `persist_memory_inner`.
146    /// Inserts the memory row, graph node, and embedding (if available).
147    fn persist_memory_sqlite(
148        &self,
149        memory: &MemoryNode,
150        embedding: &Option<Vec<f32>>,
151    ) -> Result<(), CodememError> {
152        // 1. Store memory in SQLite
153        self.storage.insert_memory(memory)?;
154
155        // 2. Insert graph node in SQLite
156        let node = codemem_core::GraphNode {
157            id: memory.id.clone(),
158            kind: codemem_core::NodeKind::Memory,
159            label: scoring::truncate_content(&memory.content, 80),
160            payload: std::collections::HashMap::new(),
161            centrality: 0.0,
162            memory_id: Some(memory.id.clone()),
163            namespace: memory.namespace.clone(),
164        };
165        if let Err(e) = self.storage.insert_graph_node(&node) {
166            tracing::warn!("Failed to insert graph node for memory {}: {e}", memory.id);
167        }
168
169        // 3. Store embedding in SQLite (vector blob, not HNSW index)
170        if let Some(vec) = embedding {
171            if let Err(e) = self.storage.store_embedding(&memory.id, vec) {
172                tracing::warn!("Failed to store embedding for {}: {e}", memory.id);
173            }
174        }
175
176        Ok(())
177    }
178
179    // ── Store with Links ──────────────────────────────────────────────────
180
181    /// Store a memory with optional explicit link IDs.
182    ///
183    /// Runs the full pipeline: persist → explicit RELATES_TO edges → auto-link
184    /// to code nodes → save index. This consolidates domain logic that was
185    /// previously spread across the MCP transport layer.
186    pub fn store_memory_with_links(
187        &self,
188        memory: &MemoryNode,
189        links: &[String],
190    ) -> Result<(), CodememError> {
191        self.persist_memory(memory)?;
192
193        // Create RELATES_TO edges for explicit links
194        if !links.is_empty() {
195            let now = chrono::Utc::now();
196            let mut graph = self.lock_graph()?;
197            for link_id in links {
198                let edge = Edge {
199                    id: format!("{}-RELATES_TO-{link_id}", memory.id),
200                    src: memory.id.clone(),
201                    dst: link_id.clone(),
202                    relationship: RelationshipType::RelatesTo,
203                    weight: 1.0,
204                    properties: HashMap::new(),
205                    created_at: now,
206                    valid_from: None,
207                    valid_to: None,
208                };
209                if let Err(e) = self.storage.insert_graph_edge(&edge) {
210                    tracing::warn!("Failed to persist link edge to {link_id}: {e}");
211                }
212                if let Err(e) = graph.add_edge(edge) {
213                    tracing::warn!("Failed to add link edge to {link_id}: {e}");
214                }
215            }
216        }
217
218        // Auto-link to code nodes mentioned in content
219        self.auto_link_to_code_nodes(&memory.id, &memory.content, links);
220
221        Ok(())
222    }
223
224    // ── Edge Helpers ─────────────────────────────────────────────────────
225
226    /// Add an edge to both storage and in-memory graph.
227    pub fn add_edge(&self, edge: Edge) -> Result<(), CodememError> {
228        self.storage.insert_graph_edge(&edge)?;
229        let mut graph = self.lock_graph()?;
230        graph.add_edge(edge)?;
231        Ok(())
232    }
233
234    // ── Self-Editing ────────────────────────────────────────────────────
235
236    /// Refine a memory: create a new version with an EVOLVED_INTO edge from old to new.
237    pub fn refine_memory(
238        &self,
239        old_id: &str,
240        content: Option<&str>,
241        tags: Option<Vec<String>>,
242        importance: Option<f64>,
243    ) -> Result<(MemoryNode, String), CodememError> {
244        let old_memory = self
245            .storage
246            .get_memory(old_id)?
247            .ok_or_else(|| CodememError::NotFound(format!("Memory not found: {old_id}")))?;
248
249        let new_content = content.unwrap_or(&old_memory.content);
250        let new_tags = tags.unwrap_or_else(|| old_memory.tags.clone());
251        let new_importance = importance.unwrap_or(old_memory.importance);
252
253        let mut memory = MemoryNode::new(new_content, old_memory.memory_type);
254        let new_id = memory.id.clone();
255        memory.importance = new_importance;
256        memory.confidence = old_memory.confidence;
257        memory.tags = new_tags;
258        memory.metadata = old_memory.metadata.clone();
259        memory.namespace = old_memory.namespace.clone();
260
261        self.persist_memory(&memory)?;
262
263        // Create EVOLVED_INTO edge from old -> new
264        let now = chrono::Utc::now();
265        let edge = Edge {
266            id: format!("{old_id}-EVOLVED_INTO-{new_id}"),
267            src: old_id.to_string(),
268            dst: new_id.clone(),
269            relationship: RelationshipType::EvolvedInto,
270            weight: 1.0,
271            properties: std::collections::HashMap::new(),
272            created_at: now,
273            valid_from: Some(now),
274            valid_to: None,
275        };
276        if let Err(e) = self.add_edge(edge) {
277            tracing::warn!("Failed to add EVOLVED_INTO edge: {e}");
278        }
279
280        Ok((memory, new_id))
281    }
282
283    /// Split a memory into multiple parts, each linked via PART_OF edges.
284    pub fn split_memory(
285        &self,
286        source_id: &str,
287        parts: &[SplitPart],
288    ) -> Result<Vec<String>, CodememError> {
289        let source_memory = self
290            .storage
291            .get_memory(source_id)?
292            .ok_or_else(|| CodememError::NotFound(format!("Memory not found: {source_id}")))?;
293
294        if parts.is_empty() {
295            return Err(CodememError::InvalidInput(
296                "'parts' array must not be empty".to_string(),
297            ));
298        }
299
300        // Validate all parts upfront before persisting anything
301        for part in parts {
302            if part.content.is_empty() {
303                return Err(CodememError::InvalidInput(
304                    "Each part must have a non-empty 'content' field".to_string(),
305                ));
306            }
307        }
308
309        let now = chrono::Utc::now();
310        let mut child_ids: Vec<String> = Vec::new();
311
312        for part in parts {
313            let tags = part
314                .tags
315                .clone()
316                .unwrap_or_else(|| source_memory.tags.clone());
317            let importance = part.importance.unwrap_or(source_memory.importance);
318
319            let mut memory = MemoryNode::new(part.content.clone(), source_memory.memory_type);
320            let child_id = memory.id.clone();
321            memory.importance = importance;
322            memory.confidence = source_memory.confidence;
323            memory.tags = tags;
324            memory.namespace = source_memory.namespace.clone();
325
326            if let Err(e) = self.persist_memory_no_save(&memory) {
327                // Clean up already-created child memories
328                for created_id in &child_ids {
329                    if let Err(del_err) = self.delete_memory(created_id) {
330                        tracing::warn!(
331                            "Failed to clean up child memory {created_id} after split failure: {del_err}"
332                        );
333                    }
334                }
335                return Err(e);
336            }
337
338            // Create PART_OF edge: child -> source
339            let edge = Edge {
340                id: format!("{child_id}-PART_OF-{source_id}"),
341                src: child_id.clone(),
342                dst: source_id.to_string(),
343                relationship: RelationshipType::PartOf,
344                weight: 1.0,
345                properties: std::collections::HashMap::new(),
346                created_at: now,
347                valid_from: Some(now),
348                valid_to: None,
349            };
350            if let Err(e) = self.add_edge(edge) {
351                tracing::warn!("Failed to add PART_OF edge: {e}");
352            }
353
354            child_ids.push(child_id);
355        }
356
357        self.save_index();
358        Ok(child_ids)
359    }
360
361    /// Merge multiple memories into one, linked via SUMMARIZES edges.
362    pub fn merge_memories(
363        &self,
364        source_ids: &[String],
365        content: &str,
366        memory_type: MemoryType,
367        importance: f64,
368        tags: Vec<String>,
369    ) -> Result<String, CodememError> {
370        if source_ids.len() < 2 {
371            return Err(CodememError::InvalidInput(
372                "'source_ids' must contain at least 2 IDs".to_string(),
373            ));
374        }
375
376        // Verify all sources exist
377        let id_refs: Vec<&str> = source_ids.iter().map(|s| s.as_str()).collect();
378        let found = self.storage.get_memories_batch(&id_refs)?;
379        if found.len() != source_ids.len() {
380            let found_ids: std::collections::HashSet<&str> =
381                found.iter().map(|m| m.id.as_str()).collect();
382            let missing: Vec<&str> = id_refs
383                .iter()
384                .filter(|id| !found_ids.contains(**id))
385                .copied()
386                .collect();
387            return Err(CodememError::NotFound(format!(
388                "Source memories not found: {}",
389                missing.join(", ")
390            )));
391        }
392
393        let mut memory = MemoryNode::new(content, memory_type);
394        let merged_id = memory.id.clone();
395        memory.importance = importance;
396        memory.confidence = found.iter().map(|m| m.confidence).sum::<f64>() / found.len() as f64;
397        memory.tags = tags;
398        memory.namespace = found.iter().find_map(|m| m.namespace.clone());
399
400        self.persist_memory_no_save(&memory)?;
401
402        // Create SUMMARIZES edges: merged -> each source
403        let now = chrono::Utc::now();
404        for source_id in source_ids {
405            let edge = Edge {
406                id: format!("{merged_id}-SUMMARIZES-{source_id}"),
407                src: merged_id.clone(),
408                dst: source_id.clone(),
409                relationship: RelationshipType::Summarizes,
410                weight: 1.0,
411                properties: std::collections::HashMap::new(),
412                created_at: now,
413                valid_from: Some(now),
414                valid_to: None,
415            };
416            if let Err(e) = self.add_edge(edge) {
417                tracing::warn!("Failed to add SUMMARIZES edge to {source_id}: {e}");
418            }
419        }
420
421        self.save_index();
422        Ok(merged_id)
423    }
424
425    /// Update a memory's content and/or importance, re-embedding if needed.
426    pub fn update_memory(
427        &self,
428        id: &str,
429        content: &str,
430        importance: Option<f64>,
431    ) -> Result<(), CodememError> {
432        self.storage.update_memory(id, content, importance)?;
433
434        // Update BM25 index
435        self.lock_bm25()?.add_document(id, content);
436
437        // Update graph node label
438        if let Ok(mut graph) = self.lock_graph() {
439            if let Ok(Some(mut node)) = graph.get_node(id) {
440                node.label = scoring::truncate_content(content, 80);
441                if let Err(e) = graph.add_node(node) {
442                    tracing::warn!("Failed to update graph node for {id}: {e}");
443                }
444            }
445        }
446
447        // Re-embed with contextual enrichment
448        // H3: Acquire embeddings lock, embed, drop lock before acquiring vector lock.
449        if let Some(emb_guard) = self.lock_embeddings()? {
450            let (mem_type, tags, namespace) =
451                if let Ok(Some(mem)) = self.storage.get_memory_no_touch(id) {
452                    (mem.memory_type, mem.tags, mem.namespace)
453                } else {
454                    (MemoryType::Context, vec![], None)
455                };
456            let enriched =
457                self.enrich_memory_text(content, mem_type, &tags, namespace.as_deref(), Some(id));
458            let emb_result = emb_guard.embed(&enriched);
459            drop(emb_guard);
460            if let Ok(embedding) = emb_result {
461                if let Err(e) = self.storage.store_embedding(id, &embedding) {
462                    tracing::warn!("Failed to store embedding for {id}: {e}");
463                }
464                let mut vec = self.lock_vector()?;
465                if let Err(e) = vec.remove(id) {
466                    tracing::warn!("Failed to remove old vector for {id}: {e}");
467                }
468                if let Err(e) = vec.insert(id, &embedding) {
469                    tracing::warn!("Failed to insert new vector for {id}: {e}");
470                }
471            }
472        }
473
474        self.save_index();
475        Ok(())
476    }
477
478    /// Update only the importance of a memory.
479    /// Routes through the engine to maintain the transport → engine → storage boundary.
480    pub fn update_importance(&self, id: &str, importance: f64) -> Result<(), CodememError> {
481        self.storage
482            .batch_update_importance(&[(id.to_string(), importance)])?;
483        Ok(())
484    }
485
486    /// Delete a memory from all subsystems.
487    ///
488    /// M1: Uses `delete_memory_cascade` on the storage backend to wrap all
489    /// SQLite deletes (memory + graph nodes/edges + embedding) in a single
490    /// transaction when the backend supports it. In-memory structures
491    /// (vector, graph, BM25) are cleaned up separately with proper lock ordering.
492    pub fn delete_memory(&self, id: &str) -> Result<bool, CodememError> {
493        // Use cascade delete for all storage-side operations in a single transaction.
494        let deleted = self.storage.delete_memory_cascade(id)?;
495        if !deleted {
496            return Ok(false);
497        }
498
499        // Clean up in-memory structures with proper lock ordering:
500        // vector first, then graph, then BM25.
501        let mut vec = self.lock_vector()?;
502        if let Err(e) = vec.remove(id) {
503            tracing::warn!("Failed to remove {id} from vector index: {e}");
504        }
505        drop(vec);
506
507        let mut graph = self.lock_graph()?;
508        if let Err(e) = graph.remove_node(id) {
509            tracing::warn!("Failed to remove {id} from in-memory graph: {e}");
510        }
511        drop(graph);
512
513        self.lock_bm25()?.remove_document(id);
514
515        // Persist vector index to disk
516        self.save_index();
517        Ok(true)
518    }
519}