1use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::atomic::{AtomicU64, Ordering};
9
10use anyhow::Result;
11use chrono::{DateTime, Utc};
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14
15use crate::embedding::EmbeddingVector;
16
17use super::hnsw::HnswIndex;
18use super::normalizer::l2_normalize_f32;
19use super::{content_hash, dedup_by_id, extract_keywords, MemoryEntry, MemoryManager, MemoryType};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
27struct VectorIndexSnapshot {
28 created_at: DateTime<Utc>,
30 entry_count: usize,
32 entries: HashMap<String, EmbeddingVector>,
34}
35
36impl MemoryManager {
41 pub async fn total_entries(&self) -> usize {
43 let mut total = 0;
44 for mt in [
45 MemoryType::Conversation,
46 MemoryType::Session,
47 MemoryType::Fact,
48 MemoryType::Episode,
49 MemoryType::Knowledge,
50 ] {
51 if let Ok(entries) = self.list(mt, usize::MAX).await {
52 total += entries.len();
53 }
54 }
55 total
56 }
57
58 pub async fn rebuild_index(&self) -> Result<()> {
63 let mut entries_to_index: Vec<(String, EmbeddingVector)> = Vec::new();
65
66 for mt in &[
67 MemoryType::Conversation,
68 MemoryType::Session,
69 MemoryType::Fact,
70 MemoryType::Episode,
71 MemoryType::Knowledge,
72 ] {
73 if let Ok(names) = self.state_store.list_category(mt.category()).await {
74 for name in names {
75 if let Ok(Some(entry)) = self
76 .state_store
77 .load_json::<MemoryEntry>(mt.category(), &name)
78 .await
79 {
80 let vector = self.embedding.embed(&entry.content).await?;
81 entries_to_index.push((entry.id.clone(), vector));
82 }
83 }
84 }
85 }
86
87 {
89 let mut index = self.vector_index.write();
90 index.clear();
91 for (id, vector) in entries_to_index {
92 index.insert(id, vector);
93 }
94 }
95
96 tracing::info!(
97 entries = self.vector_index.read().len(),
98 "Memory vector index rebuilt"
99 );
100 Ok(())
101 }
102
103 pub async fn save_index_snapshot(&self) -> Result<()> {
105 let snapshot = {
106 let index = self.vector_index.read();
107 VectorIndexSnapshot {
108 created_at: chrono::Utc::now(),
109 entry_count: index.len(),
110 entries: index.clone(),
111 }
112 };
113
114 self.state_store
115 .save_json("memory", "vector_index_snapshot", &snapshot)
116 .await?;
117
118 self.git_commit("memory/vector_index_snapshot.json", "memory: snapshot save");
119
120 tracing::debug!(
121 entries = snapshot.entry_count,
122 "Vector index snapshot saved"
123 );
124 Ok(())
125 }
126
127 pub async fn load_index_snapshot(&self) -> Result<usize> {
129 let snapshot: Option<VectorIndexSnapshot> = self
130 .state_store
131 .load_json("memory", "vector_index_snapshot")
132 .await?;
133
134 match snapshot {
135 Some(snap) => {
136 let count = snap.entry_count;
137 let mut index = self.vector_index.write();
138 *index = snap.entries;
139 tracing::info!(entries = count, "Vector index snapshot loaded");
140 Ok(count)
141 }
142 None => {
143 tracing::debug!("No vector index snapshot found");
144 Ok(0)
145 }
146 }
147 }
148
149 pub async fn remember(&self, entry: MemoryEntry) -> Result<String> {
154 let id = entry.id.clone();
155 let vector = self.embedding.embed(&entry.content).await?;
156 let category = entry.memory_type.category();
157 self.state_store.save_json(category, &id, &entry).await?;
158
159 self.git_commit(
160 &format!("{}/{}.json", category, id),
161 &format!("memory: store {}", id),
162 );
163
164 {
166 let mut index = self.vector_index.write();
167 index.insert(id.clone(), vector.clone());
168 }
169
170 if let Some(f32_vec) = vector.to_f32_dense() {
172 let hnsw = self.hnsw_index.read();
173 if let Some(ref hnsw) = *hnsw {
174 if let Err(e) = hnsw.add_entry(&id, &f32_vec) {
175 tracing::warn!(id = %id, error = %e, "Failed to update HNSW index on remember");
176 }
177 }
178 }
179
180 tracing::debug!(id = %id, ty = entry.memory_type.label(), "Memory stored");
181 Ok(id)
182 }
183
184 pub async fn get(&self, id: &str, memory_type: MemoryType) -> Result<Option<MemoryEntry>> {
186 self.state_store.load_json(memory_type.category(), id).await
187 }
188
189 pub async fn forget(&self, id: &str, memory_type: MemoryType) -> Result<bool> {
191 let result = self
192 .state_store
193 .delete_file(memory_type.category(), id)
194 .await?;
195
196 {
198 let hnsw = self.hnsw_index.read();
199 if let Some(ref hnsw) = *hnsw {
200 if let Err(e) = hnsw.remove_entry(id) {
201 tracing::warn!(id = %id, error = %e, "Failed to remove from HNSW index on forget");
202 }
203 }
204 }
205
206 Ok(result)
207 }
208
209 pub async fn list(&self, memory_type: MemoryType, limit: usize) -> Result<Vec<MemoryEntry>> {
211 let category = memory_type.category();
212 let names = self.state_store.list_category(category).await?;
213 let mut entries = Vec::new();
214 for name in names.into_iter().take(limit * 2) {
215 if let Ok(Some(entry)) = self
216 .state_store
217 .load_json::<MemoryEntry>(category, &name)
218 .await
219 {
220 entries.push(entry);
221 }
222 }
223 entries.sort_by_key(|b| std::cmp::Reverse(b.created_at));
225 entries.truncate(limit);
226 Ok(entries)
227 }
228
229 pub async fn search(
234 &self,
235 query: &str,
236 memory_type: Option<MemoryType>,
237 limit: usize,
238 ) -> Result<Vec<MemoryEntry>> {
239 let query_vector = self.embedding.embed(query).await?;
240
241 let scored: Vec<(String, f64)> = {
243 let index = self.vector_index.read();
244 let mut scored: Vec<(String, f64)> = index
245 .iter()
246 .map(|(id, vector)| {
247 let score = query_vector.cosine_similarity(vector);
248 (id.clone(), score)
249 })
250 .filter(|(_, score)| *score > 0.1)
251 .collect();
252 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
253 scored.truncate(limit);
254 scored
255 }; if scored.is_empty() {
259 return self.keyword_search(query, memory_type, limit).await;
260 }
261
262 let all_types: &[MemoryType] = &[
264 MemoryType::Conversation,
265 MemoryType::Session,
266 MemoryType::Fact,
267 MemoryType::Episode,
268 MemoryType::Knowledge,
269 ];
270 let types: &[MemoryType] = match memory_type {
271 Some(ref t) => std::slice::from_ref(t),
272 None => all_types,
273 };
274
275 let mut results = Vec::new();
277 for (id, score) in scored {
278 for mt in types {
279 if let Ok(Some(mut entry)) = self
280 .state_store
281 .load_json::<MemoryEntry>(mt.category(), &id)
282 .await
283 {
284 entry.access_count += 1;
285 entry.accessed_at = chrono::Utc::now();
286 tracing::debug!(id = %id, score, "Vector search hit");
287 results.push(entry);
288 break;
289 }
290 }
291 }
292
293 if results.is_empty() {
295 return self.keyword_search(query, memory_type, limit).await;
296 }
297
298 Ok(results)
299 }
300
301 async fn keyword_search(
303 &self,
304 query: &str,
305 memory_type: Option<MemoryType>,
306 limit: usize,
307 ) -> Result<Vec<MemoryEntry>> {
308 let keywords = extract_keywords(query);
309 let types = match memory_type {
310 Some(t) => vec![t],
311 None => vec![
312 MemoryType::Conversation,
313 MemoryType::Fact,
314 MemoryType::Episode,
315 MemoryType::Knowledge,
316 ],
317 };
318
319 let mut results = Vec::new();
320 for ty in &types {
321 let entries = self.list(*ty, limit * 2).await?;
322 for entry in entries {
323 let matches = keywords.iter().any(|k| {
324 let k_lower = k.to_lowercase();
325 entry.content.to_lowercase().contains(&k_lower)
326 || entry
327 .tags
328 .iter()
329 .any(|t| t.to_lowercase().contains(&k_lower))
330 });
331 if matches {
332 results.push(entry);
333 }
334 }
335 }
336
337 results.sort_by(|a, b| {
338 b.importance
339 .partial_cmp(&a.importance)
340 .unwrap_or(std::cmp::Ordering::Equal)
341 });
342 results.truncate(limit);
343 Ok(results)
344 }
345
346 pub async fn recall(&self, query: &str) -> Result<Vec<MemoryEntry>> {
351 let limit = self.max_recall;
352
353 let recent = self
355 .list(MemoryType::Conversation, 3)
356 .await
357 .unwrap_or_default();
358
359 let sessions = self.list(MemoryType::Session, 2).await.unwrap_or_default();
361
362 let relevant = self.search(query, None, limit).await.unwrap_or_default();
364
365 let mut combined = recent;
367 combined.extend(sessions);
368 combined.extend(relevant);
369 dedup_by_id(&mut combined);
370 combined.truncate(limit);
371 Ok(combined)
372 }
373
374 pub fn blend_into_prompt(&self, memories: &[MemoryEntry], system_prompt: &str) -> String {
376 if memories.is_empty() {
377 return system_prompt.to_string();
378 }
379
380 let memory_block = memories
381 .iter()
382 .map(|m| format!("- [{}] {}", m.memory_type.label(), m.content))
383 .collect::<Vec<_>>()
384 .join("\n");
385
386 format!("{system_prompt}\n\n## Relevant Memory\n\n{memory_block}")
387 }
388
389 pub async fn summarize_session(
394 &self,
395 session: &crate::state_store::Session,
396 ) -> Result<Option<String>> {
397 if session.user_messages.is_empty() {
398 return Ok(None);
399 }
400
401 let mut summary_parts = Vec::new();
403
404 if let Some(first_msg) = session.user_messages.first() {
406 summary_parts.push(format!("User: {}", first_msg.content));
407 }
408
409 if let Some(last_response) = session.agent_responses.last() {
411 let truncated = if last_response.content.len() > 500 {
412 format!("{}...", &last_response.content[..500])
413 } else {
414 last_response.content.clone()
415 };
416 summary_parts.push(format!("Agent: {}", truncated));
417 }
418
419 if let Some(ref seed_id) = session.active_seed_id {
421 summary_parts.push(format!("Seed: {}", seed_id));
422 }
423 if let Some(ref persona_id) = session.active_persona_id {
424 summary_parts.push(format!("Persona: {}", persona_id));
425 }
426
427 let content = summary_parts.join("\n");
428 let entry = MemoryEntry {
429 id: format!(
430 "session-{}-{}",
431 session.id.0,
432 chrono::Utc::now().timestamp()
433 ),
434 memory_type: MemoryType::Session,
435 content,
436 source: "session_summary".to_string(),
437 session_id: Some(session.id.0.clone()),
438 tags: vec![],
439 importance: 0.6,
440 created_at: chrono::Utc::now(),
441 accessed_at: chrono::Utc::now(),
442 access_count: 0,
443 };
444
445 let id = self.remember(entry).await?;
446 Ok(Some(id))
447 }
448
449 pub async fn is_duplicate(&self, content: &str) -> bool {
453 let hash = content_hash(content);
454
455 let query_vector = match self.embedding.embed(content).await {
457 Ok(v) => v,
458 Err(_) => return false,
459 };
460 let similar = {
461 let index = self.vector_index.read();
462 index
463 .iter()
464 .any(|(_, vector)| query_vector.cosine_similarity(vector) > 0.95)
465 };
466 if similar {
467 return true;
468 }
469
470 for mt in &[
472 MemoryType::Conversation,
473 MemoryType::Session,
474 MemoryType::Fact,
475 MemoryType::Episode,
476 MemoryType::Knowledge,
477 ] {
478 if let Ok(entries) = self.list(*mt, 1000).await {
479 for entry in entries {
480 if content_hash(&entry.content) == hash {
481 return true;
482 }
483 }
484 }
485 }
486 false
487 }
488
489 pub async fn remember_unique(&self, entry: MemoryEntry) -> Result<Option<String>> {
493 if self.is_duplicate(&entry.content).await {
494 tracing::debug!(id = %entry.id, "Skipping duplicate memory");
495 return Ok(None);
496 }
497 let id = self.remember(entry).await?;
498 Ok(Some(id))
499 }
500}
501
502#[derive(Debug, Clone, Serialize, Deserialize)]
508pub struct SemanticHit {
509 pub entry: MemoryEntry,
511 pub distance: f32,
513 pub similarity: f32,
515}
516
517pub struct HnswMemoryIndex {
522 index: RwLock<HnswIndex>,
524 key_to_id: RwLock<HashMap<u64, String>>,
526 id_to_key: RwLock<HashMap<String, u64>>,
528 next_key: AtomicU64,
530 persist_path: Option<PathBuf>,
532}
533
534impl std::fmt::Debug for HnswMemoryIndex {
535 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
536 f.debug_struct("HnswMemoryIndex")
537 .field("size", &self.len())
538 .field("dimensions", &self.index.read().dimensions())
539 .finish()
540 }
541}
542
543impl HnswMemoryIndex {
544 pub fn new(dimensions: usize, capacity: usize, persist_path: Option<PathBuf>) -> Result<Self> {
551 let index = HnswIndex::new(dimensions, capacity)?;
552 Ok(Self {
553 index: RwLock::new(index),
554 key_to_id: RwLock::new(HashMap::new()),
555 id_to_key: RwLock::new(HashMap::new()),
556 next_key: AtomicU64::new(1), persist_path,
558 })
559 }
560
561 pub fn restore_or_new(
563 dimensions: usize,
564 capacity: usize,
565 persist_path: Option<PathBuf>,
566 ) -> Result<Self> {
567 if let Some(ref path) = persist_path {
568 let index_path = path.join("memory.usearch");
569 let mapping_path = path.join("key_map.json");
570
571 if index_path.exists() && mapping_path.exists() {
572 tracing::info!(path = %index_path.display(), "Restoring HNSW index from disk");
573
574 if let Ok(index) = HnswIndex::load(&index_path) {
575 if let Ok(data) = std::fs::read_to_string(&mapping_path) {
576 if let Ok((k2i, i2k)) = serde_json::from_str::<(
577 HashMap<u64, String>,
578 HashMap<String, u64>,
579 )>(&data)
580 {
581 let max_key = k2i.keys().max().copied().unwrap_or(0);
582 return Ok(Self {
583 index: RwLock::new(index),
584 key_to_id: RwLock::new(k2i),
585 id_to_key: RwLock::new(i2k),
586 next_key: AtomicU64::new(max_key + 1),
587 persist_path,
588 });
589 }
590 }
591 }
592
593 tracing::warn!("Failed to restore HNSW index, creating new one");
594 }
595 }
596
597 Self::new(dimensions, capacity, persist_path)
598 }
599
600 fn get_or_create_key(&self, id: &str) -> u64 {
602 {
604 let i2k = self.id_to_key.read();
605 if let Some(&key) = i2k.get(id) {
606 return key;
607 }
608 }
609
610 let mut i2k = self.id_to_key.write();
612 let mut k2i = self.key_to_id.write();
613
614 if let Some(&key) = i2k.get(id) {
616 return key;
617 }
618
619 let key = self.next_key.fetch_add(1, Ordering::Relaxed);
620 i2k.insert(id.to_string(), key);
621 k2i.insert(key, id.to_string());
622 key
623 }
624
625 pub fn add_entry(&self, id: &str, vector: &[f32]) -> Result<()> {
627 let key = self.get_or_create_key(id);
628 let mut normalized = vector.to_vec();
629 l2_normalize_f32(&mut normalized);
630 self.index.write().add(key, &normalized)?;
631 Ok(())
632 }
633
634 pub fn remove_entry(&self, id: &str) -> Result<()> {
636 let key = {
637 let i2k = self.id_to_key.read();
638 i2k.get(id).copied()
639 };
640 if let Some(key) = key {
641 self.index.write().remove(key)?;
642 let mut k2i = self.key_to_id.write();
643 let mut i2k = self.id_to_key.write();
644 k2i.remove(&key);
645 i2k.remove(id);
646 }
647 Ok(())
648 }
649
650 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
654 let mut normalized = query.to_vec();
655 l2_normalize_f32(&mut normalized);
656
657 let raw = self.index.read().search(&normalized, k)?;
658 let k2i = self.key_to_id.read();
659
660 let results = raw
661 .into_iter()
662 .filter_map(|(key, dist)| k2i.get(&key).map(|id| (id.clone(), dist)))
663 .collect();
664
665 Ok(results)
666 }
667
668 pub fn len(&self) -> usize {
670 self.index.read().len()
671 }
672
673 pub fn is_empty(&self) -> bool {
675 self.index.read().is_empty()
676 }
677
678 pub fn persist(&self) -> Result<()> {
680 if let Some(ref path) = self.persist_path {
681 std::fs::create_dir_all(path)?;
682
683 let index_path = path.join("memory.usearch");
684 let mapping_path = path.join("key_map.json");
685
686 self.index.read().save(&index_path)?;
688
689 let k2i = self.key_to_id.read();
691 let i2k = self.id_to_key.read();
692 let data = serde_json::to_string(&(k2i.clone(), &*i2k))?;
693 std::fs::write(&mapping_path, data)?;
694
695 tracing::debug!(path = %path.display(), entries = self.len(), "HNSW index persisted");
696 }
697 Ok(())
698 }
699}
700
701impl MemoryManager {
706 pub async fn semantic_search(
724 &self,
725 query: &str,
726 memory_type: Option<MemoryType>,
727 limit: usize,
728 hnsw_index: &HnswMemoryIndex,
729 ) -> Result<Vec<SemanticHit>> {
730 if hnsw_index.is_empty() {
732 tracing::debug!("HNSW index empty, falling back to keyword search");
733 return self
734 .keyword_search(query, memory_type, limit)
735 .await
736 .map(|entries| {
737 entries
738 .into_iter()
739 .map(|entry| SemanticHit {
740 entry,
741 distance: 0.0,
742 similarity: 0.0,
743 })
744 .collect()
745 });
746 }
747
748 let query_vector = self.embedding.embed(query).await?;
750 let query_f32 = match query_vector.to_f32_dense() {
751 Some(v) => v,
752 None => {
753 tracing::debug!("Query embedding is sparse, falling back to keyword search");
754 return self
755 .keyword_search(query, memory_type, limit)
756 .await
757 .map(|entries| {
758 entries
759 .into_iter()
760 .map(|entry| SemanticHit {
761 entry,
762 distance: 0.0,
763 similarity: 0.0,
764 })
765 .collect()
766 });
767 }
768 };
769
770 let raw_hits = hnsw_index.search(&query_f32, limit * 2)?;
772
773 let all_types: &[MemoryType] = &[
775 MemoryType::Conversation,
776 MemoryType::Session,
777 MemoryType::Fact,
778 MemoryType::Episode,
779 MemoryType::Knowledge,
780 ];
781 let types: &[MemoryType] = match memory_type {
782 Some(ref t) => std::slice::from_ref(t),
783 None => all_types,
784 };
785
786 let mut results = Vec::new();
788 for (id, distance) in raw_hits {
789 for mt in types {
790 if let Ok(Some(mut entry)) = self
791 .state_store
792 .load_json::<MemoryEntry>(mt.category(), &id)
793 .await
794 {
795 entry.access_count += 1;
797 entry.accessed_at = chrono::Utc::now();
798
799 let similarity = 1.0 - distance;
800 results.push(SemanticHit {
801 entry,
802 distance,
803 similarity,
804 });
805 break;
806 }
807 }
808 if results.len() >= limit {
809 break;
810 }
811 }
812
813 results.sort_by(|a, b| {
815 b.similarity
816 .partial_cmp(&a.similarity)
817 .unwrap_or(std::cmp::Ordering::Equal)
818 });
819
820 tracing::debug!(
821 query = %query,
822 hits = results.len(),
823 "Semantic search complete"
824 );
825
826 if results.is_empty() {
828 return self
829 .keyword_search(query, memory_type, limit)
830 .await
831 .map(|entries| {
832 entries
833 .into_iter()
834 .map(|entry| SemanticHit {
835 entry,
836 distance: 0.0,
837 similarity: 0.0,
838 })
839 .collect()
840 });
841 }
842
843 Ok(results)
844 }
845
846 pub async fn rebuild_hnsw_index(&self, hnsw_index: &HnswMemoryIndex) -> Result<usize> {
850 let mut count = 0;
851
852 for mt in &[
853 MemoryType::Conversation,
854 MemoryType::Session,
855 MemoryType::Fact,
856 MemoryType::Episode,
857 MemoryType::Knowledge,
858 ] {
859 if let Ok(names) = self.state_store.list_category(mt.category()).await {
860 for name in names {
861 if let Ok(Some(entry)) = self
862 .state_store
863 .load_json::<MemoryEntry>(mt.category(), &name)
864 .await
865 {
866 let vector = self.embedding.embed(&entry.content).await?;
867 if let Some(f32_vec) = vector.to_f32_dense() {
868 if let Err(e) = hnsw_index.add_entry(&entry.id, &f32_vec) {
869 tracing::warn!(
870 id = %entry.id,
871 error = %e,
872 "Failed to add entry to HNSW index"
873 );
874 continue;
875 }
876 count += 1;
877 }
878 }
879 }
880 }
881 }
882
883 tracing::info!(entries = count, "HNSW index rebuilt");
884 Ok(count)
885 }
886}