Skip to main content

agent_io/memory/
manager.rs

1//! Memory manager - core orchestration
2
3use std::sync::Arc;
4
5use super::buffer::RingBuffer;
6use super::embeddings::EmbeddingProvider;
7use super::entry::{MemoryEntry, MemoryType};
8use super::store::MemoryStore;
9use crate::Result;
10
11/// Memory configuration
12#[derive(Debug, Clone)]
13pub struct MemoryConfig {
14    /// Short-term buffer size (number of messages)
15    pub short_term_size: usize,
16    /// Enable long-term memory storage
17    pub enable_long_term: bool,
18    /// Maximum memories to retrieve
19    pub retrieval_limit: usize,
20    /// Similarity threshold for retrieval (0.0 - 1.0)
21    pub relevance_threshold: f32,
22    /// Maximum tokens for context building
23    pub max_context_tokens: u64,
24    /// Importance decay rate per day
25    pub importance_decay: f32,
26}
27
28impl Default for MemoryConfig {
29    fn default() -> Self {
30        Self {
31            short_term_size: 20,
32            enable_long_term: true,
33            retrieval_limit: 5,
34            relevance_threshold: 0.7,
35            max_context_tokens: 2000,
36            importance_decay: 0.95,
37        }
38    }
39}
40
41/// Memory manager orchestrates short-term and long-term memory
42pub struct MemoryManager {
43    config: MemoryConfig,
44    short_term: RingBuffer<MemoryEntry>,
45    store: Arc<dyn MemoryStore>,
46    embedder: Arc<dyn EmbeddingProvider>,
47}
48
49impl MemoryManager {
50    /// Create a new memory manager
51    pub fn new(
52        config: MemoryConfig,
53        store: Arc<dyn MemoryStore>,
54        embedder: Arc<dyn EmbeddingProvider>,
55    ) -> Self {
56        Self {
57            short_term: RingBuffer::new(config.short_term_size),
58            store,
59            embedder,
60            config,
61        }
62    }
63
64    /// Create with default configuration
65    pub fn with_defaults(
66        store: Arc<dyn MemoryStore>,
67        embedder: Arc<dyn EmbeddingProvider>,
68    ) -> Self {
69        Self::new(MemoryConfig::default(), store, embedder)
70    }
71
72    /// Store a memory
73    pub async fn remember(&mut self, content: &str, memory_type: MemoryType) -> Result<String> {
74        let embedding = self.embedder.embed(content).await?;
75
76        let entry = MemoryEntry::new(content)
77            .with_type(memory_type)
78            .with_embedding(embedding);
79
80        match memory_type {
81            MemoryType::ShortTerm => {
82                self.short_term.push(entry.clone());
83                Ok(entry.id)
84            }
85            _ => {
86                if self.config.enable_long_term {
87                    self.store.add(entry).await
88                } else {
89                    self.short_term.push(entry.clone());
90                    Ok(entry.id)
91                }
92            }
93        }
94    }
95
96    /// Store a memory with importance
97    pub async fn remember_important(
98        &mut self,
99        content: &str,
100        memory_type: MemoryType,
101        importance: f32,
102    ) -> Result<String> {
103        let embedding = self.embedder.embed(content).await?;
104
105        let entry = MemoryEntry::new(content)
106            .with_type(memory_type)
107            .with_embedding(embedding)
108            .with_importance(importance);
109
110        match memory_type {
111            MemoryType::ShortTerm => {
112                self.short_term.push(entry.clone());
113                Ok(entry.id)
114            }
115            _ => {
116                if self.config.enable_long_term {
117                    self.store.add(entry).await
118                } else {
119                    self.short_term.push(entry.clone());
120                    Ok(entry.id)
121                }
122            }
123        }
124    }
125
126    /// Recall relevant memories for a query
127    pub async fn recall(&self, query: &str) -> Result<Vec<MemoryEntry>> {
128        let query_embedding = self.embedder.embed(query).await?;
129
130        // Search long-term memory
131        let mut memories = if self.config.enable_long_term {
132            self.store
133                .search_by_embedding(
134                    &query_embedding,
135                    self.config.retrieval_limit,
136                    self.config.relevance_threshold,
137                )
138                .await?
139        } else {
140            Vec::new()
141        };
142
143        // Add short-term memories
144        for entry in self.short_term.iter_recent() {
145            if let Some(ref embedding) = entry.embedding {
146                let similarity = Self::cosine_similarity(&query_embedding, embedding);
147                if similarity >= self.config.relevance_threshold {
148                    memories.push(entry.clone());
149                }
150            }
151        }
152
153        // Sort by relevance score
154        memories.sort_by(|a, b| {
155            b.relevance_score()
156                .partial_cmp(&a.relevance_score())
157                .unwrap_or(std::cmp::Ordering::Equal)
158        });
159
160        // Limit results
161        memories.truncate(self.config.retrieval_limit);
162
163        Ok(memories)
164    }
165
166    /// Build context string from memories
167    pub fn build_context(&self, memories: &[MemoryEntry]) -> String {
168        let mut context = String::new();
169        let mut token_count = 0;
170
171        for memory in memories {
172            let tokens = memory.content.len() / 4; // Approximate token count
173            if token_count + tokens > self.config.max_context_tokens as usize {
174                break;
175            }
176            context.push_str(&memory.content);
177            context.push_str("\n\n");
178            token_count += tokens;
179        }
180
181        context
182    }
183
184    /// Recall and build context in one step
185    pub async fn recall_context(&self, query: &str) -> Result<String> {
186        let memories = self.recall(query).await?;
187        Ok(self.build_context(&memories))
188    }
189
190    /// Get short-term memory buffer
191    pub fn short_term(&self) -> &RingBuffer<MemoryEntry> {
192        &self.short_term
193    }
194
195    /// Get memory store
196    pub fn store(&self) -> &Arc<dyn MemoryStore> {
197        &self.store
198    }
199
200    /// Clear short-term memory
201    pub fn clear_short_term(&mut self) {
202        self.short_term.clear();
203    }
204
205    /// Clear all memories (including long-term)
206    pub async fn clear_all(&mut self) -> Result<()> {
207        self.short_term.clear();
208        self.store.clear().await
209    }
210
211    /// Get total memory count
212    pub async fn count(&self) -> Result<usize> {
213        let long_term_count = self.store.count().await?;
214        Ok(self.short_term.len() + long_term_count)
215    }
216
217    /// Calculate cosine similarity
218    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
219        if a.len() != b.len() || a.is_empty() {
220            return 0.0;
221        }
222
223        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
224        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
225        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
226
227        if norm_a == 0.0 || norm_b == 0.0 {
228            0.0
229        } else {
230            dot / (norm_a * norm_b)
231        }
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use crate::memory::InMemoryStore;
239    use crate::memory::embeddings::MockEmbedding;
240
241    #[tokio::test]
242    async fn test_remember_and_recall() {
243        let store = Arc::new(InMemoryStore::new());
244        let embedder = Arc::new(MockEmbedding::new(128));
245        let mut manager = MemoryManager::with_defaults(store, embedder);
246
247        // Store a memory
248        let id = manager
249            .remember("I like Rust programming", MemoryType::LongTerm)
250            .await
251            .unwrap();
252        assert!(!id.is_empty());
253
254        // Recall should find it
255        let memories = manager.recall("programming").await.unwrap();
256        assert!(!memories.is_empty());
257    }
258
259    #[tokio::test]
260    async fn test_short_term_memory() {
261        let store = Arc::new(InMemoryStore::new());
262        let embedder = Arc::new(MockEmbedding::new(128));
263        let mut manager = MemoryManager::with_defaults(store, embedder);
264
265        // Store short-term memory
266        manager
267            .remember("Temporary thought", MemoryType::ShortTerm)
268            .await
269            .unwrap();
270
271        assert_eq!(manager.short_term().len(), 1);
272    }
273
274    #[tokio::test]
275    async fn test_build_context() {
276        let store = Arc::new(InMemoryStore::new());
277        let embedder = Arc::new(MockEmbedding::new(128));
278        let manager = MemoryManager::with_defaults(store, embedder);
279
280        let memories = vec![
281            MemoryEntry::new("First memory"),
282            MemoryEntry::new("Second memory"),
283        ];
284
285        let context = manager.build_context(&memories);
286        assert!(context.contains("First memory"));
287        assert!(context.contains("Second memory"));
288    }
289
290    #[tokio::test]
291    async fn test_clear() {
292        let store = Arc::new(InMemoryStore::new());
293        let embedder = Arc::new(MockEmbedding::new(128));
294        let mut manager = MemoryManager::with_defaults(store, embedder);
295
296        manager
297            .remember("Test", MemoryType::ShortTerm)
298            .await
299            .unwrap();
300
301        manager.clear_short_term();
302        assert!(manager.short_term().is_empty());
303    }
304}