Skip to main content

codemem_engine/
memory_ops.rs

1use crate::scoring;
2use crate::CodememEngine;
3use crate::SplitPart;
4use codemem_core::{
5    CodememError, Edge, GraphBackend, MemoryNode, MemoryType, RelationshipType, VectorBackend,
6};
7use std::collections::HashMap;
8use std::sync::atomic::Ordering;
9
10impl CodememEngine {
11    // ── Persistence ─────────────────────────────────────────────────────
12
13    /// Persist a memory through the full pipeline: storage → BM25 → graph → embedding → vector.
14    pub fn persist_memory(&self, memory: &MemoryNode) -> Result<(), CodememError> {
15        self.persist_memory_inner(memory, true)
16    }
17
18    /// Persist a memory without saving the vector index to disk.
19    /// Use this in batch operations, then call `save_index()` once at the end.
20    pub(crate) fn persist_memory_no_save(&self, memory: &MemoryNode) -> Result<(), CodememError> {
21        self.persist_memory_inner(memory, false)
22    }
23
24    /// Inner persist implementation with optional index save.
25    ///
26    /// H3: Lock ordering is enforced to prevent deadlocks:
27    /// 1. Embeddings lock (acquire, embed, drop)
28    /// 2. BM25 lock
29    /// 3. Graph lock
30    /// 4. Vector lock
31    fn persist_memory_inner(&self, memory: &MemoryNode, save: bool) -> Result<(), CodememError> {
32        // Auto-populate session_id from the engine's active session if not already set
33        let memory = if memory.session_id.is_none() {
34            if let Some(active_sid) = self.active_session_id() {
35                let mut m = memory.clone();
36                m.session_id = Some(active_sid);
37                std::borrow::Cow::Owned(m)
38            } else {
39                std::borrow::Cow::Borrowed(memory)
40            }
41        } else {
42            std::borrow::Cow::Borrowed(memory)
43        };
44        let memory = memory.as_ref();
45
46        // H3: Step 1 — Embed if the provider is already loaded (don't trigger lazy init).
47        // Lifecycle hooks skip embedding for speed; the provider gets initialized on
48        // first recall/search, and backfill_embeddings() picks up any gaps.
49        let embedding_result = if self.embeddings_ready() {
50            match self.lock_embeddings() {
51                Ok(Some(emb)) => {
52                    let enriched = self.enrich_memory_text(
53                        &memory.content,
54                        memory.memory_type,
55                        &memory.tags,
56                        memory.namespace.as_deref(),
57                        Some(&memory.id),
58                    );
59                    let result = emb.embed(&enriched).ok();
60                    drop(emb);
61                    result
62                }
63                Ok(None) => None,
64                Err(e) => {
65                    tracing::warn!("Embeddings lock failed during persist: {e}");
66                    None
67                }
68            }
69        } else {
70            None
71        };
72
73        // 2F: Wrap all SQLite mutations in a single transaction so that the
74        // database cannot be left in an inconsistent state if one step fails.
75        // The HNSW vector index is NOT in SQLite, so vector insertion happens
76        // after commit — if it fails, the memory is still persisted without
77        // its embedding, which is recoverable.
78        self.storage.begin_transaction()?;
79
80        let result = self.persist_memory_sqlite(memory, &embedding_result);
81
82        match result {
83            Ok(()) => {
84                self.storage.commit_transaction()?;
85            }
86            Err(e) => {
87                if let Err(rb_err) = self.storage.rollback_transaction() {
88                    tracing::error!("Failed to rollback transaction after persist error: {rb_err}");
89                }
90                return Err(e);
91            }
92        }
93
94        // 2. Update BM25 index if already loaded (don't trigger lazy init).
95        // The BM25 index rebuilds from all memories on first access anyway.
96        if self.bm25_ready() {
97            match self.lock_bm25() {
98                Ok(mut bm25) => {
99                    bm25.add_document(&memory.id, &memory.content);
100                }
101                Err(e) => tracing::warn!("BM25 lock failed during persist: {e}"),
102            }
103        }
104
105        // 3. Add memory node to in-memory graph (already persisted to SQLite above)
106        match self.lock_graph() {
107            Ok(mut graph) => {
108                let node = codemem_core::GraphNode {
109                    id: memory.id.clone(),
110                    kind: codemem_core::NodeKind::Memory,
111                    label: scoring::truncate_content(&memory.content, 80),
112                    payload: std::collections::HashMap::new(),
113                    centrality: 0.0,
114                    memory_id: Some(memory.id.clone()),
115                    namespace: memory.namespace.clone(),
116                };
117                if let Err(e) = graph.add_node(node) {
118                    tracing::warn!(
119                        "Failed to add graph node in-memory for memory {}: {e}",
120                        memory.id
121                    );
122                }
123            }
124            Err(e) => tracing::warn!("Graph lock failed during persist: {e}"),
125        }
126
127        // 3b. Auto-link to memories with shared tags (session co-membership, topic overlap)
128        self.auto_link_by_tags(memory);
129
130        // H3: Step 4 — Insert embedding into HNSW vector index if already loaded.
131        if let Some(vec) = &embedding_result {
132            if self.vector_ready() {
133                if let Ok(mut vi) = self.lock_vector() {
134                    if let Err(e) = vi.insert(&memory.id, vec) {
135                        tracing::warn!("Failed to insert into vector index for {}: {e}", memory.id);
136                    }
137                }
138            }
139        }
140
141        // C5: Set dirty flag instead of calling save_index() after each persist.
142        // Callers should use flush_if_dirty() to batch save the index.
143        if save {
144            self.save_index(); // save_index() clears dirty flag
145        } else {
146            self.dirty.store(true, Ordering::Release);
147        }
148
149        Ok(())
150    }
151
152    /// Execute all SQLite mutations for a memory persist.
153    ///
154    /// Called within the transaction opened by `persist_memory_inner`.
155    /// Inserts the memory row, graph node, and embedding (if available).
156    fn persist_memory_sqlite(
157        &self,
158        memory: &MemoryNode,
159        embedding: &Option<Vec<f32>>,
160    ) -> Result<(), CodememError> {
161        // 1. Store memory in SQLite
162        self.storage.insert_memory(memory)?;
163
164        // 2. Insert graph node in SQLite
165        let node = codemem_core::GraphNode {
166            id: memory.id.clone(),
167            kind: codemem_core::NodeKind::Memory,
168            label: scoring::truncate_content(&memory.content, 80),
169            payload: std::collections::HashMap::new(),
170            centrality: 0.0,
171            memory_id: Some(memory.id.clone()),
172            namespace: memory.namespace.clone(),
173        };
174        if let Err(e) = self.storage.insert_graph_node(&node) {
175            tracing::warn!("Failed to insert graph node for memory {}: {e}", memory.id);
176        }
177
178        // 3. Store embedding in SQLite (vector blob, not HNSW index)
179        if let Some(vec) = embedding {
180            if let Err(e) = self.storage.store_embedding(&memory.id, vec) {
181                tracing::warn!("Failed to store embedding for {}: {e}", memory.id);
182            }
183        }
184
185        Ok(())
186    }
187
188    // ── Store with Links ──────────────────────────────────────────────────
189
190    /// Store a memory with optional explicit link IDs.
191    ///
192    /// Runs the full pipeline: persist → explicit RELATES_TO edges → auto-link
193    /// to code nodes → save index. This consolidates domain logic that was
194    /// previously spread across the MCP transport layer.
195    pub fn store_memory_with_links(
196        &self,
197        memory: &MemoryNode,
198        links: &[String],
199    ) -> Result<(), CodememError> {
200        self.persist_memory(memory)?;
201
202        // Create RELATES_TO edges for explicit links
203        if !links.is_empty() {
204            let now = chrono::Utc::now();
205            let mut graph = self.lock_graph()?;
206            for link_id in links {
207                let edge = Edge {
208                    id: format!("{}-RELATES_TO-{link_id}", memory.id),
209                    src: memory.id.clone(),
210                    dst: link_id.clone(),
211                    relationship: RelationshipType::RelatesTo,
212                    weight: 1.0,
213                    properties: HashMap::new(),
214                    created_at: now,
215                    valid_from: None,
216                    valid_to: None,
217                };
218                if let Err(e) = self.storage.insert_graph_edge(&edge) {
219                    tracing::warn!("Failed to persist link edge to {link_id}: {e}");
220                }
221                if let Err(e) = graph.add_edge(edge) {
222                    tracing::warn!("Failed to add link edge to {link_id}: {e}");
223                }
224            }
225        }
226
227        // Auto-link to code nodes mentioned in content
228        self.auto_link_to_code_nodes(&memory.id, &memory.content, links);
229
230        Ok(())
231    }
232
233    // ── Edge Helpers ─────────────────────────────────────────────────────
234
235    /// Add an edge to both storage and in-memory graph.
236    pub fn add_edge(&self, edge: Edge) -> Result<(), CodememError> {
237        self.storage.insert_graph_edge(&edge)?;
238        let mut graph = self.lock_graph()?;
239        graph.add_edge(edge)?;
240        Ok(())
241    }
242
243    // ── Self-Editing ────────────────────────────────────────────────────
244
245    /// Refine a memory: create a new version with an EVOLVED_INTO edge from old to new.
246    pub fn refine_memory(
247        &self,
248        old_id: &str,
249        content: Option<&str>,
250        tags: Option<Vec<String>>,
251        importance: Option<f64>,
252    ) -> Result<(MemoryNode, String), CodememError> {
253        let old_memory = self
254            .storage
255            .get_memory(old_id)?
256            .ok_or_else(|| CodememError::NotFound(format!("Memory not found: {old_id}")))?;
257
258        let new_content = content.unwrap_or(&old_memory.content);
259        let new_tags = tags.unwrap_or_else(|| old_memory.tags.clone());
260        let new_importance = importance.unwrap_or(old_memory.importance);
261
262        let mut memory = MemoryNode::new(new_content, old_memory.memory_type);
263        let new_id = memory.id.clone();
264        memory.importance = new_importance;
265        memory.confidence = old_memory.confidence;
266        memory.tags = new_tags;
267        memory.metadata = old_memory.metadata.clone();
268        memory.namespace = old_memory.namespace.clone();
269
270        self.persist_memory(&memory)?;
271
272        // Create EVOLVED_INTO edge from old -> new
273        let now = chrono::Utc::now();
274        let edge = Edge {
275            id: format!("{old_id}-EVOLVED_INTO-{new_id}"),
276            src: old_id.to_string(),
277            dst: new_id.clone(),
278            relationship: RelationshipType::EvolvedInto,
279            weight: 1.0,
280            properties: std::collections::HashMap::new(),
281            created_at: now,
282            valid_from: Some(now),
283            valid_to: None,
284        };
285        if let Err(e) = self.add_edge(edge) {
286            tracing::warn!("Failed to add EVOLVED_INTO edge: {e}");
287        }
288
289        Ok((memory, new_id))
290    }
291
292    /// Split a memory into multiple parts, each linked via PART_OF edges.
293    pub fn split_memory(
294        &self,
295        source_id: &str,
296        parts: &[SplitPart],
297    ) -> Result<Vec<String>, CodememError> {
298        let source_memory = self
299            .storage
300            .get_memory(source_id)?
301            .ok_or_else(|| CodememError::NotFound(format!("Memory not found: {source_id}")))?;
302
303        if parts.is_empty() {
304            return Err(CodememError::InvalidInput(
305                "'parts' array must not be empty".to_string(),
306            ));
307        }
308
309        // Validate all parts upfront before persisting anything
310        for part in parts {
311            if part.content.is_empty() {
312                return Err(CodememError::InvalidInput(
313                    "Each part must have a non-empty 'content' field".to_string(),
314                ));
315            }
316        }
317
318        let now = chrono::Utc::now();
319        let mut child_ids: Vec<String> = Vec::new();
320
321        for part in parts {
322            let tags = part
323                .tags
324                .clone()
325                .unwrap_or_else(|| source_memory.tags.clone());
326            let importance = part.importance.unwrap_or(source_memory.importance);
327
328            let mut memory = MemoryNode::new(part.content.clone(), source_memory.memory_type);
329            let child_id = memory.id.clone();
330            memory.importance = importance;
331            memory.confidence = source_memory.confidence;
332            memory.tags = tags;
333            memory.namespace = source_memory.namespace.clone();
334
335            if let Err(e) = self.persist_memory_no_save(&memory) {
336                // Clean up already-created child memories
337                for created_id in &child_ids {
338                    if let Err(del_err) = self.delete_memory(created_id) {
339                        tracing::warn!(
340                            "Failed to clean up child memory {created_id} after split failure: {del_err}"
341                        );
342                    }
343                }
344                return Err(e);
345            }
346
347            // Create PART_OF edge: child -> source
348            let edge = Edge {
349                id: format!("{child_id}-PART_OF-{source_id}"),
350                src: child_id.clone(),
351                dst: source_id.to_string(),
352                relationship: RelationshipType::PartOf,
353                weight: 1.0,
354                properties: std::collections::HashMap::new(),
355                created_at: now,
356                valid_from: Some(now),
357                valid_to: None,
358            };
359            if let Err(e) = self.add_edge(edge) {
360                tracing::warn!("Failed to add PART_OF edge: {e}");
361            }
362
363            child_ids.push(child_id);
364        }
365
366        self.save_index();
367        Ok(child_ids)
368    }
369
370    /// Merge multiple memories into one, linked via SUMMARIZES edges.
371    pub fn merge_memories(
372        &self,
373        source_ids: &[String],
374        content: &str,
375        memory_type: MemoryType,
376        importance: f64,
377        tags: Vec<String>,
378    ) -> Result<String, CodememError> {
379        if source_ids.len() < 2 {
380            return Err(CodememError::InvalidInput(
381                "'source_ids' must contain at least 2 IDs".to_string(),
382            ));
383        }
384
385        // Verify all sources exist
386        let id_refs: Vec<&str> = source_ids.iter().map(|s| s.as_str()).collect();
387        let found = self.storage.get_memories_batch(&id_refs)?;
388        if found.len() != source_ids.len() {
389            let found_ids: std::collections::HashSet<&str> =
390                found.iter().map(|m| m.id.as_str()).collect();
391            let missing: Vec<&str> = id_refs
392                .iter()
393                .filter(|id| !found_ids.contains(**id))
394                .copied()
395                .collect();
396            return Err(CodememError::NotFound(format!(
397                "Source memories not found: {}",
398                missing.join(", ")
399            )));
400        }
401
402        let mut memory = MemoryNode::new(content, memory_type);
403        let merged_id = memory.id.clone();
404        memory.importance = importance;
405        memory.confidence = found.iter().map(|m| m.confidence).sum::<f64>() / found.len() as f64;
406        memory.tags = tags;
407        memory.namespace = found.iter().find_map(|m| m.namespace.clone());
408
409        self.persist_memory_no_save(&memory)?;
410
411        // Create SUMMARIZES edges: merged -> each source
412        let now = chrono::Utc::now();
413        for source_id in source_ids {
414            let edge = Edge {
415                id: format!("{merged_id}-SUMMARIZES-{source_id}"),
416                src: merged_id.clone(),
417                dst: source_id.clone(),
418                relationship: RelationshipType::Summarizes,
419                weight: 1.0,
420                properties: std::collections::HashMap::new(),
421                created_at: now,
422                valid_from: Some(now),
423                valid_to: None,
424            };
425            if let Err(e) = self.add_edge(edge) {
426                tracing::warn!("Failed to add SUMMARIZES edge to {source_id}: {e}");
427            }
428        }
429
430        self.save_index();
431        Ok(merged_id)
432    }
433
434    /// Update a memory's content and/or importance, re-embedding if needed.
435    pub fn update_memory(
436        &self,
437        id: &str,
438        content: &str,
439        importance: Option<f64>,
440    ) -> Result<(), CodememError> {
441        self.storage.update_memory(id, content, importance)?;
442
443        // Update BM25 index
444        self.lock_bm25()?.add_document(id, content);
445
446        // Update graph node label
447        if let Ok(mut graph) = self.lock_graph() {
448            if let Ok(Some(mut node)) = graph.get_node(id) {
449                node.label = scoring::truncate_content(content, 80);
450                if let Err(e) = graph.add_node(node) {
451                    tracing::warn!("Failed to update graph node for {id}: {e}");
452                }
453            }
454        }
455
456        // Re-embed with contextual enrichment
457        // H3: Acquire embeddings lock, embed, drop lock before acquiring vector lock.
458        if let Some(emb_guard) = self.lock_embeddings()? {
459            let (mem_type, tags, namespace) =
460                if let Ok(Some(mem)) = self.storage.get_memory_no_touch(id) {
461                    (mem.memory_type, mem.tags, mem.namespace)
462                } else {
463                    (MemoryType::Context, vec![], None)
464                };
465            let enriched =
466                self.enrich_memory_text(content, mem_type, &tags, namespace.as_deref(), Some(id));
467            let emb_result = emb_guard.embed(&enriched);
468            drop(emb_guard);
469            if let Ok(embedding) = emb_result {
470                if let Err(e) = self.storage.store_embedding(id, &embedding) {
471                    tracing::warn!("Failed to store embedding for {id}: {e}");
472                }
473                let mut vec = self.lock_vector()?;
474                if let Err(e) = vec.remove(id) {
475                    tracing::warn!("Failed to remove old vector for {id}: {e}");
476                }
477                if let Err(e) = vec.insert(id, &embedding) {
478                    tracing::warn!("Failed to insert new vector for {id}: {e}");
479                }
480            }
481        }
482
483        self.save_index();
484        Ok(())
485    }
486
487    /// Update only the importance of a memory.
488    /// Routes through the engine to maintain the transport → engine → storage boundary.
489    pub fn update_importance(&self, id: &str, importance: f64) -> Result<(), CodememError> {
490        self.storage
491            .batch_update_importance(&[(id.to_string(), importance)])?;
492        Ok(())
493    }
494
495    /// Delete a memory from all subsystems.
496    ///
497    /// M1: Uses `delete_memory_cascade` on the storage backend to wrap all
498    /// SQLite deletes (memory + graph nodes/edges + embedding) in a single
499    /// transaction when the backend supports it. In-memory structures
500    /// (vector, graph, BM25) are cleaned up separately with proper lock ordering.
501    pub fn delete_memory(&self, id: &str) -> Result<bool, CodememError> {
502        // Use cascade delete for all storage-side operations in a single transaction.
503        let deleted = self.storage.delete_memory_cascade(id)?;
504        if !deleted {
505            return Ok(false);
506        }
507
508        // Clean up in-memory structures with proper lock ordering:
509        // vector first, then graph, then BM25.
510        let mut vec = self.lock_vector()?;
511        if let Err(e) = vec.remove(id) {
512            tracing::warn!("Failed to remove {id} from vector index: {e}");
513        }
514        drop(vec);
515
516        let mut graph = self.lock_graph()?;
517        if let Err(e) = graph.remove_node(id) {
518            tracing::warn!("Failed to remove {id} from in-memory graph: {e}");
519        }
520        drop(graph);
521
522        self.lock_bm25()?.remove_document(id);
523
524        // Persist vector index to disk
525        self.save_index();
526        Ok(true)
527    }
528}