Skip to main content

matrixcode_core/compress/
compressor.rs

1//! Compression functions and AI compressor implementation.
2
3use crate::providers::{
4    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
5};
6use crate::tokenizer::{count_tokens, message_overhead};
7use crate::truncate::truncate_with_suffix;
8use anyhow::Result;
9use async_trait::async_trait;
10use std::collections::HashSet;
11
12use super::dependency::DependencyBuilder;
13
14use super::config::{CompressionBias, CompressionConfig};
15use super::types::{CompressionStrategy, SummarizedSegment};
16
17// ============================================================================
18// Compressor Trait
19// ============================================================================
20
21/// Compressor trait for different implementations.
22#[async_trait]
23pub trait Compressor: Send + Sync {
24    /// Compress messages using AI summarization.
25    async fn summarize(
26        &self,
27        messages: &[Message],
28        config: &CompressionConfig,
29    ) -> Result<SummarizedSegment>;
30
31    /// Get the model name used.
32    fn model_name(&self) -> &str;
33}
34
35/// AI-based compressor using a Provider.
36pub struct AiCompressor {
37    provider: Box<dyn Provider>,
38    model: String,
39    hardcode_config: crate::compress::hardcode_config::HardcodeConfig,
40}
41
42impl AiCompressor {
43    pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
44        Self {
45            provider,
46            model,
47            hardcode_config: crate::compress::hardcode_config::HardcodeConfig::default(),
48        }
49    }
50    
51    /// Set custom hardcode config
52    pub fn with_hardcode_config(mut self, config: crate::compress::hardcode_config::HardcodeConfig) -> Self {
53        self.hardcode_config = config;
54        self
55    }
56}
57
58const SUMMARY_SYSTEM_PROMPT: &str = r#"CRITICAL: 仅用文本响应。不要调用任何工具。
59
60- 不要使用 read、bash、grep、glob、edit、write 或任何其他工具
61- 你已在上方对话中获得所需的所有上下文
62- 工具调用将被拒绝并浪费你唯一的 turn — 你将失败任务
63- 你的整个响应必须是纯文本摘要
64
65---
66
67你是一个对话历史压缩助手。将对话压缩为结构化摘要。
68
69在提供最终摘要前,将你的分析包裹在 <analysis> 标签中以组织思路:
70<analysis>
711. 按时间顺序分析每条消息
722. 识别:用户请求、助手行动、关键决策、错误及修复
733. 特别注意用户的敏感指令(禁止/必须)
744. 双重检查技术准确性和完整性
75</analysis>
76
77输出要求:
78- 结构化:使用9个章节格式
79- 关键:只保留重要信息,忽略无关细节
80- 敏感:必须保留用户的敏感指令(禁止、必须等)— 原样保留
81- 任务:必须保留未完成的待办事项
82- 决策:必须保留关键方案选择和理由
83
849章节输出格式:
85【摘要】一句话概括主要工作(50字以内)
86【已完成】列出已完成的操作(工具调用、文件变更)
87【未完成】列出待办任务和阻塞项 — 最关键,压缩后恢复需要
88【关键决策】重要选择及理由(技术选型、方案决策)
89【敏感指令】用户的禁止/必须指令(必须原样保留)
90【技术栈】使用的语言、框架、库、工具
91【文件变更】读取、修改、创建的文件路径
92【问题记录】遇到的问题及解决方案
93【下一步】建议的下一步操作(直接引用最近对话展示任务中断点)
94
95每章节控制在100字以内,空章节可省略。
96输出摘要后立即停止,不要添加任何解释或后续建议。
97
98REMINDER: 不要调用任何工具。仅用纯文本响应。"#;
99
100#[async_trait]
101impl Compressor for AiCompressor {
102    async fn summarize(
103        &self,
104        messages: &[Message],
105        _config: &CompressionConfig,
106    ) -> Result<SummarizedSegment> {
107        let prompt = build_summary_prompt(messages);
108
109        let request = ChatRequest {
110            messages: vec![Message {
111                role: Role::User,
112                content: MessageContent::Text(prompt),
113            }],
114            tools: vec![],
115            system: Some(SUMMARY_SYSTEM_PROMPT.to_string()),
116            think: false,
117            max_tokens: 1024,
118            server_tools: vec![],
119            enable_caching: false,
120        };
121
122        let response = self.provider.chat(request).await?;
123        let summary_text = extract_text_from_response(&response);
124        let (summary, key_points) = parse_summary_response(&summary_text, &self.hardcode_config);
125
126        Ok(SummarizedSegment {
127            time_range: (chrono::Utc::now(), chrono::Utc::now()),
128            original_count: messages.len(),
129            summary,
130            key_points,
131        })
132    }
133
134    fn model_name(&self) -> &str {
135        &self.model
136    }
137}
138
139fn extract_text_from_response(response: &ChatResponse) -> String {
140    response
141        .content
142        .iter()
143        .filter_map(|block| {
144            if let ContentBlock::Text { text } = block {
145                Some(text.clone())
146            } else {
147                None
148            }
149        })
150        .collect::<Vec<_>>()
151        .join("\n")
152}
153
154fn parse_summary_response(text: &str, config: &crate::compress::hardcode_config::HardcodeConfig) -> (String, Vec<String>) {
155    let mut summary = String::new();
156    let mut key_points: Vec<String> = Vec::new();
157
158    // Parse 9-section structured format
159    let sections = [
160        "【摘要】",
161        "【已完成】",
162        "【未完成】",
163        "【关键决策】",
164        "【敏感指令】",
165        "【技术栈】",
166        "【文件变更】",
167        "【问题记录】",
168        "【下一步】",
169    ];
170
171    for line in text.lines() {
172        let line = line.trim();
173
174        // Check if this is a section header
175        let is_header = sections.iter().any(|s| line.starts_with(s));
176
177        if is_header {
178            // Extract content after the header
179            for section in &sections {
180                if line.starts_with(section) {
181                    let replaced = line.replace(section, "");
182                    let content = replaced.trim();
183                    if !content.is_empty() {
184                        if *section == "【摘要】" {
185                            summary = content.to_string();
186                        } else {
187                            key_points.push(format!("{}{}", section, content));
188                        }
189                    }
190                    break;
191                }
192            }
193        } else if !line.is_empty() {
194            // This is content under a section
195            if line.starts_with("•") || line.starts_with("-") || line.starts_with("*") {
196                let point = line.trim_start_matches(['•', '-', '*']).trim();
197                if !point.is_empty() {
198                    key_points.push(point.to_string());
199                }
200            } else if summary.is_empty() {
201                // Fallback: first non-empty line as summary
202                summary = line.to_string();
203            }
204        }
205    }
206
207    // Fallback if no structured format found
208    if summary.is_empty() && !text.is_empty() {
209        summary = text.lines().take(3).collect::<Vec<_>>().join(" ");
210        if summary.len() > config.summary_length_threshold {
211            summary = truncate_with_suffix(&summary, config.summary_length_threshold);
212        }
213    }
214
215    (summary, key_points)
216}
217
218// ============================================================================
219// Compression Functions
220// ============================================================================
221
222/// Compress messages synchronously.
223pub fn compress_messages(
224    messages: &[Message],
225    strategy: CompressionStrategy,
226    config: &CompressionConfig,
227) -> Result<Vec<Message>> {
228    match strategy {
229        CompressionStrategy::Truncate => truncate_compress(messages, config),
230        CompressionStrategy::SlidingWindow => sliding_window_compress(messages, config),
231        CompressionStrategy::Summarize => sliding_window_compress(messages, config),
232        CompressionStrategy::BiasBased => compress_with_bias(messages, config),
233    }
234}
235
236/// Compress with bias-based scoring.
237pub fn compress_with_bias(
238    messages: &[Message],
239    config: &CompressionConfig,
240) -> Result<Vec<Message>> {
241    if messages.len() <= config.min_preserve_messages {
242        return Ok(messages.to_vec());
243    }
244
245    let scored: Vec<(usize, Message, f64)> = messages
246        .iter()
247        .enumerate()
248        .map(|(idx, msg)| {
249            (
250                idx,
251                msg.clone(),
252                calculate_preservation_score(msg, idx, messages.len(), &config.bias),
253            )
254        })
255        .collect();
256
257    let mut scored_with_recency: Vec<(usize, Message, f64)> = scored
258        .into_iter()
259        .map(|(idx, msg, score)| {
260            let recency_bonus = if idx >= messages.len() - config.min_preserve_messages {
261                100.0
262            } else {
263                (idx as f64 / messages.len() as f64) * 20.0
264            };
265            (idx, msg, score + recency_bonus)
266        })
267        .collect();
268
269    scored_with_recency.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
270
271    let target_count = if config.bias.aggressive {
272        config.min_preserve_messages
273    } else {
274        let estimated = estimate_total_tokens(messages);
275        let target_tokens = (estimated as f64 * config.target_ratio) as u32;
276        let avg = estimated / messages.len() as u32;
277        (target_tokens / avg.max(1)) as usize
278    };
279
280    let to_keep: HashSet<usize> = scored_with_recency
281        .iter()
282        .take(target_count)
283        .map(|(idx, _, _)| *idx)
284        .collect();
285
286    let compressed: Vec<Message> = messages
287        .iter()
288        .enumerate()
289        .filter(|(idx, _)| to_keep.contains(idx))
290        .map(|(_, msg)| msg.clone())
291        .collect();
292
293    Ok(compressed)
294}
295
296fn calculate_preservation_score(
297    message: &Message,
298    _index: usize,
299    _total: usize, // Reserved for future use (total message count)
300    bias: &CompressionBias,
301) -> f64 {
302    let mut score: f64 = 10.0;
303
304    // Note: First message is no longer given unconditional high priority
305    // It will be scored based on its content like other messages.
306    // This prevents context pollution when user starts a new topic in multi-turn conversations.
307
308    match message.role {
309        Role::User => {
310            if bias.preserve_user_questions {
311                score += 30.0;
312            }
313        }
314        Role::Assistant => {
315            score += 5.0;
316        }
317        Role::Tool => {
318            if bias.preserve_tools {
319                score += 25.0;
320            }
321        }
322        Role::System => {
323            score += 40.0;
324        }
325    }
326
327    match &message.content {
328        MessageContent::Text(text) => {
329            for keyword in &bias.preserve_keywords {
330                if text.to_lowercase().contains(&keyword.to_lowercase()) {
331                    score += 15.0;
332                }
333            }
334            if contains_sensitive_instructions(text) {
335                score += 50.0;
336            }
337        }
338        MessageContent::Blocks(blocks) => {
339            for block in blocks {
340                match block {
341                    ContentBlock::ToolUse { name, .. } => {
342                        if bias.preserve_tools {
343                            score += 20.0;
344                        }
345                        if name == "write" || name == "edit" || name == "bash" {
346                            score += 10.0;
347                        }
348                        // todo_write gets high priority - preserve task tracking
349                        if name == "todo_write" {
350                            score += 60.0;
351                        }
352                        // ask tool contains key decisions
353                        if name == "ask" {
354                            score += 50.0;
355                        }
356                    }
357                    ContentBlock::ToolResult { content, .. } => {
358                        if bias.preserve_tools {
359                            score += 20.0;
360                        }
361                        if contains_sensitive_instructions(content) {
362                            score += 30.0;
363                        }
364                        // Preserve todo_write results (task status)
365                        if content.contains("TodoWrite") || content.contains("todo") {
366                            score += 40.0;
367                        }
368                        // Preserve ask responses (user decisions)
369                        if content.contains("AskUserQuestion") || content.contains("answer") {
370                            score += 30.0;
371                        }
372                    }
373                    ContentBlock::Thinking { .. } => {
374                        if bias.preserve_thinking {
375                            score += 25.0;
376                        } else {
377                            score -= 5.0;
378                        }
379                    }
380                    ContentBlock::Text { text } => {
381                        if contains_sensitive_instructions(text) {
382                            score += 50.0;
383                        }
384                    }
385                    _ => {}
386                }
387            }
388        }
389    }
390
391    score
392}
393
394fn contains_sensitive_instructions(text: &str) -> bool {
395    let lower = text.to_lowercase();
396    let patterns = [
397        "不要",
398        "禁止",
399        "必须",
400        "不允许",
401        "never",
402        "must not",
403        "do not",
404    ];
405    patterns.iter().any(|p| lower.contains(p))
406}
407
408fn truncate_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
409    if messages.len() <= config.min_preserve_messages {
410        return Ok(messages.to_vec());
411    }
412    Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
413}
414
415fn sliding_window_compress(
416    messages: &[Message],
417    config: &CompressionConfig,
418) -> Result<Vec<Message>> {
419    if messages.len() <= config.min_preserve_messages {
420        return Ok(messages.to_vec());
421    }
422
423    // Build dependency graph to preserve tool_use/tool_result pairs
424    let deps = DependencyBuilder::build(messages);
425
426    // Enhanced sliding window strategy with dependency protection:
427    // 1. Preserve tool_use/tool_result pairs
428    // 2. Keep recent messages intact
429    // Note: First message is NOT automatically preserved - it competes based on content scoring
430
431    let recent_start = messages.len().saturating_sub(config.min_preserve_messages);
432
433    // Collect indices to preserve
434    let mut preserve_indices: HashSet<usize> = HashSet::new();
435
436    // Preserve recent messages (including the current user message which is likely at the end)
437    for i in recent_start..messages.len() {
438        preserve_indices.insert(i);
439
440        // Also preserve dependency pairs for recent messages
441        for pair_idx in deps.get_pair_indices(i) {
442            preserve_indices.insert(pair_idx);
443        }
444    }
445
446    // Calculate tokens for preserved messages
447    let preserved_msgs: Vec<Message> = preserve_indices
448        .iter()
449        .filter(|&i| *i < messages.len())
450        .map(|&i| messages[i].clone())
451        .collect();
452
453    let preserved_tokens = estimate_total_tokens(&preserved_msgs);
454    let current_total = estimate_total_tokens(messages);
455    let target_tokens = (current_total as f64 * config.target_ratio) as u32;
456
457    // If preserved messages fit within target, return them sorted by index
458    if preserved_tokens <= target_tokens {
459        let mut sorted_indices: Vec<usize> = preserve_indices.iter().cloned().collect();
460        sorted_indices.sort();
461        let result: Vec<Message> = sorted_indices
462            .iter()
463            .filter(|&i| *i < messages.len())
464            .map(|&i| messages[i].clone())
465            .collect();
466        return Ok(result);
467    }
468
469    // If still too long, try dropping older messages while preserving pairs
470    let mut sorted_indices: Vec<usize> = preserve_indices.iter().cloned().collect();
471    sorted_indices.sort();
472
473    // Try removing from the oldest messages
474    while sorted_indices.len() > config.min_preserve_messages {
475        let candidate_tokens = sorted_indices
476            .iter()
477            .copied()
478            .filter(|&i| i < messages.len())
479            .map(|i| estimate_tokens(&messages[i]))
480            .sum::<u32>();
481
482        if candidate_tokens <= target_tokens {
483            break;
484        }
485
486        // Find the oldest message and try to drop it
487        // But keep its pair if it's a tool_use/tool_result
488        if sorted_indices.len() > 2 {
489            let oldest_idx = sorted_indices[0];
490            let pair_indices = deps.get_pair_indices(oldest_idx);
491
492            // Only drop if dropping won't orphan a pair
493            if pair_indices.is_empty() || pair_indices.iter().all(|p| !sorted_indices.contains(p)) {
494                sorted_indices.remove(0);
495            } else {
496                // Can't drop this one, try next
497                sorted_indices.remove(1);
498            }
499        } else {
500            break;
501        }
502    }
503
504    let result: Vec<Message> = sorted_indices
505        .iter()
506        .filter(|&i| *i < messages.len())
507        .map(|&i| messages[i].clone())
508        .collect();
509
510    Ok(result)
511}
512
513// ============================================================================
514// Token Estimation
515// ============================================================================
516
517/// Count tokens for a message using tiktoken (accurate).
518pub fn estimate_tokens(message: &Message) -> u32 {
519    let content_text = match &message.content {
520        MessageContent::Text(t) => t.clone(),
521        MessageContent::Blocks(blocks) => {
522            let mut text = String::new();
523            for block in blocks {
524                match block {
525                    ContentBlock::Text { text: t } => {
526                        text.push_str(t);
527                        text.push(' '); // Separator
528                    }
529                    ContentBlock::ToolUse { name, input, .. } => {
530                        text.push_str(name);
531                        text.push(' '); // Separator
532                        text.push_str(&input.to_string());
533                        text.push(' '); // Separator
534                    }
535                    ContentBlock::ToolResult { content, .. } => {
536                        text.push_str(content);
537                        text.push(' '); // Separator
538                    }
539                    ContentBlock::Thinking { thinking, .. } => {
540                        text.push_str(thinking);
541                        text.push(' '); // Separator
542                    }
543                    _ => {}
544                }
545            }
546            text
547        }
548    };
549
550    // Use tiktoken for accurate token counting
551    let content_tokens = count_tokens(&content_text);
552    let role_tokens = count_tokens(&format!("{:?}: ", message.role));
553    
554    // Total = content + role prefix + overhead
555    content_tokens + role_tokens + message_overhead()
556}
557
558/// Estimate total tokens for a message list.
559pub fn estimate_total_tokens(messages: &[Message]) -> u32 {
560    messages.iter().map(estimate_tokens).sum()
561}
562
563/// Check if compression should be triggered.
564pub fn should_compress(
565    current_tokens: u32,
566    context_size: Option<u32>,
567    config: &CompressionConfig,
568) -> bool {
569    match context_size {
570        Some(size) => (current_tokens as f64 / size as f64) >= config.threshold,
571        None => false,
572    }
573}
574
575/// Build a prompt for summarization.
576pub fn build_summary_prompt(messages: &[Message]) -> String {
577    let history = messages
578        .iter()
579        .map(|m| {
580            let role = match m.role {
581                Role::User => "用户",
582                Role::Assistant => "助手",
583                Role::Tool => "工具",
584                Role::System => "系统",
585            };
586            let preview = match &m.content {
587                MessageContent::Text(t) => truncate_with_suffix(t, 200),
588                MessageContent::Blocks(blocks) => blocks
589                    .iter()
590                    .map(|b| match b {
591                        ContentBlock::Text { text } => truncate_with_suffix(text, 100),
592                        ContentBlock::ToolUse { name, .. } => format!("[工具: {}]", name),
593                        ContentBlock::ToolResult { content, .. } => {
594                            truncate_with_suffix(content, 100)
595                        }
596                        _ => "[...]".to_string(),
597                    })
598                    .collect::<Vec<_>>()
599                    .join(" | "),
600            };
601            format!("{}: {}", role, preview)
602        })
603        .collect::<Vec<_>>()
604        .join("\n");
605
606    format!(
607        "请将以下对话压缩为简洁摘要({} 条消息):\n{}",
608        messages.len(),
609        history
610    )
611}
612
613// ============================================================================
614// New Pipeline-Based Compression (Async)
615// ============================================================================
616
617use super::pipeline::CompressionPipeline;
618use super::types::AiCompressionMode;
619
620/// Compress messages with AI assistance (async version).
621///
622/// This is the new recommended API for compression with intelligent
623/// scoring, dependency tracking, and content summarization.
624pub async fn compress_messages_with_ai(
625    messages: &[Message],
626    config: &CompressionConfig,
627    ai_mode: AiCompressionMode,
628    fast_model: Option<Box<dyn Provider>>,
629    token_usage: u32,
630    context_window: u32,
631) -> Result<Vec<Message>> {
632    let mut pipeline = match (ai_mode, fast_model) {
633        (AiCompressionMode::None, _) => CompressionPipeline::new_rule_only(config.clone()),
634        (AiCompressionMode::Light | AiCompressionMode::Deep, Some(model)) => {
635            CompressionPipeline::new_with_ai(config.clone(), model)
636        }
637        _ => CompressionPipeline::new_rule_only(config.clone()),
638    };
639
640    let result = pipeline
641        .execute(messages, ai_mode, token_usage, context_window)
642        .await?;
643    Ok(result.messages)
644}
645
646/// Compress messages with full AI support (async version).
647///
648/// Uses both fast_model and main_model for different compression tasks.
649pub async fn compress_messages_with_full_ai(
650    messages: &[Message],
651    config: &CompressionConfig,
652    ai_mode: AiCompressionMode,
653    fast_model: Box<dyn Provider>,
654    main_model: Box<dyn Provider>,
655    token_usage: u32,
656    context_window: u32,
657) -> Result<Vec<Message>> {
658    let mut pipeline =
659        CompressionPipeline::new_with_full_ai(config.clone(), fast_model, main_model);
660
661    let result = pipeline
662        .execute(messages, ai_mode, token_usage, context_window)
663        .await?;
664    Ok(result.messages)
665}
666
667/// Score messages without compressing (analysis only).
668///
669/// Useful for debugging and understanding compression decisions.
670pub fn score_messages_only(
671    messages: &[Message],
672    config: &CompressionConfig,
673) -> Vec<super::types::ScoredMessage> {
674    let pipeline = CompressionPipeline::new_rule_only(config.clone());
675    pipeline.score_only(messages)
676}
677
678// ============================================================================
679// Tests
680// ============================================================================
681
682#[cfg(test)]
683mod tests {
684    use super::*;
685
686    #[test]
687    fn test_estimate_tokens_simple() {
688        let msg = Message {
689            role: Role::User,
690            content: MessageContent::Text("Hello world".to_string()),
691        };
692        assert!(estimate_tokens(&msg) >= 3);
693    }
694
695    #[test]
696    fn test_should_compress() {
697        let config = CompressionConfig::default();
698        // Threshold is 0.5, so 100K/200K = 0.5 triggers compression
699        assert!(should_compress(100_000, Some(200_000), &config));
700        // 80K/200K = 0.4, below threshold
701        assert!(!should_compress(80_000, Some(200_000), &config));
702    }
703}