1use a3s_memory::{MemoryItem, MemoryStore, MemoryType, PrunePolicy, RelevanceConfig};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use std::collections::VecDeque;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(rename_all = "camelCase")]
23pub struct MemoryConfig {
24 #[serde(default)]
26 pub relevance: RelevanceConfig,
27 #[serde(default = "MemoryConfig::default_max_short_term")]
29 pub max_short_term: usize,
30 #[serde(default = "MemoryConfig::default_max_working")]
32 pub max_working: usize,
33 #[serde(default)]
35 pub prune_policy: Option<PrunePolicy>,
36 #[serde(default = "MemoryConfig::default_prune_interval_secs")]
38 pub prune_interval_secs: u64,
39}
40
41impl MemoryConfig {
42 fn default_max_short_term() -> usize {
43 100
44 }
45 fn default_max_working() -> usize {
46 10
47 }
48 fn default_prune_interval_secs() -> u64 {
49 3600
50 }
51}
52
53impl Default for MemoryConfig {
54 fn default() -> Self {
55 Self {
56 relevance: RelevanceConfig::default(),
57 max_short_term: 100,
58 max_working: 10,
59 prune_policy: None,
60 prune_interval_secs: 3600,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct MemoryStats {
72 pub long_term_count: usize,
73 pub short_term_count: usize,
74 pub working_count: usize,
75}
76
77#[derive(Clone)]
83pub struct AgentMemory {
84 pub(crate) store: Arc<dyn MemoryStore>,
86 short_term: Arc<RwLock<VecDeque<MemoryItem>>>,
88 working: Arc<RwLock<Vec<MemoryItem>>>,
90 pub(crate) max_short_term: usize,
91 pub(crate) max_working: usize,
92 pub(crate) relevance_config: RelevanceConfig,
93}
94
95impl std::fmt::Debug for AgentMemory {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 f.debug_struct("AgentMemory")
98 .field("max_short_term", &self.max_short_term)
99 .field("max_working", &self.max_working)
100 .finish()
101 }
102}
103
104impl AgentMemory {
105 pub fn new(store: Arc<dyn MemoryStore>) -> Self {
107 Self::with_config(store, MemoryConfig::default())
108 }
109
110 pub fn with_config(store: Arc<dyn MemoryStore>, config: MemoryConfig) -> Self {
115 if let Some(policy) = config.prune_policy.clone() {
116 let store_for_task = Arc::clone(&store);
117 let interval_secs = config.prune_interval_secs;
118 tokio::spawn(async move {
119 let mut ticker =
120 tokio::time::interval(std::time::Duration::from_secs(interval_secs));
121 ticker.tick().await; loop {
123 ticker.tick().await;
124 if let Err(e) = store_for_task.prune(&policy).await {
125 tracing::warn!("memory prune failed: {e}");
126 }
127 }
128 });
129 }
130
131 Self {
132 store,
133 short_term: Arc::new(RwLock::new(VecDeque::new())),
134 working: Arc::new(RwLock::new(Vec::new())),
135 max_short_term: config.max_short_term,
136 max_working: config.max_working,
137 relevance_config: config.relevance,
138 }
139 }
140
141 pub(crate) fn score(&self, item: &MemoryItem, now: DateTime<Utc>) -> f32 {
142 let age_days = (now - item.timestamp).num_seconds() as f32 / 86400.0;
143 let decay = (-age_days / self.relevance_config.decay_days).exp();
144 item.importance * self.relevance_config.importance_weight
145 + decay * self.relevance_config.recency_weight
146 }
147
148 pub async fn remember(&self, item: MemoryItem) -> anyhow::Result<()> {
150 self.store.store(item.clone()).await?;
151 let mut short_term = self.short_term.write().await;
152 short_term.push_back(item);
153 if short_term.len() > self.max_short_term {
154 short_term.pop_front();
155 }
156 Ok(())
157 }
158
159 pub async fn remember_success(
161 &self,
162 prompt: &str,
163 tools_used: &[String],
164 result: &str,
165 ) -> anyhow::Result<()> {
166 let content = format!(
167 "Success: {}\nTools: {}\nResult: {}",
168 prompt,
169 tools_used.join(", "),
170 result
171 );
172 let item = MemoryItem::new(content)
173 .with_importance(0.8)
174 .with_tag("success")
175 .with_tag("pattern")
176 .with_type(MemoryType::Procedural)
177 .with_metadata("prompt", prompt)
178 .with_metadata("tools", tools_used.join(","));
179 self.remember(item).await
180 }
181
182 pub async fn remember_failure(
184 &self,
185 prompt: &str,
186 error: &str,
187 attempted_tools: &[String],
188 ) -> anyhow::Result<()> {
189 let content = format!(
190 "Failure: {}\nError: {}\nAttempted tools: {}",
191 prompt,
192 error,
193 attempted_tools.join(", ")
194 );
195 let item = MemoryItem::new(content)
196 .with_importance(0.9)
197 .with_tag("failure")
198 .with_tag("avoid")
199 .with_type(MemoryType::Episodic)
200 .with_metadata("prompt", prompt)
201 .with_metadata("error", error);
202 self.remember(item).await
203 }
204
205 pub async fn recall_similar(
207 &self,
208 prompt: &str,
209 limit: usize,
210 ) -> anyhow::Result<Vec<MemoryItem>> {
211 self.store.search(prompt, limit).await
212 }
213
214 pub async fn recall_by_tags(
216 &self,
217 tags: &[String],
218 limit: usize,
219 ) -> anyhow::Result<Vec<MemoryItem>> {
220 self.store.search_by_tags(tags, limit).await
221 }
222
223 pub async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
225 self.store.get_recent(limit).await
226 }
227
228 pub async fn add_to_working(&self, item: MemoryItem) -> anyhow::Result<()> {
230 let mut working = self.working.write().await;
231 working.push(item);
232 if working.len() > self.max_working {
233 let now = Utc::now();
234 working.sort_by(|a, b| {
235 self.score(b, now)
236 .partial_cmp(&self.score(a, now))
237 .unwrap_or(std::cmp::Ordering::Equal)
238 });
239 working.truncate(self.max_working);
240 }
241 Ok(())
242 }
243
244 pub async fn get_working(&self) -> Vec<MemoryItem> {
246 self.working.read().await.clone()
247 }
248
249 pub async fn clear_working(&self) {
251 self.working.write().await.clear();
252 }
253
254 pub async fn get_short_term(&self) -> Vec<MemoryItem> {
256 self.short_term.read().await.iter().cloned().collect()
257 }
258
259 pub async fn clear_short_term(&self) {
261 self.short_term.write().await.clear();
262 }
263
264 pub async fn stats(&self) -> anyhow::Result<MemoryStats> {
266 Ok(MemoryStats {
267 long_term_count: self.store.count().await?,
268 short_term_count: self.short_term.read().await.len(),
269 working_count: self.working.read().await.len(),
270 })
271 }
272
273 pub fn store(&self) -> &Arc<dyn MemoryStore> {
275 &self.store
276 }
277
278 pub async fn working_count(&self) -> usize {
280 self.working.read().await.len()
281 }
282
283 pub async fn short_term_count(&self) -> usize {
285 self.short_term.read().await.len()
286 }
287}
288
289pub struct MemoryContextProvider {
295 memory: AgentMemory,
296}
297
298impl MemoryContextProvider {
299 pub fn new(memory: AgentMemory) -> Self {
300 Self { memory }
301 }
302}
303
304pub(crate) fn memory_items_to_context_result(
305 provider: impl Into<String>,
306 items: Vec<MemoryItem>,
307) -> crate::context::ContextResult {
308 let mut result = crate::context::ContextResult::new(provider);
309 for item in items {
310 let token_count = (item.content.len() / 4).max(1);
311 let context_item = crate::context::ContextItem::new(
312 &item.id,
313 crate::context::ContextType::Memory,
314 &item.content,
315 )
316 .with_relevance(item.relevance_score())
317 .with_token_count(token_count)
318 .with_source(format!("memory://{}", item.id))
319 .with_provenance("long_term_memory")
320 .with_priority(0.35)
321 .with_trust(0.7)
322 .with_freshness(0.5);
323 result.add_item(context_item);
324 }
325 result
326}
327
328#[async_trait::async_trait]
329impl crate::context::ContextProvider for MemoryContextProvider {
330 fn name(&self) -> &str {
331 "memory"
332 }
333
334 async fn query(
335 &self,
336 query: &crate::context::ContextQuery,
337 ) -> anyhow::Result<crate::context::ContextResult> {
338 let limit = query.max_results.min(5);
339 let items = self.memory.recall_similar(&query.query, limit).await?;
340
341 Ok(memory_items_to_context_result("memory", items))
342 }
343
344 async fn on_turn_complete(
345 &self,
346 _session_id: &str,
347 prompt: &str,
348 response: &str,
349 ) -> anyhow::Result<()> {
350 self.memory.remember_success(prompt, &[], response).await
351 }
352}
353
354#[cfg(test)]
359mod tests {
360 use super::*;
361 use a3s_memory::InMemoryStore;
362 use std::sync::Arc;
363
364 #[tokio::test]
365 async fn test_agent_memory_remember_and_recall() {
366 let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
367 memory
368 .remember_success("create file", &["write".to_string()], "ok")
369 .await
370 .unwrap();
371 memory
372 .remember_failure("delete file", "denied", &["bash".to_string()])
373 .await
374 .unwrap();
375
376 let results = memory.recall_similar("create", 10).await.unwrap();
377 assert!(!results.is_empty());
378
379 let stats = memory.stats().await.unwrap();
380 assert_eq!(stats.long_term_count, 2);
381 assert_eq!(stats.short_term_count, 2);
382 }
383
384 #[tokio::test]
385 async fn test_agent_memory_working() {
386 let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
387 memory
388 .add_to_working(MemoryItem::new("task").with_type(MemoryType::Working))
389 .await
390 .unwrap();
391 assert_eq!(memory.working_count().await, 1);
392 memory.clear_working().await;
393 assert_eq!(memory.working_count().await, 0);
394 }
395
396 #[tokio::test]
397 async fn test_agent_memory_working_overflow_trims() {
398 let memory = AgentMemory {
399 store: Arc::new(InMemoryStore::new()),
400 short_term: Arc::new(RwLock::new(VecDeque::new())),
401 working: Arc::new(RwLock::new(Vec::new())),
402 max_short_term: 100,
403 max_working: 3,
404 relevance_config: RelevanceConfig::default(),
405 };
406 for i in 0..5 {
407 memory
408 .add_to_working(
409 MemoryItem::new(format!("task {i}")).with_importance(i as f32 * 0.2),
410 )
411 .await
412 .unwrap();
413 }
414 assert_eq!(memory.get_working().await.len(), 3);
415 }
416
417 #[tokio::test]
418 async fn test_agent_memory_recall_by_tags() {
419 let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
420 memory
421 .remember_success("create file", &["write".to_string()], "ok")
422 .await
423 .unwrap();
424 memory
425 .remember_failure("delete file", "denied", &["bash".to_string()])
426 .await
427 .unwrap();
428
429 let successes = memory
430 .recall_by_tags(&["success".to_string()], 10)
431 .await
432 .unwrap();
433 assert_eq!(successes.len(), 1);
434 let failures = memory
435 .recall_by_tags(&["failure".to_string()], 10)
436 .await
437 .unwrap();
438 assert_eq!(failures.len(), 1);
439 }
440
441 #[tokio::test]
442 async fn test_agent_memory_short_term_trim() {
443 let store = Arc::new(InMemoryStore::new());
444 let memory = AgentMemory {
445 store,
446 short_term: Arc::new(RwLock::new(VecDeque::new())),
447 working: Arc::new(RwLock::new(Vec::new())),
448 max_short_term: 3,
449 max_working: 10,
450 relevance_config: RelevanceConfig::default(),
451 };
452 for i in 0..5 {
453 memory
454 .remember(MemoryItem::new(format!("item {i}")))
455 .await
456 .unwrap();
457 }
458 assert_eq!(memory.short_term_count().await, 3);
459 }
460
461 #[tokio::test]
462 async fn test_agent_memory_prune_delegates() {
463 use a3s_memory::PrunePolicy;
464
465 let store = Arc::new(InMemoryStore::new());
466 let memory = AgentMemory::new(store.clone());
467
468 let mut old_item = a3s_memory::MemoryItem::new("stale").with_importance(0.2);
470 old_item.timestamp = chrono::Utc::now() - chrono::Duration::days(100);
471 store.store(old_item).await.unwrap();
472
473 assert_eq!(store.count().await.unwrap(), 1);
474
475 let policy = PrunePolicy {
477 max_age_days: 90,
478 min_importance_to_keep: 0.5,
479 max_items: 0,
480 };
481 let deleted = memory.store().prune(&policy).await.unwrap();
482 assert_eq!(deleted, 1);
483 assert_eq!(store.count().await.unwrap(), 0);
484 }
485
486 #[test]
487 fn test_agent_memory_score_uses_config() {
488 let config = MemoryConfig {
489 relevance: RelevanceConfig {
490 decay_days: 7.0,
491 importance_weight: 0.9,
492 recency_weight: 0.1,
493 },
494 ..Default::default()
495 };
496 let memory = AgentMemory::with_config(Arc::new(InMemoryStore::new()), config);
497 let item = MemoryItem::new("Test").with_importance(1.0);
498 let score = memory.score(&item, Utc::now());
499 assert!(score > 0.95, "Score was {score}");
500 }
501}