matrixcode_core/compress/
compressor.rs1use crate::providers::{
4 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
5};
6use crate::truncate::truncate_with_suffix;
7use anyhow::Result;
8use async_trait::async_trait;
9use std::collections::HashSet;
10
11use super::config::{CompressionBias, CompressionConfig};
12use super::types::{CompressionStrategy, SummarizedSegment};
13
14#[async_trait]
20pub trait Compressor: Send + Sync {
21 async fn summarize(
23 &self,
24 messages: &[Message],
25 config: &CompressionConfig,
26 ) -> Result<SummarizedSegment>;
27
28 fn model_name(&self) -> &str;
30}
31
32pub struct AiCompressor {
34 provider: Box<dyn Provider>,
35 model: String,
36}
37
38impl AiCompressor {
39 pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
40 Self { provider, model }
41 }
42}
43
44const SUMMARY_SYSTEM_PROMPT: &str = r#"你是一个对话历史压缩助手。将对话压缩为简洁摘要。
45
46输出要求:
47- 简洁:摘要控制在 200 字以内
48- 关键:只保留重要操作和决策
49- 敏感:必须保留用户的敏感指令
50- 任务:必须保留未完成的待办事项
51- 决策:必须保留关键方案选择和理由
52
53输出格式:
54【摘要】一句话概括主要工作
55【已完成】列出已完成的操作
56【未完成】列出待办任务(如有)
57【关键决策】重要选择及理由(如有)
58
59请直接输出内容。"#;
60
61#[async_trait]
62impl Compressor for AiCompressor {
63 async fn summarize(
64 &self,
65 messages: &[Message],
66 _config: &CompressionConfig,
67 ) -> Result<SummarizedSegment> {
68 let prompt = build_summary_prompt(messages);
69
70 let request = ChatRequest {
71 messages: vec![Message {
72 role: Role::User,
73 content: MessageContent::Text(prompt),
74 }],
75 tools: vec![],
76 system: Some(SUMMARY_SYSTEM_PROMPT.to_string()),
77 think: false,
78 max_tokens: 1024,
79 server_tools: vec![],
80 enable_caching: false,
81 };
82
83 let response = self.provider.chat(request).await?;
84 let summary_text = extract_text_from_response(&response);
85 let (summary, key_points) = parse_summary_response(&summary_text);
86
87 Ok(SummarizedSegment {
88 time_range: (chrono::Utc::now(), chrono::Utc::now()),
89 original_count: messages.len(),
90 summary,
91 key_points,
92 })
93 }
94
95 fn model_name(&self) -> &str {
96 &self.model
97 }
98}
99
100fn extract_text_from_response(response: &ChatResponse) -> String {
101 response
102 .content
103 .iter()
104 .filter_map(|block| {
105 if let ContentBlock::Text { text } = block {
106 Some(text.clone())
107 } else {
108 None
109 }
110 })
111 .collect::<Vec<_>>()
112 .join("\n")
113}
114
115fn parse_summary_response(text: &str) -> (String, Vec<String>) {
116 let mut summary = String::new();
117 let mut key_points: Vec<String> = Vec::new();
118
119 for line in text.lines() {
120 let line = line.trim();
121 if line.starts_with("•") || line.starts_with("-") || line.starts_with("*") {
122 let point = line.trim_start_matches(['•', '-', '*']).trim();
123 if !point.is_empty() {
124 key_points.push(point.to_string());
125 }
126 } else if !line.is_empty() && summary.is_empty() {
127 summary = line.to_string();
128 }
129 }
130
131 if summary.is_empty() && !text.is_empty() {
132 summary = text.lines().take(3).collect::<Vec<_>>().join(" ");
133 if summary.len() > 200 {
134 summary = truncate_with_suffix(&summary, 200);
135 }
136 }
137
138 (summary, key_points)
139}
140
141pub fn compress_messages(
147 messages: &[Message],
148 strategy: CompressionStrategy,
149 config: &CompressionConfig,
150) -> Result<Vec<Message>> {
151 match strategy {
152 CompressionStrategy::Truncate => truncate_compress(messages, config),
153 CompressionStrategy::SlidingWindow => sliding_window_compress(messages, config),
154 CompressionStrategy::Summarize => sliding_window_compress(messages, config),
155 CompressionStrategy::BiasBased => compress_with_bias(messages, config),
156 }
157}
158
159pub fn compress_with_bias(
161 messages: &[Message],
162 config: &CompressionConfig,
163) -> Result<Vec<Message>> {
164 if messages.len() <= config.min_preserve_messages {
165 return Ok(messages.to_vec());
166 }
167
168 let scored: Vec<(usize, Message, f64)> = messages
169 .iter()
170 .enumerate()
171 .map(|(idx, msg)| {
172 (
173 idx,
174 msg.clone(),
175 calculate_preservation_score(msg, idx, messages.len(), &config.bias),
176 )
177 })
178 .collect();
179
180 let mut scored_with_recency: Vec<(usize, Message, f64)> = scored
181 .into_iter()
182 .map(|(idx, msg, score)| {
183 let recency_bonus = if idx >= messages.len() - config.min_preserve_messages {
184 100.0
185 } else {
186 (idx as f64 / messages.len() as f64) * 20.0
187 };
188 (idx, msg, score + recency_bonus)
189 })
190 .collect();
191
192 scored_with_recency.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
193
194 let target_count = if config.bias.aggressive {
195 config.min_preserve_messages
196 } else {
197 let estimated = estimate_total_tokens(messages);
198 let target_tokens = (estimated as f64 * config.target_ratio) as u32;
199 let avg = estimated / messages.len() as u32;
200 (target_tokens / avg.max(1)) as usize
201 };
202
203 let to_keep: HashSet<usize> = scored_with_recency
204 .iter()
205 .take(target_count)
206 .map(|(idx, _, _)| *idx)
207 .collect();
208
209 let compressed: Vec<Message> = messages
210 .iter()
211 .enumerate()
212 .filter(|(idx, _)| to_keep.contains(idx))
213 .map(|(_, msg)| msg.clone())
214 .collect();
215
216 Ok(compressed)
217}
218
219fn calculate_preservation_score(
220 message: &Message,
221 index: usize,
222 _total: usize, bias: &CompressionBias,
224) -> f64 {
225 let mut score: f64 = 10.0;
226
227 if index == 0 {
229 score += 100.0;
230 }
231
232 match message.role {
233 Role::User => {
234 if bias.preserve_user_questions {
235 score += 30.0;
236 }
237 }
238 Role::Assistant => {
239 score += 5.0;
240 }
241 Role::Tool => {
242 if bias.preserve_tools {
243 score += 25.0;
244 }
245 }
246 Role::System => {
247 score += 40.0;
248 }
249 }
250
251 match &message.content {
252 MessageContent::Text(text) => {
253 for keyword in &bias.preserve_keywords {
254 if text.to_lowercase().contains(&keyword.to_lowercase()) {
255 score += 15.0;
256 }
257 }
258 if contains_sensitive_instructions(text) {
259 score += 50.0;
260 }
261 }
262 MessageContent::Blocks(blocks) => {
263 for block in blocks {
264 match block {
265 ContentBlock::ToolUse { name, .. } => {
266 if bias.preserve_tools {
267 score += 20.0;
268 }
269 if name == "write" || name == "edit" || name == "bash" {
270 score += 10.0;
271 }
272 if name == "todo_write" {
274 score += 60.0;
275 }
276 if name == "ask" {
278 score += 50.0;
279 }
280 }
281 ContentBlock::ToolResult { content, .. } => {
282 if bias.preserve_tools {
283 score += 20.0;
284 }
285 if contains_sensitive_instructions(content) {
286 score += 30.0;
287 }
288 if content.contains("TodoWrite") || content.contains("todo") {
290 score += 40.0;
291 }
292 if content.contains("AskUserQuestion") || content.contains("answer") {
294 score += 30.0;
295 }
296 }
297 ContentBlock::Thinking { .. } => {
298 if bias.preserve_thinking {
299 score += 25.0;
300 } else {
301 score -= 5.0;
302 }
303 }
304 ContentBlock::Text { text } => {
305 if contains_sensitive_instructions(text) {
306 score += 50.0;
307 }
308 }
309 _ => {}
310 }
311 }
312 }
313 }
314
315 score
316}
317
318fn contains_sensitive_instructions(text: &str) -> bool {
319 let lower = text.to_lowercase();
320 let patterns = [
321 "不要",
322 "禁止",
323 "必须",
324 "不允许",
325 "never",
326 "must not",
327 "do not",
328 ];
329 patterns.iter().any(|p| lower.contains(p))
330}
331
332fn truncate_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
333 if messages.len() <= config.min_preserve_messages {
334 return Ok(messages.to_vec());
335 }
336 Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
337}
338
339fn sliding_window_compress(
340 messages: &[Message],
341 config: &CompressionConfig,
342) -> Result<Vec<Message>> {
343 if messages.len() <= config.min_preserve_messages {
344 return Ok(messages.to_vec());
345 }
346
347 let first_msg = messages.first().cloned();
353 let recent_start = messages.len().saturating_sub(config.min_preserve_messages);
354 let recent_msgs = &messages[recent_start..];
355
356 let first_tokens = first_msg.as_ref().map(|m| estimate_tokens(m)).unwrap_or(0);
358 let recent_tokens = estimate_total_tokens(recent_msgs);
359 let current_total = estimate_total_tokens(messages);
360 let target_tokens = (current_total as f64 * config.target_ratio) as u32;
361
362 if first_tokens + recent_tokens <= target_tokens {
364 let mut result: Vec<Message> = Vec::new();
366 if let Some(first) = first_msg {
367 result.push(first);
368 }
369 result.extend(recent_msgs.iter().cloned());
370 return Ok(result);
371 }
372
373 for drop_count in 0..recent_msgs.len() {
375 let candidate = &recent_msgs[drop_count..];
376 if estimate_total_tokens(candidate) <= target_tokens {
377 return Ok(candidate.to_vec());
378 }
379 }
380
381 Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
383}
384
385pub fn estimate_tokens(message: &Message) -> u32 {
391 let (ascii, non_ascii) = match &message.content {
392 MessageContent::Text(t) => count_chars(t),
393 MessageContent::Blocks(blocks) => {
394 let mut a = 0u32;
395 let mut n = 0u32;
396 for block in blocks {
397 match block {
398 ContentBlock::Text { text } => {
399 let (ca, cn) = count_chars(text);
400 a += ca;
401 n += cn;
402 }
403 ContentBlock::ToolUse { name, input, .. } => {
404 let (ca, cn) = count_chars(name);
405 a += ca;
406 n += cn;
407 let (ja, jn) = count_chars(&input.to_string());
408 a += ja;
409 n += jn;
410 }
411 ContentBlock::ToolResult { content, .. } => {
412 let (ca, cn) = count_chars(content);
413 a += ca;
414 n += cn;
415 }
416 ContentBlock::Thinking { thinking, .. } => {
417 let (ca, cn) = count_chars(thinking);
418 a += ca;
419 n += cn;
420 }
421 _ => {}
422 }
423 }
424 (a, n)
425 }
426 };
427
428 let ascii_tokens = (ascii as f64 * 0.25).ceil() as u32;
429 let non_ascii_tokens = (non_ascii as f64 * 0.67).ceil() as u32;
430 (ascii_tokens + non_ascii_tokens + 10).max(1)
431}
432
433fn count_chars(s: &str) -> (u32, u32) {
434 let mut ascii = 0u32;
435 let mut non_ascii = 0u32;
436 for ch in s.chars() {
437 if ch.is_ascii() {
438 ascii += 1;
439 } else {
440 non_ascii += 1;
441 }
442 }
443 (ascii, non_ascii)
444}
445
446pub fn estimate_total_tokens(messages: &[Message]) -> u32 {
448 messages.iter().map(estimate_tokens).sum()
449}
450
451pub fn should_compress(
453 current_tokens: u32,
454 context_size: Option<u32>,
455 config: &CompressionConfig,
456) -> bool {
457 match context_size {
458 Some(size) => (current_tokens as f64 / size as f64) >= config.threshold,
459 None => false,
460 }
461}
462
463pub fn build_summary_prompt(messages: &[Message]) -> String {
465 let history = messages
466 .iter()
467 .map(|m| {
468 let role = match m.role {
469 Role::User => "用户",
470 Role::Assistant => "助手",
471 Role::Tool => "工具",
472 Role::System => "系统",
473 };
474 let preview = match &m.content {
475 MessageContent::Text(t) => truncate_with_suffix(t, 200),
476 MessageContent::Blocks(blocks) => blocks
477 .iter()
478 .map(|b| match b {
479 ContentBlock::Text { text } => truncate_with_suffix(text, 100),
480 ContentBlock::ToolUse { name, .. } => format!("[工具: {}]", name),
481 ContentBlock::ToolResult { content, .. } => {
482 truncate_with_suffix(content, 100)
483 }
484 _ => "[...]".to_string(),
485 })
486 .collect::<Vec<_>>()
487 .join(" | "),
488 };
489 format!("{}: {}", role, preview)
490 })
491 .collect::<Vec<_>>()
492 .join("\n");
493
494 format!(
495 "请将以下对话压缩为简洁摘要({} 条消息):\n{}",
496 messages.len(),
497 history
498 )
499}
500
501#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn test_estimate_tokens_simple() {
511 let msg = Message {
512 role: Role::User,
513 content: MessageContent::Text("Hello world".to_string()),
514 };
515 assert!(estimate_tokens(&msg) >= 3);
516 }
517
518 #[test]
519 fn test_should_compress() {
520 let config = CompressionConfig::default();
521 assert!(!should_compress(100_000, Some(200_000), &config));
522 assert!(should_compress(160_000, Some(200_000), &config));
523 }
524}