Skip to main content

codemem_engine/
memory_ops.rs

1use crate::scoring;
2use crate::CodememEngine;
3use crate::SplitPart;
4use codemem_core::{
5    CodememError, Edge, GraphBackend, MemoryNode, MemoryType, RelationshipType, VectorBackend,
6};
7use std::sync::atomic::Ordering;
8
9impl CodememEngine {
10    // ── Persistence ─────────────────────────────────────────────────────
11
12    /// Persist a memory through the full pipeline: storage → BM25 → graph → embedding → vector.
13    pub fn persist_memory(&self, memory: &MemoryNode) -> Result<(), CodememError> {
14        self.persist_memory_inner(memory, true)
15    }
16
17    /// Persist a memory without saving the vector index to disk.
18    /// Use this in batch operations, then call `save_index()` once at the end.
19    pub(crate) fn persist_memory_no_save(&self, memory: &MemoryNode) -> Result<(), CodememError> {
20        self.persist_memory_inner(memory, false)
21    }
22
23    /// Inner persist implementation with optional index save.
24    ///
25    /// H3: Lock ordering is enforced to prevent deadlocks:
26    /// 1. Embeddings lock (acquire, embed, drop)
27    /// 2. BM25 lock
28    /// 3. Graph lock
29    /// 4. Vector lock
30    fn persist_memory_inner(&self, memory: &MemoryNode, save: bool) -> Result<(), CodememError> {
31        // Auto-populate session_id from the engine's active session if not already set
32        let memory = if memory.session_id.is_none() {
33            if let Some(active_sid) = self.active_session_id() {
34                let mut m = memory.clone();
35                m.session_id = Some(active_sid);
36                std::borrow::Cow::Owned(m)
37            } else {
38                std::borrow::Cow::Borrowed(memory)
39            }
40        } else {
41            std::borrow::Cow::Borrowed(memory)
42        };
43        let memory = memory.as_ref();
44
45        // H3: Step 1 — Acquire embeddings lock first, embed, save result, drop lock.
46        // This prevents holding the embeddings lock while acquiring vector/graph locks.
47        // Embedding is computed before the transaction since it doesn't touch SQLite.
48        let embedding_result = match self.lock_embeddings() {
49            Ok(Some(emb)) => {
50                let enriched = self.enrich_memory_text(
51                    &memory.content,
52                    memory.memory_type,
53                    &memory.tags,
54                    memory.namespace.as_deref(),
55                    Some(&memory.id),
56                );
57                let result = emb.embed(&enriched).ok();
58                drop(emb);
59                result
60            }
61            Ok(None) => None,
62            Err(e) => {
63                tracing::warn!("Embeddings lock failed during persist: {e}");
64                None
65            }
66        };
67
68        // 2F: Wrap all SQLite mutations in a single transaction so that the
69        // database cannot be left in an inconsistent state if one step fails.
70        // The HNSW vector index is NOT in SQLite, so vector insertion happens
71        // after commit — if it fails, the memory is still persisted without
72        // its embedding, which is recoverable.
73        self.storage.begin_transaction()?;
74
75        let result = self.persist_memory_sqlite(memory, &embedding_result);
76
77        match result {
78            Ok(()) => {
79                self.storage.commit_transaction()?;
80            }
81            Err(e) => {
82                if let Err(rb_err) = self.storage.rollback_transaction() {
83                    tracing::error!("Failed to rollback transaction after persist error: {rb_err}");
84                }
85                return Err(e);
86            }
87        }
88
89        // 2. Update BM25 index (in-memory, not SQLite)
90        match self.lock_bm25() {
91            Ok(mut bm25) => {
92                bm25.add_document(&memory.id, &memory.content);
93            }
94            Err(e) => tracing::warn!("BM25 lock failed during persist: {e}"),
95        }
96
97        // 3. Add memory node to in-memory graph (already persisted to SQLite above)
98        match self.lock_graph() {
99            Ok(mut graph) => {
100                let node = codemem_core::GraphNode {
101                    id: memory.id.clone(),
102                    kind: codemem_core::NodeKind::Memory,
103                    label: scoring::truncate_content(&memory.content, 80),
104                    payload: std::collections::HashMap::new(),
105                    centrality: 0.0,
106                    memory_id: Some(memory.id.clone()),
107                    namespace: memory.namespace.clone(),
108                };
109                if let Err(e) = graph.add_node(node) {
110                    tracing::warn!(
111                        "Failed to add graph node in-memory for memory {}: {e}",
112                        memory.id
113                    );
114                }
115            }
116            Err(e) => tracing::warn!("Graph lock failed during persist: {e}"),
117        }
118
119        // 3b. Auto-link to memories with shared tags (session co-membership, topic overlap)
120        self.auto_link_by_tags(memory);
121
122        // H3: Step 4 — Insert embedding into HNSW vector index (separate from SQLite).
123        if let Some(vec) = &embedding_result {
124            if let Ok(mut vi) = self.lock_vector() {
125                if let Err(e) = vi.insert(&memory.id, vec) {
126                    tracing::warn!("Failed to insert into vector index for {}: {e}", memory.id);
127                }
128            }
129        }
130
131        // C5: Set dirty flag instead of calling save_index() after each persist.
132        // Callers should use flush_if_dirty() to batch save the index.
133        if save {
134            self.save_index(); // save_index() clears dirty flag
135        } else {
136            self.dirty.store(true, Ordering::Release);
137        }
138
139        Ok(())
140    }
141
142    /// Execute all SQLite mutations for a memory persist.
143    ///
144    /// Called within the transaction opened by `persist_memory_inner`.
145    /// Inserts the memory row, graph node, and embedding (if available).
146    fn persist_memory_sqlite(
147        &self,
148        memory: &MemoryNode,
149        embedding: &Option<Vec<f32>>,
150    ) -> Result<(), CodememError> {
151        // 1. Store memory in SQLite
152        self.storage.insert_memory(memory)?;
153
154        // 2. Insert graph node in SQLite
155        let node = codemem_core::GraphNode {
156            id: memory.id.clone(),
157            kind: codemem_core::NodeKind::Memory,
158            label: scoring::truncate_content(&memory.content, 80),
159            payload: std::collections::HashMap::new(),
160            centrality: 0.0,
161            memory_id: Some(memory.id.clone()),
162            namespace: memory.namespace.clone(),
163        };
164        if let Err(e) = self.storage.insert_graph_node(&node) {
165            tracing::warn!("Failed to insert graph node for memory {}: {e}", memory.id);
166        }
167
168        // 3. Store embedding in SQLite (vector blob, not HNSW index)
169        if let Some(vec) = embedding {
170            if let Err(e) = self.storage.store_embedding(&memory.id, vec) {
171                tracing::warn!("Failed to store embedding for {}: {e}", memory.id);
172            }
173        }
174
175        Ok(())
176    }
177
178    // ── Edge Helpers ─────────────────────────────────────────────────────
179
180    /// Add an edge to both storage and in-memory graph.
181    pub fn add_edge(&self, edge: Edge) -> Result<(), CodememError> {
182        self.storage.insert_graph_edge(&edge)?;
183        let mut graph = self.lock_graph()?;
184        graph.add_edge(edge)?;
185        Ok(())
186    }
187
188    // ── Self-Editing ────────────────────────────────────────────────────
189
190    /// Refine a memory: create a new version with an EVOLVED_INTO edge from old to new.
191    pub fn refine_memory(
192        &self,
193        old_id: &str,
194        content: Option<&str>,
195        tags: Option<Vec<String>>,
196        importance: Option<f64>,
197    ) -> Result<(MemoryNode, String), CodememError> {
198        let old_memory = self
199            .storage
200            .get_memory(old_id)?
201            .ok_or_else(|| CodememError::NotFound(format!("Memory not found: {old_id}")))?;
202
203        let new_content = content.unwrap_or(&old_memory.content);
204        let new_tags = tags.unwrap_or_else(|| old_memory.tags.clone());
205        let new_importance = importance.unwrap_or(old_memory.importance);
206
207        let now = chrono::Utc::now();
208        let new_id = uuid::Uuid::new_v4().to_string();
209        let hash = codemem_storage::Storage::content_hash(new_content);
210
211        let memory = MemoryNode {
212            id: new_id.clone(),
213            content: new_content.to_string(),
214            memory_type: old_memory.memory_type,
215            importance: new_importance,
216            confidence: old_memory.confidence,
217            access_count: 0,
218            content_hash: hash,
219            tags: new_tags,
220            metadata: old_memory.metadata.clone(),
221            namespace: old_memory.namespace.clone(),
222            session_id: None,
223            created_at: now,
224            updated_at: now,
225            last_accessed_at: now,
226        };
227
228        self.persist_memory(&memory)?;
229
230        // Create EVOLVED_INTO edge from old -> new
231        let edge = Edge {
232            id: format!("{old_id}-EVOLVED_INTO-{new_id}"),
233            src: old_id.to_string(),
234            dst: new_id.clone(),
235            relationship: RelationshipType::EvolvedInto,
236            weight: 1.0,
237            properties: std::collections::HashMap::new(),
238            created_at: now,
239            valid_from: Some(now),
240            valid_to: None,
241        };
242        if let Err(e) = self.add_edge(edge) {
243            tracing::warn!("Failed to add EVOLVED_INTO edge: {e}");
244        }
245
246        Ok((memory, new_id))
247    }
248
249    /// Split a memory into multiple parts, each linked via PART_OF edges.
250    pub fn split_memory(
251        &self,
252        source_id: &str,
253        parts: &[SplitPart],
254    ) -> Result<Vec<String>, CodememError> {
255        let source_memory = self
256            .storage
257            .get_memory(source_id)?
258            .ok_or_else(|| CodememError::NotFound(format!("Memory not found: {source_id}")))?;
259
260        if parts.is_empty() {
261            return Err(CodememError::InvalidInput(
262                "'parts' array must not be empty".to_string(),
263            ));
264        }
265
266        // Validate all parts upfront before persisting anything
267        for part in parts {
268            if part.content.is_empty() {
269                return Err(CodememError::InvalidInput(
270                    "Each part must have a non-empty 'content' field".to_string(),
271                ));
272            }
273        }
274
275        let now = chrono::Utc::now();
276        let mut child_ids: Vec<String> = Vec::new();
277
278        for part in parts {
279            let tags = part
280                .tags
281                .clone()
282                .unwrap_or_else(|| source_memory.tags.clone());
283            let importance = part.importance.unwrap_or(source_memory.importance);
284
285            let child_id = uuid::Uuid::new_v4().to_string();
286            let hash = codemem_storage::Storage::content_hash(&part.content);
287
288            let memory = MemoryNode {
289                id: child_id.clone(),
290                content: part.content.clone(),
291                memory_type: source_memory.memory_type,
292                importance,
293                confidence: source_memory.confidence,
294                access_count: 0,
295                content_hash: hash,
296                tags,
297                metadata: std::collections::HashMap::new(),
298                namespace: source_memory.namespace.clone(),
299                session_id: None,
300                created_at: now,
301                updated_at: now,
302                last_accessed_at: now,
303            };
304
305            if let Err(e) = self.persist_memory_no_save(&memory) {
306                // Clean up already-created child memories
307                for created_id in &child_ids {
308                    if let Err(del_err) = self.delete_memory(created_id) {
309                        tracing::warn!(
310                            "Failed to clean up child memory {created_id} after split failure: {del_err}"
311                        );
312                    }
313                }
314                return Err(e);
315            }
316
317            // Create PART_OF edge: child -> source
318            let edge = Edge {
319                id: format!("{child_id}-PART_OF-{source_id}"),
320                src: child_id.clone(),
321                dst: source_id.to_string(),
322                relationship: RelationshipType::PartOf,
323                weight: 1.0,
324                properties: std::collections::HashMap::new(),
325                created_at: now,
326                valid_from: Some(now),
327                valid_to: None,
328            };
329            if let Err(e) = self.add_edge(edge) {
330                tracing::warn!("Failed to add PART_OF edge: {e}");
331            }
332
333            child_ids.push(child_id);
334        }
335
336        self.save_index();
337        Ok(child_ids)
338    }
339
340    /// Merge multiple memories into one, linked via SUMMARIZES edges.
341    pub fn merge_memories(
342        &self,
343        source_ids: &[String],
344        content: &str,
345        memory_type: MemoryType,
346        importance: f64,
347        tags: Vec<String>,
348    ) -> Result<String, CodememError> {
349        if source_ids.len() < 2 {
350            return Err(CodememError::InvalidInput(
351                "'source_ids' must contain at least 2 IDs".to_string(),
352            ));
353        }
354
355        // Verify all sources exist
356        let id_refs: Vec<&str> = source_ids.iter().map(|s| s.as_str()).collect();
357        let found = self.storage.get_memories_batch(&id_refs)?;
358        if found.len() != source_ids.len() {
359            let found_ids: std::collections::HashSet<&str> =
360                found.iter().map(|m| m.id.as_str()).collect();
361            let missing: Vec<&str> = id_refs
362                .iter()
363                .filter(|id| !found_ids.contains(**id))
364                .copied()
365                .collect();
366            return Err(CodememError::NotFound(format!(
367                "Source memories not found: {}",
368                missing.join(", ")
369            )));
370        }
371
372        let now = chrono::Utc::now();
373        let merged_id = uuid::Uuid::new_v4().to_string();
374        let hash = codemem_storage::Storage::content_hash(content);
375
376        let memory = MemoryNode {
377            id: merged_id.clone(),
378            content: content.to_string(),
379            memory_type,
380            importance,
381            confidence: found.iter().map(|m| m.confidence).sum::<f64>() / found.len() as f64,
382            access_count: 0,
383            content_hash: hash,
384            tags,
385            metadata: std::collections::HashMap::new(),
386            namespace: found.iter().find_map(|m| m.namespace.clone()),
387            session_id: None,
388            created_at: now,
389            updated_at: now,
390            last_accessed_at: now,
391        };
392
393        self.persist_memory_no_save(&memory)?;
394
395        // Create SUMMARIZES edges: merged -> each source
396        for source_id in source_ids {
397            let edge = Edge {
398                id: format!("{merged_id}-SUMMARIZES-{source_id}"),
399                src: merged_id.clone(),
400                dst: source_id.clone(),
401                relationship: RelationshipType::Summarizes,
402                weight: 1.0,
403                properties: std::collections::HashMap::new(),
404                created_at: now,
405                valid_from: Some(now),
406                valid_to: None,
407            };
408            if let Err(e) = self.add_edge(edge) {
409                tracing::warn!("Failed to add SUMMARIZES edge to {source_id}: {e}");
410            }
411        }
412
413        self.save_index();
414        Ok(merged_id)
415    }
416
417    /// Update a memory's content and/or importance, re-embedding if needed.
418    pub fn update_memory(
419        &self,
420        id: &str,
421        content: &str,
422        importance: Option<f64>,
423    ) -> Result<(), CodememError> {
424        self.storage.update_memory(id, content, importance)?;
425
426        // Update BM25 index
427        self.lock_bm25()?.add_document(id, content);
428
429        // Update graph node label
430        if let Ok(mut graph) = self.lock_graph() {
431            if let Ok(Some(mut node)) = graph.get_node(id) {
432                node.label = scoring::truncate_content(content, 80);
433                if let Err(e) = graph.add_node(node) {
434                    tracing::warn!("Failed to update graph node for {id}: {e}");
435                }
436            }
437        }
438
439        // Re-embed with contextual enrichment
440        // H3: Acquire embeddings lock, embed, drop lock before acquiring vector lock.
441        if let Some(emb_guard) = self.lock_embeddings()? {
442            let (mem_type, tags, namespace) =
443                if let Ok(Some(mem)) = self.storage.get_memory_no_touch(id) {
444                    (mem.memory_type, mem.tags, mem.namespace)
445                } else {
446                    (MemoryType::Context, vec![], None)
447                };
448            let enriched =
449                self.enrich_memory_text(content, mem_type, &tags, namespace.as_deref(), Some(id));
450            let emb_result = emb_guard.embed(&enriched);
451            drop(emb_guard);
452            if let Ok(embedding) = emb_result {
453                if let Err(e) = self.storage.store_embedding(id, &embedding) {
454                    tracing::warn!("Failed to store embedding for {id}: {e}");
455                }
456                let mut vec = self.lock_vector()?;
457                if let Err(e) = vec.remove(id) {
458                    tracing::warn!("Failed to remove old vector for {id}: {e}");
459                }
460                if let Err(e) = vec.insert(id, &embedding) {
461                    tracing::warn!("Failed to insert new vector for {id}: {e}");
462                }
463            }
464        }
465
466        self.save_index();
467        Ok(())
468    }
469
470    /// Update only the importance of a memory.
471    /// Routes through the engine to maintain the transport → engine → storage boundary.
472    pub fn update_importance(&self, id: &str, importance: f64) -> Result<(), CodememError> {
473        self.storage
474            .batch_update_importance(&[(id.to_string(), importance)])?;
475        Ok(())
476    }
477
478    /// Delete a memory from all subsystems.
479    ///
480    /// M1: Uses `delete_memory_cascade` on the storage backend to wrap all
481    /// SQLite deletes (memory + graph nodes/edges + embedding) in a single
482    /// transaction when the backend supports it. In-memory structures
483    /// (vector, graph, BM25) are cleaned up separately with proper lock ordering.
484    pub fn delete_memory(&self, id: &str) -> Result<bool, CodememError> {
485        // Use cascade delete for all storage-side operations in a single transaction.
486        let deleted = self.storage.delete_memory_cascade(id)?;
487        if !deleted {
488            return Ok(false);
489        }
490
491        // Clean up in-memory structures with proper lock ordering:
492        // vector first, then graph, then BM25.
493        let mut vec = self.lock_vector()?;
494        if let Err(e) = vec.remove(id) {
495            tracing::warn!("Failed to remove {id} from vector index: {e}");
496        }
497        drop(vec);
498
499        let mut graph = self.lock_graph()?;
500        if let Err(e) = graph.remove_node(id) {
501            tracing::warn!("Failed to remove {id} from in-memory graph: {e}");
502        }
503        drop(graph);
504
505        self.lock_bm25()?.remove_document(id);
506
507        // Persist vector index to disk
508        self.save_index();
509        Ok(true)
510    }
511}