1use crate::{Memory, MemoryContent, MemoryError, MemoryTier, Query};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11pub struct IndividualMemory {
13 instant: Arc<RwLock<HashMap<String, Memory>>>,
15
16 session: Arc<RwLock<HashMap<String, Memory>>>,
18
19 episodic: Arc<RwLock<AgentDBWrapper>>,
21
22 semantic: Arc<RwLock<AgentDBWrapper>>,
24}
25
26impl IndividualMemory {
27 pub async fn new() -> Result<Self, MemoryError> {
28 let episodic_path = PathBuf::from("/tmp/omega/memory/episodic.agentdb");
29 let semantic_path = PathBuf::from("/tmp/omega/memory/semantic.agentdb");
30
31 if let Some(parent) = episodic_path.parent() {
33 tokio::fs::create_dir_all(parent)
34 .await
35 .map_err(|e| MemoryError::Storage(format!("Failed to create directory: {}", e)))?;
36 }
37
38 Ok(Self {
39 instant: Arc::new(RwLock::new(HashMap::new())),
40 session: Arc::new(RwLock::new(HashMap::new())),
41 episodic: Arc::new(RwLock::new(AgentDBWrapper::new(episodic_path).await?)),
42 semantic: Arc::new(RwLock::new(AgentDBWrapper::new(semantic_path).await?)),
43 })
44 }
45
46 pub async fn store(&self, memory: Memory) -> Result<String, MemoryError> {
47 let id = memory.id.clone();
48
49 match memory.tier {
50 MemoryTier::Instant => {
51 self.instant.write().await.insert(id.clone(), memory);
52 self.prune_instant().await?;
53 }
54 MemoryTier::Session => {
55 self.session.write().await.insert(id.clone(), memory);
56 self.prune_session().await?;
57 }
58 MemoryTier::Episodic => {
59 self.episodic.write().await.store(memory).await?;
60 }
61 MemoryTier::Semantic => {
62 self.semantic.write().await.store(memory).await?;
63 }
64 _ => {
65 return Err(MemoryError::Storage(format!(
66 "Invalid tier {:?} for individual memory",
67 memory.tier
68 )));
69 }
70 }
71
72 Ok(id)
73 }
74
75 pub async fn recall(
76 &self,
77 query: &Query,
78 tiers: &[MemoryTier],
79 ) -> Result<Vec<Memory>, MemoryError> {
80 let mut results = Vec::new();
81
82 for tier in tiers {
83 match tier {
84 MemoryTier::Instant => {
85 let instant_mem = self.instant.read().await;
86 let mut memories: Vec<Memory> = instant_mem.values().cloned().collect();
87 memories = self.filter_memories(memories, query);
88 results.extend(memories);
89 }
90 MemoryTier::Session => {
91 let session_mem = self.session.read().await;
92 let mut memories: Vec<Memory> = session_mem.values().cloned().collect();
93 memories = self.filter_memories(memories, query);
94 results.extend(memories);
95 }
96 MemoryTier::Episodic => {
97 let episodic_results = self.episodic.read().await.search(query).await?;
98 results.extend(episodic_results);
99 }
100 MemoryTier::Semantic => {
101 let semantic_results = self.semantic.read().await.search(query).await?;
102 results.extend(semantic_results);
103 }
104 _ => {}
105 }
106 }
107
108 Ok(results)
109 }
110
111 pub async fn stats(&self) -> IndividualMemoryStats {
112 let instant_count = self.instant.read().await.len();
113 let session_count = self.session.read().await.len();
114 let episodic_count = self.episodic.read().await.count().await;
115 let semantic_count = self.semantic.read().await.count().await;
116
117 IndividualMemoryStats {
118 instant: instant_count,
119 session: session_count,
120 episodic: episodic_count,
121 semantic: semantic_count,
122 total: instant_count + session_count + episodic_count + semantic_count,
123 }
124 }
125
126 async fn prune_instant(&self) -> Result<(), MemoryError> {
127 let mut instant = self.instant.write().await;
128 let max_size = MemoryTier::Instant.typical_size();
129
130 if instant.len() > max_size {
131 let mut entries: Vec<_> = instant.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
132 entries.sort_by(|a, b| {
133 a.1.accessed_at
134 .cmp(&b.1.accessed_at)
135 });
136
137 let to_remove = entries.len() - max_size;
138 for (key, _) in entries.iter().take(to_remove) {
139 instant.remove(key);
140 }
141 }
142
143 Ok(())
144 }
145
146 async fn prune_session(&self) -> Result<(), MemoryError> {
147 let mut session = self.session.write().await;
148 let max_size = MemoryTier::Session.typical_size();
149
150 if session.len() > max_size {
151 let mut entries: Vec<_> = session.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
152 entries.sort_by(|a, b| {
153 b.1.relevance_score()
154 .partial_cmp(&a.1.relevance_score())
155 .unwrap_or(std::cmp::Ordering::Equal)
156 });
157
158 let to_remove = entries.len() - max_size;
159 for (key, memory) in entries.iter().rev().take(to_remove) {
160 if memory.importance > 0.3 {
162 let mut promoted = memory.clone();
163 promoted.tier = MemoryTier::Episodic;
164 self.episodic.write().await.store(promoted).await?;
165 }
166 session.remove(key);
167 }
168 }
169
170 Ok(())
171 }
172
173 fn filter_memories(&self, memories: Vec<Memory>, query: &Query) -> Vec<Memory> {
174 memories
175 .into_iter()
176 .filter(|m| {
177 if let Some(min_importance) = query.min_importance {
179 if m.importance < min_importance {
180 return false;
181 }
182 }
183
184 if let Some(ref text) = query.text {
186 if let MemoryContent::Text(ref content) = m.content {
187 if !content.to_lowercase().contains(&text.to_lowercase()) {
188 return false;
189 }
190 } else {
191 return false;
192 }
193 }
194
195 true
196 })
197 .collect()
198 }
199}
200
201pub struct AgentDBWrapper {
203 path: PathBuf,
204 memories: HashMap<String, Memory>,
205}
206
207impl AgentDBWrapper {
208 async fn new(path: PathBuf) -> Result<Self, MemoryError> {
209 let mut wrapper = Self {
210 path,
211 memories: HashMap::new(),
212 };
213
214 if wrapper.path.exists() {
216 wrapper.load().await?;
217 }
218
219 Ok(wrapper)
220 }
221
222 async fn store(&mut self, memory: Memory) -> Result<(), MemoryError> {
223 self.memories.insert(memory.id.clone(), memory);
224 self.save().await?;
225 Ok(())
226 }
227
228 async fn search(&self, query: &Query) -> Result<Vec<Memory>, MemoryError> {
229 let mut results: Vec<Memory> = self.memories.values().cloned().collect();
230
231 if let Some(min_importance) = query.min_importance {
233 results.retain(|m| m.importance >= min_importance);
234 }
235
236 if let Some(ref query_embedding) = query.embedding {
238 results.sort_by(|a, b| {
239 let sim_a = cosine_similarity(&a.embedding, query_embedding);
240 let sim_b = cosine_similarity(&b.embedding, query_embedding);
241 sim_b.partial_cmp(&sim_a).unwrap_or(std::cmp::Ordering::Equal)
242 });
243
244 if let Some(limit) = query.limit {
246 results.truncate(limit);
247 }
248 }
249
250 Ok(results)
251 }
252
253 async fn count(&self) -> usize {
254 self.memories.len()
255 }
256
257 async fn load(&mut self) -> Result<(), MemoryError> {
258 let data = tokio::fs::read(&self.path).await?;
259 self.memories = serde_json::from_slice(&data)?;
260 Ok(())
261 }
262
263 async fn save(&self) -> Result<(), MemoryError> {
264 let data = serde_json::to_vec_pretty(&self.memories)?;
265 tokio::fs::write(&self.path, data).await?;
266 Ok(())
267 }
268}
269
270fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
272 if a.len() != b.len() {
273 return 0.0;
274 }
275
276 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
277 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
278 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
279
280 if mag_a == 0.0 || mag_b == 0.0 {
281 return 0.0;
282 }
283
284 dot_product / (mag_a * mag_b)
285}
286
287#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct IndividualMemoryStats {
289 pub instant: usize,
290 pub session: usize,
291 pub episodic: usize,
292 pub semantic: usize,
293 pub total: usize,
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use crate::MemoryContent;
300
301 #[tokio::test]
302 async fn test_instant_memory() {
303 let mem = IndividualMemory::new().await.unwrap();
304 let memory = Memory::new(
305 MemoryTier::Instant,
306 MemoryContent::Text("test".to_string()),
307 vec![0.1, 0.2, 0.3],
308 0.5,
309 );
310
311 let id = mem.store(memory).await.unwrap();
312 assert!(!id.is_empty());
313 }
314
315 #[test]
316 fn test_cosine_similarity() {
317 let a = vec![1.0, 0.0, 0.0];
318 let b = vec![1.0, 0.0, 0.0];
319 assert_eq!(cosine_similarity(&a, &b), 1.0);
320
321 let c = vec![1.0, 0.0, 0.0];
322 let d = vec![0.0, 1.0, 0.0];
323 assert_eq!(cosine_similarity(&c, &d), 0.0);
324 }
325}