Skip to main content

matrixcode_core/compress/
compressor.rs

1//! Compression functions and AI compressor implementation.
2
3use crate::providers::{
4    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
5};
6use crate::truncate::truncate_with_suffix;
7use anyhow::Result;
8use async_trait::async_trait;
9use std::collections::HashSet;
10
11use super::config::{CompressionBias, CompressionConfig};
12use super::types::{CompressionStrategy, SummarizedSegment};
13
14// ============================================================================
15// Compressor Trait
16// ============================================================================
17
18/// Compressor trait for different implementations.
19#[async_trait]
20pub trait Compressor: Send + Sync {
21    /// Compress messages using AI summarization.
22    async fn summarize(
23        &self,
24        messages: &[Message],
25        config: &CompressionConfig,
26    ) -> Result<SummarizedSegment>;
27
28    /// Get the model name used.
29    fn model_name(&self) -> &str;
30}
31
32/// AI-based compressor using a Provider.
33pub struct AiCompressor {
34    provider: Box<dyn Provider>,
35    model: String,
36}
37
38impl AiCompressor {
39    pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
40        Self { provider, model }
41    }
42}
43
44const SUMMARY_SYSTEM_PROMPT: &str = r#"CRITICAL: 仅用文本响应。不要调用任何工具。
45
46- 不要使用 read、bash、grep、glob、edit、write 或任何其他工具
47- 你已在上方对话中获得所需的所有上下文
48- 工具调用将被拒绝并浪费你唯一的 turn — 你将失败任务
49- 你的整个响应必须是纯文本摘要
50
51---
52
53你是一个对话历史压缩助手。将对话压缩为结构化摘要。
54
55在提供最终摘要前,将你的分析包裹在 <analysis> 标签中以组织思路:
56<analysis>
571. 按时间顺序分析每条消息
582. 识别:用户请求、助手行动、关键决策、错误及修复
593. 特别注意用户的敏感指令(禁止/必须)
604. 双重检查技术准确性和完整性
61</analysis>
62
63输出要求:
64- 结构化:使用9个章节格式
65- 关键:只保留重要信息,忽略无关细节
66- 敏感:必须保留用户的敏感指令(禁止、必须等)— 原样保留
67- 任务:必须保留未完成的待办事项
68- 决策:必须保留关键方案选择和理由
69
709章节输出格式:
71【摘要】一句话概括主要工作(50字以内)
72【已完成】列出已完成的操作(工具调用、文件变更)
73【未完成】列出待办任务和阻塞项 — 最关键,压缩后恢复需要
74【关键决策】重要选择及理由(技术选型、方案决策)
75【敏感指令】用户的禁止/必须指令(必须原样保留)
76【技术栈】使用的语言、框架、库、工具
77【文件变更】读取、修改、创建的文件路径
78【问题记录】遇到的问题及解决方案
79【下一步】建议的下一步操作(直接引用最近对话展示任务中断点)
80
81每章节控制在100字以内,空章节可省略。
82输出摘要后立即停止,不要添加任何解释或后续建议。
83
84REMINDER: 不要调用任何工具。仅用纯文本响应。"#;
85
86#[async_trait]
87impl Compressor for AiCompressor {
88    async fn summarize(
89        &self,
90        messages: &[Message],
91        _config: &CompressionConfig,
92    ) -> Result<SummarizedSegment> {
93        let prompt = build_summary_prompt(messages);
94
95        let request = ChatRequest {
96            messages: vec![Message {
97                role: Role::User,
98                content: MessageContent::Text(prompt),
99            }],
100            tools: vec![],
101            system: Some(SUMMARY_SYSTEM_PROMPT.to_string()),
102            think: false,
103            max_tokens: 1024,
104            server_tools: vec![],
105            enable_caching: false,
106        };
107
108        let response = self.provider.chat(request).await?;
109        let summary_text = extract_text_from_response(&response);
110        let (summary, key_points) = parse_summary_response(&summary_text);
111
112        Ok(SummarizedSegment {
113            time_range: (chrono::Utc::now(), chrono::Utc::now()),
114            original_count: messages.len(),
115            summary,
116            key_points,
117        })
118    }
119
120    fn model_name(&self) -> &str {
121        &self.model
122    }
123}
124
125fn extract_text_from_response(response: &ChatResponse) -> String {
126    response
127        .content
128        .iter()
129        .filter_map(|block| {
130            if let ContentBlock::Text { text } = block {
131                Some(text.clone())
132            } else {
133                None
134            }
135        })
136        .collect::<Vec<_>>()
137        .join("\n")
138}
139
140fn parse_summary_response(text: &str) -> (String, Vec<String>) {
141    let mut summary = String::new();
142    let mut key_points: Vec<String> = Vec::new();
143
144    // Parse 9-section structured format
145    let sections = [
146        "【摘要】", "【已完成】", "【未完成】", "【关键决策】",
147        "【敏感指令】", "【技术栈】", "【文件变更】", "【问题记录】", "【下一步】"
148    ];
149
150    for line in text.lines() {
151        let line = line.trim();
152
153        // Check if this is a section header
154        let is_header = sections.iter().any(|s| line.starts_with(s));
155
156        if is_header {
157            // Extract content after the header
158            for section in &sections {
159                if line.starts_with(section) {
160                    let replaced = line.replace(section, "");
161                    let content = replaced.trim();
162                    if !content.is_empty() {
163                        if *section == "【摘要】" {
164                            summary = content.to_string();
165                        } else {
166                            key_points.push(format!("{}{}", section, content));
167                        }
168                    }
169                    break;
170                }
171            }
172        } else if !line.is_empty() {
173            // This is content under a section
174            if line.starts_with("•") || line.starts_with("-") || line.starts_with("*") {
175                let point = line.trim_start_matches(['•', '-', '*']).trim();
176                if !point.is_empty() {
177                    key_points.push(point.to_string());
178                }
179            } else if summary.is_empty() {
180                // Fallback: first non-empty line as summary
181                summary = line.to_string();
182            }
183        }
184    }
185
186    // Fallback if no structured format found
187    if summary.is_empty() && !text.is_empty() {
188        summary = text.lines().take(3).collect::<Vec<_>>().join(" ");
189        if summary.len() > 200 {
190            summary = truncate_with_suffix(&summary, 200);
191        }
192    }
193
194    (summary, key_points)
195}
196
197// ============================================================================
198// Compression Functions
199// ============================================================================
200
201/// Compress messages synchronously.
202pub fn compress_messages(
203    messages: &[Message],
204    strategy: CompressionStrategy,
205    config: &CompressionConfig,
206) -> Result<Vec<Message>> {
207    match strategy {
208        CompressionStrategy::Truncate => truncate_compress(messages, config),
209        CompressionStrategy::SlidingWindow => sliding_window_compress(messages, config),
210        CompressionStrategy::Summarize => sliding_window_compress(messages, config),
211        CompressionStrategy::BiasBased => compress_with_bias(messages, config),
212    }
213}
214
215/// Compress with bias-based scoring.
216pub fn compress_with_bias(
217    messages: &[Message],
218    config: &CompressionConfig,
219) -> Result<Vec<Message>> {
220    if messages.len() <= config.min_preserve_messages {
221        return Ok(messages.to_vec());
222    }
223
224    let scored: Vec<(usize, Message, f64)> = messages
225        .iter()
226        .enumerate()
227        .map(|(idx, msg)| {
228            (
229                idx,
230                msg.clone(),
231                calculate_preservation_score(msg, idx, messages.len(), &config.bias),
232            )
233        })
234        .collect();
235
236    let mut scored_with_recency: Vec<(usize, Message, f64)> = scored
237        .into_iter()
238        .map(|(idx, msg, score)| {
239            let recency_bonus = if idx >= messages.len() - config.min_preserve_messages {
240                100.0
241            } else {
242                (idx as f64 / messages.len() as f64) * 20.0
243            };
244            (idx, msg, score + recency_bonus)
245        })
246        .collect();
247
248    scored_with_recency.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
249
250    let target_count = if config.bias.aggressive {
251        config.min_preserve_messages
252    } else {
253        let estimated = estimate_total_tokens(messages);
254        let target_tokens = (estimated as f64 * config.target_ratio) as u32;
255        let avg = estimated / messages.len() as u32;
256        (target_tokens / avg.max(1)) as usize
257    };
258
259    let to_keep: HashSet<usize> = scored_with_recency
260        .iter()
261        .take(target_count)
262        .map(|(idx, _, _)| *idx)
263        .collect();
264
265    let compressed: Vec<Message> = messages
266        .iter()
267        .enumerate()
268        .filter(|(idx, _)| to_keep.contains(idx))
269        .map(|(_, msg)| msg.clone())
270        .collect();
271
272    Ok(compressed)
273}
274
275fn calculate_preservation_score(
276    message: &Message,
277    index: usize,
278    _total: usize,  // Reserved for future use (total message count)
279    bias: &CompressionBias,
280) -> f64 {
281    let mut score: f64 = 10.0;
282
283    // First message (user's original request) gets highest priority
284    if index == 0 {
285        score += 100.0;
286    }
287
288    match message.role {
289        Role::User => {
290            if bias.preserve_user_questions {
291                score += 30.0;
292            }
293        }
294        Role::Assistant => {
295            score += 5.0;
296        }
297        Role::Tool => {
298            if bias.preserve_tools {
299                score += 25.0;
300            }
301        }
302        Role::System => {
303            score += 40.0;
304        }
305    }
306
307    match &message.content {
308        MessageContent::Text(text) => {
309            for keyword in &bias.preserve_keywords {
310                if text.to_lowercase().contains(&keyword.to_lowercase()) {
311                    score += 15.0;
312                }
313            }
314            if contains_sensitive_instructions(text) {
315                score += 50.0;
316            }
317        }
318        MessageContent::Blocks(blocks) => {
319            for block in blocks {
320                match block {
321                    ContentBlock::ToolUse { name, .. } => {
322                        if bias.preserve_tools {
323                            score += 20.0;
324                        }
325                        if name == "write" || name == "edit" || name == "bash" {
326                            score += 10.0;
327                        }
328                        // todo_write gets high priority - preserve task tracking
329                        if name == "todo_write" {
330                            score += 60.0;
331                        }
332                        // ask tool contains key decisions
333                        if name == "ask" {
334                            score += 50.0;
335                        }
336                    }
337                    ContentBlock::ToolResult { content, .. } => {
338                        if bias.preserve_tools {
339                            score += 20.0;
340                        }
341                        if contains_sensitive_instructions(content) {
342                            score += 30.0;
343                        }
344                        // Preserve todo_write results (task status)
345                        if content.contains("TodoWrite") || content.contains("todo") {
346                            score += 40.0;
347                        }
348                        // Preserve ask responses (user decisions)
349                        if content.contains("AskUserQuestion") || content.contains("answer") {
350                            score += 30.0;
351                        }
352                    }
353                    ContentBlock::Thinking { .. } => {
354                        if bias.preserve_thinking {
355                            score += 25.0;
356                        } else {
357                            score -= 5.0;
358                        }
359                    }
360                    ContentBlock::Text { text } => {
361                        if contains_sensitive_instructions(text) {
362                            score += 50.0;
363                        }
364                    }
365                    _ => {}
366                }
367            }
368        }
369    }
370
371    score
372}
373
374fn contains_sensitive_instructions(text: &str) -> bool {
375    let lower = text.to_lowercase();
376    let patterns = [
377        "不要",
378        "禁止",
379        "必须",
380        "不允许",
381        "never",
382        "must not",
383        "do not",
384    ];
385    patterns.iter().any(|p| lower.contains(p))
386}
387
388fn truncate_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
389    if messages.len() <= config.min_preserve_messages {
390        return Ok(messages.to_vec());
391    }
392    Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
393}
394
395fn sliding_window_compress(
396    messages: &[Message],
397    config: &CompressionConfig,
398) -> Result<Vec<Message>> {
399    if messages.len() <= config.min_preserve_messages {
400        return Ok(messages.to_vec());
401    }
402
403    // Enhanced sliding window strategy:
404    // 1. Always keep first message (original user request)
405    // 2. Summarize middle messages if too long
406    // 3. Keep recent messages intact
407
408    let first_msg = messages.first().cloned();
409    let recent_start = messages.len().saturating_sub(config.min_preserve_messages);
410    let recent_msgs = &messages[recent_start..];
411
412    // Calculate tokens for first + recent
413    let first_tokens = first_msg.as_ref().map(estimate_tokens).unwrap_or(0);
414    let recent_tokens = estimate_total_tokens(recent_msgs);
415    let current_total = estimate_total_tokens(messages);
416    let target_tokens = (current_total as f64 * config.target_ratio) as u32;
417
418    // If first + recent already exceeds target, just use recent (drop first)
419    if first_tokens + recent_tokens <= target_tokens {
420        // We can keep first message + recent messages
421        let mut result: Vec<Message> = Vec::new();
422        if let Some(first) = first_msg {
423            result.push(first);
424        }
425        result.extend(recent_msgs.iter().cloned());
426        return Ok(result);
427    }
428
429    // If still too long, try dropping older messages from recent section
430    for drop_count in 0..recent_msgs.len() {
431        let candidate = &recent_msgs[drop_count..];
432        if estimate_total_tokens(candidate) <= target_tokens {
433            return Ok(candidate.to_vec());
434        }
435    }
436
437    // Last resort: just keep minimum recent messages
438    Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
439}
440
441// ============================================================================
442// Token Estimation
443// ============================================================================
444
445/// Estimate token count for a message.
446pub fn estimate_tokens(message: &Message) -> u32 {
447    let (ascii, non_ascii) = match &message.content {
448        MessageContent::Text(t) => count_chars(t),
449        MessageContent::Blocks(blocks) => {
450            let mut a = 0u32;
451            let mut n = 0u32;
452            for block in blocks {
453                match block {
454                    ContentBlock::Text { text } => {
455                        let (ca, cn) = count_chars(text);
456                        a += ca;
457                        n += cn;
458                    }
459                    ContentBlock::ToolUse { name, input, .. } => {
460                        let (ca, cn) = count_chars(name);
461                        a += ca;
462                        n += cn;
463                        let (ja, jn) = count_chars(&input.to_string());
464                        a += ja;
465                        n += jn;
466                    }
467                    ContentBlock::ToolResult { content, .. } => {
468                        let (ca, cn) = count_chars(content);
469                        a += ca;
470                        n += cn;
471                    }
472                    ContentBlock::Thinking { thinking, .. } => {
473                        let (ca, cn) = count_chars(thinking);
474                        a += ca;
475                        n += cn;
476                    }
477                    _ => {}
478                }
479            }
480            (a, n)
481        }
482    };
483
484    let ascii_tokens = (ascii as f64 * 0.25).ceil() as u32;
485    let non_ascii_tokens = (non_ascii as f64 * 0.67).ceil() as u32;
486    (ascii_tokens + non_ascii_tokens + 10).max(1)
487}
488
489fn count_chars(s: &str) -> (u32, u32) {
490    let mut ascii = 0u32;
491    let mut non_ascii = 0u32;
492    for ch in s.chars() {
493        if ch.is_ascii() {
494            ascii += 1;
495        } else {
496            non_ascii += 1;
497        }
498    }
499    (ascii, non_ascii)
500}
501
502/// Estimate total tokens for a message list.
503pub fn estimate_total_tokens(messages: &[Message]) -> u32 {
504    messages.iter().map(estimate_tokens).sum()
505}
506
507/// Check if compression should be triggered.
508pub fn should_compress(
509    current_tokens: u32,
510    context_size: Option<u32>,
511    config: &CompressionConfig,
512) -> bool {
513    match context_size {
514        Some(size) => (current_tokens as f64 / size as f64) >= config.threshold,
515        None => false,
516    }
517}
518
519/// Build a prompt for summarization.
520pub fn build_summary_prompt(messages: &[Message]) -> String {
521    let history = messages
522        .iter()
523        .map(|m| {
524            let role = match m.role {
525                Role::User => "用户",
526                Role::Assistant => "助手",
527                Role::Tool => "工具",
528                Role::System => "系统",
529            };
530            let preview = match &m.content {
531                MessageContent::Text(t) => truncate_with_suffix(t, 200),
532                MessageContent::Blocks(blocks) => blocks
533                    .iter()
534                    .map(|b| match b {
535                        ContentBlock::Text { text } => truncate_with_suffix(text, 100),
536                        ContentBlock::ToolUse { name, .. } => format!("[工具: {}]", name),
537                        ContentBlock::ToolResult { content, .. } => {
538                            truncate_with_suffix(content, 100)
539                        }
540                        _ => "[...]".to_string(),
541                    })
542                    .collect::<Vec<_>>()
543                    .join(" | "),
544            };
545            format!("{}: {}", role, preview)
546        })
547        .collect::<Vec<_>>()
548        .join("\n");
549
550    format!(
551        "请将以下对话压缩为简洁摘要({} 条消息):\n{}",
552        messages.len(),
553        history
554    )
555}
556
557// ============================================================================
558// New Pipeline-Based Compression (Async)
559// ============================================================================
560
561use super::pipeline::CompressionPipeline;
562use super::types::AiCompressionMode;
563
564/// Compress messages with AI assistance (async version).
565///
566/// This is the new recommended API for compression with intelligent
567/// scoring, dependency tracking, and content summarization.
568pub async fn compress_messages_with_ai(
569    messages: &[Message],
570    config: &CompressionConfig,
571    ai_mode: AiCompressionMode,
572    fast_model: Option<Box<dyn Provider>>,
573    token_usage: u32,
574    context_window: u32,
575) -> Result<Vec<Message>> {
576    let mut pipeline = match (ai_mode, fast_model) {
577        (AiCompressionMode::None, _) => CompressionPipeline::new_rule_only(config.clone()),
578        (AiCompressionMode::Light | AiCompressionMode::Deep, Some(model)) => {
579            CompressionPipeline::new_with_ai(config.clone(), model)
580        }
581        _ => CompressionPipeline::new_rule_only(config.clone()),
582    };
583
584    let result = pipeline.execute(messages, ai_mode, token_usage, context_window).await?;
585    Ok(result.messages)
586}
587
588/// Compress messages with full AI support (async version).
589///
590/// Uses both fast_model and main_model for different compression tasks.
591pub async fn compress_messages_with_full_ai(
592    messages: &[Message],
593    config: &CompressionConfig,
594    ai_mode: AiCompressionMode,
595    fast_model: Box<dyn Provider>,
596    main_model: Box<dyn Provider>,
597    token_usage: u32,
598    context_window: u32,
599) -> Result<Vec<Message>> {
600    let mut pipeline = CompressionPipeline::new_with_full_ai(
601        config.clone(),
602        fast_model,
603        main_model,
604    );
605
606    let result = pipeline.execute(messages, ai_mode, token_usage, context_window).await?;
607    Ok(result.messages)
608}
609
610/// Score messages without compressing (analysis only).
611///
612/// Useful for debugging and understanding compression decisions.
613pub fn score_messages_only(
614    messages: &[Message],
615    config: &CompressionConfig,
616) -> Vec<super::types::ScoredMessage> {
617    let pipeline = CompressionPipeline::new_rule_only(config.clone());
618    pipeline.score_only(messages)
619}
620
621// ============================================================================
622// Tests
623// ============================================================================
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628
629    #[test]
630    fn test_estimate_tokens_simple() {
631        let msg = Message {
632            role: Role::User,
633            content: MessageContent::Text("Hello world".to_string()),
634        };
635        assert!(estimate_tokens(&msg) >= 3);
636    }
637
638    #[test]
639    fn test_should_compress() {
640        let config = CompressionConfig::default();
641        // Threshold is 0.5, so 100K/200K = 0.5 triggers compression
642        assert!(should_compress(100_000, Some(200_000), &config));
643        // 80K/200K = 0.4, below threshold
644        assert!(!should_compress(80_000, Some(200_000), &config));
645    }
646}