1use crate::providers::{
4 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
5};
6use crate::tokenizer::{count_tokens, message_overhead};
7use crate::truncate::truncate_with_suffix;
8use anyhow::Result;
9use async_trait::async_trait;
10use std::collections::HashSet;
11
12use super::dependency::DependencyBuilder;
13
14use super::config::{CompressionBias, CompressionConfig};
15use super::types::{CompressionStrategy, SummarizedSegment};
16
17#[async_trait]
23pub trait Compressor: Send + Sync {
24 async fn summarize(
26 &self,
27 messages: &[Message],
28 config: &CompressionConfig,
29 ) -> Result<SummarizedSegment>;
30
31 fn model_name(&self) -> &str;
33}
34
35pub struct AiCompressor {
37 provider: Box<dyn Provider>,
38 model: String,
39 hardcode_config: crate::compress::hardcode_config::HardcodeConfig,
40}
41
42impl AiCompressor {
43 pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
44 Self {
45 provider,
46 model,
47 hardcode_config: crate::compress::hardcode_config::HardcodeConfig::default(),
48 }
49 }
50
51 pub fn with_hardcode_config(mut self, config: crate::compress::hardcode_config::HardcodeConfig) -> Self {
53 self.hardcode_config = config;
54 self
55 }
56}
57
58const SUMMARY_SYSTEM_PROMPT: &str = r#"CRITICAL: 仅用文本响应。不要调用任何工具。
59
60- 不要使用 read、bash、grep、glob、edit、write 或任何其他工具
61- 你已在上方对话中获得所需的所有上下文
62- 工具调用将被拒绝并浪费你唯一的 turn — 你将失败任务
63- 你的整个响应必须是纯文本摘要
64
65---
66
67你是一个对话历史压缩助手。将对话压缩为结构化摘要。
68
69在提供最终摘要前,将你的分析包裹在 <analysis> 标签中以组织思路:
70<analysis>
711. 按时间顺序分析每条消息
722. 识别:用户请求、助手行动、关键决策、错误及修复
733. 特别注意用户的敏感指令(禁止/必须)
744. 双重检查技术准确性和完整性
75</analysis>
76
77输出要求:
78- 结构化:使用9个章节格式
79- 关键:只保留重要信息,忽略无关细节
80- 敏感:必须保留用户的敏感指令(禁止、必须等)— 原样保留
81- 任务:必须保留未完成的待办事项
82- 决策:必须保留关键方案选择和理由
83
849章节输出格式:
85【摘要】一句话概括主要工作(50字以内)
86【已完成】列出已完成的操作(工具调用、文件变更)
87【未完成】列出待办任务和阻塞项 — 最关键,压缩后恢复需要
88【关键决策】重要选择及理由(技术选型、方案决策)
89【敏感指令】用户的禁止/必须指令(必须原样保留)
90【技术栈】使用的语言、框架、库、工具
91【文件变更】读取、修改、创建的文件路径
92【问题记录】遇到的问题及解决方案
93【下一步】建议的下一步操作(直接引用最近对话展示任务中断点)
94
95每章节控制在100字以内,空章节可省略。
96输出摘要后立即停止,不要添加任何解释或后续建议。
97
98REMINDER: 不要调用任何工具。仅用纯文本响应。"#;
99
100#[async_trait]
101impl Compressor for AiCompressor {
102 async fn summarize(
103 &self,
104 messages: &[Message],
105 _config: &CompressionConfig,
106 ) -> Result<SummarizedSegment> {
107 let prompt = build_summary_prompt(messages);
108
109 let request = ChatRequest {
110 messages: vec![Message {
111 role: Role::User,
112 content: MessageContent::Text(prompt),
113 }],
114 tools: vec![],
115 system: Some(SUMMARY_SYSTEM_PROMPT.to_string()),
116 think: false,
117 max_tokens: 1024,
118 server_tools: vec![],
119 enable_caching: false,
120 };
121
122 let response = self.provider.chat(request).await?;
123 let summary_text = extract_text_from_response(&response);
124 let (summary, key_points) = parse_summary_response(&summary_text, &self.hardcode_config);
125
126 Ok(SummarizedSegment {
127 time_range: (chrono::Utc::now(), chrono::Utc::now()),
128 original_count: messages.len(),
129 summary,
130 key_points,
131 })
132 }
133
134 fn model_name(&self) -> &str {
135 &self.model
136 }
137}
138
139fn extract_text_from_response(response: &ChatResponse) -> String {
140 response
141 .content
142 .iter()
143 .filter_map(|block| {
144 if let ContentBlock::Text { text } = block {
145 Some(text.clone())
146 } else {
147 None
148 }
149 })
150 .collect::<Vec<_>>()
151 .join("\n")
152}
153
154fn parse_summary_response(text: &str, config: &crate::compress::hardcode_config::HardcodeConfig) -> (String, Vec<String>) {
155 let mut summary = String::new();
156 let mut key_points: Vec<String> = Vec::new();
157
158 let sections = [
160 "【摘要】",
161 "【已完成】",
162 "【未完成】",
163 "【关键决策】",
164 "【敏感指令】",
165 "【技术栈】",
166 "【文件变更】",
167 "【问题记录】",
168 "【下一步】",
169 ];
170
171 for line in text.lines() {
172 let line = line.trim();
173
174 let is_header = sections.iter().any(|s| line.starts_with(s));
176
177 if is_header {
178 for section in §ions {
180 if line.starts_with(section) {
181 let replaced = line.replace(section, "");
182 let content = replaced.trim();
183 if !content.is_empty() {
184 if *section == "【摘要】" {
185 summary = content.to_string();
186 } else {
187 key_points.push(format!("{}{}", section, content));
188 }
189 }
190 break;
191 }
192 }
193 } else if !line.is_empty() {
194 if line.starts_with("•") || line.starts_with("-") || line.starts_with("*") {
196 let point = line.trim_start_matches(['•', '-', '*']).trim();
197 if !point.is_empty() {
198 key_points.push(point.to_string());
199 }
200 } else if summary.is_empty() {
201 summary = line.to_string();
203 }
204 }
205 }
206
207 if summary.is_empty() && !text.is_empty() {
209 summary = text.lines().take(3).collect::<Vec<_>>().join(" ");
210 if summary.len() > config.summary_length_threshold {
211 summary = truncate_with_suffix(&summary, config.summary_length_threshold);
212 }
213 }
214
215 (summary, key_points)
216}
217
218pub fn compress_messages(
224 messages: &[Message],
225 strategy: CompressionStrategy,
226 config: &CompressionConfig,
227) -> Result<Vec<Message>> {
228 match strategy {
229 CompressionStrategy::Truncate => truncate_compress(messages, config),
230 CompressionStrategy::SlidingWindow => sliding_window_compress(messages, config),
231 CompressionStrategy::Summarize => sliding_window_compress(messages, config),
232 CompressionStrategy::BiasBased => compress_with_bias(messages, config),
233 }
234}
235
236pub fn compress_with_bias(
238 messages: &[Message],
239 config: &CompressionConfig,
240) -> Result<Vec<Message>> {
241 if messages.len() <= config.min_preserve_messages {
242 return Ok(messages.to_vec());
243 }
244
245 let scored: Vec<(usize, Message, f64)> = messages
246 .iter()
247 .enumerate()
248 .map(|(idx, msg)| {
249 (
250 idx,
251 msg.clone(),
252 calculate_preservation_score(msg, idx, messages.len(), &config.bias),
253 )
254 })
255 .collect();
256
257 let mut scored_with_recency: Vec<(usize, Message, f64)> = scored
258 .into_iter()
259 .map(|(idx, msg, score)| {
260 let recency_bonus = if idx >= messages.len() - config.min_preserve_messages {
261 100.0
262 } else {
263 (idx as f64 / messages.len() as f64) * 20.0
264 };
265 (idx, msg, score + recency_bonus)
266 })
267 .collect();
268
269 scored_with_recency.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
270
271 let target_count = if config.bias.aggressive {
272 config.min_preserve_messages
273 } else {
274 let estimated = estimate_total_tokens(messages);
275 let target_tokens = (estimated as f64 * config.target_ratio) as u32;
276 let avg = estimated / messages.len() as u32;
277 (target_tokens / avg.max(1)) as usize
278 };
279
280 let to_keep: HashSet<usize> = scored_with_recency
281 .iter()
282 .take(target_count)
283 .map(|(idx, _, _)| *idx)
284 .collect();
285
286 let compressed: Vec<Message> = messages
287 .iter()
288 .enumerate()
289 .filter(|(idx, _)| to_keep.contains(idx))
290 .map(|(_, msg)| msg.clone())
291 .collect();
292
293 Ok(compressed)
294}
295
296fn calculate_preservation_score(
297 message: &Message,
298 _index: usize,
299 _total: usize, bias: &CompressionBias,
301) -> f64 {
302 let mut score: f64 = 10.0;
303
304 match message.role {
309 Role::User => {
310 if bias.preserve_user_questions {
311 score += 30.0;
312 }
313 }
314 Role::Assistant => {
315 score += 5.0;
316 }
317 Role::Tool => {
318 if bias.preserve_tools {
319 score += 25.0;
320 }
321 }
322 Role::System => {
323 score += 40.0;
324 }
325 }
326
327 match &message.content {
328 MessageContent::Text(text) => {
329 for keyword in &bias.preserve_keywords {
330 if text.to_lowercase().contains(&keyword.to_lowercase()) {
331 score += 15.0;
332 }
333 }
334 if contains_sensitive_instructions(text) {
335 score += 50.0;
336 }
337 }
338 MessageContent::Blocks(blocks) => {
339 for block in blocks {
340 match block {
341 ContentBlock::ToolUse { name, .. } => {
342 if bias.preserve_tools {
343 score += 20.0;
344 }
345 if name == "write" || name == "edit" || name == "bash" {
346 score += 10.0;
347 }
348 if name == "todo_write" {
350 score += 60.0;
351 }
352 if name == "ask" {
354 score += 50.0;
355 }
356 }
357 ContentBlock::ToolResult { content, .. } => {
358 if bias.preserve_tools {
359 score += 20.0;
360 }
361 if contains_sensitive_instructions(content) {
362 score += 30.0;
363 }
364 if content.contains("TodoWrite") || content.contains("todo") {
366 score += 40.0;
367 }
368 if content.contains("AskUserQuestion") || content.contains("answer") {
370 score += 30.0;
371 }
372 }
373 ContentBlock::Thinking { .. } => {
374 if bias.preserve_thinking {
375 score += 25.0;
376 } else {
377 score -= 5.0;
378 }
379 }
380 ContentBlock::Text { text } => {
381 if contains_sensitive_instructions(text) {
382 score += 50.0;
383 }
384 }
385 _ => {}
386 }
387 }
388 }
389 }
390
391 score
392}
393
394fn contains_sensitive_instructions(text: &str) -> bool {
395 let lower = text.to_lowercase();
396 let patterns = [
397 "不要",
398 "禁止",
399 "必须",
400 "不允许",
401 "never",
402 "must not",
403 "do not",
404 ];
405 patterns.iter().any(|p| lower.contains(p))
406}
407
408fn truncate_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
409 if messages.len() <= config.min_preserve_messages {
410 return Ok(messages.to_vec());
411 }
412 Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
413}
414
415fn sliding_window_compress(
416 messages: &[Message],
417 config: &CompressionConfig,
418) -> Result<Vec<Message>> {
419 if messages.len() <= config.min_preserve_messages {
420 return Ok(messages.to_vec());
421 }
422
423 let deps = DependencyBuilder::build(messages);
425
426 let recent_start = messages.len().saturating_sub(config.min_preserve_messages);
432
433 let mut preserve_indices: HashSet<usize> = HashSet::new();
435
436 for i in recent_start..messages.len() {
438 preserve_indices.insert(i);
439
440 for pair_idx in deps.get_pair_indices(i) {
442 preserve_indices.insert(pair_idx);
443 }
444 }
445
446 let preserved_msgs: Vec<Message> = preserve_indices
448 .iter()
449 .filter(|&i| *i < messages.len())
450 .map(|&i| messages[i].clone())
451 .collect();
452
453 let preserved_tokens = estimate_total_tokens(&preserved_msgs);
454 let current_total = estimate_total_tokens(messages);
455 let target_tokens = (current_total as f64 * config.target_ratio) as u32;
456
457 if preserved_tokens <= target_tokens {
459 let mut sorted_indices: Vec<usize> = preserve_indices.iter().cloned().collect();
460 sorted_indices.sort();
461 let result: Vec<Message> = sorted_indices
462 .iter()
463 .filter(|&i| *i < messages.len())
464 .map(|&i| messages[i].clone())
465 .collect();
466 return Ok(result);
467 }
468
469 let mut sorted_indices: Vec<usize> = preserve_indices.iter().cloned().collect();
471 sorted_indices.sort();
472
473 while sorted_indices.len() > config.min_preserve_messages {
475 let candidate_tokens = sorted_indices
476 .iter()
477 .copied()
478 .filter(|&i| i < messages.len())
479 .map(|i| estimate_tokens(&messages[i]))
480 .sum::<u32>();
481
482 if candidate_tokens <= target_tokens {
483 break;
484 }
485
486 if sorted_indices.len() > 2 {
489 let oldest_idx = sorted_indices[0];
490 let pair_indices = deps.get_pair_indices(oldest_idx);
491
492 if pair_indices.is_empty() || pair_indices.iter().all(|p| !sorted_indices.contains(p)) {
494 sorted_indices.remove(0);
495 } else {
496 sorted_indices.remove(1);
498 }
499 } else {
500 break;
501 }
502 }
503
504 let result: Vec<Message> = sorted_indices
505 .iter()
506 .filter(|&i| *i < messages.len())
507 .map(|&i| messages[i].clone())
508 .collect();
509
510 Ok(result)
511}
512
513pub fn estimate_tokens(message: &Message) -> u32 {
519 let content_text = match &message.content {
520 MessageContent::Text(t) => t.clone(),
521 MessageContent::Blocks(blocks) => {
522 let mut text = String::new();
523 for block in blocks {
524 match block {
525 ContentBlock::Text { text: t } => {
526 text.push_str(t);
527 text.push(' '); }
529 ContentBlock::ToolUse { name, input, .. } => {
530 text.push_str(name);
531 text.push(' '); text.push_str(&input.to_string());
533 text.push(' '); }
535 ContentBlock::ToolResult { content, .. } => {
536 text.push_str(content);
537 text.push(' '); }
539 ContentBlock::Thinking { thinking, .. } => {
540 text.push_str(thinking);
541 text.push(' '); }
543 _ => {}
544 }
545 }
546 text
547 }
548 };
549
550 let content_tokens = count_tokens(&content_text);
552 let role_tokens = count_tokens(&format!("{:?}: ", message.role));
553
554 content_tokens + role_tokens + message_overhead()
556}
557
558pub fn estimate_total_tokens(messages: &[Message]) -> u32 {
560 messages.iter().map(estimate_tokens).sum()
561}
562
563pub fn should_compress(
565 current_tokens: u32,
566 context_size: Option<u32>,
567 config: &CompressionConfig,
568) -> bool {
569 match context_size {
570 Some(size) => (current_tokens as f64 / size as f64) >= config.threshold,
571 None => false,
572 }
573}
574
575pub fn build_summary_prompt(messages: &[Message]) -> String {
577 let history = messages
578 .iter()
579 .map(|m| {
580 let role = match m.role {
581 Role::User => "用户",
582 Role::Assistant => "助手",
583 Role::Tool => "工具",
584 Role::System => "系统",
585 };
586 let preview = match &m.content {
587 MessageContent::Text(t) => truncate_with_suffix(t, 200),
588 MessageContent::Blocks(blocks) => blocks
589 .iter()
590 .map(|b| match b {
591 ContentBlock::Text { text } => truncate_with_suffix(text, 100),
592 ContentBlock::ToolUse { name, .. } => format!("[工具: {}]", name),
593 ContentBlock::ToolResult { content, .. } => {
594 truncate_with_suffix(content, 100)
595 }
596 _ => "[...]".to_string(),
597 })
598 .collect::<Vec<_>>()
599 .join(" | "),
600 };
601 format!("{}: {}", role, preview)
602 })
603 .collect::<Vec<_>>()
604 .join("\n");
605
606 format!(
607 "请将以下对话压缩为简洁摘要({} 条消息):\n{}",
608 messages.len(),
609 history
610 )
611}
612
613use super::pipeline::CompressionPipeline;
618use super::types::AiCompressionMode;
619
620pub async fn compress_messages_with_ai(
625 messages: &[Message],
626 config: &CompressionConfig,
627 ai_mode: AiCompressionMode,
628 fast_model: Option<Box<dyn Provider>>,
629 token_usage: u32,
630 context_window: u32,
631) -> Result<Vec<Message>> {
632 let mut pipeline = match (ai_mode, fast_model) {
633 (AiCompressionMode::None, _) => CompressionPipeline::new_rule_only(config.clone()),
634 (AiCompressionMode::Light | AiCompressionMode::Deep, Some(model)) => {
635 CompressionPipeline::new_with_ai(config.clone(), model)
636 }
637 _ => CompressionPipeline::new_rule_only(config.clone()),
638 };
639
640 let result = pipeline
641 .execute(messages, ai_mode, token_usage, context_window)
642 .await?;
643 Ok(result.messages)
644}
645
646pub async fn compress_messages_with_full_ai(
650 messages: &[Message],
651 config: &CompressionConfig,
652 ai_mode: AiCompressionMode,
653 fast_model: Box<dyn Provider>,
654 main_model: Box<dyn Provider>,
655 token_usage: u32,
656 context_window: u32,
657) -> Result<Vec<Message>> {
658 let mut pipeline =
659 CompressionPipeline::new_with_full_ai(config.clone(), fast_model, main_model);
660
661 let result = pipeline
662 .execute(messages, ai_mode, token_usage, context_window)
663 .await?;
664 Ok(result.messages)
665}
666
667pub fn score_messages_only(
671 messages: &[Message],
672 config: &CompressionConfig,
673) -> Vec<super::types::ScoredMessage> {
674 let pipeline = CompressionPipeline::new_rule_only(config.clone());
675 pipeline.score_only(messages)
676}
677
678#[cfg(test)]
683mod tests {
684 use super::*;
685
686 #[test]
687 fn test_estimate_tokens_simple() {
688 let msg = Message {
689 role: Role::User,
690 content: MessageContent::Text("Hello world".to_string()),
691 };
692 assert!(estimate_tokens(&msg) >= 3);
693 }
694
695 #[test]
696 fn test_should_compress() {
697 let config = CompressionConfig::default();
698 assert!(should_compress(100_000, Some(200_000), &config));
700 assert!(!should_compress(80_000, Some(200_000), &config));
702 }
703}