1use std::sync::Arc;
4
5use async_trait::async_trait;
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8
9use ai_agents_core::{ChatMessage, MemorySnapshot, Result};
10
11use super::Memory;
12use super::context::{CompressResult, ConversationContext, estimate_tokens};
13use super::summarizer::Summarizer;
14
15pub struct CompactingMemory {
16 summary: RwLock<Option<String>>,
17 messages: RwLock<Vec<ChatMessage>>,
18 summarized_count: RwLock<usize>,
19 config: CompactingMemoryConfig,
20 summarizer: Arc<dyn Summarizer>,
21 compression_history: RwLock<Vec<CompressionEvent>>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct CompactingMemoryConfig {
26 #[serde(default = "default_max_recent_messages")]
29 pub max_recent_messages: usize,
30
31 #[serde(default = "default_compress_threshold")]
32 pub compress_threshold: usize,
33
34 #[serde(default = "default_summarize_batch_size")]
35 pub summarize_batch_size: usize,
36
37 #[serde(default = "default_max_summary_length")]
39 pub max_summary_length: usize,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CompressionEvent {
44 pub timestamp: chrono::DateTime<chrono::Utc>,
45 pub messages_compressed: usize,
46 pub summary_length_before: usize,
47 pub summary_length_after: usize,
48}
49
50fn default_max_recent_messages() -> usize {
51 50
52}
53
54fn default_compress_threshold() -> usize {
55 30
56}
57
58fn default_summarize_batch_size() -> usize {
59 10
60}
61
62fn default_max_summary_length() -> usize {
63 2000
64}
65
66impl Default for CompactingMemoryConfig {
67 fn default() -> Self {
68 Self {
69 max_recent_messages: default_max_recent_messages(),
70 compress_threshold: default_compress_threshold(),
71 summarize_batch_size: default_summarize_batch_size(),
72 max_summary_length: default_max_summary_length(),
73 }
74 }
75}
76
77impl CompactingMemory {
78 pub fn new(summarizer: Arc<dyn Summarizer>, config: CompactingMemoryConfig) -> Self {
79 Self {
80 summary: RwLock::new(None),
81 messages: RwLock::new(Vec::new()),
82 summarized_count: RwLock::new(0),
83 config,
84 summarizer,
85 compression_history: RwLock::new(Vec::new()),
86 }
87 }
88
89 pub fn with_default_config(summarizer: Arc<dyn Summarizer>) -> Self {
90 Self::new(summarizer, CompactingMemoryConfig::default())
91 }
92
93 pub fn config(&self) -> &CompactingMemoryConfig {
94 &self.config
95 }
96
97 pub fn summary(&self) -> Option<String> {
98 self.summary.read().clone()
99 }
100
101 pub fn summarized_count(&self) -> usize {
102 *self.summarized_count.read()
103 }
104
105 pub fn compression_history(&self) -> Vec<CompressionEvent> {
106 self.compression_history.read().clone()
107 }
108
109 fn record_compression(&self, messages_compressed: usize, before: usize, after: usize) {
110 let event = CompressionEvent {
111 timestamp: chrono::Utc::now(),
112 messages_compressed,
113 summary_length_before: before,
114 summary_length_after: after,
115 };
116 self.compression_history.write().push(event);
117 }
118}
119
120#[async_trait]
121impl ai_agents_core::Memory for CompactingMemory {
122 async fn add_message(&self, message: ChatMessage) -> Result<()> {
123 self.messages.write().push(message);
124 Ok(())
125 }
126
127 async fn get_messages(&self, limit: Option<usize>) -> Result<Vec<ChatMessage>> {
128 let messages = self.messages.read();
129 match limit {
130 Some(n) => {
131 let start = messages.len().saturating_sub(n);
132 Ok(messages[start..].to_vec())
133 }
134 None => Ok(messages.clone()),
135 }
136 }
137
138 async fn clear(&self) -> Result<()> {
139 *self.summary.write() = None;
140 self.messages.write().clear();
141 *self.summarized_count.write() = 0;
142 self.compression_history.write().clear();
143 Ok(())
144 }
145
146 fn len(&self) -> usize {
147 self.messages.read().len()
148 }
149
150 async fn snapshot(&self) -> Result<MemorySnapshot> {
151 let messages = self.messages.read().clone();
152 let summary = self.summary.read().clone();
153
154 let mut snapshot = MemorySnapshot::new(messages);
155 if let Some(s) = summary {
156 snapshot = snapshot.with_summary(s);
157 }
158 Ok(snapshot)
159 }
160
161 async fn restore(&self, snapshot: MemorySnapshot) -> Result<()> {
162 *self.messages.write() = snapshot.messages;
163 *self.summary.write() = snapshot.summary;
164 *self.summarized_count.write() = 0;
165 self.compression_history.write().clear();
166 Ok(())
167 }
168
169 async fn evict_oldest(&self, count: usize) -> Result<Vec<ChatMessage>> {
170 let mut messages = self.messages.write();
171 let evict_count = count.min(messages.len());
172 let evicted: Vec<ChatMessage> = messages.drain(..evict_count).collect();
173 Ok(evicted)
174 }
175}
176
177#[async_trait]
178impl Memory for CompactingMemory {
179 async fn get_context(&self) -> Result<ConversationContext> {
180 let messages = self.messages.read().clone();
181 let summary = self.summary.read().clone();
182 let summarized_count = *self.summarized_count.read();
183 let total_messages = messages.len() + summarized_count;
184
185 let mut ctx = ConversationContext::with_messages(messages);
186 ctx.total_messages = total_messages;
187
188 if let Some(s) = summary {
189 ctx = ctx.with_summary(s, summarized_count);
190 }
191
192 Ok(ctx)
193 }
194
195 async fn compress(&self, summarizer: Option<&dyn Summarizer>) -> Result<CompressResult> {
196 let message_count = self.messages.read().len();
197
198 if message_count < self.config.compress_threshold {
199 return Ok(CompressResult::NotNeeded);
200 }
201
202 let summarizer = summarizer.unwrap_or(self.summarizer.as_ref());
203 let batch_size = self.config.summarize_batch_size.min(message_count);
204
205 let messages_to_summarize: Vec<ChatMessage> = {
206 let messages = self.messages.read();
207 messages[..batch_size].to_vec()
208 };
209
210 let new_summary = summarizer.summarize(&messages_to_summarize).await?;
211
212 let summary_before_len = self.summary.read().as_ref().map(|s| s.len()).unwrap_or(0);
213
214 let existing_summary = self.summary.read().clone();
215 let combined_summary = match existing_summary {
216 Some(existing) => summarizer.merge_summaries(&[existing, new_summary]).await?,
217 None => new_summary,
218 };
219
220 let final_summary = if combined_summary.len() > self.config.max_summary_length {
221 combined_summary[..self.config.max_summary_length].to_string()
222 } else {
223 combined_summary
224 };
225
226 let summary_after_len = final_summary.len();
227
228 {
229 let mut messages = self.messages.write();
230 messages.drain(..batch_size);
231 }
232
233 *self.summary.write() = Some(final_summary.clone());
234 *self.summarized_count.write() += batch_size;
235
236 self.record_compression(batch_size, summary_before_len, summary_after_len);
237
238 let tokens_before: u32 = messages_to_summarize
239 .iter()
240 .map(|m| estimate_tokens(&m.content))
241 .sum();
242 let tokens_after = estimate_tokens(&final_summary);
243 let tokens_saved = tokens_before.saturating_sub(tokens_after);
244
245 Ok(CompressResult::Compressed {
246 messages_summarized: batch_size,
247 new_summary_length: summary_after_len,
248 tokens_saved,
249 })
250 }
251
252 fn needs_compression(&self) -> bool {
253 self.messages.read().len() >= self.config.compress_threshold
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use crate::summarizer::NoopSummarizer;
261 use ai_agents_core::{Memory as CoreMemory, Role};
262
263 fn make_message(content: &str) -> ChatMessage {
264 ChatMessage {
265 role: Role::User,
266 content: content.to_string(),
267 name: None,
268 timestamp: None,
269 }
270 }
271
272 fn create_test_memory() -> CompactingMemory {
273 let summarizer = Arc::new(NoopSummarizer);
274 let config = CompactingMemoryConfig {
275 max_recent_messages: 10,
276 compress_threshold: 5,
277 summarize_batch_size: 3,
278 max_summary_length: 500,
279 };
280 CompactingMemory::new(summarizer, config)
281 }
282
283 #[tokio::test]
284 async fn test_basic_add_and_get() {
285 let memory = create_test_memory();
286
287 memory.add_message(make_message("Hello")).await.unwrap();
288 memory.add_message(make_message("World")).await.unwrap();
289
290 let messages = memory.get_messages(None).await.unwrap();
291 assert_eq!(messages.len(), 2);
292 assert_eq!(messages[0].content, "Hello");
293 assert_eq!(messages[1].content, "World");
294 }
295
296 #[tokio::test]
297 async fn test_get_messages_with_limit() {
298 let memory = create_test_memory();
299
300 for i in 0..5 {
301 memory
302 .add_message(make_message(&format!("msg{}", i)))
303 .await
304 .unwrap();
305 }
306
307 let messages = memory.get_messages(Some(2)).await.unwrap();
308 assert_eq!(messages.len(), 2);
309 assert_eq!(messages[0].content, "msg3");
310 assert_eq!(messages[1].content, "msg4");
311 }
312
313 #[tokio::test]
314 async fn test_clear() {
315 let memory = create_test_memory();
316
317 memory.add_message(make_message("test")).await.unwrap();
318 assert!(!memory.is_empty());
319
320 memory.clear().await.unwrap();
321 assert!(memory.is_empty());
322 assert!(memory.summary().is_none());
323 }
324
325 #[tokio::test]
326 async fn test_needs_compression() {
327 let memory = create_test_memory();
328
329 for i in 0..4 {
330 memory
331 .add_message(make_message(&format!("msg{}", i)))
332 .await
333 .unwrap();
334 }
335 assert!(!memory.needs_compression());
336
337 memory.add_message(make_message("msg4")).await.unwrap();
338 assert!(memory.needs_compression());
339 }
340
341 #[tokio::test]
342 async fn test_compress_not_needed() {
343 let memory = create_test_memory();
344
345 memory.add_message(make_message("msg1")).await.unwrap();
346 memory.add_message(make_message("msg2")).await.unwrap();
347
348 let result = memory.compress(None).await.unwrap();
349 assert!(matches!(result, CompressResult::NotNeeded));
350 }
351
352 #[tokio::test]
353 async fn test_compress_when_needed() {
354 let memory = create_test_memory();
355
356 for i in 0..6 {
357 memory
358 .add_message(make_message(&format!("message number {}", i)))
359 .await
360 .unwrap();
361 }
362
363 assert!(memory.needs_compression());
364
365 let result = memory.compress(None).await.unwrap();
366
367 if let CompressResult::Compressed {
368 messages_summarized,
369 ..
370 } = result
371 {
372 assert_eq!(messages_summarized, 3);
373 } else {
374 panic!("Expected Compressed result");
375 }
376
377 assert_eq!(memory.len(), 3);
378 assert!(memory.summary().is_some());
379 assert_eq!(memory.summarized_count(), 3);
380 }
381
382 #[tokio::test]
383 async fn test_get_context() {
384 let memory = create_test_memory();
385
386 for i in 0..6 {
387 memory
388 .add_message(make_message(&format!("msg{}", i)))
389 .await
390 .unwrap();
391 }
392
393 memory.compress(None).await.unwrap();
394
395 let ctx = memory.get_context().await.unwrap();
396 assert!(ctx.summary.is_some());
397 assert_eq!(ctx.messages.len(), 3);
398 assert_eq!(ctx.summarized_count, 3);
399 }
400
401 #[tokio::test]
402 async fn test_snapshot_restore() {
403 let memory = create_test_memory();
404
405 memory.add_message(make_message("msg1")).await.unwrap();
406 memory.add_message(make_message("msg2")).await.unwrap();
407
408 let snapshot = memory.snapshot().await.unwrap();
409 assert_eq!(snapshot.messages.len(), 2);
410
411 memory.clear().await.unwrap();
412 assert!(memory.is_empty());
413
414 memory.restore(snapshot).await.unwrap();
415 let messages = memory.get_messages(None).await.unwrap();
416 assert_eq!(messages.len(), 2);
417 }
418
419 #[tokio::test]
420 async fn test_compression_history() {
421 let memory = create_test_memory();
422
423 for i in 0..6 {
424 memory
425 .add_message(make_message(&format!("msg{}", i)))
426 .await
427 .unwrap();
428 }
429
430 memory.compress(None).await.unwrap();
431
432 let history = memory.compression_history();
433 assert_eq!(history.len(), 1);
434 assert_eq!(history[0].messages_compressed, 3);
435 }
436
437 #[test]
438 fn test_config_default() {
439 let config = CompactingMemoryConfig::default();
440 assert_eq!(config.max_recent_messages, 50);
441 assert_eq!(config.compress_threshold, 30);
442 assert_eq!(config.summarize_batch_size, 10);
443 assert_eq!(config.max_summary_length, 2000);
444 }
445
446 #[tokio::test]
447 async fn test_evict_oldest() {
448 let memory = create_test_memory();
449 for i in 0..5 {
450 memory
451 .add_message(make_message(&format!("msg{}", i)))
452 .await
453 .unwrap();
454 }
455
456 let evicted = memory.evict_oldest(2).await.unwrap();
457 assert_eq!(evicted.len(), 2);
458 assert_eq!(evicted[0].content, "msg0");
459 assert_eq!(evicted[1].content, "msg1");
460
461 let remaining = memory.get_messages(None).await.unwrap();
462 assert_eq!(remaining.len(), 3);
463 assert_eq!(remaining[0].content, "msg2");
464 }
465}