Skip to main content

bamboo_compression/
summarizer.rs

1//! Conversation summarization for rolling context management.
2//!
3//! When conversations are truncated due to token limits, a summary preserves
4//! key information from earlier context.
5
6use async_trait::async_trait;
7use bamboo_domain::{Message, Role};
8use std::collections::HashSet;
9
10/// Trait for summarization implementations.
11#[async_trait]
12pub trait Summarizer: Send + Sync {
13    /// Generate a summary of the given messages.
14    ///
15    /// Returns a string containing the summary.
16    async fn summarize(&self, messages: &[Message]) -> Result<String, crate::types::BudgetError>;
17
18    /// Get the estimated token count of the summary.
19    ///
20    /// Used to ensure the summary fits within the budget.
21    fn estimate_summary_tokens(&self, message_count: usize) -> u32 {
22        // Rough estimate: each message contributes ~50 tokens to the summary
23        (message_count * 50).min(1000) as u32
24    }
25}
26
27/// Heuristic summarizer that extracts key points without using an LLM.
28///
29/// This is a lightweight summarization approach that:
30/// 1. Lists user questions/requests
31/// 2. Lists tools that were used
32/// 3. Captures final conclusions
33///
34/// This provides continuity without expensive LLM calls.
35#[derive(Debug, Default)]
36pub struct HeuristicSummarizer;
37
38impl HeuristicSummarizer {
39    /// Create a new heuristic summarizer.
40    pub fn new() -> Self {
41        Self
42    }
43
44    /// Extract user questions from messages.
45    fn extract_user_questions<'a>(&self, messages: &'a [Message]) -> Vec<&'a str> {
46        messages
47            .iter()
48            .filter(|m| m.role == Role::User)
49            .filter(|m| !m.content.is_empty())
50            .take(10) // Limit to prevent huge summaries
51            .map(|m| m.content.as_str())
52            .collect()
53    }
54
55    /// Extract tool calls that were made.
56    fn extract_tools_used(&self, messages: &[Message]) -> Vec<String> {
57        let mut tools = HashSet::new();
58
59        for message in messages {
60            if let Some(ref tool_calls) = message.tool_calls {
61                for call in tool_calls {
62                    tools.insert(call.function.name.clone());
63                }
64            }
65        }
66
67        let mut result: Vec<String> = tools.into_iter().collect();
68        result.sort();
69        result
70    }
71
72    /// Extract key assistant responses.
73    fn extract_key_responses<'a>(&self, messages: &'a [Message]) -> Vec<&'a str> {
74        messages
75            .iter()
76            .filter(|m| m.role == Role::Assistant)
77            .filter(|m| !m.content.is_empty())
78            .rev() // Take most recent first
79            .take(3)
80            .map(|m| m.content.as_str())
81            .collect()
82    }
83
84    /// Safely truncate a string at a character boundary.
85    /// Uses char_indices() to ensure we don't split UTF-8 multi-byte characters.
86    fn safe_truncate(&self, s: &str, max_chars: usize) -> String {
87        if s.chars().count() <= max_chars {
88            return s.to_string();
89        }
90
91        // Take up to max_chars characters safely
92        let truncated: String = s.chars().take(max_chars).collect();
93        format!("{}...", truncated)
94    }
95}
96
97#[async_trait]
98impl Summarizer for HeuristicSummarizer {
99    async fn summarize(&self, messages: &[Message]) -> Result<String, crate::types::BudgetError> {
100        if messages.is_empty() {
101            return Ok("No conversation history.".to_string());
102        }
103
104        let questions = self.extract_user_questions(messages);
105        let tools = self.extract_tools_used(messages);
106        let responses = self.extract_key_responses(messages);
107
108        let mut summary_parts = Vec::new();
109
110        // User requests section
111        if !questions.is_empty() {
112            summary_parts.push("## User Requests".to_string());
113            for (i, q) in questions.iter().enumerate() {
114                // Truncate long questions for the summary (safe UTF-8)
115                let truncated = self.safe_truncate(q, 200);
116                summary_parts.push(format!("{}. {}", i + 1, truncated));
117            }
118        }
119
120        // Tools used section
121        if !tools.is_empty() {
122            summary_parts.push("\n## Tools Used".to_string());
123            for tool in tools {
124                summary_parts.push(format!("- {}", tool));
125            }
126        }
127
128        // Key responses section
129        if !responses.is_empty() {
130            summary_parts.push("\n## Key Outcomes".to_string());
131            for (i, r) in responses.iter().enumerate() {
132                // Truncate long responses (safe UTF-8)
133                let truncated = self.safe_truncate(r, 300);
134                summary_parts.push(format!("{}. {}", i + 1, truncated));
135            }
136        }
137
138        if summary_parts.is_empty() {
139            Ok("Previous conversation context available.".to_string())
140        } else {
141            Ok(summary_parts.join("\n"))
142        }
143    }
144}
145
146/// Trigger conditions for when to create a summary.
147#[derive(Debug, Clone)]
148pub enum SummaryTrigger {
149    /// Always summarize when truncation occurs
150    OnTruncation,
151    /// Summarize after N rounds of conversation
152    Periodic { interval: usize },
153    /// Summarize when token count exceeds threshold
154    TokenThreshold { threshold: u32 },
155}
156
157/// Manager for conversation summarization.
158pub struct SummaryManager {
159    summarizer: Box<dyn Summarizer>,
160    trigger: SummaryTrigger,
161}
162
163impl std::fmt::Debug for SummaryManager {
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        f.debug_struct("SummaryManager")
166            .field("trigger", &self.trigger)
167            .finish_non_exhaustive()
168    }
169}
170
171impl SummaryManager {
172    /// Create a new summary manager.
173    pub fn new(summarizer: impl Summarizer + 'static, trigger: SummaryTrigger) -> Self {
174        Self {
175            summarizer: Box::new(summarizer),
176            trigger,
177        }
178    }
179
180    /// Check if summarization should be triggered based on conversation state.
181    pub fn should_summarize(
182        &self,
183        messages: &[Message],
184        _truncation_occurred: bool,
185        current_token_count: u32,
186    ) -> bool {
187        match &self.trigger {
188            SummaryTrigger::OnTruncation => _truncation_occurred,
189            SummaryTrigger::Periodic { interval } => messages.len() >= *interval,
190            SummaryTrigger::TokenThreshold { threshold } => current_token_count >= *threshold,
191        }
192    }
193
194    /// Generate a summary of the messages.
195    pub async fn summarize(
196        &self,
197        messages: &[Message],
198    ) -> Result<String, crate::types::BudgetError> {
199        self.summarizer.summarize(messages).await
200    }
201
202    /// Estimate the token count of a summary for N messages.
203    pub fn estimate_summary_tokens(&self, message_count: usize) -> u32 {
204        self.summarizer.estimate_summary_tokens(message_count)
205    }
206}
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn heuristic_summarizer_extracts_user_questions() {
213        let summarizer = HeuristicSummarizer::new();
214        let messages = vec![
215            Message::user("What is the weather?"),
216            Message::assistant("It's sunny.", None),
217            Message::user("What about tomorrow?"),
218        ];
219
220        let questions = summarizer.extract_user_questions(&messages);
221        assert_eq!(questions.len(), 2);
222        assert!(questions[0].contains("weather"));
223    }
224
225    #[test]
226    fn heuristic_summarizer_extracts_tools_used() {
227        use bamboo_domain::{FunctionCall, ToolCall};
228
229        let summarizer = HeuristicSummarizer::new();
230        let tool_call = ToolCall {
231            id: "call_1".to_string(),
232            tool_type: "function".to_string(),
233            function: FunctionCall {
234                name: "search".to_string(),
235                arguments: "{}".to_string(),
236            },
237        };
238
239        let messages = vec![
240            Message::user("Search for something"),
241            Message::assistant("I'll search", Some(vec![tool_call])),
242        ];
243
244        let tools = summarizer.extract_tools_used(&messages);
245        assert_eq!(tools, vec!["search"]);
246    }
247
248    #[test]
249    fn heuristic_summarizer_extracts_key_responses() {
250        let summarizer = HeuristicSummarizer::new();
251        let messages = vec![
252            Message::user("Hello"),
253            Message::assistant("First response", None),
254            Message::user("How are you?"),
255            Message::assistant("Most recent response", None),
256        ];
257
258        let responses = summarizer.extract_key_responses(&messages);
259        // Should return most recent first
260        assert_eq!(responses[0], "Most recent response");
261    }
262
263    #[tokio::test]
264    async fn heuristic_summarizer_generates_summary() {
265        let summarizer = HeuristicSummarizer::new();
266        let messages = vec![
267            Message::user("What is Rust?"),
268            Message::assistant("Rust is a systems programming language.", None),
269        ];
270
271        let summary = summarizer.summarize(&messages).await.unwrap();
272        assert!(summary.contains("User Requests"));
273        assert!(summary.contains("What is Rust?"));
274    }
275
276    #[test]
277    fn summary_trigger_on_truncation() {
278        let trigger = SummaryTrigger::OnTruncation;
279
280        assert!(matches!(trigger, SummaryTrigger::OnTruncation));
281        // When truncation_occurred is true
282        assert!(matches!(trigger, SummaryTrigger::OnTruncation));
283        // When truncation_occurred is false - just verify the trigger type
284    }
285
286    #[test]
287    fn summary_trigger_periodic() {
288        let trigger = SummaryTrigger::Periodic { interval: 5 };
289        let messages: Vec<Message> = (0..5).map(|_| Message::user("Test")).collect();
290
291        // Verify the trigger is periodic with correct interval
292        if let SummaryTrigger::Periodic { interval } = trigger {
293            assert_eq!(interval, 5);
294            assert!(messages.len() >= interval);
295        } else {
296            panic!("Expected Periodic trigger");
297        }
298    }
299
300    #[test]
301    fn summary_trigger_token_threshold() {
302        let trigger = SummaryTrigger::TokenThreshold { threshold: 1000 };
303
304        // Verify the trigger has the correct threshold
305        if let SummaryTrigger::TokenThreshold { threshold } = trigger {
306            assert_eq!(threshold, 1000);
307        } else {
308            panic!("Expected TokenThreshold trigger");
309        }
310    }
311
312    #[test]
313    fn safe_truncate_handles_ascii() {
314        let summarizer = HeuristicSummarizer::new();
315        let text = "Hello world this is a test";
316        let truncated = summarizer.safe_truncate(text, 10);
317
318        assert!(truncated.ends_with("..."));
319        // Should have at most 10 characters + "..."
320        assert!(truncated.chars().count() <= 13);
321    }
322
323    #[test]
324    fn safe_truncate_handles_unicode() {
325        let summarizer = HeuristicSummarizer::new();
326
327        // Test with emoji (multi-byte UTF-8)
328        let text = "Hello 😀🎉🚀 World with emoji";
329        let truncated = summarizer.safe_truncate(text, 10);
330
331        // Should not panic and should end with "..."
332        assert!(truncated.ends_with("..."));
333        assert!(truncated.chars().count() <= 13);
334    }
335
336    #[test]
337    fn safe_truncate_handles_cjk() {
338        let summarizer = HeuristicSummarizer::new();
339
340        // Test with Chinese/Japanese/Korean characters (3-byte UTF-8)
341        let text = "这是一个中文测试消息用于验证截断";
342        let truncated = summarizer.safe_truncate(text, 10);
343
344        // Should not panic
345        assert!(truncated.ends_with("..."));
346        assert!(truncated.chars().count() <= 13);
347    }
348
349    #[test]
350    fn safe_truncate_handles_mixed_unicode() {
351        let summarizer = HeuristicSummarizer::new();
352
353        // Mixed ASCII, CJK, and emoji
354        let text = "Hello 世界 🌍 test message";
355        let truncated = summarizer.safe_truncate(text, 8);
356
357        // Should not panic
358        assert!(truncated.ends_with("..."));
359        assert!(truncated.chars().count() <= 11);
360    }
361
362    #[tokio::test]
363    async fn summarizer_handles_unicode_messages() {
364        let summarizer = HeuristicSummarizer::new();
365
366        // Create messages with unicode that needs truncation
367        let long_unicode =
368            "这是一段很长的中文消息需要被截断以测试我们的安全截断功能 😀🎉🚀".repeat(10);
369        let messages = vec![
370            Message::user(&long_unicode),
371            Message::assistant("Response", None),
372        ];
373
374        // Should not panic on unicode truncation
375        let summary = summarizer.summarize(&messages).await.unwrap();
376        assert!(summary.contains("User Requests"));
377    }
378
379    #[test]
380    fn safe_truncate_returns_short_text_unchanged() {
381        let summarizer = HeuristicSummarizer::new();
382        let text = "Short";
383        let truncated = summarizer.safe_truncate(text, 100);
384
385        // Should return unchanged
386        assert_eq!(truncated, text);
387    }
388}