1use std::sync::Arc;
4
5use async_trait::async_trait;
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8
9use ai_agents_core::{ChatMessage, MemorySnapshot, Result};
10
11use super::Memory;
12use super::context::{CompressResult, ConversationContext, estimate_tokens};
13use super::summarizer::Summarizer;
14
15fn prefix_at_char_boundary(text: &str, max_chars: usize) -> &str {
16 if max_chars == 0 {
17 return "";
18 }
19
20 match text.char_indices().nth(max_chars) {
21 Some((idx, _)) => &text[..idx],
22 None => text,
23 }
24}
25
26pub struct CompactingMemory {
27 summary: RwLock<Option<String>>,
28 messages: RwLock<Vec<ChatMessage>>,
29 summarized_count: RwLock<usize>,
30 config: CompactingMemoryConfig,
31 summarizer: Arc<dyn Summarizer>,
32 compression_history: RwLock<Vec<CompressionEvent>>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct CompactingMemoryConfig {
37 #[serde(default = "default_max_recent_messages")]
40 pub max_recent_messages: usize,
41
42 #[serde(default = "default_compress_threshold")]
43 pub compress_threshold: usize,
44
45 #[serde(default = "default_summarize_batch_size")]
46 pub summarize_batch_size: usize,
47
48 #[serde(default = "default_max_summary_length")]
50 pub max_summary_length: usize,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct CompressionEvent {
55 pub timestamp: chrono::DateTime<chrono::Utc>,
56 pub messages_compressed: usize,
57 pub summary_length_before: usize,
58 pub summary_length_after: usize,
59}
60
61fn default_max_recent_messages() -> usize {
62 50
63}
64
65fn default_compress_threshold() -> usize {
66 30
67}
68
69fn default_summarize_batch_size() -> usize {
70 10
71}
72
73fn default_max_summary_length() -> usize {
74 2000
75}
76
77impl Default for CompactingMemoryConfig {
78 fn default() -> Self {
79 Self {
80 max_recent_messages: default_max_recent_messages(),
81 compress_threshold: default_compress_threshold(),
82 summarize_batch_size: default_summarize_batch_size(),
83 max_summary_length: default_max_summary_length(),
84 }
85 }
86}
87
88impl CompactingMemory {
89 pub fn new(summarizer: Arc<dyn Summarizer>, config: CompactingMemoryConfig) -> Self {
90 Self {
91 summary: RwLock::new(None),
92 messages: RwLock::new(Vec::new()),
93 summarized_count: RwLock::new(0),
94 config,
95 summarizer,
96 compression_history: RwLock::new(Vec::new()),
97 }
98 }
99
100 pub fn with_default_config(summarizer: Arc<dyn Summarizer>) -> Self {
101 Self::new(summarizer, CompactingMemoryConfig::default())
102 }
103
104 pub fn config(&self) -> &CompactingMemoryConfig {
105 &self.config
106 }
107
108 pub fn summary(&self) -> Option<String> {
109 self.summary.read().clone()
110 }
111
112 pub fn summarized_count(&self) -> usize {
113 *self.summarized_count.read()
114 }
115
116 pub fn compression_history(&self) -> Vec<CompressionEvent> {
117 self.compression_history.read().clone()
118 }
119
120 fn record_compression(&self, messages_compressed: usize, before: usize, after: usize) {
121 let event = CompressionEvent {
122 timestamp: chrono::Utc::now(),
123 messages_compressed,
124 summary_length_before: before,
125 summary_length_after: after,
126 };
127 self.compression_history.write().push(event);
128 }
129}
130
131#[async_trait]
132impl ai_agents_core::Memory for CompactingMemory {
133 async fn add_message(&self, message: ChatMessage) -> Result<()> {
134 self.messages.write().push(message);
135 Ok(())
136 }
137
138 async fn get_messages(&self, limit: Option<usize>) -> Result<Vec<ChatMessage>> {
139 let messages = self.messages.read();
140 match limit {
141 Some(n) => {
142 let start = messages.len().saturating_sub(n);
143 Ok(messages[start..].to_vec())
144 }
145 None => Ok(messages.clone()),
146 }
147 }
148
149 async fn clear(&self) -> Result<()> {
150 *self.summary.write() = None;
151 self.messages.write().clear();
152 *self.summarized_count.write() = 0;
153 self.compression_history.write().clear();
154 Ok(())
155 }
156
157 fn len(&self) -> usize {
158 self.messages.read().len()
159 }
160
161 async fn snapshot(&self) -> Result<MemorySnapshot> {
162 let messages = self.messages.read().clone();
163 let summary = self.summary.read().clone();
164
165 let mut snapshot = MemorySnapshot::new(messages);
166 if let Some(s) = summary {
167 snapshot = snapshot.with_summary(s);
168 }
169 Ok(snapshot)
170 }
171
172 async fn restore(&self, snapshot: MemorySnapshot) -> Result<()> {
173 *self.messages.write() = snapshot.messages;
174 *self.summary.write() = snapshot.summary;
175 *self.summarized_count.write() = 0;
176 self.compression_history.write().clear();
177 Ok(())
178 }
179
180 async fn evict_oldest(&self, count: usize) -> Result<Vec<ChatMessage>> {
181 let mut messages = self.messages.write();
182 let evict_count = count.min(messages.len());
183 let evicted: Vec<ChatMessage> = messages.drain(..evict_count).collect();
184 Ok(evicted)
185 }
186}
187
188#[async_trait]
189impl Memory for CompactingMemory {
190 async fn get_context(&self) -> Result<ConversationContext> {
191 let messages = self.messages.read().clone();
192 let summary = self.summary.read().clone();
193 let summarized_count = *self.summarized_count.read();
194 let total_messages = messages.len() + summarized_count;
195
196 let mut ctx = ConversationContext::with_messages(messages);
197 ctx.total_messages = total_messages;
198
199 if let Some(s) = summary {
200 ctx = ctx.with_summary(s, summarized_count);
201 }
202
203 Ok(ctx)
204 }
205
206 async fn compress(&self, summarizer: Option<&dyn Summarizer>) -> Result<CompressResult> {
207 let message_count = self.messages.read().len();
208
209 if message_count < self.config.compress_threshold {
210 return Ok(CompressResult::NotNeeded);
211 }
212
213 let summarizer = summarizer.unwrap_or(self.summarizer.as_ref());
214 let batch_size = self.config.summarize_batch_size.min(message_count);
215
216 let messages_to_summarize: Vec<ChatMessage> = {
217 let messages = self.messages.read();
218 messages[..batch_size].to_vec()
219 };
220
221 let new_summary = summarizer.summarize(&messages_to_summarize).await?;
222
223 let summary_before_len = self.summary.read().as_ref().map(|s| s.len()).unwrap_or(0);
224
225 let existing_summary = self.summary.read().clone();
226 let combined_summary = match existing_summary {
227 Some(existing) => summarizer.merge_summaries(&[existing, new_summary]).await?,
228 None => new_summary,
229 };
230
231 let truncated = prefix_at_char_boundary(&combined_summary, self.config.max_summary_length);
232 let final_summary = if truncated.len() < combined_summary.len() {
233 truncated.to_string()
234 } else {
235 combined_summary
236 };
237
238 let summary_after_len = final_summary.len();
239
240 {
241 let mut messages = self.messages.write();
242 messages.drain(..batch_size);
243 }
244
245 *self.summary.write() = Some(final_summary.clone());
246 *self.summarized_count.write() += batch_size;
247
248 self.record_compression(batch_size, summary_before_len, summary_after_len);
249
250 let tokens_before: u32 = messages_to_summarize
251 .iter()
252 .map(|m| estimate_tokens(&m.content))
253 .sum();
254 let tokens_after = estimate_tokens(&final_summary);
255 let tokens_saved = tokens_before.saturating_sub(tokens_after);
256
257 Ok(CompressResult::Compressed {
258 messages_summarized: batch_size,
259 new_summary_length: summary_after_len,
260 tokens_saved,
261 })
262 }
263
264 fn needs_compression(&self) -> bool {
265 self.messages.read().len() >= self.config.compress_threshold
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::summarizer::NoopSummarizer;
273 use ai_agents_core::{Memory as CoreMemory, Role};
274
275 fn make_message(content: &str) -> ChatMessage {
276 ChatMessage {
277 role: Role::User,
278 content: content.to_string(),
279 name: None,
280 timestamp: None,
281 }
282 }
283
284 fn create_test_memory() -> CompactingMemory {
285 let summarizer = Arc::new(NoopSummarizer);
286 let config = CompactingMemoryConfig {
287 max_recent_messages: 10,
288 compress_threshold: 5,
289 summarize_batch_size: 3,
290 max_summary_length: 500,
291 };
292 CompactingMemory::new(summarizer, config)
293 }
294
295 #[tokio::test]
296 async fn test_basic_add_and_get() {
297 let memory = create_test_memory();
298
299 memory.add_message(make_message("Hello")).await.unwrap();
300 memory.add_message(make_message("World")).await.unwrap();
301
302 let messages = memory.get_messages(None).await.unwrap();
303 assert_eq!(messages.len(), 2);
304 assert_eq!(messages[0].content, "Hello");
305 assert_eq!(messages[1].content, "World");
306 }
307
308 #[tokio::test]
309 async fn test_get_messages_with_limit() {
310 let memory = create_test_memory();
311
312 for i in 0..5 {
313 memory
314 .add_message(make_message(&format!("msg{}", i)))
315 .await
316 .unwrap();
317 }
318
319 let messages = memory.get_messages(Some(2)).await.unwrap();
320 assert_eq!(messages.len(), 2);
321 assert_eq!(messages[0].content, "msg3");
322 assert_eq!(messages[1].content, "msg4");
323 }
324
325 #[tokio::test]
326 async fn test_clear() {
327 let memory = create_test_memory();
328
329 memory.add_message(make_message("test")).await.unwrap();
330 assert!(!memory.is_empty());
331
332 memory.clear().await.unwrap();
333 assert!(memory.is_empty());
334 assert!(memory.summary().is_none());
335 }
336
337 #[tokio::test]
338 async fn test_needs_compression() {
339 let memory = create_test_memory();
340
341 for i in 0..4 {
342 memory
343 .add_message(make_message(&format!("msg{}", i)))
344 .await
345 .unwrap();
346 }
347 assert!(!memory.needs_compression());
348
349 memory.add_message(make_message("msg4")).await.unwrap();
350 assert!(memory.needs_compression());
351 }
352
353 #[tokio::test]
354 async fn test_compress_not_needed() {
355 let memory = create_test_memory();
356
357 memory.add_message(make_message("msg1")).await.unwrap();
358 memory.add_message(make_message("msg2")).await.unwrap();
359
360 let result = memory.compress(None).await.unwrap();
361 assert!(matches!(result, CompressResult::NotNeeded));
362 }
363
364 #[tokio::test]
365 async fn test_compress_when_needed() {
366 let memory = create_test_memory();
367
368 for i in 0..6 {
369 memory
370 .add_message(make_message(&format!("message number {}", i)))
371 .await
372 .unwrap();
373 }
374
375 assert!(memory.needs_compression());
376
377 let result = memory.compress(None).await.unwrap();
378
379 if let CompressResult::Compressed {
380 messages_summarized,
381 ..
382 } = result
383 {
384 assert_eq!(messages_summarized, 3);
385 } else {
386 panic!("Expected Compressed result");
387 }
388
389 assert_eq!(memory.len(), 3);
390 assert!(memory.summary().is_some());
391 assert_eq!(memory.summarized_count(), 3);
392 }
393
394 #[tokio::test]
395 async fn test_get_context() {
396 let memory = create_test_memory();
397
398 for i in 0..6 {
399 memory
400 .add_message(make_message(&format!("msg{}", i)))
401 .await
402 .unwrap();
403 }
404
405 memory.compress(None).await.unwrap();
406
407 let ctx = memory.get_context().await.unwrap();
408 assert!(ctx.summary.is_some());
409 assert_eq!(ctx.messages.len(), 3);
410 assert_eq!(ctx.summarized_count, 3);
411 }
412
413 #[tokio::test]
414 async fn test_snapshot_restore() {
415 let memory = create_test_memory();
416
417 memory.add_message(make_message("msg1")).await.unwrap();
418 memory.add_message(make_message("msg2")).await.unwrap();
419
420 let snapshot = memory.snapshot().await.unwrap();
421 assert_eq!(snapshot.messages.len(), 2);
422
423 memory.clear().await.unwrap();
424 assert!(memory.is_empty());
425
426 memory.restore(snapshot).await.unwrap();
427 let messages = memory.get_messages(None).await.unwrap();
428 assert_eq!(messages.len(), 2);
429 }
430
431 #[tokio::test]
432 async fn test_compression_history() {
433 let memory = create_test_memory();
434
435 for i in 0..6 {
436 memory
437 .add_message(make_message(&format!("msg{}", i)))
438 .await
439 .unwrap();
440 }
441
442 memory.compress(None).await.unwrap();
443
444 let history = memory.compression_history();
445 assert_eq!(history.len(), 1);
446 assert_eq!(history[0].messages_compressed, 3);
447 }
448
449 #[test]
450 fn test_config_default() {
451 let config = CompactingMemoryConfig::default();
452 assert_eq!(config.max_recent_messages, 50);
453 assert_eq!(config.compress_threshold, 30);
454 assert_eq!(config.summarize_batch_size, 10);
455 assert_eq!(config.max_summary_length, 2000);
456 }
457
458 #[tokio::test]
459 async fn test_evict_oldest() {
460 let memory = create_test_memory();
461 for i in 0..5 {
462 memory
463 .add_message(make_message(&format!("msg{}", i)))
464 .await
465 .unwrap();
466 }
467
468 let evicted = memory.evict_oldest(2).await.unwrap();
469 assert_eq!(evicted.len(), 2);
470 assert_eq!(evicted[0].content, "msg0");
471 assert_eq!(evicted[1].content, "msg1");
472
473 let remaining = memory.get_messages(None).await.unwrap();
474 assert_eq!(remaining.len(), 3);
475 assert_eq!(remaining[0].content, "msg2");
476 }
477
478 #[test]
479 fn test_prefix_at_char_boundary_handles_unicode() {
480 let text = "계약서 내용을 확인하고 싶어서";
481 let prefix = prefix_at_char_boundary(text, 5);
482 assert_eq!(prefix.chars().count(), 5);
483 assert!(text.starts_with(prefix));
484 }
485}