1use a3s_memory::{MemoryItem, MemoryStore, MemoryType, RelevanceConfig};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use std::collections::VecDeque;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(rename_all = "camelCase")]
23pub struct MemoryConfig {
24 #[serde(default)]
26 pub relevance: RelevanceConfig,
27 #[serde(default = "MemoryConfig::default_max_short_term")]
29 pub max_short_term: usize,
30 #[serde(default = "MemoryConfig::default_max_working")]
32 pub max_working: usize,
33}
34
35impl MemoryConfig {
36 fn default_max_short_term() -> usize {
37 100
38 }
39 fn default_max_working() -> usize {
40 10
41 }
42}
43
44impl Default for MemoryConfig {
45 fn default() -> Self {
46 Self {
47 relevance: RelevanceConfig::default(),
48 max_short_term: 100,
49 max_working: 10,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct MemoryStats {
61 pub long_term_count: usize,
62 pub short_term_count: usize,
63 pub working_count: usize,
64}
65
66#[derive(Clone)]
72pub struct AgentMemory {
73 pub(crate) store: Arc<dyn MemoryStore>,
75 short_term: Arc<RwLock<VecDeque<MemoryItem>>>,
77 working: Arc<RwLock<Vec<MemoryItem>>>,
79 pub(crate) max_short_term: usize,
80 pub(crate) max_working: usize,
81 pub(crate) relevance_config: RelevanceConfig,
82}
83
84impl std::fmt::Debug for AgentMemory {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 f.debug_struct("AgentMemory")
87 .field("max_short_term", &self.max_short_term)
88 .field("max_working", &self.max_working)
89 .finish()
90 }
91}
92
93impl AgentMemory {
94 pub fn new(store: Arc<dyn MemoryStore>) -> Self {
96 Self::with_config(store, MemoryConfig::default())
97 }
98
99 pub fn with_config(store: Arc<dyn MemoryStore>, config: MemoryConfig) -> Self {
101 Self {
102 store,
103 short_term: Arc::new(RwLock::new(VecDeque::new())),
104 working: Arc::new(RwLock::new(Vec::new())),
105 max_short_term: config.max_short_term,
106 max_working: config.max_working,
107 relevance_config: config.relevance,
108 }
109 }
110
111 pub(crate) fn score(&self, item: &MemoryItem, now: DateTime<Utc>) -> f32 {
112 let age_days = (now - item.timestamp).num_seconds() as f32 / 86400.0;
113 let decay = (-age_days / self.relevance_config.decay_days).exp();
114 item.importance * self.relevance_config.importance_weight
115 + decay * self.relevance_config.recency_weight
116 }
117
118 pub async fn remember(&self, item: MemoryItem) -> anyhow::Result<()> {
120 self.store.store(item.clone()).await?;
121 let mut short_term = self.short_term.write().await;
122 short_term.push_back(item);
123 if short_term.len() > self.max_short_term {
124 short_term.pop_front();
125 }
126 Ok(())
127 }
128
129 pub async fn remember_success(
131 &self,
132 prompt: &str,
133 tools_used: &[String],
134 result: &str,
135 ) -> anyhow::Result<()> {
136 let content = format!(
137 "Success: {}\nTools: {}\nResult: {}",
138 prompt,
139 tools_used.join(", "),
140 result
141 );
142 let item = MemoryItem::new(content)
143 .with_importance(0.8)
144 .with_tag("success")
145 .with_tag("pattern")
146 .with_type(MemoryType::Procedural)
147 .with_metadata("prompt", prompt)
148 .with_metadata("tools", tools_used.join(","));
149 self.remember(item).await
150 }
151
152 pub async fn remember_failure(
154 &self,
155 prompt: &str,
156 error: &str,
157 attempted_tools: &[String],
158 ) -> anyhow::Result<()> {
159 let content = format!(
160 "Failure: {}\nError: {}\nAttempted tools: {}",
161 prompt,
162 error,
163 attempted_tools.join(", ")
164 );
165 let item = MemoryItem::new(content)
166 .with_importance(0.9)
167 .with_tag("failure")
168 .with_tag("avoid")
169 .with_type(MemoryType::Episodic)
170 .with_metadata("prompt", prompt)
171 .with_metadata("error", error);
172 self.remember(item).await
173 }
174
175 pub async fn recall_similar(
177 &self,
178 prompt: &str,
179 limit: usize,
180 ) -> anyhow::Result<Vec<MemoryItem>> {
181 self.store.search(prompt, limit).await
182 }
183
184 pub async fn recall_by_tags(
186 &self,
187 tags: &[String],
188 limit: usize,
189 ) -> anyhow::Result<Vec<MemoryItem>> {
190 self.store.search_by_tags(tags, limit).await
191 }
192
193 pub async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
195 self.store.get_recent(limit).await
196 }
197
198 pub async fn add_to_working(&self, item: MemoryItem) -> anyhow::Result<()> {
200 let mut working = self.working.write().await;
201 working.push(item);
202 if working.len() > self.max_working {
203 let now = Utc::now();
204 working.sort_by(|a, b| {
205 self.score(b, now)
206 .partial_cmp(&self.score(a, now))
207 .unwrap_or(std::cmp::Ordering::Equal)
208 });
209 working.truncate(self.max_working);
210 }
211 Ok(())
212 }
213
214 pub async fn get_working(&self) -> Vec<MemoryItem> {
216 self.working.read().await.clone()
217 }
218
219 pub async fn clear_working(&self) {
221 self.working.write().await.clear();
222 }
223
224 pub async fn get_short_term(&self) -> Vec<MemoryItem> {
226 self.short_term.read().await.iter().cloned().collect()
227 }
228
229 pub async fn clear_short_term(&self) {
231 self.short_term.write().await.clear();
232 }
233
234 pub async fn stats(&self) -> anyhow::Result<MemoryStats> {
236 Ok(MemoryStats {
237 long_term_count: self.store.count().await?,
238 short_term_count: self.short_term.read().await.len(),
239 working_count: self.working.read().await.len(),
240 })
241 }
242
243 pub fn store(&self) -> &Arc<dyn MemoryStore> {
245 &self.store
246 }
247
248 pub async fn working_count(&self) -> usize {
250 self.working.read().await.len()
251 }
252
253 pub async fn short_term_count(&self) -> usize {
255 self.short_term.read().await.len()
256 }
257}
258
259pub struct MemoryContextProvider {
265 memory: AgentMemory,
266}
267
268impl MemoryContextProvider {
269 pub fn new(memory: AgentMemory) -> Self {
270 Self { memory }
271 }
272}
273
274#[async_trait::async_trait]
275impl crate::context::ContextProvider for MemoryContextProvider {
276 fn name(&self) -> &str {
277 "memory"
278 }
279
280 async fn query(
281 &self,
282 query: &crate::context::ContextQuery,
283 ) -> anyhow::Result<crate::context::ContextResult> {
284 let limit = query.max_results.min(5);
285 let items = self.memory.recall_similar(&query.query, limit).await?;
286
287 let mut result = crate::context::ContextResult::new("memory");
288 for item in items {
289 let relevance = item.relevance_score();
290 let token_count = item.content.len() / 4;
291 let context_item = crate::context::ContextItem::new(
292 &item.id,
293 crate::context::ContextType::Memory,
294 &item.content,
295 )
296 .with_relevance(relevance)
297 .with_token_count(token_count)
298 .with_source("memory");
299 result.add_item(context_item);
300 }
301 Ok(result)
302 }
303
304 async fn on_turn_complete(
305 &self,
306 _session_id: &str,
307 prompt: &str,
308 response: &str,
309 ) -> anyhow::Result<()> {
310 self.memory.remember_success(prompt, &[], response).await
311 }
312}
313
314#[cfg(test)]
319mod tests {
320 use super::*;
321 use a3s_memory::InMemoryStore;
322 use std::sync::Arc;
323
324 #[tokio::test]
325 async fn test_agent_memory_remember_and_recall() {
326 let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
327 memory
328 .remember_success("create file", &["write".to_string()], "ok")
329 .await
330 .unwrap();
331 memory
332 .remember_failure("delete file", "denied", &["bash".to_string()])
333 .await
334 .unwrap();
335
336 let results = memory.recall_similar("create", 10).await.unwrap();
337 assert!(!results.is_empty());
338
339 let stats = memory.stats().await.unwrap();
340 assert_eq!(stats.long_term_count, 2);
341 assert_eq!(stats.short_term_count, 2);
342 }
343
344 #[tokio::test]
345 async fn test_agent_memory_working() {
346 let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
347 memory
348 .add_to_working(MemoryItem::new("task").with_type(MemoryType::Working))
349 .await
350 .unwrap();
351 assert_eq!(memory.working_count().await, 1);
352 memory.clear_working().await;
353 assert_eq!(memory.working_count().await, 0);
354 }
355
356 #[tokio::test]
357 async fn test_agent_memory_working_overflow_trims() {
358 let memory = AgentMemory {
359 store: Arc::new(InMemoryStore::new()),
360 short_term: Arc::new(RwLock::new(VecDeque::new())),
361 working: Arc::new(RwLock::new(Vec::new())),
362 max_short_term: 100,
363 max_working: 3,
364 relevance_config: RelevanceConfig::default(),
365 };
366 for i in 0..5 {
367 memory
368 .add_to_working(
369 MemoryItem::new(format!("task {i}")).with_importance(i as f32 * 0.2),
370 )
371 .await
372 .unwrap();
373 }
374 assert_eq!(memory.get_working().await.len(), 3);
375 }
376
377 #[tokio::test]
378 async fn test_agent_memory_recall_by_tags() {
379 let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
380 memory
381 .remember_success("create file", &["write".to_string()], "ok")
382 .await
383 .unwrap();
384 memory
385 .remember_failure("delete file", "denied", &["bash".to_string()])
386 .await
387 .unwrap();
388
389 let successes = memory
390 .recall_by_tags(&["success".to_string()], 10)
391 .await
392 .unwrap();
393 assert_eq!(successes.len(), 1);
394 let failures = memory
395 .recall_by_tags(&["failure".to_string()], 10)
396 .await
397 .unwrap();
398 assert_eq!(failures.len(), 1);
399 }
400
401 #[tokio::test]
402 async fn test_agent_memory_short_term_trim() {
403 let store = Arc::new(InMemoryStore::new());
404 let memory = AgentMemory {
405 store,
406 short_term: Arc::new(RwLock::new(VecDeque::new())),
407 working: Arc::new(RwLock::new(Vec::new())),
408 max_short_term: 3,
409 max_working: 10,
410 relevance_config: RelevanceConfig::default(),
411 };
412 for i in 0..5 {
413 memory
414 .remember(MemoryItem::new(format!("item {i}")))
415 .await
416 .unwrap();
417 }
418 assert_eq!(memory.short_term_count().await, 3);
419 }
420
421 #[test]
422 fn test_agent_memory_score_uses_config() {
423 let config = MemoryConfig {
424 relevance: RelevanceConfig {
425 decay_days: 7.0,
426 importance_weight: 0.9,
427 recency_weight: 0.1,
428 },
429 ..Default::default()
430 };
431 let memory = AgentMemory::with_config(Arc::new(InMemoryStore::new()), config);
432 let item = MemoryItem::new("Test").with_importance(1.0);
433 let score = memory.score(&item, Utc::now());
434 assert!(score > 0.95, "Score was {score}");
435 }
436}