matrixcode_core/compress/
scorer.rs1use anyhow::Result;
7
8use crate::providers::{ContentBlock, Message, MessageContent, Provider, Role};
9
10use super::types::{AiCompressionMode, DependencyGraph, PhaseWeights, ScoredMessage};
11
12pub struct Scorer {
14 fast_model: Option<Box<dyn Provider>>,
16}
17
18impl Scorer {
19 pub fn new_rule_only() -> Self {
21 Self { fast_model: None }
22 }
23
24 pub fn new_with_ai(fast_model: Box<dyn Provider>) -> Self {
26 Self {
27 fast_model: Some(fast_model),
28 }
29 }
30
31 pub async fn score_all(
33 &self,
34 messages: &[Message],
35 weights: &PhaseWeights,
36 deps: &DependencyGraph,
37 ai_mode: AiCompressionMode,
38 ) -> Result<Vec<ScoredMessage>> {
39 let mut scored: Vec<ScoredMessage> = Vec::new();
40
41 for (idx, msg) in messages.iter().enumerate() {
43 let base_score = score_by_rules(msg, idx, weights);
44 scored.push(ScoredMessage::new(idx, msg.clone(), base_score));
45 }
46
47 if ai_mode != AiCompressionMode::None && self.fast_model.is_some() {
49 for sm in &mut scored {
50 if should_ai_score(&sm.message) {
51 let ai_score = self.score_with_ai(&sm.message, ai_mode).await?;
52 sm.with_ai_score(ai_score);
53 }
54 }
55 }
56
57 apply_dependency_bonus(&mut scored, deps, weights.dependency_pair_bonus);
59
60 Ok(scored)
61 }
62
63 async fn score_with_ai(&self, message: &Message, mode: AiCompressionMode) -> Result<f64> {
65 if self.fast_model.is_none() {
66 return Ok(0.0);
67 }
68
69 let content_preview = get_content_preview(message, 500);
70 let prompt = build_ai_score_prompt(&content_preview, mode);
71
72 let provider = self.fast_model.as_ref().unwrap();
74 let response = provider
75 .chat(crate::providers::ChatRequest {
76 messages: vec![Message {
77 role: Role::User,
78 content: MessageContent::Text(prompt),
79 }],
80 tools: vec![],
81 system: Some(AI_SCORE_SYSTEM_PROMPT.to_string()),
82 think: false,
83 max_tokens: 100,
84 server_tools: vec![],
85 enable_caching: false,
86 })
87 .await?;
88
89 let score_text = extract_text_from_response(&response);
91 parse_ai_score(&score_text)
92 }
93}
94
95pub fn score_by_rules(message: &Message, index: usize, weights: &PhaseWeights) -> f64 {
97 let mut score: f64 = 10.0; if index == 0 {
101 score += weights.first_msg_bonus;
102 }
103
104 match message.role {
106 Role::User => {
107 score += weights.user_msg_bonus;
108 }
109 Role::Assistant => {
110 score += 5.0; }
112 Role::Tool => {
113 score += weights.tool_result_bonus;
114 }
115 Role::System => {
116 score += 40.0; }
118 }
119
120 score += content_score(&message.content, weights);
122
123 score
124}
125
126fn content_score(content: &MessageContent, weights: &PhaseWeights) -> f64 {
128 let mut score: f64 = 0.0;
129
130 match content {
131 MessageContent::Text(text) => {
132 if contains_sensitive_instructions(text) {
134 score += 50.0;
135 }
136
137 let keywords = [
139 "决定",
140 "decision",
141 "重要",
142 "important",
143 "关键",
144 "key",
145 "完成",
146 "done",
147 ];
148 for kw in keywords {
149 if text.to_lowercase().contains(kw) {
150 score += 15.0;
151 }
152 }
153 }
154 MessageContent::Blocks(blocks) => {
155 for block in blocks {
156 match block {
157 ContentBlock::ToolUse { name, .. } => {
158 score += weights.tool_use_bonus;
159
160 if is_critical_tool(name) {
162 score += weights.critical_tool_bonus;
163 }
164
165 if name == "todo_write" {
167 score += 60.0;
168 }
169
170 if name == "ask" {
172 score += 50.0;
173 }
174 }
175 ContentBlock::ToolResult { content, .. } => {
176 score += weights.tool_result_bonus;
177
178 if contains_sensitive_instructions(content) {
180 score += 30.0;
181 }
182
183 if content.contains("TodoWrite") || content.contains("todo") {
185 score += 40.0;
186 }
187
188 if content.contains("AskUserQuestion") || content.contains("answer") {
190 score += 30.0;
191 }
192 }
193 ContentBlock::Thinking { thinking, .. } => {
194 if thinking.contains("决定")
196 || thinking.contains("问题")
197 || thinking.contains("关键")
198 {
199 score += 30.0;
200 }
201 }
202 ContentBlock::Text { text } => {
203 if contains_sensitive_instructions(text) {
204 score += 50.0;
205 }
206 }
207 _ => {}
208 }
209 }
210 }
211 }
212
213 score
214}
215
216fn apply_dependency_bonus(scored: &mut [ScoredMessage], deps: &DependencyGraph, bonus: f64) {
218 for dep in &deps.dependencies {
219 if let Some(sm) = scored.get_mut(dep.tool_use_idx) {
221 sm.with_dependency_bonus(bonus);
222 }
223
224 if let Some(sm) = scored.get_mut(dep.tool_result_idx) {
226 sm.with_dependency_bonus(bonus);
227 }
228
229 if dep.is_critical {
231 if let Some(sm) = scored.get_mut(dep.tool_use_idx) {
232 sm.with_dependency_bonus(bonus * 0.5);
233 }
234 if let Some(sm) = scored.get_mut(dep.tool_result_idx) {
235 sm.with_dependency_bonus(bonus * 0.5);
236 }
237 }
238 }
239}
240
241fn is_critical_tool(name: &str) -> bool {
243 let critical_tools = ["write", "edit", "multi_edit", "bash"];
244 critical_tools.contains(&name)
245}
246
247fn contains_sensitive_instructions(text: &str) -> bool {
249 let lower = text.to_lowercase();
250 let patterns = [
251 "不要",
252 "禁止",
253 "必须",
254 "不允许",
255 "never",
256 "must not",
257 "do not",
258 "important",
259 ];
260 patterns.iter().any(|p| lower.contains(p))
261}
262
263fn should_ai_score(message: &Message) -> bool {
265 match message.role {
267 Role::User | Role::Assistant => {
268 let len = estimate_content_length(&message.content);
269 len > 100 }
271 _ => false,
272 }
273}
274
275fn estimate_content_length(content: &MessageContent) -> usize {
277 match content {
278 MessageContent::Text(text) => text.len(),
279 MessageContent::Blocks(blocks) => blocks
280 .iter()
281 .map(|b| match b {
282 ContentBlock::Text { text } => text.len(),
283 ContentBlock::ToolUse { input, .. } => input.to_string().len(),
284 ContentBlock::ToolResult { content, .. } => content.len(),
285 ContentBlock::Thinking { thinking, .. } => thinking.len(),
286 _ => 0,
287 })
288 .sum(),
289 }
290}
291
292fn get_content_preview(message: &Message, max_len: usize) -> String {
294 match &message.content {
295 MessageContent::Text(text) => {
296 if text.len() > max_len {
297 text[..max_len].to_string() + "..."
298 } else {
299 text.clone()
300 }
301 }
302 MessageContent::Blocks(blocks) => {
303 let preview: Vec<String> = blocks
304 .iter()
305 .take(3)
306 .map(|b| match b {
307 ContentBlock::Text { text } => text.chars().take(100).collect(),
308 ContentBlock::ToolUse { name, .. } => format!("[Tool: {}]", name),
309 ContentBlock::ToolResult { content, .. } => {
310 content.chars().take(100).collect::<String>() + "..."
311 }
312 _ => "...".to_string(),
313 })
314 .collect();
315 preview.join(" | ")
316 }
317 }
318}
319
320fn build_ai_score_prompt(content: &str, mode: AiCompressionMode) -> String {
322 match mode {
323 AiCompressionMode::Light => format!(
324 "判断这段内容对当前任务的重要性(0-30分,0=无关,30=关键):\n{}",
325 content
326 ),
327 AiCompressionMode::Deep => format!(
328 "深入分析这段内容的重要性,考虑:\n1. 是否包含关键决策\n2. 是否包含未完成任务\n3. 是否包含敏感指令\n输出重要性评分(0-30分):\n{}",
329 content
330 ),
331 AiCompressionMode::None => String::new(),
332 }
333}
334
335fn extract_text_from_response(response: &crate::providers::ChatResponse) -> String {
337 response
338 .content
339 .iter()
340 .filter_map(|b| {
341 if let ContentBlock::Text { text } = b {
342 Some(text.clone())
343 } else {
344 None
345 }
346 })
347 .collect::<Vec<_>>()
348 .join("\n")
349}
350
351fn parse_ai_score(text: &str) -> Result<f64> {
353 let text = text.trim();
355
356 if let Ok(score) = text.parse::<f64>() {
358 return Ok(score.clamp(0.0, 30.0));
359 }
360
361 for line in text.lines() {
363 let lower = line.to_lowercase();
364 if lower.contains("评分") || lower.contains("score") {
365 let nums: Vec<f64> = line
367 .split_whitespace()
368 .filter_map(|s| s.parse::<f64>().ok())
369 .collect();
370 if let Some(score) = nums.first() {
371 return Ok(score.clamp(0.0, 30.0));
372 }
373 }
374 }
375
376 Ok(10.0)
378}
379
380const AI_SCORE_SYSTEM_PROMPT: &str = r#"你是一个内容重要性评估助手。快速判断内容的重要性并输出评分。
381
382输出要求:
383- 仅输出一个数字(0-30)
384- 0 = 完全不重要,可以删除
385- 10 = 一般重要,可保留可删除
386- 20 = 重要,建议保留
387- 30 = 关键,必须保留
388
389请直接输出评分数字。"#;
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn test_score_by_rules_first_message() {
397 let weights = PhaseWeights::balanced();
398 let message = Message {
399 role: Role::User,
400 content: MessageContent::Text("Hello".to_string()),
401 };
402 let score = score_by_rules(&message, 0, &weights);
403 assert!(score > 100.0); }
405
406 #[test]
407 fn test_score_by_rules_sensitive() {
408 let weights = PhaseWeights::balanced();
409 let message = Message {
410 role: Role::User,
411 content: MessageContent::Text("不要删除这个文件".to_string()),
412 };
413 let score = score_by_rules(&message, 5, &weights);
414 assert!(score > 50.0); }
416
417 #[test]
418 fn test_contains_sensitive_instructions() {
419 assert!(contains_sensitive_instructions("不要删除"));
420 assert!(contains_sensitive_instructions("must not do this"));
421 assert!(!contains_sensitive_instructions("普通文本"));
422 }
423
424 #[test]
425 fn test_is_critical_tool() {
426 assert!(is_critical_tool("write"));
427 assert!(is_critical_tool("bash"));
428 assert!(!is_critical_tool("read"));
429 }
430
431 #[test]
432 fn test_parse_ai_score() {
433 assert_eq!(parse_ai_score("15").unwrap(), 15.0);
434 assert_eq!(parse_ai_score("评分: 20").unwrap(), 20.0);
435 assert_eq!(parse_ai_score("score: 25").unwrap(), 25.0);
436 assert_eq!(parse_ai_score("unknown").unwrap(), 10.0); }
438}