Skip to main content

matrixcode_core/compress/
scorer.rs

1//! Intelligent scoring for message preservation decisions.
2//!
3//! Combines rule-based scoring, optional AI assistance, and dependency
4//! bonuses to determine which messages to keep during compression.
5
6use anyhow::Result;
7
8use crate::providers::{ContentBlock, Message, MessageContent, Provider, Role};
9
10use super::types::{AiCompressionMode, DependencyGraph, PhaseWeights, ScoredMessage};
11
12/// Scorer for message preservation decisions.
13pub struct Scorer {
14    /// Optional fast model for AI-assisted scoring.
15    fast_model: Option<Box<dyn Provider>>,
16}
17
18impl Scorer {
19    /// Create a new scorer without AI assistance.
20    pub fn new_rule_only() -> Self {
21        Self { fast_model: None }
22    }
23
24    /// Create a new scorer with AI assistance.
25    pub fn new_with_ai(fast_model: Box<dyn Provider>) -> Self {
26        Self {
27            fast_model: Some(fast_model),
28        }
29    }
30
31    /// Score all messages.
32    pub async fn score_all(
33        &self,
34        messages: &[Message],
35        weights: &PhaseWeights,
36        deps: &DependencyGraph,
37        ai_mode: AiCompressionMode,
38    ) -> Result<Vec<ScoredMessage>> {
39        let mut scored: Vec<ScoredMessage> = Vec::new();
40
41        // Phase 1: Rule-based scoring
42        for (idx, msg) in messages.iter().enumerate() {
43            let base_score = score_by_rules(msg, idx, weights);
44            scored.push(ScoredMessage::new(idx, msg.clone(), base_score));
45        }
46
47        // Phase 2: AI-assisted scoring (optional)
48        if ai_mode != AiCompressionMode::None && self.fast_model.is_some() {
49            for sm in &mut scored {
50                if should_ai_score(&sm.message) {
51                    let ai_score = self.score_with_ai(&sm.message, ai_mode).await?;
52                    sm.with_ai_score(ai_score);
53                }
54            }
55        }
56
57        // Phase 3: Dependency bonus
58        apply_dependency_bonus(&mut scored, deps, weights.dependency_pair_bonus);
59
60        Ok(scored)
61    }
62
63    /// Score a single message with AI assistance.
64    async fn score_with_ai(&self, message: &Message, mode: AiCompressionMode) -> Result<f64> {
65        if self.fast_model.is_none() {
66            return Ok(0.0);
67        }
68
69        let content_preview = get_content_preview(message, 500);
70        let prompt = build_ai_score_prompt(&content_preview, mode);
71
72        // Use fast model for quick judgment
73        let provider = self.fast_model.as_ref().unwrap();
74        let response = provider
75            .chat(crate::providers::ChatRequest {
76                messages: vec![Message {
77                    role: Role::User,
78                    content: MessageContent::Text(prompt),
79                }],
80                tools: vec![],
81                system: Some(AI_SCORE_SYSTEM_PROMPT.to_string()),
82                think: false,
83                max_tokens: 100,
84                server_tools: vec![],
85                enable_caching: false,
86            })
87            .await?;
88
89        // Extract score from response (0-30 range)
90        let score_text = extract_text_from_response(&response);
91        parse_ai_score(&score_text)
92    }
93}
94
95/// Rule-based scoring for a message (public for pipeline use).
96pub fn score_by_rules(message: &Message, index: usize, weights: &PhaseWeights) -> f64 {
97    let mut score: f64 = 10.0; // Base score
98
99    // First message gets highest priority
100    if index == 0 {
101        score += weights.first_msg_bonus;
102    }
103
104    // Role-based scoring
105    match message.role {
106        Role::User => {
107            score += weights.user_msg_bonus;
108        }
109        Role::Assistant => {
110            score += 5.0; // Lower base for assistant messages
111        }
112        Role::Tool => {
113            score += weights.tool_result_bonus;
114        }
115        Role::System => {
116            score += 40.0; // System messages are important
117        }
118    }
119
120    // Content-based scoring
121    score += content_score(&message.content, weights);
122
123    score
124}
125
126/// Score based on content blocks.
127fn content_score(content: &MessageContent, weights: &PhaseWeights) -> f64 {
128    let mut score: f64 = 0.0;
129
130    match content {
131        MessageContent::Text(text) => {
132            // Check for sensitive instructions
133            if contains_sensitive_instructions(text) {
134                score += 50.0;
135            }
136
137            // Check for important keywords
138            let keywords = [
139                "决定",
140                "decision",
141                "重要",
142                "important",
143                "关键",
144                "key",
145                "完成",
146                "done",
147            ];
148            for kw in keywords {
149                if text.to_lowercase().contains(kw) {
150                    score += 15.0;
151                }
152            }
153        }
154        MessageContent::Blocks(blocks) => {
155            for block in blocks {
156                match block {
157                    ContentBlock::ToolUse { name, .. } => {
158                        score += weights.tool_use_bonus;
159
160                        // Critical tools get extra bonus
161                        if is_critical_tool(name) {
162                            score += weights.critical_tool_bonus;
163                        }
164
165                        // todo_write is very important for task tracking
166                        if name == "todo_write" {
167                            score += 60.0;
168                        }
169
170                        // ask contains user decisions
171                        if name == "ask" {
172                            score += 50.0;
173                        }
174                    }
175                    ContentBlock::ToolResult { content, .. } => {
176                        score += weights.tool_result_bonus;
177
178                        // Preserve important results
179                        if contains_sensitive_instructions(content) {
180                            score += 30.0;
181                        }
182
183                        // todo_write results
184                        if content.contains("TodoWrite") || content.contains("todo") {
185                            score += 40.0;
186                        }
187
188                        // ask responses
189                        if content.contains("AskUserQuestion") || content.contains("answer") {
190                            score += 30.0;
191                        }
192                    }
193                    ContentBlock::Thinking { thinking, .. } => {
194                        // Thinking can contain key insights
195                        if thinking.contains("决定")
196                            || thinking.contains("问题")
197                            || thinking.contains("关键")
198                        {
199                            score += 30.0;
200                        }
201                    }
202                    ContentBlock::Text { text } => {
203                        if contains_sensitive_instructions(text) {
204                            score += 50.0;
205                        }
206                    }
207                    _ => {}
208                }
209            }
210        }
211    }
212
213    score
214}
215
216/// Apply dependency bonus to scored messages.
217fn apply_dependency_bonus(scored: &mut [ScoredMessage], deps: &DependencyGraph, bonus: f64) {
218    for dep in &deps.dependencies {
219        // Add bonus to ToolUse message
220        if let Some(sm) = scored.get_mut(dep.tool_use_idx) {
221            sm.with_dependency_bonus(bonus);
222        }
223
224        // Add bonus to ToolResult message
225        if let Some(sm) = scored.get_mut(dep.tool_result_idx) {
226            sm.with_dependency_bonus(bonus);
227        }
228
229        // Extra bonus for critical tools
230        if dep.is_critical {
231            if let Some(sm) = scored.get_mut(dep.tool_use_idx) {
232                sm.with_dependency_bonus(bonus * 0.5);
233            }
234            if let Some(sm) = scored.get_mut(dep.tool_result_idx) {
235                sm.with_dependency_bonus(bonus * 0.5);
236            }
237        }
238    }
239}
240
241/// Check if a tool is critical (modifies state).
242fn is_critical_tool(name: &str) -> bool {
243    let critical_tools = ["write", "edit", "multi_edit", "bash"];
244    critical_tools.contains(&name)
245}
246
247/// Check if text contains sensitive instructions.
248fn contains_sensitive_instructions(text: &str) -> bool {
249    let lower = text.to_lowercase();
250    let patterns = [
251        "不要",
252        "禁止",
253        "必须",
254        "不允许",
255        "never",
256        "must not",
257        "do not",
258        "important",
259    ];
260    patterns.iter().any(|p| lower.contains(p))
261}
262
263/// Check if a message should be AI-scored.
264fn should_ai_score(message: &Message) -> bool {
265    // Only score longer user or assistant messages
266    match message.role {
267        Role::User | Role::Assistant => {
268            let len = estimate_content_length(&message.content);
269            len > 100 // Only AI-score substantial content
270        }
271        _ => false,
272    }
273}
274
275/// Estimate content length.
276fn estimate_content_length(content: &MessageContent) -> usize {
277    match content {
278        MessageContent::Text(text) => text.len(),
279        MessageContent::Blocks(blocks) => blocks
280            .iter()
281            .map(|b| match b {
282                ContentBlock::Text { text } => text.len(),
283                ContentBlock::ToolUse { input, .. } => input.to_string().len(),
284                ContentBlock::ToolResult { content, .. } => content.len(),
285                ContentBlock::Thinking { thinking, .. } => thinking.len(),
286                _ => 0,
287            })
288            .sum(),
289    }
290}
291
292/// Get content preview for AI scoring.
293fn get_content_preview(message: &Message, max_len: usize) -> String {
294    match &message.content {
295        MessageContent::Text(text) => {
296            if text.len() > max_len {
297                text[..max_len].to_string() + "..."
298            } else {
299                text.clone()
300            }
301        }
302        MessageContent::Blocks(blocks) => {
303            let preview: Vec<String> = blocks
304                .iter()
305                .take(3)
306                .map(|b| match b {
307                    ContentBlock::Text { text } => text.chars().take(100).collect(),
308                    ContentBlock::ToolUse { name, .. } => format!("[Tool: {}]", name),
309                    ContentBlock::ToolResult { content, .. } => {
310                        content.chars().take(100).collect::<String>() + "..."
311                    }
312                    _ => "...".to_string(),
313                })
314                .collect();
315            preview.join(" | ")
316        }
317    }
318}
319
320/// Build prompt for AI scoring.
321fn build_ai_score_prompt(content: &str, mode: AiCompressionMode) -> String {
322    match mode {
323        AiCompressionMode::Light => format!(
324            "判断这段内容对当前任务的重要性(0-30分,0=无关,30=关键):\n{}",
325            content
326        ),
327        AiCompressionMode::Deep => format!(
328            "深入分析这段内容的重要性,考虑:\n1. 是否包含关键决策\n2. 是否包含未完成任务\n3. 是否包含敏感指令\n输出重要性评分(0-30分):\n{}",
329            content
330        ),
331        AiCompressionMode::None => String::new(),
332    }
333}
334
335/// Extract text from response.
336fn extract_text_from_response(response: &crate::providers::ChatResponse) -> String {
337    response
338        .content
339        .iter()
340        .filter_map(|b| {
341            if let ContentBlock::Text { text } = b {
342                Some(text.clone())
343            } else {
344                None
345            }
346        })
347        .collect::<Vec<_>>()
348        .join("\n")
349}
350
351/// Parse AI score from text.
352fn parse_ai_score(text: &str) -> Result<f64> {
353    // Try to find a number in the text
354    let text = text.trim();
355
356    // Direct number
357    if let Ok(score) = text.parse::<f64>() {
358        return Ok(score.clamp(0.0, 30.0));
359    }
360
361    // Look for "评分: X" or "score: X"
362    for line in text.lines() {
363        let lower = line.to_lowercase();
364        if lower.contains("评分") || lower.contains("score") {
365            // Extract number
366            let nums: Vec<f64> = line
367                .split_whitespace()
368                .filter_map(|s| s.parse::<f64>().ok())
369                .collect();
370            if let Some(score) = nums.first() {
371                return Ok(score.clamp(0.0, 30.0));
372            }
373        }
374    }
375
376    // Default score
377    Ok(10.0)
378}
379
380const AI_SCORE_SYSTEM_PROMPT: &str = r#"你是一个内容重要性评估助手。快速判断内容的重要性并输出评分。
381
382输出要求:
383- 仅输出一个数字(0-30)
384- 0 = 完全不重要,可以删除
385- 10 = 一般重要,可保留可删除
386- 20 = 重要,建议保留
387- 30 = 关键,必须保留
388
389请直接输出评分数字。"#;
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn test_score_by_rules_first_message() {
397        let weights = PhaseWeights::balanced();
398        let message = Message {
399            role: Role::User,
400            content: MessageContent::Text("Hello".to_string()),
401        };
402        let score = score_by_rules(&message, 0, &weights);
403        assert!(score > 100.0); // Should have first_msg_bonus
404    }
405
406    #[test]
407    fn test_score_by_rules_sensitive() {
408        let weights = PhaseWeights::balanced();
409        let message = Message {
410            role: Role::User,
411            content: MessageContent::Text("不要删除这个文件".to_string()),
412        };
413        let score = score_by_rules(&message, 5, &weights);
414        assert!(score > 50.0); // Should have sensitive instruction bonus
415    }
416
417    #[test]
418    fn test_contains_sensitive_instructions() {
419        assert!(contains_sensitive_instructions("不要删除"));
420        assert!(contains_sensitive_instructions("must not do this"));
421        assert!(!contains_sensitive_instructions("普通文本"));
422    }
423
424    #[test]
425    fn test_is_critical_tool() {
426        assert!(is_critical_tool("write"));
427        assert!(is_critical_tool("bash"));
428        assert!(!is_critical_tool("read"));
429    }
430
431    #[test]
432    fn test_parse_ai_score() {
433        assert_eq!(parse_ai_score("15").unwrap(), 15.0);
434        assert_eq!(parse_ai_score("评分: 20").unwrap(), 20.0);
435        assert_eq!(parse_ai_score("score: 25").unwrap(), 25.0);
436        assert_eq!(parse_ai_score("unknown").unwrap(), 10.0); // Default
437    }
438}