Skip to main content

oxios_memory/memory/manager/
store.rs

1//! Memory store operations: save/load, index management, search.
2//!
3//! Integrates HNSW index (usearch) for fast approximate nearest neighbor search
4//! alongside the abstract storage backend for persistence.
5
6use std::collections::HashMap;
7
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10
11use crate::memory::auto_protect::AutoProtector;
12use crate::memory::embedding::EmbeddingVector;
13use crate::memory::storage::MemoryStorageExt;
14#[cfg(feature = "sqlite-memory")]
15use crate::memory::types::MemoryTier;
16use crate::memory::types::{MemoryEntry, MemoryType, content_hash, dedup_by_id, extract_keywords};
17
18use super::MemoryManager;
19
20// ---------------------------------------------------------------------------
21// VectorIndexSnapshot
22// ---------------------------------------------------------------------------
23
24/// Snapshot of the vector index for persistence.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26struct VectorIndexSnapshot {
27    /// Snapshot creation timestamp.
28    created_at: DateTime<Utc>,
29    /// Number of entries in the snapshot.
30    entry_count: usize,
31    /// Map of entry ID to embedding vector.
32    entries: HashMap<String, EmbeddingVector>,
33}
34
35// ---------------------------------------------------------------------------
36// Store & search operations
37// ---------------------------------------------------------------------------
38
39impl MemoryManager {
40    /// Returns total entries across all memory types (from disk).
41    pub async fn total_entries(&self) -> usize {
42        let mut total = 0;
43        for mt in MemoryType::all() {
44            if let Ok(entries) = self.list(*mt, 1_000_000).await {
45                total += entries.len();
46            }
47        }
48        total
49    }
50
51    /// Rebuild the vector index from all stored memories.
52    ///
53    /// Call once at startup to populate the in-memory index from
54    /// persisted memory entries.
55    pub async fn rebuild_index(&self) -> anyhow::Result<()> {
56        // Collect all entries outside the lock
57        let mut entries_to_index: Vec<(String, EmbeddingVector)> = Vec::new();
58
59        for mt in MemoryType::all() {
60            if let Ok(names) = self.storage.list_category(mt.category()).await {
61                for name in names {
62                    if let Ok(Some(entry)) = self
63                        .storage
64                        .load_json::<MemoryEntry>(mt.category(), &name)
65                        .await
66                    {
67                        let vector = self.embedding.embed(&entry.content).await?;
68                        entries_to_index.push((entry.id.clone(), vector));
69                    }
70                }
71            }
72        }
73
74        // Now acquire the lock only for the write
75        {
76            let mut index = self.vector_index.write();
77            index.clear();
78            for (id, vector) in entries_to_index {
79                index.insert(id, vector);
80            }
81        }
82
83        tracing::info!(
84            entries = self.vector_index.read().len(),
85            "Memory vector index rebuilt"
86        );
87        Ok(())
88    }
89
90    /// Save the current vector index to disk as a snapshot.
91    pub async fn save_index_snapshot(&self) -> anyhow::Result<()> {
92        let snapshot = {
93            let index = self.vector_index.read();
94            VectorIndexSnapshot {
95                created_at: chrono::Utc::now(),
96                entry_count: index.len(),
97                entries: index.clone(),
98            }
99        };
100
101        self.storage
102            .save_json("memory", "vector_index_snapshot", &snapshot)
103            .await?;
104
105        self.git_commit("memory/vector_index_snapshot.json", "memory: snapshot save")
106            .await;
107
108        tracing::debug!(
109            entries = snapshot.entry_count,
110            "Vector index snapshot saved"
111        );
112        Ok(())
113    }
114
115    /// Load a previously saved vector index snapshot from disk.
116    pub async fn load_index_snapshot(&self) -> anyhow::Result<usize> {
117        let snapshot: Option<VectorIndexSnapshot> = self
118            .storage
119            .load_json("memory", "vector_index_snapshot")
120            .await?;
121
122        match snapshot {
123            Some(snap) => {
124                let count = snap.entry_count;
125                let mut index = self.vector_index.write();
126                *index = snap.entries;
127                tracing::info!(entries = count, "Vector index snapshot loaded");
128                Ok(count)
129            }
130            None => {
131                tracing::debug!("No vector index snapshot found");
132                Ok(0)
133            }
134        }
135    }
136
137    /// Store a memory entry. Returns the entry ID.
138    ///
139    /// When SQLite backend is enabled, delegates to `SqliteMemoryStore`.
140    /// Otherwise computes and stores the entry's text vector in the in-memory
141    /// index for future semantic search.
142    pub async fn remember(&self, entry: MemoryEntry) -> anyhow::Result<String> {
143        // ── SQLite fast path (RFC-012) ──
144        #[cfg(feature = "sqlite-memory")]
145        if let Some(ref sqlite) = self.sqlite_store {
146            return sqlite.remember(&entry).await;
147        }
148
149        // ── Legacy JSON path ──
150        let id = entry.id.clone();
151        let vector = self.embedding.embed(&entry.content).await?;
152        let category = entry.memory_type.category();
153        self.storage.save_json(category, &id, &entry).await?;
154
155        self.git_commit(
156            &format!("{category}/{id}.json"),
157            &format!("memory: store {id}"),
158        )
159        .await;
160
161        // Update vector index
162        {
163            let mut index = self.vector_index.write();
164            index.insert(id.clone(), vector.clone());
165        }
166
167        // Update HNSW index if attached
168        if let Some(f32_vec) = vector.to_f32_dense() {
169            let hnsw = self.hnsw_index.read();
170            if let Some(ref hnsw) = *hnsw
171                && let Err(e) = hnsw.add_entry(&id, &f32_vec)
172            {
173                tracing::warn!(id = %id, error = %e, "Failed to update HNSW index on remember");
174            }
175        }
176
177        tracing::debug!(id = %id, ty = entry.memory_type.label(), "Memory stored");
178        Ok(id)
179    }
180
181    /// Retrieve a single memory by ID.
182    ///
183    /// This is a pure read — it does NOT mutate access counters. Bumping
184    /// `access_count`/`accessed_at`/`seen_in_sessions` here polluted entries
185    /// returned to callers (notably `get_by_id`, which scans all 9 types) with
186    /// access-tracking side-effects that were never persisted anyway. Access
187    /// recording belongs to explicit recall/touch paths, not to reads.
188    pub async fn get(
189        &self,
190        id: &str,
191        memory_type: MemoryType,
192    ) -> anyhow::Result<Option<MemoryEntry>> {
193        #[cfg(feature = "sqlite-memory")]
194        if let Some(ref sqlite) = self.sqlite_store {
195            return sqlite.get(id, memory_type);
196        }
197        self.storage.load_json(memory_type.category(), id).await
198    }
199
200    /// Delete a memory entry.
201    pub async fn forget(&self, id: &str, memory_type: MemoryType) -> anyhow::Result<bool> {
202        #[cfg(feature = "sqlite-memory")]
203        if let Some(ref sqlite) = self.sqlite_store {
204            return sqlite.forget(id, memory_type);
205        }
206        let result = self.storage.delete_file(memory_type.category(), id).await?;
207
208        // Remove from HNSW index if attached
209        {
210            let hnsw = self.hnsw_index.read();
211            if let Some(ref hnsw) = *hnsw
212                && let Err(e) = hnsw.remove_entry(id)
213            {
214                tracing::warn!(id = %id, error = %e, "Failed to remove from HNSW index on forget");
215            }
216        }
217
218        Ok(result)
219    }
220
221    /// List memories of a given type, most recent first.
222    pub async fn list(
223        &self,
224        memory_type: MemoryType,
225        limit: usize,
226    ) -> anyhow::Result<Vec<MemoryEntry>> {
227        #[cfg(feature = "sqlite-memory")]
228        if let Some(ref sqlite) = self.sqlite_store {
229            return sqlite.list(memory_type, limit);
230        }
231        let category = memory_type.category();
232        let names = self.storage.list_category(category).await?;
233        let mut entries = Vec::new();
234        for name in names.into_iter().take(limit.saturating_mul(2)) {
235            if let Ok(Some(entry)) = self.storage.load_json::<MemoryEntry>(category, &name).await {
236                entries.push(entry);
237            }
238        }
239        // Sort by created_at descending (most recent first)
240        entries.sort_by_key(|b| std::cmp::Reverse(b.created_at));
241        entries.truncate(limit);
242        Ok(entries)
243    }
244
245    /// Search memories by semantic similarity (vector search).
246    ///
247    /// Falls back to keyword search when the vector index is empty or
248    /// yields no results above the similarity threshold.
249    pub async fn search(
250        &self,
251        query: &str,
252        memory_type: Option<MemoryType>,
253        limit: usize,
254    ) -> anyhow::Result<Vec<MemoryEntry>> {
255        #[cfg(feature = "sqlite-memory")]
256        if let Some(ref sqlite) = self.sqlite_store {
257            return sqlite.search(query, memory_type, limit).await;
258        }
259        let query_vector = self.embedding.embed(query).await?;
260
261        // Scope the read lock: compute scores, then drop before any await.
262        let scored: Vec<(String, f64)> = {
263            let index = self.vector_index.read();
264            let mut scored: Vec<(String, f64)> = index
265                .iter()
266                .map(|(id, vector)| {
267                    let score = query_vector.cosine_similarity(vector);
268                    (id.clone(), score)
269                })
270                .filter(|(_, score)| *score > 0.1)
271                .collect();
272            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
273            scored.truncate(limit);
274            scored
275        }; // lock dropped here, before any .await
276
277        // If index was empty, scored will be empty — fall back immediately
278        if scored.is_empty() {
279            return self.keyword_search(query, memory_type, limit).await;
280        }
281
282        // Determine which memory types to search
283        let types: &[MemoryType] = match memory_type {
284            Some(ref t) => std::slice::from_ref(t),
285            None => MemoryType::all(),
286        };
287
288        // Load entries from storage (no lock held)
289        let mut results = Vec::new();
290        for (id, score) in scored {
291            for mt in types {
292                if let Ok(Some(mut entry)) = self
293                    .storage
294                    .load_json::<MemoryEntry>(mt.category(), &id)
295                    .await
296                {
297                    AutoProtector::record_access(&mut entry, "");
298                    tracing::debug!(id = %id, score, "Vector search hit");
299                    results.push(entry);
300                    break;
301                }
302            }
303        }
304
305        // Fall back to keyword search if no results
306        if results.is_empty() {
307            return self.keyword_search(query, memory_type, limit).await;
308        }
309
310        Ok(results)
311    }
312
313    /// Keyword-based search (original algorithm, used as fallback).
314    pub(crate) async fn keyword_search(
315        &self,
316        query: &str,
317        memory_type: Option<MemoryType>,
318        limit: usize,
319    ) -> anyhow::Result<Vec<MemoryEntry>> {
320        let keywords = extract_keywords(query);
321        let types = match memory_type {
322            Some(t) => vec![t],
323            None => MemoryType::all().to_vec(),
324        };
325
326        let mut results = Vec::new();
327        for ty in &types {
328            let entries = self.list(*ty, limit * 2).await?;
329            for entry in entries {
330                let matches = keywords.iter().any(|k| {
331                    let k_lower = k.to_lowercase();
332                    entry.content.to_lowercase().contains(&k_lower)
333                        || entry
334                            .tags
335                            .iter()
336                            .any(|t| t.to_lowercase().contains(&k_lower))
337                });
338                if matches {
339                    results.push(entry);
340                }
341            }
342        }
343
344        results.sort_by(|a, b| {
345            b.importance
346                .partial_cmp(&a.importance)
347                .unwrap_or(std::cmp::Ordering::Equal)
348        });
349        results.truncate(limit);
350        Ok(results)
351    }
352
353    /// Recall relevant memories for a new session.
354    ///
355    /// Combines recent conversation summaries, session summaries,
356    /// and keyword-matched facts/episodes.
357    pub async fn recall(&self, query: &str) -> anyhow::Result<Vec<MemoryEntry>> {
358        #[cfg(feature = "sqlite-memory")]
359        if let Some(ref sqlite) = self.sqlite_store {
360            return sqlite.recall(query, self.max_recall).await;
361        }
362        let limit = self.max_recall;
363
364        // 1. Recent conversation summaries (always include)
365        let recent = self
366            .list(MemoryType::Conversation, 3)
367            .await
368            .unwrap_or_default();
369
370        // 2. Recent session summaries
371        let sessions = self.list(MemoryType::Session, 2).await.unwrap_or_default();
372
373        // 3. Keyword-matched facts and episodes
374        let relevant = self.search(query, None, limit).await.unwrap_or_default();
375
376        // 4. Combine and deduplicate
377        let mut combined = recent;
378        combined.extend(sessions);
379        combined.extend(relevant);
380        dedup_by_id(&mut combined);
381        combined.truncate(limit);
382        Ok(combined)
383    }
384
385    /// Blend recalled memories into the system prompt.
386    pub fn blend_into_prompt(&self, memories: &[MemoryEntry], system_prompt: &str) -> String {
387        #[cfg(feature = "sqlite-memory")]
388        if let Some(ref sqlite) = self.sqlite_store {
389            return sqlite.blend_into_prompt(memories, system_prompt);
390        }
391
392        if memories.is_empty() {
393            return system_prompt.to_string();
394        }
395
396        let memory_block = memories
397            .iter()
398            .map(|m| format!("- [{}] {}", m.memory_type.label(), m.content))
399            .collect::<Vec<_>>()
400            .join("\n");
401
402        format!("{system_prompt}\n\n## Relevant Memory\n\n{memory_block}")
403    }
404
405    /// Recall with Flash Attention re-ranking (Phase 6).
406    #[cfg(feature = "sqlite-memory")]
407    pub async fn recall_with_rerank(&self, query: &str) -> anyhow::Result<Vec<MemoryEntry>> {
408        if let Some(ref sqlite) = self.sqlite_store {
409            return sqlite.recall_with_rerank(query, self.max_recall).await;
410        }
411        // Fallback to standard recall
412        self.recall(query).await
413    }
414
415    /// Check if a memory entry with identical content already exists.
416    ///
417    /// Uses a fast hash comparison against the in-memory vector index.
418    pub async fn is_duplicate(&self, content: &str) -> bool {
419        let hash = content_hash(content);
420
421        // Check semantic similarity via vector index first (fast)
422        let query_vector = match self.embedding.embed(content).await {
423            Ok(v) => v,
424            Err(_) => return false,
425        };
426        let similar = {
427            let index = self.vector_index.read();
428            index
429                .iter()
430                .any(|(_, vector)| query_vector.cosine_similarity(vector) > 0.95)
431        };
432        if similar {
433            return true;
434        }
435
436        // Then check exact content hash across all types
437        for mt in MemoryType::all() {
438            if let Ok(entries) = self.list(*mt, 1000).await {
439                for entry in entries {
440                    if content_hash(&entry.content) == hash {
441                        return true;
442                    }
443                }
444            }
445        }
446        false
447    }
448
449    /// Store a memory entry only if no duplicate content exists.
450    ///
451    /// Returns the entry ID if stored, or `None` if duplicate.
452    pub async fn remember_unique(&self, entry: MemoryEntry) -> anyhow::Result<Option<String>> {
453        #[cfg(feature = "sqlite-memory")]
454        if let Some(ref sqlite) = self.sqlite_store {
455            return sqlite.remember_unique(&entry).await;
456        }
457        if self.is_duplicate(&entry.content).await {
458            tracing::debug!(id = %entry.id, "Skipping duplicate memory");
459            return Ok(None);
460        }
461        let id = self.remember(entry).await?;
462        Ok(Some(id))
463    }
464
465    /// Recall with proactive enhancement.
466    ///
467    /// Extends the standard `recall()` with proactive memory injection
468    /// based on `RecallTiming` triggers.
469    pub async fn recall_with_proactive(
470        &self,
471        query: &str,
472        recall_timing: &mut Option<crate::memory::proactive::RecallTiming>,
473    ) -> anyhow::Result<Vec<MemoryEntry>> {
474        // Step 1: Standard recall
475        let mut combined = self.recall(query).await?;
476
477        // Step 2: Proactive enhancement based on timing triggers
478        let should_recall = recall_timing
479            .as_mut()
480            .map(|t| t.should_recall(query))
481            .unwrap_or(true);
482
483        if should_recall && combined.len() < self.max_recall {
484            #[cfg(feature = "sqlite-memory")]
485            if self.sqlite_store.is_some() {
486                let remaining = self.max_recall - combined.len();
487                let warm = self.list_by_tier(MemoryTier::Warm, remaining).await?;
488                let mut seen_ids: std::collections::HashSet<String> =
489                    combined.iter().map(|e| e.id.clone()).collect();
490                for entry in warm {
491                    if seen_ids.insert(entry.id.clone()) && combined.len() < self.max_recall {
492                        combined.push(entry);
493                    }
494                }
495            }
496
497            #[cfg(not(feature = "sqlite-memory"))]
498            {
499                let proactive = crate::memory::proactive::ProactiveRecall::new(5, 0.6);
500                let extra = proactive.recall(self, query, &combined).await?;
501                combined.extend(extra);
502                dedup_by_id(&mut combined);
503                combined.truncate(self.max_recall);
504            }
505
506            #[cfg(feature = "sqlite-memory")]
507            if self.sqlite_store.is_none() {
508                let proactive = crate::memory::proactive::ProactiveRecall::new(5, 0.6);
509                let extra = proactive.recall(self, query, &combined).await?;
510                combined.extend(extra);
511                dedup_by_id(&mut combined);
512                combined.truncate(self.max_recall);
513            }
514        }
515
516        Ok(combined)
517    }
518}