Skip to main content

matrixcode_core/compress/
scorer.rs

1//! Intelligent scoring for message preservation decisions.
2//!
3//! Combines rule-based scoring, optional AI assistance, and dependency
4//! bonuses to determine which messages to keep during compression.
5
6use anyhow::Result;
7
8use crate::providers::{ContentBlock, Message, MessageContent, Provider, Role};
9
10use super::types::{AiCompressionMode, DependencyGraph, PhaseWeights, ScoredMessage};
11
12/// Scorer for message preservation decisions.
13pub struct Scorer {
14    /// Optional fast model for AI-assisted scoring.
15    fast_model: Option<Box<dyn Provider>>,
16}
17
18impl Scorer {
19    /// Create a new scorer without AI assistance.
20    pub fn new_rule_only() -> Self {
21        Self { fast_model: None }
22    }
23
24    /// Create a new scorer with AI assistance.
25    pub fn new_with_ai(fast_model: Box<dyn Provider>) -> Self {
26        Self { fast_model: Some(fast_model) }
27    }
28
29    /// Score all messages.
30    pub async fn score_all(
31        &self,
32        messages: &[Message],
33        weights: &PhaseWeights,
34        deps: &DependencyGraph,
35        ai_mode: AiCompressionMode,
36    ) -> Result<Vec<ScoredMessage>> {
37        let mut scored: Vec<ScoredMessage> = Vec::new();
38
39        // Phase 1: Rule-based scoring
40        for (idx, msg) in messages.iter().enumerate() {
41            let base_score = score_by_rules(msg, idx, weights);
42            scored.push(ScoredMessage::new(idx, msg.clone(), base_score));
43        }
44
45        // Phase 2: AI-assisted scoring (optional)
46        if ai_mode != AiCompressionMode::None && self.fast_model.is_some() {
47            for sm in &mut scored {
48                if should_ai_score(&sm.message) {
49                    let ai_score = self.score_with_ai(&sm.message, ai_mode).await?;
50                    sm.with_ai_score(ai_score);
51                }
52            }
53        }
54
55        // Phase 3: Dependency bonus
56        apply_dependency_bonus(&mut scored, deps, weights.dependency_pair_bonus);
57
58        Ok(scored)
59    }
60
61    /// Score a single message with AI assistance.
62    async fn score_with_ai(
63        &self,
64        message: &Message,
65        mode: AiCompressionMode,
66    ) -> Result<f64> {
67        if self.fast_model.is_none() {
68            return Ok(0.0);
69        }
70
71        let content_preview = get_content_preview(message, 500);
72        let prompt = build_ai_score_prompt(&content_preview, mode);
73
74        // Use fast model for quick judgment
75        let provider = self.fast_model.as_ref().unwrap();
76        let response = provider.chat(crate::providers::ChatRequest {
77            messages: vec![Message {
78                role: Role::User,
79                content: MessageContent::Text(prompt),
80            }],
81            tools: vec![],
82            system: Some(AI_SCORE_SYSTEM_PROMPT.to_string()),
83            think: false,
84            max_tokens: 100,
85            server_tools: vec![],
86            enable_caching: false,
87        }).await?;
88
89        // Extract score from response (0-30 range)
90        let score_text = extract_text_from_response(&response);
91        parse_ai_score(&score_text)
92    }
93}
94
95/// Rule-based scoring for a message (public for pipeline use).
96pub fn score_by_rules(message: &Message, index: usize, weights: &PhaseWeights) -> f64 {
97    let mut score: f64 = 10.0; // Base score
98
99    // First message gets highest priority
100    if index == 0 {
101        score += weights.first_msg_bonus;
102    }
103
104    // Role-based scoring
105    match message.role {
106        Role::User => {
107            score += weights.user_msg_bonus;
108        }
109        Role::Assistant => {
110            score += 5.0; // Lower base for assistant messages
111        }
112        Role::Tool => {
113            score += weights.tool_result_bonus;
114        }
115        Role::System => {
116            score += 40.0; // System messages are important
117        }
118    }
119
120    // Content-based scoring
121    score += content_score(&message.content, weights);
122
123    score
124}
125
126/// Score based on content blocks.
127fn content_score(content: &MessageContent, weights: &PhaseWeights) -> f64 {
128    let mut score: f64 = 0.0;
129
130    match content {
131        MessageContent::Text(text) => {
132            // Check for sensitive instructions
133            if contains_sensitive_instructions(text) {
134                score += 50.0;
135            }
136
137            // Check for important keywords
138            let keywords = ["决定", "decision", "重要", "important", "关键", "key", "完成", "done"];
139            for kw in keywords {
140                if text.to_lowercase().contains(kw) {
141                    score += 15.0;
142                }
143            }
144        }
145        MessageContent::Blocks(blocks) => {
146            for block in blocks {
147                match block {
148                    ContentBlock::ToolUse { name, .. } => {
149                        score += weights.tool_use_bonus;
150
151                        // Critical tools get extra bonus
152                        if is_critical_tool(name) {
153                            score += weights.critical_tool_bonus;
154                        }
155
156                        // todo_write is very important for task tracking
157                        if name == "todo_write" {
158                            score += 60.0;
159                        }
160
161                        // ask contains user decisions
162                        if name == "ask" {
163                            score += 50.0;
164                        }
165                    }
166                    ContentBlock::ToolResult { content, .. } => {
167                        score += weights.tool_result_bonus;
168
169                        // Preserve important results
170                        if contains_sensitive_instructions(content) {
171                            score += 30.0;
172                        }
173
174                        // todo_write results
175                        if content.contains("TodoWrite") || content.contains("todo") {
176                            score += 40.0;
177                        }
178
179                        // ask responses
180                        if content.contains("AskUserQuestion") || content.contains("answer") {
181                            score += 30.0;
182                        }
183                    }
184                    ContentBlock::Thinking { thinking, .. } => {
185                        // Thinking can contain key insights
186                        if thinking.contains("决定") || thinking.contains("问题") || thinking.contains("关键") {
187                            score += 30.0;
188                        }
189                    }
190                    ContentBlock::Text { text } => {
191                        if contains_sensitive_instructions(text) {
192                            score += 50.0;
193                        }
194                    }
195                    _ => {}
196                }
197            }
198        }
199    }
200
201    score
202}
203
204/// Apply dependency bonus to scored messages.
205fn apply_dependency_bonus(
206    scored: &mut [ScoredMessage],
207    deps: &DependencyGraph,
208    bonus: f64,
209) {
210    for dep in &deps.dependencies {
211        // Add bonus to ToolUse message
212        if let Some(sm) = scored.get_mut(dep.tool_use_idx) {
213            sm.with_dependency_bonus(bonus);
214        }
215
216        // Add bonus to ToolResult message
217        if let Some(sm) = scored.get_mut(dep.tool_result_idx) {
218            sm.with_dependency_bonus(bonus);
219        }
220
221        // Extra bonus for critical tools
222        if dep.is_critical {
223            if let Some(sm) = scored.get_mut(dep.tool_use_idx) {
224                sm.with_dependency_bonus(bonus * 0.5);
225            }
226            if let Some(sm) = scored.get_mut(dep.tool_result_idx) {
227                sm.with_dependency_bonus(bonus * 0.5);
228            }
229        }
230    }
231}
232
233/// Check if a tool is critical (modifies state).
234fn is_critical_tool(name: &str) -> bool {
235    let critical_tools = ["write", "edit", "multi_edit", "bash"];
236    critical_tools.contains(&name)
237}
238
239/// Check if text contains sensitive instructions.
240fn contains_sensitive_instructions(text: &str) -> bool {
241    let lower = text.to_lowercase();
242    let patterns = [
243        "不要", "禁止", "必须", "不允许",
244        "never", "must not", "do not", "important",
245    ];
246    patterns.iter().any(|p| lower.contains(p))
247}
248
249/// Check if a message should be AI-scored.
250fn should_ai_score(message: &Message) -> bool {
251    // Only score longer user or assistant messages
252    match message.role {
253        Role::User | Role::Assistant => {
254            let len = estimate_content_length(&message.content);
255            len > 100 // Only AI-score substantial content
256        }
257        _ => false,
258    }
259}
260
261/// Estimate content length.
262fn estimate_content_length(content: &MessageContent) -> usize {
263    match content {
264        MessageContent::Text(text) => text.len(),
265        MessageContent::Blocks(blocks) => {
266            blocks.iter().map(|b| {
267                match b {
268                    ContentBlock::Text { text } => text.len(),
269                    ContentBlock::ToolUse { input, .. } => input.to_string().len(),
270                    ContentBlock::ToolResult { content, .. } => content.len(),
271                    ContentBlock::Thinking { thinking, .. } => thinking.len(),
272                    _ => 0,
273                }
274            }).sum()
275        }
276    }
277}
278
279/// Get content preview for AI scoring.
280fn get_content_preview(message: &Message, max_len: usize) -> String {
281    match &message.content {
282        MessageContent::Text(text) => {
283            if text.len() > max_len {
284                text[..max_len].to_string() + "..."
285            } else {
286                text.clone()
287            }
288        }
289        MessageContent::Blocks(blocks) => {
290            let preview: Vec<String> = blocks.iter().take(3).map(|b| {
291                match b {
292                    ContentBlock::Text { text } => text.chars().take(100).collect(),
293                    ContentBlock::ToolUse { name, .. } => format!("[Tool: {}]", name),
294                    ContentBlock::ToolResult { content, .. } => {
295                        content.chars().take(100).collect::<String>() + "..."
296                    },
297                    _ => "...".to_string(),
298                }
299            }).collect();
300            preview.join(" | ")
301        }
302    }
303}
304
305/// Build prompt for AI scoring.
306fn build_ai_score_prompt(content: &str, mode: AiCompressionMode) -> String {
307    match mode {
308        AiCompressionMode::Light => format!(
309            "判断这段内容对当前任务的重要性(0-30分,0=无关,30=关键):\n{}",
310            content
311        ),
312        AiCompressionMode::Deep => format!(
313            "深入分析这段内容的重要性,考虑:\n1. 是否包含关键决策\n2. 是否包含未完成任务\n3. 是否包含敏感指令\n输出重要性评分(0-30分):\n{}",
314            content
315        ),
316        AiCompressionMode::None => String::new(),
317    }
318}
319
320/// Extract text from response.
321fn extract_text_from_response(response: &crate::providers::ChatResponse) -> String {
322    response.content.iter()
323        .filter_map(|b| {
324            if let ContentBlock::Text { text } = b {
325                Some(text.clone())
326            } else {
327                None
328            }
329        })
330        .collect::<Vec<_>>()
331        .join("\n")
332}
333
334/// Parse AI score from text.
335fn parse_ai_score(text: &str) -> Result<f64> {
336    // Try to find a number in the text
337    let text = text.trim();
338
339    // Direct number
340    if let Ok(score) = text.parse::<f64>() {
341        return Ok(score.clamp(0.0, 30.0));
342    }
343
344    // Look for "评分: X" or "score: X"
345    for line in text.lines() {
346        let lower = line.to_lowercase();
347        if lower.contains("评分") || lower.contains("score") {
348            // Extract number
349            let nums: Vec<f64> = line
350                .split_whitespace()
351                .filter_map(|s| s.parse::<f64>().ok())
352                .collect();
353            if let Some(score) = nums.first() {
354                return Ok(score.clamp(0.0, 30.0));
355            }
356        }
357    }
358
359    // Default score
360    Ok(10.0)
361}
362
363const AI_SCORE_SYSTEM_PROMPT: &str = r#"你是一个内容重要性评估助手。快速判断内容的重要性并输出评分。
364
365输出要求:
366- 仅输出一个数字(0-30)
367- 0 = 完全不重要,可以删除
368- 10 = 一般重要,可保留可删除
369- 20 = 重要,建议保留
370- 30 = 关键,必须保留
371
372请直接输出评分数字。"#;
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_score_by_rules_first_message() {
380        let weights = PhaseWeights::balanced();
381        let message = Message {
382            role: Role::User,
383            content: MessageContent::Text("Hello".to_string()),
384        };
385        let score = score_by_rules(&message, 0, &weights);
386        assert!(score > 100.0); // Should have first_msg_bonus
387    }
388
389    #[test]
390    fn test_score_by_rules_sensitive() {
391        let weights = PhaseWeights::balanced();
392        let message = Message {
393            role: Role::User,
394            content: MessageContent::Text("不要删除这个文件".to_string()),
395        };
396        let score = score_by_rules(&message, 5, &weights);
397        assert!(score > 50.0); // Should have sensitive instruction bonus
398    }
399
400    #[test]
401    fn test_contains_sensitive_instructions() {
402        assert!(contains_sensitive_instructions("不要删除"));
403        assert!(contains_sensitive_instructions("must not do this"));
404        assert!(!contains_sensitive_instructions("普通文本"));
405    }
406
407    #[test]
408    fn test_is_critical_tool() {
409        assert!(is_critical_tool("write"));
410        assert!(is_critical_tool("bash"));
411        assert!(!is_critical_tool("read"));
412    }
413
414    #[test]
415    fn test_parse_ai_score() {
416        assert_eq!(parse_ai_score("15").unwrap(), 15.0);
417        assert_eq!(parse_ai_score("评分: 20").unwrap(), 20.0);
418        assert_eq!(parse_ai_score("score: 25").unwrap(), 25.0);
419        assert_eq!(parse_ai_score("unknown").unwrap(), 10.0); // Default
420    }
421}