Skip to main content

codemem_engine/
memory_ops.rs

1use crate::scoring;
2use crate::CodememEngine;
3use crate::SplitPart;
4use codemem_core::{CodememError, Edge, MemoryNode, MemoryType, RelationshipType};
5use std::collections::HashMap;
6use std::sync::atomic::Ordering;
7
8impl CodememEngine {
9    // ── Persistence ─────────────────────────────────────────────────────
10
11    /// Persist a memory through the full pipeline: storage → BM25 → graph → embedding → vector.
12    pub fn persist_memory(&self, memory: &MemoryNode) -> Result<(), CodememError> {
13        self.persist_memory_inner(memory, true)
14    }
15
16    /// Persist a memory without saving the vector index to disk.
17    /// Use this in batch operations, then call `save_index()` once at the end.
18    pub(crate) fn persist_memory_no_save(&self, memory: &MemoryNode) -> Result<(), CodememError> {
19        self.persist_memory_inner(memory, false)
20    }
21
22    /// Inner persist implementation with optional index save.
23    ///
24    /// H3: Lock ordering is enforced to prevent deadlocks:
25    /// 1. Embeddings lock (acquire, embed, drop)
26    /// 2. BM25 lock
27    /// 3. Graph lock
28    /// 4. Vector lock
29    fn persist_memory_inner(&self, memory: &MemoryNode, save: bool) -> Result<(), CodememError> {
30        // Auto-populate session_id from the engine's active session if not already set
31        let memory = if memory.session_id.is_none() {
32            if let Some(active_sid) = self.active_session_id() {
33                let mut m = memory.clone();
34                m.session_id = Some(active_sid);
35                std::borrow::Cow::Owned(m)
36            } else {
37                std::borrow::Cow::Borrowed(memory)
38            }
39        } else {
40            std::borrow::Cow::Borrowed(memory)
41        };
42
43        // Auto-set expires_at for session memories if not explicitly set
44        let memory = if memory.expires_at.is_none() && memory.session_id.is_some() {
45            let ttl_hours = self.config.memory.default_session_ttl_hours;
46            if ttl_hours > 0 {
47                let mut m = memory.into_owned();
48                m.expires_at = Some(chrono::Utc::now() + chrono::Duration::hours(ttl_hours as i64));
49                std::borrow::Cow::Owned(m)
50            } else {
51                memory
52            }
53        } else {
54            memory
55        };
56
57        // Auto-populate repo/git_ref from engine scope if not already set
58        let memory = if memory.repo.is_none() {
59            if let Some(scope) = self.scope() {
60                let mut m = memory.into_owned();
61                m.repo = Some(scope.repo.clone());
62                m.git_ref = Some(scope.git_ref.clone());
63                std::borrow::Cow::Owned(m)
64            } else {
65                memory
66            }
67        } else {
68            memory
69        };
70        let memory = memory.as_ref();
71
72        // H3: Step 1 — Embed if the provider is already loaded (don't trigger lazy init).
73        // Lifecycle hooks skip embedding for speed; the provider gets initialized on
74        // first recall/search, and backfill_embeddings() picks up any gaps.
75        let embedding_result = if self.embeddings_ready() {
76            match self.lock_embeddings() {
77                Ok(Some(emb)) => {
78                    let enriched = self.enrich_memory_text(
79                        &memory.content,
80                        memory.memory_type,
81                        &memory.tags,
82                        memory.namespace.as_deref(),
83                        Some(&memory.id),
84                    );
85                    let result = emb.embed(&enriched).ok();
86                    drop(emb);
87                    result
88                }
89                Ok(None) => None,
90                Err(e) => {
91                    tracing::warn!("Embeddings lock failed during persist: {e}");
92                    None
93                }
94            }
95        } else {
96            None
97        };
98
99        // 2F: Wrap all SQLite mutations in a single transaction so that the
100        // database cannot be left in an inconsistent state if one step fails.
101        // The HNSW vector index is NOT in SQLite, so vector insertion happens
102        // after commit — if it fails, the memory is still persisted without
103        // its embedding, which is recoverable.
104        self.storage.begin_transaction()?;
105
106        let result = self.persist_memory_sqlite(memory, &embedding_result);
107
108        match result {
109            Ok(()) => {
110                self.storage.commit_transaction()?;
111            }
112            Err(e) => {
113                if let Err(rb_err) = self.storage.rollback_transaction() {
114                    tracing::error!("Failed to rollback transaction after persist error: {rb_err}");
115                }
116                return Err(e);
117            }
118        }
119
120        // 2. Update BM25 index if already loaded (don't trigger lazy init).
121        // The BM25 index rebuilds from all memories on first access anyway.
122        if self.bm25_ready() {
123            match self.lock_bm25() {
124                Ok(mut bm25) => {
125                    bm25.add_document(&memory.id, &memory.content);
126                }
127                Err(e) => tracing::warn!("BM25 lock failed during persist: {e}"),
128            }
129        }
130
131        // 3. Add memory node to in-memory graph (already persisted to SQLite above)
132        match self.lock_graph() {
133            Ok(mut graph) => {
134                let node = codemem_core::GraphNode {
135                    id: memory.id.clone(),
136                    kind: codemem_core::NodeKind::Memory,
137                    label: scoring::truncate_content(&memory.content, 80),
138                    payload: std::collections::HashMap::new(),
139                    centrality: 0.0,
140                    memory_id: Some(memory.id.clone()),
141                    namespace: memory.namespace.clone(),
142                    valid_from: None,
143                    valid_to: None,
144                };
145                if let Err(e) = graph.add_node(node) {
146                    tracing::warn!(
147                        "Failed to add graph node in-memory for memory {}: {e}",
148                        memory.id
149                    );
150                }
151            }
152            Err(e) => tracing::warn!("Graph lock failed during persist: {e}"),
153        }
154
155        // 3b. Auto-link to memories with shared tags (session co-membership, topic overlap)
156        self.auto_link_by_tags(memory);
157
158        // H3: Step 4 — Insert embedding into HNSW vector index if already loaded.
159        if let Some(vec) = &embedding_result {
160            if self.vector_ready() {
161                if let Ok(mut vi) = self.lock_vector() {
162                    if let Err(e) = vi.insert(&memory.id, vec) {
163                        tracing::warn!("Failed to insert into vector index for {}: {e}", memory.id);
164                    }
165                }
166            }
167        }
168
169        // C5: Set dirty flag instead of calling save_index() after each persist.
170        // Callers should use flush_if_dirty() to batch save the index.
171        if save {
172            self.save_index(); // save_index() clears dirty flag
173        } else {
174            self.dirty.store(true, Ordering::Release);
175        }
176
177        Ok(())
178    }
179
180    /// Execute all SQLite mutations for a memory persist.
181    ///
182    /// Called within the transaction opened by `persist_memory_inner`.
183    /// Inserts the memory row, graph node, and embedding (if available).
184    fn persist_memory_sqlite(
185        &self,
186        memory: &MemoryNode,
187        embedding: &Option<Vec<f32>>,
188    ) -> Result<(), CodememError> {
189        // 1. Store memory in SQLite
190        self.storage.insert_memory(memory)?;
191
192        // 2. Insert graph node in SQLite
193        let node = codemem_core::GraphNode {
194            id: memory.id.clone(),
195            kind: codemem_core::NodeKind::Memory,
196            label: scoring::truncate_content(&memory.content, 80),
197            payload: std::collections::HashMap::new(),
198            centrality: 0.0,
199            memory_id: Some(memory.id.clone()),
200            namespace: memory.namespace.clone(),
201            valid_from: None,
202            valid_to: None,
203        };
204        if let Err(e) = self.storage.insert_graph_node(&node) {
205            tracing::warn!("Failed to insert graph node for memory {}: {e}", memory.id);
206        }
207
208        // 3. Store embedding in SQLite (vector blob, not HNSW index)
209        if let Some(vec) = embedding {
210            if let Err(e) = self.storage.store_embedding(&memory.id, vec) {
211                tracing::warn!("Failed to store embedding for {}: {e}", memory.id);
212            }
213        }
214
215        Ok(())
216    }
217
218    // ── Store with Links ──────────────────────────────────────────────────
219
220    /// Store a memory with optional explicit link IDs.
221    ///
222    /// Runs the full pipeline: persist → explicit RELATES_TO edges → auto-link
223    /// to code nodes → save index. This consolidates domain logic that was
224    /// previously spread across the MCP transport layer.
225    pub fn store_memory_with_links(
226        &self,
227        memory: &MemoryNode,
228        links: &[String],
229    ) -> Result<(), CodememError> {
230        self.persist_memory(memory)?;
231
232        // Create RELATES_TO edges for explicit links
233        if !links.is_empty() {
234            let now = chrono::Utc::now();
235            let mut graph = self.lock_graph()?;
236            for link_id in links {
237                let edge = Edge {
238                    id: format!("{}-RELATES_TO-{link_id}", memory.id),
239                    src: memory.id.clone(),
240                    dst: link_id.clone(),
241                    relationship: RelationshipType::RelatesTo,
242                    weight: 1.0,
243                    properties: HashMap::new(),
244                    created_at: now,
245                    valid_from: None,
246                    valid_to: None,
247                };
248                if let Err(e) = self.storage.insert_graph_edge(&edge) {
249                    tracing::warn!("Failed to persist link edge to {link_id}: {e}");
250                }
251                if let Err(e) = graph.add_edge(edge) {
252                    tracing::warn!("Failed to add link edge to {link_id}: {e}");
253                }
254            }
255        }
256
257        // Auto-link to code nodes mentioned in content
258        self.auto_link_to_code_nodes(&memory.id, &memory.content, links);
259
260        Ok(())
261    }
262
263    // ── Edge Helpers ─────────────────────────────────────────────────────
264
265    /// Add an edge to both storage and in-memory graph.
266    pub fn add_edge(&self, edge: Edge) -> Result<(), CodememError> {
267        self.storage.insert_graph_edge(&edge)?;
268        let mut graph = self.lock_graph()?;
269        graph.add_edge(edge)?;
270        Ok(())
271    }
272
273    // ── Self-Editing ────────────────────────────────────────────────────
274
275    /// Refine a memory: create a new version with an EVOLVED_INTO edge from old to new.
276    pub fn refine_memory(
277        &self,
278        old_id: &str,
279        content: Option<&str>,
280        tags: Option<Vec<String>>,
281        importance: Option<f64>,
282    ) -> Result<(MemoryNode, String), CodememError> {
283        let old_memory = self
284            .storage
285            .get_memory(old_id)?
286            .ok_or_else(|| CodememError::NotFound(format!("Memory not found: {old_id}")))?;
287
288        let new_content = content.unwrap_or(&old_memory.content);
289        let new_tags = tags.unwrap_or_else(|| old_memory.tags.clone());
290        let new_importance = importance.unwrap_or(old_memory.importance);
291
292        let mut memory = MemoryNode::new(new_content, old_memory.memory_type);
293        let new_id = memory.id.clone();
294        memory.importance = new_importance;
295        memory.confidence = old_memory.confidence;
296        memory.tags = new_tags;
297        memory.metadata = old_memory.metadata.clone();
298        memory.namespace = old_memory.namespace.clone();
299
300        self.persist_memory(&memory)?;
301
302        // Create EVOLVED_INTO edge from old -> new
303        let now = chrono::Utc::now();
304        let edge = Edge {
305            id: format!("{old_id}-EVOLVED_INTO-{new_id}"),
306            src: old_id.to_string(),
307            dst: new_id.clone(),
308            relationship: RelationshipType::EvolvedInto,
309            weight: 1.0,
310            properties: std::collections::HashMap::new(),
311            created_at: now,
312            valid_from: Some(now),
313            valid_to: None,
314        };
315        if let Err(e) = self.add_edge(edge) {
316            tracing::warn!("Failed to add EVOLVED_INTO edge: {e}");
317        }
318
319        Ok((memory, new_id))
320    }
321
322    /// Split a memory into multiple parts, each linked via PART_OF edges.
323    pub fn split_memory(
324        &self,
325        source_id: &str,
326        parts: &[SplitPart],
327    ) -> Result<Vec<String>, CodememError> {
328        let source_memory = self
329            .storage
330            .get_memory(source_id)?
331            .ok_or_else(|| CodememError::NotFound(format!("Memory not found: {source_id}")))?;
332
333        if parts.is_empty() {
334            return Err(CodememError::InvalidInput(
335                "'parts' array must not be empty".to_string(),
336            ));
337        }
338
339        // Validate all parts upfront before persisting anything
340        for part in parts {
341            if part.content.is_empty() {
342                return Err(CodememError::InvalidInput(
343                    "Each part must have a non-empty 'content' field".to_string(),
344                ));
345            }
346        }
347
348        let now = chrono::Utc::now();
349        let mut child_ids: Vec<String> = Vec::new();
350
351        for part in parts {
352            let tags = part
353                .tags
354                .clone()
355                .unwrap_or_else(|| source_memory.tags.clone());
356            let importance = part.importance.unwrap_or(source_memory.importance);
357
358            let mut memory = MemoryNode::new(part.content.clone(), source_memory.memory_type);
359            let child_id = memory.id.clone();
360            memory.importance = importance;
361            memory.confidence = source_memory.confidence;
362            memory.tags = tags;
363            memory.namespace = source_memory.namespace.clone();
364
365            if let Err(e) = self.persist_memory_no_save(&memory) {
366                // Clean up already-created child memories
367                for created_id in &child_ids {
368                    if let Err(del_err) = self.delete_memory(created_id) {
369                        tracing::warn!(
370                            "Failed to clean up child memory {created_id} after split failure: {del_err}"
371                        );
372                    }
373                }
374                return Err(e);
375            }
376
377            // Create PART_OF edge: child -> source
378            let edge = Edge {
379                id: format!("{child_id}-PART_OF-{source_id}"),
380                src: child_id.clone(),
381                dst: source_id.to_string(),
382                relationship: RelationshipType::PartOf,
383                weight: 1.0,
384                properties: std::collections::HashMap::new(),
385                created_at: now,
386                valid_from: Some(now),
387                valid_to: None,
388            };
389            if let Err(e) = self.add_edge(edge) {
390                tracing::warn!("Failed to add PART_OF edge: {e}");
391            }
392
393            child_ids.push(child_id);
394        }
395
396        self.save_index();
397        Ok(child_ids)
398    }
399
400    /// Merge multiple memories into one, linked via SUMMARIZES edges.
401    pub fn merge_memories(
402        &self,
403        source_ids: &[String],
404        content: &str,
405        memory_type: MemoryType,
406        importance: f64,
407        tags: Vec<String>,
408    ) -> Result<String, CodememError> {
409        if source_ids.len() < 2 {
410            return Err(CodememError::InvalidInput(
411                "'source_ids' must contain at least 2 IDs".to_string(),
412            ));
413        }
414
415        // Verify all sources exist
416        let id_refs: Vec<&str> = source_ids.iter().map(|s| s.as_str()).collect();
417        let found = self.storage.get_memories_batch(&id_refs)?;
418        if found.len() != source_ids.len() {
419            let found_ids: std::collections::HashSet<&str> =
420                found.iter().map(|m| m.id.as_str()).collect();
421            let missing: Vec<&str> = id_refs
422                .iter()
423                .filter(|id| !found_ids.contains(**id))
424                .copied()
425                .collect();
426            return Err(CodememError::NotFound(format!(
427                "Source memories not found: {}",
428                missing.join(", ")
429            )));
430        }
431
432        let mut memory = MemoryNode::new(content, memory_type);
433        let merged_id = memory.id.clone();
434        memory.importance = importance;
435        memory.confidence = found.iter().map(|m| m.confidence).sum::<f64>() / found.len() as f64;
436        memory.tags = tags;
437        memory.namespace = found.iter().find_map(|m| m.namespace.clone());
438
439        self.persist_memory_no_save(&memory)?;
440
441        // Create SUMMARIZES edges: merged -> each source
442        let now = chrono::Utc::now();
443        for source_id in source_ids {
444            let edge = Edge {
445                id: format!("{merged_id}-SUMMARIZES-{source_id}"),
446                src: merged_id.clone(),
447                dst: source_id.clone(),
448                relationship: RelationshipType::Summarizes,
449                weight: 1.0,
450                properties: std::collections::HashMap::new(),
451                created_at: now,
452                valid_from: Some(now),
453                valid_to: None,
454            };
455            if let Err(e) = self.add_edge(edge) {
456                tracing::warn!("Failed to add SUMMARIZES edge to {source_id}: {e}");
457            }
458        }
459
460        self.save_index();
461        Ok(merged_id)
462    }
463
464    /// Update a memory's content and/or importance, re-embedding if needed.
465    pub fn update_memory(
466        &self,
467        id: &str,
468        content: &str,
469        importance: Option<f64>,
470    ) -> Result<(), CodememError> {
471        self.storage.update_memory(id, content, importance)?;
472
473        // Update BM25 index
474        self.lock_bm25()?.add_document(id, content);
475
476        // Update graph node label
477        if let Ok(mut graph) = self.lock_graph() {
478            if let Ok(Some(mut node)) = graph.get_node(id) {
479                node.label = scoring::truncate_content(content, 80);
480                if let Err(e) = graph.add_node(node) {
481                    tracing::warn!("Failed to update graph node for {id}: {e}");
482                }
483            }
484        }
485
486        // Re-embed with contextual enrichment
487        // H3: Acquire embeddings lock, embed, drop lock before acquiring vector lock.
488        if let Some(emb_guard) = self.lock_embeddings()? {
489            let (mem_type, tags, namespace) =
490                if let Ok(Some(mem)) = self.storage.get_memory_no_touch(id) {
491                    (mem.memory_type, mem.tags, mem.namespace)
492                } else {
493                    (MemoryType::Context, vec![], None)
494                };
495            let enriched =
496                self.enrich_memory_text(content, mem_type, &tags, namespace.as_deref(), Some(id));
497            let emb_result = emb_guard.embed(&enriched);
498            drop(emb_guard);
499            if let Ok(embedding) = emb_result {
500                if let Err(e) = self.storage.store_embedding(id, &embedding) {
501                    tracing::warn!("Failed to store embedding for {id}: {e}");
502                }
503                let mut vec = self.lock_vector()?;
504                if let Err(e) = vec.remove(id) {
505                    tracing::warn!("Failed to remove old vector for {id}: {e}");
506                }
507                if let Err(e) = vec.insert(id, &embedding) {
508                    tracing::warn!("Failed to insert new vector for {id}: {e}");
509                }
510            }
511        }
512
513        self.save_index();
514        Ok(())
515    }
516
517    /// Update only the importance of a memory.
518    /// Routes through the engine to maintain the transport → engine → storage boundary.
519    pub fn update_importance(&self, id: &str, importance: f64) -> Result<(), CodememError> {
520        self.storage
521            .batch_update_importance(&[(id.to_string(), importance)])?;
522        Ok(())
523    }
524
525    /// Delete a memory from all subsystems.
526    ///
527    /// M1: Uses `delete_memory_cascade` on the storage backend to wrap all
528    /// SQLite deletes (memory + graph nodes/edges + embedding) in a single
529    /// transaction when the backend supports it. In-memory structures
530    /// (vector, graph, BM25) are cleaned up separately with proper lock ordering.
531    pub fn delete_memory(&self, id: &str) -> Result<bool, CodememError> {
532        // Use cascade delete for all storage-side operations in a single transaction.
533        let deleted = self.storage.delete_memory_cascade(id)?;
534        if !deleted {
535            return Ok(false);
536        }
537
538        // Clean up in-memory structures with proper lock ordering:
539        // vector first, then graph, then BM25.
540        let mut vec = self.lock_vector()?;
541        if let Err(e) = vec.remove(id) {
542            tracing::warn!("Failed to remove {id} from vector index: {e}");
543        }
544        drop(vec);
545
546        let mut graph = self.lock_graph()?;
547        if let Err(e) = graph.remove_node(id) {
548            tracing::warn!("Failed to remove {id} from in-memory graph: {e}");
549        }
550        drop(graph);
551
552        self.lock_bm25()?.remove_document(id);
553
554        // Persist vector index to disk
555        self.save_index();
556        Ok(true)
557    }
558
559    /// Opportunistic sweep of expired memories. Rate-limited to once per 60 seconds
560    /// to avoid adding overhead to every recall call.
561    pub(crate) fn sweep_expired_memories(&self) {
562        let now = chrono::Utc::now().timestamp();
563        let last = self.last_expiry_sweep.load(Ordering::Relaxed);
564        if now - last < 60 {
565            return;
566        }
567        // CAS to avoid concurrent sweeps
568        if self
569            .last_expiry_sweep
570            .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
571            .is_err()
572        {
573            return;
574        }
575        match self.storage.delete_expired_memories() {
576            Ok(0) => {}
577            Ok(n) => tracing::debug!("Swept {n} expired memories"),
578            Err(e) => tracing::warn!("Expired memory sweep failed: {e}"),
579        }
580    }
581}