oxios_memory/memory/manager/
store.rs1use std::collections::HashMap;
7
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10
11use crate::memory::auto_protect::AutoProtector;
12use crate::memory::embedding::EmbeddingVector;
13use crate::memory::storage::MemoryStorageExt;
14#[cfg(feature = "sqlite-memory")]
15use crate::memory::types::MemoryTier;
16use crate::memory::types::{content_hash, dedup_by_id, extract_keywords, MemoryEntry, MemoryType};
17
18use super::MemoryManager;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
26struct VectorIndexSnapshot {
27 created_at: DateTime<Utc>,
29 entry_count: usize,
31 entries: HashMap<String, EmbeddingVector>,
33}
34
35impl MemoryManager {
40 pub async fn total_entries(&self) -> usize {
42 let mut total = 0;
43 for mt in MemoryType::all() {
44 if let Ok(entries) = self.list(*mt, 1_000_000).await {
45 total += entries.len();
46 }
47 }
48 total
49 }
50
51 pub async fn rebuild_index(&self) -> anyhow::Result<()> {
56 let mut entries_to_index: Vec<(String, EmbeddingVector)> = Vec::new();
58
59 for mt in MemoryType::all() {
60 if let Ok(names) = self.storage.list_category(mt.category()).await {
61 for name in names {
62 if let Ok(Some(entry)) = self
63 .storage
64 .load_json::<MemoryEntry>(mt.category(), &name)
65 .await
66 {
67 let vector = self.embedding.embed(&entry.content).await?;
68 entries_to_index.push((entry.id.clone(), vector));
69 }
70 }
71 }
72 }
73
74 {
76 let mut index = self.vector_index.write();
77 index.clear();
78 for (id, vector) in entries_to_index {
79 index.insert(id, vector);
80 }
81 }
82
83 tracing::info!(
84 entries = self.vector_index.read().len(),
85 "Memory vector index rebuilt"
86 );
87 Ok(())
88 }
89
90 pub async fn save_index_snapshot(&self) -> anyhow::Result<()> {
92 let snapshot = {
93 let index = self.vector_index.read();
94 VectorIndexSnapshot {
95 created_at: chrono::Utc::now(),
96 entry_count: index.len(),
97 entries: index.clone(),
98 }
99 };
100
101 self.storage
102 .save_json("memory", "vector_index_snapshot", &snapshot)
103 .await?;
104
105 self.git_commit("memory/vector_index_snapshot.json", "memory: snapshot save")
106 .await;
107
108 tracing::debug!(
109 entries = snapshot.entry_count,
110 "Vector index snapshot saved"
111 );
112 Ok(())
113 }
114
115 pub async fn load_index_snapshot(&self) -> anyhow::Result<usize> {
117 let snapshot: Option<VectorIndexSnapshot> = self
118 .storage
119 .load_json("memory", "vector_index_snapshot")
120 .await?;
121
122 match snapshot {
123 Some(snap) => {
124 let count = snap.entry_count;
125 let mut index = self.vector_index.write();
126 *index = snap.entries;
127 tracing::info!(entries = count, "Vector index snapshot loaded");
128 Ok(count)
129 }
130 None => {
131 tracing::debug!("No vector index snapshot found");
132 Ok(0)
133 }
134 }
135 }
136
137 pub async fn remember(&self, entry: MemoryEntry) -> anyhow::Result<String> {
143 #[cfg(feature = "sqlite-memory")]
145 if let Some(ref sqlite) = self.sqlite_store {
146 return sqlite.remember(&entry).await;
147 }
148
149 let id = entry.id.clone();
151 let vector = self.embedding.embed(&entry.content).await?;
152 let category = entry.memory_type.category();
153 self.storage.save_json(category, &id, &entry).await?;
154
155 self.git_commit(
156 &format!("{category}/{id}.json"),
157 &format!("memory: store {id}"),
158 )
159 .await;
160
161 {
163 let mut index = self.vector_index.write();
164 index.insert(id.clone(), vector.clone());
165 }
166
167 if let Some(f32_vec) = vector.to_f32_dense() {
169 let hnsw = self.hnsw_index.read();
170 if let Some(ref hnsw) = *hnsw {
171 if let Err(e) = hnsw.add_entry(&id, &f32_vec) {
172 tracing::warn!(id = %id, error = %e, "Failed to update HNSW index on remember");
173 }
174 }
175 }
176
177 tracing::debug!(id = %id, ty = entry.memory_type.label(), "Memory stored");
178 Ok(id)
179 }
180
181 pub async fn get(
185 &self,
186 id: &str,
187 memory_type: MemoryType,
188 ) -> anyhow::Result<Option<MemoryEntry>> {
189 #[cfg(feature = "sqlite-memory")]
190 if let Some(ref sqlite) = self.sqlite_store {
191 return sqlite.get(id, memory_type);
192 }
193 let result: Option<MemoryEntry> =
194 self.storage.load_json(memory_type.category(), id).await?;
195 if let Some(mut entry) = result {
196 AutoProtector::record_access(&mut entry, "");
197 Ok(Some(entry))
198 } else {
199 Ok(None)
200 }
201 }
202
203 pub async fn forget(&self, id: &str, memory_type: MemoryType) -> anyhow::Result<bool> {
205 #[cfg(feature = "sqlite-memory")]
206 if let Some(ref sqlite) = self.sqlite_store {
207 return sqlite.forget(id, memory_type);
208 }
209 let result = self.storage.delete_file(memory_type.category(), id).await?;
210
211 {
213 let hnsw = self.hnsw_index.read();
214 if let Some(ref hnsw) = *hnsw {
215 if let Err(e) = hnsw.remove_entry(id) {
216 tracing::warn!(id = %id, error = %e, "Failed to remove from HNSW index on forget");
217 }
218 }
219 }
220
221 Ok(result)
222 }
223
224 pub async fn list(
226 &self,
227 memory_type: MemoryType,
228 limit: usize,
229 ) -> anyhow::Result<Vec<MemoryEntry>> {
230 #[cfg(feature = "sqlite-memory")]
231 if let Some(ref sqlite) = self.sqlite_store {
232 return sqlite.list(memory_type, limit);
233 }
234 let category = memory_type.category();
235 let names = self.storage.list_category(category).await?;
236 let mut entries = Vec::new();
237 for name in names.into_iter().take(limit.saturating_mul(2)) {
238 if let Ok(Some(entry)) = self.storage.load_json::<MemoryEntry>(category, &name).await {
239 entries.push(entry);
240 }
241 }
242 entries.sort_by_key(|b| std::cmp::Reverse(b.created_at));
244 entries.truncate(limit);
245 Ok(entries)
246 }
247
248 pub async fn search(
253 &self,
254 query: &str,
255 memory_type: Option<MemoryType>,
256 limit: usize,
257 ) -> anyhow::Result<Vec<MemoryEntry>> {
258 #[cfg(feature = "sqlite-memory")]
259 if let Some(ref sqlite) = self.sqlite_store {
260 return sqlite.search(query, memory_type, limit).await;
261 }
262 let query_vector = self.embedding.embed(query).await?;
263
264 let scored: Vec<(String, f64)> = {
266 let index = self.vector_index.read();
267 let mut scored: Vec<(String, f64)> = index
268 .iter()
269 .map(|(id, vector)| {
270 let score = query_vector.cosine_similarity(vector);
271 (id.clone(), score)
272 })
273 .filter(|(_, score)| *score > 0.1)
274 .collect();
275 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
276 scored.truncate(limit);
277 scored
278 }; if scored.is_empty() {
282 return self.keyword_search(query, memory_type, limit).await;
283 }
284
285 let types: &[MemoryType] = match memory_type {
287 Some(ref t) => std::slice::from_ref(t),
288 None => MemoryType::all(),
289 };
290
291 let mut results = Vec::new();
293 for (id, score) in scored {
294 for mt in types {
295 if let Ok(Some(mut entry)) = self
296 .storage
297 .load_json::<MemoryEntry>(mt.category(), &id)
298 .await
299 {
300 AutoProtector::record_access(&mut entry, "");
301 tracing::debug!(id = %id, score, "Vector search hit");
302 results.push(entry);
303 break;
304 }
305 }
306 }
307
308 if results.is_empty() {
310 return self.keyword_search(query, memory_type, limit).await;
311 }
312
313 Ok(results)
314 }
315
316 pub(crate) async fn keyword_search(
318 &self,
319 query: &str,
320 memory_type: Option<MemoryType>,
321 limit: usize,
322 ) -> anyhow::Result<Vec<MemoryEntry>> {
323 let keywords = extract_keywords(query);
324 let types = match memory_type {
325 Some(t) => vec![t],
326 None => MemoryType::all().to_vec(),
327 };
328
329 let mut results = Vec::new();
330 for ty in &types {
331 let entries = self.list(*ty, limit * 2).await?;
332 for entry in entries {
333 let matches = keywords.iter().any(|k| {
334 let k_lower = k.to_lowercase();
335 entry.content.to_lowercase().contains(&k_lower)
336 || entry
337 .tags
338 .iter()
339 .any(|t| t.to_lowercase().contains(&k_lower))
340 });
341 if matches {
342 results.push(entry);
343 }
344 }
345 }
346
347 results.sort_by(|a, b| {
348 b.importance
349 .partial_cmp(&a.importance)
350 .unwrap_or(std::cmp::Ordering::Equal)
351 });
352 results.truncate(limit);
353 Ok(results)
354 }
355
356 pub async fn recall(&self, query: &str) -> anyhow::Result<Vec<MemoryEntry>> {
361 #[cfg(feature = "sqlite-memory")]
362 if let Some(ref sqlite) = self.sqlite_store {
363 return sqlite.recall(query, self.max_recall).await;
364 }
365 let limit = self.max_recall;
366
367 let recent = self
369 .list(MemoryType::Conversation, 3)
370 .await
371 .unwrap_or_default();
372
373 let sessions = self.list(MemoryType::Session, 2).await.unwrap_or_default();
375
376 let relevant = self.search(query, None, limit).await.unwrap_or_default();
378
379 let mut combined = recent;
381 combined.extend(sessions);
382 combined.extend(relevant);
383 dedup_by_id(&mut combined);
384 combined.truncate(limit);
385 Ok(combined)
386 }
387
388 pub fn blend_into_prompt(&self, memories: &[MemoryEntry], system_prompt: &str) -> String {
390 #[cfg(feature = "sqlite-memory")]
391 if let Some(ref sqlite) = self.sqlite_store {
392 return sqlite.blend_into_prompt(memories, system_prompt);
393 }
394
395 if memories.is_empty() {
396 return system_prompt.to_string();
397 }
398
399 let memory_block = memories
400 .iter()
401 .map(|m| format!("- [{}] {}", m.memory_type.label(), m.content))
402 .collect::<Vec<_>>()
403 .join("\n");
404
405 format!("{system_prompt}\n\n## Relevant Memory\n\n{memory_block}")
406 }
407
408 #[cfg(feature = "sqlite-memory")]
410 pub async fn recall_with_rerank(&self, query: &str) -> anyhow::Result<Vec<MemoryEntry>> {
411 if let Some(ref sqlite) = self.sqlite_store {
412 return sqlite.recall_with_rerank(query, self.max_recall).await;
413 }
414 self.recall(query).await
416 }
417
418 pub async fn is_duplicate(&self, content: &str) -> bool {
422 let hash = content_hash(content);
423
424 let query_vector = match self.embedding.embed(content).await {
426 Ok(v) => v,
427 Err(_) => return false,
428 };
429 let similar = {
430 let index = self.vector_index.read();
431 index
432 .iter()
433 .any(|(_, vector)| query_vector.cosine_similarity(vector) > 0.95)
434 };
435 if similar {
436 return true;
437 }
438
439 for mt in MemoryType::all() {
441 if let Ok(entries) = self.list(*mt, 1000).await {
442 for entry in entries {
443 if content_hash(&entry.content) == hash {
444 return true;
445 }
446 }
447 }
448 }
449 false
450 }
451
452 pub async fn remember_unique(&self, entry: MemoryEntry) -> anyhow::Result<Option<String>> {
456 #[cfg(feature = "sqlite-memory")]
457 if let Some(ref sqlite) = self.sqlite_store {
458 return sqlite.remember_unique(&entry).await;
459 }
460 if self.is_duplicate(&entry.content).await {
461 tracing::debug!(id = %entry.id, "Skipping duplicate memory");
462 return Ok(None);
463 }
464 let id = self.remember(entry).await?;
465 Ok(Some(id))
466 }
467
468 pub async fn recall_with_proactive(
473 &self,
474 query: &str,
475 recall_timing: &mut Option<crate::memory::proactive::RecallTiming>,
476 ) -> anyhow::Result<Vec<MemoryEntry>> {
477 let mut combined = self.recall(query).await?;
479
480 let should_recall = recall_timing
482 .as_mut()
483 .map(|t| t.should_recall(query))
484 .unwrap_or(true);
485
486 if should_recall && combined.len() < self.max_recall {
487 #[cfg(feature = "sqlite-memory")]
488 if self.sqlite_store.is_some() {
489 let remaining = self.max_recall - combined.len();
490 let warm = self.list_by_tier(MemoryTier::Warm, remaining).await?;
491 let mut seen_ids: std::collections::HashSet<String> =
492 combined.iter().map(|e| e.id.clone()).collect();
493 for entry in warm {
494 if seen_ids.insert(entry.id.clone()) && combined.len() < self.max_recall {
495 combined.push(entry);
496 }
497 }
498 }
499
500 #[cfg(not(feature = "sqlite-memory"))]
501 {
502 let proactive = crate::memory::proactive::ProactiveRecall::new(5, 0.6);
503 let extra = proactive.recall(self, query, &combined).await?;
504 combined.extend(extra);
505 dedup_by_id(&mut combined);
506 combined.truncate(self.max_recall);
507 }
508
509 #[cfg(feature = "sqlite-memory")]
510 if self.sqlite_store.is_none() {
511 let proactive = crate::memory::proactive::ProactiveRecall::new(5, 0.6);
512 let extra = proactive.recall(self, query, &combined).await?;
513 combined.extend(extra);
514 dedup_by_id(&mut combined);
515 combined.truncate(self.max_recall);
516 }
517 }
518
519 Ok(combined)
520 }
521}