Skip to main content

matrixcode_core/compress/
tool_compressor.rs

1//! Tool result compression for large content.
2//!
3//! Compresses large tool results (e.g., file reads) using
4//! summarization or truncation to reduce token usage.
5
6use anyhow::Result;
7
8use crate::providers::{ContentBlock, Message, MessageContent};
9use super::summarizer::Summarizer;
10use super::types::{AiCompressionMode, CompressionThresholds};
11
12/// Compressor for tool results.
13pub struct ToolCompressor {
14    /// Summarizer for AI-based compression.
15    summarizer: Option<Summarizer>,
16    /// Thresholds for content size.
17    thresholds: CompressionThresholds,
18}
19
20impl ToolCompressor {
21    /// Create a new tool compressor without AI summarization.
22    pub fn new_truncate_only(thresholds: CompressionThresholds) -> Self {
23        Self {
24            summarizer: None,
25            thresholds,
26        }
27    }
28
29    /// Create a new tool compressor with AI summarization.
30    pub fn new_with_ai(summarizer: Summarizer, thresholds: CompressionThresholds) -> Self {
31        Self {
32            summarizer: Some(summarizer),
33            thresholds,
34        }
35    }
36
37    /// Compress large tool results in all messages.
38    pub async fn compress_results(
39        &self,
40        messages: &[Message],
41        ai_mode: AiCompressionMode,
42    ) -> Result<Vec<Message>> {
43        let mut result = messages.to_vec();
44
45        for msg in &mut result {
46            if let MessageContent::Blocks(blocks) = &mut msg.content {
47                for block in blocks.iter_mut() {
48                    if let ContentBlock::ToolResult { content, .. } = block {
49                        let tokens = estimate_tokens_str(content);
50
51                        if tokens < self.thresholds.small_content {
52                            // Keep unchanged
53                            continue;
54                        }
55
56                        // Compress based on mode
57                        let compressed = self.compress_content(content, tokens, ai_mode).await?;
58                        *content = compressed;
59                    }
60                }
61            }
62        }
63
64        Ok(result)
65    }
66
67    /// Compress a single content string.
68    async fn compress_content(
69        &self,
70        content: &str,
71        tokens: u32,
72        ai_mode: AiCompressionMode,
73    ) -> Result<String> {
74        // No AI mode: truncate
75        if ai_mode == AiCompressionMode::None || self.summarizer.is_none() {
76            return Ok(self.truncate_content(content));
77        }
78
79        let summarizer = self.summarizer.as_ref().unwrap();
80
81        // Medium content: light summary
82        if tokens < self.thresholds.medium_content {
83            let summary = summarizer.summarize_light(content).await?;
84            return Ok(format!("[摘要] {}", summary));
85        }
86
87        // Large content: choose based on ai_mode
88        match ai_mode {
89            AiCompressionMode::Light => {
90                let summary = summarizer.summarize_light(content).await?;
91                Ok(format!("[摘要] {}", summary))
92            }
93            AiCompressionMode::Deep => {
94                let summary = summarizer.summarize_deep(content).await?;
95                Ok(format!("[详细摘要] {}", summary))
96            }
97            AiCompressionMode::None => Ok(self.truncate_content(content)),
98        }
99    }
100
101    /// Truncate content without AI.
102    fn truncate_content(&self, content: &str) -> String {
103        // Preserve ends for better context
104        Summarizer::truncate_preserve_ends(content, self.thresholds.small_content)
105    }
106
107    /// Check if a tool result needs compression.
108    pub fn needs_compression(content: &str, thresholds: &CompressionThresholds) -> bool {
109        estimate_tokens_str(content) >= thresholds.small_content
110    }
111}
112
113/// Estimate tokens from string (simplified).
114fn estimate_tokens_str(s: &str) -> u32 {
115    let (ascii, non_ascii) = count_chars(s);
116    let ascii_tokens = (ascii as f64 * 0.25).ceil() as u32;
117    let non_ascii_tokens = (non_ascii as f64 * 0.67).ceil() as u32;
118    ascii_tokens + non_ascii_tokens
119}
120
121/// Count ASCII and non-ASCII characters.
122fn count_chars(s: &str) -> (u32, u32) {
123    let mut ascii = 0u32;
124    let mut non_ascii = 0u32;
125    for ch in s.chars() {
126        if ch.is_ascii() {
127            ascii += 1;
128        } else {
129            non_ascii += 1;
130        }
131    }
132    (ascii, non_ascii)
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn test_needs_compression() {
141        let thresholds = CompressionThresholds::default();
142
143        let short = "短内容";
144        assert!(!ToolCompressor::needs_compression(short, &thresholds));
145
146        // Need ~500+ tokens to trigger (about 2000 chars for ASCII)
147        let long = "很长的内容...".repeat(200);
148        assert!(ToolCompressor::needs_compression(&long, &thresholds));
149    }
150
151    #[test]
152    fn test_truncate_content() {
153        let thresholds = CompressionThresholds::default();
154        let compressor = ToolCompressor::new_truncate_only(thresholds);
155
156        // Need content longer than threshold
157        let content = "开头内容中间很长的部分结尾内容".repeat(50);
158        let result = compressor.truncate_content(&content);
159
160        assert!(result.len() < content.len());
161        assert!(result.contains("[内容截断]"));
162    }
163
164    #[test]
165    fn test_estimate_tokens_str() {
166        let ascii = "hello world";
167        let tokens = estimate_tokens_str(ascii);
168        assert!(tokens > 0 && tokens < 10);
169
170        let chinese = "你好世界";
171        let tokens = estimate_tokens_str(chinese);
172        assert!(tokens > 0);
173    }
174}