agent_io/memory/
manager.rs1use std::sync::Arc;
4
5use super::buffer::RingBuffer;
6use super::embeddings::EmbeddingProvider;
7use super::entry::{MemoryEntry, MemoryType};
8use super::store::MemoryStore;
9use crate::Result;
10
11#[derive(Debug, Clone)]
13pub struct MemoryConfig {
14 pub short_term_size: usize,
16 pub enable_long_term: bool,
18 pub retrieval_limit: usize,
20 pub relevance_threshold: f32,
22 pub max_context_tokens: u64,
24 pub importance_decay: f32,
26}
27
28impl Default for MemoryConfig {
29 fn default() -> Self {
30 Self {
31 short_term_size: 20,
32 enable_long_term: true,
33 retrieval_limit: 5,
34 relevance_threshold: 0.7,
35 max_context_tokens: 2000,
36 importance_decay: 0.95,
37 }
38 }
39}
40
41pub struct MemoryManager {
43 config: MemoryConfig,
44 short_term: RingBuffer<MemoryEntry>,
45 store: Arc<dyn MemoryStore>,
46 embedder: Arc<dyn EmbeddingProvider>,
47}
48
49impl MemoryManager {
50 pub fn new(
52 config: MemoryConfig,
53 store: Arc<dyn MemoryStore>,
54 embedder: Arc<dyn EmbeddingProvider>,
55 ) -> Self {
56 Self {
57 short_term: RingBuffer::new(config.short_term_size),
58 store,
59 embedder,
60 config,
61 }
62 }
63
64 pub fn with_defaults(
66 store: Arc<dyn MemoryStore>,
67 embedder: Arc<dyn EmbeddingProvider>,
68 ) -> Self {
69 Self::new(MemoryConfig::default(), store, embedder)
70 }
71
72 pub async fn remember(&mut self, content: &str, memory_type: MemoryType) -> Result<String> {
74 let embedding = self.embedder.embed(content).await?;
75
76 let entry = MemoryEntry::new(content)
77 .with_type(memory_type)
78 .with_embedding(embedding);
79
80 match memory_type {
81 MemoryType::ShortTerm => {
82 self.short_term.push(entry.clone());
83 Ok(entry.id)
84 }
85 _ => {
86 if self.config.enable_long_term {
87 self.store.add(entry).await
88 } else {
89 self.short_term.push(entry.clone());
90 Ok(entry.id)
91 }
92 }
93 }
94 }
95
96 pub async fn remember_important(
98 &mut self,
99 content: &str,
100 memory_type: MemoryType,
101 importance: f32,
102 ) -> Result<String> {
103 let embedding = self.embedder.embed(content).await?;
104
105 let entry = MemoryEntry::new(content)
106 .with_type(memory_type)
107 .with_embedding(embedding)
108 .with_importance(importance);
109
110 match memory_type {
111 MemoryType::ShortTerm => {
112 self.short_term.push(entry.clone());
113 Ok(entry.id)
114 }
115 _ => {
116 if self.config.enable_long_term {
117 self.store.add(entry).await
118 } else {
119 self.short_term.push(entry.clone());
120 Ok(entry.id)
121 }
122 }
123 }
124 }
125
126 pub async fn recall(&self, query: &str) -> Result<Vec<MemoryEntry>> {
128 let query_embedding = self.embedder.embed(query).await?;
129
130 let mut memories = if self.config.enable_long_term {
132 self.store
133 .search_by_embedding(
134 &query_embedding,
135 self.config.retrieval_limit,
136 self.config.relevance_threshold,
137 )
138 .await?
139 } else {
140 Vec::new()
141 };
142
143 for entry in self.short_term.iter_recent() {
145 if let Some(ref embedding) = entry.embedding {
146 let similarity = Self::cosine_similarity(&query_embedding, embedding);
147 if similarity >= self.config.relevance_threshold {
148 memories.push(entry.clone());
149 }
150 }
151 }
152
153 memories.sort_by(|a, b| {
155 b.relevance_score()
156 .partial_cmp(&a.relevance_score())
157 .unwrap_or(std::cmp::Ordering::Equal)
158 });
159
160 memories.truncate(self.config.retrieval_limit);
162
163 Ok(memories)
164 }
165
166 pub fn build_context(&self, memories: &[MemoryEntry]) -> String {
168 let mut context = String::new();
169 let mut token_count = 0;
170
171 for memory in memories {
172 let tokens = memory.content.len() / 4; if token_count + tokens > self.config.max_context_tokens as usize {
174 break;
175 }
176 context.push_str(&memory.content);
177 context.push_str("\n\n");
178 token_count += tokens;
179 }
180
181 context
182 }
183
184 pub async fn recall_context(&self, query: &str) -> Result<String> {
186 let memories = self.recall(query).await?;
187 Ok(self.build_context(&memories))
188 }
189
190 pub fn short_term(&self) -> &RingBuffer<MemoryEntry> {
192 &self.short_term
193 }
194
195 pub fn store(&self) -> &Arc<dyn MemoryStore> {
197 &self.store
198 }
199
200 pub fn clear_short_term(&mut self) {
202 self.short_term.clear();
203 }
204
205 pub async fn clear_all(&mut self) -> Result<()> {
207 self.short_term.clear();
208 self.store.clear().await
209 }
210
211 pub async fn count(&self) -> Result<usize> {
213 let long_term_count = self.store.count().await?;
214 Ok(self.short_term.len() + long_term_count)
215 }
216
217 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
219 if a.len() != b.len() || a.is_empty() {
220 return 0.0;
221 }
222
223 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
224 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
225 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
226
227 if norm_a == 0.0 || norm_b == 0.0 {
228 0.0
229 } else {
230 dot / (norm_a * norm_b)
231 }
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use crate::memory::InMemoryStore;
239 use crate::memory::embeddings::MockEmbedding;
240
241 #[tokio::test]
242 async fn test_remember_and_recall() {
243 let store = Arc::new(InMemoryStore::new());
244 let embedder = Arc::new(MockEmbedding::new(128));
245 let mut manager = MemoryManager::with_defaults(store, embedder);
246
247 let id = manager
249 .remember("I like Rust programming", MemoryType::LongTerm)
250 .await
251 .unwrap();
252 assert!(!id.is_empty());
253
254 let memories = manager.recall("programming").await.unwrap();
256 assert!(!memories.is_empty());
257 }
258
259 #[tokio::test]
260 async fn test_short_term_memory() {
261 let store = Arc::new(InMemoryStore::new());
262 let embedder = Arc::new(MockEmbedding::new(128));
263 let mut manager = MemoryManager::with_defaults(store, embedder);
264
265 manager
267 .remember("Temporary thought", MemoryType::ShortTerm)
268 .await
269 .unwrap();
270
271 assert_eq!(manager.short_term().len(), 1);
272 }
273
274 #[tokio::test]
275 async fn test_build_context() {
276 let store = Arc::new(InMemoryStore::new());
277 let embedder = Arc::new(MockEmbedding::new(128));
278 let manager = MemoryManager::with_defaults(store, embedder);
279
280 let memories = vec![
281 MemoryEntry::new("First memory"),
282 MemoryEntry::new("Second memory"),
283 ];
284
285 let context = manager.build_context(&memories);
286 assert!(context.contains("First memory"));
287 assert!(context.contains("Second memory"));
288 }
289
290 #[tokio::test]
291 async fn test_clear() {
292 let store = Arc::new(InMemoryStore::new());
293 let embedder = Arc::new(MockEmbedding::new(128));
294 let mut manager = MemoryManager::with_defaults(store, embedder);
295
296 manager
297 .remember("Test", MemoryType::ShortTerm)
298 .await
299 .unwrap();
300
301 manager.clear_short_term();
302 assert!(manager.short_term().is_empty());
303 }
304}