1use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::Arc;
10
11use anyhow::Result;
12use chrono::{DateTime, Utc};
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15
16use crate::embedding::{EmbeddingProvider, EmbeddingVector, TfIdfEmbeddingProvider};
17use crate::git_layer::GitLayer;
18use crate::state_store::StateStore;
19
20pub use budget::{CurationCandidate, CurationReport, MemoryBudget};
22pub use store::HnswMemoryIndex;
23
24use std::collections::hash_map::DefaultHasher;
29use std::hash::{Hash, Hasher};
30
31pub fn content_hash(content: &str) -> u64 {
33 let mut hasher = DefaultHasher::new();
34 content.hash(&mut hasher);
35 hasher.finish()
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct TextVector {
49 tf: HashMap<String, f64>,
51}
52
53impl TextVector {
54 pub fn from_text(text: &str) -> Self {
56 let mut tf: HashMap<String, f64> = HashMap::new();
57 let terms = Self::tokenize(text);
58 let total = terms.len() as f64;
59
60 for term in terms {
61 *tf.entry(term).or_insert(0.0) += 1.0;
62 }
63
64 if total > 0.0 {
66 for v in tf.values_mut() {
67 *v /= total;
68 }
69 }
70
71 Self { tf }
72 }
73
74 pub fn tokenize(text: &str) -> Vec<String> {
78 text.to_lowercase()
79 .split(|c: char| !c.is_alphanumeric() && !('\u{AC00}'..='\u{D7A3}').contains(&c))
80 .filter(|s| !s.is_empty() && s.len() > 1)
81 .map(|s| s.to_string())
82 .collect()
83 }
84
85 pub fn tf_map(&self) -> &HashMap<String, f64> {
87 &self.tf
88 }
89
90 pub fn cosine_similarity(&self, other: &TextVector) -> f64 {
92 let mut dot = 0.0;
93 let mut norm_a = 0.0;
94 let mut norm_b = 0.0;
95
96 for (term, &a) in &self.tf {
97 norm_a += a * a;
98 if let Some(&b) = other.tf.get(term) {
99 dot += a * b;
100 }
101 }
102 for &b in other.tf.values() {
103 norm_b += b * b;
104 }
105
106 if norm_a == 0.0 || norm_b == 0.0 {
107 return 0.0;
108 }
109
110 dot / (norm_a.sqrt() * norm_b.sqrt())
111 }
112}
113
114#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
120#[serde(rename_all = "snake_case")]
121pub enum MemoryType {
122 Conversation,
124 Session,
126 Fact,
128 Episode,
130 Knowledge,
132}
133
134impl MemoryType {
135 pub fn category(&self) -> &'static str {
137 match self {
138 MemoryType::Conversation => "memory/conversations",
139 MemoryType::Session => "memory/sessions",
140 MemoryType::Fact => "memory/facts",
141 MemoryType::Episode => "memory/episodes",
142 MemoryType::Knowledge => "memory/knowledge",
143 }
144 }
145
146 pub fn label(&self) -> &'static str {
148 match self {
149 MemoryType::Conversation => "conversation",
150 MemoryType::Session => "session",
151 MemoryType::Fact => "fact",
152 MemoryType::Episode => "episode",
153 MemoryType::Knowledge => "knowledge",
154 }
155 }
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct MemoryEntry {
161 pub id: String,
163 pub memory_type: MemoryType,
165 pub content: String,
167 pub source: String,
169 #[serde(skip_serializing_if = "Option::is_none")]
171 pub session_id: Option<String>,
172 #[serde(default)]
174 pub tags: Vec<String>,
175 #[serde(default = "default_importance")]
177 pub importance: f32,
178 pub created_at: DateTime<Utc>,
180 pub accessed_at: DateTime<Utc>,
182 #[serde(default)]
184 pub access_count: u32,
185}
186
187fn default_importance() -> f32 {
188 0.5
189}
190
191pub struct MemoryManager {
201 state_store: Arc<StateStore>,
202 max_recall: usize,
203 vector_index: RwLock<HashMap<String, EmbeddingVector>>,
205 embedding: Arc<dyn EmbeddingProvider>,
207 git_layer: Option<Arc<GitLayer>>,
209 hnsw_index: RwLock<Option<Arc<HnswMemoryIndex>>>,
211}
212
213impl std::fmt::Debug for MemoryManager {
214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215 f.debug_struct("MemoryManager")
216 .field("max_recall", &self.max_recall)
217 .field("index_size", &self.vector_index.read().len())
218 .finish()
219 }
220}
221
222impl MemoryManager {
223 pub fn new(state_store: Arc<StateStore>) -> Self {
225 Self {
226 state_store,
227 max_recall: 10,
228 vector_index: RwLock::new(HashMap::new()),
229 embedding: Arc::new(TfIdfEmbeddingProvider),
230 git_layer: None,
231 hnsw_index: RwLock::new(None),
232 }
233 }
234
235 pub fn set_git_layer(&mut self, gl: Arc<GitLayer>) {
237 self.git_layer = Some(gl);
238 }
239
240 pub fn for_space(space_dir: PathBuf) -> Self {
245 let memory_dir = space_dir.join("memory");
246 let state_store = Arc::new(StateStore::new(memory_dir).unwrap_or_else(|_| {
247 StateStore::new(std::env::temp_dir().join("oxios-memory")).unwrap()
249 }));
250 Self::new(state_store)
251 }
252
253 pub fn set_hnsw_index(&self, index: Arc<HnswMemoryIndex>) {
258 *self.hnsw_index.write() = Some(index);
259 }
260
261 fn git_commit(&self, rel_path: &str, message: &str) {
263 if let Some(ref gl) = self.git_layer {
264 if gl.is_enabled() {
265 let _ = gl.commit_file(rel_path, message);
266 }
267 }
268 }
269
270 pub fn with_max_recall(mut self, n: usize) -> Self {
272 self.max_recall = n;
273 self
274 }
275
276 pub fn with_config(mut self, config: &crate::config::MemoryConfig) -> Self {
278 self.max_recall = config.max_recall;
279 self
280 }
281
282 pub fn vector_index_size(&self) -> usize {
284 self.vector_index.read().len()
285 }
286
287 pub fn effective_importance(entry: &MemoryEntry) -> f32 {
292 let access_boost = (1.0_f32 + entry.access_count as f32).ln();
293 entry.importance * (1.0 + access_boost)
294 }
295
296 pub async fn curate(&self, budget: &MemoryBudget) -> Result<CurationReport> {
300 let mut report = CurationReport::default();
301
302 for mt in &[
303 MemoryType::Conversation,
304 MemoryType::Session,
305 MemoryType::Fact,
306 MemoryType::Episode,
307 MemoryType::Knowledge,
308 ] {
309 let entries = self.list(*mt, budget.max_per_type * 2).await?;
310 if entries.len() <= budget.max_per_type {
311 continue;
312 }
313
314 let total_count = entries.len();
316 let mut scored: Vec<_> = entries
317 .into_iter()
318 .map(|e| (e.id.clone(), e.memory_type, Self::effective_importance(&e)))
319 .collect();
320 scored.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
321
322 let to_remove = scored.len() - budget.max_per_type;
323 for (id, memory_type, score) in scored.into_iter().take(to_remove) {
324 report.candidates_for_removal.push(CurationCandidate {
325 id,
326 memory_type,
327 effective_importance: score,
328 });
329 }
330 report.total_before += total_count;
331 }
332
333 for candidate in &report.candidates_for_removal {
335 if self
336 .forget(&candidate.id, candidate.memory_type)
337 .await
338 .is_ok()
339 {
340 report.removed += 1;
341 }
342 }
343
344 report.total_after = report.total_before - report.removed;
345 Ok(report)
346 }
347
348 pub fn spawn_curation_task(self: &Arc<Self>, budget: MemoryBudget) {
352 let mgr = Arc::clone(self);
353 tokio::spawn(async move {
354 match mgr.curate(&budget).await {
355 Ok(report) => {
356 if report.removed > 0 {
357 tracing::info!(
358 removed = report.removed,
359 candidates = report.candidates_for_removal.len(),
360 "Memory curation complete"
361 );
362 }
363 }
364 Err(e) => {
365 tracing::warn!(error = %e, "Memory curation failed");
366 }
367 }
368 });
369 }
370}
371
372pub(crate) fn extract_keywords(query: &str) -> Vec<String> {
380 const STOP_WORDS: &[&str] = &[
381 "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had",
382 "do", "does", "did", "will", "would", "could", "should", "may", "might", "can", "shall",
383 "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into", "through",
384 "during", "before", "after", "above", "below", "between", "out", "off", "over", "under",
385 "again", "further", "then", "once", "and", "but", "or", "nor", "not", "so", "yet", "both",
386 "either", "neither", "each", "every", "all", "any", "few", "more", "most", "other", "some",
387 "such", "no", "only", "own", "same", "than", "too", "very", "just", "because", "if",
388 "when", "where", "how", "what", "which", "who", "whom", "this", "that", "these", "those",
389 "i", "me", "my", "we", "our", "you", "your", "he", "him", "his", "she", "her", "it", "its",
390 "they", "them", "their",
391 ];
392
393 query
394 .split_whitespace()
395 .map(|w| {
396 let w = w.trim_end_matches(|c: char| c.is_ascii_punctuation());
398 w.to_lowercase()
399 })
400 .filter(|w| w.len() > 2 && !STOP_WORDS.contains(&w.as_str()))
401 .collect()
402}
403
404pub(crate) fn dedup_by_id(entries: &mut Vec<MemoryEntry>) {
406 let mut seen = std::collections::HashSet::new();
407 entries.retain(|e| seen.insert(e.id.clone()));
408}
409
410pub mod auto_memory_bridge;
415mod budget;
416mod chunking;
417pub mod flash_attention;
418mod graph;
419mod hnsw;
420pub mod hyperbolic;
421pub mod normalizer;
422pub(crate) mod store;
423
424pub use store::SemanticHit;
425
426pub use chunking::{chunk_fixed, chunk_paragraphs, ChunkConfig, TextChunk};
428pub use graph::MemoryGraph;
429pub use hnsw::HnswIndex;
430pub use normalizer::{cosine_similarity_f32, l2_normalize_f32, l2_normalize_f64};
431
432#[cfg(test)]
437mod tests {
438 use super::*;
439
440 #[test]
441 fn test_memory_type_category() {
442 assert_eq!(MemoryType::Conversation.category(), "memory/conversations");
443 assert_eq!(MemoryType::Fact.category(), "memory/facts");
444 assert_eq!(MemoryType::Knowledge.category(), "memory/knowledge");
445 }
446
447 #[test]
448 fn test_extract_keywords() {
449 let kw = extract_keywords("How do I implement a Rust agent system?");
450 assert!(kw.contains(&"implement".to_string()));
451 assert!(kw.contains(&"rust".to_string()));
452 assert!(kw.contains(&"agent".to_string()));
453 assert!(kw.contains(&"system".to_string()));
454 assert!(!kw.contains(&"how".to_string()));
456 assert!(!kw.contains(&"do".to_string()));
457 }
458
459 #[test]
460 fn test_dedup_by_id() {
461 let mut entries = vec![
462 make_entry("a", MemoryType::Fact),
463 make_entry("b", MemoryType::Fact),
464 make_entry("a", MemoryType::Episode), ];
466 dedup_by_id(&mut entries);
467 assert_eq!(entries.len(), 2);
468 }
469
470 #[test]
471 fn test_blend_into_prompt_empty() {
472 let mgr = MemoryManager::new(Arc::new(
473 StateStore::new(std::env::temp_dir().join("test")).unwrap(),
474 ));
475 let result = mgr.blend_into_prompt(&[], "You are an agent.");
476 assert_eq!(result, "You are an agent.");
477 }
478
479 #[test]
480 fn test_blend_into_prompt_with_memories() {
481 let mgr = MemoryManager::new(Arc::new(
482 StateStore::new(std::env::temp_dir().join("test")).unwrap(),
483 ));
484 let memories = vec![make_entry("test", MemoryType::Fact)];
485 let result = mgr.blend_into_prompt(&memories, "You are an agent.");
486 assert!(result.contains("## Relevant Memory"));
487 assert!(result.contains("[fact]"));
488 }
489
490 #[test]
493 fn test_text_vector_cosine_similarity() {
494 let v1 = TextVector::from_text("fix the null pointer error in main.rs");
495 let v2 = TextVector::from_text("null pointer error found in rust code");
496 let v3 = TextVector::from_text("update the documentation for deployment");
497
498 assert!(
500 v1.cosine_similarity(&v2) > 0.3,
501 "Similar texts should have > 0.3 similarity"
502 );
503
504 assert!(
506 v1.cosine_similarity(&v3) < 0.2,
507 "Different texts should have < 0.2 similarity"
508 );
509 }
510
511 #[test]
512 fn test_text_vector_korean() {
513 let v1 = TextVector::from_text("main.rs 파일의 null pointer 에러 수정");
514 let v2 = TextVector::from_text("null pointer 오류를 수정했습니다");
515 let v3 = TextVector::from_text("문서 업데이트 배포 가이드");
516
517 assert!(v1.cosine_similarity(&v2) > 0.1, "Korean+code similarity");
518 assert!(v1.cosine_similarity(&v3) < 0.1, "Korean different topics");
519 }
520
521 #[test]
522 fn test_text_vector_empty() {
523 let v1 = TextVector::from_text("");
524 let v2 = TextVector::from_text("hello");
525 assert_eq!(v1.cosine_similarity(&v2), 0.0);
526 }
527
528 #[test]
529 fn test_text_vector_identical() {
530 let v1 = TextVector::from_text("rust programming language");
531 let v2 = TextVector::from_text("rust programming language");
532 let sim = v1.cosine_similarity(&v2);
533 assert!(
534 (sim - 1.0).abs() < 1e-9,
535 "Identical texts should have similarity ~1.0, got {}",
536 sim
537 );
538 }
539
540 #[test]
541 fn test_tokenize_korean() {
542 let terms = TextVector::tokenize("main.rs 파일의 버그를 수정");
543 assert!(!terms.is_empty(), "Korean text should produce tokens");
545 }
546
547 #[tokio::test]
548 async fn test_vector_search_over_keyword_fallback() {
549 let temp_dir = tempfile::tempdir().unwrap();
550 let store = Arc::new(StateStore::new(temp_dir.path().to_path_buf()).unwrap());
551 let mgr = MemoryManager::new(store.clone());
552
553 let entry1 = MemoryEntry {
555 id: "vec-test-1".to_string(),
556 memory_type: MemoryType::Fact,
557 content: "Rust is a systems programming language focused on safety".to_string(),
558 source: "test".to_string(),
559 session_id: None,
560 tags: vec![],
561 importance: 0.5,
562 created_at: Utc::now(),
563 accessed_at: Utc::now(),
564 access_count: 0,
565 };
566 let entry2 = MemoryEntry {
567 id: "vec-test-2".to_string(),
568 memory_type: MemoryType::Fact,
569 content: "Python is great for machine learning and data science".to_string(),
570 source: "test".to_string(),
571 session_id: None,
572 tags: vec![],
573 importance: 0.5,
574 created_at: Utc::now(),
575 accessed_at: Utc::now(),
576 access_count: 0,
577 };
578
579 mgr.remember(entry1).await.unwrap();
580 mgr.remember(entry2).await.unwrap();
581
582 let results = mgr
584 .search("systems programming with rust", None, 5)
585 .await
586 .unwrap();
587 assert!(!results.is_empty(), "Vector search should find results");
588 assert_eq!(
589 results[0].id, "vec-test-1",
590 "Should find the Rust entry first"
591 );
592 }
593
594 #[tokio::test]
595 async fn test_rebuild_index() {
596 let temp_dir = tempfile::tempdir().unwrap();
597 let store = Arc::new(StateStore::new(temp_dir.path().to_path_buf()).unwrap());
598 let mgr = MemoryManager::new(store.clone());
599
600 let entry = MemoryEntry {
602 id: "rebuild-test-1".to_string(),
603 memory_type: MemoryType::Fact,
604 content: "memory for rebuild test".to_string(),
605 source: "test".to_string(),
606 session_id: None,
607 tags: vec![],
608 importance: 0.5,
609 created_at: Utc::now(),
610 accessed_at: Utc::now(),
611 access_count: 0,
612 };
613 store
614 .save_json("memory/facts", "rebuild-test-1", &entry)
615 .await
616 .unwrap();
617
618 assert_eq!(mgr.vector_index.read().len(), 0);
620
621 mgr.rebuild_index().await.unwrap();
623
624 assert_eq!(mgr.vector_index.read().len(), 1);
626 assert!(mgr.vector_index.read().contains_key("rebuild-test-1"));
627 }
628
629 fn make_entry(id: &str, ty: MemoryType) -> MemoryEntry {
630 MemoryEntry {
631 id: id.to_string(),
632 memory_type: ty,
633 content: format!("Test content for {}", id),
634 source: "test".to_string(),
635 session_id: None,
636 tags: vec![],
637 importance: 0.5,
638 created_at: Utc::now(),
639 accessed_at: Utc::now(),
640 access_count: 0,
641 }
642 }
643}