Skip to main content

matrixcode_core/compress/
compressor.rs

1//! Compression functions and AI compressor implementation.
2
3use crate::providers::{
4    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
5};
6use crate::truncate::truncate_with_suffix;
7use anyhow::Result;
8use async_trait::async_trait;
9use std::collections::HashSet;
10
11use super::config::{CompressionBias, CompressionConfig};
12use super::types::{CompressionStrategy, SummarizedSegment};
13
14// ============================================================================
15// Compressor Trait
16// ============================================================================
17
18/// Compressor trait for different implementations.
19#[async_trait]
20pub trait Compressor: Send + Sync {
21    /// Compress messages using AI summarization.
22    async fn summarize(
23        &self,
24        messages: &[Message],
25        config: &CompressionConfig,
26    ) -> Result<SummarizedSegment>;
27
28    /// Get the model name used.
29    fn model_name(&self) -> &str;
30}
31
32/// AI-based compressor using a Provider.
33pub struct AiCompressor {
34    provider: Box<dyn Provider>,
35    model: String,
36}
37
38impl AiCompressor {
39    pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
40        Self { provider, model }
41    }
42}
43
44const SUMMARY_SYSTEM_PROMPT: &str = r#"你是一个对话历史压缩助手。将对话压缩为简洁摘要。
45
46输出要求:
47- 简洁:摘要控制在 200 字以内
48- 关键:只保留重要操作和决策
49- 敏感:必须保留用户的敏感指令
50- 任务:必须保留未完成的待办事项
51- 决策:必须保留关键方案选择和理由
52
53输出格式:
54【摘要】一句话概括主要工作
55【已完成】列出已完成的操作
56【未完成】列出待办任务(如有)
57【关键决策】重要选择及理由(如有)
58
59请直接输出内容。"#;
60
61#[async_trait]
62impl Compressor for AiCompressor {
63    async fn summarize(
64        &self,
65        messages: &[Message],
66        _config: &CompressionConfig,
67    ) -> Result<SummarizedSegment> {
68        let prompt = build_summary_prompt(messages);
69
70        let request = ChatRequest {
71            messages: vec![Message {
72                role: Role::User,
73                content: MessageContent::Text(prompt),
74            }],
75            tools: vec![],
76            system: Some(SUMMARY_SYSTEM_PROMPT.to_string()),
77            think: false,
78            max_tokens: 1024,
79            server_tools: vec![],
80            enable_caching: false,
81        };
82
83        let response = self.provider.chat(request).await?;
84        let summary_text = extract_text_from_response(&response);
85        let (summary, key_points) = parse_summary_response(&summary_text);
86
87        Ok(SummarizedSegment {
88            time_range: (chrono::Utc::now(), chrono::Utc::now()),
89            original_count: messages.len(),
90            summary,
91            key_points,
92        })
93    }
94
95    fn model_name(&self) -> &str {
96        &self.model
97    }
98}
99
100fn extract_text_from_response(response: &ChatResponse) -> String {
101    response
102        .content
103        .iter()
104        .filter_map(|block| {
105            if let ContentBlock::Text { text } = block {
106                Some(text.clone())
107            } else {
108                None
109            }
110        })
111        .collect::<Vec<_>>()
112        .join("\n")
113}
114
115fn parse_summary_response(text: &str) -> (String, Vec<String>) {
116    let mut summary = String::new();
117    let mut key_points: Vec<String> = Vec::new();
118
119    for line in text.lines() {
120        let line = line.trim();
121        if line.starts_with("•") || line.starts_with("-") || line.starts_with("*") {
122            let point = line.trim_start_matches(['•', '-', '*']).trim();
123            if !point.is_empty() {
124                key_points.push(point.to_string());
125            }
126        } else if !line.is_empty() && summary.is_empty() {
127            summary = line.to_string();
128        }
129    }
130
131    if summary.is_empty() && !text.is_empty() {
132        summary = text.lines().take(3).collect::<Vec<_>>().join(" ");
133        if summary.len() > 200 {
134            summary = truncate_with_suffix(&summary, 200);
135        }
136    }
137
138    (summary, key_points)
139}
140
141// ============================================================================
142// Compression Functions
143// ============================================================================
144
145/// Compress messages synchronously.
146pub fn compress_messages(
147    messages: &[Message],
148    strategy: CompressionStrategy,
149    config: &CompressionConfig,
150) -> Result<Vec<Message>> {
151    match strategy {
152        CompressionStrategy::Truncate => truncate_compress(messages, config),
153        CompressionStrategy::SlidingWindow => sliding_window_compress(messages, config),
154        CompressionStrategy::Summarize => sliding_window_compress(messages, config),
155        CompressionStrategy::BiasBased => compress_with_bias(messages, config),
156    }
157}
158
159/// Compress with bias-based scoring.
160pub fn compress_with_bias(
161    messages: &[Message],
162    config: &CompressionConfig,
163) -> Result<Vec<Message>> {
164    if messages.len() <= config.min_preserve_messages {
165        return Ok(messages.to_vec());
166    }
167
168    let scored: Vec<(usize, Message, f64)> = messages
169        .iter()
170        .enumerate()
171        .map(|(idx, msg)| {
172            (
173                idx,
174                msg.clone(),
175                calculate_preservation_score(msg, idx, messages.len(), &config.bias),
176            )
177        })
178        .collect();
179
180    let mut scored_with_recency: Vec<(usize, Message, f64)> = scored
181        .into_iter()
182        .map(|(idx, msg, score)| {
183            let recency_bonus = if idx >= messages.len() - config.min_preserve_messages {
184                100.0
185            } else {
186                (idx as f64 / messages.len() as f64) * 20.0
187            };
188            (idx, msg, score + recency_bonus)
189        })
190        .collect();
191
192    scored_with_recency.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
193
194    let target_count = if config.bias.aggressive {
195        config.min_preserve_messages
196    } else {
197        let estimated = estimate_total_tokens(messages);
198        let target_tokens = (estimated as f64 * config.target_ratio) as u32;
199        let avg = estimated / messages.len() as u32;
200        (target_tokens / avg.max(1)) as usize
201    };
202
203    let to_keep: HashSet<usize> = scored_with_recency
204        .iter()
205        .take(target_count)
206        .map(|(idx, _, _)| *idx)
207        .collect();
208
209    let compressed: Vec<Message> = messages
210        .iter()
211        .enumerate()
212        .filter(|(idx, _)| to_keep.contains(idx))
213        .map(|(_, msg)| msg.clone())
214        .collect();
215
216    Ok(compressed)
217}
218
219fn calculate_preservation_score(
220    message: &Message,
221    index: usize,
222    _total: usize,  // Reserved for future use (total message count)
223    bias: &CompressionBias,
224) -> f64 {
225    let mut score: f64 = 10.0;
226
227    // First message (user's original request) gets highest priority
228    if index == 0 {
229        score += 100.0;
230    }
231
232    match message.role {
233        Role::User => {
234            if bias.preserve_user_questions {
235                score += 30.0;
236            }
237        }
238        Role::Assistant => {
239            score += 5.0;
240        }
241        Role::Tool => {
242            if bias.preserve_tools {
243                score += 25.0;
244            }
245        }
246        Role::System => {
247            score += 40.0;
248        }
249    }
250
251    match &message.content {
252        MessageContent::Text(text) => {
253            for keyword in &bias.preserve_keywords {
254                if text.to_lowercase().contains(&keyword.to_lowercase()) {
255                    score += 15.0;
256                }
257            }
258            if contains_sensitive_instructions(text) {
259                score += 50.0;
260            }
261        }
262        MessageContent::Blocks(blocks) => {
263            for block in blocks {
264                match block {
265                    ContentBlock::ToolUse { name, .. } => {
266                        if bias.preserve_tools {
267                            score += 20.0;
268                        }
269                        if name == "write" || name == "edit" || name == "bash" {
270                            score += 10.0;
271                        }
272                        // todo_write gets high priority - preserve task tracking
273                        if name == "todo_write" {
274                            score += 60.0;
275                        }
276                        // ask tool contains key decisions
277                        if name == "ask" {
278                            score += 50.0;
279                        }
280                    }
281                    ContentBlock::ToolResult { content, .. } => {
282                        if bias.preserve_tools {
283                            score += 20.0;
284                        }
285                        if contains_sensitive_instructions(content) {
286                            score += 30.0;
287                        }
288                        // Preserve todo_write results (task status)
289                        if content.contains("TodoWrite") || content.contains("todo") {
290                            score += 40.0;
291                        }
292                        // Preserve ask responses (user decisions)
293                        if content.contains("AskUserQuestion") || content.contains("answer") {
294                            score += 30.0;
295                        }
296                    }
297                    ContentBlock::Thinking { .. } => {
298                        if bias.preserve_thinking {
299                            score += 25.0;
300                        } else {
301                            score -= 5.0;
302                        }
303                    }
304                    ContentBlock::Text { text } => {
305                        if contains_sensitive_instructions(text) {
306                            score += 50.0;
307                        }
308                    }
309                    _ => {}
310                }
311            }
312        }
313    }
314
315    score
316}
317
318fn contains_sensitive_instructions(text: &str) -> bool {
319    let lower = text.to_lowercase();
320    let patterns = [
321        "不要",
322        "禁止",
323        "必须",
324        "不允许",
325        "never",
326        "must not",
327        "do not",
328    ];
329    patterns.iter().any(|p| lower.contains(p))
330}
331
332fn truncate_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
333    if messages.len() <= config.min_preserve_messages {
334        return Ok(messages.to_vec());
335    }
336    Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
337}
338
339fn sliding_window_compress(
340    messages: &[Message],
341    config: &CompressionConfig,
342) -> Result<Vec<Message>> {
343    if messages.len() <= config.min_preserve_messages {
344        return Ok(messages.to_vec());
345    }
346
347    let target_tokens = (estimate_total_tokens(messages) as f64 * config.target_ratio) as u32;
348
349    for start_idx in config.min_preserve_messages..messages.len() {
350        let candidate = &messages[start_idx..];
351        if estimate_total_tokens(candidate) <= target_tokens {
352            return Ok(candidate.to_vec());
353        }
354    }
355
356    Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
357}
358
359// ============================================================================
360// Token Estimation
361// ============================================================================
362
363/// Estimate token count for a message.
364pub fn estimate_tokens(message: &Message) -> u32 {
365    let (ascii, non_ascii) = match &message.content {
366        MessageContent::Text(t) => count_chars(t),
367        MessageContent::Blocks(blocks) => {
368            let mut a = 0u32;
369            let mut n = 0u32;
370            for block in blocks {
371                match block {
372                    ContentBlock::Text { text } => {
373                        let (ca, cn) = count_chars(text);
374                        a += ca;
375                        n += cn;
376                    }
377                    ContentBlock::ToolUse { name, input, .. } => {
378                        let (ca, cn) = count_chars(name);
379                        a += ca;
380                        n += cn;
381                        let (ja, jn) = count_chars(&input.to_string());
382                        a += ja;
383                        n += jn;
384                    }
385                    ContentBlock::ToolResult { content, .. } => {
386                        let (ca, cn) = count_chars(content);
387                        a += ca;
388                        n += cn;
389                    }
390                    ContentBlock::Thinking { thinking, .. } => {
391                        let (ca, cn) = count_chars(thinking);
392                        a += ca;
393                        n += cn;
394                    }
395                    _ => {}
396                }
397            }
398            (a, n)
399        }
400    };
401
402    let ascii_tokens = (ascii as f64 * 0.25).ceil() as u32;
403    let non_ascii_tokens = (non_ascii as f64 * 0.67).ceil() as u32;
404    (ascii_tokens + non_ascii_tokens + 10).max(1)
405}
406
407fn count_chars(s: &str) -> (u32, u32) {
408    let mut ascii = 0u32;
409    let mut non_ascii = 0u32;
410    for ch in s.chars() {
411        if ch.is_ascii() {
412            ascii += 1;
413        } else {
414            non_ascii += 1;
415        }
416    }
417    (ascii, non_ascii)
418}
419
420/// Estimate total tokens for a message list.
421pub fn estimate_total_tokens(messages: &[Message]) -> u32 {
422    messages.iter().map(estimate_tokens).sum()
423}
424
425/// Check if compression should be triggered.
426pub fn should_compress(
427    current_tokens: u32,
428    context_size: Option<u32>,
429    config: &CompressionConfig,
430) -> bool {
431    match context_size {
432        Some(size) => (current_tokens as f64 / size as f64) >= config.threshold,
433        None => false,
434    }
435}
436
437/// Build a prompt for summarization.
438pub fn build_summary_prompt(messages: &[Message]) -> String {
439    let history = messages
440        .iter()
441        .map(|m| {
442            let role = match m.role {
443                Role::User => "用户",
444                Role::Assistant => "助手",
445                Role::Tool => "工具",
446                Role::System => "系统",
447            };
448            let preview = match &m.content {
449                MessageContent::Text(t) => truncate_with_suffix(t, 200),
450                MessageContent::Blocks(blocks) => blocks
451                    .iter()
452                    .map(|b| match b {
453                        ContentBlock::Text { text } => truncate_with_suffix(text, 100),
454                        ContentBlock::ToolUse { name, .. } => format!("[工具: {}]", name),
455                        ContentBlock::ToolResult { content, .. } => {
456                            truncate_with_suffix(content, 100)
457                        }
458                        _ => "[...]".to_string(),
459                    })
460                    .collect::<Vec<_>>()
461                    .join(" | "),
462            };
463            format!("{}: {}", role, preview)
464        })
465        .collect::<Vec<_>>()
466        .join("\n");
467
468    format!(
469        "请将以下对话压缩为简洁摘要({} 条消息):\n{}",
470        messages.len(),
471        history
472    )
473}
474
475// ============================================================================
476// Tests
477// ============================================================================
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482
483    #[test]
484    fn test_estimate_tokens_simple() {
485        let msg = Message {
486            role: Role::User,
487            content: MessageContent::Text("Hello world".to_string()),
488        };
489        assert!(estimate_tokens(&msg) >= 3);
490    }
491
492    #[test]
493    fn test_should_compress() {
494        let config = CompressionConfig::default();
495        assert!(!should_compress(100_000, Some(200_000), &config));
496        assert!(should_compress(160_000, Some(200_000), &config));
497    }
498}