Skip to main content

ai_agents_memory/
compacting.rs

1//! CompactingMemory implementation with auto-summarization
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8
9use ai_agents_core::{ChatMessage, MemorySnapshot, Result};
10
11use super::Memory;
12use super::context::{CompressResult, ConversationContext, estimate_tokens};
13use super::summarizer::Summarizer;
14
15fn prefix_at_char_boundary(text: &str, max_chars: usize) -> &str {
16    if max_chars == 0 {
17        return "";
18    }
19
20    match text.char_indices().nth(max_chars) {
21        Some((idx, _)) => &text[..idx],
22        None => text,
23    }
24}
25
26pub struct CompactingMemory {
27    summary: RwLock<Option<String>>,
28    messages: RwLock<Vec<ChatMessage>>,
29    summarized_count: RwLock<usize>,
30    config: CompactingMemoryConfig,
31    summarizer: Arc<dyn Summarizer>,
32    compression_history: RwLock<Vec<CompressionEvent>>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct CompactingMemoryConfig {
37    // This field is never used: currently only `compress_threshold` and `summary_batch_size` control the compression behavior.
38    // FIXME: implement for this one later.
39    #[serde(default = "default_max_recent_messages")]
40    pub max_recent_messages: usize,
41
42    #[serde(default = "default_compress_threshold")]
43    pub compress_threshold: usize,
44
45    #[serde(default = "default_summarize_batch_size")]
46    pub summarize_batch_size: usize,
47
48    // FIXME: unlimited length as default value?
49    #[serde(default = "default_max_summary_length")]
50    pub max_summary_length: usize,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct CompressionEvent {
55    pub timestamp: chrono::DateTime<chrono::Utc>,
56    pub messages_compressed: usize,
57    pub summary_length_before: usize,
58    pub summary_length_after: usize,
59}
60
61fn default_max_recent_messages() -> usize {
62    50
63}
64
65fn default_compress_threshold() -> usize {
66    30
67}
68
69fn default_summarize_batch_size() -> usize {
70    10
71}
72
73fn default_max_summary_length() -> usize {
74    2000
75}
76
77impl Default for CompactingMemoryConfig {
78    fn default() -> Self {
79        Self {
80            max_recent_messages: default_max_recent_messages(),
81            compress_threshold: default_compress_threshold(),
82            summarize_batch_size: default_summarize_batch_size(),
83            max_summary_length: default_max_summary_length(),
84        }
85    }
86}
87
88impl CompactingMemory {
89    pub fn new(summarizer: Arc<dyn Summarizer>, config: CompactingMemoryConfig) -> Self {
90        Self {
91            summary: RwLock::new(None),
92            messages: RwLock::new(Vec::new()),
93            summarized_count: RwLock::new(0),
94            config,
95            summarizer,
96            compression_history: RwLock::new(Vec::new()),
97        }
98    }
99
100    pub fn with_default_config(summarizer: Arc<dyn Summarizer>) -> Self {
101        Self::new(summarizer, CompactingMemoryConfig::default())
102    }
103
104    pub fn config(&self) -> &CompactingMemoryConfig {
105        &self.config
106    }
107
108    pub fn summary(&self) -> Option<String> {
109        self.summary.read().clone()
110    }
111
112    pub fn summarized_count(&self) -> usize {
113        *self.summarized_count.read()
114    }
115
116    pub fn compression_history(&self) -> Vec<CompressionEvent> {
117        self.compression_history.read().clone()
118    }
119
120    fn record_compression(&self, messages_compressed: usize, before: usize, after: usize) {
121        let event = CompressionEvent {
122            timestamp: chrono::Utc::now(),
123            messages_compressed,
124            summary_length_before: before,
125            summary_length_after: after,
126        };
127        self.compression_history.write().push(event);
128    }
129}
130
131#[async_trait]
132impl ai_agents_core::Memory for CompactingMemory {
133    async fn add_message(&self, message: ChatMessage) -> Result<()> {
134        self.messages.write().push(message);
135        Ok(())
136    }
137
138    async fn get_messages(&self, limit: Option<usize>) -> Result<Vec<ChatMessage>> {
139        let messages = self.messages.read();
140        match limit {
141            Some(n) => {
142                let start = messages.len().saturating_sub(n);
143                Ok(messages[start..].to_vec())
144            }
145            None => Ok(messages.clone()),
146        }
147    }
148
149    async fn clear(&self) -> Result<()> {
150        *self.summary.write() = None;
151        self.messages.write().clear();
152        *self.summarized_count.write() = 0;
153        self.compression_history.write().clear();
154        Ok(())
155    }
156
157    fn len(&self) -> usize {
158        self.messages.read().len()
159    }
160
161    async fn snapshot(&self) -> Result<MemorySnapshot> {
162        let messages = self.messages.read().clone();
163        let summary = self.summary.read().clone();
164
165        let mut snapshot = MemorySnapshot::new(messages);
166        if let Some(s) = summary {
167            snapshot = snapshot.with_summary(s);
168        }
169        Ok(snapshot)
170    }
171
172    async fn restore(&self, snapshot: MemorySnapshot) -> Result<()> {
173        *self.messages.write() = snapshot.messages;
174        *self.summary.write() = snapshot.summary;
175        *self.summarized_count.write() = 0;
176        self.compression_history.write().clear();
177        Ok(())
178    }
179
180    async fn evict_oldest(&self, count: usize) -> Result<Vec<ChatMessage>> {
181        let mut messages = self.messages.write();
182        let evict_count = count.min(messages.len());
183        let evicted: Vec<ChatMessage> = messages.drain(..evict_count).collect();
184        Ok(evicted)
185    }
186}
187
188#[async_trait]
189impl Memory for CompactingMemory {
190    async fn get_context(&self) -> Result<ConversationContext> {
191        let messages = self.messages.read().clone();
192        let summary = self.summary.read().clone();
193        let summarized_count = *self.summarized_count.read();
194        let total_messages = messages.len() + summarized_count;
195
196        let mut ctx = ConversationContext::with_messages(messages);
197        ctx.total_messages = total_messages;
198
199        if let Some(s) = summary {
200            ctx = ctx.with_summary(s, summarized_count);
201        }
202
203        Ok(ctx)
204    }
205
206    async fn compress(&self, summarizer: Option<&dyn Summarizer>) -> Result<CompressResult> {
207        let message_count = self.messages.read().len();
208
209        if message_count < self.config.compress_threshold {
210            return Ok(CompressResult::NotNeeded);
211        }
212
213        let summarizer = summarizer.unwrap_or(self.summarizer.as_ref());
214        let batch_size = self.config.summarize_batch_size.min(message_count);
215
216        let messages_to_summarize: Vec<ChatMessage> = {
217            let messages = self.messages.read();
218            messages[..batch_size].to_vec()
219        };
220
221        let new_summary = summarizer.summarize(&messages_to_summarize).await?;
222
223        let summary_before_len = self.summary.read().as_ref().map(|s| s.len()).unwrap_or(0);
224
225        let existing_summary = self.summary.read().clone();
226        let combined_summary = match existing_summary {
227            Some(existing) => summarizer.merge_summaries(&[existing, new_summary]).await?,
228            None => new_summary,
229        };
230
231        let truncated = prefix_at_char_boundary(&combined_summary, self.config.max_summary_length);
232        let final_summary = if truncated.len() < combined_summary.len() {
233            truncated.to_string()
234        } else {
235            combined_summary
236        };
237
238        let summary_after_len = final_summary.len();
239
240        {
241            let mut messages = self.messages.write();
242            messages.drain(..batch_size);
243        }
244
245        *self.summary.write() = Some(final_summary.clone());
246        *self.summarized_count.write() += batch_size;
247
248        self.record_compression(batch_size, summary_before_len, summary_after_len);
249
250        let tokens_before: u32 = messages_to_summarize
251            .iter()
252            .map(|m| estimate_tokens(&m.content))
253            .sum();
254        let tokens_after = estimate_tokens(&final_summary);
255        let tokens_saved = tokens_before.saturating_sub(tokens_after);
256
257        Ok(CompressResult::Compressed {
258            messages_summarized: batch_size,
259            new_summary_length: summary_after_len,
260            tokens_saved,
261        })
262    }
263
264    fn needs_compression(&self) -> bool {
265        self.messages.read().len() >= self.config.compress_threshold
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::summarizer::NoopSummarizer;
273    use ai_agents_core::{Memory as CoreMemory, Role};
274
275    fn make_message(content: &str) -> ChatMessage {
276        ChatMessage {
277            role: Role::User,
278            content: content.to_string(),
279            name: None,
280            timestamp: None,
281        }
282    }
283
284    fn create_test_memory() -> CompactingMemory {
285        let summarizer = Arc::new(NoopSummarizer);
286        let config = CompactingMemoryConfig {
287            max_recent_messages: 10,
288            compress_threshold: 5,
289            summarize_batch_size: 3,
290            max_summary_length: 500,
291        };
292        CompactingMemory::new(summarizer, config)
293    }
294
295    #[tokio::test]
296    async fn test_basic_add_and_get() {
297        let memory = create_test_memory();
298
299        memory.add_message(make_message("Hello")).await.unwrap();
300        memory.add_message(make_message("World")).await.unwrap();
301
302        let messages = memory.get_messages(None).await.unwrap();
303        assert_eq!(messages.len(), 2);
304        assert_eq!(messages[0].content, "Hello");
305        assert_eq!(messages[1].content, "World");
306    }
307
308    #[tokio::test]
309    async fn test_get_messages_with_limit() {
310        let memory = create_test_memory();
311
312        for i in 0..5 {
313            memory
314                .add_message(make_message(&format!("msg{}", i)))
315                .await
316                .unwrap();
317        }
318
319        let messages = memory.get_messages(Some(2)).await.unwrap();
320        assert_eq!(messages.len(), 2);
321        assert_eq!(messages[0].content, "msg3");
322        assert_eq!(messages[1].content, "msg4");
323    }
324
325    #[tokio::test]
326    async fn test_clear() {
327        let memory = create_test_memory();
328
329        memory.add_message(make_message("test")).await.unwrap();
330        assert!(!memory.is_empty());
331
332        memory.clear().await.unwrap();
333        assert!(memory.is_empty());
334        assert!(memory.summary().is_none());
335    }
336
337    #[tokio::test]
338    async fn test_needs_compression() {
339        let memory = create_test_memory();
340
341        for i in 0..4 {
342            memory
343                .add_message(make_message(&format!("msg{}", i)))
344                .await
345                .unwrap();
346        }
347        assert!(!memory.needs_compression());
348
349        memory.add_message(make_message("msg4")).await.unwrap();
350        assert!(memory.needs_compression());
351    }
352
353    #[tokio::test]
354    async fn test_compress_not_needed() {
355        let memory = create_test_memory();
356
357        memory.add_message(make_message("msg1")).await.unwrap();
358        memory.add_message(make_message("msg2")).await.unwrap();
359
360        let result = memory.compress(None).await.unwrap();
361        assert!(matches!(result, CompressResult::NotNeeded));
362    }
363
364    #[tokio::test]
365    async fn test_compress_when_needed() {
366        let memory = create_test_memory();
367
368        for i in 0..6 {
369            memory
370                .add_message(make_message(&format!("message number {}", i)))
371                .await
372                .unwrap();
373        }
374
375        assert!(memory.needs_compression());
376
377        let result = memory.compress(None).await.unwrap();
378
379        if let CompressResult::Compressed {
380            messages_summarized,
381            ..
382        } = result
383        {
384            assert_eq!(messages_summarized, 3);
385        } else {
386            panic!("Expected Compressed result");
387        }
388
389        assert_eq!(memory.len(), 3);
390        assert!(memory.summary().is_some());
391        assert_eq!(memory.summarized_count(), 3);
392    }
393
394    #[tokio::test]
395    async fn test_get_context() {
396        let memory = create_test_memory();
397
398        for i in 0..6 {
399            memory
400                .add_message(make_message(&format!("msg{}", i)))
401                .await
402                .unwrap();
403        }
404
405        memory.compress(None).await.unwrap();
406
407        let ctx = memory.get_context().await.unwrap();
408        assert!(ctx.summary.is_some());
409        assert_eq!(ctx.messages.len(), 3);
410        assert_eq!(ctx.summarized_count, 3);
411    }
412
413    #[tokio::test]
414    async fn test_snapshot_restore() {
415        let memory = create_test_memory();
416
417        memory.add_message(make_message("msg1")).await.unwrap();
418        memory.add_message(make_message("msg2")).await.unwrap();
419
420        let snapshot = memory.snapshot().await.unwrap();
421        assert_eq!(snapshot.messages.len(), 2);
422
423        memory.clear().await.unwrap();
424        assert!(memory.is_empty());
425
426        memory.restore(snapshot).await.unwrap();
427        let messages = memory.get_messages(None).await.unwrap();
428        assert_eq!(messages.len(), 2);
429    }
430
431    #[tokio::test]
432    async fn test_compression_history() {
433        let memory = create_test_memory();
434
435        for i in 0..6 {
436            memory
437                .add_message(make_message(&format!("msg{}", i)))
438                .await
439                .unwrap();
440        }
441
442        memory.compress(None).await.unwrap();
443
444        let history = memory.compression_history();
445        assert_eq!(history.len(), 1);
446        assert_eq!(history[0].messages_compressed, 3);
447    }
448
449    #[test]
450    fn test_config_default() {
451        let config = CompactingMemoryConfig::default();
452        assert_eq!(config.max_recent_messages, 50);
453        assert_eq!(config.compress_threshold, 30);
454        assert_eq!(config.summarize_batch_size, 10);
455        assert_eq!(config.max_summary_length, 2000);
456    }
457
458    #[tokio::test]
459    async fn test_evict_oldest() {
460        let memory = create_test_memory();
461        for i in 0..5 {
462            memory
463                .add_message(make_message(&format!("msg{}", i)))
464                .await
465                .unwrap();
466        }
467
468        let evicted = memory.evict_oldest(2).await.unwrap();
469        assert_eq!(evicted.len(), 2);
470        assert_eq!(evicted[0].content, "msg0");
471        assert_eq!(evicted[1].content, "msg1");
472
473        let remaining = memory.get_messages(None).await.unwrap();
474        assert_eq!(remaining.len(), 3);
475        assert_eq!(remaining[0].content, "msg2");
476    }
477
478    #[test]
479    fn test_prefix_at_char_boundary_handles_unicode() {
480        let text = "계약서 내용을 확인하고 싶어서";
481        let prefix = prefix_at_char_boundary(text, 5);
482        assert_eq!(prefix.chars().count(), 5);
483        assert!(text.starts_with(prefix));
484    }
485}