use anyhow::Result;
use crate::providers::{ContentBlock, Message, MessageContent, Provider, Role};
use super::types::{AiCompressionMode, DependencyGraph, PhaseWeights, ScoredMessage};
pub struct Scorer {
fast_model: Option<Box<dyn Provider>>,
}
impl Scorer {
pub fn new_rule_only() -> Self {
Self { fast_model: None }
}
pub fn new_with_ai(fast_model: Box<dyn Provider>) -> Self {
Self { fast_model: Some(fast_model) }
}
pub async fn score_all(
&self,
messages: &[Message],
weights: &PhaseWeights,
deps: &DependencyGraph,
ai_mode: AiCompressionMode,
) -> Result<Vec<ScoredMessage>> {
let mut scored: Vec<ScoredMessage> = Vec::new();
for (idx, msg) in messages.iter().enumerate() {
let base_score = score_by_rules(msg, idx, weights);
scored.push(ScoredMessage::new(idx, msg.clone(), base_score));
}
if ai_mode != AiCompressionMode::None && self.fast_model.is_some() {
for sm in &mut scored {
if should_ai_score(&sm.message) {
let ai_score = self.score_with_ai(&sm.message, ai_mode).await?;
sm.with_ai_score(ai_score);
}
}
}
apply_dependency_bonus(&mut scored, deps, weights.dependency_pair_bonus);
Ok(scored)
}
async fn score_with_ai(
&self,
message: &Message,
mode: AiCompressionMode,
) -> Result<f64> {
if self.fast_model.is_none() {
return Ok(0.0);
}
let content_preview = get_content_preview(message, 500);
let prompt = build_ai_score_prompt(&content_preview, mode);
let provider = self.fast_model.as_ref().unwrap();
let response = provider.chat(crate::providers::ChatRequest {
messages: vec![Message {
role: Role::User,
content: MessageContent::Text(prompt),
}],
tools: vec![],
system: Some(AI_SCORE_SYSTEM_PROMPT.to_string()),
think: false,
max_tokens: 100,
server_tools: vec![],
enable_caching: false,
}).await?;
let score_text = extract_text_from_response(&response);
parse_ai_score(&score_text)
}
}
pub fn score_by_rules(message: &Message, index: usize, weights: &PhaseWeights) -> f64 {
let mut score: f64 = 10.0;
if index == 0 {
score += weights.first_msg_bonus;
}
match message.role {
Role::User => {
score += weights.user_msg_bonus;
}
Role::Assistant => {
score += 5.0; }
Role::Tool => {
score += weights.tool_result_bonus;
}
Role::System => {
score += 40.0; }
}
score += content_score(&message.content, weights);
score
}
fn content_score(content: &MessageContent, weights: &PhaseWeights) -> f64 {
let mut score: f64 = 0.0;
match content {
MessageContent::Text(text) => {
if contains_sensitive_instructions(text) {
score += 50.0;
}
let keywords = ["决定", "decision", "重要", "important", "关键", "key", "完成", "done"];
for kw in keywords {
if text.to_lowercase().contains(kw) {
score += 15.0;
}
}
}
MessageContent::Blocks(blocks) => {
for block in blocks {
match block {
ContentBlock::ToolUse { name, .. } => {
score += weights.tool_use_bonus;
if is_critical_tool(name) {
score += weights.critical_tool_bonus;
}
if name == "todo_write" {
score += 60.0;
}
if name == "ask" {
score += 50.0;
}
}
ContentBlock::ToolResult { content, .. } => {
score += weights.tool_result_bonus;
if contains_sensitive_instructions(content) {
score += 30.0;
}
if content.contains("TodoWrite") || content.contains("todo") {
score += 40.0;
}
if content.contains("AskUserQuestion") || content.contains("answer") {
score += 30.0;
}
}
ContentBlock::Thinking { thinking, .. } => {
if thinking.contains("决定") || thinking.contains("问题") || thinking.contains("关键") {
score += 30.0;
}
}
ContentBlock::Text { text } => {
if contains_sensitive_instructions(text) {
score += 50.0;
}
}
_ => {}
}
}
}
}
score
}
fn apply_dependency_bonus(
scored: &mut [ScoredMessage],
deps: &DependencyGraph,
bonus: f64,
) {
for dep in &deps.dependencies {
if let Some(sm) = scored.get_mut(dep.tool_use_idx) {
sm.with_dependency_bonus(bonus);
}
if let Some(sm) = scored.get_mut(dep.tool_result_idx) {
sm.with_dependency_bonus(bonus);
}
if dep.is_critical {
if let Some(sm) = scored.get_mut(dep.tool_use_idx) {
sm.with_dependency_bonus(bonus * 0.5);
}
if let Some(sm) = scored.get_mut(dep.tool_result_idx) {
sm.with_dependency_bonus(bonus * 0.5);
}
}
}
}
fn is_critical_tool(name: &str) -> bool {
let critical_tools = ["write", "edit", "multi_edit", "bash"];
critical_tools.contains(&name)
}
fn contains_sensitive_instructions(text: &str) -> bool {
let lower = text.to_lowercase();
let patterns = [
"不要", "禁止", "必须", "不允许",
"never", "must not", "do not", "important",
];
patterns.iter().any(|p| lower.contains(p))
}
fn should_ai_score(message: &Message) -> bool {
match message.role {
Role::User | Role::Assistant => {
let len = estimate_content_length(&message.content);
len > 100 }
_ => false,
}
}
fn estimate_content_length(content: &MessageContent) -> usize {
match content {
MessageContent::Text(text) => text.len(),
MessageContent::Blocks(blocks) => {
blocks.iter().map(|b| {
match b {
ContentBlock::Text { text } => text.len(),
ContentBlock::ToolUse { input, .. } => input.to_string().len(),
ContentBlock::ToolResult { content, .. } => content.len(),
ContentBlock::Thinking { thinking, .. } => thinking.len(),
_ => 0,
}
}).sum()
}
}
}
fn get_content_preview(message: &Message, max_len: usize) -> String {
match &message.content {
MessageContent::Text(text) => {
if text.len() > max_len {
text[..max_len].to_string() + "..."
} else {
text.clone()
}
}
MessageContent::Blocks(blocks) => {
let preview: Vec<String> = blocks.iter().take(3).map(|b| {
match b {
ContentBlock::Text { text } => text.chars().take(100).collect(),
ContentBlock::ToolUse { name, .. } => format!("[Tool: {}]", name),
ContentBlock::ToolResult { content, .. } => {
content.chars().take(100).collect::<String>() + "..."
},
_ => "...".to_string(),
}
}).collect();
preview.join(" | ")
}
}
}
fn build_ai_score_prompt(content: &str, mode: AiCompressionMode) -> String {
match mode {
AiCompressionMode::Light => format!(
"判断这段内容对当前任务的重要性(0-30分,0=无关,30=关键):\n{}",
content
),
AiCompressionMode::Deep => format!(
"深入分析这段内容的重要性,考虑:\n1. 是否包含关键决策\n2. 是否包含未完成任务\n3. 是否包含敏感指令\n输出重要性评分(0-30分):\n{}",
content
),
AiCompressionMode::None => String::new(),
}
}
fn extract_text_from_response(response: &crate::providers::ChatResponse) -> String {
response.content.iter()
.filter_map(|b| {
if let ContentBlock::Text { text } = b {
Some(text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n")
}
fn parse_ai_score(text: &str) -> Result<f64> {
let text = text.trim();
if let Ok(score) = text.parse::<f64>() {
return Ok(score.clamp(0.0, 30.0));
}
for line in text.lines() {
let lower = line.to_lowercase();
if lower.contains("评分") || lower.contains("score") {
let nums: Vec<f64> = line
.split_whitespace()
.filter_map(|s| s.parse::<f64>().ok())
.collect();
if let Some(score) = nums.first() {
return Ok(score.clamp(0.0, 30.0));
}
}
}
Ok(10.0)
}
const AI_SCORE_SYSTEM_PROMPT: &str = r#"你是一个内容重要性评估助手。快速判断内容的重要性并输出评分。
输出要求:
- 仅输出一个数字(0-30)
- 0 = 完全不重要,可以删除
- 10 = 一般重要,可保留可删除
- 20 = 重要,建议保留
- 30 = 关键,必须保留
请直接输出评分数字。"#;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_score_by_rules_first_message() {
let weights = PhaseWeights::balanced();
let message = Message {
role: Role::User,
content: MessageContent::Text("Hello".to_string()),
};
let score = score_by_rules(&message, 0, &weights);
assert!(score > 100.0); }
#[test]
fn test_score_by_rules_sensitive() {
let weights = PhaseWeights::balanced();
let message = Message {
role: Role::User,
content: MessageContent::Text("不要删除这个文件".to_string()),
};
let score = score_by_rules(&message, 5, &weights);
assert!(score > 50.0); }
#[test]
fn test_contains_sensitive_instructions() {
assert!(contains_sensitive_instructions("不要删除"));
assert!(contains_sensitive_instructions("must not do this"));
assert!(!contains_sensitive_instructions("普通文本"));
}
#[test]
fn test_is_critical_tool() {
assert!(is_critical_tool("write"));
assert!(is_critical_tool("bash"));
assert!(!is_critical_tool("read"));
}
#[test]
fn test_parse_ai_score() {
assert_eq!(parse_ai_score("15").unwrap(), 15.0);
assert_eq!(parse_ai_score("评分: 20").unwrap(), 20.0);
assert_eq!(parse_ai_score("score: 25").unwrap(), 25.0);
assert_eq!(parse_ai_score("unknown").unwrap(), 10.0); }
}