use tracing::debug;
use crate::profiles::Tier;
const WEIGHTS: [f64; 15] = [
0.08, 0.15, 0.18, 0.10, 0.05, 0.02, 0.12, 0.05, 0.04, 0.06, 0.04, 0.03, 0.04, 0.02, 0.02, ];
pub fn classify(messages: &[solvela_protocol::ChatMessage], has_tools: bool) -> ScorerResult {
let text = concatenate_user_content(messages);
let word_count = text.split_whitespace().count();
let msg_count = messages.len();
let mut signals = [0.0_f64; 15];
signals[0] = if word_count < 15 {
-0.5
} else if word_count < 50 {
-0.2
} else if word_count > 200 {
0.5
} else {
0.0
};
signals[1] = score_code_presence(&text);
signals[2] = score_keyword_density(
&text,
&[
"prove",
"theorem",
"step by step",
"reason",
"analyze",
"evaluate",
"compare and contrast",
"think through",
"explain why",
],
);
signals[3] = score_keyword_density(
&text,
&[
"algorithm",
"kubernetes",
"database",
"architecture",
"distributed",
"concurrent",
"protocol",
"optimization",
"benchmark",
],
);
signals[4] = score_keyword_density(
&text,
&[
"story",
"poem",
"brainstorm",
"creative",
"imagine",
"fiction",
"narrative",
],
);
signals[5] = -score_keyword_density(
&text,
&[
"what is",
"define",
"translate",
"hello",
"hi",
"thanks",
"yes",
"no",
],
);
signals[6] = score_keyword_density(
&text,
&[
"first", "then", "next", "finally", "step 1", "step 2", "1.", "2.", "3.",
],
);
let question_marks = text.matches('?').count();
signals[7] = match question_marks {
0 => 0.0,
1 => 0.1,
2..=3 => 0.4,
_ => 0.8,
};
signals[8] = score_keyword_density(
&text,
&[
"read file",
"write file",
"edit",
"deploy",
"execute",
"run command",
"install",
],
);
signals[9] = score_math_presence(&text);
let avg_word_len = if word_count > 0 {
text.split_whitespace().map(|w| w.len() as f64).sum::<f64>() / word_count as f64
} else {
0.0
};
signals[10] = if avg_word_len > 7.0 {
0.6
} else if avg_word_len > 5.5 {
0.2
} else {
0.0
};
signals[11] = match msg_count {
0..=2 => 0.0,
3..=5 => 0.3,
6..=10 => 0.6,
_ => 1.0,
};
signals[12] = if has_tools { 0.8 } else { 0.0 };
signals[13] = score_keyword_density(&text, &["json", "csv", "xml", "markdown", "structured"]);
signals[14] = score_keyword_density(
&text,
&[
"medical",
"legal",
"scientific",
"clinical",
"regulatory",
"compliance",
"diagnosis",
],
);
let score: f64 = signals.iter().zip(WEIGHTS.iter()).map(|(s, w)| s * w).sum();
let tier = Tier::from_score(score);
debug!(
score = format!("{score:.4}"),
tier = ?tier,
word_count,
msg_count,
"request classified"
);
ScorerResult {
score,
tier,
signals,
}
}
#[derive(Debug, Clone)]
pub struct ScorerResult {
pub score: f64,
pub tier: Tier,
pub signals: [f64; 15],
}
fn concatenate_user_content(messages: &[solvela_protocol::ChatMessage]) -> String {
messages
.iter()
.filter(|m| m.role == solvela_protocol::Role::User)
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join(" ")
.to_lowercase()
}
fn score_keyword_density(text: &str, keywords: &[&str]) -> f64 {
let matches = keywords.iter().filter(|k| text.contains(**k)).count();
match matches {
0 => 0.0,
1 => 0.3,
2 => 0.6,
_ => 1.0,
}
}
fn score_code_presence(text: &str) -> f64 {
let mut score = 0.0;
if text.contains("```") || text.contains('`') {
score += 0.4;
}
let code_keywords = [
"function", "class", "def ", "fn ", "impl ", "struct ", "const ", "let ", "var ", "import",
"return", "async", "await",
];
let matches = code_keywords.iter().filter(|k| text.contains(**k)).count();
score += (matches as f64 * 0.15).min(0.6);
score.min(1.0)
}
fn score_math_presence(text: &str) -> f64 {
let mut score = 0.0;
let math_indicators = ["=", "+", "-", "*", "/", "∑", "∫", "∀", "∃", "≥", "≤"];
let matches = math_indicators
.iter()
.filter(|k| text.contains(**k))
.count();
score += (matches as f64 * 0.15).min(0.5);
if text.contains("equation") || text.contains("formula") || text.contains("calculate") {
score += 0.4;
}
score.min(1.0)
}
#[cfg(test)]
mod tests {
use super::*;
use solvela_protocol::{ChatMessage, Role};
fn user_msg(content: &str) -> ChatMessage {
ChatMessage {
role: Role::User,
content: content.to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
#[test]
fn test_simple_greeting() {
let messages = vec![user_msg("Hello!")];
let result = classify(&messages, false);
assert_eq!(result.tier, Tier::Simple);
}
#[test]
fn test_code_request() {
let messages = vec![user_msg(
"Write a function that implements a distributed consensus algorithm with async/await",
)];
let result = classify(&messages, false);
assert!(
result.tier == Tier::Complex || result.tier == Tier::Medium,
"got {:?}",
result.tier
);
}
#[test]
fn test_reasoning_request() {
let messages = vec![user_msg(
"Prove step by step that the algorithm is correct. Analyze the time complexity and evaluate whether it's optimal. \
Compare and contrast with alternative approaches, then explain why the chosen algorithm is better. \
Think through edge cases and reason about correctness guarantees.",
)];
let result = classify(&messages, false);
assert!(
result.tier == Tier::Reasoning || result.tier == Tier::Complex,
"got {:?} with score {:.4}",
result.tier,
result.score
);
}
#[test]
fn test_tool_usage_boosts_score() {
let messages = vec![user_msg("Search the web for recent news")];
let without_tools = classify(&messages, false);
let with_tools = classify(&messages, true);
assert!(with_tools.score > without_tools.score);
}
#[test]
fn test_weights_sum_to_one() {
let sum: f64 = WEIGHTS.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-10,
"weights sum to {sum}, expected 1.0"
);
}
}