matrixcode_core/compress/
scorer.rs1use anyhow::Result;
7
8use crate::providers::{ContentBlock, Message, MessageContent, Provider, Role};
9
10use super::types::{AiCompressionMode, DependencyGraph, PhaseWeights, ScoredMessage};
11
12pub struct Scorer {
14 fast_model: Option<Box<dyn Provider>>,
16}
17
18impl Scorer {
19 pub fn new_rule_only() -> Self {
21 Self { fast_model: None }
22 }
23
24 pub fn new_with_ai(fast_model: Box<dyn Provider>) -> Self {
26 Self { fast_model: Some(fast_model) }
27 }
28
29 pub async fn score_all(
31 &self,
32 messages: &[Message],
33 weights: &PhaseWeights,
34 deps: &DependencyGraph,
35 ai_mode: AiCompressionMode,
36 ) -> Result<Vec<ScoredMessage>> {
37 let mut scored: Vec<ScoredMessage> = Vec::new();
38
39 for (idx, msg) in messages.iter().enumerate() {
41 let base_score = score_by_rules(msg, idx, weights);
42 scored.push(ScoredMessage::new(idx, msg.clone(), base_score));
43 }
44
45 if ai_mode != AiCompressionMode::None && self.fast_model.is_some() {
47 for sm in &mut scored {
48 if should_ai_score(&sm.message) {
49 let ai_score = self.score_with_ai(&sm.message, ai_mode).await?;
50 sm.with_ai_score(ai_score);
51 }
52 }
53 }
54
55 apply_dependency_bonus(&mut scored, deps, weights.dependency_pair_bonus);
57
58 Ok(scored)
59 }
60
61 async fn score_with_ai(
63 &self,
64 message: &Message,
65 mode: AiCompressionMode,
66 ) -> Result<f64> {
67 if self.fast_model.is_none() {
68 return Ok(0.0);
69 }
70
71 let content_preview = get_content_preview(message, 500);
72 let prompt = build_ai_score_prompt(&content_preview, mode);
73
74 let provider = self.fast_model.as_ref().unwrap();
76 let response = provider.chat(crate::providers::ChatRequest {
77 messages: vec![Message {
78 role: Role::User,
79 content: MessageContent::Text(prompt),
80 }],
81 tools: vec![],
82 system: Some(AI_SCORE_SYSTEM_PROMPT.to_string()),
83 think: false,
84 max_tokens: 100,
85 server_tools: vec![],
86 enable_caching: false,
87 }).await?;
88
89 let score_text = extract_text_from_response(&response);
91 parse_ai_score(&score_text)
92 }
93}
94
95pub fn score_by_rules(message: &Message, index: usize, weights: &PhaseWeights) -> f64 {
97 let mut score: f64 = 10.0; if index == 0 {
101 score += weights.first_msg_bonus;
102 }
103
104 match message.role {
106 Role::User => {
107 score += weights.user_msg_bonus;
108 }
109 Role::Assistant => {
110 score += 5.0; }
112 Role::Tool => {
113 score += weights.tool_result_bonus;
114 }
115 Role::System => {
116 score += 40.0; }
118 }
119
120 score += content_score(&message.content, weights);
122
123 score
124}
125
126fn content_score(content: &MessageContent, weights: &PhaseWeights) -> f64 {
128 let mut score: f64 = 0.0;
129
130 match content {
131 MessageContent::Text(text) => {
132 if contains_sensitive_instructions(text) {
134 score += 50.0;
135 }
136
137 let keywords = ["决定", "decision", "重要", "important", "关键", "key", "完成", "done"];
139 for kw in keywords {
140 if text.to_lowercase().contains(kw) {
141 score += 15.0;
142 }
143 }
144 }
145 MessageContent::Blocks(blocks) => {
146 for block in blocks {
147 match block {
148 ContentBlock::ToolUse { name, .. } => {
149 score += weights.tool_use_bonus;
150
151 if is_critical_tool(name) {
153 score += weights.critical_tool_bonus;
154 }
155
156 if name == "todo_write" {
158 score += 60.0;
159 }
160
161 if name == "ask" {
163 score += 50.0;
164 }
165 }
166 ContentBlock::ToolResult { content, .. } => {
167 score += weights.tool_result_bonus;
168
169 if contains_sensitive_instructions(content) {
171 score += 30.0;
172 }
173
174 if content.contains("TodoWrite") || content.contains("todo") {
176 score += 40.0;
177 }
178
179 if content.contains("AskUserQuestion") || content.contains("answer") {
181 score += 30.0;
182 }
183 }
184 ContentBlock::Thinking { thinking, .. } => {
185 if thinking.contains("决定") || thinking.contains("问题") || thinking.contains("关键") {
187 score += 30.0;
188 }
189 }
190 ContentBlock::Text { text } => {
191 if contains_sensitive_instructions(text) {
192 score += 50.0;
193 }
194 }
195 _ => {}
196 }
197 }
198 }
199 }
200
201 score
202}
203
204fn apply_dependency_bonus(
206 scored: &mut [ScoredMessage],
207 deps: &DependencyGraph,
208 bonus: f64,
209) {
210 for dep in &deps.dependencies {
211 if let Some(sm) = scored.get_mut(dep.tool_use_idx) {
213 sm.with_dependency_bonus(bonus);
214 }
215
216 if let Some(sm) = scored.get_mut(dep.tool_result_idx) {
218 sm.with_dependency_bonus(bonus);
219 }
220
221 if dep.is_critical {
223 if let Some(sm) = scored.get_mut(dep.tool_use_idx) {
224 sm.with_dependency_bonus(bonus * 0.5);
225 }
226 if let Some(sm) = scored.get_mut(dep.tool_result_idx) {
227 sm.with_dependency_bonus(bonus * 0.5);
228 }
229 }
230 }
231}
232
233fn is_critical_tool(name: &str) -> bool {
235 let critical_tools = ["write", "edit", "multi_edit", "bash"];
236 critical_tools.contains(&name)
237}
238
239fn contains_sensitive_instructions(text: &str) -> bool {
241 let lower = text.to_lowercase();
242 let patterns = [
243 "不要", "禁止", "必须", "不允许",
244 "never", "must not", "do not", "important",
245 ];
246 patterns.iter().any(|p| lower.contains(p))
247}
248
249fn should_ai_score(message: &Message) -> bool {
251 match message.role {
253 Role::User | Role::Assistant => {
254 let len = estimate_content_length(&message.content);
255 len > 100 }
257 _ => false,
258 }
259}
260
261fn estimate_content_length(content: &MessageContent) -> usize {
263 match content {
264 MessageContent::Text(text) => text.len(),
265 MessageContent::Blocks(blocks) => {
266 blocks.iter().map(|b| {
267 match b {
268 ContentBlock::Text { text } => text.len(),
269 ContentBlock::ToolUse { input, .. } => input.to_string().len(),
270 ContentBlock::ToolResult { content, .. } => content.len(),
271 ContentBlock::Thinking { thinking, .. } => thinking.len(),
272 _ => 0,
273 }
274 }).sum()
275 }
276 }
277}
278
279fn get_content_preview(message: &Message, max_len: usize) -> String {
281 match &message.content {
282 MessageContent::Text(text) => {
283 if text.len() > max_len {
284 text[..max_len].to_string() + "..."
285 } else {
286 text.clone()
287 }
288 }
289 MessageContent::Blocks(blocks) => {
290 let preview: Vec<String> = blocks.iter().take(3).map(|b| {
291 match b {
292 ContentBlock::Text { text } => text.chars().take(100).collect(),
293 ContentBlock::ToolUse { name, .. } => format!("[Tool: {}]", name),
294 ContentBlock::ToolResult { content, .. } => {
295 content.chars().take(100).collect::<String>() + "..."
296 },
297 _ => "...".to_string(),
298 }
299 }).collect();
300 preview.join(" | ")
301 }
302 }
303}
304
305fn build_ai_score_prompt(content: &str, mode: AiCompressionMode) -> String {
307 match mode {
308 AiCompressionMode::Light => format!(
309 "判断这段内容对当前任务的重要性(0-30分,0=无关,30=关键):\n{}",
310 content
311 ),
312 AiCompressionMode::Deep => format!(
313 "深入分析这段内容的重要性,考虑:\n1. 是否包含关键决策\n2. 是否包含未完成任务\n3. 是否包含敏感指令\n输出重要性评分(0-30分):\n{}",
314 content
315 ),
316 AiCompressionMode::None => String::new(),
317 }
318}
319
320fn extract_text_from_response(response: &crate::providers::ChatResponse) -> String {
322 response.content.iter()
323 .filter_map(|b| {
324 if let ContentBlock::Text { text } = b {
325 Some(text.clone())
326 } else {
327 None
328 }
329 })
330 .collect::<Vec<_>>()
331 .join("\n")
332}
333
334fn parse_ai_score(text: &str) -> Result<f64> {
336 let text = text.trim();
338
339 if let Ok(score) = text.parse::<f64>() {
341 return Ok(score.clamp(0.0, 30.0));
342 }
343
344 for line in text.lines() {
346 let lower = line.to_lowercase();
347 if lower.contains("评分") || lower.contains("score") {
348 let nums: Vec<f64> = line
350 .split_whitespace()
351 .filter_map(|s| s.parse::<f64>().ok())
352 .collect();
353 if let Some(score) = nums.first() {
354 return Ok(score.clamp(0.0, 30.0));
355 }
356 }
357 }
358
359 Ok(10.0)
361}
362
363const AI_SCORE_SYSTEM_PROMPT: &str = r#"你是一个内容重要性评估助手。快速判断内容的重要性并输出评分。
364
365输出要求:
366- 仅输出一个数字(0-30)
367- 0 = 完全不重要,可以删除
368- 10 = 一般重要,可保留可删除
369- 20 = 重要,建议保留
370- 30 = 关键,必须保留
371
372请直接输出评分数字。"#;
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_score_by_rules_first_message() {
380 let weights = PhaseWeights::balanced();
381 let message = Message {
382 role: Role::User,
383 content: MessageContent::Text("Hello".to_string()),
384 };
385 let score = score_by_rules(&message, 0, &weights);
386 assert!(score > 100.0); }
388
389 #[test]
390 fn test_score_by_rules_sensitive() {
391 let weights = PhaseWeights::balanced();
392 let message = Message {
393 role: Role::User,
394 content: MessageContent::Text("不要删除这个文件".to_string()),
395 };
396 let score = score_by_rules(&message, 5, &weights);
397 assert!(score > 50.0); }
399
400 #[test]
401 fn test_contains_sensitive_instructions() {
402 assert!(contains_sensitive_instructions("不要删除"));
403 assert!(contains_sensitive_instructions("must not do this"));
404 assert!(!contains_sensitive_instructions("普通文本"));
405 }
406
407 #[test]
408 fn test_is_critical_tool() {
409 assert!(is_critical_tool("write"));
410 assert!(is_critical_tool("bash"));
411 assert!(!is_critical_tool("read"));
412 }
413
414 #[test]
415 fn test_parse_ai_score() {
416 assert_eq!(parse_ai_score("15").unwrap(), 15.0);
417 assert_eq!(parse_ai_score("评分: 20").unwrap(), 20.0);
418 assert_eq!(parse_ai_score("score: 25").unwrap(), 25.0);
419 assert_eq!(parse_ai_score("unknown").unwrap(), 10.0); }
421}