Skip to main content

ai_agents_memory/
compacting.rs

1//! CompactingMemory implementation with auto-summarization
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8
9use ai_agents_core::{ChatMessage, MemorySnapshot, Result};
10
11use super::Memory;
12use super::context::{CompressResult, ConversationContext, estimate_tokens};
13use super::summarizer::Summarizer;
14
15pub struct CompactingMemory {
16    summary: RwLock<Option<String>>,
17    messages: RwLock<Vec<ChatMessage>>,
18    summarized_count: RwLock<usize>,
19    config: CompactingMemoryConfig,
20    summarizer: Arc<dyn Summarizer>,
21    compression_history: RwLock<Vec<CompressionEvent>>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct CompactingMemoryConfig {
26    // This field is never used: currently only `compress_threshold` and `summary_batch_size` control the compression behavior.
27    // FIXME: implement for this one later.
28    #[serde(default = "default_max_recent_messages")]
29    pub max_recent_messages: usize,
30
31    #[serde(default = "default_compress_threshold")]
32    pub compress_threshold: usize,
33
34    #[serde(default = "default_summarize_batch_size")]
35    pub summarize_batch_size: usize,
36
37    // FIXME: unlimited length as default value?
38    #[serde(default = "default_max_summary_length")]
39    pub max_summary_length: usize,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CompressionEvent {
44    pub timestamp: chrono::DateTime<chrono::Utc>,
45    pub messages_compressed: usize,
46    pub summary_length_before: usize,
47    pub summary_length_after: usize,
48}
49
50fn default_max_recent_messages() -> usize {
51    50
52}
53
54fn default_compress_threshold() -> usize {
55    30
56}
57
58fn default_summarize_batch_size() -> usize {
59    10
60}
61
62fn default_max_summary_length() -> usize {
63    2000
64}
65
66impl Default for CompactingMemoryConfig {
67    fn default() -> Self {
68        Self {
69            max_recent_messages: default_max_recent_messages(),
70            compress_threshold: default_compress_threshold(),
71            summarize_batch_size: default_summarize_batch_size(),
72            max_summary_length: default_max_summary_length(),
73        }
74    }
75}
76
77impl CompactingMemory {
78    pub fn new(summarizer: Arc<dyn Summarizer>, config: CompactingMemoryConfig) -> Self {
79        Self {
80            summary: RwLock::new(None),
81            messages: RwLock::new(Vec::new()),
82            summarized_count: RwLock::new(0),
83            config,
84            summarizer,
85            compression_history: RwLock::new(Vec::new()),
86        }
87    }
88
89    pub fn with_default_config(summarizer: Arc<dyn Summarizer>) -> Self {
90        Self::new(summarizer, CompactingMemoryConfig::default())
91    }
92
93    pub fn config(&self) -> &CompactingMemoryConfig {
94        &self.config
95    }
96
97    pub fn summary(&self) -> Option<String> {
98        self.summary.read().clone()
99    }
100
101    pub fn summarized_count(&self) -> usize {
102        *self.summarized_count.read()
103    }
104
105    pub fn compression_history(&self) -> Vec<CompressionEvent> {
106        self.compression_history.read().clone()
107    }
108
109    fn record_compression(&self, messages_compressed: usize, before: usize, after: usize) {
110        let event = CompressionEvent {
111            timestamp: chrono::Utc::now(),
112            messages_compressed,
113            summary_length_before: before,
114            summary_length_after: after,
115        };
116        self.compression_history.write().push(event);
117    }
118}
119
120#[async_trait]
121impl ai_agents_core::Memory for CompactingMemory {
122    async fn add_message(&self, message: ChatMessage) -> Result<()> {
123        self.messages.write().push(message);
124        Ok(())
125    }
126
127    async fn get_messages(&self, limit: Option<usize>) -> Result<Vec<ChatMessage>> {
128        let messages = self.messages.read();
129        match limit {
130            Some(n) => {
131                let start = messages.len().saturating_sub(n);
132                Ok(messages[start..].to_vec())
133            }
134            None => Ok(messages.clone()),
135        }
136    }
137
138    async fn clear(&self) -> Result<()> {
139        *self.summary.write() = None;
140        self.messages.write().clear();
141        *self.summarized_count.write() = 0;
142        self.compression_history.write().clear();
143        Ok(())
144    }
145
146    fn len(&self) -> usize {
147        self.messages.read().len()
148    }
149
150    async fn snapshot(&self) -> Result<MemorySnapshot> {
151        let messages = self.messages.read().clone();
152        let summary = self.summary.read().clone();
153
154        let mut snapshot = MemorySnapshot::new(messages);
155        if let Some(s) = summary {
156            snapshot = snapshot.with_summary(s);
157        }
158        Ok(snapshot)
159    }
160
161    async fn restore(&self, snapshot: MemorySnapshot) -> Result<()> {
162        *self.messages.write() = snapshot.messages;
163        *self.summary.write() = snapshot.summary;
164        *self.summarized_count.write() = 0;
165        self.compression_history.write().clear();
166        Ok(())
167    }
168
169    async fn evict_oldest(&self, count: usize) -> Result<Vec<ChatMessage>> {
170        let mut messages = self.messages.write();
171        let evict_count = count.min(messages.len());
172        let evicted: Vec<ChatMessage> = messages.drain(..evict_count).collect();
173        Ok(evicted)
174    }
175}
176
177#[async_trait]
178impl Memory for CompactingMemory {
179    async fn get_context(&self) -> Result<ConversationContext> {
180        let messages = self.messages.read().clone();
181        let summary = self.summary.read().clone();
182        let summarized_count = *self.summarized_count.read();
183        let total_messages = messages.len() + summarized_count;
184
185        let mut ctx = ConversationContext::with_messages(messages);
186        ctx.total_messages = total_messages;
187
188        if let Some(s) = summary {
189            ctx = ctx.with_summary(s, summarized_count);
190        }
191
192        Ok(ctx)
193    }
194
195    async fn compress(&self, summarizer: Option<&dyn Summarizer>) -> Result<CompressResult> {
196        let message_count = self.messages.read().len();
197
198        if message_count < self.config.compress_threshold {
199            return Ok(CompressResult::NotNeeded);
200        }
201
202        let summarizer = summarizer.unwrap_or(self.summarizer.as_ref());
203        let batch_size = self.config.summarize_batch_size.min(message_count);
204
205        let messages_to_summarize: Vec<ChatMessage> = {
206            let messages = self.messages.read();
207            messages[..batch_size].to_vec()
208        };
209
210        let new_summary = summarizer.summarize(&messages_to_summarize).await?;
211
212        let summary_before_len = self.summary.read().as_ref().map(|s| s.len()).unwrap_or(0);
213
214        let existing_summary = self.summary.read().clone();
215        let combined_summary = match existing_summary {
216            Some(existing) => summarizer.merge_summaries(&[existing, new_summary]).await?,
217            None => new_summary,
218        };
219
220        let final_summary = if combined_summary.len() > self.config.max_summary_length {
221            combined_summary[..self.config.max_summary_length].to_string()
222        } else {
223            combined_summary
224        };
225
226        let summary_after_len = final_summary.len();
227
228        {
229            let mut messages = self.messages.write();
230            messages.drain(..batch_size);
231        }
232
233        *self.summary.write() = Some(final_summary.clone());
234        *self.summarized_count.write() += batch_size;
235
236        self.record_compression(batch_size, summary_before_len, summary_after_len);
237
238        let tokens_before: u32 = messages_to_summarize
239            .iter()
240            .map(|m| estimate_tokens(&m.content))
241            .sum();
242        let tokens_after = estimate_tokens(&final_summary);
243        let tokens_saved = tokens_before.saturating_sub(tokens_after);
244
245        Ok(CompressResult::Compressed {
246            messages_summarized: batch_size,
247            new_summary_length: summary_after_len,
248            tokens_saved,
249        })
250    }
251
252    fn needs_compression(&self) -> bool {
253        self.messages.read().len() >= self.config.compress_threshold
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use crate::summarizer::NoopSummarizer;
261    use ai_agents_core::{Memory as CoreMemory, Role};
262
263    fn make_message(content: &str) -> ChatMessage {
264        ChatMessage {
265            role: Role::User,
266            content: content.to_string(),
267            name: None,
268            timestamp: None,
269        }
270    }
271
272    fn create_test_memory() -> CompactingMemory {
273        let summarizer = Arc::new(NoopSummarizer);
274        let config = CompactingMemoryConfig {
275            max_recent_messages: 10,
276            compress_threshold: 5,
277            summarize_batch_size: 3,
278            max_summary_length: 500,
279        };
280        CompactingMemory::new(summarizer, config)
281    }
282
283    #[tokio::test]
284    async fn test_basic_add_and_get() {
285        let memory = create_test_memory();
286
287        memory.add_message(make_message("Hello")).await.unwrap();
288        memory.add_message(make_message("World")).await.unwrap();
289
290        let messages = memory.get_messages(None).await.unwrap();
291        assert_eq!(messages.len(), 2);
292        assert_eq!(messages[0].content, "Hello");
293        assert_eq!(messages[1].content, "World");
294    }
295
296    #[tokio::test]
297    async fn test_get_messages_with_limit() {
298        let memory = create_test_memory();
299
300        for i in 0..5 {
301            memory
302                .add_message(make_message(&format!("msg{}", i)))
303                .await
304                .unwrap();
305        }
306
307        let messages = memory.get_messages(Some(2)).await.unwrap();
308        assert_eq!(messages.len(), 2);
309        assert_eq!(messages[0].content, "msg3");
310        assert_eq!(messages[1].content, "msg4");
311    }
312
313    #[tokio::test]
314    async fn test_clear() {
315        let memory = create_test_memory();
316
317        memory.add_message(make_message("test")).await.unwrap();
318        assert!(!memory.is_empty());
319
320        memory.clear().await.unwrap();
321        assert!(memory.is_empty());
322        assert!(memory.summary().is_none());
323    }
324
325    #[tokio::test]
326    async fn test_needs_compression() {
327        let memory = create_test_memory();
328
329        for i in 0..4 {
330            memory
331                .add_message(make_message(&format!("msg{}", i)))
332                .await
333                .unwrap();
334        }
335        assert!(!memory.needs_compression());
336
337        memory.add_message(make_message("msg4")).await.unwrap();
338        assert!(memory.needs_compression());
339    }
340
341    #[tokio::test]
342    async fn test_compress_not_needed() {
343        let memory = create_test_memory();
344
345        memory.add_message(make_message("msg1")).await.unwrap();
346        memory.add_message(make_message("msg2")).await.unwrap();
347
348        let result = memory.compress(None).await.unwrap();
349        assert!(matches!(result, CompressResult::NotNeeded));
350    }
351
352    #[tokio::test]
353    async fn test_compress_when_needed() {
354        let memory = create_test_memory();
355
356        for i in 0..6 {
357            memory
358                .add_message(make_message(&format!("message number {}", i)))
359                .await
360                .unwrap();
361        }
362
363        assert!(memory.needs_compression());
364
365        let result = memory.compress(None).await.unwrap();
366
367        if let CompressResult::Compressed {
368            messages_summarized,
369            ..
370        } = result
371        {
372            assert_eq!(messages_summarized, 3);
373        } else {
374            panic!("Expected Compressed result");
375        }
376
377        assert_eq!(memory.len(), 3);
378        assert!(memory.summary().is_some());
379        assert_eq!(memory.summarized_count(), 3);
380    }
381
382    #[tokio::test]
383    async fn test_get_context() {
384        let memory = create_test_memory();
385
386        for i in 0..6 {
387            memory
388                .add_message(make_message(&format!("msg{}", i)))
389                .await
390                .unwrap();
391        }
392
393        memory.compress(None).await.unwrap();
394
395        let ctx = memory.get_context().await.unwrap();
396        assert!(ctx.summary.is_some());
397        assert_eq!(ctx.messages.len(), 3);
398        assert_eq!(ctx.summarized_count, 3);
399    }
400
401    #[tokio::test]
402    async fn test_snapshot_restore() {
403        let memory = create_test_memory();
404
405        memory.add_message(make_message("msg1")).await.unwrap();
406        memory.add_message(make_message("msg2")).await.unwrap();
407
408        let snapshot = memory.snapshot().await.unwrap();
409        assert_eq!(snapshot.messages.len(), 2);
410
411        memory.clear().await.unwrap();
412        assert!(memory.is_empty());
413
414        memory.restore(snapshot).await.unwrap();
415        let messages = memory.get_messages(None).await.unwrap();
416        assert_eq!(messages.len(), 2);
417    }
418
419    #[tokio::test]
420    async fn test_compression_history() {
421        let memory = create_test_memory();
422
423        for i in 0..6 {
424            memory
425                .add_message(make_message(&format!("msg{}", i)))
426                .await
427                .unwrap();
428        }
429
430        memory.compress(None).await.unwrap();
431
432        let history = memory.compression_history();
433        assert_eq!(history.len(), 1);
434        assert_eq!(history[0].messages_compressed, 3);
435    }
436
437    #[test]
438    fn test_config_default() {
439        let config = CompactingMemoryConfig::default();
440        assert_eq!(config.max_recent_messages, 50);
441        assert_eq!(config.compress_threshold, 30);
442        assert_eq!(config.summarize_batch_size, 10);
443        assert_eq!(config.max_summary_length, 2000);
444    }
445
446    #[tokio::test]
447    async fn test_evict_oldest() {
448        let memory = create_test_memory();
449        for i in 0..5 {
450            memory
451                .add_message(make_message(&format!("msg{}", i)))
452                .await
453                .unwrap();
454        }
455
456        let evicted = memory.evict_oldest(2).await.unwrap();
457        assert_eq!(evicted.len(), 2);
458        assert_eq!(evicted[0].content, "msg0");
459        assert_eq!(evicted[1].content, "msg1");
460
461        let remaining = memory.get_messages(None).await.unwrap();
462        assert_eq!(remaining.len(), 3);
463        assert_eq!(remaining[0].content, "msg2");
464    }
465}