use async_trait::async_trait;
use bamboo_domain::{Message, Role};
use std::collections::HashSet;
#[async_trait]
pub trait Summarizer: Send + Sync {
async fn summarize(&self, messages: &[Message]) -> Result<String, crate::types::BudgetError>;
fn estimate_summary_tokens(&self, message_count: usize) -> u32 {
(message_count * 50).min(1000) as u32
}
}
#[derive(Debug, Default)]
pub struct HeuristicSummarizer;
impl HeuristicSummarizer {
pub fn new() -> Self {
Self
}
fn extract_user_questions<'a>(&self, messages: &'a [Message]) -> Vec<&'a str> {
messages
.iter()
.filter(|m| m.role == Role::User)
.filter(|m| !m.content.is_empty())
.take(10) .map(|m| m.content.as_str())
.collect()
}
fn extract_tools_used(&self, messages: &[Message]) -> Vec<String> {
let mut tools = HashSet::new();
for message in messages {
if let Some(ref tool_calls) = message.tool_calls {
for call in tool_calls {
tools.insert(call.function.name.clone());
}
}
}
let mut result: Vec<String> = tools.into_iter().collect();
result.sort();
result
}
fn extract_key_responses<'a>(&self, messages: &'a [Message]) -> Vec<&'a str> {
messages
.iter()
.filter(|m| m.role == Role::Assistant)
.filter(|m| !m.content.is_empty())
.rev() .take(3)
.map(|m| m.content.as_str())
.collect()
}
fn safe_truncate(&self, s: &str, max_chars: usize) -> String {
if s.chars().count() <= max_chars {
return s.to_string();
}
let truncated: String = s.chars().take(max_chars).collect();
format!("{}...", truncated)
}
}
#[async_trait]
impl Summarizer for HeuristicSummarizer {
async fn summarize(&self, messages: &[Message]) -> Result<String, crate::types::BudgetError> {
if messages.is_empty() {
return Ok("No conversation history.".to_string());
}
let questions = self.extract_user_questions(messages);
let tools = self.extract_tools_used(messages);
let responses = self.extract_key_responses(messages);
let mut summary_parts = Vec::new();
if !questions.is_empty() {
summary_parts.push("## User Requests".to_string());
for (i, q) in questions.iter().enumerate() {
let truncated = self.safe_truncate(q, 200);
summary_parts.push(format!("{}. {}", i + 1, truncated));
}
}
if !tools.is_empty() {
summary_parts.push("\n## Tools Used".to_string());
for tool in tools {
summary_parts.push(format!("- {}", tool));
}
}
if !responses.is_empty() {
summary_parts.push("\n## Key Outcomes".to_string());
for (i, r) in responses.iter().enumerate() {
let truncated = self.safe_truncate(r, 300);
summary_parts.push(format!("{}. {}", i + 1, truncated));
}
}
if summary_parts.is_empty() {
Ok("Previous conversation context available.".to_string())
} else {
Ok(summary_parts.join("\n"))
}
}
}
#[derive(Debug, Clone)]
pub enum SummaryTrigger {
OnTruncation,
Periodic { interval: usize },
TokenThreshold { threshold: u32 },
}
pub struct SummaryManager {
summarizer: Box<dyn Summarizer>,
trigger: SummaryTrigger,
}
impl std::fmt::Debug for SummaryManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SummaryManager")
.field("trigger", &self.trigger)
.finish_non_exhaustive()
}
}
impl SummaryManager {
pub fn new(summarizer: impl Summarizer + 'static, trigger: SummaryTrigger) -> Self {
Self {
summarizer: Box::new(summarizer),
trigger,
}
}
pub fn should_summarize(
&self,
messages: &[Message],
_truncation_occurred: bool,
current_token_count: u32,
) -> bool {
match &self.trigger {
SummaryTrigger::OnTruncation => _truncation_occurred,
SummaryTrigger::Periodic { interval } => messages.len() >= *interval,
SummaryTrigger::TokenThreshold { threshold } => current_token_count >= *threshold,
}
}
pub async fn summarize(
&self,
messages: &[Message],
) -> Result<String, crate::types::BudgetError> {
self.summarizer.summarize(messages).await
}
pub fn estimate_summary_tokens(&self, message_count: usize) -> u32 {
self.summarizer.estimate_summary_tokens(message_count)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn heuristic_summarizer_extracts_user_questions() {
let summarizer = HeuristicSummarizer::new();
let messages = vec![
Message::user("What is the weather?"),
Message::assistant("It's sunny.", None),
Message::user("What about tomorrow?"),
];
let questions = summarizer.extract_user_questions(&messages);
assert_eq!(questions.len(), 2);
assert!(questions[0].contains("weather"));
}
#[test]
fn heuristic_summarizer_extracts_tools_used() {
use bamboo_domain::{FunctionCall, ToolCall};
let summarizer = HeuristicSummarizer::new();
let tool_call = ToolCall {
id: "call_1".to_string(),
tool_type: "function".to_string(),
function: FunctionCall {
name: "search".to_string(),
arguments: "{}".to_string(),
},
};
let messages = vec![
Message::user("Search for something"),
Message::assistant("I'll search", Some(vec![tool_call])),
];
let tools = summarizer.extract_tools_used(&messages);
assert_eq!(tools, vec!["search"]);
}
#[test]
fn heuristic_summarizer_extracts_key_responses() {
let summarizer = HeuristicSummarizer::new();
let messages = vec![
Message::user("Hello"),
Message::assistant("First response", None),
Message::user("How are you?"),
Message::assistant("Most recent response", None),
];
let responses = summarizer.extract_key_responses(&messages);
assert_eq!(responses[0], "Most recent response");
}
#[tokio::test]
async fn heuristic_summarizer_generates_summary() {
let summarizer = HeuristicSummarizer::new();
let messages = vec![
Message::user("What is Rust?"),
Message::assistant("Rust is a systems programming language.", None),
];
let summary = summarizer.summarize(&messages).await.unwrap();
assert!(summary.contains("User Requests"));
assert!(summary.contains("What is Rust?"));
}
#[test]
fn summary_trigger_on_truncation() {
let trigger = SummaryTrigger::OnTruncation;
assert!(matches!(trigger, SummaryTrigger::OnTruncation));
assert!(matches!(trigger, SummaryTrigger::OnTruncation));
}
#[test]
fn summary_trigger_periodic() {
let trigger = SummaryTrigger::Periodic { interval: 5 };
let messages: Vec<Message> = (0..5).map(|_| Message::user("Test")).collect();
if let SummaryTrigger::Periodic { interval } = trigger {
assert_eq!(interval, 5);
assert!(messages.len() >= interval);
} else {
panic!("Expected Periodic trigger");
}
}
#[test]
fn summary_trigger_token_threshold() {
let trigger = SummaryTrigger::TokenThreshold { threshold: 1000 };
if let SummaryTrigger::TokenThreshold { threshold } = trigger {
assert_eq!(threshold, 1000);
} else {
panic!("Expected TokenThreshold trigger");
}
}
#[test]
fn safe_truncate_handles_ascii() {
let summarizer = HeuristicSummarizer::new();
let text = "Hello world this is a test";
let truncated = summarizer.safe_truncate(text, 10);
assert!(truncated.ends_with("..."));
assert!(truncated.chars().count() <= 13);
}
#[test]
fn safe_truncate_handles_unicode() {
let summarizer = HeuristicSummarizer::new();
let text = "Hello 😀🎉🚀 World with emoji";
let truncated = summarizer.safe_truncate(text, 10);
assert!(truncated.ends_with("..."));
assert!(truncated.chars().count() <= 13);
}
#[test]
fn safe_truncate_handles_cjk() {
let summarizer = HeuristicSummarizer::new();
let text = "这是一个中文测试消息用于验证截断";
let truncated = summarizer.safe_truncate(text, 10);
assert!(truncated.ends_with("..."));
assert!(truncated.chars().count() <= 13);
}
#[test]
fn safe_truncate_handles_mixed_unicode() {
let summarizer = HeuristicSummarizer::new();
let text = "Hello 世界 🌍 test message";
let truncated = summarizer.safe_truncate(text, 8);
assert!(truncated.ends_with("..."));
assert!(truncated.chars().count() <= 11);
}
#[tokio::test]
async fn summarizer_handles_unicode_messages() {
let summarizer = HeuristicSummarizer::new();
let long_unicode =
"这是一段很长的中文消息需要被截断以测试我们的安全截断功能 😀🎉🚀".repeat(10);
let messages = vec![
Message::user(&long_unicode),
Message::assistant("Response", None),
];
let summary = summarizer.summarize(&messages).await.unwrap();
assert!(summary.contains("User Requests"));
}
#[test]
fn safe_truncate_returns_short_text_unchanged() {
let summarizer = HeuristicSummarizer::new();
let text = "Short";
let truncated = summarizer.safe_truncate(text, 100);
assert_eq!(truncated, text);
}
}