Skip to main content

matrixcode_core/
compress.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7use crate::providers::{ContentBlock, Message, MessageContent, Provider, Role, ChatRequest, ChatResponse};
8
9/// Compression trigger threshold (percentage of context window).
10pub const DEFAULT_COMPRESSION_THRESHOLD: f64 = 0.75;
11
12/// Minimum messages to keep after compression.
13pub const MIN_MESSAGES_TO_KEEP: usize = 8;
14
15/// Target ratio after compression (keep this fraction of tokens).
16pub const DEFAULT_TARGET_RATIO: f64 = 0.4;
17
18/// Default model for summarization (cost-effective).
19pub const DEFAULT_COMPRESSOR_MODEL: &str = "claude-3-5-haiku-20241022";
20
21/// Compression bias - controls what to prioritize during compression.
22#[derive(Debug, Clone, Default)]
23pub struct CompressionBias {
24    /// Preserve tool calls and their results (important operations).
25    pub preserve_tools: bool,
26    /// Preserve thinking blocks (reasoning process).
27    pub preserve_thinking: bool,
28    /// Preserve user questions (even if old).
29    pub preserve_user_questions: bool,
30    /// Compact long outputs instead of removing them.
31    pub compact_long_outputs: bool,
32    /// Aggressive mode - remove more content.
33    pub aggressive: bool,
34    /// Custom keywords to preserve messages containing them.
35    pub preserve_keywords: Vec<String>,
36}
37
38impl CompressionBias {
39    /// Default bias - balanced preservation.
40    pub fn balanced() -> Self {
41        Self {
42            preserve_tools: true,
43            preserve_thinking: false,
44            preserve_user_questions: true,
45            compact_long_outputs: false,
46            aggressive: false,
47            preserve_keywords: vec![
48                "决定".to_string(), "decision".to_string(), 
49                "重要".to_string(), "important".to_string(), 
50                "关键".to_string(), "key".to_string()
51            ],
52        }
53    }
54
55    /// Preserve all important content (tools, thinking, decisions).
56    pub fn preserve_important() -> Self {
57        Self {
58            preserve_tools: true,
59            preserve_thinking: true,
60            preserve_user_questions: true,
61            compact_long_outputs: true,
62            aggressive: false,
63            preserve_keywords: vec![
64                "决定".to_string(), "decision".to_string(), 
65                "重要".to_string(), "important".to_string(), 
66                "关键".to_string(), "key".to_string(),
67                "完成".to_string(), "done".to_string(), 
68                "成功".to_string(), "success".to_string()
69            ],
70        }
71    }
72
73    /// Aggressive compression - remove as much as possible.
74    pub fn aggressive() -> Self {
75        Self {
76            preserve_tools: false,
77            preserve_thinking: false,
78            preserve_user_questions: false,
79            compact_long_outputs: false,
80            aggressive: true,
81            preserve_keywords: vec![],
82        }
83    }
84
85    /// Focus on preserving tool operations.
86    pub fn tool_focused() -> Self {
87        Self {
88            preserve_tools: true,
89            preserve_thinking: false,
90            preserve_user_questions: false,
91            compact_long_outputs: false,
92            aggressive: false,
93            preserve_keywords: vec![
94                "工具".to_string(), "tool".to_string(), 
95                "执行".to_string(), "execute".to_string(), 
96                "文件".to_string(), "file".to_string()
97            ],
98        }
99    }
100
101    /// Parse bias from a string specification.
102    /// Format: "preserve:tools,thinking,user" or "aggressive" or "balanced"
103    pub fn parse(spec: &str) -> Result<Self> {
104        let spec = spec.trim().to_lowercase();
105        
106        if spec == "balanced" || spec == "default" || spec.is_empty() {
107            return Ok(Self::balanced());
108        }
109        if spec == "aggressive" {
110            return Ok(Self::aggressive());
111        }
112        if spec == "preserve_important" || spec == "important" {
113            return Ok(Self::preserve_important());
114        }
115        if spec == "tool_focused" || spec == "tools" {
116            return Ok(Self::tool_focused());
117        }
118
119        // Parse custom specification: "preserve:tools,thinking,user keywords:决定,重要"
120        let mut bias = Self::default();
121        
122        for part in spec.split_whitespace() {
123            if let Some(preserve_list) = part.strip_prefix("preserve:") {
124                for item in preserve_list.split(',') {
125                    match item.trim() {
126                        "tools" | "tool" => bias.preserve_tools = true,
127                        "thinking" | "think" => bias.preserve_thinking = true,
128                        "user" | "questions" => bias.preserve_user_questions = true,
129                        "compact" | "long" => bias.compact_long_outputs = true,
130                        _ => {}
131                    }
132                }
133            } else if let Some(keyword_list) = part.strip_prefix("keywords:") {
134                bias.preserve_keywords = keyword_list.split(',')
135                    .map(|k| k.trim().to_string())
136                    .filter(|k| !k.is_empty())
137                    .collect();
138            } else if part == "aggressive" {
139                bias.aggressive = true;
140            }
141        }
142
143        Ok(bias)
144    }
145
146    /// Format bias for display.
147    pub fn format(&self) -> String {
148        let mut parts: Vec<String> = Vec::new();
149        
150        if self.preserve_tools { parts.push("tools".to_string()); }
151        if self.preserve_thinking { parts.push("thinking".to_string()); }
152        if self.preserve_user_questions { parts.push("user".to_string()); }
153        if self.compact_long_outputs { parts.push("compact".to_string()); }
154        if self.aggressive { parts.push("aggressive".to_string()); }
155        
156        if !self.preserve_keywords.is_empty() {
157            parts.push(format!("keywords:{}", self.preserve_keywords.join(",")));
158        }
159
160        if parts.is_empty() {
161            "default".to_string()
162        } else {
163            parts.join(", ")
164        }
165    }
166}
167
168/// Configuration for context compression.
169#[derive(Debug, Clone)]
170pub struct CompressionConfig {
171    /// Threshold (0.0-1.0) at which to trigger compression.
172    pub threshold: f64,
173    /// Maximum tokens to target after compression.
174    pub target_ratio: f64,
175    /// Minimum recent messages to always preserve.
176    pub min_preserve_messages: usize,
177    /// Whether to use AI summarization (requires a compressor model).
178    pub use_summarization: bool,
179    /// Optional model name for summarization (if different from main model).
180    pub compressor_model: Option<String>,
181    /// Compression bias - what to prioritize during compression.
182    pub bias: CompressionBias,
183}
184
185impl Default for CompressionConfig {
186    fn default() -> Self {
187        Self {
188            threshold: DEFAULT_COMPRESSION_THRESHOLD,
189            target_ratio: DEFAULT_TARGET_RATIO,
190            min_preserve_messages: MIN_MESSAGES_TO_KEEP,
191            use_summarization: true,
192            compressor_model: None,
193            bias: CompressionBias::balanced(),
194        }
195    }
196}
197
198impl CompressionConfig {
199    /// Get the compressor model name.
200    pub fn compressor_model_name(&self) -> &str {
201        self.compressor_model.as_deref().unwrap_or(DEFAULT_COMPRESSOR_MODEL)
202    }
203}
204
205/// Result of a compression operation.
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct CompressionResult {
208    /// Original message count.
209    pub original_count: usize,
210    /// New message count after compression.
211    pub new_count: usize,
212    /// Estimated token reduction.
213    pub tokens_saved: u32,
214    /// Summary of removed content (if summarization was used).
215    pub summary: Option<String>,
216    /// Strategy used for compression.
217    pub strategy: CompressionStrategy,
218    /// When the compression occurred.
219    pub timestamp: DateTime<Utc>,
220}
221
222impl CompressionResult {
223    /// Create a new compression result.
224    pub fn new(
225        original_count: usize,
226        new_count: usize,
227        tokens_saved: u32,
228        summary: Option<String>,
229        strategy: CompressionStrategy,
230    ) -> Self {
231        Self {
232            original_count,
233            new_count,
234            tokens_saved,
235            summary,
236            strategy,
237            timestamp: Utc::now(),
238        }
239    }
240
241    /// Format for display.
242    pub fn format_summary(&self) -> String {
243        let strategy_name = match self.strategy {
244            CompressionStrategy::Truncate => "truncate",
245            CompressionStrategy::SlidingWindow => "sliding window",
246            CompressionStrategy::Summarize => "AI summarize",
247            CompressionStrategy::BiasBased => "bias-based",
248        };
249        format!(
250            "{} messages → {} messages (saved ~{} tokens, {})",
251            self.original_count,
252            self.new_count,
253            format_tokens(self.tokens_saved),
254            strategy_name
255        )
256    }
257}
258
259pub fn format_tokens(n: u32) -> String {
260    if n < 1_000 {
261        n.to_string()
262    } else if n < 10_000 {
263        format!("{:.1}K", n as f64 / 1_000.0)
264    } else {
265        format!("{:.0}K", n as f64 / 1_000.0)
266    }
267}
268
269/// Strategy for compressing conversation history.
270#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
271#[serde(rename_all = "snake_case")]
272pub enum CompressionStrategy {
273    /// Remove oldest messages, keep recent ones.
274    Truncate,
275    /// Use sliding window - keep last N message pairs.
276    SlidingWindow,
277    /// Summarize old messages into a compact summary block.
278    Summarize,
279    /// Use bias-based scoring to prioritize what to keep.
280    BiasBased,
281}
282
283/// A segment of conversation history that has been summarized.
284#[derive(Debug, Clone, Serialize, Deserialize)]
285pub struct SummarizedSegment {
286    /// Timestamp range of the summarized messages.
287    pub time_range: (DateTime<Utc>, DateTime<Utc>),
288    /// Number of original messages in this segment.
289    pub original_count: usize,
290    /// The summary text.
291    pub summary: String,
292    /// Key decisions or actions taken during this segment.
293    pub key_points: Vec<String>,
294}
295
296impl SummarizedSegment {
297    /// Render as a system message for context injection.
298    pub fn to_message(&self) -> Message {
299        let key_points_text = if self.key_points.is_empty() {
300            "无".to_string()
301        } else {
302            self.key_points.iter().map(|p| format!("• {}", p)).collect::<Vec<_>>().join("\n")
303        };
304        
305        let content = format!(
306            "[对话摘要 - 原 {} 条消息]\n\n{}\n\n关键要点:\n{}",
307            self.original_count,
308            self.summary,
309            key_points_text
310        );
311        
312        Message {
313            role: Role::User,
314            content: MessageContent::Text(content),
315        }
316    }
317}
318
319/// Compression history entry for session metadata.
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct CompressionHistoryEntry {
322    /// When the compression occurred.
323    pub timestamp: DateTime<Utc>,
324    /// Strategy used.
325    pub strategy: CompressionStrategy,
326    /// Original message count.
327    pub original_count: usize,
328    /// New message count.
329    pub new_count: usize,
330    /// Estimated tokens saved.
331    pub tokens_saved: u32,
332    /// Whether summary was generated.
333    pub has_summary: bool,
334}
335
336impl CompressionHistoryEntry {
337    /// Create from a CompressionResult.
338    pub fn from_result(result: &CompressionResult) -> Self {
339        Self {
340            timestamp: result.timestamp,
341            strategy: result.strategy,
342            original_count: result.original_count,
343            new_count: result.new_count,
344            tokens_saved: result.tokens_saved,
345            has_summary: result.summary.is_some(),
346        }
347    }
348
349    /// Format for display.
350    pub fn format_line(&self) -> String {
351        let strategy_name = match self.strategy {
352            CompressionStrategy::Truncate => "truncate",
353            CompressionStrategy::SlidingWindow => "sliding window",
354            CompressionStrategy::Summarize => "AI summarize",
355            CompressionStrategy::BiasBased => "bias-based",
356        };
357        let summary_marker = if self.has_summary { "📝" } else { "✂️" };
358        format!(
359            "{} {} - {} msgs → {} msgs (~{} tokens saved) {}",
360            self.timestamp.format("%Y-%m-%d %H:%M"),
361            strategy_name,
362            self.original_count,
363            self.new_count,
364            format_tokens(self.tokens_saved),
365            summary_marker
366        )
367    }
368}
369
370/// Compressor trait for different compression implementations.
371#[async_trait]
372pub trait Compressor: Send + Sync {
373    /// Compress messages using AI summarization.
374    async fn summarize(&self, messages: &[Message], config: &CompressionConfig) -> Result<SummarizedSegment>;
375    
376    /// Get the model name used for summarization.
377    fn model_name(&self) -> &str;
378}
379
380/// AI-based compressor using a Provider.
381pub struct AiCompressor {
382    provider: Box<dyn Provider>,
383    model: String,
384}
385
386impl AiCompressor {
387    /// Create a new AI compressor.
388    pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
389        Self { provider, model }
390    }
391}
392
393#[async_trait]
394impl Compressor for AiCompressor {
395    async fn summarize(&self, messages: &[Message], _config: &CompressionConfig) -> Result<SummarizedSegment> {
396        let prompt = build_summary_prompt(messages);
397        
398        let request = ChatRequest {
399            messages: vec![Message {
400                role: Role::User,
401                content: MessageContent::Text(prompt),
402            }],
403            tools: vec![], // No tools for summarization
404            system: Some(SUMMARY_SYSTEM_PROMPT.to_string()),
405            think: false, // No extended thinking for summarization
406            max_tokens: 1024, // Short summary
407            server_tools: vec![],
408            enable_caching: false, // No caching for summarization
409        };
410        
411        let response = self.provider.chat(request).await?;
412        
413        // Extract text from response
414        let summary_text = extract_text_from_response(&response);
415        
416        // Parse the summary into structured format
417        let (summary, key_points) = parse_summary_response(&summary_text);
418        
419        Ok(SummarizedSegment {
420            time_range: (Utc::now(), Utc::now()), // Approximate
421            original_count: messages.len(),
422            summary,
423            key_points,
424        })
425    }
426    
427    fn model_name(&self) -> &str {
428        &self.model
429    }
430}
431
432/// System prompt for summarization.
433const SUMMARY_SYSTEM_PROMPT: &str = r#"你是一个对话历史压缩助手。你的任务是将对话历史压缩为简洁的摘要,保留关键信息。
434
435输出要求:
436- 简洁:摘要控制在 200 字以内
437- 关键:只保留重要操作和决策
438- 结构化:使用清晰格式
439- 敏感:必须保留用户的敏感指令(如"不要..."、"必须..."、"禁止..."等)
440- 偏好:保留用户的偏好设置和决策
441
442请直接输出摘要内容。"#;
443
444/// Extract text content from a chat response.
445fn extract_text_from_response(response: &ChatResponse) -> String {
446    response.content
447        .iter()
448        .filter_map(|block| {
449            if let ContentBlock::Text { text } = block {
450                Some(text.clone())
451            } else {
452                None
453            }
454        })
455        .collect::<Vec<_>>()
456        .join("\n")
457}
458
459/// Parse summary response into structured format.
460fn parse_summary_response(text: &str) -> (String, Vec<String>) {
461    let mut summary = String::new();
462    let mut key_points: Vec<String> = Vec::new();
463    
464    for line in text.lines() {
465        let line = line.trim();
466        
467        // Detect bullet points
468        if line.starts_with("•") || line.starts_with("-") || line.starts_with("*") {
469            let point = line.trim_start_matches(['•', '-', '*']).trim();
470            if !point.is_empty() {
471                key_points.push(point.to_string());
472            }
473        } else if line.starts_with("已完成") || line.starts_with("操作") {
474            // Extract operations section
475            let ops = line.trim_start_matches(|c: char| c.is_alphabetic() || c == ':' || c == ':').trim();
476            if !ops.is_empty() && ops != ":" && ops != ":" {
477                key_points.push(ops.to_string());
478            }
479        } else if !line.is_empty() && summary.is_empty() {
480            // First non-empty line is the overview
481            summary = line.to_string();
482        } else if !line.is_empty() {
483            // Append to summary if no key points yet
484            if key_points.is_empty() && summary.len() < 200 {
485                summary.push(' ');
486                summary.push_str(line);
487            }
488        }
489    }
490    
491    // If no structured parsing worked, use the whole text as summary
492    if summary.is_empty() && !text.is_empty() {
493        summary = text.lines().take(3).collect::<Vec<_>>().join(" ");
494        if summary.len() > 200 {
495            summary = truncate_text(&summary, 200);
496        }
497    }
498    
499    (summary, key_points)
500}
501
502fn truncate_text(s: &str, max: usize) -> String {
503    if s.len() <= max {
504        s.to_string()
505    } else {
506        let mut end = max;
507        while end > 0 && !s.is_char_boundary(end) {
508            end -= 1;
509        }
510        format!("{}...", &s[..end])
511    }
512}
513
514/// Compress messages synchronously (for non-AI strategies).
515pub fn compress_messages(
516    messages: &[Message],
517    strategy: CompressionStrategy,
518    config: &CompressionConfig,
519) -> Result<Vec<Message>> {
520    match strategy {
521        CompressionStrategy::Truncate => truncate_compress(messages, config),
522        CompressionStrategy::SlidingWindow => sliding_window_compress(messages, config),
523        CompressionStrategy::Summarize => {
524            // Summarize requires async AI call, fall back to sliding window
525            sliding_window_compress(messages, config)
526        }
527        CompressionStrategy::BiasBased => compress_with_bias(messages, config),
528    }
529}
530
531/// Compress messages with bias - prioritized removal based on configuration.
532pub fn compress_with_bias(
533    messages: &[Message],
534    config: &CompressionConfig,
535) -> Result<Vec<Message>> {
536    if messages.len() <= config.min_preserve_messages {
537        return Ok(messages.to_vec());
538    }
539
540    // Calculate preservation score for each message
541    let scored_messages: Vec<(usize, Message, f64)> = messages
542        .iter()
543        .enumerate()
544        .map(|(idx, msg)| (idx, msg.clone(), calculate_preservation_score(msg, idx, messages.len(), &config.bias)))
545        .collect();
546
547    // Sort by score (higher score = more important to keep)
548    // Also factor in recency - recent messages get bonus score
549    let mut scored_with_recency: Vec<(usize, Message, f64)> = scored_messages
550        .into_iter()
551        .map(|(idx, msg, score)| {
552            // Recency bonus: later messages get higher score
553            let recency_bonus = if idx >= messages.len() - config.min_preserve_messages {
554                100.0 // Always keep recent messages
555            } else {
556                (idx as f64 / messages.len() as f64) * 20.0 // Up to 20 points for being recent
557            };
558            (idx, msg, score + recency_bonus)
559        })
560        .collect();
561
562    // Sort by score descending
563    scored_with_recency.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
564
565    // Determine how many to keep based on target_ratio and aggressive mode
566    let target_count = if config.bias.aggressive {
567        config.min_preserve_messages
568    } else {
569        let estimated_tokens = estimate_total_tokens(messages);
570        let target_tokens = (estimated_tokens as f64 * config.target_ratio) as u32;
571        let avg_tokens_per_msg = estimated_tokens / messages.len() as u32;
572        let calculated = (target_tokens / avg_tokens_per_msg.max(1)) as usize;
573        calculated.max(config.min_preserve_messages)
574    };
575
576    // Keep top-scored messages, but maintain chronological order
577    let to_keep_indices: HashSet<usize> = scored_with_recency
578        .iter()
579        .take(target_count)
580        .map(|(idx, _, _)| *idx)
581        .collect();
582
583    // Rebuild message list in original order
584    let compressed: Vec<Message> = messages
585        .iter()
586        .enumerate()
587        .filter(|(idx, _)| to_keep_indices.contains(idx))
588        .map(|(_, msg)| msg.clone())
589        .collect();
590
591    Ok(compressed)
592}
593
594/// Calculate preservation score for a message (higher = more important to keep).
595fn calculate_preservation_score(message: &Message, _index: usize, _total: usize, bias: &CompressionBias) -> f64 {
596    let mut score: f64 = 10.0; // Base score
597
598    // Role-based scoring
599    match message.role {
600        Role::User => {
601            if bias.preserve_user_questions {
602                score += 30.0;
603            }
604        }
605        Role::Assistant => {
606            score += 5.0;
607        }
608        Role::Tool => {
609            if bias.preserve_tools {
610                score += 25.0;
611            }
612        }
613        Role::System => {
614            score += 40.0; // System messages are usually important
615        }
616    }
617
618    // Content-based scoring
619    match &message.content {
620        MessageContent::Text(text) => {
621            // Check for keywords
622            for keyword in &bias.preserve_keywords {
623                if text.to_lowercase().contains(&keyword.to_lowercase()) {
624                    score += 15.0;
625                }
626            }
627            
628            // Check for sensitive instructions (Claude Code inspired)
629            if contains_sensitive_instructions(text) {
630                score += 50.0; // Highly preserve sensitive instructions
631            }
632            
633            // Penalize very long messages if not compacting
634            if !bias.compact_long_outputs && text.len() > 2000 {
635                score -= 10.0;
636            }
637        }
638        MessageContent::Blocks(blocks) => {
639            for block in blocks {
640                match block {
641                    ContentBlock::ToolUse { name, .. } => {
642                        if bias.preserve_tools {
643                            score += 20.0;
644                        }
645                        // Certain tools are more important
646                        if name == "write" || name == "edit" || name == "bash" {
647                            score += 10.0;
648                        }
649                    }
650                    ContentBlock::ToolResult { content, .. } => {
651                        if bias.preserve_tools {
652                            score += 20.0;
653                        }
654                        // Check for keywords in result
655                        for keyword in &bias.preserve_keywords {
656                            if content.to_lowercase().contains(&keyword.to_lowercase()) {
657                                score += 10.0;
658                            }
659                        }
660                        // Check for sensitive instructions in result
661                        if contains_sensitive_instructions(content) {
662                            score += 30.0;
663                        }
664                    }
665                    ContentBlock::Thinking { .. } => {
666                        if bias.preserve_thinking {
667                            score += 25.0;
668                        } else {
669                            score -= 5.0; // Thinking blocks can be verbose
670                        }
671                    }
672                    ContentBlock::Text { text } => {
673                        for keyword in &bias.preserve_keywords {
674                            if text.to_lowercase().contains(&keyword.to_lowercase()) {
675                                score += 15.0;
676                            }
677                        }
678                        // Check for sensitive instructions
679                        if contains_sensitive_instructions(text) {
680                            score += 50.0;
681                        }
682                    }
683                    _ => {}
684                }
685            }
686        }
687    }
688
689    score
690}
691
692/// Check if text contains sensitive user instructions that must be preserved.
693/// Inspired by Claude Code's "preserve sensitive user instructions" feature.
694fn contains_sensitive_instructions(text: &str) -> bool {
695    let text_lower = text.to_lowercase();
696    
697    // Sensitive instruction patterns (cleaned, no duplicates, more specific)
698    let sensitive_patterns = [
699        // Negative instructions (must NOT do something)
700        "不要", "禁止", "不能", "千万别", "禁止使用",
701        "never do", "must not", "should not", "cannot", "avoid",
702        
703        // Mandatory instructions (MUST do something)
704        "必须", "一定要", "务必", "必须使用",
705        "must", "required", "mandatory",
706        
707        // Security/privacy related
708        "敏感", "隐私", "密码", "secret", "password", "credential",
709        "private", "sensitive", "confidential",
710        
711        // Critical decisions
712        "决定", "决策", "critical", "important", "关键",
713        
714        // User preferences
715        "偏好", "我喜欢", "我习惯", "prefer", "preference",
716        
717        // Strict constraints
718        "严格按照", "遵循", "按原样", "strictly", "exactly",
719        "不要修改", "不要改动", "keep original", "as is",
720    ];
721    
722    for pattern in &sensitive_patterns {
723        if text_lower.contains(pattern) {
724            return true;
725        }
726    }
727    
728    false
729}
730
731/// Compress messages with AI summarization (async version).
732pub async fn compress_messages_with_ai(
733    messages: &[Message],
734    compressor: &dyn Compressor,
735    config: &CompressionConfig,
736) -> Result<(Vec<Message>, Option<SummarizedSegment>)> {
737    if messages.len() <= config.min_preserve_messages {
738        return Ok((messages.to_vec(), None));
739    }
740    
741    // Determine split point: messages to summarize vs messages to keep
742    let preserve_count = config.min_preserve_messages;
743    let summarize_messages = &messages[..messages.len() - preserve_count];
744    let keep_messages = &messages[messages.len() - preserve_count..];
745    
746    // Generate summary
747    let segment = compressor.summarize(summarize_messages, config).await?;
748    
749    // Build new message list: summary message + kept messages
750    let summary_msg = segment.to_message();
751    let mut compressed = vec![summary_msg];
752    compressed.extend(keep_messages.to_vec());
753    
754    Ok((compressed, Some(segment)))
755}
756
757/// Simple truncation: remove oldest messages.
758fn truncate_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
759    if messages.len() <= config.min_preserve_messages {
760        return Ok(messages.to_vec());
761    }
762
763    let keep_count = config.min_preserve_messages;
764    let start_idx = messages.len().saturating_sub(keep_count);
765
766    Ok(messages[start_idx..].to_vec())
767}
768
769/// Sliding window: preserve complete conversation turns.
770/// Now uses token-based target instead of turn count for more stable compression.
771fn sliding_window_compress(messages: &[Message], config: &CompressionConfig) -> Result<Vec<Message>> {
772    if messages.len() <= config.min_preserve_messages {
773        return Ok(messages.to_vec());
774    }
775
776    // Estimate total tokens
777    let total_tokens = estimate_total_tokens(messages);
778    let target_tokens = (total_tokens as f64 * config.target_ratio) as u32;
779    
780    // Find turn boundaries (user messages mark start of each turn)
781    let mut turn_boundaries: Vec<usize> = Vec::new();
782    for (i, msg) in messages.iter().enumerate() {
783        if msg.role == Role::User {
784            turn_boundaries.push(i);
785        }
786    }
787
788    // Minimum start index to ensure we keep at least min_preserve_messages
789    let min_start_idx = messages.len().saturating_sub(config.min_preserve_messages);
790    
791    // Try to find a turn that:
792    // 1. Starts at or after min_start_idx (ensures enough messages)
793    // 2. Fits within token target
794    // Iterate from the earliest acceptable turn
795    for &start_idx in turn_boundaries.iter() {
796        // Must have enough messages
797        if messages.len() - start_idx < config.min_preserve_messages {
798            continue;
799        }
800        
801        let candidate_messages = &messages[start_idx..];
802        let candidate_tokens = estimate_total_tokens(candidate_messages);
803        
804        // If this turn fits within token target, use it
805        if candidate_tokens <= target_tokens {
806            return Ok(candidate_messages.to_vec());
807        }
808    }
809
810    // Fallback: keep exactly min_preserve_messages from the end
811    Ok(messages[min_start_idx..].to_vec())
812}
813
814/// Estimate token count for a message (improved approximation).
815/// For mixed content (code, Chinese, English), use a weighted estimate:
816/// - ASCII chars: ~4 chars per token (0.25 tokens/char)
817/// - Non-ASCII (Chinese, etc): ~1.5 chars per token (0.67 tokens/char)
818/// - JSON/structured data: typically more tokens per char
819pub fn estimate_tokens(message: &Message) -> u32 {
820    let (ascii_count, non_ascii_count) = match &message.content {
821        MessageContent::Text(t) => count_chars(t),
822        MessageContent::Blocks(blocks) => {
823            let mut ascii = 0;
824            let mut non_ascii = 0;
825            for block in blocks {
826                match block {
827                    ContentBlock::Text { text } => {
828                        let (a, n) = count_chars(text);
829                        ascii += a;
830                        non_ascii += n;
831                    }
832                    ContentBlock::ToolUse { name, input, .. } => {
833                        let (a, n) = count_chars(name);
834                        ascii += a;
835                        non_ascii += n;
836                        // JSON input typically has more tokens per char
837                        let json_str = input.to_string();
838                        let (ja, jn) = count_chars(&json_str);
839                        ascii += ja;
840                        non_ascii += jn;
841                    }
842                    ContentBlock::ToolResult { content, .. } => {
843                        let (a, n) = count_chars(content);
844                        ascii += a;
845                        non_ascii += n;
846                    }
847                    ContentBlock::Thinking { thinking, .. } => {
848                        let (a, n) = count_chars(thinking);
849                        ascii += a;
850                        non_ascii += n;
851                    }
852                    _ => {}
853                }
854            }
855            (ascii, non_ascii)
856        }
857    };
858
859    // Calculate tokens: ASCII uses ~0.25 tokens/char, non-ASCII uses ~0.67 tokens/char
860    // Add overhead for message structure (~10 tokens per message)
861    let ascii_tokens = (ascii_count as f64 * 0.25).ceil() as u32;
862    let non_ascii_tokens = (non_ascii_count as f64 * 0.67).ceil() as u32;
863    let total = ascii_tokens + non_ascii_tokens + 10;  // Add overhead
864
865    total.max(1)
866}
867
868/// Count ASCII and non-ASCII characters in a string.
869fn count_chars(s: &str) -> (u32, u32) {
870    let mut ascii = 0u32;
871    let mut non_ascii = 0u32;
872    for ch in s.chars() {
873        if ch.is_ascii() {
874            ascii += 1;
875        } else {
876            non_ascii += 1;
877        }
878    }
879    (ascii, non_ascii)
880}
881
882/// Estimate total tokens for a message list.
883pub fn estimate_total_tokens(messages: &[Message]) -> u32 {
884    messages.iter().map(estimate_tokens).sum()
885}
886
887/// Check if compression should be triggered.
888pub fn should_compress(
889    current_tokens: u32,
890    context_size: Option<u32>,
891    config: &CompressionConfig,
892) -> bool {
893    match context_size {
894        Some(size) => {
895            let ratio = current_tokens as f64 / size as f64;
896            ratio >= config.threshold
897        }
898        None => false,
899    }
900}
901
902/// Build a prompt for AI-based summarization.
903pub fn build_summary_prompt(messages: &[Message]) -> String {
904    let history_text = messages
905        .iter()
906        .map(|m| {
907            let role = match m.role {
908                Role::User => "用户",
909                Role::Assistant => "助手",
910                Role::Tool => "工具",
911                Role::System => "系统",
912            };
913            let content_preview = match &m.content {
914                MessageContent::Text(t) => truncate_for_summary(t, 200),
915                MessageContent::Blocks(blocks) => {
916                    let preview: Vec<String> = blocks
917                        .iter()
918                        .map(|b| match b {
919                            ContentBlock::Text { text } => truncate_for_summary(text, 100),
920                            ContentBlock::ToolUse { name, .. } => format!("[工具: {}]", name),
921                            ContentBlock::ToolResult { content, .. } => truncate_for_summary(content, 100),
922                            _ => "[...]".to_string(),
923                        })
924                        .collect();
925                    preview.join(" | ")
926                }
927            };
928            format!("{}: {}", role, content_preview)
929        })
930        .collect::<Vec<_>>()
931        .join("\n");
932
933    format!(
934        r#"请将以下对话历史压缩为简洁摘要:
935
936对话历史({} 条消息):
937{}
938
939请输出:
9401. 概述(一句话描述主要任务)
9412. 已完成的关键操作(2-3 条)
9423. 当前状态(如果有)"#,
943        messages.len(),
944        history_text
945    )
946}
947
948fn truncate_for_summary(s: &str, max: usize) -> String {
949    truncate_text(s, max)
950}
951
952#[cfg(test)]
953mod tests {
954    use super::*;
955    use serde_json::json;
956
957    #[test]
958    fn test_estimate_tokens_simple() {
959        let msg = Message {
960            role: Role::User,
961            content: MessageContent::Text("Hello world".to_string()),
962        };
963        assert!(estimate_tokens(&msg) >= 3);
964    }
965
966    #[test]
967    fn test_should_compress_below_threshold() {
968        let config = CompressionConfig::default();
969        assert!(!should_compress(100_000, Some(200_000), &config));
970    }
971
972    #[test]
973    fn test_should_compress_above_threshold() {
974        let config = CompressionConfig::default();
975        assert!(should_compress(160_000, Some(200_000), &config));
976    }
977
978    #[test]
979    fn test_truncate_compress_keeps_minimum() {
980        let messages: Vec<Message> = (0..10)
981            .map(|i| Message {
982                role: Role::User,
983                content: MessageContent::Text(format!("Message {}", i)),
984            })
985            .collect();
986
987        let config = CompressionConfig {
988            min_preserve_messages: 4,
989            ..Default::default()
990        };
991
992        let compressed = truncate_compress(&messages, &config).unwrap();
993        assert_eq!(compressed.len(), 4);
994        assert_eq!(compressed[0].content, MessageContent::Text("Message 6".to_string()));
995    }
996
997    #[test]
998    fn test_sliding_window_preserves_turns() {
999        // Create messages with longer content to test token-based compression
1000        let messages: Vec<Message> = vec![
1001            Message { role: Role::User, content: MessageContent::Text("Q1 - this is a longer question to test token estimation".to_string()) },
1002            Message { role: Role::Assistant, content: MessageContent::Text("A1 - this is a longer answer with more content for token estimation".to_string()) },
1003            Message { role: Role::User, content: MessageContent::Text("Q2 - another longer question for testing".to_string()) },
1004            Message { role: Role::Assistant, content: MessageContent::Text("A2 - another longer answer for testing token estimation properly".to_string()) },
1005            Message { role: Role::User, content: MessageContent::Text("Q3 - the third question in this test".to_string()) },
1006            Message { role: Role::Assistant, content: MessageContent::Text("A3 - the third answer with sufficient content".to_string()) },
1007        ];
1008
1009        let config = CompressionConfig {
1010            min_preserve_messages: 4,
1011            target_ratio: 0.5,
1012            ..Default::default()
1013        };
1014
1015        let compressed = sliding_window_compress(&messages, &config).unwrap();
1016        // Should preserve at least min_preserve_messages
1017        assert!(compressed.len() >= config.min_preserve_messages);
1018        // Should preserve complete turns (user + assistant pairs)
1019        assert!(compressed.iter().any(|m| m.role == Role::User));
1020    }
1021
1022    #[test]
1023    fn test_parse_summary_response() {
1024        let text = "用户请求实现登录功能。\n已完成操作:\n• 创建了 login.rs 文件\n• 添加了密码验证逻辑\n当前状态:测试中";
1025        let (summary, key_points) = parse_summary_response(text);
1026        
1027        assert!(!summary.is_empty());
1028        assert!(key_points.len() >= 2);
1029    }
1030
1031    #[test]
1032    fn test_compression_result_format() {
1033        let result = CompressionResult::new(
1034            20,
1035            8,
1036            5000,
1037            Some("摘要内容".to_string()),
1038            CompressionStrategy::Summarize,
1039        );
1040        
1041        let formatted = result.format_summary();
1042        assert!(formatted.contains("20"));
1043        assert!(formatted.contains("8"));
1044        assert!(formatted.contains("AI summarize"));
1045    }
1046
1047    #[test]
1048    fn test_compression_history_entry() {
1049        let result = CompressionResult::new(
1050            15,
1051            6,
1052            3000,
1053            None,
1054            CompressionStrategy::SlidingWindow,
1055        );
1056        
1057        let entry = CompressionHistoryEntry::from_result(&result);
1058        assert_eq!(entry.strategy, CompressionStrategy::SlidingWindow);
1059        assert!(!entry.has_summary);
1060    }
1061
1062    #[test]
1063    fn test_compression_bias_parse() {
1064        // Test preset biases
1065        let balanced = CompressionBias::parse("balanced").unwrap();
1066        assert!(balanced.preserve_tools);
1067        assert!(balanced.preserve_user_questions);
1068
1069        let aggressive = CompressionBias::parse("aggressive").unwrap();
1070        assert!(!aggressive.preserve_tools);
1071        assert!(aggressive.aggressive);
1072
1073        let important = CompressionBias::parse("important").unwrap();
1074        assert!(important.preserve_thinking);
1075        assert!(important.preserve_tools);
1076
1077        let tools = CompressionBias::parse("tools").unwrap();
1078        assert!(tools.preserve_tools);
1079        assert!(!tools.preserve_thinking);
1080    }
1081
1082    #[test]
1083    fn test_compression_bias_format() {
1084        let bias = CompressionBias::balanced();
1085        let formatted = bias.format();
1086        assert!(formatted.contains("tools"));
1087        assert!(formatted.contains("user"));
1088    }
1089
1090    #[test]
1091    fn test_compress_with_bias_preserves_tools() {
1092        let messages: Vec<Message> = vec![
1093            Message { role: Role::User, content: MessageContent::Text("Q1".to_string()) },
1094            Message { 
1095                role: Role::Assistant, 
1096                content: MessageContent::Blocks(vec![
1097                    ContentBlock::ToolUse { id: "1".to_string(), name: "read".to_string(), input: json!({}) }
1098                ])
1099            },
1100            Message { role: Role::Tool, content: MessageContent::Blocks(vec![
1101                ContentBlock::ToolResult { tool_use_id: "1".to_string(), content: "file content".to_string() }
1102            ])},
1103            Message { role: Role::User, content: MessageContent::Text("Q2".to_string()) },
1104            Message { role: Role::Assistant, content: MessageContent::Text("A2".to_string()) },
1105            Message { role: Role::User, content: MessageContent::Text("Q3".to_string()) },
1106            Message { role: Role::Assistant, content: MessageContent::Text("A3".to_string()) },
1107        ];
1108
1109        let config = CompressionConfig {
1110            min_preserve_messages: 2,
1111            bias: CompressionBias::tool_focused(),
1112            ..Default::default()
1113        };
1114
1115        let compressed = compress_with_bias(&messages, &config).unwrap();
1116        
1117        // Tool-focused bias should preserve tool calls
1118        let has_tool_use = compressed.iter().any(|m| {
1119            matches!(&m.content, MessageContent::Blocks(blocks) if 
1120                blocks.iter().any(|b| matches!(b, ContentBlock::ToolUse { .. })))
1121        });
1122        assert!(has_tool_use || compressed.len() >= messages.len() - 2);
1123    }
1124
1125    #[test]
1126    fn test_bias_based_strategy() {
1127        let messages: Vec<Message> = (0..10)
1128            .map(|i| Message {
1129                role: if i % 2 == 0 { Role::User } else { Role::Assistant },
1130                content: MessageContent::Text(format!("Message {}", i)),
1131            })
1132            .collect();
1133
1134        let config = CompressionConfig {
1135            min_preserve_messages: 4,
1136            bias: CompressionBias::aggressive(),
1137            ..Default::default()
1138        };
1139
1140        let compressed = compress_messages(&messages, CompressionStrategy::BiasBased, &config).unwrap();
1141        assert!(compressed.len() <= messages.len());
1142        assert!(compressed.len() >= config.min_preserve_messages);
1143    }
1144}