bamboo_compression/
summarizer.rs1use async_trait::async_trait;
7use bamboo_domain::{Message, Role};
8use std::collections::HashSet;
9
10#[async_trait]
12pub trait Summarizer: Send + Sync {
13 async fn summarize(&self, messages: &[Message]) -> Result<String, crate::types::BudgetError>;
17
18 fn estimate_summary_tokens(&self, message_count: usize) -> u32 {
22 (message_count * 50).min(1000) as u32
24 }
25}
26
27#[derive(Debug, Default)]
36pub struct HeuristicSummarizer;
37
38impl HeuristicSummarizer {
39 pub fn new() -> Self {
41 Self
42 }
43
44 fn extract_user_questions<'a>(&self, messages: &'a [Message]) -> Vec<&'a str> {
46 messages
47 .iter()
48 .filter(|m| m.role == Role::User)
49 .filter(|m| !m.content.is_empty())
50 .take(10) .map(|m| m.content.as_str())
52 .collect()
53 }
54
55 fn extract_tools_used(&self, messages: &[Message]) -> Vec<String> {
57 let mut tools = HashSet::new();
58
59 for message in messages {
60 if let Some(ref tool_calls) = message.tool_calls {
61 for call in tool_calls {
62 tools.insert(call.function.name.clone());
63 }
64 }
65 }
66
67 let mut result: Vec<String> = tools.into_iter().collect();
68 result.sort();
69 result
70 }
71
72 fn extract_key_responses<'a>(&self, messages: &'a [Message]) -> Vec<&'a str> {
74 messages
75 .iter()
76 .filter(|m| m.role == Role::Assistant)
77 .filter(|m| !m.content.is_empty())
78 .rev() .take(3)
80 .map(|m| m.content.as_str())
81 .collect()
82 }
83
84 fn safe_truncate(&self, s: &str, max_chars: usize) -> String {
87 if s.chars().count() <= max_chars {
88 return s.to_string();
89 }
90
91 let truncated: String = s.chars().take(max_chars).collect();
93 format!("{}...", truncated)
94 }
95}
96
97#[async_trait]
98impl Summarizer for HeuristicSummarizer {
99 async fn summarize(&self, messages: &[Message]) -> Result<String, crate::types::BudgetError> {
100 if messages.is_empty() {
101 return Ok("No conversation history.".to_string());
102 }
103
104 let questions = self.extract_user_questions(messages);
105 let tools = self.extract_tools_used(messages);
106 let responses = self.extract_key_responses(messages);
107
108 let mut summary_parts = Vec::new();
109
110 if !questions.is_empty() {
112 summary_parts.push("## User Requests".to_string());
113 for (i, q) in questions.iter().enumerate() {
114 let truncated = self.safe_truncate(q, 200);
116 summary_parts.push(format!("{}. {}", i + 1, truncated));
117 }
118 }
119
120 if !tools.is_empty() {
122 summary_parts.push("\n## Tools Used".to_string());
123 for tool in tools {
124 summary_parts.push(format!("- {}", tool));
125 }
126 }
127
128 if !responses.is_empty() {
130 summary_parts.push("\n## Key Outcomes".to_string());
131 for (i, r) in responses.iter().enumerate() {
132 let truncated = self.safe_truncate(r, 300);
134 summary_parts.push(format!("{}. {}", i + 1, truncated));
135 }
136 }
137
138 if summary_parts.is_empty() {
139 Ok("Previous conversation context available.".to_string())
140 } else {
141 Ok(summary_parts.join("\n"))
142 }
143 }
144}
145
146#[derive(Debug, Clone)]
148pub enum SummaryTrigger {
149 OnTruncation,
151 Periodic { interval: usize },
153 TokenThreshold { threshold: u32 },
155}
156
157pub struct SummaryManager {
159 summarizer: Box<dyn Summarizer>,
160 trigger: SummaryTrigger,
161}
162
163impl std::fmt::Debug for SummaryManager {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 f.debug_struct("SummaryManager")
166 .field("trigger", &self.trigger)
167 .finish_non_exhaustive()
168 }
169}
170
171impl SummaryManager {
172 pub fn new(summarizer: impl Summarizer + 'static, trigger: SummaryTrigger) -> Self {
174 Self {
175 summarizer: Box::new(summarizer),
176 trigger,
177 }
178 }
179
180 pub fn should_summarize(
182 &self,
183 messages: &[Message],
184 _truncation_occurred: bool,
185 current_token_count: u32,
186 ) -> bool {
187 match &self.trigger {
188 SummaryTrigger::OnTruncation => _truncation_occurred,
189 SummaryTrigger::Periodic { interval } => messages.len() >= *interval,
190 SummaryTrigger::TokenThreshold { threshold } => current_token_count >= *threshold,
191 }
192 }
193
194 pub async fn summarize(
196 &self,
197 messages: &[Message],
198 ) -> Result<String, crate::types::BudgetError> {
199 self.summarizer.summarize(messages).await
200 }
201
202 pub fn estimate_summary_tokens(&self, message_count: usize) -> u32 {
204 self.summarizer.estimate_summary_tokens(message_count)
205 }
206}
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn heuristic_summarizer_extracts_user_questions() {
213 let summarizer = HeuristicSummarizer::new();
214 let messages = vec![
215 Message::user("What is the weather?"),
216 Message::assistant("It's sunny.", None),
217 Message::user("What about tomorrow?"),
218 ];
219
220 let questions = summarizer.extract_user_questions(&messages);
221 assert_eq!(questions.len(), 2);
222 assert!(questions[0].contains("weather"));
223 }
224
225 #[test]
226 fn heuristic_summarizer_extracts_tools_used() {
227 use bamboo_domain::{FunctionCall, ToolCall};
228
229 let summarizer = HeuristicSummarizer::new();
230 let tool_call = ToolCall {
231 id: "call_1".to_string(),
232 tool_type: "function".to_string(),
233 function: FunctionCall {
234 name: "search".to_string(),
235 arguments: "{}".to_string(),
236 },
237 };
238
239 let messages = vec![
240 Message::user("Search for something"),
241 Message::assistant("I'll search", Some(vec![tool_call])),
242 ];
243
244 let tools = summarizer.extract_tools_used(&messages);
245 assert_eq!(tools, vec!["search"]);
246 }
247
248 #[test]
249 fn heuristic_summarizer_extracts_key_responses() {
250 let summarizer = HeuristicSummarizer::new();
251 let messages = vec![
252 Message::user("Hello"),
253 Message::assistant("First response", None),
254 Message::user("How are you?"),
255 Message::assistant("Most recent response", None),
256 ];
257
258 let responses = summarizer.extract_key_responses(&messages);
259 assert_eq!(responses[0], "Most recent response");
261 }
262
263 #[tokio::test]
264 async fn heuristic_summarizer_generates_summary() {
265 let summarizer = HeuristicSummarizer::new();
266 let messages = vec![
267 Message::user("What is Rust?"),
268 Message::assistant("Rust is a systems programming language.", None),
269 ];
270
271 let summary = summarizer.summarize(&messages).await.unwrap();
272 assert!(summary.contains("User Requests"));
273 assert!(summary.contains("What is Rust?"));
274 }
275
276 #[test]
277 fn summary_trigger_on_truncation() {
278 let trigger = SummaryTrigger::OnTruncation;
279
280 assert!(matches!(trigger, SummaryTrigger::OnTruncation));
281 assert!(matches!(trigger, SummaryTrigger::OnTruncation));
283 }
285
286 #[test]
287 fn summary_trigger_periodic() {
288 let trigger = SummaryTrigger::Periodic { interval: 5 };
289 let messages: Vec<Message> = (0..5).map(|_| Message::user("Test")).collect();
290
291 if let SummaryTrigger::Periodic { interval } = trigger {
293 assert_eq!(interval, 5);
294 assert!(messages.len() >= interval);
295 } else {
296 panic!("Expected Periodic trigger");
297 }
298 }
299
300 #[test]
301 fn summary_trigger_token_threshold() {
302 let trigger = SummaryTrigger::TokenThreshold { threshold: 1000 };
303
304 if let SummaryTrigger::TokenThreshold { threshold } = trigger {
306 assert_eq!(threshold, 1000);
307 } else {
308 panic!("Expected TokenThreshold trigger");
309 }
310 }
311
312 #[test]
313 fn safe_truncate_handles_ascii() {
314 let summarizer = HeuristicSummarizer::new();
315 let text = "Hello world this is a test";
316 let truncated = summarizer.safe_truncate(text, 10);
317
318 assert!(truncated.ends_with("..."));
319 assert!(truncated.chars().count() <= 13);
321 }
322
323 #[test]
324 fn safe_truncate_handles_unicode() {
325 let summarizer = HeuristicSummarizer::new();
326
327 let text = "Hello 😀🎉🚀 World with emoji";
329 let truncated = summarizer.safe_truncate(text, 10);
330
331 assert!(truncated.ends_with("..."));
333 assert!(truncated.chars().count() <= 13);
334 }
335
336 #[test]
337 fn safe_truncate_handles_cjk() {
338 let summarizer = HeuristicSummarizer::new();
339
340 let text = "这是一个中文测试消息用于验证截断";
342 let truncated = summarizer.safe_truncate(text, 10);
343
344 assert!(truncated.ends_with("..."));
346 assert!(truncated.chars().count() <= 13);
347 }
348
349 #[test]
350 fn safe_truncate_handles_mixed_unicode() {
351 let summarizer = HeuristicSummarizer::new();
352
353 let text = "Hello 世界 🌍 test message";
355 let truncated = summarizer.safe_truncate(text, 8);
356
357 assert!(truncated.ends_with("..."));
359 assert!(truncated.chars().count() <= 11);
360 }
361
362 #[tokio::test]
363 async fn summarizer_handles_unicode_messages() {
364 let summarizer = HeuristicSummarizer::new();
365
366 let long_unicode =
368 "这是一段很长的中文消息需要被截断以测试我们的安全截断功能 😀🎉🚀".repeat(10);
369 let messages = vec![
370 Message::user(&long_unicode),
371 Message::assistant("Response", None),
372 ];
373
374 let summary = summarizer.summarize(&messages).await.unwrap();
376 assert!(summary.contains("User Requests"));
377 }
378
379 #[test]
380 fn safe_truncate_returns_short_text_unchanged() {
381 let summarizer = HeuristicSummarizer::new();
382 let text = "Short";
383 let truncated = summarizer.safe_truncate(text, 100);
384
385 assert_eq!(truncated, text);
387 }
388}