Skip to main content

ai_agents_memory/
summarizer.rs

1//! Summarizer trait and implementations for memory compression
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use ai_agents_core::{ChatMessage, LLMProvider, Result, Role};
8
9/// Summarizes conversation messages for memory compression.
10///
11/// Built-in implementations: `LLMSummarizer` (uses an LLM to generate summaries)
12/// and `NoopSummarizer` (concatenates messages, for testing).
13/// Most users use `LLMSummarizer`, auto-configured from the YAML `summarizer_llm` field.
14#[async_trait]
15pub trait Summarizer: Send + Sync {
16    /// Produce a summary from a batch of messages.
17    async fn summarize(&self, messages: &[ChatMessage]) -> Result<String>;
18
19    /// Maximum messages per summarization call. Returns 20 by default.
20    fn max_batch_size(&self) -> usize {
21        20
22    }
23
24    /// Combine multiple summaries into one. Joins with `\n\n` by default.
25    async fn merge_summaries(&self, summaries: &[String]) -> Result<String> {
26        Ok(summaries.join("\n\n"))
27    }
28}
29
30pub struct LLMSummarizer {
31    llm: Arc<dyn LLMProvider>,
32    prompt_template: String,
33    merge_prompt_template: String,
34    max_batch_size: usize,
35}
36
37impl LLMSummarizer {
38    pub fn new(llm: Arc<dyn LLMProvider>) -> Self {
39        Self {
40            llm,
41            prompt_template: DEFAULT_SUMMARY_PROMPT.to_string(),
42            merge_prompt_template: DEFAULT_MERGE_PROMPT.to_string(),
43            max_batch_size: 20,
44        }
45    }
46
47    pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
48        self.prompt_template = prompt.into();
49        self
50    }
51
52    pub fn with_merge_prompt(mut self, prompt: impl Into<String>) -> Self {
53        self.merge_prompt_template = prompt.into();
54        self
55    }
56
57    pub fn with_batch_size(mut self, size: usize) -> Self {
58        self.max_batch_size = size.max(1);
59        self
60    }
61
62    fn format_messages(&self, messages: &[ChatMessage]) -> String {
63        messages
64            .iter()
65            .map(|m| format!("{}: {}", format_role(&m.role), m.content))
66            .collect::<Vec<_>>()
67            .join("\n")
68    }
69}
70
71fn format_role(role: &Role) -> &'static str {
72    match role {
73        Role::System => "System",
74        Role::User => "User",
75        Role::Assistant => "Assistant",
76        Role::Tool => "Tool",
77        Role::Function => "Function",
78    }
79}
80
81#[async_trait]
82impl Summarizer for LLMSummarizer {
83    async fn summarize(&self, messages: &[ChatMessage]) -> Result<String> {
84        if messages.is_empty() {
85            return Ok(String::new());
86        }
87
88        let conversation = self.format_messages(messages);
89        let prompt = self
90            .prompt_template
91            .replace("{conversation}", &conversation);
92
93        let llm_messages = vec![ChatMessage::user(&prompt)];
94
95        let response = self.llm.complete(&llm_messages, None).await?;
96        Ok(response.content.trim().to_string())
97    }
98
99    fn max_batch_size(&self) -> usize {
100        self.max_batch_size
101    }
102
103    async fn merge_summaries(&self, summaries: &[String]) -> Result<String> {
104        if summaries.is_empty() {
105            return Ok(String::new());
106        }
107
108        if summaries.len() == 1 {
109            return Ok(summaries[0].clone());
110        }
111
112        let combined = summaries.join("\n---\n");
113        let prompt = self.merge_prompt_template.replace("{summaries}", &combined);
114
115        let llm_messages = vec![ChatMessage::user(&prompt)];
116
117        let response = self.llm.complete(&llm_messages, None).await?;
118        Ok(response.content.trim().to_string())
119    }
120}
121
122pub const DEFAULT_SUMMARY_PROMPT: &str = r#"Summarize the following conversation concisely, preserving key information, decisions, and context that would be important for continuing the conversation:
123
124{conversation}
125
126Summary:"#;
127
128pub const DEFAULT_MERGE_PROMPT: &str = r#"Merge the following conversation summaries into a single coherent summary, preserving all important information:
129
130{summaries}
131
132Merged Summary:"#;
133
134pub struct NoopSummarizer;
135
136#[async_trait]
137impl Summarizer for NoopSummarizer {
138    async fn summarize(&self, messages: &[ChatMessage]) -> Result<String> {
139        Ok(messages
140            .iter()
141            .map(|m| m.content.clone())
142            .collect::<Vec<_>>()
143            .join(" | "))
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use ai_agents_core::{FinishReason, LLMChunk, LLMConfig, LLMError, LLMFeature, LLMResponse};
151    use parking_lot::Mutex;
152
153    struct MockLLMProvider {
154        responses: Mutex<Vec<String>>,
155    }
156
157    impl MockLLMProvider {
158        fn new(responses: Vec<String>) -> Self {
159            Self {
160                responses: Mutex::new(responses),
161            }
162        }
163    }
164
165    #[async_trait]
166    impl LLMProvider for MockLLMProvider {
167        async fn complete(
168            &self,
169            _messages: &[ChatMessage],
170            _config: Option<&LLMConfig>,
171        ) -> std::result::Result<LLMResponse, LLMError> {
172            let response = self
173                .responses
174                .lock()
175                .pop()
176                .unwrap_or_else(|| "Summary of conversation".to_string());
177            Ok(LLMResponse::new(response, FinishReason::Stop))
178        }
179
180        async fn complete_stream(
181            &self,
182            _messages: &[ChatMessage],
183            _config: Option<&LLMConfig>,
184        ) -> std::result::Result<
185            Box<dyn futures::Stream<Item = std::result::Result<LLMChunk, LLMError>> + Unpin + Send>,
186            LLMError,
187        > {
188            Err(LLMError::Other(
189                "Streaming not supported in mock".to_string(),
190            ))
191        }
192
193        fn provider_name(&self) -> &str {
194            "mock"
195        }
196
197        fn supports(&self, _feature: LLMFeature) -> bool {
198            true
199        }
200    }
201
202    fn make_message(role: Role, content: &str) -> ChatMessage {
203        ChatMessage {
204            role,
205            content: content.to_string(),
206            name: None,
207            timestamp: None,
208        }
209    }
210
211    #[tokio::test]
212    async fn test_llm_summarizer_basic() {
213        let provider = Arc::new(MockLLMProvider::new(vec!["Test summary".to_string()]));
214        let summarizer = LLMSummarizer::new(provider);
215
216        let messages = vec![
217            make_message(Role::User, "Hello"),
218            make_message(Role::Assistant, "Hi there!"),
219        ];
220
221        let summary = summarizer.summarize(&messages).await.unwrap();
222        assert_eq!(summary, "Test summary");
223    }
224
225    #[tokio::test]
226    async fn test_llm_summarizer_empty_messages() {
227        let provider = Arc::new(MockLLMProvider::new(vec![]));
228        let summarizer = LLMSummarizer::new(provider);
229
230        let summary = summarizer.summarize(&[]).await.unwrap();
231        assert!(summary.is_empty());
232    }
233
234    #[tokio::test]
235    async fn test_llm_summarizer_custom_prompt() {
236        let provider = Arc::new(MockLLMProvider::new(vec!["Custom summary".to_string()]));
237        let summarizer = LLMSummarizer::new(provider).with_prompt("Custom prompt: {conversation}");
238
239        let messages = vec![make_message(Role::User, "Test")];
240        let summary = summarizer.summarize(&messages).await.unwrap();
241        assert_eq!(summary, "Custom summary");
242    }
243
244    #[tokio::test]
245    async fn test_merge_summaries() {
246        let provider = Arc::new(MockLLMProvider::new(vec!["Merged summary".to_string()]));
247        let summarizer = LLMSummarizer::new(provider);
248
249        let summaries = vec!["Summary 1".to_string(), "Summary 2".to_string()];
250        let merged = summarizer.merge_summaries(&summaries).await.unwrap();
251        assert_eq!(merged, "Merged summary");
252    }
253
254    #[tokio::test]
255    async fn test_merge_single_summary() {
256        let provider = Arc::new(MockLLMProvider::new(vec![]));
257        let summarizer = LLMSummarizer::new(provider);
258
259        let summaries = vec!["Only summary".to_string()];
260        let merged = summarizer.merge_summaries(&summaries).await.unwrap();
261        assert_eq!(merged, "Only summary");
262    }
263
264    #[tokio::test]
265    async fn test_noop_summarizer() {
266        let summarizer = NoopSummarizer;
267
268        let messages = vec![
269            make_message(Role::User, "Hello"),
270            make_message(Role::Assistant, "Hi"),
271        ];
272
273        let summary = summarizer.summarize(&messages).await.unwrap();
274        assert!(summary.contains("Hello"));
275        assert!(summary.contains("Hi"));
276    }
277
278    #[test]
279    fn test_max_batch_size() {
280        let provider = Arc::new(MockLLMProvider::new(vec![]));
281        let summarizer = LLMSummarizer::new(provider).with_batch_size(10);
282        assert_eq!(summarizer.max_batch_size(), 10);
283    }
284
285    #[test]
286    fn test_format_role() {
287        assert_eq!(format_role(&Role::User), "User");
288        assert_eq!(format_role(&Role::Assistant), "Assistant");
289        assert_eq!(format_role(&Role::System), "System");
290        assert_eq!(format_role(&Role::Tool), "Tool");
291        assert_eq!(format_role(&Role::Function), "Function");
292    }
293}