Skip to main content

matrixcode_core/compress/
compressor.rs

1//! Compression functions and AI compressor implementation.
2
3use crate::providers::{
4    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
5};
6use crate::truncate::truncate_with_suffix;
7use anyhow::Result;
8use async_trait::async_trait;
9use std::collections::HashSet;
10
11use super::config::{CompressionBias, CompressionConfig};
12use super::types::{CompressionStrategy, SummarizedSegment};
13
14// ============================================================================
15// Compressor Trait
16// ============================================================================
17
18/// Compressor trait for different implementations.
19#[async_trait]
20pub trait Compressor: Send + Sync {
21    /// Compress messages using AI summarization.
22    async fn summarize(
23        &self,
24        messages: &[Message],
25        config: &CompressionConfig,
26    ) -> Result<SummarizedSegment>;
27
28    /// Get the model name used.
29    fn model_name(&self) -> &str;
30}
31
32/// AI-based compressor using a Provider.
33pub struct AiCompressor {
34    provider: Box<dyn Provider>,
35    model: String,
36}
37
38impl AiCompressor {
39    pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
40        Self { provider, model }
41    }
42}
43
44const SUMMARY_SYSTEM_PROMPT: &str = r#"CRITICAL: 仅用文本响应。不要调用任何工具。
45
46- 不要使用 read、bash、grep、glob、edit、write 或任何其他工具
47- 你已在上方对话中获得所需的所有上下文
48- 工具调用将被拒绝并浪费你唯一的 turn — 你将失败任务
49- 你的整个响应必须是纯文本摘要
50
51---
52
53你是一个对话历史压缩助手。将对话压缩为结构化摘要。
54
55在提供最终摘要前,将你的分析包裹在 <analysis> 标签中以组织思路:
56<analysis>
571. 按时间顺序分析每条消息
582. 识别:用户请求、助手行动、关键决策、错误及修复
593. 特别注意用户的敏感指令(禁止/必须)
604. 双重检查技术准确性和完整性
61</analysis>
62
63输出要求:
64- 结构化:使用9个章节格式
65- 关键:只保留重要信息,忽略无关细节
66- 敏感:必须保留用户的敏感指令(禁止、必须等)— 原样保留
67- 任务:必须保留未完成的待办事项
68- 决策:必须保留关键方案选择和理由
69
709章节输出格式:
71【摘要】一句话概括主要工作(50字以内)
72【已完成】列出已完成的操作(工具调用、文件变更)
73【未完成】列出待办任务和阻塞项 — 最关键,压缩后恢复需要
74【关键决策】重要选择及理由(技术选型、方案决策)
75【敏感指令】用户的禁止/必须指令(必须原样保留)
76【技术栈】使用的语言、框架、库、工具
77【文件变更】读取、修改、创建的文件路径
78【问题记录】遇到的问题及解决方案
79【下一步】建议的下一步操作(直接引用最近对话展示任务中断点)
80
81每章节控制在100字以内,空章节可省略。
82输出摘要后立即停止,不要添加任何解释或后续建议。
83
84REMINDER: 不要调用任何工具。仅用纯文本响应。"#;
85
86#[async_trait]
87impl Compressor for AiCompressor {
88    async fn summarize(
89        &self,
90        messages: &[Message],
91        _config: &CompressionConfig,
92    ) -> Result<SummarizedSegment> {
93        let prompt = build_summary_prompt(messages);
94
95        let request = ChatRequest {
96            messages: vec![Message {
97                role: Role::User,
98                content: MessageContent::Text(prompt),
99            }],
100            tools: vec![],
101            system: Some(SUMMARY_SYSTEM_PROMPT.to_string()),
102            think: false,
103            max_tokens: 1024,
104            server_tools: vec![],
105            enable_caching: false,
106        };
107
108        let response = self.provider.chat(request).await?;
109        let summary_text = extract_text_from_response(&response);
110        let (summary, key_points) = parse_summary_response(&summary_text);
111
112        Ok(SummarizedSegment {
113            time_range: (chrono::Utc::now(), chrono::Utc::now()),
114            original_count: messages.len(),
115            summary,
116            key_points,
117        })
118    }
119
120    fn model_name(&self) -> &str {
121        &self.model
122    }
123}
124
125fn extract_text_from_response(response: &ChatResponse) -> String {
126    response
127        .content
128        .iter()
129        .filter_map(|block| {
130            if let ContentBlock::Text { text } = block {
131                Some(text.clone())
132            } else {
133                None
134            }
135        })
136        .collect::<Vec<_>>()
137        .join("\n")
138}
139
140fn parse_summary_response(text: &str) -> (String, Vec<String>) {
141    let mut summary = String::new();
142    let mut key_points: Vec<String> = Vec::new();
143
144    // Parse 9-section structured format
145    let sections = [
146        "【摘要】",
147        "【已完成】",
148        "【未完成】",
149        "【关键决策】",
150        "【敏感指令】",
151        "【技术栈】",
152        "【文件变更】",
153        "【问题记录】",
154        "【下一步】",
155    ];
156
157    for line in text.lines() {
158        let line = line.trim();
159
160        // Check if this is a section header
161        let is_header = sections.iter().any(|s| line.starts_with(s));
162
163        if is_header {
164            // Extract content after the header
165            for section in &sections {
166                if line.starts_with(section) {
167                    let replaced = line.replace(section, "");
168                    let content = replaced.trim();
169                    if !content.is_empty() {
170                        if *section == "【摘要】" {
171                            summary = content.to_string();
172                        } else {
173                            key_points.push(format!("{}{}", section, content));
174                        }
175                    }
176                    break;
177                }
178            }
179        } else if !line.is_empty() {
180            // This is content under a section
181            if line.starts_with("•") || line.starts_with("-") || line.starts_with("*") {
182                let point = line.trim_start_matches(['•', '-', '*']).trim();
183                if !point.is_empty() {
184                    key_points.push(point.to_string());
185                }
186            } else if summary.is_empty() {
187                // Fallback: first non-empty line as summary
188                summary = line.to_string();
189            }
190        }
191    }
192
193    // Fallback if no structured format found
194    if summary.is_empty() && !text.is_empty() {
195        summary = text.lines().take(3).collect::<Vec<_>>().join(" ");
196        if summary.len() > 200 {
197            summary = truncate_with_suffix(&summary, 200);
198        }
199    }
200
201    (summary, key_points)
202}
203
204// ============================================================================
205// Compression Functions
206// ============================================================================
207
208/// Compress messages synchronously.
209pub fn compress_messages(
210    messages: &[Message],
211    strategy: CompressionStrategy,
212    config: &CompressionConfig,
213) -> Result<Vec<Message>> {
214    match strategy {
215        CompressionStrategy::Truncate => truncate_compress(messages, config),
216        CompressionStrategy::SlidingWindow => sliding_window_compress(messages, config),
217        CompressionStrategy::Summarize => sliding_window_compress(messages, config),
218        CompressionStrategy::BiasBased => compress_with_bias(messages, config),
219    }
220}
221
222/// Compress with bias-based scoring.
223pub fn compress_with_bias(
224    messages: &[Message],
225    config: &CompressionConfig,
226) -> Result<Vec<Message>> {
227    if messages.len() <= config.min_preserve_messages {
228        return Ok(messages.to_vec());
229    }
230
231    let scored: Vec<(usize, Message, f64)> = messages
232        .iter()
233        .enumerate()
234        .map(|(idx, msg)| {
235            (
236                idx,
237                msg.clone(),
238                calculate_preservation_score(msg, idx, messages.len(), &config.bias),
239            )
240        })
241        .collect();
242
243    let mut scored_with_recency: Vec<(usize, Message, f64)> = scored
244        .into_iter()
245        .map(|(idx, msg, score)| {
246            let recency_bonus = if idx >= messages.len() - config.min_preserve_messages {
247                100.0
248            } else {
249                (idx as f64 / messages.len() as f64) * 20.0
250            };
251            (idx, msg, score + recency_bonus)
252        })
253        .collect();
254
255    scored_with_recency.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
256
257    let target_count = if config.bias.aggressive {
258        config.min_preserve_messages
259    } else {
260        let estimated = estimate_total_tokens(messages);
261        let target_tokens = (estimated as f64 * config.target_ratio) as u32;
262        let avg = estimated / messages.len() as u32;
263        (target_tokens / avg.max(1)) as usize
264    };
265
266    let to_keep: HashSet<usize> = scored_with_recency
267        .iter()
268        .take(target_count)
269        .map(|(idx, _, _)| *idx)
270        .collect();
271
272    let compressed: Vec<Message> = messages
273        .iter()
274        .enumerate()
275        .filter(|(idx, _)| to_keep.contains(idx))
276        .map(|(_, msg)| msg.clone())
277        .collect();
278
279    Ok(compressed)
280}
281
282fn calculate_preservation_score(
283    message: &Message,
284    index: usize,
285    _total: usize, // Reserved for future use (total message count)
286    bias: &CompressionBias,
287) -> f64 {
288    let mut score: f64 = 10.0;
289
290    // First message (user's original request) gets highest priority
291    if index == 0 {
292        score += 100.0;
293    }
294
295    match message.role {
296        Role::User => {
297            if bias.preserve_user_questions {
298                score += 30.0;
299            }
300        }
301        Role::Assistant => {
302            score += 5.0;
303        }
304        Role::Tool => {
305            if bias.preserve_tools {
306                score += 25.0;
307            }
308        }
309        Role::System => {
310            score += 40.0;
311        }
312    }
313
314    match &message.content {
315        MessageContent::Text(text) => {
316            for keyword in &bias.preserve_keywords {
317                if text.to_lowercase().contains(&keyword.to_lowercase()) {
318                    score += 15.0;
319                }
320            }
321            if contains_sensitive_instructions(text) {
322                score += 50.0;
323            }
324        }
325        MessageContent::Blocks(blocks) => {
326            for block in blocks {
327                match block {
328                    ContentBlock::ToolUse { name, .. } => {
329                        if bias.preserve_tools {
330                            score += 20.0;
331                        }
332                        if name == "write" || name == "edit" || name == "bash" {
333                            score += 10.0;
334                        }
335                        // todo_write gets high priority - preserve task tracking
336                        if name == "todo_write" {
337                            score += 60.0;
338                        }
339                        // ask tool contains key decisions
340                        if name == "ask" {
341                            score += 50.0;
342                        }
343                    }
344                    ContentBlock::ToolResult { content, .. } => {
345                        if bias.preserve_tools {
346                            score += 20.0;
347                        }
348                        if contains_sensitive_instructions(content) {
349                            score += 30.0;
350                        }
351                        // Preserve todo_write results (task status)
352                        if content.contains("TodoWrite") || content.contains("todo") {
353                            score += 40.0;
354                        }
355                        // Preserve ask responses (user decisions)
356                        if content.contains("AskUserQuestion") || content.contains("answer") {
357                            score += 30.0;
358                        }
359                    }
360                    ContentBlock::Thinking { .. } => {
361                        if bias.preserve_thinking {
362                            score += 25.0;
363                        } else {
364                            score -= 5.0;
365                        }
366                    }
367                    ContentBlock::Text { text } => {
368                        if contains_sensitive_instructions(text) {
369                            score += 50.0;
370                        }
371                    }
372                    _ => {}
373                }
374            }
375        }
376    }
377
378    score
379}
380
381fn contains_sensitive_instructions(text: &str) -> bool {
382    let lower = text.to_lowercase();
383    let patterns = [
384        "不要",
385        "禁止",
386        "必须",
387        "不允许",
388        "never",
389        "must not",
390        "do not",
391    ];
392    patterns.iter().any(|p| lower.contains(p))
393}
394
395fn truncate_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
396    if messages.len() <= config.min_preserve_messages {
397        return Ok(messages.to_vec());
398    }
399    Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
400}
401
402fn sliding_window_compress(
403    messages: &[Message],
404    config: &CompressionConfig,
405) -> Result<Vec<Message>> {
406    if messages.len() <= config.min_preserve_messages {
407        return Ok(messages.to_vec());
408    }
409
410    // Enhanced sliding window strategy:
411    // 1. Always keep first message (original user request)
412    // 2. Summarize middle messages if too long
413    // 3. Keep recent messages intact
414
415    let first_msg = messages.first().cloned();
416    let recent_start = messages.len().saturating_sub(config.min_preserve_messages);
417    let recent_msgs = &messages[recent_start..];
418
419    // Calculate tokens for first + recent
420    let first_tokens = first_msg.as_ref().map(estimate_tokens).unwrap_or(0);
421    let recent_tokens = estimate_total_tokens(recent_msgs);
422    let current_total = estimate_total_tokens(messages);
423    let target_tokens = (current_total as f64 * config.target_ratio) as u32;
424
425    // If first + recent already exceeds target, just use recent (drop first)
426    if first_tokens + recent_tokens <= target_tokens {
427        // We can keep first message + recent messages
428        let mut result: Vec<Message> = Vec::new();
429        if let Some(first) = first_msg {
430            result.push(first);
431        }
432        result.extend(recent_msgs.iter().cloned());
433        return Ok(result);
434    }
435
436    // If still too long, try dropping older messages from recent section
437    for drop_count in 0..recent_msgs.len() {
438        let candidate = &recent_msgs[drop_count..];
439        if estimate_total_tokens(candidate) <= target_tokens {
440            return Ok(candidate.to_vec());
441        }
442    }
443
444    // Last resort: just keep minimum recent messages
445    Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
446}
447
448// ============================================================================
449// Token Estimation
450// ============================================================================
451
452/// Estimate token count for a message.
453pub fn estimate_tokens(message: &Message) -> u32 {
454    let (ascii, non_ascii) = match &message.content {
455        MessageContent::Text(t) => count_chars(t),
456        MessageContent::Blocks(blocks) => {
457            let mut a = 0u32;
458            let mut n = 0u32;
459            for block in blocks {
460                match block {
461                    ContentBlock::Text { text } => {
462                        let (ca, cn) = count_chars(text);
463                        a += ca;
464                        n += cn;
465                    }
466                    ContentBlock::ToolUse { name, input, .. } => {
467                        let (ca, cn) = count_chars(name);
468                        a += ca;
469                        n += cn;
470                        let (ja, jn) = count_chars(&input.to_string());
471                        a += ja;
472                        n += jn;
473                    }
474                    ContentBlock::ToolResult { content, .. } => {
475                        let (ca, cn) = count_chars(content);
476                        a += ca;
477                        n += cn;
478                    }
479                    ContentBlock::Thinking { thinking, .. } => {
480                        let (ca, cn) = count_chars(thinking);
481                        a += ca;
482                        n += cn;
483                    }
484                    _ => {}
485                }
486            }
487            (a, n)
488        }
489    };
490
491    let ascii_tokens = (ascii as f64 * 0.25).ceil() as u32;
492    let non_ascii_tokens = (non_ascii as f64 * 0.67).ceil() as u32;
493    (ascii_tokens + non_ascii_tokens + 10).max(1)
494}
495
496fn count_chars(s: &str) -> (u32, u32) {
497    let mut ascii = 0u32;
498    let mut non_ascii = 0u32;
499    for ch in s.chars() {
500        if ch.is_ascii() {
501            ascii += 1;
502        } else {
503            non_ascii += 1;
504        }
505    }
506    (ascii, non_ascii)
507}
508
509/// Estimate total tokens for a message list.
510pub fn estimate_total_tokens(messages: &[Message]) -> u32 {
511    messages.iter().map(estimate_tokens).sum()
512}
513
514/// Check if compression should be triggered.
515pub fn should_compress(
516    current_tokens: u32,
517    context_size: Option<u32>,
518    config: &CompressionConfig,
519) -> bool {
520    match context_size {
521        Some(size) => (current_tokens as f64 / size as f64) >= config.threshold,
522        None => false,
523    }
524}
525
526/// Build a prompt for summarization.
527pub fn build_summary_prompt(messages: &[Message]) -> String {
528    let history = messages
529        .iter()
530        .map(|m| {
531            let role = match m.role {
532                Role::User => "用户",
533                Role::Assistant => "助手",
534                Role::Tool => "工具",
535                Role::System => "系统",
536            };
537            let preview = match &m.content {
538                MessageContent::Text(t) => truncate_with_suffix(t, 200),
539                MessageContent::Blocks(blocks) => blocks
540                    .iter()
541                    .map(|b| match b {
542                        ContentBlock::Text { text } => truncate_with_suffix(text, 100),
543                        ContentBlock::ToolUse { name, .. } => format!("[工具: {}]", name),
544                        ContentBlock::ToolResult { content, .. } => {
545                            truncate_with_suffix(content, 100)
546                        }
547                        _ => "[...]".to_string(),
548                    })
549                    .collect::<Vec<_>>()
550                    .join(" | "),
551            };
552            format!("{}: {}", role, preview)
553        })
554        .collect::<Vec<_>>()
555        .join("\n");
556
557    format!(
558        "请将以下对话压缩为简洁摘要({} 条消息):\n{}",
559        messages.len(),
560        history
561    )
562}
563
564// ============================================================================
565// New Pipeline-Based Compression (Async)
566// ============================================================================
567
568use super::pipeline::CompressionPipeline;
569use super::types::AiCompressionMode;
570
571/// Compress messages with AI assistance (async version).
572///
573/// This is the new recommended API for compression with intelligent
574/// scoring, dependency tracking, and content summarization.
575pub async fn compress_messages_with_ai(
576    messages: &[Message],
577    config: &CompressionConfig,
578    ai_mode: AiCompressionMode,
579    fast_model: Option<Box<dyn Provider>>,
580    token_usage: u32,
581    context_window: u32,
582) -> Result<Vec<Message>> {
583    let mut pipeline = match (ai_mode, fast_model) {
584        (AiCompressionMode::None, _) => CompressionPipeline::new_rule_only(config.clone()),
585        (AiCompressionMode::Light | AiCompressionMode::Deep, Some(model)) => {
586            CompressionPipeline::new_with_ai(config.clone(), model)
587        }
588        _ => CompressionPipeline::new_rule_only(config.clone()),
589    };
590
591    let result = pipeline
592        .execute(messages, ai_mode, token_usage, context_window)
593        .await?;
594    Ok(result.messages)
595}
596
597/// Compress messages with full AI support (async version).
598///
599/// Uses both fast_model and main_model for different compression tasks.
600pub async fn compress_messages_with_full_ai(
601    messages: &[Message],
602    config: &CompressionConfig,
603    ai_mode: AiCompressionMode,
604    fast_model: Box<dyn Provider>,
605    main_model: Box<dyn Provider>,
606    token_usage: u32,
607    context_window: u32,
608) -> Result<Vec<Message>> {
609    let mut pipeline =
610        CompressionPipeline::new_with_full_ai(config.clone(), fast_model, main_model);
611
612    let result = pipeline
613        .execute(messages, ai_mode, token_usage, context_window)
614        .await?;
615    Ok(result.messages)
616}
617
618/// Score messages without compressing (analysis only).
619///
620/// Useful for debugging and understanding compression decisions.
621pub fn score_messages_only(
622    messages: &[Message],
623    config: &CompressionConfig,
624) -> Vec<super::types::ScoredMessage> {
625    let pipeline = CompressionPipeline::new_rule_only(config.clone());
626    pipeline.score_only(messages)
627}
628
629// ============================================================================
630// Tests
631// ============================================================================
632
633#[cfg(test)]
634mod tests {
635    use super::*;
636
637    #[test]
638    fn test_estimate_tokens_simple() {
639        let msg = Message {
640            role: Role::User,
641            content: MessageContent::Text("Hello world".to_string()),
642        };
643        assert!(estimate_tokens(&msg) >= 3);
644    }
645
646    #[test]
647    fn test_should_compress() {
648        let config = CompressionConfig::default();
649        // Threshold is 0.5, so 100K/200K = 0.5 triggers compression
650        assert!(should_compress(100_000, Some(200_000), &config));
651        // 80K/200K = 0.4, below threshold
652        assert!(!should_compress(80_000, Some(200_000), &config));
653    }
654}