1use crate::providers::{
4 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
5};
6use crate::truncate::truncate_with_suffix;
7use anyhow::Result;
8use async_trait::async_trait;
9use std::collections::HashSet;
10
11use super::config::{CompressionBias, CompressionConfig};
12use super::types::{CompressionStrategy, SummarizedSegment};
13
14#[async_trait]
20pub trait Compressor: Send + Sync {
21 async fn summarize(
23 &self,
24 messages: &[Message],
25 config: &CompressionConfig,
26 ) -> Result<SummarizedSegment>;
27
28 fn model_name(&self) -> &str;
30}
31
32pub struct AiCompressor {
34 provider: Box<dyn Provider>,
35 model: String,
36}
37
38impl AiCompressor {
39 pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
40 Self { provider, model }
41 }
42}
43
44const SUMMARY_SYSTEM_PROMPT: &str = r#"CRITICAL: 仅用文本响应。不要调用任何工具。
45
46- 不要使用 read、bash、grep、glob、edit、write 或任何其他工具
47- 你已在上方对话中获得所需的所有上下文
48- 工具调用将被拒绝并浪费你唯一的 turn — 你将失败任务
49- 你的整个响应必须是纯文本摘要
50
51---
52
53你是一个对话历史压缩助手。将对话压缩为结构化摘要。
54
55在提供最终摘要前,将你的分析包裹在 <analysis> 标签中以组织思路:
56<analysis>
571. 按时间顺序分析每条消息
582. 识别:用户请求、助手行动、关键决策、错误及修复
593. 特别注意用户的敏感指令(禁止/必须)
604. 双重检查技术准确性和完整性
61</analysis>
62
63输出要求:
64- 结构化:使用9个章节格式
65- 关键:只保留重要信息,忽略无关细节
66- 敏感:必须保留用户的敏感指令(禁止、必须等)— 原样保留
67- 任务:必须保留未完成的待办事项
68- 决策:必须保留关键方案选择和理由
69
709章节输出格式:
71【摘要】一句话概括主要工作(50字以内)
72【已完成】列出已完成的操作(工具调用、文件变更)
73【未完成】列出待办任务和阻塞项 — 最关键,压缩后恢复需要
74【关键决策】重要选择及理由(技术选型、方案决策)
75【敏感指令】用户的禁止/必须指令(必须原样保留)
76【技术栈】使用的语言、框架、库、工具
77【文件变更】读取、修改、创建的文件路径
78【问题记录】遇到的问题及解决方案
79【下一步】建议的下一步操作(直接引用最近对话展示任务中断点)
80
81每章节控制在100字以内,空章节可省略。
82输出摘要后立即停止,不要添加任何解释或后续建议。
83
84REMINDER: 不要调用任何工具。仅用纯文本响应。"#;
85
86#[async_trait]
87impl Compressor for AiCompressor {
88 async fn summarize(
89 &self,
90 messages: &[Message],
91 _config: &CompressionConfig,
92 ) -> Result<SummarizedSegment> {
93 let prompt = build_summary_prompt(messages);
94
95 let request = ChatRequest {
96 messages: vec![Message {
97 role: Role::User,
98 content: MessageContent::Text(prompt),
99 }],
100 tools: vec![],
101 system: Some(SUMMARY_SYSTEM_PROMPT.to_string()),
102 think: false,
103 max_tokens: 1024,
104 server_tools: vec![],
105 enable_caching: false,
106 };
107
108 let response = self.provider.chat(request).await?;
109 let summary_text = extract_text_from_response(&response);
110 let (summary, key_points) = parse_summary_response(&summary_text);
111
112 Ok(SummarizedSegment {
113 time_range: (chrono::Utc::now(), chrono::Utc::now()),
114 original_count: messages.len(),
115 summary,
116 key_points,
117 })
118 }
119
120 fn model_name(&self) -> &str {
121 &self.model
122 }
123}
124
125fn extract_text_from_response(response: &ChatResponse) -> String {
126 response
127 .content
128 .iter()
129 .filter_map(|block| {
130 if let ContentBlock::Text { text } = block {
131 Some(text.clone())
132 } else {
133 None
134 }
135 })
136 .collect::<Vec<_>>()
137 .join("\n")
138}
139
140fn parse_summary_response(text: &str) -> (String, Vec<String>) {
141 let mut summary = String::new();
142 let mut key_points: Vec<String> = Vec::new();
143
144 let sections = [
146 "【摘要】", "【已完成】", "【未完成】", "【关键决策】",
147 "【敏感指令】", "【技术栈】", "【文件变更】", "【问题记录】", "【下一步】"
148 ];
149
150 for line in text.lines() {
151 let line = line.trim();
152
153 let is_header = sections.iter().any(|s| line.starts_with(s));
155
156 if is_header {
157 for section in §ions {
159 if line.starts_with(section) {
160 let replaced = line.replace(section, "");
161 let content = replaced.trim();
162 if !content.is_empty() {
163 if *section == "【摘要】" {
164 summary = content.to_string();
165 } else {
166 key_points.push(format!("{}{}", section, content));
167 }
168 }
169 break;
170 }
171 }
172 } else if !line.is_empty() {
173 if line.starts_with("•") || line.starts_with("-") || line.starts_with("*") {
175 let point = line.trim_start_matches(['•', '-', '*']).trim();
176 if !point.is_empty() {
177 key_points.push(point.to_string());
178 }
179 } else if summary.is_empty() {
180 summary = line.to_string();
182 }
183 }
184 }
185
186 if summary.is_empty() && !text.is_empty() {
188 summary = text.lines().take(3).collect::<Vec<_>>().join(" ");
189 if summary.len() > 200 {
190 summary = truncate_with_suffix(&summary, 200);
191 }
192 }
193
194 (summary, key_points)
195}
196
197pub fn compress_messages(
203 messages: &[Message],
204 strategy: CompressionStrategy,
205 config: &CompressionConfig,
206) -> Result<Vec<Message>> {
207 match strategy {
208 CompressionStrategy::Truncate => truncate_compress(messages, config),
209 CompressionStrategy::SlidingWindow => sliding_window_compress(messages, config),
210 CompressionStrategy::Summarize => sliding_window_compress(messages, config),
211 CompressionStrategy::BiasBased => compress_with_bias(messages, config),
212 }
213}
214
215pub fn compress_with_bias(
217 messages: &[Message],
218 config: &CompressionConfig,
219) -> Result<Vec<Message>> {
220 if messages.len() <= config.min_preserve_messages {
221 return Ok(messages.to_vec());
222 }
223
224 let scored: Vec<(usize, Message, f64)> = messages
225 .iter()
226 .enumerate()
227 .map(|(idx, msg)| {
228 (
229 idx,
230 msg.clone(),
231 calculate_preservation_score(msg, idx, messages.len(), &config.bias),
232 )
233 })
234 .collect();
235
236 let mut scored_with_recency: Vec<(usize, Message, f64)> = scored
237 .into_iter()
238 .map(|(idx, msg, score)| {
239 let recency_bonus = if idx >= messages.len() - config.min_preserve_messages {
240 100.0
241 } else {
242 (idx as f64 / messages.len() as f64) * 20.0
243 };
244 (idx, msg, score + recency_bonus)
245 })
246 .collect();
247
248 scored_with_recency.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
249
250 let target_count = if config.bias.aggressive {
251 config.min_preserve_messages
252 } else {
253 let estimated = estimate_total_tokens(messages);
254 let target_tokens = (estimated as f64 * config.target_ratio) as u32;
255 let avg = estimated / messages.len() as u32;
256 (target_tokens / avg.max(1)) as usize
257 };
258
259 let to_keep: HashSet<usize> = scored_with_recency
260 .iter()
261 .take(target_count)
262 .map(|(idx, _, _)| *idx)
263 .collect();
264
265 let compressed: Vec<Message> = messages
266 .iter()
267 .enumerate()
268 .filter(|(idx, _)| to_keep.contains(idx))
269 .map(|(_, msg)| msg.clone())
270 .collect();
271
272 Ok(compressed)
273}
274
275fn calculate_preservation_score(
276 message: &Message,
277 index: usize,
278 _total: usize, bias: &CompressionBias,
280) -> f64 {
281 let mut score: f64 = 10.0;
282
283 if index == 0 {
285 score += 100.0;
286 }
287
288 match message.role {
289 Role::User => {
290 if bias.preserve_user_questions {
291 score += 30.0;
292 }
293 }
294 Role::Assistant => {
295 score += 5.0;
296 }
297 Role::Tool => {
298 if bias.preserve_tools {
299 score += 25.0;
300 }
301 }
302 Role::System => {
303 score += 40.0;
304 }
305 }
306
307 match &message.content {
308 MessageContent::Text(text) => {
309 for keyword in &bias.preserve_keywords {
310 if text.to_lowercase().contains(&keyword.to_lowercase()) {
311 score += 15.0;
312 }
313 }
314 if contains_sensitive_instructions(text) {
315 score += 50.0;
316 }
317 }
318 MessageContent::Blocks(blocks) => {
319 for block in blocks {
320 match block {
321 ContentBlock::ToolUse { name, .. } => {
322 if bias.preserve_tools {
323 score += 20.0;
324 }
325 if name == "write" || name == "edit" || name == "bash" {
326 score += 10.0;
327 }
328 if name == "todo_write" {
330 score += 60.0;
331 }
332 if name == "ask" {
334 score += 50.0;
335 }
336 }
337 ContentBlock::ToolResult { content, .. } => {
338 if bias.preserve_tools {
339 score += 20.0;
340 }
341 if contains_sensitive_instructions(content) {
342 score += 30.0;
343 }
344 if content.contains("TodoWrite") || content.contains("todo") {
346 score += 40.0;
347 }
348 if content.contains("AskUserQuestion") || content.contains("answer") {
350 score += 30.0;
351 }
352 }
353 ContentBlock::Thinking { .. } => {
354 if bias.preserve_thinking {
355 score += 25.0;
356 } else {
357 score -= 5.0;
358 }
359 }
360 ContentBlock::Text { text } => {
361 if contains_sensitive_instructions(text) {
362 score += 50.0;
363 }
364 }
365 _ => {}
366 }
367 }
368 }
369 }
370
371 score
372}
373
374fn contains_sensitive_instructions(text: &str) -> bool {
375 let lower = text.to_lowercase();
376 let patterns = [
377 "不要",
378 "禁止",
379 "必须",
380 "不允许",
381 "never",
382 "must not",
383 "do not",
384 ];
385 patterns.iter().any(|p| lower.contains(p))
386}
387
388fn truncate_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
389 if messages.len() <= config.min_preserve_messages {
390 return Ok(messages.to_vec());
391 }
392 Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
393}
394
395fn sliding_window_compress(
396 messages: &[Message],
397 config: &CompressionConfig,
398) -> Result<Vec<Message>> {
399 if messages.len() <= config.min_preserve_messages {
400 return Ok(messages.to_vec());
401 }
402
403 let first_msg = messages.first().cloned();
409 let recent_start = messages.len().saturating_sub(config.min_preserve_messages);
410 let recent_msgs = &messages[recent_start..];
411
412 let first_tokens = first_msg.as_ref().map(estimate_tokens).unwrap_or(0);
414 let recent_tokens = estimate_total_tokens(recent_msgs);
415 let current_total = estimate_total_tokens(messages);
416 let target_tokens = (current_total as f64 * config.target_ratio) as u32;
417
418 if first_tokens + recent_tokens <= target_tokens {
420 let mut result: Vec<Message> = Vec::new();
422 if let Some(first) = first_msg {
423 result.push(first);
424 }
425 result.extend(recent_msgs.iter().cloned());
426 return Ok(result);
427 }
428
429 for drop_count in 0..recent_msgs.len() {
431 let candidate = &recent_msgs[drop_count..];
432 if estimate_total_tokens(candidate) <= target_tokens {
433 return Ok(candidate.to_vec());
434 }
435 }
436
437 Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
439}
440
441pub fn estimate_tokens(message: &Message) -> u32 {
447 let (ascii, non_ascii) = match &message.content {
448 MessageContent::Text(t) => count_chars(t),
449 MessageContent::Blocks(blocks) => {
450 let mut a = 0u32;
451 let mut n = 0u32;
452 for block in blocks {
453 match block {
454 ContentBlock::Text { text } => {
455 let (ca, cn) = count_chars(text);
456 a += ca;
457 n += cn;
458 }
459 ContentBlock::ToolUse { name, input, .. } => {
460 let (ca, cn) = count_chars(name);
461 a += ca;
462 n += cn;
463 let (ja, jn) = count_chars(&input.to_string());
464 a += ja;
465 n += jn;
466 }
467 ContentBlock::ToolResult { content, .. } => {
468 let (ca, cn) = count_chars(content);
469 a += ca;
470 n += cn;
471 }
472 ContentBlock::Thinking { thinking, .. } => {
473 let (ca, cn) = count_chars(thinking);
474 a += ca;
475 n += cn;
476 }
477 _ => {}
478 }
479 }
480 (a, n)
481 }
482 };
483
484 let ascii_tokens = (ascii as f64 * 0.25).ceil() as u32;
485 let non_ascii_tokens = (non_ascii as f64 * 0.67).ceil() as u32;
486 (ascii_tokens + non_ascii_tokens + 10).max(1)
487}
488
489fn count_chars(s: &str) -> (u32, u32) {
490 let mut ascii = 0u32;
491 let mut non_ascii = 0u32;
492 for ch in s.chars() {
493 if ch.is_ascii() {
494 ascii += 1;
495 } else {
496 non_ascii += 1;
497 }
498 }
499 (ascii, non_ascii)
500}
501
502pub fn estimate_total_tokens(messages: &[Message]) -> u32 {
504 messages.iter().map(estimate_tokens).sum()
505}
506
507pub fn should_compress(
509 current_tokens: u32,
510 context_size: Option<u32>,
511 config: &CompressionConfig,
512) -> bool {
513 match context_size {
514 Some(size) => (current_tokens as f64 / size as f64) >= config.threshold,
515 None => false,
516 }
517}
518
519pub fn build_summary_prompt(messages: &[Message]) -> String {
521 let history = messages
522 .iter()
523 .map(|m| {
524 let role = match m.role {
525 Role::User => "用户",
526 Role::Assistant => "助手",
527 Role::Tool => "工具",
528 Role::System => "系统",
529 };
530 let preview = match &m.content {
531 MessageContent::Text(t) => truncate_with_suffix(t, 200),
532 MessageContent::Blocks(blocks) => blocks
533 .iter()
534 .map(|b| match b {
535 ContentBlock::Text { text } => truncate_with_suffix(text, 100),
536 ContentBlock::ToolUse { name, .. } => format!("[工具: {}]", name),
537 ContentBlock::ToolResult { content, .. } => {
538 truncate_with_suffix(content, 100)
539 }
540 _ => "[...]".to_string(),
541 })
542 .collect::<Vec<_>>()
543 .join(" | "),
544 };
545 format!("{}: {}", role, preview)
546 })
547 .collect::<Vec<_>>()
548 .join("\n");
549
550 format!(
551 "请将以下对话压缩为简洁摘要({} 条消息):\n{}",
552 messages.len(),
553 history
554 )
555}
556
557use super::pipeline::CompressionPipeline;
562use super::types::AiCompressionMode;
563
564pub async fn compress_messages_with_ai(
569 messages: &[Message],
570 config: &CompressionConfig,
571 ai_mode: AiCompressionMode,
572 fast_model: Option<Box<dyn Provider>>,
573 token_usage: u32,
574 context_window: u32,
575) -> Result<Vec<Message>> {
576 let mut pipeline = match (ai_mode, fast_model) {
577 (AiCompressionMode::None, _) => CompressionPipeline::new_rule_only(config.clone()),
578 (AiCompressionMode::Light | AiCompressionMode::Deep, Some(model)) => {
579 CompressionPipeline::new_with_ai(config.clone(), model)
580 }
581 _ => CompressionPipeline::new_rule_only(config.clone()),
582 };
583
584 let result = pipeline.execute(messages, ai_mode, token_usage, context_window).await?;
585 Ok(result.messages)
586}
587
588pub async fn compress_messages_with_full_ai(
592 messages: &[Message],
593 config: &CompressionConfig,
594 ai_mode: AiCompressionMode,
595 fast_model: Box<dyn Provider>,
596 main_model: Box<dyn Provider>,
597 token_usage: u32,
598 context_window: u32,
599) -> Result<Vec<Message>> {
600 let mut pipeline = CompressionPipeline::new_with_full_ai(
601 config.clone(),
602 fast_model,
603 main_model,
604 );
605
606 let result = pipeline.execute(messages, ai_mode, token_usage, context_window).await?;
607 Ok(result.messages)
608}
609
610pub fn score_messages_only(
614 messages: &[Message],
615 config: &CompressionConfig,
616) -> Vec<super::types::ScoredMessage> {
617 let pipeline = CompressionPipeline::new_rule_only(config.clone());
618 pipeline.score_only(messages)
619}
620
621#[cfg(test)]
626mod tests {
627 use super::*;
628
629 #[test]
630 fn test_estimate_tokens_simple() {
631 let msg = Message {
632 role: Role::User,
633 content: MessageContent::Text("Hello world".to_string()),
634 };
635 assert!(estimate_tokens(&msg) >= 3);
636 }
637
638 #[test]
639 fn test_should_compress() {
640 let config = CompressionConfig::default();
641 assert!(should_compress(100_000, Some(200_000), &config));
643 assert!(!should_compress(80_000, Some(200_000), &config));
645 }
646}