1use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::atomic::{AtomicU64, Ordering};
9
10use anyhow::Result;
11use chrono::{DateTime, Utc};
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14
15use crate::embedding::EmbeddingVector;
16
17use super::hnsw::HnswIndex;
18use super::normalizer::l2_normalize_f32;
19use super::{content_hash, dedup_by_id, extract_keywords, MemoryEntry, MemoryManager, MemoryType};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
27struct VectorIndexSnapshot {
28 created_at: DateTime<Utc>,
30 entry_count: usize,
32 entries: HashMap<String, EmbeddingVector>,
34}
35
36impl MemoryManager {
41 pub async fn total_entries(&self) -> usize {
43 let mut total = 0;
44 for mt in [
45 MemoryType::Conversation,
46 MemoryType::Session,
47 MemoryType::Fact,
48 MemoryType::Episode,
49 MemoryType::Knowledge,
50 ] {
51 if let Ok(entries) = self.list(mt, 1_000_000).await {
53 total += entries.len();
54 }
55 }
56 total
57 }
58
59 pub async fn rebuild_index(&self) -> Result<()> {
64 let mut entries_to_index: Vec<(String, EmbeddingVector)> = Vec::new();
66
67 for mt in &[
68 MemoryType::Conversation,
69 MemoryType::Session,
70 MemoryType::Fact,
71 MemoryType::Episode,
72 MemoryType::Knowledge,
73 ] {
74 if let Ok(names) = self.state_store.list_category(mt.category()).await {
75 for name in names {
76 if let Ok(Some(entry)) = self
77 .state_store
78 .load_json::<MemoryEntry>(mt.category(), &name)
79 .await
80 {
81 let vector = self.embedding.embed(&entry.content).await?;
82 entries_to_index.push((entry.id.clone(), vector));
83 }
84 }
85 }
86 }
87
88 {
90 let mut index = self.vector_index.write();
91 index.clear();
92 for (id, vector) in entries_to_index {
93 index.insert(id, vector);
94 }
95 }
96
97 tracing::info!(
98 entries = self.vector_index.read().len(),
99 "Memory vector index rebuilt"
100 );
101 Ok(())
102 }
103
104 pub async fn save_index_snapshot(&self) -> Result<()> {
106 let snapshot = {
107 let index = self.vector_index.read();
108 VectorIndexSnapshot {
109 created_at: chrono::Utc::now(),
110 entry_count: index.len(),
111 entries: index.clone(),
112 }
113 };
114
115 self.state_store
116 .save_json("memory", "vector_index_snapshot", &snapshot)
117 .await?;
118
119 self.git_commit("memory/vector_index_snapshot.json", "memory: snapshot save");
120
121 tracing::debug!(
122 entries = snapshot.entry_count,
123 "Vector index snapshot saved"
124 );
125 Ok(())
126 }
127
128 pub async fn load_index_snapshot(&self) -> Result<usize> {
130 let snapshot: Option<VectorIndexSnapshot> = self
131 .state_store
132 .load_json("memory", "vector_index_snapshot")
133 .await?;
134
135 match snapshot {
136 Some(snap) => {
137 let count = snap.entry_count;
138 let mut index = self.vector_index.write();
139 *index = snap.entries;
140 tracing::info!(entries = count, "Vector index snapshot loaded");
141 Ok(count)
142 }
143 None => {
144 tracing::debug!("No vector index snapshot found");
145 Ok(0)
146 }
147 }
148 }
149
150 pub async fn remember(&self, entry: MemoryEntry) -> Result<String> {
155 let id = entry.id.clone();
156 let vector = self.embedding.embed(&entry.content).await?;
157 let category = entry.memory_type.category();
158 self.state_store.save_json(category, &id, &entry).await?;
159
160 self.git_commit(
161 &format!("{}/{}.json", category, id),
162 &format!("memory: store {}", id),
163 );
164
165 {
167 let mut index = self.vector_index.write();
168 index.insert(id.clone(), vector.clone());
169 }
170
171 if let Some(f32_vec) = vector.to_f32_dense() {
173 let hnsw = self.hnsw_index.read();
174 if let Some(ref hnsw) = *hnsw {
175 if let Err(e) = hnsw.add_entry(&id, &f32_vec) {
176 tracing::warn!(id = %id, error = %e, "Failed to update HNSW index on remember");
177 }
178 }
179 }
180
181 tracing::debug!(id = %id, ty = entry.memory_type.label(), "Memory stored");
182 Ok(id)
183 }
184
185 pub async fn get(&self, id: &str, memory_type: MemoryType) -> Result<Option<MemoryEntry>> {
187 self.state_store.load_json(memory_type.category(), id).await
188 }
189
190 pub async fn forget(&self, id: &str, memory_type: MemoryType) -> Result<bool> {
192 let result = self
193 .state_store
194 .delete_file(memory_type.category(), id)
195 .await?;
196
197 {
199 let hnsw = self.hnsw_index.read();
200 if let Some(ref hnsw) = *hnsw {
201 if let Err(e) = hnsw.remove_entry(id) {
202 tracing::warn!(id = %id, error = %e, "Failed to remove from HNSW index on forget");
203 }
204 }
205 }
206
207 Ok(result)
208 }
209
210 pub async fn list(&self, memory_type: MemoryType, limit: usize) -> Result<Vec<MemoryEntry>> {
212 let category = memory_type.category();
213 let names = self.state_store.list_category(category).await?;
214 let mut entries = Vec::new();
215 for name in names.into_iter().take(limit.saturating_mul(2)) {
216 if let Ok(Some(entry)) = self
217 .state_store
218 .load_json::<MemoryEntry>(category, &name)
219 .await
220 {
221 entries.push(entry);
222 }
223 }
224 entries.sort_by_key(|b| std::cmp::Reverse(b.created_at));
226 entries.truncate(limit);
227 Ok(entries)
228 }
229
230 pub async fn search(
235 &self,
236 query: &str,
237 memory_type: Option<MemoryType>,
238 limit: usize,
239 ) -> Result<Vec<MemoryEntry>> {
240 let query_vector = self.embedding.embed(query).await?;
241
242 let scored: Vec<(String, f64)> = {
244 let index = self.vector_index.read();
245 let mut scored: Vec<(String, f64)> = index
246 .iter()
247 .map(|(id, vector)| {
248 let score = query_vector.cosine_similarity(vector);
249 (id.clone(), score)
250 })
251 .filter(|(_, score)| *score > 0.1)
252 .collect();
253 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
254 scored.truncate(limit);
255 scored
256 }; if scored.is_empty() {
260 return self.keyword_search(query, memory_type, limit).await;
261 }
262
263 let all_types: &[MemoryType] = &[
265 MemoryType::Conversation,
266 MemoryType::Session,
267 MemoryType::Fact,
268 MemoryType::Episode,
269 MemoryType::Knowledge,
270 ];
271 let types: &[MemoryType] = match memory_type {
272 Some(ref t) => std::slice::from_ref(t),
273 None => all_types,
274 };
275
276 let mut results = Vec::new();
278 for (id, score) in scored {
279 for mt in types {
280 if let Ok(Some(mut entry)) = self
281 .state_store
282 .load_json::<MemoryEntry>(mt.category(), &id)
283 .await
284 {
285 entry.access_count += 1;
286 entry.accessed_at = chrono::Utc::now();
287 tracing::debug!(id = %id, score, "Vector search hit");
288 results.push(entry);
289 break;
290 }
291 }
292 }
293
294 if results.is_empty() {
296 return self.keyword_search(query, memory_type, limit).await;
297 }
298
299 Ok(results)
300 }
301
302 async fn keyword_search(
304 &self,
305 query: &str,
306 memory_type: Option<MemoryType>,
307 limit: usize,
308 ) -> Result<Vec<MemoryEntry>> {
309 let keywords = extract_keywords(query);
310 let types = match memory_type {
311 Some(t) => vec![t],
312 None => vec![
313 MemoryType::Conversation,
314 MemoryType::Fact,
315 MemoryType::Episode,
316 MemoryType::Knowledge,
317 ],
318 };
319
320 let mut results = Vec::new();
321 for ty in &types {
322 let entries = self.list(*ty, limit * 2).await?;
323 for entry in entries {
324 let matches = keywords.iter().any(|k| {
325 let k_lower = k.to_lowercase();
326 entry.content.to_lowercase().contains(&k_lower)
327 || entry
328 .tags
329 .iter()
330 .any(|t| t.to_lowercase().contains(&k_lower))
331 });
332 if matches {
333 results.push(entry);
334 }
335 }
336 }
337
338 results.sort_by(|a, b| {
339 b.importance
340 .partial_cmp(&a.importance)
341 .unwrap_or(std::cmp::Ordering::Equal)
342 });
343 results.truncate(limit);
344 Ok(results)
345 }
346
347 pub async fn recall(&self, query: &str) -> Result<Vec<MemoryEntry>> {
352 let limit = self.max_recall;
353
354 let recent = self
356 .list(MemoryType::Conversation, 3)
357 .await
358 .unwrap_or_default();
359
360 let sessions = self.list(MemoryType::Session, 2).await.unwrap_or_default();
362
363 let relevant = self.search(query, None, limit).await.unwrap_or_default();
365
366 let mut combined = recent;
368 combined.extend(sessions);
369 combined.extend(relevant);
370 dedup_by_id(&mut combined);
371 combined.truncate(limit);
372 Ok(combined)
373 }
374
375 pub fn blend_into_prompt(&self, memories: &[MemoryEntry], system_prompt: &str) -> String {
377 if memories.is_empty() {
378 return system_prompt.to_string();
379 }
380
381 let memory_block = memories
382 .iter()
383 .map(|m| format!("- [{}] {}", m.memory_type.label(), m.content))
384 .collect::<Vec<_>>()
385 .join("\n");
386
387 format!("{system_prompt}\n\n## Relevant Memory\n\n{memory_block}")
388 }
389
390 pub async fn summarize_session(
395 &self,
396 session: &crate::state_store::Session,
397 ) -> Result<Option<String>> {
398 if session.user_messages.is_empty() {
399 return Ok(None);
400 }
401
402 let mut summary_parts = Vec::new();
404
405 if let Some(first_msg) = session.user_messages.first() {
407 summary_parts.push(format!("User: {}", first_msg.content));
408 }
409
410 if let Some(last_response) = session.agent_responses.last() {
412 let truncated = if last_response.content.len() > 500 {
413 format!("{}...", &last_response.content[..500])
414 } else {
415 last_response.content.clone()
416 };
417 summary_parts.push(format!("Agent: {}", truncated));
418 }
419
420 if let Some(ref seed_id) = session.active_seed_id {
422 summary_parts.push(format!("Seed: {}", seed_id));
423 }
424 if let Some(ref persona_id) = session.active_persona_id {
425 summary_parts.push(format!("Persona: {}", persona_id));
426 }
427
428 let content = summary_parts.join("\n");
429 let entry = MemoryEntry {
430 id: format!(
431 "session-{}-{}",
432 session.id.0,
433 chrono::Utc::now().timestamp()
434 ),
435 memory_type: MemoryType::Session,
436 content,
437 source: "session_summary".to_string(),
438 session_id: Some(session.id.0.clone()),
439 tags: vec![],
440 importance: 0.6,
441 created_at: chrono::Utc::now(),
442 accessed_at: chrono::Utc::now(),
443 access_count: 0,
444 };
445
446 let id = self.remember(entry).await?;
447 Ok(Some(id))
448 }
449
450 pub async fn is_duplicate(&self, content: &str) -> bool {
454 let hash = content_hash(content);
455
456 let query_vector = match self.embedding.embed(content).await {
458 Ok(v) => v,
459 Err(_) => return false,
460 };
461 let similar = {
462 let index = self.vector_index.read();
463 index
464 .iter()
465 .any(|(_, vector)| query_vector.cosine_similarity(vector) > 0.95)
466 };
467 if similar {
468 return true;
469 }
470
471 for mt in &[
473 MemoryType::Conversation,
474 MemoryType::Session,
475 MemoryType::Fact,
476 MemoryType::Episode,
477 MemoryType::Knowledge,
478 ] {
479 if let Ok(entries) = self.list(*mt, 1000).await {
480 for entry in entries {
481 if content_hash(&entry.content) == hash {
482 return true;
483 }
484 }
485 }
486 }
487 false
488 }
489
490 pub async fn remember_unique(&self, entry: MemoryEntry) -> Result<Option<String>> {
494 if self.is_duplicate(&entry.content).await {
495 tracing::debug!(id = %entry.id, "Skipping duplicate memory");
496 return Ok(None);
497 }
498 let id = self.remember(entry).await?;
499 Ok(Some(id))
500 }
501}
502
503#[derive(Debug, Clone, Serialize, Deserialize)]
509pub struct SemanticHit {
510 pub entry: MemoryEntry,
512 pub distance: f32,
514 pub similarity: f32,
516}
517
518pub struct HnswMemoryIndex {
523 index: RwLock<HnswIndex>,
525 key_to_id: RwLock<HashMap<u64, String>>,
527 id_to_key: RwLock<HashMap<String, u64>>,
529 next_key: AtomicU64,
531 persist_path: Option<PathBuf>,
533}
534
535impl std::fmt::Debug for HnswMemoryIndex {
536 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
537 f.debug_struct("HnswMemoryIndex")
538 .field("size", &self.len())
539 .field("dimensions", &self.index.read().dimensions())
540 .finish()
541 }
542}
543
544impl HnswMemoryIndex {
545 pub fn new(dimensions: usize, capacity: usize, persist_path: Option<PathBuf>) -> Result<Self> {
552 let index = HnswIndex::new(dimensions, capacity)?;
553 Ok(Self {
554 index: RwLock::new(index),
555 key_to_id: RwLock::new(HashMap::new()),
556 id_to_key: RwLock::new(HashMap::new()),
557 next_key: AtomicU64::new(1), persist_path,
559 })
560 }
561
562 pub fn restore_or_new(
564 dimensions: usize,
565 capacity: usize,
566 persist_path: Option<PathBuf>,
567 ) -> Result<Self> {
568 if let Some(ref path) = persist_path {
569 let index_path = path.join("memory.usearch");
570 let mapping_path = path.join("key_map.json");
571
572 if index_path.exists() && mapping_path.exists() {
573 tracing::info!(path = %index_path.display(), "Restoring HNSW index from disk");
574
575 if let Ok(index) = HnswIndex::load(&index_path) {
576 if let Ok(data) = std::fs::read_to_string(&mapping_path) {
577 if let Ok((k2i, i2k)) = serde_json::from_str::<(
578 HashMap<u64, String>,
579 HashMap<String, u64>,
580 )>(&data)
581 {
582 let max_key = k2i.keys().max().copied().unwrap_or(0);
583 return Ok(Self {
584 index: RwLock::new(index),
585 key_to_id: RwLock::new(k2i),
586 id_to_key: RwLock::new(i2k),
587 next_key: AtomicU64::new(max_key + 1),
588 persist_path,
589 });
590 }
591 }
592 }
593
594 tracing::warn!("Failed to restore HNSW index, creating new one");
595 }
596 }
597
598 Self::new(dimensions, capacity, persist_path)
599 }
600
601 fn get_or_create_key(&self, id: &str) -> u64 {
603 {
605 let i2k = self.id_to_key.read();
606 if let Some(&key) = i2k.get(id) {
607 return key;
608 }
609 }
610
611 let mut i2k = self.id_to_key.write();
613 let mut k2i = self.key_to_id.write();
614
615 if let Some(&key) = i2k.get(id) {
617 return key;
618 }
619
620 let key = self.next_key.fetch_add(1, Ordering::Relaxed);
621 i2k.insert(id.to_string(), key);
622 k2i.insert(key, id.to_string());
623 key
624 }
625
626 pub fn add_entry(&self, id: &str, vector: &[f32]) -> Result<()> {
628 let key = self.get_or_create_key(id);
629 let mut normalized = vector.to_vec();
630 l2_normalize_f32(&mut normalized);
631 self.index.write().add(key, &normalized)?;
632 Ok(())
633 }
634
635 pub fn remove_entry(&self, id: &str) -> Result<()> {
637 let key = {
638 let i2k = self.id_to_key.read();
639 i2k.get(id).copied()
640 };
641 if let Some(key) = key {
642 self.index.write().remove(key)?;
643 let mut k2i = self.key_to_id.write();
644 let mut i2k = self.id_to_key.write();
645 k2i.remove(&key);
646 i2k.remove(id);
647 }
648 Ok(())
649 }
650
651 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
655 let mut normalized = query.to_vec();
656 l2_normalize_f32(&mut normalized);
657
658 let raw = self.index.read().search(&normalized, k)?;
659 let k2i = self.key_to_id.read();
660
661 let results = raw
662 .into_iter()
663 .filter_map(|(key, dist)| k2i.get(&key).map(|id| (id.clone(), dist)))
664 .collect();
665
666 Ok(results)
667 }
668
669 pub fn len(&self) -> usize {
671 self.index.read().len()
672 }
673
674 pub fn is_empty(&self) -> bool {
676 self.index.read().is_empty()
677 }
678
679 pub fn persist(&self) -> Result<()> {
681 if let Some(ref path) = self.persist_path {
682 std::fs::create_dir_all(path)?;
683
684 let index_path = path.join("memory.usearch");
685 let mapping_path = path.join("key_map.json");
686
687 self.index.read().save(&index_path)?;
689
690 let k2i = self.key_to_id.read();
692 let i2k = self.id_to_key.read();
693 let data = serde_json::to_string(&(k2i.clone(), &*i2k))?;
694 std::fs::write(&mapping_path, data)?;
695
696 tracing::debug!(path = %path.display(), entries = self.len(), "HNSW index persisted");
697 }
698 Ok(())
699 }
700}
701
702impl MemoryManager {
707 pub async fn semantic_search(
725 &self,
726 query: &str,
727 memory_type: Option<MemoryType>,
728 limit: usize,
729 hnsw_index: &HnswMemoryIndex,
730 ) -> Result<Vec<SemanticHit>> {
731 if hnsw_index.is_empty() {
733 tracing::debug!("HNSW index empty, falling back to keyword search");
734 return self
735 .keyword_search(query, memory_type, limit)
736 .await
737 .map(|entries| {
738 entries
739 .into_iter()
740 .map(|entry| SemanticHit {
741 entry,
742 distance: 0.0,
743 similarity: 0.0,
744 })
745 .collect()
746 });
747 }
748
749 let query_vector = self.embedding.embed(query).await?;
751 let query_f32 = match query_vector.to_f32_dense() {
752 Some(v) => v,
753 None => {
754 tracing::debug!("Query embedding is sparse, falling back to keyword search");
755 return self
756 .keyword_search(query, memory_type, limit)
757 .await
758 .map(|entries| {
759 entries
760 .into_iter()
761 .map(|entry| SemanticHit {
762 entry,
763 distance: 0.0,
764 similarity: 0.0,
765 })
766 .collect()
767 });
768 }
769 };
770
771 let raw_hits = hnsw_index.search(&query_f32, limit * 2)?;
773
774 let all_types: &[MemoryType] = &[
776 MemoryType::Conversation,
777 MemoryType::Session,
778 MemoryType::Fact,
779 MemoryType::Episode,
780 MemoryType::Knowledge,
781 ];
782 let types: &[MemoryType] = match memory_type {
783 Some(ref t) => std::slice::from_ref(t),
784 None => all_types,
785 };
786
787 let mut results = Vec::new();
789 for (id, distance) in raw_hits {
790 for mt in types {
791 if let Ok(Some(mut entry)) = self
792 .state_store
793 .load_json::<MemoryEntry>(mt.category(), &id)
794 .await
795 {
796 entry.access_count += 1;
798 entry.accessed_at = chrono::Utc::now();
799
800 let similarity = 1.0 - distance;
801 results.push(SemanticHit {
802 entry,
803 distance,
804 similarity,
805 });
806 break;
807 }
808 }
809 if results.len() >= limit {
810 break;
811 }
812 }
813
814 results.sort_by(|a, b| {
816 b.similarity
817 .partial_cmp(&a.similarity)
818 .unwrap_or(std::cmp::Ordering::Equal)
819 });
820
821 tracing::debug!(
822 query = %query,
823 hits = results.len(),
824 "Semantic search complete"
825 );
826
827 if results.is_empty() {
829 return self
830 .keyword_search(query, memory_type, limit)
831 .await
832 .map(|entries| {
833 entries
834 .into_iter()
835 .map(|entry| SemanticHit {
836 entry,
837 distance: 0.0,
838 similarity: 0.0,
839 })
840 .collect()
841 });
842 }
843
844 Ok(results)
845 }
846
847 pub async fn rebuild_hnsw_index(&self, hnsw_index: &HnswMemoryIndex) -> Result<usize> {
851 let mut count = 0;
852
853 for mt in &[
854 MemoryType::Conversation,
855 MemoryType::Session,
856 MemoryType::Fact,
857 MemoryType::Episode,
858 MemoryType::Knowledge,
859 ] {
860 if let Ok(names) = self.state_store.list_category(mt.category()).await {
861 for name in names {
862 if let Ok(Some(entry)) = self
863 .state_store
864 .load_json::<MemoryEntry>(mt.category(), &name)
865 .await
866 {
867 let vector = self.embedding.embed(&entry.content).await?;
868 if let Some(f32_vec) = vector.to_f32_dense() {
869 if let Err(e) = hnsw_index.add_entry(&entry.id, &f32_vec) {
870 tracing::warn!(
871 id = %entry.id,
872 error = %e,
873 "Failed to add entry to HNSW index"
874 );
875 continue;
876 }
877 count += 1;
878 }
879 }
880 }
881 }
882 }
883
884 tracing::info!(entries = count, "HNSW index rebuilt");
885 Ok(count)
886 }
887}