use anyhow::Result;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use crate::providers::{ContentBlock, Message, MessageContent, Provider, Role, ChatRequest, ChatResponse};
pub const DEFAULT_COMPRESSION_THRESHOLD: f64 = 0.75;
pub const MIN_MESSAGES_TO_KEEP: usize = 8;
pub const DEFAULT_TARGET_RATIO: f64 = 0.4;
pub const DEFAULT_COMPRESSOR_MODEL: &str = "claude-3-5-haiku-20241022";
#[derive(Debug, Clone, Default)]
pub struct CompressionBias {
pub preserve_tools: bool,
pub preserve_thinking: bool,
pub preserve_user_questions: bool,
pub compact_long_outputs: bool,
pub aggressive: bool,
pub preserve_keywords: Vec<String>,
}
impl CompressionBias {
pub fn balanced() -> Self {
Self {
preserve_tools: true,
preserve_thinking: false,
preserve_user_questions: true,
compact_long_outputs: false,
aggressive: false,
preserve_keywords: vec![
"决定".to_string(), "decision".to_string(),
"重要".to_string(), "important".to_string(),
"关键".to_string(), "key".to_string()
],
}
}
pub fn preserve_important() -> Self {
Self {
preserve_tools: true,
preserve_thinking: true,
preserve_user_questions: true,
compact_long_outputs: true,
aggressive: false,
preserve_keywords: vec![
"决定".to_string(), "decision".to_string(),
"重要".to_string(), "important".to_string(),
"关键".to_string(), "key".to_string(),
"完成".to_string(), "done".to_string(),
"成功".to_string(), "success".to_string()
],
}
}
pub fn aggressive() -> Self {
Self {
preserve_tools: false,
preserve_thinking: false,
preserve_user_questions: false,
compact_long_outputs: false,
aggressive: true,
preserve_keywords: vec![],
}
}
pub fn tool_focused() -> Self {
Self {
preserve_tools: true,
preserve_thinking: false,
preserve_user_questions: false,
compact_long_outputs: false,
aggressive: false,
preserve_keywords: vec![
"工具".to_string(), "tool".to_string(),
"执行".to_string(), "execute".to_string(),
"文件".to_string(), "file".to_string()
],
}
}
pub fn parse(spec: &str) -> Result<Self> {
let spec = spec.trim().to_lowercase();
if spec == "balanced" || spec == "default" || spec.is_empty() {
return Ok(Self::balanced());
}
if spec == "aggressive" {
return Ok(Self::aggressive());
}
if spec == "preserve_important" || spec == "important" {
return Ok(Self::preserve_important());
}
if spec == "tool_focused" || spec == "tools" {
return Ok(Self::tool_focused());
}
let mut bias = Self::default();
for part in spec.split_whitespace() {
if let Some(preserve_list) = part.strip_prefix("preserve:") {
for item in preserve_list.split(',') {
match item.trim() {
"tools" | "tool" => bias.preserve_tools = true,
"thinking" | "think" => bias.preserve_thinking = true,
"user" | "questions" => bias.preserve_user_questions = true,
"compact" | "long" => bias.compact_long_outputs = true,
_ => {}
}
}
} else if let Some(keyword_list) = part.strip_prefix("keywords:") {
bias.preserve_keywords = keyword_list.split(',')
.map(|k| k.trim().to_string())
.filter(|k| !k.is_empty())
.collect();
} else if part == "aggressive" {
bias.aggressive = true;
}
}
Ok(bias)
}
pub fn format(&self) -> String {
let mut parts: Vec<String> = Vec::new();
if self.preserve_tools { parts.push("tools".to_string()); }
if self.preserve_thinking { parts.push("thinking".to_string()); }
if self.preserve_user_questions { parts.push("user".to_string()); }
if self.compact_long_outputs { parts.push("compact".to_string()); }
if self.aggressive { parts.push("aggressive".to_string()); }
if !self.preserve_keywords.is_empty() {
parts.push(format!("keywords:{}", self.preserve_keywords.join(",")));
}
if parts.is_empty() {
"default".to_string()
} else {
parts.join(", ")
}
}
}
#[derive(Debug, Clone)]
pub struct CompressionConfig {
pub threshold: f64,
pub target_ratio: f64,
pub min_preserve_messages: usize,
pub use_summarization: bool,
pub compressor_model: Option<String>,
pub bias: CompressionBias,
}
impl Default for CompressionConfig {
fn default() -> Self {
Self {
threshold: DEFAULT_COMPRESSION_THRESHOLD,
target_ratio: DEFAULT_TARGET_RATIO,
min_preserve_messages: MIN_MESSAGES_TO_KEEP,
use_summarization: true,
compressor_model: None,
bias: CompressionBias::balanced(),
}
}
}
impl CompressionConfig {
pub fn compressor_model_name(&self) -> &str {
self.compressor_model.as_deref().unwrap_or(DEFAULT_COMPRESSOR_MODEL)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompressionResult {
pub original_count: usize,
pub new_count: usize,
pub tokens_saved: u32,
pub summary: Option<String>,
pub strategy: CompressionStrategy,
pub timestamp: DateTime<Utc>,
}
impl CompressionResult {
pub fn new(
original_count: usize,
new_count: usize,
tokens_saved: u32,
summary: Option<String>,
strategy: CompressionStrategy,
) -> Self {
Self {
original_count,
new_count,
tokens_saved,
summary,
strategy,
timestamp: Utc::now(),
}
}
pub fn format_summary(&self) -> String {
let strategy_name = match self.strategy {
CompressionStrategy::Truncate => "truncate",
CompressionStrategy::SlidingWindow => "sliding window",
CompressionStrategy::Summarize => "AI summarize",
CompressionStrategy::BiasBased => "bias-based",
};
format!(
"{} messages → {} messages (saved ~{} tokens, {})",
self.original_count,
self.new_count,
format_tokens(self.tokens_saved),
strategy_name
)
}
}
pub fn format_tokens(n: u32) -> String {
if n < 1_000 {
n.to_string()
} else if n < 10_000 {
format!("{:.1}K", n as f64 / 1_000.0)
} else {
format!("{:.0}K", n as f64 / 1_000.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompressionStrategy {
Truncate,
SlidingWindow,
Summarize,
BiasBased,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SummarizedSegment {
pub time_range: (DateTime<Utc>, DateTime<Utc>),
pub original_count: usize,
pub summary: String,
pub key_points: Vec<String>,
}
impl SummarizedSegment {
pub fn to_message(&self) -> Message {
let key_points_text = if self.key_points.is_empty() {
"无".to_string()
} else {
self.key_points.iter().map(|p| format!("• {}", p)).collect::<Vec<_>>().join("\n")
};
let content = format!(
"[对话摘要 - 原 {} 条消息]\n\n{}\n\n关键要点:\n{}",
self.original_count,
self.summary,
key_points_text
);
Message {
role: Role::User,
content: MessageContent::Text(content),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompressionHistoryEntry {
pub timestamp: DateTime<Utc>,
pub strategy: CompressionStrategy,
pub original_count: usize,
pub new_count: usize,
pub tokens_saved: u32,
pub has_summary: bool,
}
impl CompressionHistoryEntry {
pub fn from_result(result: &CompressionResult) -> Self {
Self {
timestamp: result.timestamp,
strategy: result.strategy,
original_count: result.original_count,
new_count: result.new_count,
tokens_saved: result.tokens_saved,
has_summary: result.summary.is_some(),
}
}
pub fn format_line(&self) -> String {
let strategy_name = match self.strategy {
CompressionStrategy::Truncate => "truncate",
CompressionStrategy::SlidingWindow => "sliding window",
CompressionStrategy::Summarize => "AI summarize",
CompressionStrategy::BiasBased => "bias-based",
};
let summary_marker = if self.has_summary { "📝" } else { "✂️" };
format!(
"{} {} - {} msgs → {} msgs (~{} tokens saved) {}",
self.timestamp.format("%Y-%m-%d %H:%M"),
strategy_name,
self.original_count,
self.new_count,
format_tokens(self.tokens_saved),
summary_marker
)
}
}
#[async_trait]
pub trait Compressor: Send + Sync {
async fn summarize(&self, messages: &[Message], config: &CompressionConfig) -> Result<SummarizedSegment>;
fn model_name(&self) -> &str;
}
pub struct AiCompressor {
provider: Box<dyn Provider>,
model: String,
}
impl AiCompressor {
pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
Self { provider, model }
}
}
#[async_trait]
impl Compressor for AiCompressor {
async fn summarize(&self, messages: &[Message], _config: &CompressionConfig) -> Result<SummarizedSegment> {
let prompt = build_summary_prompt(messages);
let request = ChatRequest {
messages: vec![Message {
role: Role::User,
content: MessageContent::Text(prompt),
}],
tools: vec![], system: Some(SUMMARY_SYSTEM_PROMPT.to_string()),
think: false, max_tokens: 1024, server_tools: vec![],
enable_caching: false, };
let response = self.provider.chat(request).await?;
let summary_text = extract_text_from_response(&response);
let (summary, key_points) = parse_summary_response(&summary_text);
Ok(SummarizedSegment {
time_range: (Utc::now(), Utc::now()), original_count: messages.len(),
summary,
key_points,
})
}
fn model_name(&self) -> &str {
&self.model
}
}
const SUMMARY_SYSTEM_PROMPT: &str = r#"你是一个对话历史压缩助手。你的任务是将对话历史压缩为简洁的摘要,保留关键信息。
输出要求:
- 简洁:摘要控制在 200 字以内
- 关键:只保留重要操作和决策
- 结构化:使用清晰格式
- 敏感:必须保留用户的敏感指令(如"不要..."、"必须..."、"禁止..."等)
- 偏好:保留用户的偏好设置和决策
请直接输出摘要内容。"#;
fn extract_text_from_response(response: &ChatResponse) -> String {
response.content
.iter()
.filter_map(|block| {
if let ContentBlock::Text { text } = block {
Some(text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n")
}
fn parse_summary_response(text: &str) -> (String, Vec<String>) {
let mut summary = String::new();
let mut key_points: Vec<String> = Vec::new();
for line in text.lines() {
let line = line.trim();
if line.starts_with("•") || line.starts_with("-") || line.starts_with("*") {
let point = line.trim_start_matches(['•', '-', '*']).trim();
if !point.is_empty() {
key_points.push(point.to_string());
}
} else if line.starts_with("已完成") || line.starts_with("操作") {
let ops = line.trim_start_matches(|c: char| c.is_alphabetic() || c == ':' || c == ':').trim();
if !ops.is_empty() && ops != ":" && ops != ":" {
key_points.push(ops.to_string());
}
} else if !line.is_empty() && summary.is_empty() {
summary = line.to_string();
} else if !line.is_empty() {
if key_points.is_empty() && summary.len() < 200 {
summary.push(' ');
summary.push_str(line);
}
}
}
if summary.is_empty() && !text.is_empty() {
summary = text.lines().take(3).collect::<Vec<_>>().join(" ");
if summary.len() > 200 {
summary = truncate_text(&summary, 200);
}
}
(summary, key_points)
}
fn truncate_text(s: &str, max: usize) -> String {
if s.len() <= max {
s.to_string()
} else {
let mut end = max;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
format!("{}...", &s[..end])
}
}
pub fn compress_messages(
messages: &[Message],
strategy: CompressionStrategy,
config: &CompressionConfig,
) -> Result<Vec<Message>> {
match strategy {
CompressionStrategy::Truncate => truncate_compress(messages, config),
CompressionStrategy::SlidingWindow => sliding_window_compress(messages, config),
CompressionStrategy::Summarize => {
sliding_window_compress(messages, config)
}
CompressionStrategy::BiasBased => compress_with_bias(messages, config),
}
}
pub fn compress_with_bias(
messages: &[Message],
config: &CompressionConfig,
) -> Result<Vec<Message>> {
if messages.len() <= config.min_preserve_messages {
return Ok(messages.to_vec());
}
let scored_messages: Vec<(usize, Message, f64)> = messages
.iter()
.enumerate()
.map(|(idx, msg)| (idx, msg.clone(), calculate_preservation_score(msg, idx, messages.len(), &config.bias)))
.collect();
let mut scored_with_recency: Vec<(usize, Message, f64)> = scored_messages
.into_iter()
.map(|(idx, msg, score)| {
let recency_bonus = if idx >= messages.len() - config.min_preserve_messages {
100.0 } else {
(idx as f64 / messages.len() as f64) * 20.0 };
(idx, msg, score + recency_bonus)
})
.collect();
scored_with_recency.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
let target_count = if config.bias.aggressive {
config.min_preserve_messages
} else {
let estimated_tokens = estimate_total_tokens(messages);
let target_tokens = (estimated_tokens as f64 * config.target_ratio) as u32;
let avg_tokens_per_msg = estimated_tokens / messages.len() as u32;
let calculated = (target_tokens / avg_tokens_per_msg.max(1)) as usize;
calculated.max(config.min_preserve_messages)
};
let to_keep_indices: HashSet<usize> = scored_with_recency
.iter()
.take(target_count)
.map(|(idx, _, _)| *idx)
.collect();
let compressed: Vec<Message> = messages
.iter()
.enumerate()
.filter(|(idx, _)| to_keep_indices.contains(idx))
.map(|(_, msg)| msg.clone())
.collect();
Ok(compressed)
}
fn calculate_preservation_score(message: &Message, _index: usize, _total: usize, bias: &CompressionBias) -> f64 {
let mut score: f64 = 10.0;
match message.role {
Role::User => {
if bias.preserve_user_questions {
score += 30.0;
}
}
Role::Assistant => {
score += 5.0;
}
Role::Tool => {
if bias.preserve_tools {
score += 25.0;
}
}
Role::System => {
score += 40.0; }
}
match &message.content {
MessageContent::Text(text) => {
for keyword in &bias.preserve_keywords {
if text.to_lowercase().contains(&keyword.to_lowercase()) {
score += 15.0;
}
}
if contains_sensitive_instructions(text) {
score += 50.0; }
if !bias.compact_long_outputs && text.len() > 2000 {
score -= 10.0;
}
}
MessageContent::Blocks(blocks) => {
for block in blocks {
match block {
ContentBlock::ToolUse { name, .. } => {
if bias.preserve_tools {
score += 20.0;
}
if name == "write" || name == "edit" || name == "bash" {
score += 10.0;
}
}
ContentBlock::ToolResult { content, .. } => {
if bias.preserve_tools {
score += 20.0;
}
for keyword in &bias.preserve_keywords {
if content.to_lowercase().contains(&keyword.to_lowercase()) {
score += 10.0;
}
}
if contains_sensitive_instructions(content) {
score += 30.0;
}
}
ContentBlock::Thinking { .. } => {
if bias.preserve_thinking {
score += 25.0;
} else {
score -= 5.0; }
}
ContentBlock::Text { text } => {
for keyword in &bias.preserve_keywords {
if text.to_lowercase().contains(&keyword.to_lowercase()) {
score += 15.0;
}
}
if contains_sensitive_instructions(text) {
score += 50.0;
}
}
_ => {}
}
}
}
}
score
}
fn contains_sensitive_instructions(text: &str) -> bool {
let text_lower = text.to_lowercase();
let sensitive_patterns = [
"不要", "禁止", "不能", "千万别", "禁止使用",
"never do", "must not", "should not", "cannot", "avoid",
"必须", "一定要", "务必", "必须使用",
"must", "required", "mandatory",
"敏感", "隐私", "密码", "secret", "password", "credential",
"private", "sensitive", "confidential",
"决定", "决策", "critical", "important", "关键",
"偏好", "我喜欢", "我习惯", "prefer", "preference",
"严格按照", "遵循", "按原样", "strictly", "exactly",
"不要修改", "不要改动", "keep original", "as is",
];
for pattern in &sensitive_patterns {
if text_lower.contains(pattern) {
return true;
}
}
false
}
pub async fn compress_messages_with_ai(
messages: &[Message],
compressor: &dyn Compressor,
config: &CompressionConfig,
) -> Result<(Vec<Message>, Option<SummarizedSegment>)> {
if messages.len() <= config.min_preserve_messages {
return Ok((messages.to_vec(), None));
}
let preserve_count = config.min_preserve_messages;
let summarize_messages = &messages[..messages.len() - preserve_count];
let keep_messages = &messages[messages.len() - preserve_count..];
let segment = compressor.summarize(summarize_messages, config).await?;
let summary_msg = segment.to_message();
let mut compressed = vec![summary_msg];
compressed.extend(keep_messages.to_vec());
Ok((compressed, Some(segment)))
}
fn truncate_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
if messages.len() <= config.min_preserve_messages {
return Ok(messages.to_vec());
}
let keep_count = config.min_preserve_messages;
let start_idx = messages.len().saturating_sub(keep_count);
Ok(messages[start_idx..].to_vec())
}
fn sliding_window_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
if messages.len() <= config.min_preserve_messages {
return Ok(messages.to_vec());
}
let total_tokens = estimate_total_tokens(messages);
let target_tokens = (total_tokens as f64 * config.target_ratio) as u32;
let mut turn_boundaries: Vec<usize> = Vec::new();
for (i, msg) in messages.iter().enumerate() {
if msg.role == Role::User {
turn_boundaries.push(i);
}
}
let min_start_idx = messages.len().saturating_sub(config.min_preserve_messages);
for &start_idx in turn_boundaries.iter() {
if messages.len() - start_idx < config.min_preserve_messages {
continue;
}
let candidate_messages = &messages[start_idx..];
let candidate_tokens = estimate_total_tokens(candidate_messages);
if candidate_tokens <= target_tokens {
return Ok(candidate_messages.to_vec());
}
}
Ok(messages[min_start_idx..].to_vec())
}
pub fn estimate_tokens(message: &Message) -> u32 {
let char_count = match &message.content {
MessageContent::Text(t) => t.len(),
MessageContent::Blocks(blocks) => {
let mut count = 0;
for block in blocks {
match block {
ContentBlock::Text { text } => count += text.len(),
ContentBlock::ToolUse { name, input, .. } => {
count += name.len();
count += input.to_string().len();
}
ContentBlock::ToolResult { content, .. } => count += content.len(),
ContentBlock::Thinking { thinking, .. } => count += thinking.len(),
_ => {}
}
}
count
}
};
(char_count / 3).max(1) as u32
}
pub fn estimate_total_tokens(messages: &[Message]) -> u32 {
messages.iter().map(estimate_tokens).sum()
}
pub fn should_compress(
current_tokens: u32,
context_size: Option<u32>,
config: &CompressionConfig,
) -> bool {
match context_size {
Some(size) => {
let ratio = current_tokens as f64 / size as f64;
ratio >= config.threshold
}
None => false,
}
}
pub fn build_summary_prompt(messages: &[Message]) -> String {
let history_text = messages
.iter()
.map(|m| {
let role = match m.role {
Role::User => "用户",
Role::Assistant => "助手",
Role::Tool => "工具",
Role::System => "系统",
};
let content_preview = match &m.content {
MessageContent::Text(t) => truncate_for_summary(t, 200),
MessageContent::Blocks(blocks) => {
let preview: Vec<String> = blocks
.iter()
.map(|b| match b {
ContentBlock::Text { text } => truncate_for_summary(text, 100),
ContentBlock::ToolUse { name, .. } => format!("[工具: {}]", name),
ContentBlock::ToolResult { content, .. } => truncate_for_summary(content, 100),
_ => "[...]".to_string(),
})
.collect();
preview.join(" | ")
}
};
format!("{}: {}", role, content_preview)
})
.collect::<Vec<_>>()
.join("\n");
format!(
r#"请将以下对话历史压缩为简洁摘要:
对话历史({} 条消息):
{}
请输出:
1. 概述(一句话描述主要任务)
2. 已完成的关键操作(2-3 条)
3. 当前状态(如果有)"#,
messages.len(),
history_text
)
}
fn truncate_for_summary(s: &str, max: usize) -> String {
truncate_text(s, max)
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_estimate_tokens_simple() {
let msg = Message {
role: Role::User,
content: MessageContent::Text("Hello world".to_string()),
};
assert!(estimate_tokens(&msg) >= 3);
}
#[test]
fn test_should_compress_below_threshold() {
let config = CompressionConfig::default();
assert!(!should_compress(100_000, Some(200_000), &config));
}
#[test]
fn test_should_compress_above_threshold() {
let config = CompressionConfig::default();
assert!(should_compress(160_000, Some(200_000), &config));
}
#[test]
fn test_truncate_compress_keeps_minimum() {
let messages: Vec<Message> = (0..10)
.map(|i| Message {
role: Role::User,
content: MessageContent::Text(format!("Message {}", i)),
})
.collect();
let config = CompressionConfig {
min_preserve_messages: 4,
..Default::default()
};
let compressed = truncate_compress(&messages, &config).unwrap();
assert_eq!(compressed.len(), 4);
assert_eq!(compressed[0].content, MessageContent::Text("Message 6".to_string()));
}
#[test]
fn test_sliding_window_preserves_turns() {
let messages: Vec<Message> = vec![
Message { role: Role::User, content: MessageContent::Text("Q1 - this is a longer question to test token estimation".to_string()) },
Message { role: Role::Assistant, content: MessageContent::Text("A1 - this is a longer answer with more content for token estimation".to_string()) },
Message { role: Role::User, content: MessageContent::Text("Q2 - another longer question for testing".to_string()) },
Message { role: Role::Assistant, content: MessageContent::Text("A2 - another longer answer for testing token estimation properly".to_string()) },
Message { role: Role::User, content: MessageContent::Text("Q3 - the third question in this test".to_string()) },
Message { role: Role::Assistant, content: MessageContent::Text("A3 - the third answer with sufficient content".to_string()) },
];
let config = CompressionConfig {
min_preserve_messages: 4,
target_ratio: 0.5,
..Default::default()
};
let compressed = sliding_window_compress(&messages, &config).unwrap();
assert!(compressed.len() >= config.min_preserve_messages);
assert!(compressed.iter().any(|m| m.role == Role::User));
}
#[test]
fn test_parse_summary_response() {
let text = "用户请求实现登录功能。\n已完成操作:\n• 创建了 login.rs 文件\n• 添加了密码验证逻辑\n当前状态:测试中";
let (summary, key_points) = parse_summary_response(text);
assert!(!summary.is_empty());
assert!(key_points.len() >= 2);
}
#[test]
fn test_compression_result_format() {
let result = CompressionResult::new(
20,
8,
5000,
Some("摘要内容".to_string()),
CompressionStrategy::Summarize,
);
let formatted = result.format_summary();
assert!(formatted.contains("20"));
assert!(formatted.contains("8"));
assert!(formatted.contains("AI summarize"));
}
#[test]
fn test_compression_history_entry() {
let result = CompressionResult::new(
15,
6,
3000,
None,
CompressionStrategy::SlidingWindow,
);
let entry = CompressionHistoryEntry::from_result(&result);
assert_eq!(entry.strategy, CompressionStrategy::SlidingWindow);
assert!(!entry.has_summary);
}
#[test]
fn test_compression_bias_parse() {
let balanced = CompressionBias::parse("balanced").unwrap();
assert!(balanced.preserve_tools);
assert!(balanced.preserve_user_questions);
let aggressive = CompressionBias::parse("aggressive").unwrap();
assert!(!aggressive.preserve_tools);
assert!(aggressive.aggressive);
let important = CompressionBias::parse("important").unwrap();
assert!(important.preserve_thinking);
assert!(important.preserve_tools);
let tools = CompressionBias::parse("tools").unwrap();
assert!(tools.preserve_tools);
assert!(!tools.preserve_thinking);
}
#[test]
fn test_compression_bias_format() {
let bias = CompressionBias::balanced();
let formatted = bias.format();
assert!(formatted.contains("tools"));
assert!(formatted.contains("user"));
}
#[test]
fn test_compress_with_bias_preserves_tools() {
let messages: Vec<Message> = vec![
Message { role: Role::User, content: MessageContent::Text("Q1".to_string()) },
Message {
role: Role::Assistant,
content: MessageContent::Blocks(vec![
ContentBlock::ToolUse { id: "1".to_string(), name: "read".to_string(), input: json!({}) }
])
},
Message { role: Role::Tool, content: MessageContent::Blocks(vec![
ContentBlock::ToolResult { tool_use_id: "1".to_string(), content: "file content".to_string() }
])},
Message { role: Role::User, content: MessageContent::Text("Q2".to_string()) },
Message { role: Role::Assistant, content: MessageContent::Text("A2".to_string()) },
Message { role: Role::User, content: MessageContent::Text("Q3".to_string()) },
Message { role: Role::Assistant, content: MessageContent::Text("A3".to_string()) },
];
let config = CompressionConfig {
min_preserve_messages: 2,
bias: CompressionBias::tool_focused(),
..Default::default()
};
let compressed = compress_with_bias(&messages, &config).unwrap();
let has_tool_use = compressed.iter().any(|m| {
matches!(&m.content, MessageContent::Blocks(blocks) if
blocks.iter().any(|b| matches!(b, ContentBlock::ToolUse { .. })))
});
assert!(has_tool_use || compressed.len() >= messages.len() - 2);
}
#[test]
fn test_bias_based_strategy() {
let messages: Vec<Message> = (0..10)
.map(|i| Message {
role: if i % 2 == 0 { Role::User } else { Role::Assistant },
content: MessageContent::Text(format!("Message {}", i)),
})
.collect();
let config = CompressionConfig {
min_preserve_messages: 4,
bias: CompressionBias::aggressive(),
..Default::default()
};
let compressed = compress_messages(&messages, CompressionStrategy::BiasBased, &config).unwrap();
assert!(compressed.len() <= messages.len());
assert!(compressed.len() >= config.min_preserve_messages);
}
}