use anyhow::Result;
use crate::providers::{ContentBlock, Message, MessageContent};
use super::summarizer::Summarizer;
use super::types::{AiCompressionMode, CompressionThresholds};
pub struct ToolCompressor {
summarizer: Option<Summarizer>,
thresholds: CompressionThresholds,
}
impl ToolCompressor {
pub fn new_truncate_only(thresholds: CompressionThresholds) -> Self {
Self {
summarizer: None,
thresholds,
}
}
pub fn new_with_ai(summarizer: Summarizer, thresholds: CompressionThresholds) -> Self {
Self {
summarizer: Some(summarizer),
thresholds,
}
}
pub async fn compress_results(
&self,
messages: &[Message],
ai_mode: AiCompressionMode,
) -> Result<Vec<Message>> {
let mut result = messages.to_vec();
for msg in &mut result {
if let MessageContent::Blocks(blocks) = &mut msg.content {
for block in blocks.iter_mut() {
if let ContentBlock::ToolResult { content, .. } = block {
let tokens = estimate_tokens_str(content);
if tokens < self.thresholds.small_content {
continue;
}
let compressed = self.compress_content(content, tokens, ai_mode).await?;
*content = compressed;
}
}
}
}
Ok(result)
}
async fn compress_content(
&self,
content: &str,
tokens: u32,
ai_mode: AiCompressionMode,
) -> Result<String> {
if ai_mode == AiCompressionMode::None || self.summarizer.is_none() {
return Ok(self.truncate_content(content));
}
let summarizer = self.summarizer.as_ref().unwrap();
if tokens < self.thresholds.medium_content {
let summary = summarizer.summarize_light(content).await?;
return Ok(format!("[摘要] {}", summary));
}
match ai_mode {
AiCompressionMode::Light => {
let summary = summarizer.summarize_light(content).await?;
Ok(format!("[摘要] {}", summary))
}
AiCompressionMode::Deep => {
let summary = summarizer.summarize_deep(content).await?;
Ok(format!("[详细摘要] {}", summary))
}
AiCompressionMode::None => Ok(self.truncate_content(content)),
}
}
fn truncate_content(&self, content: &str) -> String {
Summarizer::truncate_preserve_ends(content, self.thresholds.small_content)
}
pub fn needs_compression(content: &str, thresholds: &CompressionThresholds) -> bool {
estimate_tokens_str(content) >= thresholds.small_content
}
}
fn estimate_tokens_str(s: &str) -> u32 {
let (ascii, non_ascii) = count_chars(s);
let ascii_tokens = (ascii as f64 * 0.25).ceil() as u32;
let non_ascii_tokens = (non_ascii as f64 * 0.67).ceil() as u32;
ascii_tokens + non_ascii_tokens
}
fn count_chars(s: &str) -> (u32, u32) {
let mut ascii = 0u32;
let mut non_ascii = 0u32;
for ch in s.chars() {
if ch.is_ascii() {
ascii += 1;
} else {
non_ascii += 1;
}
}
(ascii, non_ascii)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_needs_compression() {
let thresholds = CompressionThresholds::default();
let short = "短内容";
assert!(!ToolCompressor::needs_compression(short, &thresholds));
let long = "很长的内容...".repeat(200);
assert!(ToolCompressor::needs_compression(&long, &thresholds));
}
#[test]
fn test_truncate_content() {
let thresholds = CompressionThresholds::default();
let compressor = ToolCompressor::new_truncate_only(thresholds);
let content = "开头内容中间很长的部分结尾内容".repeat(50);
let result = compressor.truncate_content(&content);
assert!(result.len() < content.len());
assert!(result.contains("[内容截断]"));
}
#[test]
fn test_estimate_tokens_str() {
let ascii = "hello world";
let tokens = estimate_tokens_str(ascii);
assert!(tokens > 0 && tokens < 10);
let chinese = "你好世界";
let tokens = estimate_tokens_str(chinese);
assert!(tokens > 0);
}
}