Skip to main content

matrixcode_core/compress/
compressor.rs

1//! Compression functions and AI compressor implementation.
2
3use crate::providers::{
4    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
5};
6use crate::truncate::truncate_with_suffix;
7use anyhow::Result;
8use async_trait::async_trait;
9use std::collections::HashSet;
10
11use super::config::{CompressionBias, CompressionConfig};
12use super::types::{CompressionStrategy, SummarizedSegment};
13
14// ============================================================================
15// Compressor Trait
16// ============================================================================
17
18/// Compressor trait for different implementations.
19#[async_trait]
20pub trait Compressor: Send + Sync {
21    /// Compress messages using AI summarization.
22    async fn summarize(
23        &self,
24        messages: &[Message],
25        config: &CompressionConfig,
26    ) -> Result<SummarizedSegment>;
27
28    /// Get the model name used.
29    fn model_name(&self) -> &str;
30}
31
32/// AI-based compressor using a Provider.
33pub struct AiCompressor {
34    provider: Box<dyn Provider>,
35    model: String,
36}
37
38impl AiCompressor {
39    pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
40        Self { provider, model }
41    }
42}
43
44const SUMMARY_SYSTEM_PROMPT: &str = r#"你是一个对话历史压缩助手。将对话压缩为结构化摘要。
45
46输出要求:
47- 结构化:使用9个章节格式
48- 关键:只保留重要信息,忽略无关细节
49- 敏感:必须保留用户的敏感指令(禁止、必须等)
50- 任务:必须保留未完成的待办事项
51- 决策:必须保留关键方案选择和理由
52
539章节输出格式:
54【摘要】一句话概括主要工作(50字以内)
55【已完成】列出已完成的操作(工具调用、文件变更)
56【未完成】列出待办任务和阻塞项
57【关键决策】重要选择及理由(技术选型、方案决策)
58【敏感指令】用户的禁止/必须指令(必须原样保留)
59【技术栈】使用的语言、框架、库、工具
60【文件变更】读取、修改、创建的文件路径
61【问题记录】遇到的问题及解决方案
62【下一步】建议的下一步操作
63
64每章节控制在100字以内,空章节可省略。
65请直接输出内容。"#;
66
67#[async_trait]
68impl Compressor for AiCompressor {
69    async fn summarize(
70        &self,
71        messages: &[Message],
72        _config: &CompressionConfig,
73    ) -> Result<SummarizedSegment> {
74        let prompt = build_summary_prompt(messages);
75
76        let request = ChatRequest {
77            messages: vec![Message {
78                role: Role::User,
79                content: MessageContent::Text(prompt),
80            }],
81            tools: vec![],
82            system: Some(SUMMARY_SYSTEM_PROMPT.to_string()),
83            think: false,
84            max_tokens: 1024,
85            server_tools: vec![],
86            enable_caching: false,
87        };
88
89        let response = self.provider.chat(request).await?;
90        let summary_text = extract_text_from_response(&response);
91        let (summary, key_points) = parse_summary_response(&summary_text);
92
93        Ok(SummarizedSegment {
94            time_range: (chrono::Utc::now(), chrono::Utc::now()),
95            original_count: messages.len(),
96            summary,
97            key_points,
98        })
99    }
100
101    fn model_name(&self) -> &str {
102        &self.model
103    }
104}
105
106fn extract_text_from_response(response: &ChatResponse) -> String {
107    response
108        .content
109        .iter()
110        .filter_map(|block| {
111            if let ContentBlock::Text { text } = block {
112                Some(text.clone())
113            } else {
114                None
115            }
116        })
117        .collect::<Vec<_>>()
118        .join("\n")
119}
120
121fn parse_summary_response(text: &str) -> (String, Vec<String>) {
122    let mut summary = String::new();
123    let mut key_points: Vec<String> = Vec::new();
124
125    // Parse 9-section structured format
126    let sections = [
127        "【摘要】", "【已完成】", "【未完成】", "【关键决策】",
128        "【敏感指令】", "【技术栈】", "【文件变更】", "【问题记录】", "【下一步】"
129    ];
130
131    for line in text.lines() {
132        let line = line.trim();
133
134        // Check if this is a section header
135        let is_header = sections.iter().any(|s| line.starts_with(s));
136
137        if is_header {
138            // Extract content after the header
139            for section in &sections {
140                if line.starts_with(section) {
141                    let replaced = line.replace(section, "");
142                    let content = replaced.trim();
143                    if !content.is_empty() {
144                        if *section == "【摘要】" {
145                            summary = content.to_string();
146                        } else {
147                            key_points.push(format!("{}{}", section, content));
148                        }
149                    }
150                    break;
151                }
152            }
153        } else if !line.is_empty() {
154            // This is content under a section
155            if line.starts_with("•") || line.starts_with("-") || line.starts_with("*") {
156                let point = line.trim_start_matches(['•', '-', '*']).trim();
157                if !point.is_empty() {
158                    key_points.push(point.to_string());
159                }
160            } else if summary.is_empty() {
161                // Fallback: first non-empty line as summary
162                summary = line.to_string();
163            }
164        }
165    }
166
167    // Fallback if no structured format found
168    if summary.is_empty() && !text.is_empty() {
169        summary = text.lines().take(3).collect::<Vec<_>>().join(" ");
170        if summary.len() > 200 {
171            summary = truncate_with_suffix(&summary, 200);
172        }
173    }
174
175    (summary, key_points)
176}
177
178// ============================================================================
179// Compression Functions
180// ============================================================================
181
182/// Compress messages synchronously.
183pub fn compress_messages(
184    messages: &[Message],
185    strategy: CompressionStrategy,
186    config: &CompressionConfig,
187) -> Result<Vec<Message>> {
188    match strategy {
189        CompressionStrategy::Truncate => truncate_compress(messages, config),
190        CompressionStrategy::SlidingWindow => sliding_window_compress(messages, config),
191        CompressionStrategy::Summarize => sliding_window_compress(messages, config),
192        CompressionStrategy::BiasBased => compress_with_bias(messages, config),
193    }
194}
195
196/// Compress with bias-based scoring.
197pub fn compress_with_bias(
198    messages: &[Message],
199    config: &CompressionConfig,
200) -> Result<Vec<Message>> {
201    if messages.len() <= config.min_preserve_messages {
202        return Ok(messages.to_vec());
203    }
204
205    let scored: Vec<(usize, Message, f64)> = messages
206        .iter()
207        .enumerate()
208        .map(|(idx, msg)| {
209            (
210                idx,
211                msg.clone(),
212                calculate_preservation_score(msg, idx, messages.len(), &config.bias),
213            )
214        })
215        .collect();
216
217    let mut scored_with_recency: Vec<(usize, Message, f64)> = scored
218        .into_iter()
219        .map(|(idx, msg, score)| {
220            let recency_bonus = if idx >= messages.len() - config.min_preserve_messages {
221                100.0
222            } else {
223                (idx as f64 / messages.len() as f64) * 20.0
224            };
225            (idx, msg, score + recency_bonus)
226        })
227        .collect();
228
229    scored_with_recency.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
230
231    let target_count = if config.bias.aggressive {
232        config.min_preserve_messages
233    } else {
234        let estimated = estimate_total_tokens(messages);
235        let target_tokens = (estimated as f64 * config.target_ratio) as u32;
236        let avg = estimated / messages.len() as u32;
237        (target_tokens / avg.max(1)) as usize
238    };
239
240    let to_keep: HashSet<usize> = scored_with_recency
241        .iter()
242        .take(target_count)
243        .map(|(idx, _, _)| *idx)
244        .collect();
245
246    let compressed: Vec<Message> = messages
247        .iter()
248        .enumerate()
249        .filter(|(idx, _)| to_keep.contains(idx))
250        .map(|(_, msg)| msg.clone())
251        .collect();
252
253    Ok(compressed)
254}
255
256fn calculate_preservation_score(
257    message: &Message,
258    index: usize,
259    _total: usize,  // Reserved for future use (total message count)
260    bias: &CompressionBias,
261) -> f64 {
262    let mut score: f64 = 10.0;
263
264    // First message (user's original request) gets highest priority
265    if index == 0 {
266        score += 100.0;
267    }
268
269    match message.role {
270        Role::User => {
271            if bias.preserve_user_questions {
272                score += 30.0;
273            }
274        }
275        Role::Assistant => {
276            score += 5.0;
277        }
278        Role::Tool => {
279            if bias.preserve_tools {
280                score += 25.0;
281            }
282        }
283        Role::System => {
284            score += 40.0;
285        }
286    }
287
288    match &message.content {
289        MessageContent::Text(text) => {
290            for keyword in &bias.preserve_keywords {
291                if text.to_lowercase().contains(&keyword.to_lowercase()) {
292                    score += 15.0;
293                }
294            }
295            if contains_sensitive_instructions(text) {
296                score += 50.0;
297            }
298        }
299        MessageContent::Blocks(blocks) => {
300            for block in blocks {
301                match block {
302                    ContentBlock::ToolUse { name, .. } => {
303                        if bias.preserve_tools {
304                            score += 20.0;
305                        }
306                        if name == "write" || name == "edit" || name == "bash" {
307                            score += 10.0;
308                        }
309                        // todo_write gets high priority - preserve task tracking
310                        if name == "todo_write" {
311                            score += 60.0;
312                        }
313                        // ask tool contains key decisions
314                        if name == "ask" {
315                            score += 50.0;
316                        }
317                    }
318                    ContentBlock::ToolResult { content, .. } => {
319                        if bias.preserve_tools {
320                            score += 20.0;
321                        }
322                        if contains_sensitive_instructions(content) {
323                            score += 30.0;
324                        }
325                        // Preserve todo_write results (task status)
326                        if content.contains("TodoWrite") || content.contains("todo") {
327                            score += 40.0;
328                        }
329                        // Preserve ask responses (user decisions)
330                        if content.contains("AskUserQuestion") || content.contains("answer") {
331                            score += 30.0;
332                        }
333                    }
334                    ContentBlock::Thinking { .. } => {
335                        if bias.preserve_thinking {
336                            score += 25.0;
337                        } else {
338                            score -= 5.0;
339                        }
340                    }
341                    ContentBlock::Text { text } => {
342                        if contains_sensitive_instructions(text) {
343                            score += 50.0;
344                        }
345                    }
346                    _ => {}
347                }
348            }
349        }
350    }
351
352    score
353}
354
355fn contains_sensitive_instructions(text: &str) -> bool {
356    let lower = text.to_lowercase();
357    let patterns = [
358        "不要",
359        "禁止",
360        "必须",
361        "不允许",
362        "never",
363        "must not",
364        "do not",
365    ];
366    patterns.iter().any(|p| lower.contains(p))
367}
368
369fn truncate_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
370    if messages.len() <= config.min_preserve_messages {
371        return Ok(messages.to_vec());
372    }
373    Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
374}
375
376fn sliding_window_compress(
377    messages: &[Message],
378    config: &CompressionConfig,
379) -> Result<Vec<Message>> {
380    if messages.len() <= config.min_preserve_messages {
381        return Ok(messages.to_vec());
382    }
383
384    // Enhanced sliding window strategy:
385    // 1. Always keep first message (original user request)
386    // 2. Summarize middle messages if too long
387    // 3. Keep recent messages intact
388
389    let first_msg = messages.first().cloned();
390    let recent_start = messages.len().saturating_sub(config.min_preserve_messages);
391    let recent_msgs = &messages[recent_start..];
392
393    // Calculate tokens for first + recent
394    let first_tokens = first_msg.as_ref().map(|m| estimate_tokens(m)).unwrap_or(0);
395    let recent_tokens = estimate_total_tokens(recent_msgs);
396    let current_total = estimate_total_tokens(messages);
397    let target_tokens = (current_total as f64 * config.target_ratio) as u32;
398
399    // If first + recent already exceeds target, just use recent (drop first)
400    if first_tokens + recent_tokens <= target_tokens {
401        // We can keep first message + recent messages
402        let mut result: Vec<Message> = Vec::new();
403        if let Some(first) = first_msg {
404            result.push(first);
405        }
406        result.extend(recent_msgs.iter().cloned());
407        return Ok(result);
408    }
409
410    // If still too long, try dropping older messages from recent section
411    for drop_count in 0..recent_msgs.len() {
412        let candidate = &recent_msgs[drop_count..];
413        if estimate_total_tokens(candidate) <= target_tokens {
414            return Ok(candidate.to_vec());
415        }
416    }
417
418    // Last resort: just keep minimum recent messages
419    Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
420}
421
422// ============================================================================
423// Token Estimation
424// ============================================================================
425
426/// Estimate token count for a message.
427pub fn estimate_tokens(message: &Message) -> u32 {
428    let (ascii, non_ascii) = match &message.content {
429        MessageContent::Text(t) => count_chars(t),
430        MessageContent::Blocks(blocks) => {
431            let mut a = 0u32;
432            let mut n = 0u32;
433            for block in blocks {
434                match block {
435                    ContentBlock::Text { text } => {
436                        let (ca, cn) = count_chars(text);
437                        a += ca;
438                        n += cn;
439                    }
440                    ContentBlock::ToolUse { name, input, .. } => {
441                        let (ca, cn) = count_chars(name);
442                        a += ca;
443                        n += cn;
444                        let (ja, jn) = count_chars(&input.to_string());
445                        a += ja;
446                        n += jn;
447                    }
448                    ContentBlock::ToolResult { content, .. } => {
449                        let (ca, cn) = count_chars(content);
450                        a += ca;
451                        n += cn;
452                    }
453                    ContentBlock::Thinking { thinking, .. } => {
454                        let (ca, cn) = count_chars(thinking);
455                        a += ca;
456                        n += cn;
457                    }
458                    _ => {}
459                }
460            }
461            (a, n)
462        }
463    };
464
465    let ascii_tokens = (ascii as f64 * 0.25).ceil() as u32;
466    let non_ascii_tokens = (non_ascii as f64 * 0.67).ceil() as u32;
467    (ascii_tokens + non_ascii_tokens + 10).max(1)
468}
469
470fn count_chars(s: &str) -> (u32, u32) {
471    let mut ascii = 0u32;
472    let mut non_ascii = 0u32;
473    for ch in s.chars() {
474        if ch.is_ascii() {
475            ascii += 1;
476        } else {
477            non_ascii += 1;
478        }
479    }
480    (ascii, non_ascii)
481}
482
483/// Estimate total tokens for a message list.
484pub fn estimate_total_tokens(messages: &[Message]) -> u32 {
485    messages.iter().map(estimate_tokens).sum()
486}
487
488/// Check if compression should be triggered.
489pub fn should_compress(
490    current_tokens: u32,
491    context_size: Option<u32>,
492    config: &CompressionConfig,
493) -> bool {
494    match context_size {
495        Some(size) => (current_tokens as f64 / size as f64) >= config.threshold,
496        None => false,
497    }
498}
499
500/// Build a prompt for summarization.
501pub fn build_summary_prompt(messages: &[Message]) -> String {
502    let history = messages
503        .iter()
504        .map(|m| {
505            let role = match m.role {
506                Role::User => "用户",
507                Role::Assistant => "助手",
508                Role::Tool => "工具",
509                Role::System => "系统",
510            };
511            let preview = match &m.content {
512                MessageContent::Text(t) => truncate_with_suffix(t, 200),
513                MessageContent::Blocks(blocks) => blocks
514                    .iter()
515                    .map(|b| match b {
516                        ContentBlock::Text { text } => truncate_with_suffix(text, 100),
517                        ContentBlock::ToolUse { name, .. } => format!("[工具: {}]", name),
518                        ContentBlock::ToolResult { content, .. } => {
519                            truncate_with_suffix(content, 100)
520                        }
521                        _ => "[...]".to_string(),
522                    })
523                    .collect::<Vec<_>>()
524                    .join(" | "),
525            };
526            format!("{}: {}", role, preview)
527        })
528        .collect::<Vec<_>>()
529        .join("\n");
530
531    format!(
532        "请将以下对话压缩为简洁摘要({} 条消息):\n{}",
533        messages.len(),
534        history
535    )
536}
537
538// ============================================================================
539// New Pipeline-Based Compression (Async)
540// ============================================================================
541
542use super::pipeline::CompressionPipeline;
543use super::types::AiCompressionMode;
544
545/// Compress messages with AI assistance (async version).
546///
547/// This is the new recommended API for compression with intelligent
548/// scoring, dependency tracking, and content summarization.
549pub async fn compress_messages_with_ai(
550    messages: &[Message],
551    config: &CompressionConfig,
552    ai_mode: AiCompressionMode,
553    fast_model: Option<Box<dyn Provider>>,
554    token_usage: u32,
555    context_window: u32,
556) -> Result<Vec<Message>> {
557    let mut pipeline = match (ai_mode, fast_model) {
558        (AiCompressionMode::None, _) => CompressionPipeline::new_rule_only(config.clone()),
559        (AiCompressionMode::Light | AiCompressionMode::Deep, Some(model)) => {
560            CompressionPipeline::new_with_ai(config.clone(), model)
561        }
562        _ => CompressionPipeline::new_rule_only(config.clone()),
563    };
564
565    let result = pipeline.execute(messages, ai_mode, token_usage, context_window).await?;
566    Ok(result.messages)
567}
568
569/// Compress messages with full AI support (async version).
570///
571/// Uses both fast_model and main_model for different compression tasks.
572pub async fn compress_messages_with_full_ai(
573    messages: &[Message],
574    config: &CompressionConfig,
575    ai_mode: AiCompressionMode,
576    fast_model: Box<dyn Provider>,
577    main_model: Box<dyn Provider>,
578    token_usage: u32,
579    context_window: u32,
580) -> Result<Vec<Message>> {
581    let mut pipeline = CompressionPipeline::new_with_full_ai(
582        config.clone(),
583        fast_model,
584        main_model,
585    );
586
587    let result = pipeline.execute(messages, ai_mode, token_usage, context_window).await?;
588    Ok(result.messages)
589}
590
591/// Score messages without compressing (analysis only).
592///
593/// Useful for debugging and understanding compression decisions.
594pub fn score_messages_only(
595    messages: &[Message],
596    config: &CompressionConfig,
597) -> Vec<super::types::ScoredMessage> {
598    let pipeline = CompressionPipeline::new_rule_only(config.clone());
599    pipeline.score_only(messages)
600}
601
602// ============================================================================
603// Tests
604// ============================================================================
605
606#[cfg(test)]
607mod tests {
608    use super::*;
609
610    #[test]
611    fn test_estimate_tokens_simple() {
612        let msg = Message {
613            role: Role::User,
614            content: MessageContent::Text("Hello world".to_string()),
615        };
616        assert!(estimate_tokens(&msg) >= 3);
617    }
618
619    #[test]
620    fn test_should_compress() {
621        let config = CompressionConfig::default();
622        // Threshold is 0.5, so 100K/200K = 0.5 triggers compression
623        assert!(should_compress(100_000, Some(200_000), &config));
624        // 80K/200K = 0.4, below threshold
625        assert!(!should_compress(80_000, Some(200_000), &config));
626    }
627}