1use crate::providers::{
4 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
5};
6use crate::truncate::truncate_with_suffix;
7use anyhow::Result;
8use async_trait::async_trait;
9use std::collections::HashSet;
10
11use super::config::{CompressionBias, CompressionConfig};
12use super::types::{CompressionStrategy, SummarizedSegment};
13
14#[async_trait]
20pub trait Compressor: Send + Sync {
21 async fn summarize(
23 &self,
24 messages: &[Message],
25 config: &CompressionConfig,
26 ) -> Result<SummarizedSegment>;
27
28 fn model_name(&self) -> &str;
30}
31
32pub struct AiCompressor {
34 provider: Box<dyn Provider>,
35 model: String,
36}
37
38impl AiCompressor {
39 pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
40 Self { provider, model }
41 }
42}
43
44const SUMMARY_SYSTEM_PROMPT: &str = r#"你是一个对话历史压缩助手。将对话压缩为结构化摘要。
45
46输出要求:
47- 结构化:使用9个章节格式
48- 关键:只保留重要信息,忽略无关细节
49- 敏感:必须保留用户的敏感指令(禁止、必须等)
50- 任务:必须保留未完成的待办事项
51- 决策:必须保留关键方案选择和理由
52
539章节输出格式:
54【摘要】一句话概括主要工作(50字以内)
55【已完成】列出已完成的操作(工具调用、文件变更)
56【未完成】列出待办任务和阻塞项
57【关键决策】重要选择及理由(技术选型、方案决策)
58【敏感指令】用户的禁止/必须指令(必须原样保留)
59【技术栈】使用的语言、框架、库、工具
60【文件变更】读取、修改、创建的文件路径
61【问题记录】遇到的问题及解决方案
62【下一步】建议的下一步操作
63
64每章节控制在100字以内,空章节可省略。
65请直接输出内容。"#;
66
67#[async_trait]
68impl Compressor for AiCompressor {
69 async fn summarize(
70 &self,
71 messages: &[Message],
72 _config: &CompressionConfig,
73 ) -> Result<SummarizedSegment> {
74 let prompt = build_summary_prompt(messages);
75
76 let request = ChatRequest {
77 messages: vec![Message {
78 role: Role::User,
79 content: MessageContent::Text(prompt),
80 }],
81 tools: vec![],
82 system: Some(SUMMARY_SYSTEM_PROMPT.to_string()),
83 think: false,
84 max_tokens: 1024,
85 server_tools: vec![],
86 enable_caching: false,
87 };
88
89 let response = self.provider.chat(request).await?;
90 let summary_text = extract_text_from_response(&response);
91 let (summary, key_points) = parse_summary_response(&summary_text);
92
93 Ok(SummarizedSegment {
94 time_range: (chrono::Utc::now(), chrono::Utc::now()),
95 original_count: messages.len(),
96 summary,
97 key_points,
98 })
99 }
100
101 fn model_name(&self) -> &str {
102 &self.model
103 }
104}
105
106fn extract_text_from_response(response: &ChatResponse) -> String {
107 response
108 .content
109 .iter()
110 .filter_map(|block| {
111 if let ContentBlock::Text { text } = block {
112 Some(text.clone())
113 } else {
114 None
115 }
116 })
117 .collect::<Vec<_>>()
118 .join("\n")
119}
120
121fn parse_summary_response(text: &str) -> (String, Vec<String>) {
122 let mut summary = String::new();
123 let mut key_points: Vec<String> = Vec::new();
124
125 let sections = [
127 "【摘要】", "【已完成】", "【未完成】", "【关键决策】",
128 "【敏感指令】", "【技术栈】", "【文件变更】", "【问题记录】", "【下一步】"
129 ];
130
131 for line in text.lines() {
132 let line = line.trim();
133
134 let is_header = sections.iter().any(|s| line.starts_with(s));
136
137 if is_header {
138 for section in §ions {
140 if line.starts_with(section) {
141 let replaced = line.replace(section, "");
142 let content = replaced.trim();
143 if !content.is_empty() {
144 if *section == "【摘要】" {
145 summary = content.to_string();
146 } else {
147 key_points.push(format!("{}{}", section, content));
148 }
149 }
150 break;
151 }
152 }
153 } else if !line.is_empty() {
154 if line.starts_with("•") || line.starts_with("-") || line.starts_with("*") {
156 let point = line.trim_start_matches(['•', '-', '*']).trim();
157 if !point.is_empty() {
158 key_points.push(point.to_string());
159 }
160 } else if summary.is_empty() {
161 summary = line.to_string();
163 }
164 }
165 }
166
167 if summary.is_empty() && !text.is_empty() {
169 summary = text.lines().take(3).collect::<Vec<_>>().join(" ");
170 if summary.len() > 200 {
171 summary = truncate_with_suffix(&summary, 200);
172 }
173 }
174
175 (summary, key_points)
176}
177
178pub fn compress_messages(
184 messages: &[Message],
185 strategy: CompressionStrategy,
186 config: &CompressionConfig,
187) -> Result<Vec<Message>> {
188 match strategy {
189 CompressionStrategy::Truncate => truncate_compress(messages, config),
190 CompressionStrategy::SlidingWindow => sliding_window_compress(messages, config),
191 CompressionStrategy::Summarize => sliding_window_compress(messages, config),
192 CompressionStrategy::BiasBased => compress_with_bias(messages, config),
193 }
194}
195
196pub fn compress_with_bias(
198 messages: &[Message],
199 config: &CompressionConfig,
200) -> Result<Vec<Message>> {
201 if messages.len() <= config.min_preserve_messages {
202 return Ok(messages.to_vec());
203 }
204
205 let scored: Vec<(usize, Message, f64)> = messages
206 .iter()
207 .enumerate()
208 .map(|(idx, msg)| {
209 (
210 idx,
211 msg.clone(),
212 calculate_preservation_score(msg, idx, messages.len(), &config.bias),
213 )
214 })
215 .collect();
216
217 let mut scored_with_recency: Vec<(usize, Message, f64)> = scored
218 .into_iter()
219 .map(|(idx, msg, score)| {
220 let recency_bonus = if idx >= messages.len() - config.min_preserve_messages {
221 100.0
222 } else {
223 (idx as f64 / messages.len() as f64) * 20.0
224 };
225 (idx, msg, score + recency_bonus)
226 })
227 .collect();
228
229 scored_with_recency.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
230
231 let target_count = if config.bias.aggressive {
232 config.min_preserve_messages
233 } else {
234 let estimated = estimate_total_tokens(messages);
235 let target_tokens = (estimated as f64 * config.target_ratio) as u32;
236 let avg = estimated / messages.len() as u32;
237 (target_tokens / avg.max(1)) as usize
238 };
239
240 let to_keep: HashSet<usize> = scored_with_recency
241 .iter()
242 .take(target_count)
243 .map(|(idx, _, _)| *idx)
244 .collect();
245
246 let compressed: Vec<Message> = messages
247 .iter()
248 .enumerate()
249 .filter(|(idx, _)| to_keep.contains(idx))
250 .map(|(_, msg)| msg.clone())
251 .collect();
252
253 Ok(compressed)
254}
255
256fn calculate_preservation_score(
257 message: &Message,
258 index: usize,
259 _total: usize, bias: &CompressionBias,
261) -> f64 {
262 let mut score: f64 = 10.0;
263
264 if index == 0 {
266 score += 100.0;
267 }
268
269 match message.role {
270 Role::User => {
271 if bias.preserve_user_questions {
272 score += 30.0;
273 }
274 }
275 Role::Assistant => {
276 score += 5.0;
277 }
278 Role::Tool => {
279 if bias.preserve_tools {
280 score += 25.0;
281 }
282 }
283 Role::System => {
284 score += 40.0;
285 }
286 }
287
288 match &message.content {
289 MessageContent::Text(text) => {
290 for keyword in &bias.preserve_keywords {
291 if text.to_lowercase().contains(&keyword.to_lowercase()) {
292 score += 15.0;
293 }
294 }
295 if contains_sensitive_instructions(text) {
296 score += 50.0;
297 }
298 }
299 MessageContent::Blocks(blocks) => {
300 for block in blocks {
301 match block {
302 ContentBlock::ToolUse { name, .. } => {
303 if bias.preserve_tools {
304 score += 20.0;
305 }
306 if name == "write" || name == "edit" || name == "bash" {
307 score += 10.0;
308 }
309 if name == "todo_write" {
311 score += 60.0;
312 }
313 if name == "ask" {
315 score += 50.0;
316 }
317 }
318 ContentBlock::ToolResult { content, .. } => {
319 if bias.preserve_tools {
320 score += 20.0;
321 }
322 if contains_sensitive_instructions(content) {
323 score += 30.0;
324 }
325 if content.contains("TodoWrite") || content.contains("todo") {
327 score += 40.0;
328 }
329 if content.contains("AskUserQuestion") || content.contains("answer") {
331 score += 30.0;
332 }
333 }
334 ContentBlock::Thinking { .. } => {
335 if bias.preserve_thinking {
336 score += 25.0;
337 } else {
338 score -= 5.0;
339 }
340 }
341 ContentBlock::Text { text } => {
342 if contains_sensitive_instructions(text) {
343 score += 50.0;
344 }
345 }
346 _ => {}
347 }
348 }
349 }
350 }
351
352 score
353}
354
355fn contains_sensitive_instructions(text: &str) -> bool {
356 let lower = text.to_lowercase();
357 let patterns = [
358 "不要",
359 "禁止",
360 "必须",
361 "不允许",
362 "never",
363 "must not",
364 "do not",
365 ];
366 patterns.iter().any(|p| lower.contains(p))
367}
368
369fn truncate_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
370 if messages.len() <= config.min_preserve_messages {
371 return Ok(messages.to_vec());
372 }
373 Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
374}
375
376fn sliding_window_compress(
377 messages: &[Message],
378 config: &CompressionConfig,
379) -> Result<Vec<Message>> {
380 if messages.len() <= config.min_preserve_messages {
381 return Ok(messages.to_vec());
382 }
383
384 let first_msg = messages.first().cloned();
390 let recent_start = messages.len().saturating_sub(config.min_preserve_messages);
391 let recent_msgs = &messages[recent_start..];
392
393 let first_tokens = first_msg.as_ref().map(|m| estimate_tokens(m)).unwrap_or(0);
395 let recent_tokens = estimate_total_tokens(recent_msgs);
396 let current_total = estimate_total_tokens(messages);
397 let target_tokens = (current_total as f64 * config.target_ratio) as u32;
398
399 if first_tokens + recent_tokens <= target_tokens {
401 let mut result: Vec<Message> = Vec::new();
403 if let Some(first) = first_msg {
404 result.push(first);
405 }
406 result.extend(recent_msgs.iter().cloned());
407 return Ok(result);
408 }
409
410 for drop_count in 0..recent_msgs.len() {
412 let candidate = &recent_msgs[drop_count..];
413 if estimate_total_tokens(candidate) <= target_tokens {
414 return Ok(candidate.to_vec());
415 }
416 }
417
418 Ok(messages[messages.len() - config.min_preserve_messages..].to_vec())
420}
421
422pub fn estimate_tokens(message: &Message) -> u32 {
428 let (ascii, non_ascii) = match &message.content {
429 MessageContent::Text(t) => count_chars(t),
430 MessageContent::Blocks(blocks) => {
431 let mut a = 0u32;
432 let mut n = 0u32;
433 for block in blocks {
434 match block {
435 ContentBlock::Text { text } => {
436 let (ca, cn) = count_chars(text);
437 a += ca;
438 n += cn;
439 }
440 ContentBlock::ToolUse { name, input, .. } => {
441 let (ca, cn) = count_chars(name);
442 a += ca;
443 n += cn;
444 let (ja, jn) = count_chars(&input.to_string());
445 a += ja;
446 n += jn;
447 }
448 ContentBlock::ToolResult { content, .. } => {
449 let (ca, cn) = count_chars(content);
450 a += ca;
451 n += cn;
452 }
453 ContentBlock::Thinking { thinking, .. } => {
454 let (ca, cn) = count_chars(thinking);
455 a += ca;
456 n += cn;
457 }
458 _ => {}
459 }
460 }
461 (a, n)
462 }
463 };
464
465 let ascii_tokens = (ascii as f64 * 0.25).ceil() as u32;
466 let non_ascii_tokens = (non_ascii as f64 * 0.67).ceil() as u32;
467 (ascii_tokens + non_ascii_tokens + 10).max(1)
468}
469
470fn count_chars(s: &str) -> (u32, u32) {
471 let mut ascii = 0u32;
472 let mut non_ascii = 0u32;
473 for ch in s.chars() {
474 if ch.is_ascii() {
475 ascii += 1;
476 } else {
477 non_ascii += 1;
478 }
479 }
480 (ascii, non_ascii)
481}
482
483pub fn estimate_total_tokens(messages: &[Message]) -> u32 {
485 messages.iter().map(estimate_tokens).sum()
486}
487
488pub fn should_compress(
490 current_tokens: u32,
491 context_size: Option<u32>,
492 config: &CompressionConfig,
493) -> bool {
494 match context_size {
495 Some(size) => (current_tokens as f64 / size as f64) >= config.threshold,
496 None => false,
497 }
498}
499
500pub fn build_summary_prompt(messages: &[Message]) -> String {
502 let history = messages
503 .iter()
504 .map(|m| {
505 let role = match m.role {
506 Role::User => "用户",
507 Role::Assistant => "助手",
508 Role::Tool => "工具",
509 Role::System => "系统",
510 };
511 let preview = match &m.content {
512 MessageContent::Text(t) => truncate_with_suffix(t, 200),
513 MessageContent::Blocks(blocks) => blocks
514 .iter()
515 .map(|b| match b {
516 ContentBlock::Text { text } => truncate_with_suffix(text, 100),
517 ContentBlock::ToolUse { name, .. } => format!("[工具: {}]", name),
518 ContentBlock::ToolResult { content, .. } => {
519 truncate_with_suffix(content, 100)
520 }
521 _ => "[...]".to_string(),
522 })
523 .collect::<Vec<_>>()
524 .join(" | "),
525 };
526 format!("{}: {}", role, preview)
527 })
528 .collect::<Vec<_>>()
529 .join("\n");
530
531 format!(
532 "请将以下对话压缩为简洁摘要({} 条消息):\n{}",
533 messages.len(),
534 history
535 )
536}
537
538use super::pipeline::CompressionPipeline;
543use super::types::AiCompressionMode;
544
545pub async fn compress_messages_with_ai(
550 messages: &[Message],
551 config: &CompressionConfig,
552 ai_mode: AiCompressionMode,
553 fast_model: Option<Box<dyn Provider>>,
554 token_usage: u32,
555 context_window: u32,
556) -> Result<Vec<Message>> {
557 let mut pipeline = match (ai_mode, fast_model) {
558 (AiCompressionMode::None, _) => CompressionPipeline::new_rule_only(config.clone()),
559 (AiCompressionMode::Light | AiCompressionMode::Deep, Some(model)) => {
560 CompressionPipeline::new_with_ai(config.clone(), model)
561 }
562 _ => CompressionPipeline::new_rule_only(config.clone()),
563 };
564
565 let result = pipeline.execute(messages, ai_mode, token_usage, context_window).await?;
566 Ok(result.messages)
567}
568
569pub async fn compress_messages_with_full_ai(
573 messages: &[Message],
574 config: &CompressionConfig,
575 ai_mode: AiCompressionMode,
576 fast_model: Box<dyn Provider>,
577 main_model: Box<dyn Provider>,
578 token_usage: u32,
579 context_window: u32,
580) -> Result<Vec<Message>> {
581 let mut pipeline = CompressionPipeline::new_with_full_ai(
582 config.clone(),
583 fast_model,
584 main_model,
585 );
586
587 let result = pipeline.execute(messages, ai_mode, token_usage, context_window).await?;
588 Ok(result.messages)
589}
590
591pub fn score_messages_only(
595 messages: &[Message],
596 config: &CompressionConfig,
597) -> Vec<super::types::ScoredMessage> {
598 let pipeline = CompressionPipeline::new_rule_only(config.clone());
599 pipeline.score_only(messages)
600}
601
602#[cfg(test)]
607mod tests {
608 use super::*;
609
610 #[test]
611 fn test_estimate_tokens_simple() {
612 let msg = Message {
613 role: Role::User,
614 content: MessageContent::Text("Hello world".to_string()),
615 };
616 assert!(estimate_tokens(&msg) >= 3);
617 }
618
619 #[test]
620 fn test_should_compress() {
621 let config = CompressionConfig::default();
622 assert!(should_compress(100_000, Some(200_000), &config));
624 assert!(!should_compress(80_000, Some(200_000), &config));
626 }
627}