1use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::Arc;
10
11use anyhow::Result;
12use chrono::{DateTime, Utc};
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15
16use crate::embedding::{EmbeddingProvider, EmbeddingVector, TfIdfEmbeddingProvider};
17use crate::git_layer::GitLayer;
18use crate::state_store::StateStore;
19
20pub use budget::{CurationCandidate, CurationReport, MemoryBudget};
22pub use store::HnswMemoryIndex;
23
24use std::collections::hash_map::DefaultHasher;
29use std::hash::{Hash, Hasher};
30
31pub fn content_hash(content: &str) -> u64 {
33 let mut hasher = DefaultHasher::new();
34 content.hash(&mut hasher);
35 hasher.finish()
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct TextVector {
49 tf: HashMap<String, f64>,
51}
52
53impl TextVector {
54 pub fn from_text(text: &str) -> Self {
56 let mut tf: HashMap<String, f64> = HashMap::new();
57 let terms = Self::tokenize(text);
58 let total = terms.len() as f64;
59
60 for term in terms {
61 *tf.entry(term).or_insert(0.0) += 1.0;
62 }
63
64 if total > 0.0 {
66 for v in tf.values_mut() {
67 *v /= total;
68 }
69 }
70
71 Self { tf }
72 }
73
74 pub fn tokenize(text: &str) -> Vec<String> {
78 text.to_lowercase()
79 .split(|c: char| !c.is_alphanumeric() && !('\u{AC00}'..='\u{D7A3}').contains(&c))
80 .filter(|s| !s.is_empty() && s.len() > 1)
81 .map(|s| s.to_string())
82 .collect()
83 }
84
85 pub fn tf_map(&self) -> &HashMap<String, f64> {
87 &self.tf
88 }
89
90 pub fn cosine_similarity(&self, other: &TextVector) -> f64 {
92 let mut dot = 0.0;
93 let mut norm_a = 0.0;
94 let mut norm_b = 0.0;
95
96 for (term, &a) in &self.tf {
97 norm_a += a * a;
98 if let Some(&b) = other.tf.get(term) {
99 dot += a * b;
100 }
101 }
102 for &b in other.tf.values() {
103 norm_b += b * b;
104 }
105
106 if norm_a == 0.0 || norm_b == 0.0 {
107 return 0.0;
108 }
109
110 dot / (norm_a.sqrt() * norm_b.sqrt())
111 }
112}
113
114#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
120#[serde(rename_all = "snake_case")]
121pub enum MemoryType {
122 Conversation,
124 Session,
126 Fact,
128 Episode,
130 Knowledge,
132}
133
134impl MemoryType {
135 pub fn category(&self) -> &'static str {
137 match self {
138 MemoryType::Conversation => "memory/conversations",
139 MemoryType::Session => "memory/sessions",
140 MemoryType::Fact => "memory/facts",
141 MemoryType::Episode => "memory/episodes",
142 MemoryType::Knowledge => "memory/knowledge",
143 }
144 }
145
146 pub fn label(&self) -> &'static str {
148 match self {
149 MemoryType::Conversation => "conversation",
150 MemoryType::Session => "session",
151 MemoryType::Fact => "fact",
152 MemoryType::Episode => "episode",
153 MemoryType::Knowledge => "knowledge",
154 }
155 }
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct MemoryEntry {
161 pub id: String,
163 pub memory_type: MemoryType,
165 pub content: String,
167 pub source: String,
169 #[serde(skip_serializing_if = "Option::is_none")]
171 pub session_id: Option<String>,
172 #[serde(default)]
174 pub tags: Vec<String>,
175 #[serde(default = "default_importance")]
177 pub importance: f32,
178 pub created_at: DateTime<Utc>,
180 pub accessed_at: DateTime<Utc>,
182 #[serde(default)]
184 pub access_count: u32,
185}
186
187fn default_importance() -> f32 {
188 0.5
189}
190
191pub struct MemoryManager {
201 state_store: Arc<StateStore>,
202 max_recall: usize,
203 vector_index: RwLock<HashMap<String, EmbeddingVector>>,
205 embedding: Arc<dyn EmbeddingProvider>,
207 git_layer: Option<Arc<GitLayer>>,
209 hnsw_index: RwLock<Option<Arc<HnswMemoryIndex>>>,
211}
212
213impl std::fmt::Debug for MemoryManager {
214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215 f.debug_struct("MemoryManager")
216 .field("max_recall", &self.max_recall)
217 .field("index_size", &self.vector_index.read().len())
218 .finish()
219 }
220}
221
222impl MemoryManager {
223 pub fn new(state_store: Arc<StateStore>) -> Self {
225 Self {
226 state_store,
227 max_recall: 10,
228 vector_index: RwLock::new(HashMap::new()),
229 embedding: Arc::new(TfIdfEmbeddingProvider),
230 git_layer: None,
231 hnsw_index: RwLock::new(None),
232 }
233 }
234
235 pub fn set_git_layer(&mut self, gl: Arc<GitLayer>) {
237 self.git_layer = Some(gl);
238 }
239
240 pub fn for_space(space_dir: PathBuf) -> Self {
245 let memory_dir = space_dir.join("memory");
246 let state_store = Arc::new(StateStore::new(memory_dir).unwrap_or_else(|_| {
247 StateStore::new(std::env::temp_dir().join("oxios-memory")).unwrap()
249 }));
250 Self::new(state_store)
251 }
252
253 pub fn set_hnsw_index(&self, index: Arc<HnswMemoryIndex>) {
258 *self.hnsw_index.write() = Some(index);
259 }
260
261 fn git_commit(&self, rel_path: &str, message: &str) {
263 if let Some(ref gl) = self.git_layer {
264 if gl.is_enabled() {
265 let _ = gl.commit_file(rel_path, message);
266 }
267 }
268 }
269
270 pub fn with_max_recall(mut self, n: usize) -> Self {
272 self.max_recall = n;
273 self
274 }
275
276 pub fn with_config(mut self, config: &crate::config::MemoryConfig) -> Self {
278 self.max_recall = config.max_recall;
279 self
280 }
281
282 pub fn vector_index_size(&self) -> usize {
284 self.vector_index.read().len()
285 }
286
287 pub fn effective_importance(entry: &MemoryEntry) -> f32 {
292 let access_boost = (1.0_f32 + entry.access_count as f32).ln();
293 entry.importance * (1.0 + access_boost)
294 }
295
296 pub async fn curate(&self, budget: &MemoryBudget) -> Result<CurationReport> {
300 let mut report = CurationReport::default();
301
302 for mt in &[
303 MemoryType::Conversation,
304 MemoryType::Session,
305 MemoryType::Fact,
306 MemoryType::Episode,
307 MemoryType::Knowledge,
308 ] {
309 let entries = self.list(*mt, budget.max_per_type * 2).await?;
310 if entries.len() <= budget.max_per_type {
311 continue;
312 }
313
314 let total_count = entries.len();
316 let mut scored: Vec<_> = entries
317 .into_iter()
318 .map(|e| (e.id.clone(), e.memory_type, Self::effective_importance(&e)))
319 .collect();
320 scored.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
321
322 let to_remove = scored.len() - budget.max_per_type;
323 for (id, memory_type, score) in scored.into_iter().take(to_remove) {
324 report.candidates_for_removal.push(CurationCandidate {
325 id,
326 memory_type,
327 effective_importance: score,
328 });
329 }
330 report.total_before += total_count;
331 }
332
333 for candidate in &report.candidates_for_removal {
335 if self
336 .forget(&candidate.id, candidate.memory_type)
337 .await
338 .is_ok()
339 {
340 report.removed += 1;
341 }
342 }
343
344 report.total_after = report.total_before - report.removed;
345 Ok(report)
346 }
347
348 pub fn spawn_curation_task(self: &Arc<Self>, budget: MemoryBudget) {
352 let mgr = Arc::clone(self);
353 tokio::spawn(async move {
354 match mgr.curate(&budget).await {
355 Ok(report) => {
356 if report.removed > 0 {
357 tracing::info!(
358 removed = report.removed,
359 candidates = report.candidates_for_removal.len(),
360 "Memory curation complete"
361 );
362 }
363 }
364 Err(e) => {
365 tracing::warn!(error = %e, "Memory curation failed");
366 }
367 }
368 });
369 }
370}
371
372pub(crate) fn extract_keywords(query: &str) -> Vec<String> {
380 const STOP_WORDS: &[&str] = &[
381 "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had",
382 "do", "does", "did", "will", "would", "could", "should", "may", "might", "can", "shall",
383 "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into", "through",
384 "during", "before", "after", "above", "below", "between", "out", "off", "over", "under",
385 "again", "further", "then", "once", "and", "but", "or", "nor", "not", "so", "yet", "both",
386 "either", "neither", "each", "every", "all", "any", "few", "more", "most", "other", "some",
387 "such", "no", "only", "own", "same", "than", "too", "very", "just", "because", "if",
388 "when", "where", "how", "what", "which", "who", "whom", "this", "that", "these", "those",
389 "i", "me", "my", "we", "our", "you", "your", "he", "him", "his", "she", "her", "it", "its",
390 "they", "them", "their",
391 ];
392
393 query
394 .split_whitespace()
395 .map(|w| {
396 let w = w.trim_end_matches(|c: char| c.is_ascii_punctuation());
398 w.to_lowercase()
399 })
400 .filter(|w| w.len() > 2 && !STOP_WORDS.contains(&w.as_str()))
401 .collect()
402}
403
404pub(crate) fn dedup_by_id(entries: &mut Vec<MemoryEntry>) {
406 let mut seen = std::collections::HashSet::new();
407 entries.retain(|e| seen.insert(e.id.clone()));
408}
409
410pub mod auto_memory_bridge;
415mod budget;
416mod chunking;
417pub mod embedding_cache;
418pub mod flash_attention;
419mod graph;
420mod hnsw;
421pub mod hyperbolic;
422pub mod normalizer;
423pub(crate) mod store;
424
425pub use embedding_cache::{CacheStats, EmbeddingCache};
426pub use store::SemanticHit;
427
428pub use chunking::{chunk_fixed, chunk_paragraphs, ChunkConfig, TextChunk};
430pub use graph::MemoryGraph;
431pub use hnsw::HnswIndex;
432pub use normalizer::{cosine_similarity_f32, l2_normalize_f32, l2_normalize_f64};
433
434#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_memory_type_category() {
444 assert_eq!(MemoryType::Conversation.category(), "memory/conversations");
445 assert_eq!(MemoryType::Fact.category(), "memory/facts");
446 assert_eq!(MemoryType::Knowledge.category(), "memory/knowledge");
447 }
448
449 #[test]
450 fn test_extract_keywords() {
451 let kw = extract_keywords("How do I implement a Rust agent system?");
452 assert!(kw.contains(&"implement".to_string()));
453 assert!(kw.contains(&"rust".to_string()));
454 assert!(kw.contains(&"agent".to_string()));
455 assert!(kw.contains(&"system".to_string()));
456 assert!(!kw.contains(&"how".to_string()));
458 assert!(!kw.contains(&"do".to_string()));
459 }
460
461 #[test]
462 fn test_dedup_by_id() {
463 let mut entries = vec![
464 make_entry("a", MemoryType::Fact),
465 make_entry("b", MemoryType::Fact),
466 make_entry("a", MemoryType::Episode), ];
468 dedup_by_id(&mut entries);
469 assert_eq!(entries.len(), 2);
470 }
471
472 #[test]
473 fn test_blend_into_prompt_empty() {
474 let mgr = MemoryManager::new(Arc::new(
475 StateStore::new(std::env::temp_dir().join("test")).unwrap(),
476 ));
477 let result = mgr.blend_into_prompt(&[], "You are an agent.");
478 assert_eq!(result, "You are an agent.");
479 }
480
481 #[test]
482 fn test_blend_into_prompt_with_memories() {
483 let mgr = MemoryManager::new(Arc::new(
484 StateStore::new(std::env::temp_dir().join("test")).unwrap(),
485 ));
486 let memories = vec![make_entry("test", MemoryType::Fact)];
487 let result = mgr.blend_into_prompt(&memories, "You are an agent.");
488 assert!(result.contains("## Relevant Memory"));
489 assert!(result.contains("[fact]"));
490 }
491
492 #[test]
495 fn test_text_vector_cosine_similarity() {
496 let v1 = TextVector::from_text("fix the null pointer error in main.rs");
497 let v2 = TextVector::from_text("null pointer error found in rust code");
498 let v3 = TextVector::from_text("update the documentation for deployment");
499
500 assert!(
502 v1.cosine_similarity(&v2) > 0.3,
503 "Similar texts should have > 0.3 similarity"
504 );
505
506 assert!(
508 v1.cosine_similarity(&v3) < 0.2,
509 "Different texts should have < 0.2 similarity"
510 );
511 }
512
513 #[test]
514 fn test_text_vector_korean() {
515 let v1 = TextVector::from_text("main.rs 파일의 null pointer 에러 수정");
516 let v2 = TextVector::from_text("null pointer 오류를 수정했습니다");
517 let v3 = TextVector::from_text("문서 업데이트 배포 가이드");
518
519 assert!(v1.cosine_similarity(&v2) > 0.1, "Korean+code similarity");
520 assert!(v1.cosine_similarity(&v3) < 0.1, "Korean different topics");
521 }
522
523 #[test]
524 fn test_text_vector_empty() {
525 let v1 = TextVector::from_text("");
526 let v2 = TextVector::from_text("hello");
527 assert_eq!(v1.cosine_similarity(&v2), 0.0);
528 }
529
530 #[test]
531 fn test_text_vector_identical() {
532 let v1 = TextVector::from_text("rust programming language");
533 let v2 = TextVector::from_text("rust programming language");
534 let sim = v1.cosine_similarity(&v2);
535 assert!(
536 (sim - 1.0).abs() < 1e-9,
537 "Identical texts should have similarity ~1.0, got {}",
538 sim
539 );
540 }
541
542 #[test]
543 fn test_tokenize_korean() {
544 let terms = TextVector::tokenize("main.rs 파일의 버그를 수정");
545 assert!(!terms.is_empty(), "Korean text should produce tokens");
547 }
548
549 #[tokio::test]
550 async fn test_vector_search_over_keyword_fallback() {
551 let temp_dir = tempfile::tempdir().unwrap();
552 let store = Arc::new(StateStore::new(temp_dir.path().to_path_buf()).unwrap());
553 let mgr = MemoryManager::new(store.clone());
554
555 let entry1 = MemoryEntry {
557 id: "vec-test-1".to_string(),
558 memory_type: MemoryType::Fact,
559 content: "Rust is a systems programming language focused on safety".to_string(),
560 source: "test".to_string(),
561 session_id: None,
562 tags: vec![],
563 importance: 0.5,
564 created_at: Utc::now(),
565 accessed_at: Utc::now(),
566 access_count: 0,
567 };
568 let entry2 = MemoryEntry {
569 id: "vec-test-2".to_string(),
570 memory_type: MemoryType::Fact,
571 content: "Python is great for machine learning and data science".to_string(),
572 source: "test".to_string(),
573 session_id: None,
574 tags: vec![],
575 importance: 0.5,
576 created_at: Utc::now(),
577 accessed_at: Utc::now(),
578 access_count: 0,
579 };
580
581 mgr.remember(entry1).await.unwrap();
582 mgr.remember(entry2).await.unwrap();
583
584 let results = mgr
586 .search("systems programming with rust", None, 5)
587 .await
588 .unwrap();
589 assert!(!results.is_empty(), "Vector search should find results");
590 assert_eq!(
591 results[0].id, "vec-test-1",
592 "Should find the Rust entry first"
593 );
594 }
595
596 #[tokio::test]
597 async fn test_rebuild_index() {
598 let temp_dir = tempfile::tempdir().unwrap();
599 let store = Arc::new(StateStore::new(temp_dir.path().to_path_buf()).unwrap());
600 let mgr = MemoryManager::new(store.clone());
601
602 let entry = MemoryEntry {
604 id: "rebuild-test-1".to_string(),
605 memory_type: MemoryType::Fact,
606 content: "memory for rebuild test".to_string(),
607 source: "test".to_string(),
608 session_id: None,
609 tags: vec![],
610 importance: 0.5,
611 created_at: Utc::now(),
612 accessed_at: Utc::now(),
613 access_count: 0,
614 };
615 store
616 .save_json("memory/facts", "rebuild-test-1", &entry)
617 .await
618 .unwrap();
619
620 assert_eq!(mgr.vector_index.read().len(), 0);
622
623 mgr.rebuild_index().await.unwrap();
625
626 assert_eq!(mgr.vector_index.read().len(), 1);
628 assert!(mgr.vector_index.read().contains_key("rebuild-test-1"));
629 }
630
631 fn make_entry(id: &str, ty: MemoryType) -> MemoryEntry {
632 MemoryEntry {
633 id: id.to_string(),
634 memory_type: ty,
635 content: format!("Test content for {}", id),
636 source: "test".to_string(),
637 session_id: None,
638 tags: vec![],
639 importance: 0.5,
640 created_at: Utc::now(),
641 accessed_at: Utc::now(),
642 access_count: 0,
643 }
644 }
645}