matrixcode_core/compress/
compressor.rs1use crate::providers::{
4 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
5};
6use crate::truncate::truncate_with_suffix;
7use anyhow::Result;
8use async_trait::async_trait;
9use std::collections::HashSet;
10
11use super::config::{CompressionBias, CompressionConfig};
12use super::types::{CompressionStrategy, SummarizedSegment};
13
14#[async_trait]
20pub trait Compressor: Send + Sync {
21 async fn summarize(
23 &self,
24 messages: &[Message],
25 config: &CompressionConfig,
26 ) -> Result<SummarizedSegment>;
27
28 fn model_name(&self) -> &str;
30}
31
32pub struct AiCompressor {
34 provider: Box<dyn Provider>,
35 model: String,
36}
37
38impl AiCompressor {
39 pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
40 Self { provider, model }
41 }
42}
43
44const SUMMARY_SYSTEM_PROMPT: &str = r#"你是一个对话历史压缩助手。将对话压缩为简洁摘要。
45
46输出要求:
47- 简洁:摘要控制在 200 字以内
48- 关键:只保留重要操作和决策
49- 敏感:必须保留用户的敏感指令
50
51请直接输出摘要内容。"#;
52
53#[async_trait]
54impl Compressor for AiCompressor {
55 async fn summarize(
56 &self,
57 messages: &[Message],
58 _config: &CompressionConfig,
59 ) -> Result<SummarizedSegment> {
60 let prompt = build_summary_prompt(messages);
61
62 let request = ChatRequest {
63 messages: vec![Message {
64 role: Role::User,
65 content: MessageContent::Text(prompt),
66 }],
67 tools: vec![],
68 system: Some(SUMMARY_SYSTEM_PROMPT.to_string()),
69 think: false,
70 max_tokens: 1024,
71 server_tools: vec![],
72 enable_caching: false,
73 };
74
75 let response = self.provider.chat(request).await?;
76 let summary_text = extract_text_from_response(&response);
77 let (summary, key_points) = parse_summary_response(&summary_text);
78
79 Ok(SummarizedSegment {
80 time_range: (chrono::Utc::now(), chrono::Utc::now()),
81 original_count: messages.len(),
82 summary,
83 key_points,
84 })
85 }
86
87 fn model_name(&self) -> &str {
88 &self.model
89 }
90}
91
92fn extract_text_from_response(response: &ChatResponse) -> String {
93 response
94 .content
95 .iter()
96 .filter_map(|block| {
97 if let ContentBlock::Text { text } = block {
98 Some(text.clone())
99 } else {
100 None
101 }
102 })
103 .collect::<Vec<_>>()
104 .join("\n")
105}
106
107fn parse_summary_response(text: &str) -> (String, Vec<String>) {
108 let mut summary = String::new();
109 let mut key_points: Vec<String> = Vec::new();
110
111 for line in text.lines() {
112 let line = line.trim();
113 if line.starts_with("•") || line.starts_with("-") || line.starts_with("*") {
114 let point = line.trim_start_matches(['•', '-', '*']).trim();
115 if !point.is_empty() {
116 key_points.push(point.to_string());
117 }
118 } else if !line.is_empty() && summary.is_empty() {
119 summary = line.to_string();
120 }
121 }
122
123 if summary.is_empty() && !text.is_empty() {
124 summary = text.lines().take(3).collect::<Vec<_>>().join(" ");
125 if summary.len() > 200 {
126 summary = truncate_with_suffix(&summary, 200);
127 }
128 }
129
130 (summary, key_points)
131}
132
133pub fn compress_messages(
139 messages: &[Message],
140 strategy: CompressionStrategy,
141 config: &CompressionConfig,
142) -> Result<Vec<Message>> {
143 match strategy {
144 CompressionStrategy::Truncate => truncate_compress(messages, config),
145 CompressionStrategy::SlidingWindow => sliding_window_compress(messages, config),
146 CompressionStrategy::Summarize => sliding_window_compress(messages, config),
147 CompressionStrategy::BiasBased => compress_with_bias(messages, config),
148 }
149}
150
151pub fn compress_with_bias(
153 messages: &[Message],
154 config: &CompressionConfig,
155) -> Result<Vec<Message>> {
156 if messages.len() <= config.min_preserve_messages {
157 return Ok(messages.to_vec());
158 }
159
160 let scored: Vec<(usize, Message, f64)> = messages
161 .iter()
162 .enumerate()
163 .map(|(idx, msg)| {
164 (
165 idx,
166 msg.clone(),
167 calculate_preservation_score(msg, idx, messages.len(), &config.bias),
168 )
169 })
170 .collect();
171
172 let mut scored_with_recency: Vec<(usize, Message, f64)> = scored
173 .into_iter()
174 .map(|(idx, msg, score)| {
175 let recency_bonus = if idx >= messages.len() - config.min_preserve_messages {
176 100.0
177 } else {
178 (idx as f64 / messages.len() as f64) * 20.0
179 };
180 (idx, msg, score + recency_bonus)
181 })
182 .collect();
183
184 scored_with_recency.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
185
186 let target_count = if config.bias.aggressive {
187 config.min_preserve_messages
188 } else {
189 let estimated = estimate_total_tokens(messages);
190 let target_tokens = (estimated as f64 * config.target_ratio) as u32;
191 let avg = estimated / messages.len() as u32;
192 (target_tokens / avg.max(1)) as usize
193 };
194
195 let to_keep: HashSet<usize> = scored_with_recency
196 .iter()
197 .take(target_count)
198 .map(|(idx, _, _)| *idx)
199 .collect();
200
201 let compressed: Vec<Message> = messages
202 .iter()
203 .enumerate()
204 .filter(|(idx, _)| to_keep.contains(idx))
205 .map(|(_, msg)| msg.clone())
206 .collect();
207
208 Ok(compressed)
209}
210
211fn calculate_preservation_score(
212 message: &Message,
213 _index: usize,
214 _total: usize,
215 bias: &CompressionBias,
216) -> f64 {
217 let mut score: f64 = 10.0;
218
219 match message.role {
220 Role::User => {
221 if bias.preserve_user_questions {
222 score += 30.0;
223 }
224 }
225 Role::Assistant => {
226 score += 5.0;
227 }
228 Role::Tool => {
229 if bias.preserve_tools {
230 score += 25.0;
231 }
232 }
233 Role::System => {
234 score += 40.0;
235 }
236 }
237
238 match &message.content {
239 MessageContent::Text(text) => {
240 for keyword in &bias.preserve_keywords {
241 if text.to_lowercase().contains(&keyword.to_lowercase()) {
242 score += 15.0;
243 }
244 }
245 if contains_sensitive_instructions(text) {
246 score += 50.0;
247 }
248 }
249 MessageContent::Blocks(blocks) => {
250 for block in blocks {
251 match block {
252 ContentBlock::ToolUse { name, .. } => {
253 if bias.preserve_tools {
254 score += 20.0;
255 }
256 if name == "write" || name == "edit" || name == "bash" {
257 score += 10.0;
258 }
259 }
260 ContentBlock::ToolResult { content, .. } => {
261 if bias.preserve_tools {
262 score += 20.0;
263 }
264 if contains_sensitive_instructions(content) {
265 score += 30.0;
266 }
267 }
268 ContentBlock::Thinking { .. } => {
269 if bias.preserve_thinking {
270 score += 25.0;
271 } else {
272 score -= 5.0;
273 }
274 }
275 ContentBlock::Text { text } => {
276 if contains_sensitive_instructions(text) {
277 score += 50.0;
278 }
279 }
280 _ => {}
281 }
282 }
283 }
284 }
285
286 score
287}
288
289fn contains_sensitive_instructions(text: &str) -> bool {
290 let lower = text.to_lowercase();
291 let patterns = [
292 "不要",
293 "禁止",
294 "必须",
295 "不允许",
296 "never",
297 "must not",
298 "do not",
299 ];
300 patterns.iter().any(|p| lower.contains(p))
301}
302
303fn truncate_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
304 if messages.len() <= config.min_preserve_messages {
305 return Ok(messages.to_vec());
306 }
307 Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
308}
309
310fn sliding_window_compress(
311 messages: &[Message],
312 config: &CompressionConfig,
313) -> Result<Vec<Message>> {
314 if messages.len() <= config.min_preserve_messages {
315 return Ok(messages.to_vec());
316 }
317
318 let target_tokens = (estimate_total_tokens(messages) as f64 * config.target_ratio) as u32;
319
320 for start_idx in config.min_preserve_messages..messages.len() {
321 let candidate = &messages[start_idx..];
322 if estimate_total_tokens(candidate) <= target_tokens {
323 return Ok(candidate.to_vec());
324 }
325 }
326
327 Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
328}
329
330pub fn estimate_tokens(message: &Message) -> u32 {
336 let (ascii, non_ascii) = match &message.content {
337 MessageContent::Text(t) => count_chars(t),
338 MessageContent::Blocks(blocks) => {
339 let mut a = 0u32;
340 let mut n = 0u32;
341 for block in blocks {
342 match block {
343 ContentBlock::Text { text } => {
344 let (ca, cn) = count_chars(text);
345 a += ca;
346 n += cn;
347 }
348 ContentBlock::ToolUse { name, input, .. } => {
349 let (ca, cn) = count_chars(name);
350 a += ca;
351 n += cn;
352 let (ja, jn) = count_chars(&input.to_string());
353 a += ja;
354 n += jn;
355 }
356 ContentBlock::ToolResult { content, .. } => {
357 let (ca, cn) = count_chars(content);
358 a += ca;
359 n += cn;
360 }
361 ContentBlock::Thinking { thinking, .. } => {
362 let (ca, cn) = count_chars(thinking);
363 a += ca;
364 n += cn;
365 }
366 _ => {}
367 }
368 }
369 (a, n)
370 }
371 };
372
373 let ascii_tokens = (ascii as f64 * 0.25).ceil() as u32;
374 let non_ascii_tokens = (non_ascii as f64 * 0.67).ceil() as u32;
375 (ascii_tokens + non_ascii_tokens + 10).max(1)
376}
377
378fn count_chars(s: &str) -> (u32, u32) {
379 let mut ascii = 0u32;
380 let mut non_ascii = 0u32;
381 for ch in s.chars() {
382 if ch.is_ascii() {
383 ascii += 1;
384 } else {
385 non_ascii += 1;
386 }
387 }
388 (ascii, non_ascii)
389}
390
391pub fn estimate_total_tokens(messages: &[Message]) -> u32 {
393 messages.iter().map(estimate_tokens).sum()
394}
395
396pub fn should_compress(
398 current_tokens: u32,
399 context_size: Option<u32>,
400 config: &CompressionConfig,
401) -> bool {
402 match context_size {
403 Some(size) => (current_tokens as f64 / size as f64) >= config.threshold,
404 None => false,
405 }
406}
407
408pub fn build_summary_prompt(messages: &[Message]) -> String {
410 let history = messages
411 .iter()
412 .map(|m| {
413 let role = match m.role {
414 Role::User => "用户",
415 Role::Assistant => "助手",
416 Role::Tool => "工具",
417 Role::System => "系统",
418 };
419 let preview = match &m.content {
420 MessageContent::Text(t) => truncate_with_suffix(t, 200),
421 MessageContent::Blocks(blocks) => blocks
422 .iter()
423 .map(|b| match b {
424 ContentBlock::Text { text } => truncate_with_suffix(text, 100),
425 ContentBlock::ToolUse { name, .. } => format!("[工具: {}]", name),
426 ContentBlock::ToolResult { content, .. } => {
427 truncate_with_suffix(content, 100)
428 }
429 _ => "[...]".to_string(),
430 })
431 .collect::<Vec<_>>()
432 .join(" | "),
433 };
434 format!("{}: {}", role, preview)
435 })
436 .collect::<Vec<_>>()
437 .join("\n");
438
439 format!(
440 "请将以下对话压缩为简洁摘要({} 条消息):\n{}",
441 messages.len(),
442 history
443 )
444}
445
446#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[test]
455 fn test_estimate_tokens_simple() {
456 let msg = Message {
457 role: Role::User,
458 content: MessageContent::Text("Hello world".to_string()),
459 };
460 assert!(estimate_tokens(&msg) >= 3);
461 }
462
463 #[test]
464 fn test_should_compress() {
465 let config = CompressionConfig::default();
466 assert!(!should_compress(100_000, Some(200_000), &config));
467 assert!(should_compress(160_000, Some(200_000), &config));
468 }
469}