omega_memory/
individual.rs

1//! Individual scale memory (Tier 1-4)
2//! Implements working memory storage with AgentDB integration
3
4use crate::{Memory, MemoryContent, MemoryError, MemoryTier, Query};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11/// Individual memory system managing Tier 1-4
12pub struct IndividualMemory {
13    /// Tier 1: Instant memory (in-memory only)
14    instant: Arc<RwLock<HashMap<String, Memory>>>,
15
16    /// Tier 2: Session memory (in-memory + AgentDB)
17    session: Arc<RwLock<HashMap<String, Memory>>>,
18
19    /// Tier 3: Episodic memory (AgentDB)
20    episodic: Arc<RwLock<AgentDBWrapper>>,
21
22    /// Tier 4: Semantic memory (AgentDB with indexing)
23    semantic: Arc<RwLock<AgentDBWrapper>>,
24}
25
26impl IndividualMemory {
27    pub async fn new() -> Result<Self, MemoryError> {
28        let episodic_path = PathBuf::from("/tmp/omega/memory/episodic.agentdb");
29        let semantic_path = PathBuf::from("/tmp/omega/memory/semantic.agentdb");
30
31        // Ensure directories exist
32        if let Some(parent) = episodic_path.parent() {
33            tokio::fs::create_dir_all(parent)
34                .await
35                .map_err(|e| MemoryError::Storage(format!("Failed to create directory: {}", e)))?;
36        }
37
38        Ok(Self {
39            instant: Arc::new(RwLock::new(HashMap::new())),
40            session: Arc::new(RwLock::new(HashMap::new())),
41            episodic: Arc::new(RwLock::new(AgentDBWrapper::new(episodic_path).await?)),
42            semantic: Arc::new(RwLock::new(AgentDBWrapper::new(semantic_path).await?)),
43        })
44    }
45
46    pub async fn store(&self, memory: Memory) -> Result<String, MemoryError> {
47        let id = memory.id.clone();
48
49        match memory.tier {
50            MemoryTier::Instant => {
51                self.instant.write().await.insert(id.clone(), memory);
52                self.prune_instant().await?;
53            }
54            MemoryTier::Session => {
55                self.session.write().await.insert(id.clone(), memory);
56                self.prune_session().await?;
57            }
58            MemoryTier::Episodic => {
59                self.episodic.write().await.store(memory).await?;
60            }
61            MemoryTier::Semantic => {
62                self.semantic.write().await.store(memory).await?;
63            }
64            _ => {
65                return Err(MemoryError::Storage(format!(
66                    "Invalid tier {:?} for individual memory",
67                    memory.tier
68                )));
69            }
70        }
71
72        Ok(id)
73    }
74
75    pub async fn recall(
76        &self,
77        query: &Query,
78        tiers: &[MemoryTier],
79    ) -> Result<Vec<Memory>, MemoryError> {
80        let mut results = Vec::new();
81
82        for tier in tiers {
83            match tier {
84                MemoryTier::Instant => {
85                    let instant_mem = self.instant.read().await;
86                    let mut memories: Vec<Memory> = instant_mem.values().cloned().collect();
87                    memories = self.filter_memories(memories, query);
88                    results.extend(memories);
89                }
90                MemoryTier::Session => {
91                    let session_mem = self.session.read().await;
92                    let mut memories: Vec<Memory> = session_mem.values().cloned().collect();
93                    memories = self.filter_memories(memories, query);
94                    results.extend(memories);
95                }
96                MemoryTier::Episodic => {
97                    let episodic_results = self.episodic.read().await.search(query).await?;
98                    results.extend(episodic_results);
99                }
100                MemoryTier::Semantic => {
101                    let semantic_results = self.semantic.read().await.search(query).await?;
102                    results.extend(semantic_results);
103                }
104                _ => {}
105            }
106        }
107
108        Ok(results)
109    }
110
111    pub async fn stats(&self) -> IndividualMemoryStats {
112        let instant_count = self.instant.read().await.len();
113        let session_count = self.session.read().await.len();
114        let episodic_count = self.episodic.read().await.count().await;
115        let semantic_count = self.semantic.read().await.count().await;
116
117        IndividualMemoryStats {
118            instant: instant_count,
119            session: session_count,
120            episodic: episodic_count,
121            semantic: semantic_count,
122            total: instant_count + session_count + episodic_count + semantic_count,
123        }
124    }
125
126    async fn prune_instant(&self) -> Result<(), MemoryError> {
127        let mut instant = self.instant.write().await;
128        let max_size = MemoryTier::Instant.typical_size();
129
130        if instant.len() > max_size {
131            let mut entries: Vec<_> = instant.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
132            entries.sort_by(|a, b| {
133                a.1.accessed_at
134                    .cmp(&b.1.accessed_at)
135            });
136
137            let to_remove = entries.len() - max_size;
138            for (key, _) in entries.iter().take(to_remove) {
139                instant.remove(key);
140            }
141        }
142
143        Ok(())
144    }
145
146    async fn prune_session(&self) -> Result<(), MemoryError> {
147        let mut session = self.session.write().await;
148        let max_size = MemoryTier::Session.typical_size();
149
150        if session.len() > max_size {
151            let mut entries: Vec<_> = session.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
152            entries.sort_by(|a, b| {
153                b.1.relevance_score()
154                    .partial_cmp(&a.1.relevance_score())
155                    .unwrap_or(std::cmp::Ordering::Equal)
156            });
157
158            let to_remove = entries.len() - max_size;
159            for (key, memory) in entries.iter().rev().take(to_remove) {
160                // Promote important memories to episodic before removing
161                if memory.importance > 0.3 {
162                    let mut promoted = memory.clone();
163                    promoted.tier = MemoryTier::Episodic;
164                    self.episodic.write().await.store(promoted).await?;
165                }
166                session.remove(key);
167            }
168        }
169
170        Ok(())
171    }
172
173    fn filter_memories(&self, memories: Vec<Memory>, query: &Query) -> Vec<Memory> {
174        memories
175            .into_iter()
176            .filter(|m| {
177                // Filter by importance threshold
178                if let Some(min_importance) = query.min_importance {
179                    if m.importance < min_importance {
180                        return false;
181                    }
182                }
183
184                // Text matching if text query provided
185                if let Some(ref text) = query.text {
186                    if let MemoryContent::Text(ref content) = m.content {
187                        if !content.to_lowercase().contains(&text.to_lowercase()) {
188                            return false;
189                        }
190                    } else {
191                        return false;
192                    }
193                }
194
195                true
196            })
197            .collect()
198    }
199}
200
201/// AgentDB wrapper for persistent storage
202pub struct AgentDBWrapper {
203    path: PathBuf,
204    memories: HashMap<String, Memory>,
205}
206
207impl AgentDBWrapper {
208    async fn new(path: PathBuf) -> Result<Self, MemoryError> {
209        let mut wrapper = Self {
210            path,
211            memories: HashMap::new(),
212        };
213
214        // Load existing memories if file exists
215        if wrapper.path.exists() {
216            wrapper.load().await?;
217        }
218
219        Ok(wrapper)
220    }
221
222    async fn store(&mut self, memory: Memory) -> Result<(), MemoryError> {
223        self.memories.insert(memory.id.clone(), memory);
224        self.save().await?;
225        Ok(())
226    }
227
228    async fn search(&self, query: &Query) -> Result<Vec<Memory>, MemoryError> {
229        let mut results: Vec<Memory> = self.memories.values().cloned().collect();
230
231        // Filter by importance
232        if let Some(min_importance) = query.min_importance {
233            results.retain(|m| m.importance >= min_importance);
234        }
235
236        // Vector similarity search if embedding provided
237        if let Some(ref query_embedding) = query.embedding {
238            results.sort_by(|a, b| {
239                let sim_a = cosine_similarity(&a.embedding, query_embedding);
240                let sim_b = cosine_similarity(&b.embedding, query_embedding);
241                sim_b.partial_cmp(&sim_a).unwrap_or(std::cmp::Ordering::Equal)
242            });
243
244            // Take top k results
245            if let Some(limit) = query.limit {
246                results.truncate(limit);
247            }
248        }
249
250        Ok(results)
251    }
252
253    async fn count(&self) -> usize {
254        self.memories.len()
255    }
256
257    async fn load(&mut self) -> Result<(), MemoryError> {
258        let data = tokio::fs::read(&self.path).await?;
259        self.memories = serde_json::from_slice(&data)?;
260        Ok(())
261    }
262
263    async fn save(&self) -> Result<(), MemoryError> {
264        let data = serde_json::to_vec_pretty(&self.memories)?;
265        tokio::fs::write(&self.path, data).await?;
266        Ok(())
267    }
268}
269
270/// Cosine similarity between two vectors
271fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
272    if a.len() != b.len() {
273        return 0.0;
274    }
275
276    let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
277    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
278    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
279
280    if mag_a == 0.0 || mag_b == 0.0 {
281        return 0.0;
282    }
283
284    dot_product / (mag_a * mag_b)
285}
286
287#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct IndividualMemoryStats {
289    pub instant: usize,
290    pub session: usize,
291    pub episodic: usize,
292    pub semantic: usize,
293    pub total: usize,
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use crate::MemoryContent;
300
301    #[tokio::test]
302    async fn test_instant_memory() {
303        let mem = IndividualMemory::new().await.unwrap();
304        let memory = Memory::new(
305            MemoryTier::Instant,
306            MemoryContent::Text("test".to_string()),
307            vec![0.1, 0.2, 0.3],
308            0.5,
309        );
310
311        let id = mem.store(memory).await.unwrap();
312        assert!(!id.is_empty());
313    }
314
315    #[test]
316    fn test_cosine_similarity() {
317        let a = vec![1.0, 0.0, 0.0];
318        let b = vec![1.0, 0.0, 0.0];
319        assert_eq!(cosine_similarity(&a, &b), 1.0);
320
321        let c = vec![1.0, 0.0, 0.0];
322        let d = vec![0.0, 1.0, 0.0];
323        assert_eq!(cosine_similarity(&c, &d), 0.0);
324    }
325}