use crate::providers::{Message, MessageContent, ContentBlock, Role};
use crate::compress::{ConversationFocus, FocusPoint, FocusManager};
use anyhow::Result;
pub struct FocusScoreEvaluator {
use_ai: bool,
high_priority_threshold: f32,
preserve_threshold: f32,
}
impl Default for FocusScoreEvaluator {
fn default() -> Self {
Self {
use_ai: false, high_priority_threshold: 0.7,
preserve_threshold: 0.3,
}
}
}
impl FocusScoreEvaluator {
pub fn new() -> Self {
Self::default()
}
pub fn with_ai() -> Self {
Self {
use_ai: true,
..Self::default()
}
}
pub fn with_thresholds(high_priority: f32, preserve: f32) -> Self {
Self {
use_ai: false,
high_priority_threshold: high_priority,
preserve_threshold: preserve,
}
}
pub fn evaluate(&self, message: &Message, focus: &ConversationFocus) -> f32 {
self.evaluate_rule_based(message, focus)
}
pub fn evaluate_for_focus_point(&self, message: &Message, focus: &FocusPoint) -> f32 {
let text = self.extract_text(message);
let text_lower = text.to_lowercase();
let mut score = 0.0;
for keyword in &focus.keywords {
if text_lower.contains(&keyword.to_lowercase()) {
score += 0.15;
}
}
for entity in &focus.entities {
if text_lower.contains(&entity.to_lowercase()) {
score += 0.25;
}
}
if let Some(question) = &focus.core_question {
if self.questions_similar(&text, question) {
score += 0.4;
}
}
score *= focus.importance;
score *= focus.confidence;
score.clamp(0.0, 1.0)
}
pub fn evaluate_for_manager(&self, message: &Message, manager: &FocusManager) -> f32 {
let active_foci = manager.get_active_foci();
if active_foci.is_empty() {
return 0.5; }
let max_score = active_foci
.iter()
.map(|f| self.evaluate_for_focus_point(message, f))
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(0.5);
if let Some(current) = manager.current_focus() {
let current_score = self.evaluate_for_focus_point(message, current);
if current_score > 0.5 {
return (max_score + current_score * 0.2).min(1.0);
}
}
max_score
}
fn evaluate_rule_based(&self, message: &Message, focus: &ConversationFocus) -> f32 {
let text = self.extract_text(message);
let text_lower = text.to_lowercase();
let mut score: f32 = 0.0;
if let Some(topic) = &focus.current_topic {
let topic_keywords: Vec<&str> = topic.split(", ").collect();
for kw in topic_keywords {
if text_lower.contains(&kw.to_lowercase()) {
score += 0.2;
}
}
}
if let Some(question) = &focus.current_question {
if self.questions_similar(&text, question) {
score += 0.35;
}
}
for ctx in &focus.recent_context {
if text_lower.contains(&ctx.to_lowercase()) {
score += 0.1;
}
}
if matches!(message.role, Role::User) {
score *= 1.2;
}
if text.len() > 200 {
score *= 1.1;
}
score.clamp(0.0, 1.0)
}
fn questions_similar(&self, text: &str, question: &str) -> bool {
let text_words = self.extract_significant_words(text);
let question_words = self.extract_significant_words(question);
let common = text_words.intersection(&question_words).count();
let total = question_words.len().max(1);
common as f32 / total as f32 > 0.5
}
fn extract_significant_words(&self, text: &str) -> std::collections::HashSet<String> {
text.to_lowercase()
.split_whitespace()
.filter(|w| w.len() > 3)
.map(|s| s.to_string())
.collect()
}
fn extract_text(&self, message: &Message) -> String {
match &message.content {
MessageContent::Text(t) => t.clone(),
MessageContent::Blocks(blocks) => {
blocks.iter()
.filter_map(|b| {
if let ContentBlock::Text { text } = b {
Some(text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n")
}
}
}
pub fn should_preserve(&self, score: f32) -> bool {
score >= self.preserve_threshold
}
pub fn is_high_priority(&self, score: f32) -> bool {
score >= self.high_priority_threshold
}
pub fn preserve_threshold(&self) -> f32 {
self.preserve_threshold
}
pub fn high_priority_threshold(&self) -> f32 {
self.high_priority_threshold
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compress::FocusManager;
fn create_test_message(role: Role, text: &str) -> Message {
Message {
role,
content: MessageContent::Text(text.to_string()),
}
}
fn create_test_focus() -> ConversationFocus {
ConversationFocus {
current_topic: Some("API 性能优化".to_string()),
current_question: Some("如何减少 API 响应延迟?".to_string()),
recent_context: vec!["API".to_string(), "性能".to_string(), "延迟".to_string()],
topic_transitions: Vec::new(),
detected_at: 10,
}
}
fn create_test_focus_point() -> FocusPoint {
FocusPoint::new(
"focus-1".to_string(),
"API 性能优化".to_string(),
vec!["API".to_string(), "性能".to_string(), "延迟".to_string()],
vec!["api.rs".to_string(), "handler()".to_string()],
Some("如何减少延迟?".to_string()),
0,
).with_importance(0.8)
.with_confidence(0.85)
}
#[test]
fn test_evaluator_creation() {
let evaluator = FocusScoreEvaluator::new();
assert!(!evaluator.use_ai);
assert_eq!(evaluator.high_priority_threshold, 0.7);
assert_eq!(evaluator.preserve_threshold, 0.3);
}
#[test]
fn test_evaluator_with_thresholds() {
let evaluator = FocusScoreEvaluator::with_thresholds(0.8, 0.4);
assert_eq!(evaluator.high_priority_threshold, 0.8);
assert_eq!(evaluator.preserve_threshold, 0.4);
}
#[test]
fn test_evaluator_high_relevance() {
let evaluator = FocusScoreEvaluator::new();
let focus = create_test_focus();
let message = create_test_message(Role::User, "API 性能 延迟 优化");
let score = evaluator.evaluate(&message, &focus);
assert!(score > 0.0, "Score should be positive for message with keywords");
}
#[test]
fn test_evaluator_low_relevance() {
let evaluator = FocusScoreEvaluator::new();
let focus = create_test_focus();
let message = create_test_message(Role::User, "今天天气晴朗阳光明媚");
let score = evaluator.evaluate(&message, &focus);
assert!(score < 0.3, "Score should be low for irrelevant message");
}
#[test]
fn test_evaluator_medium_relevance() {
let evaluator = FocusScoreEvaluator::new();
let focus = create_test_focus();
let message = create_test_message(Role::Assistant, "代码中 API 调用需要注意错误处理");
let score = evaluator.evaluate(&message, &focus);
assert!(score >= 0.1 && score <= 0.5, "Score should be medium");
}
#[test]
fn test_evaluator_for_focus_point() {
let evaluator = FocusScoreEvaluator::new();
let focus = create_test_focus_point();
let message = create_test_message(Role::User, "api.rs 中的 handler 函数延迟太高");
let score = evaluator.evaluate_for_focus_point(&message, &focus);
assert!(score > 0.3, "Should match keywords and entities");
}
#[test]
fn test_evaluator_for_manager() {
let evaluator = FocusScoreEvaluator::new();
let mut manager = FocusManager::new();
manager.add_focus(create_test_focus_point());
let message = create_test_message(Role::User, "如何优化 API 性能?");
let score = evaluator.evaluate_for_manager(&message, &manager);
assert!(score > 0.0);
}
#[test]
fn test_should_preserve() {
let evaluator = FocusScoreEvaluator::with_thresholds(0.7, 0.3);
assert!(evaluator.should_preserve(0.5));
assert!(evaluator.should_preserve(0.3));
assert!(!evaluator.should_preserve(0.2));
}
#[test]
fn test_is_high_priority() {
let evaluator = FocusScoreEvaluator::with_thresholds(0.7, 0.3);
assert!(evaluator.is_high_priority(0.8));
assert!(evaluator.is_high_priority(0.7));
assert!(!evaluator.is_high_priority(0.6));
}
#[test]
fn test_questions_similar() {
let evaluator = FocusScoreEvaluator::new();
assert!(evaluator.questions_similar(
"API 响应 延迟 太高怎么办",
"如何减少 API 响应 延迟"
));
assert!(!evaluator.questions_similar(
"天气很好",
"API 响应 延迟 太高怎么办"
));
}
#[test]
fn test_user_message_boost() {
let evaluator = FocusScoreEvaluator::new();
let focus = create_test_focus();
let user_msg = create_test_message(Role::User, "API 性能问题");
let assistant_msg = create_test_message(Role::Assistant, "API 性能问题");
let user_score = evaluator.evaluate(&user_msg, &focus);
let assistant_score = evaluator.evaluate(&assistant_msg, &focus);
assert!(user_score >= assistant_score, "User message should have higher score");
}
#[test]
fn test_extract_significant_words() {
let evaluator = FocusScoreEvaluator::new();
let words = evaluator.extract_significant_words("about API performance optimization");
assert!(words.contains("about"));
assert!(words.contains("performance"));
assert!(words.contains("optimization"));
}
}