Skip to main content

oxios_kernel/memory/
mod.rs

1//! Agent memory system.
2//!
3//! Provides persistent memory for agents across sessions.
4//! Memory entries are stored as JSON files via StateStore.
5//! Supports embedding-based vector search using TF-IDF + cosine similarity.
6
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::Arc;
10
11use anyhow::Result;
12use chrono::{DateTime, Utc};
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15
16use crate::embedding::{EmbeddingProvider, EmbeddingVector, TfIdfEmbeddingProvider};
17use crate::git_layer::GitLayer;
18use crate::state_store::StateStore;
19
20// Re-export budget types so external `use crate::memory::X` paths still work.
21pub use budget::{CurationCandidate, CurationReport, MemoryBudget};
22pub use store::HnswMemoryIndex;
23
24// ---------------------------------------------------------------------------
25// Content hashing
26// ---------------------------------------------------------------------------
27
28use std::collections::hash_map::DefaultHasher;
29use std::hash::{Hash, Hasher};
30
31/// Compute a stable hash of content for deduplication.
32pub fn content_hash(content: &str) -> u64 {
33    let mut hasher = DefaultHasher::new();
34    content.hash(&mut hasher);
35    hasher.finish()
36}
37
38// ---------------------------------------------------------------------------
39// TextVector (TF-IDF vector for semantic similarity)
40// ---------------------------------------------------------------------------
41
42/// Simple TF-IDF vector for text similarity.
43///
44/// Tokenizes text into terms, computes normalized term frequency,
45/// and supports cosine similarity comparison. No external embedding
46/// model needed — works for any language including Korean.
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct TextVector {
49    /// Term frequencies (normalized).
50    tf: HashMap<String, f64>,
51}
52
53impl TextVector {
54    /// Create a text vector from input text.
55    pub fn from_text(text: &str) -> Self {
56        let mut tf: HashMap<String, f64> = HashMap::new();
57        let terms = Self::tokenize(text);
58        let total = terms.len() as f64;
59
60        for term in terms {
61            *tf.entry(term).or_insert(0.0) += 1.0;
62        }
63
64        // Normalize by total term count
65        if total > 0.0 {
66            for v in tf.values_mut() {
67                *v /= total;
68            }
69        }
70
71        Self { tf }
72    }
73
74    /// Tokenize text into terms (language-agnostic).
75    /// Splits on whitespace and punctuation, lowercases.
76    /// Preserves Korean Hangul syllables (U+AC00–U+D7A3) within tokens.
77    pub fn tokenize(text: &str) -> Vec<String> {
78        text.to_lowercase()
79            .split(|c: char| !c.is_alphanumeric() && !('\u{AC00}'..='\u{D7A3}').contains(&c))
80            .filter(|s| !s.is_empty() && s.len() > 1)
81            .map(|s| s.to_string())
82            .collect()
83    }
84
85    /// Returns a reference to the term-frequency map.
86    pub fn tf_map(&self) -> &HashMap<String, f64> {
87        &self.tf
88    }
89
90    /// Compute cosine similarity between two vectors.
91    pub fn cosine_similarity(&self, other: &TextVector) -> f64 {
92        let mut dot = 0.0;
93        let mut norm_a = 0.0;
94        let mut norm_b = 0.0;
95
96        for (term, &a) in &self.tf {
97            norm_a += a * a;
98            if let Some(&b) = other.tf.get(term) {
99                dot += a * b;
100            }
101        }
102        for &b in other.tf.values() {
103            norm_b += b * b;
104        }
105
106        if norm_a == 0.0 || norm_b == 0.0 {
107            return 0.0;
108        }
109
110        dot / (norm_a.sqrt() * norm_b.sqrt())
111    }
112}
113
114// ---------------------------------------------------------------------------
115// Types
116// ---------------------------------------------------------------------------
117
118/// Memory entry type.
119#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
120#[serde(rename_all = "snake_case")]
121pub enum MemoryType {
122    /// Conversation compaction summary (auto-generated).
123    Conversation,
124    /// Session-end summary (auto-generated).
125    Session,
126    /// Agent-stored fact.
127    Fact,
128    /// Episode memory (event/experience).
129    Episode,
130    /// Static knowledge (user/program-provided).
131    Knowledge,
132}
133
134impl MemoryType {
135    /// Category name used in StateStore.
136    pub fn category(&self) -> &'static str {
137        match self {
138            MemoryType::Conversation => "memory/conversations",
139            MemoryType::Session => "memory/sessions",
140            MemoryType::Fact => "memory/facts",
141            MemoryType::Episode => "memory/episodes",
142            MemoryType::Knowledge => "memory/knowledge",
143        }
144    }
145
146    /// Human-readable label.
147    pub fn label(&self) -> &'static str {
148        match self {
149            MemoryType::Conversation => "conversation",
150            MemoryType::Session => "session",
151            MemoryType::Fact => "fact",
152            MemoryType::Episode => "episode",
153            MemoryType::Knowledge => "knowledge",
154        }
155    }
156}
157
158/// A single memory entry.
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct MemoryEntry {
161    /// Unique ID.
162    pub id: String,
163    /// Memory type.
164    pub memory_type: MemoryType,
165    /// Content (Markdown).
166    pub content: String,
167    /// Creator (agent name, "compaction", "system", etc.).
168    pub source: String,
169    /// Related session ID.
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub session_id: Option<String>,
172    /// Tags for search.
173    #[serde(default)]
174    pub tags: Vec<String>,
175    /// Importance (0.0 – 1.0).
176    #[serde(default = "default_importance")]
177    pub importance: f32,
178    /// Creation timestamp.
179    pub created_at: DateTime<Utc>,
180    /// Last access timestamp.
181    pub accessed_at: DateTime<Utc>,
182    /// Access count.
183    #[serde(default)]
184    pub access_count: u32,
185}
186
187fn default_importance() -> f32 {
188    0.5
189}
190
191// ---------------------------------------------------------------------------
192// MemoryManager
193// ---------------------------------------------------------------------------
194
195/// Agent memory manager.
196///
197/// Stores and retrieves memory entries using the file-based StateStore.
198/// Supports embedding-based vector search via an in-memory TF-IDF index
199/// that is rebuilt on startup.
200pub struct MemoryManager {
201    state_store: Arc<StateStore>,
202    max_recall: usize,
203    /// Vector index for semantic search (id → EmbeddingVector).
204    vector_index: RwLock<HashMap<String, EmbeddingVector>>,
205    /// Embedding provider for generating vectors.
206    embedding: Arc<dyn EmbeddingProvider>,
207    /// Optional git layer for version-controlled memory.
208    git_layer: Option<Arc<GitLayer>>,
209    /// Optional HNSW index for fast ANN search.
210    hnsw_index: RwLock<Option<Arc<HnswMemoryIndex>>>,
211}
212
213impl std::fmt::Debug for MemoryManager {
214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215        f.debug_struct("MemoryManager")
216            .field("max_recall", &self.max_recall)
217            .field("index_size", &self.vector_index.read().len())
218            .finish()
219    }
220}
221
222impl MemoryManager {
223    /// Create a new MemoryManager.
224    pub fn new(state_store: Arc<StateStore>) -> Self {
225        Self {
226            state_store,
227            max_recall: 10,
228            vector_index: RwLock::new(HashMap::new()),
229            embedding: Arc::new(TfIdfEmbeddingProvider),
230            git_layer: None,
231            hnsw_index: RwLock::new(None),
232        }
233    }
234
235    /// Attach a git layer for version-controlled saves.
236    pub fn set_git_layer(&mut self, gl: Arc<GitLayer>) {
237        self.git_layer = Some(gl);
238    }
239
240    /// Create a Space-scoped MemoryManager.
241    ///
242    /// Each Space gets its own StateStore under the given directory,
243    /// providing natural memory isolation between Spaces.
244    pub fn for_space(space_dir: PathBuf) -> Self {
245        let memory_dir = space_dir.join("memory");
246        let state_store = Arc::new(StateStore::new(memory_dir).unwrap_or_else(|_| {
247            // Fallback: create in temp dir
248            StateStore::new(std::env::temp_dir().join("oxios-memory")).unwrap()
249        }));
250        Self::new(state_store)
251    }
252
253    /// Attach an HNSW index for fast semantic search.
254    ///
255    /// Once attached, `remember()` and `forget()` automatically keep
256    /// the HNSW index in sync with the state store.
257    pub fn set_hnsw_index(&self, index: Arc<HnswMemoryIndex>) {
258        *self.hnsw_index.write() = Some(index);
259    }
260
261    /// Commit a file to git if git_layer is available.
262    fn git_commit(&self, rel_path: &str, message: &str) {
263        if let Some(ref gl) = self.git_layer {
264            if gl.is_enabled() {
265                let _ = gl.commit_file(rel_path, message);
266            }
267        }
268    }
269
270    /// Set max memories returned by recall.
271    pub fn with_max_recall(mut self, n: usize) -> Self {
272        self.max_recall = n;
273        self
274    }
275
276    /// Apply MemoryConfig settings.
277    pub fn with_config(mut self, config: &crate::config::MemoryConfig) -> Self {
278        self.max_recall = config.max_recall;
279        self
280    }
281
282    /// Returns the number of entries in the vector index.
283    pub fn vector_index_size(&self) -> usize {
284        self.vector_index.read().len()
285    }
286
287    /// Compute effective importance of a memory entry.
288    ///
289    /// Effective importance = base_importance * (1 + log(1 + access_count))
290    /// Memories accessed frequently get a boost.
291    pub fn effective_importance(entry: &MemoryEntry) -> f32 {
292        let access_boost = (1.0_f32 + entry.access_count as f32).ln();
293        entry.importance * (1.0 + access_boost)
294    }
295
296    /// Curate memories: identify candidates for removal based on budget.
297    ///
298    /// Returns a report of how many entries would be pruned per type.
299    pub async fn curate(&self, budget: &MemoryBudget) -> Result<CurationReport> {
300        let mut report = CurationReport::default();
301
302        for mt in &[
303            MemoryType::Conversation,
304            MemoryType::Session,
305            MemoryType::Fact,
306            MemoryType::Episode,
307            MemoryType::Knowledge,
308        ] {
309            let entries = self.list(*mt, budget.max_per_type * 2).await?;
310            if entries.len() <= budget.max_per_type {
311                continue;
312            }
313
314            // Sort by effective importance ascending (least important first)
315            let total_count = entries.len();
316            let mut scored: Vec<_> = entries
317                .into_iter()
318                .map(|e| (e.id.clone(), e.memory_type, Self::effective_importance(&e)))
319                .collect();
320            scored.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
321
322            let to_remove = scored.len() - budget.max_per_type;
323            for (id, memory_type, score) in scored.into_iter().take(to_remove) {
324                report.candidates_for_removal.push(CurationCandidate {
325                    id,
326                    memory_type,
327                    effective_importance: score,
328                });
329            }
330            report.total_before += total_count;
331        }
332
333        // Actually remove candidates
334        for candidate in &report.candidates_for_removal {
335            if self
336                .forget(&candidate.id, candidate.memory_type)
337                .await
338                .is_ok()
339            {
340                report.removed += 1;
341            }
342        }
343
344        report.total_after = report.total_before - report.removed;
345        Ok(report)
346    }
347
348    /// Spawn a background curation task.
349    ///
350    /// Returns immediately; curation runs asynchronously.
351    pub fn spawn_curation_task(self: &Arc<Self>, budget: MemoryBudget) {
352        let mgr = Arc::clone(self);
353        tokio::spawn(async move {
354            match mgr.curate(&budget).await {
355                Ok(report) => {
356                    if report.removed > 0 {
357                        tracing::info!(
358                            removed = report.removed,
359                            candidates = report.candidates_for_removal.len(),
360                            "Memory curation complete"
361                        );
362                    }
363                }
364                Err(e) => {
365                    tracing::warn!(error = %e, "Memory curation failed");
366                }
367            }
368        });
369    }
370}
371
372// ---------------------------------------------------------------------------
373// Helpers
374// ---------------------------------------------------------------------------
375
376/// Extract search keywords from a query string.
377///
378/// Simple implementation: split on whitespace, lowercase, filter stop words.
379pub(crate) fn extract_keywords(query: &str) -> Vec<String> {
380    const STOP_WORDS: &[&str] = &[
381        "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had",
382        "do", "does", "did", "will", "would", "could", "should", "may", "might", "can", "shall",
383        "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into", "through",
384        "during", "before", "after", "above", "below", "between", "out", "off", "over", "under",
385        "again", "further", "then", "once", "and", "but", "or", "nor", "not", "so", "yet", "both",
386        "either", "neither", "each", "every", "all", "any", "few", "more", "most", "other", "some",
387        "such", "no", "only", "own", "same", "than", "too", "very", "just", "because", "if",
388        "when", "where", "how", "what", "which", "who", "whom", "this", "that", "these", "those",
389        "i", "me", "my", "we", "our", "you", "your", "he", "him", "his", "she", "her", "it", "its",
390        "they", "them", "their",
391    ];
392
393    query
394        .split_whitespace()
395        .map(|w| {
396            // Strip trailing punctuation
397            let w = w.trim_end_matches(|c: char| c.is_ascii_punctuation());
398            w.to_lowercase()
399        })
400        .filter(|w| w.len() > 2 && !STOP_WORDS.contains(&w.as_str()))
401        .collect()
402}
403
404/// Remove duplicate entries by ID, keeping the first occurrence.
405pub(crate) fn dedup_by_id(entries: &mut Vec<MemoryEntry>) {
406    let mut seen = std::collections::HashSet::new();
407    entries.retain(|e| seen.insert(e.id.clone()));
408}
409
410// ---------------------------------------------------------------------------
411// Sub-modules
412// ---------------------------------------------------------------------------
413
414pub mod auto_memory_bridge;
415mod budget;
416mod chunking;
417pub mod embedding_cache;
418pub mod flash_attention;
419mod graph;
420mod hnsw;
421pub mod hyperbolic;
422pub mod normalizer;
423pub(crate) mod store;
424
425pub use embedding_cache::{CacheStats, EmbeddingCache};
426pub use store::SemanticHit;
427
428// Re-export key types from sub-modules.
429pub use chunking::{chunk_fixed, chunk_paragraphs, ChunkConfig, TextChunk};
430pub use graph::MemoryGraph;
431pub use hnsw::HnswIndex;
432pub use normalizer::{cosine_similarity_f32, l2_normalize_f32, l2_normalize_f64};
433
434// ---------------------------------------------------------------------------
435// Tests
436// ---------------------------------------------------------------------------
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_memory_type_category() {
444        assert_eq!(MemoryType::Conversation.category(), "memory/conversations");
445        assert_eq!(MemoryType::Fact.category(), "memory/facts");
446        assert_eq!(MemoryType::Knowledge.category(), "memory/knowledge");
447    }
448
449    #[test]
450    fn test_extract_keywords() {
451        let kw = extract_keywords("How do I implement a Rust agent system?");
452        assert!(kw.contains(&"implement".to_string()));
453        assert!(kw.contains(&"rust".to_string()));
454        assert!(kw.contains(&"agent".to_string()));
455        assert!(kw.contains(&"system".to_string()));
456        // stop words filtered
457        assert!(!kw.contains(&"how".to_string()));
458        assert!(!kw.contains(&"do".to_string()));
459    }
460
461    #[test]
462    fn test_dedup_by_id() {
463        let mut entries = vec![
464            make_entry("a", MemoryType::Fact),
465            make_entry("b", MemoryType::Fact),
466            make_entry("a", MemoryType::Episode), // duplicate id
467        ];
468        dedup_by_id(&mut entries);
469        assert_eq!(entries.len(), 2);
470    }
471
472    #[test]
473    fn test_blend_into_prompt_empty() {
474        let mgr = MemoryManager::new(Arc::new(
475            StateStore::new(std::env::temp_dir().join("test")).unwrap(),
476        ));
477        let result = mgr.blend_into_prompt(&[], "You are an agent.");
478        assert_eq!(result, "You are an agent.");
479    }
480
481    #[test]
482    fn test_blend_into_prompt_with_memories() {
483        let mgr = MemoryManager::new(Arc::new(
484            StateStore::new(std::env::temp_dir().join("test")).unwrap(),
485        ));
486        let memories = vec![make_entry("test", MemoryType::Fact)];
487        let result = mgr.blend_into_prompt(&memories, "You are an agent.");
488        assert!(result.contains("## Relevant Memory"));
489        assert!(result.contains("[fact]"));
490    }
491
492    // ---- Vector search tests ----
493
494    #[test]
495    fn test_text_vector_cosine_similarity() {
496        let v1 = TextVector::from_text("fix the null pointer error in main.rs");
497        let v2 = TextVector::from_text("null pointer error found in rust code");
498        let v3 = TextVector::from_text("update the documentation for deployment");
499
500        // Similar texts should have high similarity
501        assert!(
502            v1.cosine_similarity(&v2) > 0.3,
503            "Similar texts should have > 0.3 similarity"
504        );
505
506        // Different texts should have low similarity
507        assert!(
508            v1.cosine_similarity(&v3) < 0.2,
509            "Different texts should have < 0.2 similarity"
510        );
511    }
512
513    #[test]
514    fn test_text_vector_korean() {
515        let v1 = TextVector::from_text("main.rs 파일의 null pointer 에러 수정");
516        let v2 = TextVector::from_text("null pointer 오류를 수정했습니다");
517        let v3 = TextVector::from_text("문서 업데이트 배포 가이드");
518
519        assert!(v1.cosine_similarity(&v2) > 0.1, "Korean+code similarity");
520        assert!(v1.cosine_similarity(&v3) < 0.1, "Korean different topics");
521    }
522
523    #[test]
524    fn test_text_vector_empty() {
525        let v1 = TextVector::from_text("");
526        let v2 = TextVector::from_text("hello");
527        assert_eq!(v1.cosine_similarity(&v2), 0.0);
528    }
529
530    #[test]
531    fn test_text_vector_identical() {
532        let v1 = TextVector::from_text("rust programming language");
533        let v2 = TextVector::from_text("rust programming language");
534        let sim = v1.cosine_similarity(&v2);
535        assert!(
536            (sim - 1.0).abs() < 1e-9,
537            "Identical texts should have similarity ~1.0, got {}",
538            sim
539        );
540    }
541
542    #[test]
543    fn test_tokenize_korean() {
544        let terms = TextVector::tokenize("main.rs 파일의 버그를 수정");
545        // Should contain at least some meaningful tokens
546        assert!(!terms.is_empty(), "Korean text should produce tokens");
547    }
548
549    #[tokio::test]
550    async fn test_vector_search_over_keyword_fallback() {
551        let temp_dir = tempfile::tempdir().unwrap();
552        let store = Arc::new(StateStore::new(temp_dir.path().to_path_buf()).unwrap());
553        let mgr = MemoryManager::new(store.clone());
554
555        // Store some memories
556        let entry1 = MemoryEntry {
557            id: "vec-test-1".to_string(),
558            memory_type: MemoryType::Fact,
559            content: "Rust is a systems programming language focused on safety".to_string(),
560            source: "test".to_string(),
561            session_id: None,
562            tags: vec![],
563            importance: 0.5,
564            created_at: Utc::now(),
565            accessed_at: Utc::now(),
566            access_count: 0,
567        };
568        let entry2 = MemoryEntry {
569            id: "vec-test-2".to_string(),
570            memory_type: MemoryType::Fact,
571            content: "Python is great for machine learning and data science".to_string(),
572            source: "test".to_string(),
573            session_id: None,
574            tags: vec![],
575            importance: 0.5,
576            created_at: Utc::now(),
577            accessed_at: Utc::now(),
578            access_count: 0,
579        };
580
581        mgr.remember(entry1).await.unwrap();
582        mgr.remember(entry2).await.unwrap();
583
584        // Vector search should find the Rust entry for a Rust-related query
585        let results = mgr
586            .search("systems programming with rust", None, 5)
587            .await
588            .unwrap();
589        assert!(!results.is_empty(), "Vector search should find results");
590        assert_eq!(
591            results[0].id, "vec-test-1",
592            "Should find the Rust entry first"
593        );
594    }
595
596    #[tokio::test]
597    async fn test_rebuild_index() {
598        let temp_dir = tempfile::tempdir().unwrap();
599        let store = Arc::new(StateStore::new(temp_dir.path().to_path_buf()).unwrap());
600        let mgr = MemoryManager::new(store.clone());
601
602        // Store a memory directly via state_store (bypassing remember to test rebuild)
603        let entry = MemoryEntry {
604            id: "rebuild-test-1".to_string(),
605            memory_type: MemoryType::Fact,
606            content: "memory for rebuild test".to_string(),
607            source: "test".to_string(),
608            session_id: None,
609            tags: vec![],
610            importance: 0.5,
611            created_at: Utc::now(),
612            accessed_at: Utc::now(),
613            access_count: 0,
614        };
615        store
616            .save_json("memory/facts", "rebuild-test-1", &entry)
617            .await
618            .unwrap();
619
620        // Index should be empty before rebuild
621        assert_eq!(mgr.vector_index.read().len(), 0);
622
623        // Rebuild
624        mgr.rebuild_index().await.unwrap();
625
626        // Index should now contain the entry
627        assert_eq!(mgr.vector_index.read().len(), 1);
628        assert!(mgr.vector_index.read().contains_key("rebuild-test-1"));
629    }
630
631    fn make_entry(id: &str, ty: MemoryType) -> MemoryEntry {
632        MemoryEntry {
633            id: id.to_string(),
634            memory_type: ty,
635            content: format!("Test content for {}", id),
636            source: "test".to_string(),
637            session_id: None,
638            tags: vec![],
639            importance: 0.5,
640            created_at: Utc::now(),
641            accessed_at: Utc::now(),
642            access_count: 0,
643        }
644    }
645}