Skip to main content

oxios_memory/memory/manager/
store.rs

1//! Memory store operations: save/load, index management, search.
2//!
3//! Integrates HNSW index (usearch) for fast approximate nearest neighbor search
4//! alongside the abstract storage backend for persistence.
5
6use std::collections::HashMap;
7
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10
11use crate::memory::auto_protect::AutoProtector;
12use crate::memory::embedding::EmbeddingVector;
13use crate::memory::storage::MemoryStorageExt;
14#[cfg(feature = "sqlite-memory")]
15use crate::memory::types::MemoryTier;
16use crate::memory::types::{content_hash, dedup_by_id, extract_keywords, MemoryEntry, MemoryType};
17
18use super::MemoryManager;
19
20// ---------------------------------------------------------------------------
21// VectorIndexSnapshot
22// ---------------------------------------------------------------------------
23
24/// Snapshot of the vector index for persistence.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26struct VectorIndexSnapshot {
27    /// Snapshot creation timestamp.
28    created_at: DateTime<Utc>,
29    /// Number of entries in the snapshot.
30    entry_count: usize,
31    /// Map of entry ID to embedding vector.
32    entries: HashMap<String, EmbeddingVector>,
33}
34
35// ---------------------------------------------------------------------------
36// Store & search operations
37// ---------------------------------------------------------------------------
38
39impl MemoryManager {
40    /// Returns total entries across all memory types (from disk).
41    pub async fn total_entries(&self) -> usize {
42        let mut total = 0;
43        for mt in MemoryType::all() {
44            if let Ok(entries) = self.list(*mt, 1_000_000).await {
45                total += entries.len();
46            }
47        }
48        total
49    }
50
51    /// Rebuild the vector index from all stored memories.
52    ///
53    /// Call once at startup to populate the in-memory index from
54    /// persisted memory entries.
55    pub async fn rebuild_index(&self) -> anyhow::Result<()> {
56        // Collect all entries outside the lock
57        let mut entries_to_index: Vec<(String, EmbeddingVector)> = Vec::new();
58
59        for mt in MemoryType::all() {
60            if let Ok(names) = self.storage.list_category(mt.category()).await {
61                for name in names {
62                    if let Ok(Some(entry)) = self
63                        .storage
64                        .load_json::<MemoryEntry>(mt.category(), &name)
65                        .await
66                    {
67                        let vector = self.embedding.embed(&entry.content).await?;
68                        entries_to_index.push((entry.id.clone(), vector));
69                    }
70                }
71            }
72        }
73
74        // Now acquire the lock only for the write
75        {
76            let mut index = self.vector_index.write();
77            index.clear();
78            for (id, vector) in entries_to_index {
79                index.insert(id, vector);
80            }
81        }
82
83        tracing::info!(
84            entries = self.vector_index.read().len(),
85            "Memory vector index rebuilt"
86        );
87        Ok(())
88    }
89
90    /// Save the current vector index to disk as a snapshot.
91    pub async fn save_index_snapshot(&self) -> anyhow::Result<()> {
92        let snapshot = {
93            let index = self.vector_index.read();
94            VectorIndexSnapshot {
95                created_at: chrono::Utc::now(),
96                entry_count: index.len(),
97                entries: index.clone(),
98            }
99        };
100
101        self.storage
102            .save_json("memory", "vector_index_snapshot", &snapshot)
103            .await?;
104
105        self.git_commit("memory/vector_index_snapshot.json", "memory: snapshot save")
106            .await;
107
108        tracing::debug!(
109            entries = snapshot.entry_count,
110            "Vector index snapshot saved"
111        );
112        Ok(())
113    }
114
115    /// Load a previously saved vector index snapshot from disk.
116    pub async fn load_index_snapshot(&self) -> anyhow::Result<usize> {
117        let snapshot: Option<VectorIndexSnapshot> = self
118            .storage
119            .load_json("memory", "vector_index_snapshot")
120            .await?;
121
122        match snapshot {
123            Some(snap) => {
124                let count = snap.entry_count;
125                let mut index = self.vector_index.write();
126                *index = snap.entries;
127                tracing::info!(entries = count, "Vector index snapshot loaded");
128                Ok(count)
129            }
130            None => {
131                tracing::debug!("No vector index snapshot found");
132                Ok(0)
133            }
134        }
135    }
136
137    /// Store a memory entry. Returns the entry ID.
138    ///
139    /// When SQLite backend is enabled, delegates to `SqliteMemoryStore`.
140    /// Otherwise computes and stores the entry's text vector in the in-memory
141    /// index for future semantic search.
142    pub async fn remember(&self, entry: MemoryEntry) -> anyhow::Result<String> {
143        // ── SQLite fast path (RFC-012) ──
144        #[cfg(feature = "sqlite-memory")]
145        if let Some(ref sqlite) = self.sqlite_store {
146            return sqlite.remember(&entry).await;
147        }
148
149        // ── Legacy JSON path ──
150        let id = entry.id.clone();
151        let vector = self.embedding.embed(&entry.content).await?;
152        let category = entry.memory_type.category();
153        self.storage.save_json(category, &id, &entry).await?;
154
155        self.git_commit(
156            &format!("{category}/{id}.json"),
157            &format!("memory: store {id}"),
158        )
159        .await;
160
161        // Update vector index
162        {
163            let mut index = self.vector_index.write();
164            index.insert(id.clone(), vector.clone());
165        }
166
167        // Update HNSW index if attached
168        if let Some(f32_vec) = vector.to_f32_dense() {
169            let hnsw = self.hnsw_index.read();
170            if let Some(ref hnsw) = *hnsw {
171                if let Err(e) = hnsw.add_entry(&id, &f32_vec) {
172                    tracing::warn!(id = %id, error = %e, "Failed to update HNSW index on remember");
173                }
174            }
175        }
176
177        tracing::debug!(id = %id, ty = entry.memory_type.label(), "Memory stored");
178        Ok(id)
179    }
180
181    /// Retrieve a single memory by ID.
182    ///
183    /// Records access for auto-protection tracking.
184    pub async fn get(
185        &self,
186        id: &str,
187        memory_type: MemoryType,
188    ) -> anyhow::Result<Option<MemoryEntry>> {
189        #[cfg(feature = "sqlite-memory")]
190        if let Some(ref sqlite) = self.sqlite_store {
191            return sqlite.get(id, memory_type);
192        }
193        let result: Option<MemoryEntry> =
194            self.storage.load_json(memory_type.category(), id).await?;
195        if let Some(mut entry) = result {
196            AutoProtector::record_access(&mut entry, "");
197            Ok(Some(entry))
198        } else {
199            Ok(None)
200        }
201    }
202
203    /// Delete a memory entry.
204    pub async fn forget(&self, id: &str, memory_type: MemoryType) -> anyhow::Result<bool> {
205        #[cfg(feature = "sqlite-memory")]
206        if let Some(ref sqlite) = self.sqlite_store {
207            return sqlite.forget(id, memory_type);
208        }
209        let result = self.storage.delete_file(memory_type.category(), id).await?;
210
211        // Remove from HNSW index if attached
212        {
213            let hnsw = self.hnsw_index.read();
214            if let Some(ref hnsw) = *hnsw {
215                if let Err(e) = hnsw.remove_entry(id) {
216                    tracing::warn!(id = %id, error = %e, "Failed to remove from HNSW index on forget");
217                }
218            }
219        }
220
221        Ok(result)
222    }
223
224    /// List memories of a given type, most recent first.
225    pub async fn list(
226        &self,
227        memory_type: MemoryType,
228        limit: usize,
229    ) -> anyhow::Result<Vec<MemoryEntry>> {
230        #[cfg(feature = "sqlite-memory")]
231        if let Some(ref sqlite) = self.sqlite_store {
232            return sqlite.list(memory_type, limit);
233        }
234        let category = memory_type.category();
235        let names = self.storage.list_category(category).await?;
236        let mut entries = Vec::new();
237        for name in names.into_iter().take(limit.saturating_mul(2)) {
238            if let Ok(Some(entry)) = self.storage.load_json::<MemoryEntry>(category, &name).await {
239                entries.push(entry);
240            }
241        }
242        // Sort by created_at descending (most recent first)
243        entries.sort_by_key(|b| std::cmp::Reverse(b.created_at));
244        entries.truncate(limit);
245        Ok(entries)
246    }
247
248    /// Search memories by semantic similarity (vector search).
249    ///
250    /// Falls back to keyword search when the vector index is empty or
251    /// yields no results above the similarity threshold.
252    pub async fn search(
253        &self,
254        query: &str,
255        memory_type: Option<MemoryType>,
256        limit: usize,
257    ) -> anyhow::Result<Vec<MemoryEntry>> {
258        #[cfg(feature = "sqlite-memory")]
259        if let Some(ref sqlite) = self.sqlite_store {
260            return sqlite.search(query, memory_type, limit).await;
261        }
262        let query_vector = self.embedding.embed(query).await?;
263
264        // Scope the read lock: compute scores, then drop before any await.
265        let scored: Vec<(String, f64)> = {
266            let index = self.vector_index.read();
267            let mut scored: Vec<(String, f64)> = index
268                .iter()
269                .map(|(id, vector)| {
270                    let score = query_vector.cosine_similarity(vector);
271                    (id.clone(), score)
272                })
273                .filter(|(_, score)| *score > 0.1)
274                .collect();
275            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
276            scored.truncate(limit);
277            scored
278        }; // lock dropped here, before any .await
279
280        // If index was empty, scored will be empty — fall back immediately
281        if scored.is_empty() {
282            return self.keyword_search(query, memory_type, limit).await;
283        }
284
285        // Determine which memory types to search
286        let types: &[MemoryType] = match memory_type {
287            Some(ref t) => std::slice::from_ref(t),
288            None => MemoryType::all(),
289        };
290
291        // Load entries from storage (no lock held)
292        let mut results = Vec::new();
293        for (id, score) in scored {
294            for mt in types {
295                if let Ok(Some(mut entry)) = self
296                    .storage
297                    .load_json::<MemoryEntry>(mt.category(), &id)
298                    .await
299                {
300                    AutoProtector::record_access(&mut entry, "");
301                    tracing::debug!(id = %id, score, "Vector search hit");
302                    results.push(entry);
303                    break;
304                }
305            }
306        }
307
308        // Fall back to keyword search if no results
309        if results.is_empty() {
310            return self.keyword_search(query, memory_type, limit).await;
311        }
312
313        Ok(results)
314    }
315
316    /// Keyword-based search (original algorithm, used as fallback).
317    pub(crate) async fn keyword_search(
318        &self,
319        query: &str,
320        memory_type: Option<MemoryType>,
321        limit: usize,
322    ) -> anyhow::Result<Vec<MemoryEntry>> {
323        let keywords = extract_keywords(query);
324        let types = match memory_type {
325            Some(t) => vec![t],
326            None => MemoryType::all().to_vec(),
327        };
328
329        let mut results = Vec::new();
330        for ty in &types {
331            let entries = self.list(*ty, limit * 2).await?;
332            for entry in entries {
333                let matches = keywords.iter().any(|k| {
334                    let k_lower = k.to_lowercase();
335                    entry.content.to_lowercase().contains(&k_lower)
336                        || entry
337                            .tags
338                            .iter()
339                            .any(|t| t.to_lowercase().contains(&k_lower))
340                });
341                if matches {
342                    results.push(entry);
343                }
344            }
345        }
346
347        results.sort_by(|a, b| {
348            b.importance
349                .partial_cmp(&a.importance)
350                .unwrap_or(std::cmp::Ordering::Equal)
351        });
352        results.truncate(limit);
353        Ok(results)
354    }
355
356    /// Recall relevant memories for a new session.
357    ///
358    /// Combines recent conversation summaries, session summaries,
359    /// and keyword-matched facts/episodes.
360    pub async fn recall(&self, query: &str) -> anyhow::Result<Vec<MemoryEntry>> {
361        #[cfg(feature = "sqlite-memory")]
362        if let Some(ref sqlite) = self.sqlite_store {
363            return sqlite.recall(query, self.max_recall).await;
364        }
365        let limit = self.max_recall;
366
367        // 1. Recent conversation summaries (always include)
368        let recent = self
369            .list(MemoryType::Conversation, 3)
370            .await
371            .unwrap_or_default();
372
373        // 2. Recent session summaries
374        let sessions = self.list(MemoryType::Session, 2).await.unwrap_or_default();
375
376        // 3. Keyword-matched facts and episodes
377        let relevant = self.search(query, None, limit).await.unwrap_or_default();
378
379        // 4. Combine and deduplicate
380        let mut combined = recent;
381        combined.extend(sessions);
382        combined.extend(relevant);
383        dedup_by_id(&mut combined);
384        combined.truncate(limit);
385        Ok(combined)
386    }
387
388    /// Blend recalled memories into the system prompt.
389    pub fn blend_into_prompt(&self, memories: &[MemoryEntry], system_prompt: &str) -> String {
390        #[cfg(feature = "sqlite-memory")]
391        if let Some(ref sqlite) = self.sqlite_store {
392            return sqlite.blend_into_prompt(memories, system_prompt);
393        }
394
395        if memories.is_empty() {
396            return system_prompt.to_string();
397        }
398
399        let memory_block = memories
400            .iter()
401            .map(|m| format!("- [{}] {}", m.memory_type.label(), m.content))
402            .collect::<Vec<_>>()
403            .join("\n");
404
405        format!("{system_prompt}\n\n## Relevant Memory\n\n{memory_block}")
406    }
407
408    /// Recall with Flash Attention re-ranking (Phase 6).
409    #[cfg(feature = "sqlite-memory")]
410    pub async fn recall_with_rerank(&self, query: &str) -> anyhow::Result<Vec<MemoryEntry>> {
411        if let Some(ref sqlite) = self.sqlite_store {
412            return sqlite.recall_with_rerank(query, self.max_recall).await;
413        }
414        // Fallback to standard recall
415        self.recall(query).await
416    }
417
418    /// Check if a memory entry with identical content already exists.
419    ///
420    /// Uses a fast hash comparison against the in-memory vector index.
421    pub async fn is_duplicate(&self, content: &str) -> bool {
422        let hash = content_hash(content);
423
424        // Check semantic similarity via vector index first (fast)
425        let query_vector = match self.embedding.embed(content).await {
426            Ok(v) => v,
427            Err(_) => return false,
428        };
429        let similar = {
430            let index = self.vector_index.read();
431            index
432                .iter()
433                .any(|(_, vector)| query_vector.cosine_similarity(vector) > 0.95)
434        };
435        if similar {
436            return true;
437        }
438
439        // Then check exact content hash across all types
440        for mt in MemoryType::all() {
441            if let Ok(entries) = self.list(*mt, 1000).await {
442                for entry in entries {
443                    if content_hash(&entry.content) == hash {
444                        return true;
445                    }
446                }
447            }
448        }
449        false
450    }
451
452    /// Store a memory entry only if no duplicate content exists.
453    ///
454    /// Returns the entry ID if stored, or `None` if duplicate.
455    pub async fn remember_unique(&self, entry: MemoryEntry) -> anyhow::Result<Option<String>> {
456        #[cfg(feature = "sqlite-memory")]
457        if let Some(ref sqlite) = self.sqlite_store {
458            return sqlite.remember_unique(&entry).await;
459        }
460        if self.is_duplicate(&entry.content).await {
461            tracing::debug!(id = %entry.id, "Skipping duplicate memory");
462            return Ok(None);
463        }
464        let id = self.remember(entry).await?;
465        Ok(Some(id))
466    }
467
468    /// Recall with proactive enhancement.
469    ///
470    /// Extends the standard `recall()` with proactive memory injection
471    /// based on `RecallTiming` triggers.
472    pub async fn recall_with_proactive(
473        &self,
474        query: &str,
475        recall_timing: &mut Option<crate::memory::proactive::RecallTiming>,
476    ) -> anyhow::Result<Vec<MemoryEntry>> {
477        // Step 1: Standard recall
478        let mut combined = self.recall(query).await?;
479
480        // Step 2: Proactive enhancement based on timing triggers
481        let should_recall = recall_timing
482            .as_mut()
483            .map(|t| t.should_recall(query))
484            .unwrap_or(true);
485
486        if should_recall && combined.len() < self.max_recall {
487            #[cfg(feature = "sqlite-memory")]
488            if self.sqlite_store.is_some() {
489                let remaining = self.max_recall - combined.len();
490                let warm = self.list_by_tier(MemoryTier::Warm, remaining).await?;
491                let mut seen_ids: std::collections::HashSet<String> =
492                    combined.iter().map(|e| e.id.clone()).collect();
493                for entry in warm {
494                    if seen_ids.insert(entry.id.clone()) && combined.len() < self.max_recall {
495                        combined.push(entry);
496                    }
497                }
498            }
499
500            #[cfg(not(feature = "sqlite-memory"))]
501            {
502                let proactive = crate::memory::proactive::ProactiveRecall::new(5, 0.6);
503                let extra = proactive.recall(self, query, &combined).await?;
504                combined.extend(extra);
505                dedup_by_id(&mut combined);
506                combined.truncate(self.max_recall);
507            }
508
509            #[cfg(feature = "sqlite-memory")]
510            if self.sqlite_store.is_none() {
511                let proactive = crate::memory::proactive::ProactiveRecall::new(5, 0.6);
512                let extra = proactive.recall(self, query, &combined).await?;
513                combined.extend(extra);
514                dedup_by_id(&mut combined);
515                combined.truncate(self.max_recall);
516            }
517        }
518
519        Ok(combined)
520    }
521}