1use crate::providers::{
4 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
5};
6use crate::truncate::truncate_with_suffix;
7use anyhow::Result;
8use async_trait::async_trait;
9use std::collections::HashSet;
10
11use super::config::{CompressionBias, CompressionConfig};
12use super::types::{CompressionStrategy, SummarizedSegment};
13
14#[async_trait]
20pub trait Compressor: Send + Sync {
21 async fn summarize(
23 &self,
24 messages: &[Message],
25 config: &CompressionConfig,
26 ) -> Result<SummarizedSegment>;
27
28 fn model_name(&self) -> &str;
30}
31
32pub struct AiCompressor {
34 provider: Box<dyn Provider>,
35 model: String,
36}
37
38impl AiCompressor {
39 pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
40 Self { provider, model }
41 }
42}
43
44const SUMMARY_SYSTEM_PROMPT: &str = r#"CRITICAL: 仅用文本响应。不要调用任何工具。
45
46- 不要使用 read、bash、grep、glob、edit、write 或任何其他工具
47- 你已在上方对话中获得所需的所有上下文
48- 工具调用将被拒绝并浪费你唯一的 turn — 你将失败任务
49- 你的整个响应必须是纯文本摘要
50
51---
52
53你是一个对话历史压缩助手。将对话压缩为结构化摘要。
54
55在提供最终摘要前,将你的分析包裹在 <analysis> 标签中以组织思路:
56<analysis>
571. 按时间顺序分析每条消息
582. 识别:用户请求、助手行动、关键决策、错误及修复
593. 特别注意用户的敏感指令(禁止/必须)
604. 双重检查技术准确性和完整性
61</analysis>
62
63输出要求:
64- 结构化:使用9个章节格式
65- 关键:只保留重要信息,忽略无关细节
66- 敏感:必须保留用户的敏感指令(禁止、必须等)— 原样保留
67- 任务:必须保留未完成的待办事项
68- 决策:必须保留关键方案选择和理由
69
709章节输出格式:
71【摘要】一句话概括主要工作(50字以内)
72【已完成】列出已完成的操作(工具调用、文件变更)
73【未完成】列出待办任务和阻塞项 — 最关键,压缩后恢复需要
74【关键决策】重要选择及理由(技术选型、方案决策)
75【敏感指令】用户的禁止/必须指令(必须原样保留)
76【技术栈】使用的语言、框架、库、工具
77【文件变更】读取、修改、创建的文件路径
78【问题记录】遇到的问题及解决方案
79【下一步】建议的下一步操作(直接引用最近对话展示任务中断点)
80
81每章节控制在100字以内,空章节可省略。
82输出摘要后立即停止,不要添加任何解释或后续建议。
83
84REMINDER: 不要调用任何工具。仅用纯文本响应。"#;
85
86#[async_trait]
87impl Compressor for AiCompressor {
88 async fn summarize(
89 &self,
90 messages: &[Message],
91 _config: &CompressionConfig,
92 ) -> Result<SummarizedSegment> {
93 let prompt = build_summary_prompt(messages);
94
95 let request = ChatRequest {
96 messages: vec![Message {
97 role: Role::User,
98 content: MessageContent::Text(prompt),
99 }],
100 tools: vec![],
101 system: Some(SUMMARY_SYSTEM_PROMPT.to_string()),
102 think: false,
103 max_tokens: 1024,
104 server_tools: vec![],
105 enable_caching: false,
106 };
107
108 let response = self.provider.chat(request).await?;
109 let summary_text = extract_text_from_response(&response);
110 let (summary, key_points) = parse_summary_response(&summary_text);
111
112 Ok(SummarizedSegment {
113 time_range: (chrono::Utc::now(), chrono::Utc::now()),
114 original_count: messages.len(),
115 summary,
116 key_points,
117 })
118 }
119
120 fn model_name(&self) -> &str {
121 &self.model
122 }
123}
124
125fn extract_text_from_response(response: &ChatResponse) -> String {
126 response
127 .content
128 .iter()
129 .filter_map(|block| {
130 if let ContentBlock::Text { text } = block {
131 Some(text.clone())
132 } else {
133 None
134 }
135 })
136 .collect::<Vec<_>>()
137 .join("\n")
138}
139
140fn parse_summary_response(text: &str) -> (String, Vec<String>) {
141 let mut summary = String::new();
142 let mut key_points: Vec<String> = Vec::new();
143
144 let sections = [
146 "【摘要】",
147 "【已完成】",
148 "【未完成】",
149 "【关键决策】",
150 "【敏感指令】",
151 "【技术栈】",
152 "【文件变更】",
153 "【问题记录】",
154 "【下一步】",
155 ];
156
157 for line in text.lines() {
158 let line = line.trim();
159
160 let is_header = sections.iter().any(|s| line.starts_with(s));
162
163 if is_header {
164 for section in §ions {
166 if line.starts_with(section) {
167 let replaced = line.replace(section, "");
168 let content = replaced.trim();
169 if !content.is_empty() {
170 if *section == "【摘要】" {
171 summary = content.to_string();
172 } else {
173 key_points.push(format!("{}{}", section, content));
174 }
175 }
176 break;
177 }
178 }
179 } else if !line.is_empty() {
180 if line.starts_with("•") || line.starts_with("-") || line.starts_with("*") {
182 let point = line.trim_start_matches(['•', '-', '*']).trim();
183 if !point.is_empty() {
184 key_points.push(point.to_string());
185 }
186 } else if summary.is_empty() {
187 summary = line.to_string();
189 }
190 }
191 }
192
193 if summary.is_empty() && !text.is_empty() {
195 summary = text.lines().take(3).collect::<Vec<_>>().join(" ");
196 if summary.len() > 200 {
197 summary = truncate_with_suffix(&summary, 200);
198 }
199 }
200
201 (summary, key_points)
202}
203
204pub fn compress_messages(
210 messages: &[Message],
211 strategy: CompressionStrategy,
212 config: &CompressionConfig,
213) -> Result<Vec<Message>> {
214 match strategy {
215 CompressionStrategy::Truncate => truncate_compress(messages, config),
216 CompressionStrategy::SlidingWindow => sliding_window_compress(messages, config),
217 CompressionStrategy::Summarize => sliding_window_compress(messages, config),
218 CompressionStrategy::BiasBased => compress_with_bias(messages, config),
219 }
220}
221
222pub fn compress_with_bias(
224 messages: &[Message],
225 config: &CompressionConfig,
226) -> Result<Vec<Message>> {
227 if messages.len() <= config.min_preserve_messages {
228 return Ok(messages.to_vec());
229 }
230
231 let scored: Vec<(usize, Message, f64)> = messages
232 .iter()
233 .enumerate()
234 .map(|(idx, msg)| {
235 (
236 idx,
237 msg.clone(),
238 calculate_preservation_score(msg, idx, messages.len(), &config.bias),
239 )
240 })
241 .collect();
242
243 let mut scored_with_recency: Vec<(usize, Message, f64)> = scored
244 .into_iter()
245 .map(|(idx, msg, score)| {
246 let recency_bonus = if idx >= messages.len() - config.min_preserve_messages {
247 100.0
248 } else {
249 (idx as f64 / messages.len() as f64) * 20.0
250 };
251 (idx, msg, score + recency_bonus)
252 })
253 .collect();
254
255 scored_with_recency.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
256
257 let target_count = if config.bias.aggressive {
258 config.min_preserve_messages
259 } else {
260 let estimated = estimate_total_tokens(messages);
261 let target_tokens = (estimated as f64 * config.target_ratio) as u32;
262 let avg = estimated / messages.len() as u32;
263 (target_tokens / avg.max(1)) as usize
264 };
265
266 let to_keep: HashSet<usize> = scored_with_recency
267 .iter()
268 .take(target_count)
269 .map(|(idx, _, _)| *idx)
270 .collect();
271
272 let compressed: Vec<Message> = messages
273 .iter()
274 .enumerate()
275 .filter(|(idx, _)| to_keep.contains(idx))
276 .map(|(_, msg)| msg.clone())
277 .collect();
278
279 Ok(compressed)
280}
281
282fn calculate_preservation_score(
283 message: &Message,
284 index: usize,
285 _total: usize, bias: &CompressionBias,
287) -> f64 {
288 let mut score: f64 = 10.0;
289
290 if index == 0 {
292 score += 100.0;
293 }
294
295 match message.role {
296 Role::User => {
297 if bias.preserve_user_questions {
298 score += 30.0;
299 }
300 }
301 Role::Assistant => {
302 score += 5.0;
303 }
304 Role::Tool => {
305 if bias.preserve_tools {
306 score += 25.0;
307 }
308 }
309 Role::System => {
310 score += 40.0;
311 }
312 }
313
314 match &message.content {
315 MessageContent::Text(text) => {
316 for keyword in &bias.preserve_keywords {
317 if text.to_lowercase().contains(&keyword.to_lowercase()) {
318 score += 15.0;
319 }
320 }
321 if contains_sensitive_instructions(text) {
322 score += 50.0;
323 }
324 }
325 MessageContent::Blocks(blocks) => {
326 for block in blocks {
327 match block {
328 ContentBlock::ToolUse { name, .. } => {
329 if bias.preserve_tools {
330 score += 20.0;
331 }
332 if name == "write" || name == "edit" || name == "bash" {
333 score += 10.0;
334 }
335 if name == "todo_write" {
337 score += 60.0;
338 }
339 if name == "ask" {
341 score += 50.0;
342 }
343 }
344 ContentBlock::ToolResult { content, .. } => {
345 if bias.preserve_tools {
346 score += 20.0;
347 }
348 if contains_sensitive_instructions(content) {
349 score += 30.0;
350 }
351 if content.contains("TodoWrite") || content.contains("todo") {
353 score += 40.0;
354 }
355 if content.contains("AskUserQuestion") || content.contains("answer") {
357 score += 30.0;
358 }
359 }
360 ContentBlock::Thinking { .. } => {
361 if bias.preserve_thinking {
362 score += 25.0;
363 } else {
364 score -= 5.0;
365 }
366 }
367 ContentBlock::Text { text } => {
368 if contains_sensitive_instructions(text) {
369 score += 50.0;
370 }
371 }
372 _ => {}
373 }
374 }
375 }
376 }
377
378 score
379}
380
381fn contains_sensitive_instructions(text: &str) -> bool {
382 let lower = text.to_lowercase();
383 let patterns = [
384 "不要",
385 "禁止",
386 "必须",
387 "不允许",
388 "never",
389 "must not",
390 "do not",
391 ];
392 patterns.iter().any(|p| lower.contains(p))
393}
394
395fn truncate_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
396 if messages.len() <= config.min_preserve_messages {
397 return Ok(messages.to_vec());
398 }
399 Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
400}
401
402fn sliding_window_compress(
403 messages: &[Message],
404 config: &CompressionConfig,
405) -> Result<Vec<Message>> {
406 if messages.len() <= config.min_preserve_messages {
407 return Ok(messages.to_vec());
408 }
409
410 let first_msg = messages.first().cloned();
416 let recent_start = messages.len().saturating_sub(config.min_preserve_messages);
417 let recent_msgs = &messages[recent_start..];
418
419 let first_tokens = first_msg.as_ref().map(estimate_tokens).unwrap_or(0);
421 let recent_tokens = estimate_total_tokens(recent_msgs);
422 let current_total = estimate_total_tokens(messages);
423 let target_tokens = (current_total as f64 * config.target_ratio) as u32;
424
425 if first_tokens + recent_tokens <= target_tokens {
427 let mut result: Vec<Message> = Vec::new();
429 if let Some(first) = first_msg {
430 result.push(first);
431 }
432 result.extend(recent_msgs.iter().cloned());
433 return Ok(result);
434 }
435
436 for drop_count in 0..recent_msgs.len() {
438 let candidate = &recent_msgs[drop_count..];
439 if estimate_total_tokens(candidate) <= target_tokens {
440 return Ok(candidate.to_vec());
441 }
442 }
443
444 Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
446}
447
448pub fn estimate_tokens(message: &Message) -> u32 {
454 let (ascii, non_ascii) = match &message.content {
455 MessageContent::Text(t) => count_chars(t),
456 MessageContent::Blocks(blocks) => {
457 let mut a = 0u32;
458 let mut n = 0u32;
459 for block in blocks {
460 match block {
461 ContentBlock::Text { text } => {
462 let (ca, cn) = count_chars(text);
463 a += ca;
464 n += cn;
465 }
466 ContentBlock::ToolUse { name, input, .. } => {
467 let (ca, cn) = count_chars(name);
468 a += ca;
469 n += cn;
470 let (ja, jn) = count_chars(&input.to_string());
471 a += ja;
472 n += jn;
473 }
474 ContentBlock::ToolResult { content, .. } => {
475 let (ca, cn) = count_chars(content);
476 a += ca;
477 n += cn;
478 }
479 ContentBlock::Thinking { thinking, .. } => {
480 let (ca, cn) = count_chars(thinking);
481 a += ca;
482 n += cn;
483 }
484 _ => {}
485 }
486 }
487 (a, n)
488 }
489 };
490
491 let ascii_tokens = (ascii as f64 * 0.25).ceil() as u32;
492 let non_ascii_tokens = (non_ascii as f64 * 0.67).ceil() as u32;
493 (ascii_tokens + non_ascii_tokens + 10).max(1)
494}
495
496fn count_chars(s: &str) -> (u32, u32) {
497 let mut ascii = 0u32;
498 let mut non_ascii = 0u32;
499 for ch in s.chars() {
500 if ch.is_ascii() {
501 ascii += 1;
502 } else {
503 non_ascii += 1;
504 }
505 }
506 (ascii, non_ascii)
507}
508
509pub fn estimate_total_tokens(messages: &[Message]) -> u32 {
511 messages.iter().map(estimate_tokens).sum()
512}
513
514pub fn should_compress(
516 current_tokens: u32,
517 context_size: Option<u32>,
518 config: &CompressionConfig,
519) -> bool {
520 match context_size {
521 Some(size) => (current_tokens as f64 / size as f64) >= config.threshold,
522 None => false,
523 }
524}
525
526pub fn build_summary_prompt(messages: &[Message]) -> String {
528 let history = messages
529 .iter()
530 .map(|m| {
531 let role = match m.role {
532 Role::User => "用户",
533 Role::Assistant => "助手",
534 Role::Tool => "工具",
535 Role::System => "系统",
536 };
537 let preview = match &m.content {
538 MessageContent::Text(t) => truncate_with_suffix(t, 200),
539 MessageContent::Blocks(blocks) => blocks
540 .iter()
541 .map(|b| match b {
542 ContentBlock::Text { text } => truncate_with_suffix(text, 100),
543 ContentBlock::ToolUse { name, .. } => format!("[工具: {}]", name),
544 ContentBlock::ToolResult { content, .. } => {
545 truncate_with_suffix(content, 100)
546 }
547 _ => "[...]".to_string(),
548 })
549 .collect::<Vec<_>>()
550 .join(" | "),
551 };
552 format!("{}: {}", role, preview)
553 })
554 .collect::<Vec<_>>()
555 .join("\n");
556
557 format!(
558 "请将以下对话压缩为简洁摘要({} 条消息):\n{}",
559 messages.len(),
560 history
561 )
562}
563
564use super::pipeline::CompressionPipeline;
569use super::types::AiCompressionMode;
570
571pub async fn compress_messages_with_ai(
576 messages: &[Message],
577 config: &CompressionConfig,
578 ai_mode: AiCompressionMode,
579 fast_model: Option<Box<dyn Provider>>,
580 token_usage: u32,
581 context_window: u32,
582) -> Result<Vec<Message>> {
583 let mut pipeline = match (ai_mode, fast_model) {
584 (AiCompressionMode::None, _) => CompressionPipeline::new_rule_only(config.clone()),
585 (AiCompressionMode::Light | AiCompressionMode::Deep, Some(model)) => {
586 CompressionPipeline::new_with_ai(config.clone(), model)
587 }
588 _ => CompressionPipeline::new_rule_only(config.clone()),
589 };
590
591 let result = pipeline
592 .execute(messages, ai_mode, token_usage, context_window)
593 .await?;
594 Ok(result.messages)
595}
596
597pub async fn compress_messages_with_full_ai(
601 messages: &[Message],
602 config: &CompressionConfig,
603 ai_mode: AiCompressionMode,
604 fast_model: Box<dyn Provider>,
605 main_model: Box<dyn Provider>,
606 token_usage: u32,
607 context_window: u32,
608) -> Result<Vec<Message>> {
609 let mut pipeline =
610 CompressionPipeline::new_with_full_ai(config.clone(), fast_model, main_model);
611
612 let result = pipeline
613 .execute(messages, ai_mode, token_usage, context_window)
614 .await?;
615 Ok(result.messages)
616}
617
618pub fn score_messages_only(
622 messages: &[Message],
623 config: &CompressionConfig,
624) -> Vec<super::types::ScoredMessage> {
625 let pipeline = CompressionPipeline::new_rule_only(config.clone());
626 pipeline.score_only(messages)
627}
628
629#[cfg(test)]
634mod tests {
635 use super::*;
636
637 #[test]
638 fn test_estimate_tokens_simple() {
639 let msg = Message {
640 role: Role::User,
641 content: MessageContent::Text("Hello world".to_string()),
642 };
643 assert!(estimate_tokens(&msg) >= 3);
644 }
645
646 #[test]
647 fn test_should_compress() {
648 let config = CompressionConfig::default();
649 assert!(should_compress(100_000, Some(200_000), &config));
651 assert!(!should_compress(80_000, Some(200_000), &config));
653 }
654}