1use anyhow::Result;
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7use crate::providers::{ContentBlock, Message, MessageContent, Provider, Role, ChatRequest, ChatResponse};
8
9pub const DEFAULT_COMPRESSION_THRESHOLD: f64 = 0.75;
11
12pub const MIN_MESSAGES_TO_KEEP: usize = 8;
14
15pub const DEFAULT_TARGET_RATIO: f64 = 0.4;
17
18pub const DEFAULT_COMPRESSOR_MODEL: &str = "claude-3-5-haiku-20241022";
20
21#[derive(Debug, Clone, Default)]
23pub struct CompressionBias {
24 pub preserve_tools: bool,
26 pub preserve_thinking: bool,
28 pub preserve_user_questions: bool,
30 pub compact_long_outputs: bool,
32 pub aggressive: bool,
34 pub preserve_keywords: Vec<String>,
36}
37
38impl CompressionBias {
39 pub fn balanced() -> Self {
41 Self {
42 preserve_tools: true,
43 preserve_thinking: false,
44 preserve_user_questions: true,
45 compact_long_outputs: false,
46 aggressive: false,
47 preserve_keywords: vec![
48 "决定".to_string(), "decision".to_string(),
49 "重要".to_string(), "important".to_string(),
50 "关键".to_string(), "key".to_string()
51 ],
52 }
53 }
54
55 pub fn preserve_important() -> Self {
57 Self {
58 preserve_tools: true,
59 preserve_thinking: true,
60 preserve_user_questions: true,
61 compact_long_outputs: true,
62 aggressive: false,
63 preserve_keywords: vec![
64 "决定".to_string(), "decision".to_string(),
65 "重要".to_string(), "important".to_string(),
66 "关键".to_string(), "key".to_string(),
67 "完成".to_string(), "done".to_string(),
68 "成功".to_string(), "success".to_string()
69 ],
70 }
71 }
72
73 pub fn aggressive() -> Self {
75 Self {
76 preserve_tools: false,
77 preserve_thinking: false,
78 preserve_user_questions: false,
79 compact_long_outputs: false,
80 aggressive: true,
81 preserve_keywords: vec![],
82 }
83 }
84
85 pub fn tool_focused() -> Self {
87 Self {
88 preserve_tools: true,
89 preserve_thinking: false,
90 preserve_user_questions: false,
91 compact_long_outputs: false,
92 aggressive: false,
93 preserve_keywords: vec![
94 "工具".to_string(), "tool".to_string(),
95 "执行".to_string(), "execute".to_string(),
96 "文件".to_string(), "file".to_string()
97 ],
98 }
99 }
100
101 pub fn parse(spec: &str) -> Result<Self> {
104 let spec = spec.trim().to_lowercase();
105
106 if spec == "balanced" || spec == "default" || spec.is_empty() {
107 return Ok(Self::balanced());
108 }
109 if spec == "aggressive" {
110 return Ok(Self::aggressive());
111 }
112 if spec == "preserve_important" || spec == "important" {
113 return Ok(Self::preserve_important());
114 }
115 if spec == "tool_focused" || spec == "tools" {
116 return Ok(Self::tool_focused());
117 }
118
119 let mut bias = Self::default();
121
122 for part in spec.split_whitespace() {
123 if let Some(preserve_list) = part.strip_prefix("preserve:") {
124 for item in preserve_list.split(',') {
125 match item.trim() {
126 "tools" | "tool" => bias.preserve_tools = true,
127 "thinking" | "think" => bias.preserve_thinking = true,
128 "user" | "questions" => bias.preserve_user_questions = true,
129 "compact" | "long" => bias.compact_long_outputs = true,
130 _ => {}
131 }
132 }
133 } else if let Some(keyword_list) = part.strip_prefix("keywords:") {
134 bias.preserve_keywords = keyword_list.split(',')
135 .map(|k| k.trim().to_string())
136 .filter(|k| !k.is_empty())
137 .collect();
138 } else if part == "aggressive" {
139 bias.aggressive = true;
140 }
141 }
142
143 Ok(bias)
144 }
145
146 pub fn format(&self) -> String {
148 let mut parts: Vec<String> = Vec::new();
149
150 if self.preserve_tools { parts.push("tools".to_string()); }
151 if self.preserve_thinking { parts.push("thinking".to_string()); }
152 if self.preserve_user_questions { parts.push("user".to_string()); }
153 if self.compact_long_outputs { parts.push("compact".to_string()); }
154 if self.aggressive { parts.push("aggressive".to_string()); }
155
156 if !self.preserve_keywords.is_empty() {
157 parts.push(format!("keywords:{}", self.preserve_keywords.join(",")));
158 }
159
160 if parts.is_empty() {
161 "default".to_string()
162 } else {
163 parts.join(", ")
164 }
165 }
166}
167
168#[derive(Debug, Clone)]
170pub struct CompressionConfig {
171 pub threshold: f64,
173 pub target_ratio: f64,
175 pub min_preserve_messages: usize,
177 pub use_summarization: bool,
179 pub compressor_model: Option<String>,
181 pub bias: CompressionBias,
183}
184
185impl Default for CompressionConfig {
186 fn default() -> Self {
187 Self {
188 threshold: DEFAULT_COMPRESSION_THRESHOLD,
189 target_ratio: DEFAULT_TARGET_RATIO,
190 min_preserve_messages: MIN_MESSAGES_TO_KEEP,
191 use_summarization: true,
192 compressor_model: None,
193 bias: CompressionBias::balanced(),
194 }
195 }
196}
197
198impl CompressionConfig {
199 pub fn compressor_model_name(&self) -> &str {
201 self.compressor_model.as_deref().unwrap_or(DEFAULT_COMPRESSOR_MODEL)
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct CompressionResult {
208 pub original_count: usize,
210 pub new_count: usize,
212 pub tokens_saved: u32,
214 pub summary: Option<String>,
216 pub strategy: CompressionStrategy,
218 pub timestamp: DateTime<Utc>,
220}
221
222impl CompressionResult {
223 pub fn new(
225 original_count: usize,
226 new_count: usize,
227 tokens_saved: u32,
228 summary: Option<String>,
229 strategy: CompressionStrategy,
230 ) -> Self {
231 Self {
232 original_count,
233 new_count,
234 tokens_saved,
235 summary,
236 strategy,
237 timestamp: Utc::now(),
238 }
239 }
240
241 pub fn format_summary(&self) -> String {
243 let strategy_name = match self.strategy {
244 CompressionStrategy::Truncate => "truncate",
245 CompressionStrategy::SlidingWindow => "sliding window",
246 CompressionStrategy::Summarize => "AI summarize",
247 CompressionStrategy::BiasBased => "bias-based",
248 };
249 format!(
250 "{} messages → {} messages (saved ~{} tokens, {})",
251 self.original_count,
252 self.new_count,
253 format_tokens(self.tokens_saved),
254 strategy_name
255 )
256 }
257}
258
259pub fn format_tokens(n: u32) -> String {
260 if n < 1_000 {
261 n.to_string()
262 } else if n < 10_000 {
263 format!("{:.1}K", n as f64 / 1_000.0)
264 } else {
265 format!("{:.0}K", n as f64 / 1_000.0)
266 }
267}
268
269#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
271#[serde(rename_all = "snake_case")]
272pub enum CompressionStrategy {
273 Truncate,
275 SlidingWindow,
277 Summarize,
279 BiasBased,
281}
282
283#[derive(Debug, Clone, Serialize, Deserialize)]
285pub struct SummarizedSegment {
286 pub time_range: (DateTime<Utc>, DateTime<Utc>),
288 pub original_count: usize,
290 pub summary: String,
292 pub key_points: Vec<String>,
294}
295
296impl SummarizedSegment {
297 pub fn to_message(&self) -> Message {
299 let key_points_text = if self.key_points.is_empty() {
300 "无".to_string()
301 } else {
302 self.key_points.iter().map(|p| format!("• {}", p)).collect::<Vec<_>>().join("\n")
303 };
304
305 let content = format!(
306 "[对话摘要 - 原 {} 条消息]\n\n{}\n\n关键要点:\n{}",
307 self.original_count,
308 self.summary,
309 key_points_text
310 );
311
312 Message {
313 role: Role::User,
314 content: MessageContent::Text(content),
315 }
316 }
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct CompressionHistoryEntry {
322 pub timestamp: DateTime<Utc>,
324 pub strategy: CompressionStrategy,
326 pub original_count: usize,
328 pub new_count: usize,
330 pub tokens_saved: u32,
332 pub has_summary: bool,
334}
335
336impl CompressionHistoryEntry {
337 pub fn from_result(result: &CompressionResult) -> Self {
339 Self {
340 timestamp: result.timestamp,
341 strategy: result.strategy,
342 original_count: result.original_count,
343 new_count: result.new_count,
344 tokens_saved: result.tokens_saved,
345 has_summary: result.summary.is_some(),
346 }
347 }
348
349 pub fn format_line(&self) -> String {
351 let strategy_name = match self.strategy {
352 CompressionStrategy::Truncate => "truncate",
353 CompressionStrategy::SlidingWindow => "sliding window",
354 CompressionStrategy::Summarize => "AI summarize",
355 CompressionStrategy::BiasBased => "bias-based",
356 };
357 let summary_marker = if self.has_summary { "📝" } else { "✂️" };
358 format!(
359 "{} {} - {} msgs → {} msgs (~{} tokens saved) {}",
360 self.timestamp.format("%Y-%m-%d %H:%M"),
361 strategy_name,
362 self.original_count,
363 self.new_count,
364 format_tokens(self.tokens_saved),
365 summary_marker
366 )
367 }
368}
369
370#[async_trait]
372pub trait Compressor: Send + Sync {
373 async fn summarize(&self, messages: &[Message], config: &CompressionConfig) -> Result<SummarizedSegment>;
375
376 fn model_name(&self) -> &str;
378}
379
380pub struct AiCompressor {
382 provider: Box<dyn Provider>,
383 model: String,
384}
385
386impl AiCompressor {
387 pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
389 Self { provider, model }
390 }
391}
392
393#[async_trait]
394impl Compressor for AiCompressor {
395 async fn summarize(&self, messages: &[Message], _config: &CompressionConfig) -> Result<SummarizedSegment> {
396 let prompt = build_summary_prompt(messages);
397
398 let request = ChatRequest {
399 messages: vec![Message {
400 role: Role::User,
401 content: MessageContent::Text(prompt),
402 }],
403 tools: vec![], system: Some(SUMMARY_SYSTEM_PROMPT.to_string()),
405 think: false, max_tokens: 1024, server_tools: vec![],
408 enable_caching: false, };
410
411 let response = self.provider.chat(request).await?;
412
413 let summary_text = extract_text_from_response(&response);
415
416 let (summary, key_points) = parse_summary_response(&summary_text);
418
419 Ok(SummarizedSegment {
420 time_range: (Utc::now(), Utc::now()), original_count: messages.len(),
422 summary,
423 key_points,
424 })
425 }
426
427 fn model_name(&self) -> &str {
428 &self.model
429 }
430}
431
432const SUMMARY_SYSTEM_PROMPT: &str = r#"你是一个对话历史压缩助手。你的任务是将对话历史压缩为简洁的摘要,保留关键信息。
434
435输出要求:
436- 简洁:摘要控制在 200 字以内
437- 关键:只保留重要操作和决策
438- 结构化:使用清晰格式
439- 敏感:必须保留用户的敏感指令(如"不要..."、"必须..."、"禁止..."等)
440- 偏好:保留用户的偏好设置和决策
441
442请直接输出摘要内容。"#;
443
444fn extract_text_from_response(response: &ChatResponse) -> String {
446 response.content
447 .iter()
448 .filter_map(|block| {
449 if let ContentBlock::Text { text } = block {
450 Some(text.clone())
451 } else {
452 None
453 }
454 })
455 .collect::<Vec<_>>()
456 .join("\n")
457}
458
459fn parse_summary_response(text: &str) -> (String, Vec<String>) {
461 let mut summary = String::new();
462 let mut key_points: Vec<String> = Vec::new();
463
464 for line in text.lines() {
465 let line = line.trim();
466
467 if line.starts_with("•") || line.starts_with("-") || line.starts_with("*") {
469 let point = line.trim_start_matches(['•', '-', '*']).trim();
470 if !point.is_empty() {
471 key_points.push(point.to_string());
472 }
473 } else if line.starts_with("已完成") || line.starts_with("操作") {
474 let ops = line.trim_start_matches(|c: char| c.is_alphabetic() || c == ':' || c == ':').trim();
476 if !ops.is_empty() && ops != ":" && ops != ":" {
477 key_points.push(ops.to_string());
478 }
479 } else if !line.is_empty() && summary.is_empty() {
480 summary = line.to_string();
482 } else if !line.is_empty() {
483 if key_points.is_empty() && summary.len() < 200 {
485 summary.push(' ');
486 summary.push_str(line);
487 }
488 }
489 }
490
491 if summary.is_empty() && !text.is_empty() {
493 summary = text.lines().take(3).collect::<Vec<_>>().join(" ");
494 if summary.len() > 200 {
495 summary = truncate_text(&summary, 200);
496 }
497 }
498
499 (summary, key_points)
500}
501
502fn truncate_text(s: &str, max: usize) -> String {
503 if s.len() <= max {
504 s.to_string()
505 } else {
506 let mut end = max;
507 while end > 0 && !s.is_char_boundary(end) {
508 end -= 1;
509 }
510 format!("{}...", &s[..end])
511 }
512}
513
514pub fn compress_messages(
516 messages: &[Message],
517 strategy: CompressionStrategy,
518 config: &CompressionConfig,
519) -> Result<Vec<Message>> {
520 match strategy {
521 CompressionStrategy::Truncate => truncate_compress(messages, config),
522 CompressionStrategy::SlidingWindow => sliding_window_compress(messages, config),
523 CompressionStrategy::Summarize => {
524 sliding_window_compress(messages, config)
526 }
527 CompressionStrategy::BiasBased => compress_with_bias(messages, config),
528 }
529}
530
531pub fn compress_with_bias(
533 messages: &[Message],
534 config: &CompressionConfig,
535) -> Result<Vec<Message>> {
536 if messages.len() <= config.min_preserve_messages {
537 return Ok(messages.to_vec());
538 }
539
540 let scored_messages: Vec<(usize, Message, f64)> = messages
542 .iter()
543 .enumerate()
544 .map(|(idx, msg)| (idx, msg.clone(), calculate_preservation_score(msg, idx, messages.len(), &config.bias)))
545 .collect();
546
547 let mut scored_with_recency: Vec<(usize, Message, f64)> = scored_messages
550 .into_iter()
551 .map(|(idx, msg, score)| {
552 let recency_bonus = if idx >= messages.len() - config.min_preserve_messages {
554 100.0 } else {
556 (idx as f64 / messages.len() as f64) * 20.0 };
558 (idx, msg, score + recency_bonus)
559 })
560 .collect();
561
562 scored_with_recency.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
564
565 let target_count = if config.bias.aggressive {
567 config.min_preserve_messages
568 } else {
569 let estimated_tokens = estimate_total_tokens(messages);
570 let target_tokens = (estimated_tokens as f64 * config.target_ratio) as u32;
571 let avg_tokens_per_msg = estimated_tokens / messages.len() as u32;
572 let calculated = (target_tokens / avg_tokens_per_msg.max(1)) as usize;
573 calculated.max(config.min_preserve_messages)
574 };
575
576 let to_keep_indices: HashSet<usize> = scored_with_recency
578 .iter()
579 .take(target_count)
580 .map(|(idx, _, _)| *idx)
581 .collect();
582
583 let compressed: Vec<Message> = messages
585 .iter()
586 .enumerate()
587 .filter(|(idx, _)| to_keep_indices.contains(idx))
588 .map(|(_, msg)| msg.clone())
589 .collect();
590
591 Ok(compressed)
592}
593
594fn calculate_preservation_score(message: &Message, _index: usize, _total: usize, bias: &CompressionBias) -> f64 {
596 let mut score: f64 = 10.0; match message.role {
600 Role::User => {
601 if bias.preserve_user_questions {
602 score += 30.0;
603 }
604 }
605 Role::Assistant => {
606 score += 5.0;
607 }
608 Role::Tool => {
609 if bias.preserve_tools {
610 score += 25.0;
611 }
612 }
613 Role::System => {
614 score += 40.0; }
616 }
617
618 match &message.content {
620 MessageContent::Text(text) => {
621 for keyword in &bias.preserve_keywords {
623 if text.to_lowercase().contains(&keyword.to_lowercase()) {
624 score += 15.0;
625 }
626 }
627
628 if contains_sensitive_instructions(text) {
630 score += 50.0; }
632
633 if !bias.compact_long_outputs && text.len() > 2000 {
635 score -= 10.0;
636 }
637 }
638 MessageContent::Blocks(blocks) => {
639 for block in blocks {
640 match block {
641 ContentBlock::ToolUse { name, .. } => {
642 if bias.preserve_tools {
643 score += 20.0;
644 }
645 if name == "write" || name == "edit" || name == "bash" {
647 score += 10.0;
648 }
649 }
650 ContentBlock::ToolResult { content, .. } => {
651 if bias.preserve_tools {
652 score += 20.0;
653 }
654 for keyword in &bias.preserve_keywords {
656 if content.to_lowercase().contains(&keyword.to_lowercase()) {
657 score += 10.0;
658 }
659 }
660 if contains_sensitive_instructions(content) {
662 score += 30.0;
663 }
664 }
665 ContentBlock::Thinking { .. } => {
666 if bias.preserve_thinking {
667 score += 25.0;
668 } else {
669 score -= 5.0; }
671 }
672 ContentBlock::Text { text } => {
673 for keyword in &bias.preserve_keywords {
674 if text.to_lowercase().contains(&keyword.to_lowercase()) {
675 score += 15.0;
676 }
677 }
678 if contains_sensitive_instructions(text) {
680 score += 50.0;
681 }
682 }
683 _ => {}
684 }
685 }
686 }
687 }
688
689 score
690}
691
692fn contains_sensitive_instructions(text: &str) -> bool {
695 let text_lower = text.to_lowercase();
696
697 let sensitive_patterns = [
699 "不要", "禁止", "不能", "千万别", "禁止使用",
701 "never do", "must not", "should not", "cannot", "avoid",
702
703 "必须", "一定要", "务必", "必须使用",
705 "must", "required", "mandatory",
706
707 "敏感", "隐私", "密码", "secret", "password", "credential",
709 "private", "sensitive", "confidential",
710
711 "决定", "决策", "critical", "important", "关键",
713
714 "偏好", "我喜欢", "我习惯", "prefer", "preference",
716
717 "严格按照", "遵循", "按原样", "strictly", "exactly",
719 "不要修改", "不要改动", "keep original", "as is",
720 ];
721
722 for pattern in &sensitive_patterns {
723 if text_lower.contains(pattern) {
724 return true;
725 }
726 }
727
728 false
729}
730
731pub async fn compress_messages_with_ai(
733 messages: &[Message],
734 compressor: &dyn Compressor,
735 config: &CompressionConfig,
736) -> Result<(Vec<Message>, Option<SummarizedSegment>)> {
737 if messages.len() <= config.min_preserve_messages {
738 return Ok((messages.to_vec(), None));
739 }
740
741 let preserve_count = config.min_preserve_messages;
743 let summarize_messages = &messages[..messages.len() - preserve_count];
744 let keep_messages = &messages[messages.len() - preserve_count..];
745
746 let segment = compressor.summarize(summarize_messages, config).await?;
748
749 let summary_msg = segment.to_message();
751 let mut compressed = vec![summary_msg];
752 compressed.extend(keep_messages.to_vec());
753
754 Ok((compressed, Some(segment)))
755}
756
757fn truncate_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
759 if messages.len() <= config.min_preserve_messages {
760 return Ok(messages.to_vec());
761 }
762
763 let keep_count = config.min_preserve_messages;
764 let start_idx = messages.len().saturating_sub(keep_count);
765
766 Ok(messages[start_idx..].to_vec())
767}
768
769fn sliding_window_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
772 if messages.len() <= config.min_preserve_messages {
773 return Ok(messages.to_vec());
774 }
775
776 let total_tokens = estimate_total_tokens(messages);
778 let target_tokens = (total_tokens as f64 * config.target_ratio) as u32;
779
780 let mut turn_boundaries: Vec<usize> = Vec::new();
782 for (i, msg) in messages.iter().enumerate() {
783 if msg.role == Role::User {
784 turn_boundaries.push(i);
785 }
786 }
787
788 let min_start_idx = messages.len().saturating_sub(config.min_preserve_messages);
790
791 for &start_idx in turn_boundaries.iter() {
796 if messages.len() - start_idx < config.min_preserve_messages {
798 continue;
799 }
800
801 let candidate_messages = &messages[start_idx..];
802 let candidate_tokens = estimate_total_tokens(candidate_messages);
803
804 if candidate_tokens <= target_tokens {
806 return Ok(candidate_messages.to_vec());
807 }
808 }
809
810 Ok(messages[min_start_idx..].to_vec())
812}
813
814pub fn estimate_tokens(message: &Message) -> u32 {
820 let (ascii_count, non_ascii_count) = match &message.content {
821 MessageContent::Text(t) => count_chars(t),
822 MessageContent::Blocks(blocks) => {
823 let mut ascii = 0;
824 let mut non_ascii = 0;
825 for block in blocks {
826 match block {
827 ContentBlock::Text { text } => {
828 let (a, n) = count_chars(text);
829 ascii += a;
830 non_ascii += n;
831 }
832 ContentBlock::ToolUse { name, input, .. } => {
833 let (a, n) = count_chars(name);
834 ascii += a;
835 non_ascii += n;
836 let json_str = input.to_string();
838 let (ja, jn) = count_chars(&json_str);
839 ascii += ja;
840 non_ascii += jn;
841 }
842 ContentBlock::ToolResult { content, .. } => {
843 let (a, n) = count_chars(content);
844 ascii += a;
845 non_ascii += n;
846 }
847 ContentBlock::Thinking { thinking, .. } => {
848 let (a, n) = count_chars(thinking);
849 ascii += a;
850 non_ascii += n;
851 }
852 _ => {}
853 }
854 }
855 (ascii, non_ascii)
856 }
857 };
858
859 let ascii_tokens = (ascii_count as f64 * 0.25).ceil() as u32;
862 let non_ascii_tokens = (non_ascii_count as f64 * 0.67).ceil() as u32;
863 let total = ascii_tokens + non_ascii_tokens + 10; total.max(1)
866}
867
868fn count_chars(s: &str) -> (u32, u32) {
870 let mut ascii = 0u32;
871 let mut non_ascii = 0u32;
872 for ch in s.chars() {
873 if ch.is_ascii() {
874 ascii += 1;
875 } else {
876 non_ascii += 1;
877 }
878 }
879 (ascii, non_ascii)
880}
881
882pub fn estimate_total_tokens(messages: &[Message]) -> u32 {
884 messages.iter().map(estimate_tokens).sum()
885}
886
887pub fn should_compress(
889 current_tokens: u32,
890 context_size: Option<u32>,
891 config: &CompressionConfig,
892) -> bool {
893 match context_size {
894 Some(size) => {
895 let ratio = current_tokens as f64 / size as f64;
896 ratio >= config.threshold
897 }
898 None => false,
899 }
900}
901
902pub fn build_summary_prompt(messages: &[Message]) -> String {
904 let history_text = messages
905 .iter()
906 .map(|m| {
907 let role = match m.role {
908 Role::User => "用户",
909 Role::Assistant => "助手",
910 Role::Tool => "工具",
911 Role::System => "系统",
912 };
913 let content_preview = match &m.content {
914 MessageContent::Text(t) => truncate_for_summary(t, 200),
915 MessageContent::Blocks(blocks) => {
916 let preview: Vec<String> = blocks
917 .iter()
918 .map(|b| match b {
919 ContentBlock::Text { text } => truncate_for_summary(text, 100),
920 ContentBlock::ToolUse { name, .. } => format!("[工具: {}]", name),
921 ContentBlock::ToolResult { content, .. } => truncate_for_summary(content, 100),
922 _ => "[...]".to_string(),
923 })
924 .collect();
925 preview.join(" | ")
926 }
927 };
928 format!("{}: {}", role, content_preview)
929 })
930 .collect::<Vec<_>>()
931 .join("\n");
932
933 format!(
934 r#"请将以下对话历史压缩为简洁摘要:
935
936对话历史({} 条消息):
937{}
938
939请输出:
9401. 概述(一句话描述主要任务)
9412. 已完成的关键操作(2-3 条)
9423. 当前状态(如果有)"#,
943 messages.len(),
944 history_text
945 )
946}
947
948fn truncate_for_summary(s: &str, max: usize) -> String {
949 truncate_text(s, max)
950}
951
952#[cfg(test)]
953mod tests {
954 use super::*;
955 use serde_json::json;
956
957 #[test]
958 fn test_estimate_tokens_simple() {
959 let msg = Message {
960 role: Role::User,
961 content: MessageContent::Text("Hello world".to_string()),
962 };
963 assert!(estimate_tokens(&msg) >= 3);
964 }
965
966 #[test]
967 fn test_should_compress_below_threshold() {
968 let config = CompressionConfig::default();
969 assert!(!should_compress(100_000, Some(200_000), &config));
970 }
971
972 #[test]
973 fn test_should_compress_above_threshold() {
974 let config = CompressionConfig::default();
975 assert!(should_compress(160_000, Some(200_000), &config));
976 }
977
978 #[test]
979 fn test_truncate_compress_keeps_minimum() {
980 let messages: Vec<Message> = (0..10)
981 .map(|i| Message {
982 role: Role::User,
983 content: MessageContent::Text(format!("Message {}", i)),
984 })
985 .collect();
986
987 let config = CompressionConfig {
988 min_preserve_messages: 4,
989 ..Default::default()
990 };
991
992 let compressed = truncate_compress(&messages, &config).unwrap();
993 assert_eq!(compressed.len(), 4);
994 assert_eq!(compressed[0].content, MessageContent::Text("Message 6".to_string()));
995 }
996
997 #[test]
998 fn test_sliding_window_preserves_turns() {
999 let messages: Vec<Message> = vec![
1001 Message { role: Role::User, content: MessageContent::Text("Q1 - this is a longer question to test token estimation".to_string()) },
1002 Message { role: Role::Assistant, content: MessageContent::Text("A1 - this is a longer answer with more content for token estimation".to_string()) },
1003 Message { role: Role::User, content: MessageContent::Text("Q2 - another longer question for testing".to_string()) },
1004 Message { role: Role::Assistant, content: MessageContent::Text("A2 - another longer answer for testing token estimation properly".to_string()) },
1005 Message { role: Role::User, content: MessageContent::Text("Q3 - the third question in this test".to_string()) },
1006 Message { role: Role::Assistant, content: MessageContent::Text("A3 - the third answer with sufficient content".to_string()) },
1007 ];
1008
1009 let config = CompressionConfig {
1010 min_preserve_messages: 4,
1011 target_ratio: 0.5,
1012 ..Default::default()
1013 };
1014
1015 let compressed = sliding_window_compress(&messages, &config).unwrap();
1016 assert!(compressed.len() >= config.min_preserve_messages);
1018 assert!(compressed.iter().any(|m| m.role == Role::User));
1020 }
1021
1022 #[test]
1023 fn test_parse_summary_response() {
1024 let text = "用户请求实现登录功能。\n已完成操作:\n• 创建了 login.rs 文件\n• 添加了密码验证逻辑\n当前状态:测试中";
1025 let (summary, key_points) = parse_summary_response(text);
1026
1027 assert!(!summary.is_empty());
1028 assert!(key_points.len() >= 2);
1029 }
1030
1031 #[test]
1032 fn test_compression_result_format() {
1033 let result = CompressionResult::new(
1034 20,
1035 8,
1036 5000,
1037 Some("摘要内容".to_string()),
1038 CompressionStrategy::Summarize,
1039 );
1040
1041 let formatted = result.format_summary();
1042 assert!(formatted.contains("20"));
1043 assert!(formatted.contains("8"));
1044 assert!(formatted.contains("AI summarize"));
1045 }
1046
1047 #[test]
1048 fn test_compression_history_entry() {
1049 let result = CompressionResult::new(
1050 15,
1051 6,
1052 3000,
1053 None,
1054 CompressionStrategy::SlidingWindow,
1055 );
1056
1057 let entry = CompressionHistoryEntry::from_result(&result);
1058 assert_eq!(entry.strategy, CompressionStrategy::SlidingWindow);
1059 assert!(!entry.has_summary);
1060 }
1061
1062 #[test]
1063 fn test_compression_bias_parse() {
1064 let balanced = CompressionBias::parse("balanced").unwrap();
1066 assert!(balanced.preserve_tools);
1067 assert!(balanced.preserve_user_questions);
1068
1069 let aggressive = CompressionBias::parse("aggressive").unwrap();
1070 assert!(!aggressive.preserve_tools);
1071 assert!(aggressive.aggressive);
1072
1073 let important = CompressionBias::parse("important").unwrap();
1074 assert!(important.preserve_thinking);
1075 assert!(important.preserve_tools);
1076
1077 let tools = CompressionBias::parse("tools").unwrap();
1078 assert!(tools.preserve_tools);
1079 assert!(!tools.preserve_thinking);
1080 }
1081
1082 #[test]
1083 fn test_compression_bias_format() {
1084 let bias = CompressionBias::balanced();
1085 let formatted = bias.format();
1086 assert!(formatted.contains("tools"));
1087 assert!(formatted.contains("user"));
1088 }
1089
1090 #[test]
1091 fn test_compress_with_bias_preserves_tools() {
1092 let messages: Vec<Message> = vec![
1093 Message { role: Role::User, content: MessageContent::Text("Q1".to_string()) },
1094 Message {
1095 role: Role::Assistant,
1096 content: MessageContent::Blocks(vec![
1097 ContentBlock::ToolUse { id: "1".to_string(), name: "read".to_string(), input: json!({}) }
1098 ])
1099 },
1100 Message { role: Role::Tool, content: MessageContent::Blocks(vec![
1101 ContentBlock::ToolResult { tool_use_id: "1".to_string(), content: "file content".to_string() }
1102 ])},
1103 Message { role: Role::User, content: MessageContent::Text("Q2".to_string()) },
1104 Message { role: Role::Assistant, content: MessageContent::Text("A2".to_string()) },
1105 Message { role: Role::User, content: MessageContent::Text("Q3".to_string()) },
1106 Message { role: Role::Assistant, content: MessageContent::Text("A3".to_string()) },
1107 ];
1108
1109 let config = CompressionConfig {
1110 min_preserve_messages: 2,
1111 bias: CompressionBias::tool_focused(),
1112 ..Default::default()
1113 };
1114
1115 let compressed = compress_with_bias(&messages, &config).unwrap();
1116
1117 let has_tool_use = compressed.iter().any(|m| {
1119 matches!(&m.content, MessageContent::Blocks(blocks) if
1120 blocks.iter().any(|b| matches!(b, ContentBlock::ToolUse { .. })))
1121 });
1122 assert!(has_tool_use || compressed.len() >= messages.len() - 2);
1123 }
1124
1125 #[test]
1126 fn test_bias_based_strategy() {
1127 let messages: Vec<Message> = (0..10)
1128 .map(|i| Message {
1129 role: if i % 2 == 0 { Role::User } else { Role::Assistant },
1130 content: MessageContent::Text(format!("Message {}", i)),
1131 })
1132 .collect();
1133
1134 let config = CompressionConfig {
1135 min_preserve_messages: 4,
1136 bias: CompressionBias::aggressive(),
1137 ..Default::default()
1138 };
1139
1140 let compressed = compress_messages(&messages, CompressionStrategy::BiasBased, &config).unwrap();
1141 assert!(compressed.len() <= messages.len());
1142 assert!(compressed.len() >= config.min_preserve_messages);
1143 }
1144}