1use crate::providers::{
4 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
5};
6use crate::tokenizer::{count_tokens, message_overhead};
7use crate::truncate::truncate_with_suffix;
8use anyhow::Result;
9use async_trait::async_trait;
10use std::collections::HashSet;
11
12use super::config::{CompressionBias, CompressionConfig};
13use super::types::{CompressionStrategy, SummarizedSegment};
14
15#[async_trait]
21pub trait Compressor: Send + Sync {
22 async fn summarize(
24 &self,
25 messages: &[Message],
26 config: &CompressionConfig,
27 ) -> Result<SummarizedSegment>;
28
29 fn model_name(&self) -> &str;
31}
32
33pub struct AiCompressor {
35 provider: Box<dyn Provider>,
36 model: String,
37 hardcode_config: crate::compress::hardcode_config::HardcodeConfig,
38}
39
40impl AiCompressor {
41 pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
42 Self {
43 provider,
44 model,
45 hardcode_config: crate::compress::hardcode_config::HardcodeConfig::default(),
46 }
47 }
48
49 pub fn with_hardcode_config(mut self, config: crate::compress::hardcode_config::HardcodeConfig) -> Self {
51 self.hardcode_config = config;
52 self
53 }
54}
55
56const SUMMARY_SYSTEM_PROMPT: &str = r#"CRITICAL: 仅用文本响应。不要调用任何工具。
57
58- 不要使用 read、bash、grep、glob、edit、write 或任何其他工具
59- 你已在上方对话中获得所需的所有上下文
60- 工具调用将被拒绝并浪费你唯一的 turn — 你将失败任务
61- 你的整个响应必须是纯文本摘要
62
63---
64
65你是一个对话历史压缩助手。将对话压缩为结构化摘要。
66
67在提供最终摘要前,将你的分析包裹在 <analysis> 标签中以组织思路:
68<analysis>
691. 按时间顺序分析每条消息
702. 识别:用户请求、助手行动、关键决策、错误及修复
713. 特别注意用户的敏感指令(禁止/必须)
724. 双重检查技术准确性和完整性
73</analysis>
74
75输出要求:
76- 结构化:使用9个章节格式
77- 关键:只保留重要信息,忽略无关细节
78- 敏感:必须保留用户的敏感指令(禁止、必须等)— 原样保留
79- 任务:必须保留未完成的待办事项
80- 决策:必须保留关键方案选择和理由
81
829章节输出格式:
83【摘要】一句话概括主要工作(50字以内)
84【已完成】列出已完成的操作(工具调用、文件变更)
85【未完成】列出待办任务和阻塞项 — 最关键,压缩后恢复需要
86【关键决策】重要选择及理由(技术选型、方案决策)
87【敏感指令】用户的禁止/必须指令(必须原样保留)
88【技术栈】使用的语言、框架、库、工具
89【文件变更】读取、修改、创建的文件路径
90【问题记录】遇到的问题及解决方案
91【下一步】建议的下一步操作(直接引用最近对话展示任务中断点)
92
93每章节控制在100字以内,空章节可省略。
94输出摘要后立即停止,不要添加任何解释或后续建议。
95
96REMINDER: 不要调用任何工具。仅用纯文本响应。"#;
97
98#[async_trait]
99impl Compressor for AiCompressor {
100 async fn summarize(
101 &self,
102 messages: &[Message],
103 _config: &CompressionConfig,
104 ) -> Result<SummarizedSegment> {
105 let prompt = build_summary_prompt(messages);
106
107 let request = ChatRequest {
108 messages: vec![Message {
109 role: Role::User,
110 content: MessageContent::Text(prompt),
111 }],
112 tools: vec![],
113 system: Some(SUMMARY_SYSTEM_PROMPT.to_string()),
114 think: false,
115 max_tokens: 1024,
116 server_tools: vec![],
117 enable_caching: false,
118 };
119
120 let response = self.provider.chat(request).await?;
121 let summary_text = extract_text_from_response(&response);
122 let (summary, key_points) = parse_summary_response(&summary_text, &self.hardcode_config);
123
124 Ok(SummarizedSegment {
125 time_range: (chrono::Utc::now(), chrono::Utc::now()),
126 original_count: messages.len(),
127 summary,
128 key_points,
129 })
130 }
131
132 fn model_name(&self) -> &str {
133 &self.model
134 }
135}
136
137fn extract_text_from_response(response: &ChatResponse) -> String {
138 response
139 .content
140 .iter()
141 .filter_map(|block| {
142 if let ContentBlock::Text { text } = block {
143 Some(text.clone())
144 } else {
145 None
146 }
147 })
148 .collect::<Vec<_>>()
149 .join("\n")
150}
151
152fn parse_summary_response(text: &str, config: &crate::compress::hardcode_config::HardcodeConfig) -> (String, Vec<String>) {
153 let mut summary = String::new();
154 let mut key_points: Vec<String> = Vec::new();
155
156 let sections = [
158 "【摘要】",
159 "【已完成】",
160 "【未完成】",
161 "【关键决策】",
162 "【敏感指令】",
163 "【技术栈】",
164 "【文件变更】",
165 "【问题记录】",
166 "【下一步】",
167 ];
168
169 for line in text.lines() {
170 let line = line.trim();
171
172 let is_header = sections.iter().any(|s| line.starts_with(s));
174
175 if is_header {
176 for section in §ions {
178 if line.starts_with(section) {
179 let replaced = line.replace(section, "");
180 let content = replaced.trim();
181 if !content.is_empty() {
182 if *section == "【摘要】" {
183 summary = content.to_string();
184 } else {
185 key_points.push(format!("{}{}", section, content));
186 }
187 }
188 break;
189 }
190 }
191 } else if !line.is_empty() {
192 if line.starts_with("•") || line.starts_with("-") || line.starts_with("*") {
194 let point = line.trim_start_matches(['•', '-', '*']).trim();
195 if !point.is_empty() {
196 key_points.push(point.to_string());
197 }
198 } else if summary.is_empty() {
199 summary = line.to_string();
201 }
202 }
203 }
204
205 if summary.is_empty() && !text.is_empty() {
207 summary = text.lines().take(3).collect::<Vec<_>>().join(" ");
208 if summary.len() > config.summary_length_threshold {
209 summary = truncate_with_suffix(&summary, config.summary_length_threshold);
210 }
211 }
212
213 (summary, key_points)
214}
215
216pub fn compress_messages(
222 messages: &[Message],
223 strategy: CompressionStrategy,
224 config: &CompressionConfig,
225) -> Result<Vec<Message>> {
226 match strategy {
227 CompressionStrategy::Truncate => truncate_compress(messages, config),
228 CompressionStrategy::SlidingWindow => sliding_window_compress(messages, config),
229 CompressionStrategy::Summarize => sliding_window_compress(messages, config),
230 CompressionStrategy::BiasBased => compress_with_bias(messages, config),
231 }
232}
233
234pub fn compress_with_bias(
236 messages: &[Message],
237 config: &CompressionConfig,
238) -> Result<Vec<Message>> {
239 if messages.len() <= config.min_preserve_messages {
240 return Ok(messages.to_vec());
241 }
242
243 let scored: Vec<(usize, Message, f64)> = messages
244 .iter()
245 .enumerate()
246 .map(|(idx, msg)| {
247 (
248 idx,
249 msg.clone(),
250 calculate_preservation_score(msg, idx, messages.len(), &config.bias),
251 )
252 })
253 .collect();
254
255 let mut scored_with_recency: Vec<(usize, Message, f64)> = scored
256 .into_iter()
257 .map(|(idx, msg, score)| {
258 let recency_bonus = if idx >= messages.len() - config.min_preserve_messages {
259 100.0
260 } else {
261 (idx as f64 / messages.len() as f64) * 20.0
262 };
263 (idx, msg, score + recency_bonus)
264 })
265 .collect();
266
267 scored_with_recency.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
268
269 let target_count = if config.bias.aggressive {
270 config.min_preserve_messages
271 } else {
272 let estimated = estimate_total_tokens(messages);
273 let target_tokens = (estimated as f64 * config.target_ratio) as u32;
274 let avg = estimated / messages.len() as u32;
275 (target_tokens / avg.max(1)) as usize
276 };
277
278 let to_keep: HashSet<usize> = scored_with_recency
279 .iter()
280 .take(target_count)
281 .map(|(idx, _, _)| *idx)
282 .collect();
283
284 let compressed: Vec<Message> = messages
285 .iter()
286 .enumerate()
287 .filter(|(idx, _)| to_keep.contains(idx))
288 .map(|(_, msg)| msg.clone())
289 .collect();
290
291 Ok(compressed)
292}
293
294fn calculate_preservation_score(
295 message: &Message,
296 index: usize,
297 _total: usize, bias: &CompressionBias,
299) -> f64 {
300 let mut score: f64 = 10.0;
301
302 if index == 0 {
304 score += 100.0;
305 }
306
307 match message.role {
308 Role::User => {
309 if bias.preserve_user_questions {
310 score += 30.0;
311 }
312 }
313 Role::Assistant => {
314 score += 5.0;
315 }
316 Role::Tool => {
317 if bias.preserve_tools {
318 score += 25.0;
319 }
320 }
321 Role::System => {
322 score += 40.0;
323 }
324 }
325
326 match &message.content {
327 MessageContent::Text(text) => {
328 for keyword in &bias.preserve_keywords {
329 if text.to_lowercase().contains(&keyword.to_lowercase()) {
330 score += 15.0;
331 }
332 }
333 if contains_sensitive_instructions(text) {
334 score += 50.0;
335 }
336 }
337 MessageContent::Blocks(blocks) => {
338 for block in blocks {
339 match block {
340 ContentBlock::ToolUse { name, .. } => {
341 if bias.preserve_tools {
342 score += 20.0;
343 }
344 if name == "write" || name == "edit" || name == "bash" {
345 score += 10.0;
346 }
347 if name == "todo_write" {
349 score += 60.0;
350 }
351 if name == "ask" {
353 score += 50.0;
354 }
355 }
356 ContentBlock::ToolResult { content, .. } => {
357 if bias.preserve_tools {
358 score += 20.0;
359 }
360 if contains_sensitive_instructions(content) {
361 score += 30.0;
362 }
363 if content.contains("TodoWrite") || content.contains("todo") {
365 score += 40.0;
366 }
367 if content.contains("AskUserQuestion") || content.contains("answer") {
369 score += 30.0;
370 }
371 }
372 ContentBlock::Thinking { .. } => {
373 if bias.preserve_thinking {
374 score += 25.0;
375 } else {
376 score -= 5.0;
377 }
378 }
379 ContentBlock::Text { text } => {
380 if contains_sensitive_instructions(text) {
381 score += 50.0;
382 }
383 }
384 _ => {}
385 }
386 }
387 }
388 }
389
390 score
391}
392
393fn contains_sensitive_instructions(text: &str) -> bool {
394 let lower = text.to_lowercase();
395 let patterns = [
396 "不要",
397 "禁止",
398 "必须",
399 "不允许",
400 "never",
401 "must not",
402 "do not",
403 ];
404 patterns.iter().any(|p| lower.contains(p))
405}
406
407fn truncate_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
408 if messages.len() <= config.min_preserve_messages {
409 return Ok(messages.to_vec());
410 }
411 Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
412}
413
414fn sliding_window_compress(
415 messages: &[Message],
416 config: &CompressionConfig,
417) -> Result<Vec<Message>> {
418 if messages.len() <= config.min_preserve_messages {
419 return Ok(messages.to_vec());
420 }
421
422 let first_msg = messages.first().cloned();
428 let recent_start = messages.len().saturating_sub(config.min_preserve_messages);
429 let recent_msgs = &messages[recent_start..];
430
431 let first_tokens = first_msg.as_ref().map(estimate_tokens).unwrap_or(0);
433 let recent_tokens = estimate_total_tokens(recent_msgs);
434 let current_total = estimate_total_tokens(messages);
435 let target_tokens = (current_total as f64 * config.target_ratio) as u32;
436
437 if first_tokens + recent_tokens <= target_tokens {
439 let mut result: Vec<Message> = Vec::new();
441 if let Some(first) = first_msg {
442 result.push(first);
443 }
444 result.extend(recent_msgs.iter().cloned());
445 return Ok(result);
446 }
447
448 for drop_count in 0..recent_msgs.len() {
450 let candidate = &recent_msgs[drop_count..];
451 if estimate_total_tokens(candidate) <= target_tokens {
452 return Ok(candidate.to_vec());
453 }
454 }
455
456 Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
458}
459
460pub fn estimate_tokens(message: &Message) -> u32 {
466 let content_text = match &message.content {
467 MessageContent::Text(t) => t.clone(),
468 MessageContent::Blocks(blocks) => {
469 let mut text = String::new();
470 for block in blocks {
471 match block {
472 ContentBlock::Text { text: t } => {
473 text.push_str(t);
474 text.push(' '); }
476 ContentBlock::ToolUse { name, input, .. } => {
477 text.push_str(name);
478 text.push(' '); text.push_str(&input.to_string());
480 text.push(' '); }
482 ContentBlock::ToolResult { content, .. } => {
483 text.push_str(content);
484 text.push(' '); }
486 ContentBlock::Thinking { thinking, .. } => {
487 text.push_str(thinking);
488 text.push(' '); }
490 _ => {}
491 }
492 }
493 text
494 }
495 };
496
497 let content_tokens = count_tokens(&content_text);
499 let role_tokens = count_tokens(&format!("{:?}: ", message.role));
500
501 content_tokens + role_tokens + message_overhead()
503}
504
505#[deprecated(note = "Use estimate_tokens with tiktoken instead")]
507fn count_chars(s: &str) -> (u32, u32) {
508 let mut ascii = 0u32;
509 let mut non_ascii = 0u32;
510 for ch in s.chars() {
511 if ch.is_ascii() {
512 ascii += 1;
513 } else {
514 non_ascii += 1;
515 }
516 }
517 (ascii, non_ascii)
518}
519
520pub fn estimate_total_tokens(messages: &[Message]) -> u32 {
522 messages.iter().map(estimate_tokens).sum()
523}
524
525pub fn should_compress(
527 current_tokens: u32,
528 context_size: Option<u32>,
529 config: &CompressionConfig,
530) -> bool {
531 match context_size {
532 Some(size) => (current_tokens as f64 / size as f64) >= config.threshold,
533 None => false,
534 }
535}
536
537pub fn build_summary_prompt(messages: &[Message]) -> String {
539 let history = messages
540 .iter()
541 .map(|m| {
542 let role = match m.role {
543 Role::User => "用户",
544 Role::Assistant => "助手",
545 Role::Tool => "工具",
546 Role::System => "系统",
547 };
548 let preview = match &m.content {
549 MessageContent::Text(t) => truncate_with_suffix(t, 200),
550 MessageContent::Blocks(blocks) => blocks
551 .iter()
552 .map(|b| match b {
553 ContentBlock::Text { text } => truncate_with_suffix(text, 100),
554 ContentBlock::ToolUse { name, .. } => format!("[工具: {}]", name),
555 ContentBlock::ToolResult { content, .. } => {
556 truncate_with_suffix(content, 100)
557 }
558 _ => "[...]".to_string(),
559 })
560 .collect::<Vec<_>>()
561 .join(" | "),
562 };
563 format!("{}: {}", role, preview)
564 })
565 .collect::<Vec<_>>()
566 .join("\n");
567
568 format!(
569 "请将以下对话压缩为简洁摘要({} 条消息):\n{}",
570 messages.len(),
571 history
572 )
573}
574
575use super::pipeline::CompressionPipeline;
580use super::types::AiCompressionMode;
581
582pub async fn compress_messages_with_ai(
587 messages: &[Message],
588 config: &CompressionConfig,
589 ai_mode: AiCompressionMode,
590 fast_model: Option<Box<dyn Provider>>,
591 token_usage: u32,
592 context_window: u32,
593) -> Result<Vec<Message>> {
594 let mut pipeline = match (ai_mode, fast_model) {
595 (AiCompressionMode::None, _) => CompressionPipeline::new_rule_only(config.clone()),
596 (AiCompressionMode::Light | AiCompressionMode::Deep, Some(model)) => {
597 CompressionPipeline::new_with_ai(config.clone(), model)
598 }
599 _ => CompressionPipeline::new_rule_only(config.clone()),
600 };
601
602 let result = pipeline
603 .execute(messages, ai_mode, token_usage, context_window)
604 .await?;
605 Ok(result.messages)
606}
607
608pub async fn compress_messages_with_full_ai(
612 messages: &[Message],
613 config: &CompressionConfig,
614 ai_mode: AiCompressionMode,
615 fast_model: Box<dyn Provider>,
616 main_model: Box<dyn Provider>,
617 token_usage: u32,
618 context_window: u32,
619) -> Result<Vec<Message>> {
620 let mut pipeline =
621 CompressionPipeline::new_with_full_ai(config.clone(), fast_model, main_model);
622
623 let result = pipeline
624 .execute(messages, ai_mode, token_usage, context_window)
625 .await?;
626 Ok(result.messages)
627}
628
629pub fn score_messages_only(
633 messages: &[Message],
634 config: &CompressionConfig,
635) -> Vec<super::types::ScoredMessage> {
636 let pipeline = CompressionPipeline::new_rule_only(config.clone());
637 pipeline.score_only(messages)
638}
639
640#[cfg(test)]
645mod tests {
646 use super::*;
647
648 #[test]
649 fn test_estimate_tokens_simple() {
650 let msg = Message {
651 role: Role::User,
652 content: MessageContent::Text("Hello world".to_string()),
653 };
654 assert!(estimate_tokens(&msg) >= 3);
655 }
656
657 #[test]
658 fn test_should_compress() {
659 let config = CompressionConfig::default();
660 assert!(should_compress(100_000, Some(200_000), &config));
662 assert!(!should_compress(80_000, Some(200_000), &config));
664 }
665}