ai_agents_memory/
summarizer.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use ai_agents_core::{ChatMessage, LLMProvider, Result, Role};
8
9#[async_trait]
15pub trait Summarizer: Send + Sync {
16 async fn summarize(&self, messages: &[ChatMessage]) -> Result<String>;
18
19 fn max_batch_size(&self) -> usize {
21 20
22 }
23
24 async fn merge_summaries(&self, summaries: &[String]) -> Result<String> {
26 Ok(summaries.join("\n\n"))
27 }
28}
29
30pub struct LLMSummarizer {
31 llm: Arc<dyn LLMProvider>,
32 prompt_template: String,
33 merge_prompt_template: String,
34 max_batch_size: usize,
35}
36
37impl LLMSummarizer {
38 pub fn new(llm: Arc<dyn LLMProvider>) -> Self {
39 Self {
40 llm,
41 prompt_template: DEFAULT_SUMMARY_PROMPT.to_string(),
42 merge_prompt_template: DEFAULT_MERGE_PROMPT.to_string(),
43 max_batch_size: 20,
44 }
45 }
46
47 pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
48 self.prompt_template = prompt.into();
49 self
50 }
51
52 pub fn with_merge_prompt(mut self, prompt: impl Into<String>) -> Self {
53 self.merge_prompt_template = prompt.into();
54 self
55 }
56
57 pub fn with_batch_size(mut self, size: usize) -> Self {
58 self.max_batch_size = size.max(1);
59 self
60 }
61
62 fn format_messages(&self, messages: &[ChatMessage]) -> String {
63 messages
64 .iter()
65 .map(|m| format!("{}: {}", format_role(&m.role), m.content))
66 .collect::<Vec<_>>()
67 .join("\n")
68 }
69}
70
71fn format_role(role: &Role) -> &'static str {
72 match role {
73 Role::System => "System",
74 Role::User => "User",
75 Role::Assistant => "Assistant",
76 Role::Tool => "Tool",
77 Role::Function => "Function",
78 }
79}
80
81#[async_trait]
82impl Summarizer for LLMSummarizer {
83 async fn summarize(&self, messages: &[ChatMessage]) -> Result<String> {
84 if messages.is_empty() {
85 return Ok(String::new());
86 }
87
88 let conversation = self.format_messages(messages);
89 let prompt = self
90 .prompt_template
91 .replace("{conversation}", &conversation);
92
93 let llm_messages = vec![ChatMessage::user(&prompt)];
94
95 let response = self.llm.complete(&llm_messages, None).await?;
96 Ok(response.content.trim().to_string())
97 }
98
99 fn max_batch_size(&self) -> usize {
100 self.max_batch_size
101 }
102
103 async fn merge_summaries(&self, summaries: &[String]) -> Result<String> {
104 if summaries.is_empty() {
105 return Ok(String::new());
106 }
107
108 if summaries.len() == 1 {
109 return Ok(summaries[0].clone());
110 }
111
112 let combined = summaries.join("\n---\n");
113 let prompt = self.merge_prompt_template.replace("{summaries}", &combined);
114
115 let llm_messages = vec![ChatMessage::user(&prompt)];
116
117 let response = self.llm.complete(&llm_messages, None).await?;
118 Ok(response.content.trim().to_string())
119 }
120}
121
122pub const DEFAULT_SUMMARY_PROMPT: &str = r#"Summarize the following conversation concisely, preserving key information, decisions, and context that would be important for continuing the conversation:
123
124{conversation}
125
126Summary:"#;
127
128pub const DEFAULT_MERGE_PROMPT: &str = r#"Merge the following conversation summaries into a single coherent summary, preserving all important information:
129
130{summaries}
131
132Merged Summary:"#;
133
134pub struct NoopSummarizer;
135
136#[async_trait]
137impl Summarizer for NoopSummarizer {
138 async fn summarize(&self, messages: &[ChatMessage]) -> Result<String> {
139 Ok(messages
140 .iter()
141 .map(|m| m.content.clone())
142 .collect::<Vec<_>>()
143 .join(" | "))
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use ai_agents_core::{FinishReason, LLMChunk, LLMConfig, LLMError, LLMFeature, LLMResponse};
151 use parking_lot::Mutex;
152
153 struct MockLLMProvider {
154 responses: Mutex<Vec<String>>,
155 }
156
157 impl MockLLMProvider {
158 fn new(responses: Vec<String>) -> Self {
159 Self {
160 responses: Mutex::new(responses),
161 }
162 }
163 }
164
165 #[async_trait]
166 impl LLMProvider for MockLLMProvider {
167 async fn complete(
168 &self,
169 _messages: &[ChatMessage],
170 _config: Option<&LLMConfig>,
171 ) -> std::result::Result<LLMResponse, LLMError> {
172 let response = self
173 .responses
174 .lock()
175 .pop()
176 .unwrap_or_else(|| "Summary of conversation".to_string());
177 Ok(LLMResponse::new(response, FinishReason::Stop))
178 }
179
180 async fn complete_stream(
181 &self,
182 _messages: &[ChatMessage],
183 _config: Option<&LLMConfig>,
184 ) -> std::result::Result<
185 Box<dyn futures::Stream<Item = std::result::Result<LLMChunk, LLMError>> + Unpin + Send>,
186 LLMError,
187 > {
188 Err(LLMError::Other(
189 "Streaming not supported in mock".to_string(),
190 ))
191 }
192
193 fn provider_name(&self) -> &str {
194 "mock"
195 }
196
197 fn supports(&self, _feature: LLMFeature) -> bool {
198 true
199 }
200 }
201
202 fn make_message(role: Role, content: &str) -> ChatMessage {
203 ChatMessage {
204 role,
205 content: content.to_string(),
206 name: None,
207 timestamp: None,
208 }
209 }
210
211 #[tokio::test]
212 async fn test_llm_summarizer_basic() {
213 let provider = Arc::new(MockLLMProvider::new(vec!["Test summary".to_string()]));
214 let summarizer = LLMSummarizer::new(provider);
215
216 let messages = vec![
217 make_message(Role::User, "Hello"),
218 make_message(Role::Assistant, "Hi there!"),
219 ];
220
221 let summary = summarizer.summarize(&messages).await.unwrap();
222 assert_eq!(summary, "Test summary");
223 }
224
225 #[tokio::test]
226 async fn test_llm_summarizer_empty_messages() {
227 let provider = Arc::new(MockLLMProvider::new(vec![]));
228 let summarizer = LLMSummarizer::new(provider);
229
230 let summary = summarizer.summarize(&[]).await.unwrap();
231 assert!(summary.is_empty());
232 }
233
234 #[tokio::test]
235 async fn test_llm_summarizer_custom_prompt() {
236 let provider = Arc::new(MockLLMProvider::new(vec!["Custom summary".to_string()]));
237 let summarizer = LLMSummarizer::new(provider).with_prompt("Custom prompt: {conversation}");
238
239 let messages = vec![make_message(Role::User, "Test")];
240 let summary = summarizer.summarize(&messages).await.unwrap();
241 assert_eq!(summary, "Custom summary");
242 }
243
244 #[tokio::test]
245 async fn test_merge_summaries() {
246 let provider = Arc::new(MockLLMProvider::new(vec!["Merged summary".to_string()]));
247 let summarizer = LLMSummarizer::new(provider);
248
249 let summaries = vec!["Summary 1".to_string(), "Summary 2".to_string()];
250 let merged = summarizer.merge_summaries(&summaries).await.unwrap();
251 assert_eq!(merged, "Merged summary");
252 }
253
254 #[tokio::test]
255 async fn test_merge_single_summary() {
256 let provider = Arc::new(MockLLMProvider::new(vec![]));
257 let summarizer = LLMSummarizer::new(provider);
258
259 let summaries = vec!["Only summary".to_string()];
260 let merged = summarizer.merge_summaries(&summaries).await.unwrap();
261 assert_eq!(merged, "Only summary");
262 }
263
264 #[tokio::test]
265 async fn test_noop_summarizer() {
266 let summarizer = NoopSummarizer;
267
268 let messages = vec![
269 make_message(Role::User, "Hello"),
270 make_message(Role::Assistant, "Hi"),
271 ];
272
273 let summary = summarizer.summarize(&messages).await.unwrap();
274 assert!(summary.contains("Hello"));
275 assert!(summary.contains("Hi"));
276 }
277
278 #[test]
279 fn test_max_batch_size() {
280 let provider = Arc::new(MockLLMProvider::new(vec![]));
281 let summarizer = LLMSummarizer::new(provider).with_batch_size(10);
282 assert_eq!(summarizer.max_batch_size(), 10);
283 }
284
285 #[test]
286 fn test_format_role() {
287 assert_eq!(format_role(&Role::User), "User");
288 assert_eq!(format_role(&Role::Assistant), "Assistant");
289 assert_eq!(format_role(&Role::System), "System");
290 assert_eq!(format_role(&Role::Tool), "Tool");
291 assert_eq!(format_role(&Role::Function), "Function");
292 }
293}