Skip to main content

matrixcode_core/compress/
compressor.rs

1//! Compression functions and AI compressor implementation.
2
3use crate::providers::{
4    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
5};
6use crate::tokenizer::{count_tokens, message_overhead};
7use crate::truncate::truncate_with_suffix;
8use anyhow::Result;
9use async_trait::async_trait;
10use std::collections::HashSet;
11
12use super::config::{CompressionBias, CompressionConfig};
13use super::types::{CompressionStrategy, SummarizedSegment};
14
15// ============================================================================
16// Compressor Trait
17// ============================================================================
18
19/// Compressor trait for different implementations.
20#[async_trait]
21pub trait Compressor: Send + Sync {
22    /// Compress messages using AI summarization.
23    async fn summarize(
24        &self,
25        messages: &[Message],
26        config: &CompressionConfig,
27    ) -> Result<SummarizedSegment>;
28
29    /// Get the model name used.
30    fn model_name(&self) -> &str;
31}
32
33/// AI-based compressor using a Provider.
34pub struct AiCompressor {
35    provider: Box<dyn Provider>,
36    model: String,
37    hardcode_config: crate::compress::hardcode_config::HardcodeConfig,
38}
39
40impl AiCompressor {
41    pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
42        Self {
43            provider,
44            model,
45            hardcode_config: crate::compress::hardcode_config::HardcodeConfig::default(),
46        }
47    }
48    
49    /// Set custom hardcode config
50    pub fn with_hardcode_config(mut self, config: crate::compress::hardcode_config::HardcodeConfig) -> Self {
51        self.hardcode_config = config;
52        self
53    }
54}
55
56const SUMMARY_SYSTEM_PROMPT: &str = r#"CRITICAL: 仅用文本响应。不要调用任何工具。
57
58- 不要使用 read、bash、grep、glob、edit、write 或任何其他工具
59- 你已在上方对话中获得所需的所有上下文
60- 工具调用将被拒绝并浪费你唯一的 turn — 你将失败任务
61- 你的整个响应必须是纯文本摘要
62
63---
64
65你是一个对话历史压缩助手。将对话压缩为结构化摘要。
66
67在提供最终摘要前,将你的分析包裹在 <analysis> 标签中以组织思路:
68<analysis>
691. 按时间顺序分析每条消息
702. 识别:用户请求、助手行动、关键决策、错误及修复
713. 特别注意用户的敏感指令(禁止/必须)
724. 双重检查技术准确性和完整性
73</analysis>
74
75输出要求:
76- 结构化:使用9个章节格式
77- 关键:只保留重要信息,忽略无关细节
78- 敏感:必须保留用户的敏感指令(禁止、必须等)— 原样保留
79- 任务:必须保留未完成的待办事项
80- 决策:必须保留关键方案选择和理由
81
829章节输出格式:
83【摘要】一句话概括主要工作(50字以内)
84【已完成】列出已完成的操作(工具调用、文件变更)
85【未完成】列出待办任务和阻塞项 — 最关键,压缩后恢复需要
86【关键决策】重要选择及理由(技术选型、方案决策)
87【敏感指令】用户的禁止/必须指令(必须原样保留)
88【技术栈】使用的语言、框架、库、工具
89【文件变更】读取、修改、创建的文件路径
90【问题记录】遇到的问题及解决方案
91【下一步】建议的下一步操作(直接引用最近对话展示任务中断点)
92
93每章节控制在100字以内,空章节可省略。
94输出摘要后立即停止,不要添加任何解释或后续建议。
95
96REMINDER: 不要调用任何工具。仅用纯文本响应。"#;
97
98#[async_trait]
99impl Compressor for AiCompressor {
100    async fn summarize(
101        &self,
102        messages: &[Message],
103        _config: &CompressionConfig,
104    ) -> Result<SummarizedSegment> {
105        let prompt = build_summary_prompt(messages);
106
107        let request = ChatRequest {
108            messages: vec![Message {
109                role: Role::User,
110                content: MessageContent::Text(prompt),
111            }],
112            tools: vec![],
113            system: Some(SUMMARY_SYSTEM_PROMPT.to_string()),
114            think: false,
115            max_tokens: 1024,
116            server_tools: vec![],
117            enable_caching: false,
118        };
119
120        let response = self.provider.chat(request).await?;
121        let summary_text = extract_text_from_response(&response);
122        let (summary, key_points) = parse_summary_response(&summary_text, &self.hardcode_config);
123
124        Ok(SummarizedSegment {
125            time_range: (chrono::Utc::now(), chrono::Utc::now()),
126            original_count: messages.len(),
127            summary,
128            key_points,
129        })
130    }
131
132    fn model_name(&self) -> &str {
133        &self.model
134    }
135}
136
137fn extract_text_from_response(response: &ChatResponse) -> String {
138    response
139        .content
140        .iter()
141        .filter_map(|block| {
142            if let ContentBlock::Text { text } = block {
143                Some(text.clone())
144            } else {
145                None
146            }
147        })
148        .collect::<Vec<_>>()
149        .join("\n")
150}
151
152fn parse_summary_response(text: &str, config: &crate::compress::hardcode_config::HardcodeConfig) -> (String, Vec<String>) {
153    let mut summary = String::new();
154    let mut key_points: Vec<String> = Vec::new();
155
156    // Parse 9-section structured format
157    let sections = [
158        "【摘要】",
159        "【已完成】",
160        "【未完成】",
161        "【关键决策】",
162        "【敏感指令】",
163        "【技术栈】",
164        "【文件变更】",
165        "【问题记录】",
166        "【下一步】",
167    ];
168
169    for line in text.lines() {
170        let line = line.trim();
171
172        // Check if this is a section header
173        let is_header = sections.iter().any(|s| line.starts_with(s));
174
175        if is_header {
176            // Extract content after the header
177            for section in &sections {
178                if line.starts_with(section) {
179                    let replaced = line.replace(section, "");
180                    let content = replaced.trim();
181                    if !content.is_empty() {
182                        if *section == "【摘要】" {
183                            summary = content.to_string();
184                        } else {
185                            key_points.push(format!("{}{}", section, content));
186                        }
187                    }
188                    break;
189                }
190            }
191        } else if !line.is_empty() {
192            // This is content under a section
193            if line.starts_with("•") || line.starts_with("-") || line.starts_with("*") {
194                let point = line.trim_start_matches(['•', '-', '*']).trim();
195                if !point.is_empty() {
196                    key_points.push(point.to_string());
197                }
198            } else if summary.is_empty() {
199                // Fallback: first non-empty line as summary
200                summary = line.to_string();
201            }
202        }
203    }
204
205    // Fallback if no structured format found
206    if summary.is_empty() && !text.is_empty() {
207        summary = text.lines().take(3).collect::<Vec<_>>().join(" ");
208        if summary.len() > config.summary_length_threshold {
209            summary = truncate_with_suffix(&summary, config.summary_length_threshold);
210        }
211    }
212
213    (summary, key_points)
214}
215
216// ============================================================================
217// Compression Functions
218// ============================================================================
219
220/// Compress messages synchronously.
221pub fn compress_messages(
222    messages: &[Message],
223    strategy: CompressionStrategy,
224    config: &CompressionConfig,
225) -> Result<Vec<Message>> {
226    match strategy {
227        CompressionStrategy::Truncate => truncate_compress(messages, config),
228        CompressionStrategy::SlidingWindow => sliding_window_compress(messages, config),
229        CompressionStrategy::Summarize => sliding_window_compress(messages, config),
230        CompressionStrategy::BiasBased => compress_with_bias(messages, config),
231    }
232}
233
234/// Compress with bias-based scoring.
235pub fn compress_with_bias(
236    messages: &[Message],
237    config: &CompressionConfig,
238) -> Result<Vec<Message>> {
239    if messages.len() <= config.min_preserve_messages {
240        return Ok(messages.to_vec());
241    }
242
243    let scored: Vec<(usize, Message, f64)> = messages
244        .iter()
245        .enumerate()
246        .map(|(idx, msg)| {
247            (
248                idx,
249                msg.clone(),
250                calculate_preservation_score(msg, idx, messages.len(), &config.bias),
251            )
252        })
253        .collect();
254
255    let mut scored_with_recency: Vec<(usize, Message, f64)> = scored
256        .into_iter()
257        .map(|(idx, msg, score)| {
258            let recency_bonus = if idx >= messages.len() - config.min_preserve_messages {
259                100.0
260            } else {
261                (idx as f64 / messages.len() as f64) * 20.0
262            };
263            (idx, msg, score + recency_bonus)
264        })
265        .collect();
266
267    scored_with_recency.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
268
269    let target_count = if config.bias.aggressive {
270        config.min_preserve_messages
271    } else {
272        let estimated = estimate_total_tokens(messages);
273        let target_tokens = (estimated as f64 * config.target_ratio) as u32;
274        let avg = estimated / messages.len() as u32;
275        (target_tokens / avg.max(1)) as usize
276    };
277
278    let to_keep: HashSet<usize> = scored_with_recency
279        .iter()
280        .take(target_count)
281        .map(|(idx, _, _)| *idx)
282        .collect();
283
284    let compressed: Vec<Message> = messages
285        .iter()
286        .enumerate()
287        .filter(|(idx, _)| to_keep.contains(idx))
288        .map(|(_, msg)| msg.clone())
289        .collect();
290
291    Ok(compressed)
292}
293
294fn calculate_preservation_score(
295    message: &Message,
296    index: usize,
297    _total: usize, // Reserved for future use (total message count)
298    bias: &CompressionBias,
299) -> f64 {
300    let mut score: f64 = 10.0;
301
302    // First message (user's original request) gets highest priority
303    if index == 0 {
304        score += 100.0;
305    }
306
307    match message.role {
308        Role::User => {
309            if bias.preserve_user_questions {
310                score += 30.0;
311            }
312        }
313        Role::Assistant => {
314            score += 5.0;
315        }
316        Role::Tool => {
317            if bias.preserve_tools {
318                score += 25.0;
319            }
320        }
321        Role::System => {
322            score += 40.0;
323        }
324    }
325
326    match &message.content {
327        MessageContent::Text(text) => {
328            for keyword in &bias.preserve_keywords {
329                if text.to_lowercase().contains(&keyword.to_lowercase()) {
330                    score += 15.0;
331                }
332            }
333            if contains_sensitive_instructions(text) {
334                score += 50.0;
335            }
336        }
337        MessageContent::Blocks(blocks) => {
338            for block in blocks {
339                match block {
340                    ContentBlock::ToolUse { name, .. } => {
341                        if bias.preserve_tools {
342                            score += 20.0;
343                        }
344                        if name == "write" || name == "edit" || name == "bash" {
345                            score += 10.0;
346                        }
347                        // todo_write gets high priority - preserve task tracking
348                        if name == "todo_write" {
349                            score += 60.0;
350                        }
351                        // ask tool contains key decisions
352                        if name == "ask" {
353                            score += 50.0;
354                        }
355                    }
356                    ContentBlock::ToolResult { content, .. } => {
357                        if bias.preserve_tools {
358                            score += 20.0;
359                        }
360                        if contains_sensitive_instructions(content) {
361                            score += 30.0;
362                        }
363                        // Preserve todo_write results (task status)
364                        if content.contains("TodoWrite") || content.contains("todo") {
365                            score += 40.0;
366                        }
367                        // Preserve ask responses (user decisions)
368                        if content.contains("AskUserQuestion") || content.contains("answer") {
369                            score += 30.0;
370                        }
371                    }
372                    ContentBlock::Thinking { .. } => {
373                        if bias.preserve_thinking {
374                            score += 25.0;
375                        } else {
376                            score -= 5.0;
377                        }
378                    }
379                    ContentBlock::Text { text } => {
380                        if contains_sensitive_instructions(text) {
381                            score += 50.0;
382                        }
383                    }
384                    _ => {}
385                }
386            }
387        }
388    }
389
390    score
391}
392
393fn contains_sensitive_instructions(text: &str) -> bool {
394    let lower = text.to_lowercase();
395    let patterns = [
396        "不要",
397        "禁止",
398        "必须",
399        "不允许",
400        "never",
401        "must not",
402        "do not",
403    ];
404    patterns.iter().any(|p| lower.contains(p))
405}
406
407fn truncate_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
408    if messages.len() <= config.min_preserve_messages {
409        return Ok(messages.to_vec());
410    }
411    Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
412}
413
414fn sliding_window_compress(
415    messages: &[Message],
416    config: &CompressionConfig,
417) -> Result<Vec<Message>> {
418    if messages.len() <= config.min_preserve_messages {
419        return Ok(messages.to_vec());
420    }
421
422    // Enhanced sliding window strategy:
423    // 1. Always keep first message (original user request)
424    // 2. Summarize middle messages if too long
425    // 3. Keep recent messages intact
426
427    let first_msg = messages.first().cloned();
428    let recent_start = messages.len().saturating_sub(config.min_preserve_messages);
429    let recent_msgs = &messages[recent_start..];
430
431    // Calculate tokens for first + recent
432    let first_tokens = first_msg.as_ref().map(estimate_tokens).unwrap_or(0);
433    let recent_tokens = estimate_total_tokens(recent_msgs);
434    let current_total = estimate_total_tokens(messages);
435    let target_tokens = (current_total as f64 * config.target_ratio) as u32;
436
437    // If first + recent already exceeds target, just use recent (drop first)
438    if first_tokens + recent_tokens <= target_tokens {
439        // We can keep first message + recent messages
440        let mut result: Vec<Message> = Vec::new();
441        if let Some(first) = first_msg {
442            result.push(first);
443        }
444        result.extend(recent_msgs.iter().cloned());
445        return Ok(result);
446    }
447
448    // If still too long, try dropping older messages from recent section
449    for drop_count in 0..recent_msgs.len() {
450        let candidate = &recent_msgs[drop_count..];
451        if estimate_total_tokens(candidate) <= target_tokens {
452            return Ok(candidate.to_vec());
453        }
454    }
455
456    // Last resort: just keep minimum recent messages
457    Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
458}
459
460// ============================================================================
461// Token Estimation
462// ============================================================================
463
464/// Count tokens for a message using tiktoken (accurate).
465pub fn estimate_tokens(message: &Message) -> u32 {
466    let content_text = match &message.content {
467        MessageContent::Text(t) => t.clone(),
468        MessageContent::Blocks(blocks) => {
469            let mut text = String::new();
470            for block in blocks {
471                match block {
472                    ContentBlock::Text { text: t } => {
473                        text.push_str(t);
474                        text.push(' '); // Separator
475                    }
476                    ContentBlock::ToolUse { name, input, .. } => {
477                        text.push_str(name);
478                        text.push(' '); // Separator
479                        text.push_str(&input.to_string());
480                        text.push(' '); // Separator
481                    }
482                    ContentBlock::ToolResult { content, .. } => {
483                        text.push_str(content);
484                        text.push(' '); // Separator
485                    }
486                    ContentBlock::Thinking { thinking, .. } => {
487                        text.push_str(thinking);
488                        text.push(' '); // Separator
489                    }
490                    _ => {}
491                }
492            }
493            text
494        }
495    };
496
497    // Use tiktoken for accurate token counting
498    let content_tokens = count_tokens(&content_text);
499    let role_tokens = count_tokens(&format!("{:?}: ", message.role));
500    
501    // Total = content + role prefix + overhead
502    content_tokens + role_tokens + message_overhead()
503}
504
505/// Legacy char counting function (deprecated, kept for backward compatibility).
506#[deprecated(note = "Use estimate_tokens with tiktoken instead")]
507fn count_chars(s: &str) -> (u32, u32) {
508    let mut ascii = 0u32;
509    let mut non_ascii = 0u32;
510    for ch in s.chars() {
511        if ch.is_ascii() {
512            ascii += 1;
513        } else {
514            non_ascii += 1;
515        }
516    }
517    (ascii, non_ascii)
518}
519
520/// Estimate total tokens for a message list.
521pub fn estimate_total_tokens(messages: &[Message]) -> u32 {
522    messages.iter().map(estimate_tokens).sum()
523}
524
525/// Check if compression should be triggered.
526pub fn should_compress(
527    current_tokens: u32,
528    context_size: Option<u32>,
529    config: &CompressionConfig,
530) -> bool {
531    match context_size {
532        Some(size) => (current_tokens as f64 / size as f64) >= config.threshold,
533        None => false,
534    }
535}
536
537/// Build a prompt for summarization.
538pub fn build_summary_prompt(messages: &[Message]) -> String {
539    let history = messages
540        .iter()
541        .map(|m| {
542            let role = match m.role {
543                Role::User => "用户",
544                Role::Assistant => "助手",
545                Role::Tool => "工具",
546                Role::System => "系统",
547            };
548            let preview = match &m.content {
549                MessageContent::Text(t) => truncate_with_suffix(t, 200),
550                MessageContent::Blocks(blocks) => blocks
551                    .iter()
552                    .map(|b| match b {
553                        ContentBlock::Text { text } => truncate_with_suffix(text, 100),
554                        ContentBlock::ToolUse { name, .. } => format!("[工具: {}]", name),
555                        ContentBlock::ToolResult { content, .. } => {
556                            truncate_with_suffix(content, 100)
557                        }
558                        _ => "[...]".to_string(),
559                    })
560                    .collect::<Vec<_>>()
561                    .join(" | "),
562            };
563            format!("{}: {}", role, preview)
564        })
565        .collect::<Vec<_>>()
566        .join("\n");
567
568    format!(
569        "请将以下对话压缩为简洁摘要({} 条消息):\n{}",
570        messages.len(),
571        history
572    )
573}
574
575// ============================================================================
576// New Pipeline-Based Compression (Async)
577// ============================================================================
578
579use super::pipeline::CompressionPipeline;
580use super::types::AiCompressionMode;
581
582/// Compress messages with AI assistance (async version).
583///
584/// This is the new recommended API for compression with intelligent
585/// scoring, dependency tracking, and content summarization.
586pub async fn compress_messages_with_ai(
587    messages: &[Message],
588    config: &CompressionConfig,
589    ai_mode: AiCompressionMode,
590    fast_model: Option<Box<dyn Provider>>,
591    token_usage: u32,
592    context_window: u32,
593) -> Result<Vec<Message>> {
594    let mut pipeline = match (ai_mode, fast_model) {
595        (AiCompressionMode::None, _) => CompressionPipeline::new_rule_only(config.clone()),
596        (AiCompressionMode::Light | AiCompressionMode::Deep, Some(model)) => {
597            CompressionPipeline::new_with_ai(config.clone(), model)
598        }
599        _ => CompressionPipeline::new_rule_only(config.clone()),
600    };
601
602    let result = pipeline
603        .execute(messages, ai_mode, token_usage, context_window)
604        .await?;
605    Ok(result.messages)
606}
607
608/// Compress messages with full AI support (async version).
609///
610/// Uses both fast_model and main_model for different compression tasks.
611pub async fn compress_messages_with_full_ai(
612    messages: &[Message],
613    config: &CompressionConfig,
614    ai_mode: AiCompressionMode,
615    fast_model: Box<dyn Provider>,
616    main_model: Box<dyn Provider>,
617    token_usage: u32,
618    context_window: u32,
619) -> Result<Vec<Message>> {
620    let mut pipeline =
621        CompressionPipeline::new_with_full_ai(config.clone(), fast_model, main_model);
622
623    let result = pipeline
624        .execute(messages, ai_mode, token_usage, context_window)
625        .await?;
626    Ok(result.messages)
627}
628
629/// Score messages without compressing (analysis only).
630///
631/// Useful for debugging and understanding compression decisions.
632pub fn score_messages_only(
633    messages: &[Message],
634    config: &CompressionConfig,
635) -> Vec<super::types::ScoredMessage> {
636    let pipeline = CompressionPipeline::new_rule_only(config.clone());
637    pipeline.score_only(messages)
638}
639
640// ============================================================================
641// Tests
642// ============================================================================
643
644#[cfg(test)]
645mod tests {
646    use super::*;
647
648    #[test]
649    fn test_estimate_tokens_simple() {
650        let msg = Message {
651            role: Role::User,
652            content: MessageContent::Text("Hello world".to_string()),
653        };
654        assert!(estimate_tokens(&msg) >= 3);
655    }
656
657    #[test]
658    fn test_should_compress() {
659        let config = CompressionConfig::default();
660        // Threshold is 0.5, so 100K/200K = 0.5 triggers compression
661        assert!(should_compress(100_000, Some(200_000), &config));
662        // 80K/200K = 0.4, below threshold
663        assert!(!should_compress(80_000, Some(200_000), &config));
664    }
665}