Skip to main content

oxios_kernel/memory/
mod.rs

1//! Agent memory system.
2//!
3//! Provides persistent memory for agents across sessions.
4//! Memory entries are stored as JSON files via StateStore.
5//! Supports embedding-based vector search using TF-IDF + cosine similarity.
6
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::Arc;
10
11use anyhow::Result;
12use chrono::{DateTime, Utc};
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15
16use crate::embedding::{EmbeddingProvider, EmbeddingVector, TfIdfEmbeddingProvider};
17use crate::git_layer::GitLayer;
18use crate::state_store::StateStore;
19
20// Re-export budget types so external `use crate::memory::X` paths still work.
21pub use budget::{CurationCandidate, CurationReport, MemoryBudget};
22pub use store::HnswMemoryIndex;
23
24// ---------------------------------------------------------------------------
25// Content hashing
26// ---------------------------------------------------------------------------
27
28use std::collections::hash_map::DefaultHasher;
29use std::hash::{Hash, Hasher};
30
31/// Compute a stable hash of content for deduplication.
32pub fn content_hash(content: &str) -> u64 {
33    let mut hasher = DefaultHasher::new();
34    content.hash(&mut hasher);
35    hasher.finish()
36}
37
38// ---------------------------------------------------------------------------
39// TextVector (TF-IDF vector for semantic similarity)
40// ---------------------------------------------------------------------------
41
42/// Simple TF-IDF vector for text similarity.
43///
44/// Tokenizes text into terms, computes normalized term frequency,
45/// and supports cosine similarity comparison. No external embedding
46/// model needed — works for any language including Korean.
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct TextVector {
49    /// Term frequencies (normalized).
50    tf: HashMap<String, f64>,
51}
52
53impl TextVector {
54    /// Create a text vector from input text.
55    pub fn from_text(text: &str) -> Self {
56        let mut tf: HashMap<String, f64> = HashMap::new();
57        let terms = Self::tokenize(text);
58        let total = terms.len() as f64;
59
60        for term in terms {
61            *tf.entry(term).or_insert(0.0) += 1.0;
62        }
63
64        // Normalize by total term count
65        if total > 0.0 {
66            for v in tf.values_mut() {
67                *v /= total;
68            }
69        }
70
71        Self { tf }
72    }
73
74    /// Tokenize text into terms (language-agnostic).
75    /// Splits on whitespace and punctuation, lowercases.
76    /// Preserves Korean Hangul syllables (U+AC00–U+D7A3) within tokens.
77    pub fn tokenize(text: &str) -> Vec<String> {
78        text.to_lowercase()
79            .split(|c: char| !c.is_alphanumeric() && !('\u{AC00}'..='\u{D7A3}').contains(&c))
80            .filter(|s| !s.is_empty() && s.len() > 1)
81            .map(|s| s.to_string())
82            .collect()
83    }
84
85    /// Returns a reference to the term-frequency map.
86    pub fn tf_map(&self) -> &HashMap<String, f64> {
87        &self.tf
88    }
89
90    /// Compute cosine similarity between two vectors.
91    pub fn cosine_similarity(&self, other: &TextVector) -> f64 {
92        let mut dot = 0.0;
93        let mut norm_a = 0.0;
94        let mut norm_b = 0.0;
95
96        for (term, &a) in &self.tf {
97            norm_a += a * a;
98            if let Some(&b) = other.tf.get(term) {
99                dot += a * b;
100            }
101        }
102        for &b in other.tf.values() {
103            norm_b += b * b;
104        }
105
106        if norm_a == 0.0 || norm_b == 0.0 {
107            return 0.0;
108        }
109
110        dot / (norm_a.sqrt() * norm_b.sqrt())
111    }
112}
113
114// ---------------------------------------------------------------------------
115// Types
116// ---------------------------------------------------------------------------
117
118/// Memory entry type.
119#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
120#[serde(rename_all = "snake_case")]
121pub enum MemoryType {
122    /// Conversation compaction summary (auto-generated).
123    Conversation,
124    /// Session-end summary (auto-generated).
125    Session,
126    /// Agent-stored fact.
127    Fact,
128    /// Episode memory (event/experience).
129    Episode,
130    /// Static knowledge (user/program-provided).
131    Knowledge,
132}
133
134impl MemoryType {
135    /// Category name used in StateStore.
136    pub fn category(&self) -> &'static str {
137        match self {
138            MemoryType::Conversation => "memory/conversations",
139            MemoryType::Session => "memory/sessions",
140            MemoryType::Fact => "memory/facts",
141            MemoryType::Episode => "memory/episodes",
142            MemoryType::Knowledge => "memory/knowledge",
143        }
144    }
145
146    /// Human-readable label.
147    pub fn label(&self) -> &'static str {
148        match self {
149            MemoryType::Conversation => "conversation",
150            MemoryType::Session => "session",
151            MemoryType::Fact => "fact",
152            MemoryType::Episode => "episode",
153            MemoryType::Knowledge => "knowledge",
154        }
155    }
156}
157
158/// A single memory entry.
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct MemoryEntry {
161    /// Unique ID.
162    pub id: String,
163    /// Memory type.
164    pub memory_type: MemoryType,
165    /// Content (Markdown).
166    pub content: String,
167    /// Creator (agent name, "compaction", "system", etc.).
168    pub source: String,
169    /// Related session ID.
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub session_id: Option<String>,
172    /// Tags for search.
173    #[serde(default)]
174    pub tags: Vec<String>,
175    /// Importance (0.0 – 1.0).
176    #[serde(default = "default_importance")]
177    pub importance: f32,
178    /// Creation timestamp.
179    pub created_at: DateTime<Utc>,
180    /// Last access timestamp.
181    pub accessed_at: DateTime<Utc>,
182    /// Access count.
183    #[serde(default)]
184    pub access_count: u32,
185}
186
187fn default_importance() -> f32 {
188    0.5
189}
190
191// ---------------------------------------------------------------------------
192// MemoryManager
193// ---------------------------------------------------------------------------
194
195/// Agent memory manager.
196///
197/// Stores and retrieves memory entries using the file-based StateStore.
198/// Supports embedding-based vector search via an in-memory TF-IDF index
199/// that is rebuilt on startup.
200pub struct MemoryManager {
201    state_store: Arc<StateStore>,
202    max_recall: usize,
203    /// Vector index for semantic search (id → EmbeddingVector).
204    vector_index: RwLock<HashMap<String, EmbeddingVector>>,
205    /// Embedding provider for generating vectors.
206    embedding: Arc<dyn EmbeddingProvider>,
207    /// Optional git layer for version-controlled memory.
208    git_layer: Option<Arc<GitLayer>>,
209    /// Optional HNSW index for fast ANN search.
210    hnsw_index: RwLock<Option<Arc<HnswMemoryIndex>>>,
211}
212
213impl std::fmt::Debug for MemoryManager {
214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215        f.debug_struct("MemoryManager")
216            .field("max_recall", &self.max_recall)
217            .field("index_size", &self.vector_index.read().len())
218            .finish()
219    }
220}
221
222impl MemoryManager {
223    /// Create a new MemoryManager.
224    pub fn new(state_store: Arc<StateStore>) -> Self {
225        Self {
226            state_store,
227            max_recall: 10,
228            vector_index: RwLock::new(HashMap::new()),
229            embedding: Arc::new(TfIdfEmbeddingProvider),
230            git_layer: None,
231            hnsw_index: RwLock::new(None),
232        }
233    }
234
235    /// Attach a git layer for version-controlled saves.
236    pub fn set_git_layer(&mut self, gl: Arc<GitLayer>) {
237        self.git_layer = Some(gl);
238    }
239
240    /// Create a Space-scoped MemoryManager.
241    ///
242    /// Each Space gets its own StateStore under the given directory,
243    /// providing natural memory isolation between Spaces.
244    pub fn for_space(space_dir: PathBuf) -> Self {
245        let memory_dir = space_dir.join("memory");
246        let state_store = Arc::new(StateStore::new(memory_dir).unwrap_or_else(|_| {
247            // Fallback: create in temp dir
248            StateStore::new(std::env::temp_dir().join("oxios-memory")).unwrap()
249        }));
250        Self::new(state_store)
251    }
252
253    /// Attach an HNSW index for fast semantic search.
254    ///
255    /// Once attached, `remember()` and `forget()` automatically keep
256    /// the HNSW index in sync with the state store.
257    pub fn set_hnsw_index(&self, index: Arc<HnswMemoryIndex>) {
258        *self.hnsw_index.write() = Some(index);
259    }
260
261    /// Commit a file to git if git_layer is available.
262    fn git_commit(&self, rel_path: &str, message: &str) {
263        if let Some(ref gl) = self.git_layer {
264            if gl.is_enabled() {
265                let _ = gl.commit_file(rel_path, message);
266            }
267        }
268    }
269
270    /// Set max memories returned by recall.
271    pub fn with_max_recall(mut self, n: usize) -> Self {
272        self.max_recall = n;
273        self
274    }
275
276    /// Apply MemoryConfig settings.
277    pub fn with_config(mut self, config: &crate::config::MemoryConfig) -> Self {
278        self.max_recall = config.max_recall;
279        self
280    }
281
282    /// Returns the number of entries in the vector index.
283    pub fn vector_index_size(&self) -> usize {
284        self.vector_index.read().len()
285    }
286
287    /// Compute effective importance of a memory entry.
288    ///
289    /// Effective importance = base_importance * (1 + log(1 + access_count))
290    /// Memories accessed frequently get a boost.
291    pub fn effective_importance(entry: &MemoryEntry) -> f32 {
292        let access_boost = (1.0_f32 + entry.access_count as f32).ln();
293        entry.importance * (1.0 + access_boost)
294    }
295
296    /// Curate memories: identify candidates for removal based on budget.
297    ///
298    /// Returns a report of how many entries would be pruned per type.
299    pub async fn curate(&self, budget: &MemoryBudget) -> Result<CurationReport> {
300        let mut report = CurationReport::default();
301
302        for mt in &[
303            MemoryType::Conversation,
304            MemoryType::Session,
305            MemoryType::Fact,
306            MemoryType::Episode,
307            MemoryType::Knowledge,
308        ] {
309            let entries = self.list(*mt, budget.max_per_type * 2).await?;
310            if entries.len() <= budget.max_per_type {
311                continue;
312            }
313
314            // Sort by effective importance ascending (least important first)
315            let total_count = entries.len();
316            let mut scored: Vec<_> = entries
317                .into_iter()
318                .map(|e| (e.id.clone(), e.memory_type, Self::effective_importance(&e)))
319                .collect();
320            scored.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
321
322            let to_remove = scored.len() - budget.max_per_type;
323            for (id, memory_type, score) in scored.into_iter().take(to_remove) {
324                report.candidates_for_removal.push(CurationCandidate {
325                    id,
326                    memory_type,
327                    effective_importance: score,
328                });
329            }
330            report.total_before += total_count;
331        }
332
333        // Actually remove candidates
334        for candidate in &report.candidates_for_removal {
335            if self
336                .forget(&candidate.id, candidate.memory_type)
337                .await
338                .is_ok()
339            {
340                report.removed += 1;
341            }
342        }
343
344        report.total_after = report.total_before - report.removed;
345        Ok(report)
346    }
347
348    /// Spawn a background curation task.
349    ///
350    /// Returns immediately; curation runs asynchronously.
351    pub fn spawn_curation_task(self: &Arc<Self>, budget: MemoryBudget) {
352        let mgr = Arc::clone(self);
353        tokio::spawn(async move {
354            match mgr.curate(&budget).await {
355                Ok(report) => {
356                    if report.removed > 0 {
357                        tracing::info!(
358                            removed = report.removed,
359                            candidates = report.candidates_for_removal.len(),
360                            "Memory curation complete"
361                        );
362                    }
363                }
364                Err(e) => {
365                    tracing::warn!(error = %e, "Memory curation failed");
366                }
367            }
368        });
369    }
370}
371
372// ---------------------------------------------------------------------------
373// Helpers
374// ---------------------------------------------------------------------------
375
376/// Extract search keywords from a query string.
377///
378/// Simple implementation: split on whitespace, lowercase, filter stop words.
379pub(crate) fn extract_keywords(query: &str) -> Vec<String> {
380    const STOP_WORDS: &[&str] = &[
381        "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had",
382        "do", "does", "did", "will", "would", "could", "should", "may", "might", "can", "shall",
383        "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into", "through",
384        "during", "before", "after", "above", "below", "between", "out", "off", "over", "under",
385        "again", "further", "then", "once", "and", "but", "or", "nor", "not", "so", "yet", "both",
386        "either", "neither", "each", "every", "all", "any", "few", "more", "most", "other", "some",
387        "such", "no", "only", "own", "same", "than", "too", "very", "just", "because", "if",
388        "when", "where", "how", "what", "which", "who", "whom", "this", "that", "these", "those",
389        "i", "me", "my", "we", "our", "you", "your", "he", "him", "his", "she", "her", "it", "its",
390        "they", "them", "their",
391    ];
392
393    query
394        .split_whitespace()
395        .map(|w| {
396            // Strip trailing punctuation
397            let w = w.trim_end_matches(|c: char| c.is_ascii_punctuation());
398            w.to_lowercase()
399        })
400        .filter(|w| w.len() > 2 && !STOP_WORDS.contains(&w.as_str()))
401        .collect()
402}
403
404/// Remove duplicate entries by ID, keeping the first occurrence.
405pub(crate) fn dedup_by_id(entries: &mut Vec<MemoryEntry>) {
406    let mut seen = std::collections::HashSet::new();
407    entries.retain(|e| seen.insert(e.id.clone()));
408}
409
410// ---------------------------------------------------------------------------
411// Sub-modules
412// ---------------------------------------------------------------------------
413
414pub mod auto_memory_bridge;
415mod budget;
416mod chunking;
417pub mod flash_attention;
418mod graph;
419mod hnsw;
420pub mod hyperbolic;
421pub mod normalizer;
422pub(crate) mod store;
423
424pub use store::SemanticHit;
425
426// Re-export key types from sub-modules.
427pub use chunking::{chunk_fixed, chunk_paragraphs, ChunkConfig, TextChunk};
428pub use graph::MemoryGraph;
429pub use hnsw::HnswIndex;
430pub use normalizer::{cosine_similarity_f32, l2_normalize_f32, l2_normalize_f64};
431
432// ---------------------------------------------------------------------------
433// Tests
434// ---------------------------------------------------------------------------
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    #[test]
441    fn test_memory_type_category() {
442        assert_eq!(MemoryType::Conversation.category(), "memory/conversations");
443        assert_eq!(MemoryType::Fact.category(), "memory/facts");
444        assert_eq!(MemoryType::Knowledge.category(), "memory/knowledge");
445    }
446
447    #[test]
448    fn test_extract_keywords() {
449        let kw = extract_keywords("How do I implement a Rust agent system?");
450        assert!(kw.contains(&"implement".to_string()));
451        assert!(kw.contains(&"rust".to_string()));
452        assert!(kw.contains(&"agent".to_string()));
453        assert!(kw.contains(&"system".to_string()));
454        // stop words filtered
455        assert!(!kw.contains(&"how".to_string()));
456        assert!(!kw.contains(&"do".to_string()));
457    }
458
459    #[test]
460    fn test_dedup_by_id() {
461        let mut entries = vec![
462            make_entry("a", MemoryType::Fact),
463            make_entry("b", MemoryType::Fact),
464            make_entry("a", MemoryType::Episode), // duplicate id
465        ];
466        dedup_by_id(&mut entries);
467        assert_eq!(entries.len(), 2);
468    }
469
470    #[test]
471    fn test_blend_into_prompt_empty() {
472        let mgr = MemoryManager::new(Arc::new(
473            StateStore::new(std::env::temp_dir().join("test")).unwrap(),
474        ));
475        let result = mgr.blend_into_prompt(&[], "You are an agent.");
476        assert_eq!(result, "You are an agent.");
477    }
478
479    #[test]
480    fn test_blend_into_prompt_with_memories() {
481        let mgr = MemoryManager::new(Arc::new(
482            StateStore::new(std::env::temp_dir().join("test")).unwrap(),
483        ));
484        let memories = vec![make_entry("test", MemoryType::Fact)];
485        let result = mgr.blend_into_prompt(&memories, "You are an agent.");
486        assert!(result.contains("## Relevant Memory"));
487        assert!(result.contains("[fact]"));
488    }
489
490    // ---- Vector search tests ----
491
492    #[test]
493    fn test_text_vector_cosine_similarity() {
494        let v1 = TextVector::from_text("fix the null pointer error in main.rs");
495        let v2 = TextVector::from_text("null pointer error found in rust code");
496        let v3 = TextVector::from_text("update the documentation for deployment");
497
498        // Similar texts should have high similarity
499        assert!(
500            v1.cosine_similarity(&v2) > 0.3,
501            "Similar texts should have > 0.3 similarity"
502        );
503
504        // Different texts should have low similarity
505        assert!(
506            v1.cosine_similarity(&v3) < 0.2,
507            "Different texts should have < 0.2 similarity"
508        );
509    }
510
511    #[test]
512    fn test_text_vector_korean() {
513        let v1 = TextVector::from_text("main.rs 파일의 null pointer 에러 수정");
514        let v2 = TextVector::from_text("null pointer 오류를 수정했습니다");
515        let v3 = TextVector::from_text("문서 업데이트 배포 가이드");
516
517        assert!(v1.cosine_similarity(&v2) > 0.1, "Korean+code similarity");
518        assert!(v1.cosine_similarity(&v3) < 0.1, "Korean different topics");
519    }
520
521    #[test]
522    fn test_text_vector_empty() {
523        let v1 = TextVector::from_text("");
524        let v2 = TextVector::from_text("hello");
525        assert_eq!(v1.cosine_similarity(&v2), 0.0);
526    }
527
528    #[test]
529    fn test_text_vector_identical() {
530        let v1 = TextVector::from_text("rust programming language");
531        let v2 = TextVector::from_text("rust programming language");
532        let sim = v1.cosine_similarity(&v2);
533        assert!(
534            (sim - 1.0).abs() < 1e-9,
535            "Identical texts should have similarity ~1.0, got {}",
536            sim
537        );
538    }
539
540    #[test]
541    fn test_tokenize_korean() {
542        let terms = TextVector::tokenize("main.rs 파일의 버그를 수정");
543        // Should contain at least some meaningful tokens
544        assert!(!terms.is_empty(), "Korean text should produce tokens");
545    }
546
547    #[tokio::test]
548    async fn test_vector_search_over_keyword_fallback() {
549        let temp_dir = tempfile::tempdir().unwrap();
550        let store = Arc::new(StateStore::new(temp_dir.path().to_path_buf()).unwrap());
551        let mgr = MemoryManager::new(store.clone());
552
553        // Store some memories
554        let entry1 = MemoryEntry {
555            id: "vec-test-1".to_string(),
556            memory_type: MemoryType::Fact,
557            content: "Rust is a systems programming language focused on safety".to_string(),
558            source: "test".to_string(),
559            session_id: None,
560            tags: vec![],
561            importance: 0.5,
562            created_at: Utc::now(),
563            accessed_at: Utc::now(),
564            access_count: 0,
565        };
566        let entry2 = MemoryEntry {
567            id: "vec-test-2".to_string(),
568            memory_type: MemoryType::Fact,
569            content: "Python is great for machine learning and data science".to_string(),
570            source: "test".to_string(),
571            session_id: None,
572            tags: vec![],
573            importance: 0.5,
574            created_at: Utc::now(),
575            accessed_at: Utc::now(),
576            access_count: 0,
577        };
578
579        mgr.remember(entry1).await.unwrap();
580        mgr.remember(entry2).await.unwrap();
581
582        // Vector search should find the Rust entry for a Rust-related query
583        let results = mgr
584            .search("systems programming with rust", None, 5)
585            .await
586            .unwrap();
587        assert!(!results.is_empty(), "Vector search should find results");
588        assert_eq!(
589            results[0].id, "vec-test-1",
590            "Should find the Rust entry first"
591        );
592    }
593
594    #[tokio::test]
595    async fn test_rebuild_index() {
596        let temp_dir = tempfile::tempdir().unwrap();
597        let store = Arc::new(StateStore::new(temp_dir.path().to_path_buf()).unwrap());
598        let mgr = MemoryManager::new(store.clone());
599
600        // Store a memory directly via state_store (bypassing remember to test rebuild)
601        let entry = MemoryEntry {
602            id: "rebuild-test-1".to_string(),
603            memory_type: MemoryType::Fact,
604            content: "memory for rebuild test".to_string(),
605            source: "test".to_string(),
606            session_id: None,
607            tags: vec![],
608            importance: 0.5,
609            created_at: Utc::now(),
610            accessed_at: Utc::now(),
611            access_count: 0,
612        };
613        store
614            .save_json("memory/facts", "rebuild-test-1", &entry)
615            .await
616            .unwrap();
617
618        // Index should be empty before rebuild
619        assert_eq!(mgr.vector_index.read().len(), 0);
620
621        // Rebuild
622        mgr.rebuild_index().await.unwrap();
623
624        // Index should now contain the entry
625        assert_eq!(mgr.vector_index.read().len(), 1);
626        assert!(mgr.vector_index.read().contains_key("rebuild-test-1"));
627    }
628
629    fn make_entry(id: &str, ty: MemoryType) -> MemoryEntry {
630        MemoryEntry {
631            id: id.to_string(),
632            memory_type: ty,
633            content: format!("Test content for {}", id),
634            source: "test".to_string(),
635            session_id: None,
636            tags: vec![],
637            importance: 0.5,
638            created_at: Utc::now(),
639            accessed_at: Utc::now(),
640            access_count: 0,
641        }
642    }
643}