Skip to main content

matrixcode_core/compress/
compressor.rs

1//! Compression functions and AI compressor implementation.
2
3use crate::providers::{
4    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
5};
6use crate::truncate::truncate_with_suffix;
7use anyhow::Result;
8use async_trait::async_trait;
9use std::collections::HashSet;
10
11use super::config::{CompressionBias, CompressionConfig};
12use super::types::{CompressionStrategy, SummarizedSegment};
13
14// ============================================================================
15// Compressor Trait
16// ============================================================================
17
18/// Compressor trait for different implementations.
19#[async_trait]
20pub trait Compressor: Send + Sync {
21    /// Compress messages using AI summarization.
22    async fn summarize(
23        &self,
24        messages: &[Message],
25        config: &CompressionConfig,
26    ) -> Result<SummarizedSegment>;
27
28    /// Get the model name used.
29    fn model_name(&self) -> &str;
30}
31
32/// AI-based compressor using a Provider.
33pub struct AiCompressor {
34    provider: Box<dyn Provider>,
35    model: String,
36}
37
38impl AiCompressor {
39    pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
40        Self { provider, model }
41    }
42}
43
44const SUMMARY_SYSTEM_PROMPT: &str = r#"你是一个对话历史压缩助手。将对话压缩为简洁摘要。
45
46输出要求:
47- 简洁:摘要控制在 200 字以内
48- 关键:只保留重要操作和决策
49- 敏感:必须保留用户的敏感指令
50
51请直接输出摘要内容。"#;
52
53#[async_trait]
54impl Compressor for AiCompressor {
55    async fn summarize(
56        &self,
57        messages: &[Message],
58        _config: &CompressionConfig,
59    ) -> Result<SummarizedSegment> {
60        let prompt = build_summary_prompt(messages);
61
62        let request = ChatRequest {
63            messages: vec![Message {
64                role: Role::User,
65                content: MessageContent::Text(prompt),
66            }],
67            tools: vec![],
68            system: Some(SUMMARY_SYSTEM_PROMPT.to_string()),
69            think: false,
70            max_tokens: 1024,
71            server_tools: vec![],
72            enable_caching: false,
73        };
74
75        let response = self.provider.chat(request).await?;
76        let summary_text = extract_text_from_response(&response);
77        let (summary, key_points) = parse_summary_response(&summary_text);
78
79        Ok(SummarizedSegment {
80            time_range: (chrono::Utc::now(), chrono::Utc::now()),
81            original_count: messages.len(),
82            summary,
83            key_points,
84        })
85    }
86
87    fn model_name(&self) -> &str {
88        &self.model
89    }
90}
91
92fn extract_text_from_response(response: &ChatResponse) -> String {
93    response
94        .content
95        .iter()
96        .filter_map(|block| {
97            if let ContentBlock::Text { text } = block {
98                Some(text.clone())
99            } else {
100                None
101            }
102        })
103        .collect::<Vec<_>>()
104        .join("\n")
105}
106
107fn parse_summary_response(text: &str) -> (String, Vec<String>) {
108    let mut summary = String::new();
109    let mut key_points: Vec<String> = Vec::new();
110
111    for line in text.lines() {
112        let line = line.trim();
113        if line.starts_with("•") || line.starts_with("-") || line.starts_with("*") {
114            let point = line.trim_start_matches(['•', '-', '*']).trim();
115            if !point.is_empty() {
116                key_points.push(point.to_string());
117            }
118        } else if !line.is_empty() && summary.is_empty() {
119            summary = line.to_string();
120        }
121    }
122
123    if summary.is_empty() && !text.is_empty() {
124        summary = text.lines().take(3).collect::<Vec<_>>().join(" ");
125        if summary.len() > 200 {
126            summary = truncate_with_suffix(&summary, 200);
127        }
128    }
129
130    (summary, key_points)
131}
132
133// ============================================================================
134// Compression Functions
135// ============================================================================
136
137/// Compress messages synchronously.
138pub fn compress_messages(
139    messages: &[Message],
140    strategy: CompressionStrategy,
141    config: &CompressionConfig,
142) -> Result<Vec<Message>> {
143    match strategy {
144        CompressionStrategy::Truncate => truncate_compress(messages, config),
145        CompressionStrategy::SlidingWindow => sliding_window_compress(messages, config),
146        CompressionStrategy::Summarize => sliding_window_compress(messages, config),
147        CompressionStrategy::BiasBased => compress_with_bias(messages, config),
148    }
149}
150
151/// Compress with bias-based scoring.
152pub fn compress_with_bias(
153    messages: &[Message],
154    config: &CompressionConfig,
155) -> Result<Vec<Message>> {
156    if messages.len() <= config.min_preserve_messages {
157        return Ok(messages.to_vec());
158    }
159
160    let scored: Vec<(usize, Message, f64)> = messages
161        .iter()
162        .enumerate()
163        .map(|(idx, msg)| {
164            (
165                idx,
166                msg.clone(),
167                calculate_preservation_score(msg, idx, messages.len(), &config.bias),
168            )
169        })
170        .collect();
171
172    let mut scored_with_recency: Vec<(usize, Message, f64)> = scored
173        .into_iter()
174        .map(|(idx, msg, score)| {
175            let recency_bonus = if idx >= messages.len() - config.min_preserve_messages {
176                100.0
177            } else {
178                (idx as f64 / messages.len() as f64) * 20.0
179            };
180            (idx, msg, score + recency_bonus)
181        })
182        .collect();
183
184    scored_with_recency.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
185
186    let target_count = if config.bias.aggressive {
187        config.min_preserve_messages
188    } else {
189        let estimated = estimate_total_tokens(messages);
190        let target_tokens = (estimated as f64 * config.target_ratio) as u32;
191        let avg = estimated / messages.len() as u32;
192        (target_tokens / avg.max(1)) as usize
193    };
194
195    let to_keep: HashSet<usize> = scored_with_recency
196        .iter()
197        .take(target_count)
198        .map(|(idx, _, _)| *idx)
199        .collect();
200
201    let compressed: Vec<Message> = messages
202        .iter()
203        .enumerate()
204        .filter(|(idx, _)| to_keep.contains(idx))
205        .map(|(_, msg)| msg.clone())
206        .collect();
207
208    Ok(compressed)
209}
210
211fn calculate_preservation_score(
212    message: &Message,
213    _index: usize,
214    _total: usize,
215    bias: &CompressionBias,
216) -> f64 {
217    let mut score: f64 = 10.0;
218
219    match message.role {
220        Role::User => {
221            if bias.preserve_user_questions {
222                score += 30.0;
223            }
224        }
225        Role::Assistant => {
226            score += 5.0;
227        }
228        Role::Tool => {
229            if bias.preserve_tools {
230                score += 25.0;
231            }
232        }
233        Role::System => {
234            score += 40.0;
235        }
236    }
237
238    match &message.content {
239        MessageContent::Text(text) => {
240            for keyword in &bias.preserve_keywords {
241                if text.to_lowercase().contains(&keyword.to_lowercase()) {
242                    score += 15.0;
243                }
244            }
245            if contains_sensitive_instructions(text) {
246                score += 50.0;
247            }
248        }
249        MessageContent::Blocks(blocks) => {
250            for block in blocks {
251                match block {
252                    ContentBlock::ToolUse { name, .. } => {
253                        if bias.preserve_tools {
254                            score += 20.0;
255                        }
256                        if name == "write" || name == "edit" || name == "bash" {
257                            score += 10.0;
258                        }
259                    }
260                    ContentBlock::ToolResult { content, .. } => {
261                        if bias.preserve_tools {
262                            score += 20.0;
263                        }
264                        if contains_sensitive_instructions(content) {
265                            score += 30.0;
266                        }
267                    }
268                    ContentBlock::Thinking { .. } => {
269                        if bias.preserve_thinking {
270                            score += 25.0;
271                        } else {
272                            score -= 5.0;
273                        }
274                    }
275                    ContentBlock::Text { text } => {
276                        if contains_sensitive_instructions(text) {
277                            score += 50.0;
278                        }
279                    }
280                    _ => {}
281                }
282            }
283        }
284    }
285
286    score
287}
288
289fn contains_sensitive_instructions(text: &str) -> bool {
290    let lower = text.to_lowercase();
291    let patterns = [
292        "不要",
293        "禁止",
294        "必须",
295        "不允许",
296        "never",
297        "must not",
298        "do not",
299    ];
300    patterns.iter().any(|p| lower.contains(p))
301}
302
303fn truncate_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
304    if messages.len() <= config.min_preserve_messages {
305        return Ok(messages.to_vec());
306    }
307    Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
308}
309
310fn sliding_window_compress(
311    messages: &[Message],
312    config: &CompressionConfig,
313) -> Result<Vec<Message>> {
314    if messages.len() <= config.min_preserve_messages {
315        return Ok(messages.to_vec());
316    }
317
318    let target_tokens = (estimate_total_tokens(messages) as f64 * config.target_ratio) as u32;
319
320    for start_idx in config.min_preserve_messages..messages.len() {
321        let candidate = &messages[start_idx..];
322        if estimate_total_tokens(candidate) <= target_tokens {
323            return Ok(candidate.to_vec());
324        }
325    }
326
327    Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
328}
329
330// ============================================================================
331// Token Estimation
332// ============================================================================
333
334/// Estimate token count for a message.
335pub fn estimate_tokens(message: &Message) -> u32 {
336    let (ascii, non_ascii) = match &message.content {
337        MessageContent::Text(t) => count_chars(t),
338        MessageContent::Blocks(blocks) => {
339            let mut a = 0u32;
340            let mut n = 0u32;
341            for block in blocks {
342                match block {
343                    ContentBlock::Text { text } => {
344                        let (ca, cn) = count_chars(text);
345                        a += ca;
346                        n += cn;
347                    }
348                    ContentBlock::ToolUse { name, input, .. } => {
349                        let (ca, cn) = count_chars(name);
350                        a += ca;
351                        n += cn;
352                        let (ja, jn) = count_chars(&input.to_string());
353                        a += ja;
354                        n += jn;
355                    }
356                    ContentBlock::ToolResult { content, .. } => {
357                        let (ca, cn) = count_chars(content);
358                        a += ca;
359                        n += cn;
360                    }
361                    ContentBlock::Thinking { thinking, .. } => {
362                        let (ca, cn) = count_chars(thinking);
363                        a += ca;
364                        n += cn;
365                    }
366                    _ => {}
367                }
368            }
369            (a, n)
370        }
371    };
372
373    let ascii_tokens = (ascii as f64 * 0.25).ceil() as u32;
374    let non_ascii_tokens = (non_ascii as f64 * 0.67).ceil() as u32;
375    (ascii_tokens + non_ascii_tokens + 10).max(1)
376}
377
378fn count_chars(s: &str) -> (u32, u32) {
379    let mut ascii = 0u32;
380    let mut non_ascii = 0u32;
381    for ch in s.chars() {
382        if ch.is_ascii() {
383            ascii += 1;
384        } else {
385            non_ascii += 1;
386        }
387    }
388    (ascii, non_ascii)
389}
390
391/// Estimate total tokens for a message list.
392pub fn estimate_total_tokens(messages: &[Message]) -> u32 {
393    messages.iter().map(estimate_tokens).sum()
394}
395
396/// Check if compression should be triggered.
397pub fn should_compress(
398    current_tokens: u32,
399    context_size: Option<u32>,
400    config: &CompressionConfig,
401) -> bool {
402    match context_size {
403        Some(size) => (current_tokens as f64 / size as f64) >= config.threshold,
404        None => false,
405    }
406}
407
408/// Build a prompt for summarization.
409pub fn build_summary_prompt(messages: &[Message]) -> String {
410    let history = messages
411        .iter()
412        .map(|m| {
413            let role = match m.role {
414                Role::User => "用户",
415                Role::Assistant => "助手",
416                Role::Tool => "工具",
417                Role::System => "系统",
418            };
419            let preview = match &m.content {
420                MessageContent::Text(t) => truncate_with_suffix(t, 200),
421                MessageContent::Blocks(blocks) => blocks
422                    .iter()
423                    .map(|b| match b {
424                        ContentBlock::Text { text } => truncate_with_suffix(text, 100),
425                        ContentBlock::ToolUse { name, .. } => format!("[工具: {}]", name),
426                        ContentBlock::ToolResult { content, .. } => {
427                            truncate_with_suffix(content, 100)
428                        }
429                        _ => "[...]".to_string(),
430                    })
431                    .collect::<Vec<_>>()
432                    .join(" | "),
433            };
434            format!("{}: {}", role, preview)
435        })
436        .collect::<Vec<_>>()
437        .join("\n");
438
439    format!(
440        "请将以下对话压缩为简洁摘要({} 条消息):\n{}",
441        messages.len(),
442        history
443    )
444}
445
446// ============================================================================
447// Tests
448// ============================================================================
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    #[test]
455    fn test_estimate_tokens_simple() {
456        let msg = Message {
457            role: Role::User,
458            content: MessageContent::Text("Hello world".to_string()),
459        };
460        assert!(estimate_tokens(&msg) >= 3);
461    }
462
463    #[test]
464    fn test_should_compress() {
465        let config = CompressionConfig::default();
466        assert!(!should_compress(100_000, Some(200_000), &config));
467        assert!(should_compress(160_000, Some(200_000), &config));
468    }
469}