1use a3s_memory::{MemoryItem, MemoryStore, MemoryType, PrunePolicy, RelevanceConfig};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use std::collections::VecDeque;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(rename_all = "camelCase")]
23pub struct MemoryConfig {
24 #[serde(default)]
26 pub relevance: RelevanceConfig,
27 #[serde(default = "MemoryConfig::default_max_short_term")]
29 pub max_short_term: usize,
30 #[serde(default = "MemoryConfig::default_max_working")]
32 pub max_working: usize,
33 #[serde(default)]
35 pub prune_policy: Option<PrunePolicy>,
36 #[serde(default = "MemoryConfig::default_prune_interval_secs")]
38 pub prune_interval_secs: u64,
39}
40
41impl MemoryConfig {
42 fn default_max_short_term() -> usize {
43 100
44 }
45 fn default_max_working() -> usize {
46 10
47 }
48 fn default_prune_interval_secs() -> u64 {
49 3600
50 }
51}
52
53impl Default for MemoryConfig {
54 fn default() -> Self {
55 Self {
56 relevance: RelevanceConfig::default(),
57 max_short_term: 100,
58 max_working: 10,
59 prune_policy: None,
60 prune_interval_secs: 3600,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct MemoryStats {
72 pub long_term_count: usize,
73 pub short_term_count: usize,
74 pub working_count: usize,
75}
76
77#[derive(Clone)]
83pub struct AgentMemory {
84 pub(crate) store: Arc<dyn MemoryStore>,
86 short_term: Arc<RwLock<VecDeque<MemoryItem>>>,
88 working: Arc<RwLock<Vec<MemoryItem>>>,
90 pub(crate) max_short_term: usize,
91 pub(crate) max_working: usize,
92 pub(crate) relevance_config: RelevanceConfig,
93}
94
95impl std::fmt::Debug for AgentMemory {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 f.debug_struct("AgentMemory")
98 .field("max_short_term", &self.max_short_term)
99 .field("max_working", &self.max_working)
100 .finish()
101 }
102}
103
104impl AgentMemory {
105 pub fn new(store: Arc<dyn MemoryStore>) -> Self {
107 Self::with_config(store, MemoryConfig::default())
108 }
109
110 pub fn with_config(store: Arc<dyn MemoryStore>, config: MemoryConfig) -> Self {
115 if let Some(policy) = config.prune_policy.clone() {
116 let store_for_task = Arc::clone(&store);
117 let interval_secs = config.prune_interval_secs;
118 tokio::spawn(async move {
119 let mut ticker =
120 tokio::time::interval(std::time::Duration::from_secs(interval_secs));
121 ticker.tick().await; loop {
123 ticker.tick().await;
124 if let Err(e) = store_for_task.prune(&policy).await {
125 tracing::warn!("memory prune failed: {e}");
126 }
127 }
128 });
129 }
130
131 Self {
132 store,
133 short_term: Arc::new(RwLock::new(VecDeque::new())),
134 working: Arc::new(RwLock::new(Vec::new())),
135 max_short_term: config.max_short_term,
136 max_working: config.max_working,
137 relevance_config: config.relevance,
138 }
139 }
140
141 pub(crate) fn score(&self, item: &MemoryItem, now: DateTime<Utc>) -> f32 {
142 let age_days = (now - item.timestamp).num_seconds() as f32 / 86400.0;
143 let decay = (-age_days / self.relevance_config.decay_days).exp();
144 item.importance * self.relevance_config.importance_weight
145 + decay * self.relevance_config.recency_weight
146 }
147
148 pub async fn remember(&self, item: MemoryItem) -> anyhow::Result<()> {
150 self.store.store(item.clone()).await?;
151 let mut short_term = self.short_term.write().await;
152 short_term.push_back(item);
153 if short_term.len() > self.max_short_term {
154 short_term.pop_front();
155 }
156 Ok(())
157 }
158
159 pub async fn remember_success(
161 &self,
162 prompt: &str,
163 tools_used: &[String],
164 result: &str,
165 ) -> anyhow::Result<()> {
166 let content = format!(
167 "Success: {}\nTools: {}\nResult: {}",
168 prompt,
169 tools_used.join(", "),
170 result
171 );
172 let item = MemoryItem::new(content)
173 .with_importance(0.8)
174 .with_tag("success")
175 .with_tag("pattern")
176 .with_type(MemoryType::Procedural)
177 .with_metadata("prompt", prompt)
178 .with_metadata("tools", tools_used.join(","));
179 self.remember(item).await
180 }
181
182 pub async fn remember_failure(
184 &self,
185 prompt: &str,
186 error: &str,
187 attempted_tools: &[String],
188 ) -> anyhow::Result<()> {
189 let content = format!(
190 "Failure: {}\nError: {}\nAttempted tools: {}",
191 prompt,
192 error,
193 attempted_tools.join(", ")
194 );
195 let item = MemoryItem::new(content)
196 .with_importance(0.9)
197 .with_tag("failure")
198 .with_tag("avoid")
199 .with_type(MemoryType::Episodic)
200 .with_metadata("prompt", prompt)
201 .with_metadata("error", error);
202 self.remember(item).await
203 }
204
205 pub async fn recall_similar(
207 &self,
208 prompt: &str,
209 limit: usize,
210 ) -> anyhow::Result<Vec<MemoryItem>> {
211 self.store.search(prompt, limit).await
212 }
213
214 pub async fn recall_by_tags(
216 &self,
217 tags: &[String],
218 limit: usize,
219 ) -> anyhow::Result<Vec<MemoryItem>> {
220 self.store.search_by_tags(tags, limit).await
221 }
222
223 pub async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
225 self.store.get_recent(limit).await
226 }
227
228 pub async fn add_to_working(&self, item: MemoryItem) -> anyhow::Result<()> {
230 let mut working = self.working.write().await;
231 working.push(item);
232 if working.len() > self.max_working {
233 let now = Utc::now();
234 working.sort_by(|a, b| {
235 self.score(b, now)
236 .partial_cmp(&self.score(a, now))
237 .unwrap_or(std::cmp::Ordering::Equal)
238 });
239 working.truncate(self.max_working);
240 }
241 Ok(())
242 }
243
244 pub async fn get_working(&self) -> Vec<MemoryItem> {
246 self.working.read().await.clone()
247 }
248
249 pub async fn clear_working(&self) {
251 self.working.write().await.clear();
252 }
253
254 pub async fn get_short_term(&self) -> Vec<MemoryItem> {
256 self.short_term.read().await.iter().cloned().collect()
257 }
258
259 pub async fn clear_short_term(&self) {
261 self.short_term.write().await.clear();
262 }
263
264 pub async fn stats(&self) -> anyhow::Result<MemoryStats> {
266 Ok(MemoryStats {
267 long_term_count: self.store.count().await?,
268 short_term_count: self.short_term.read().await.len(),
269 working_count: self.working.read().await.len(),
270 })
271 }
272
273 pub fn store(&self) -> &Arc<dyn MemoryStore> {
275 &self.store
276 }
277
278 pub async fn working_count(&self) -> usize {
280 self.working.read().await.len()
281 }
282
283 pub async fn short_term_count(&self) -> usize {
285 self.short_term.read().await.len()
286 }
287}
288
289pub struct MemoryContextProvider {
295 memory: AgentMemory,
296}
297
298impl MemoryContextProvider {
299 pub fn new(memory: AgentMemory) -> Self {
300 Self { memory }
301 }
302}
303
304#[async_trait::async_trait]
305impl crate::context::ContextProvider for MemoryContextProvider {
306 fn name(&self) -> &str {
307 "memory"
308 }
309
310 async fn query(
311 &self,
312 query: &crate::context::ContextQuery,
313 ) -> anyhow::Result<crate::context::ContextResult> {
314 let limit = query.max_results.min(5);
315 let items = self.memory.recall_similar(&query.query, limit).await?;
316
317 let mut result = crate::context::ContextResult::new("memory");
318 for item in items {
319 let relevance = item.relevance_score();
320 let token_count = item.content.len() / 4;
321 let context_item = crate::context::ContextItem::new(
322 &item.id,
323 crate::context::ContextType::Memory,
324 &item.content,
325 )
326 .with_relevance(relevance)
327 .with_token_count(token_count)
328 .with_source("memory");
329 result.add_item(context_item);
330 }
331 Ok(result)
332 }
333
334 async fn on_turn_complete(
335 &self,
336 _session_id: &str,
337 prompt: &str,
338 response: &str,
339 ) -> anyhow::Result<()> {
340 self.memory.remember_success(prompt, &[], response).await
341 }
342}
343
344#[cfg(test)]
349mod tests {
350 use super::*;
351 use a3s_memory::InMemoryStore;
352 use std::sync::Arc;
353
354 #[tokio::test]
355 async fn test_agent_memory_remember_and_recall() {
356 let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
357 memory
358 .remember_success("create file", &["write".to_string()], "ok")
359 .await
360 .unwrap();
361 memory
362 .remember_failure("delete file", "denied", &["bash".to_string()])
363 .await
364 .unwrap();
365
366 let results = memory.recall_similar("create", 10).await.unwrap();
367 assert!(!results.is_empty());
368
369 let stats = memory.stats().await.unwrap();
370 assert_eq!(stats.long_term_count, 2);
371 assert_eq!(stats.short_term_count, 2);
372 }
373
374 #[tokio::test]
375 async fn test_agent_memory_working() {
376 let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
377 memory
378 .add_to_working(MemoryItem::new("task").with_type(MemoryType::Working))
379 .await
380 .unwrap();
381 assert_eq!(memory.working_count().await, 1);
382 memory.clear_working().await;
383 assert_eq!(memory.working_count().await, 0);
384 }
385
386 #[tokio::test]
387 async fn test_agent_memory_working_overflow_trims() {
388 let memory = AgentMemory {
389 store: Arc::new(InMemoryStore::new()),
390 short_term: Arc::new(RwLock::new(VecDeque::new())),
391 working: Arc::new(RwLock::new(Vec::new())),
392 max_short_term: 100,
393 max_working: 3,
394 relevance_config: RelevanceConfig::default(),
395 };
396 for i in 0..5 {
397 memory
398 .add_to_working(
399 MemoryItem::new(format!("task {i}")).with_importance(i as f32 * 0.2),
400 )
401 .await
402 .unwrap();
403 }
404 assert_eq!(memory.get_working().await.len(), 3);
405 }
406
407 #[tokio::test]
408 async fn test_agent_memory_recall_by_tags() {
409 let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
410 memory
411 .remember_success("create file", &["write".to_string()], "ok")
412 .await
413 .unwrap();
414 memory
415 .remember_failure("delete file", "denied", &["bash".to_string()])
416 .await
417 .unwrap();
418
419 let successes = memory
420 .recall_by_tags(&["success".to_string()], 10)
421 .await
422 .unwrap();
423 assert_eq!(successes.len(), 1);
424 let failures = memory
425 .recall_by_tags(&["failure".to_string()], 10)
426 .await
427 .unwrap();
428 assert_eq!(failures.len(), 1);
429 }
430
431 #[tokio::test]
432 async fn test_agent_memory_short_term_trim() {
433 let store = Arc::new(InMemoryStore::new());
434 let memory = AgentMemory {
435 store,
436 short_term: Arc::new(RwLock::new(VecDeque::new())),
437 working: Arc::new(RwLock::new(Vec::new())),
438 max_short_term: 3,
439 max_working: 10,
440 relevance_config: RelevanceConfig::default(),
441 };
442 for i in 0..5 {
443 memory
444 .remember(MemoryItem::new(format!("item {i}")))
445 .await
446 .unwrap();
447 }
448 assert_eq!(memory.short_term_count().await, 3);
449 }
450
451 #[tokio::test]
452 async fn test_agent_memory_prune_delegates() {
453 use a3s_memory::PrunePolicy;
454
455 let store = Arc::new(InMemoryStore::new());
456 let memory = AgentMemory::new(store.clone());
457
458 let mut old_item = a3s_memory::MemoryItem::new("stale").with_importance(0.2);
460 old_item.timestamp = chrono::Utc::now() - chrono::Duration::days(100);
461 store.store(old_item).await.unwrap();
462
463 assert_eq!(store.count().await.unwrap(), 1);
464
465 let policy = PrunePolicy {
467 max_age_days: 90,
468 min_importance_to_keep: 0.5,
469 max_items: 0,
470 };
471 let deleted = memory.store().prune(&policy).await.unwrap();
472 assert_eq!(deleted, 1);
473 assert_eq!(store.count().await.unwrap(), 0);
474 }
475
476 #[test]
477 fn test_agent_memory_score_uses_config() {
478 let config = MemoryConfig {
479 relevance: RelevanceConfig {
480 decay_days: 7.0,
481 importance_weight: 0.9,
482 recency_weight: 0.1,
483 },
484 ..Default::default()
485 };
486 let memory = AgentMemory::with_config(Arc::new(InMemoryStore::new()), config);
487 let item = MemoryItem::new("Test").with_importance(1.0);
488 let score = memory.score(&item, Utc::now());
489 assert!(score > 0.95, "Score was {score}");
490 }
491}