use std::collections::VecDeque;
use serde::{Deserialize, Serialize};
use crate::any::AnyProvider;
use crate::provider::{LlmProvider, Message, Role};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum ClassifierMode {
#[default]
Heuristic,
Judge,
}
#[derive(Debug, Clone)]
pub struct QualityVerdict {
pub score: f64,
pub should_escalate: bool,
pub reason: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ProviderQualityHistory {
scores: VecDeque<f64>,
}
impl ProviderQualityHistory {
pub fn push(&mut self, score: f64, window: usize) {
self.scores.push_back(score);
if self.scores.len() > window {
self.scores.pop_front();
}
}
#[must_use]
pub fn mean(&self) -> f64 {
if self.scores.is_empty() {
return 0.5;
}
#[allow(clippy::cast_precision_loss)]
let len = self.scores.len() as f64;
self.scores.iter().sum::<f64>() / len
}
}
#[derive(Debug, Clone, Default)]
pub struct CascadeState {
pub provider_quality: std::collections::HashMap<String, ProviderQualityHistory>,
pub window_size: usize,
}
impl CascadeState {
#[must_use]
pub fn new(window_size: usize) -> Self {
Self {
provider_quality: std::collections::HashMap::new(),
window_size,
}
}
pub fn record(&mut self, provider: &str, score: f64) {
let window = self.window_size;
self.provider_quality
.entry(provider.to_owned())
.or_default()
.push(score, window);
}
#[must_use]
pub fn mean(&self, provider: &str) -> f64 {
self.provider_quality
.get(provider)
.map_or(0.5, ProviderQualityHistory::mean)
}
}
#[must_use]
pub fn heuristic_score(response: &str) -> QualityVerdict {
if response.trim().len() < 10 {
let score = if response.trim().is_empty() { 0.0 } else { 0.1 };
return QualityVerdict {
should_escalate: false,
score,
reason: "response too short or empty".to_owned(),
};
}
let length_score = length_signal(response);
let rep_ratio = repetition_ratio(response);
let coherence_score = coherence_signal(response);
let base_score = (length_score * 0.50 + coherence_score * 0.50).clamp(0.0, 1.0);
let score = if rep_ratio > 0.5 {
base_score * 0.3
} else {
base_score
}
.clamp(0.0, 1.0);
let repetition_score = 1.0 - rep_ratio;
let reason = if length_score < 0.3 {
"response too short or empty".to_owned()
} else if repetition_score < 0.5 {
"high trigram repetition detected".to_owned()
} else if coherence_score < 0.3 {
"incoherent / fragmented response".to_owned()
} else {
format!(
"heuristic ok (length={length_score:.2}, rep={repetition_score:.2}, coh={coherence_score:.2})"
)
};
QualityVerdict {
should_escalate: false, score,
reason,
}
}
fn length_signal(response: &str) -> f64 {
let len = response.trim().len();
match len {
0 => 0.0,
1..=10 => 0.1,
11..=30 => 0.3,
31..=50 => 0.6,
_ => 1.0,
}
}
fn repetition_ratio(response: &str) -> f64 {
let words: Vec<&str> = response.split_whitespace().collect();
if words.len() < 4 {
return 0.0;
}
let mut trigrams = std::collections::HashMap::<(&str, &str, &str), usize>::new();
for w in words.windows(3) {
*trigrams.entry((w[0], w[1], w[2])).or_insert(0) += 1;
}
let total = trigrams.values().sum::<usize>();
let repeated = trigrams.values().filter(|&&c| c > 1).sum::<usize>();
if total == 0 {
return 0.0;
}
#[allow(clippy::cast_precision_loss)]
let ratio = repeated as f64 / total as f64;
ratio.clamp(0.0, 1.0)
}
fn coherence_signal(response: &str) -> f64 {
let text = response.trim();
if text.is_empty() {
return 0.0;
}
let sentence_count = text
.split(['.', '!', '?', '\n'])
.filter(|s| !s.trim().is_empty())
.count();
let word_count = text.split_whitespace().count();
if word_count == 0 {
return 0.0;
}
if word_count < 3 {
return 0.2;
}
#[allow(clippy::cast_precision_loss)]
let avg_sentence_len = if sentence_count > 0 {
word_count as f64 / sentence_count as f64
} else {
word_count as f64
};
if avg_sentence_len < 3.0 { 0.4 } else { 1.0 }
}
pub async fn judge_score(judge: &AnyProvider, response: &str) -> Option<f64> {
let prompt = format!(
"Rate the following AI response on a scale from 0 to 10, \
where 0 is completely useless or degenerate and 10 is high quality and coherent. \
Reply with ONLY a single number (integer or decimal, e.g. 7 or 8.5). \
Do not add any explanation.\n\nResponse to rate:\n{response}"
);
let messages = vec![Message::from_legacy(Role::User, prompt)];
let reply = judge.chat(&messages).await.ok()?;
parse_judge_score(&reply)
}
fn parse_judge_score(reply: &str) -> Option<f64> {
for token in reply.split_whitespace() {
let clean: String = token
.chars()
.filter(|c| c.is_ascii_digit() || *c == '.')
.collect();
if clean.is_empty() {
continue;
}
if let Ok(n) = clean.parse::<f64>()
&& n.is_finite()
&& n >= 0.0
{
return Some((n / 10.0).clamp(0.0, 1.0));
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_response_scores_zero() {
let v = heuristic_score("");
assert!(v.score < 0.15, "empty response score={}", v.score);
}
#[test]
fn very_short_response_scores_low() {
let v = heuristic_score("ok");
assert!(v.score < 0.5, "short response score={}", v.score);
}
#[test]
fn normal_response_scores_high() {
let v = heuristic_score(
"The answer to your question is straightforward. \
First, consider the context. Then analyze the options available. \
Finally, choose the best approach for your use case.",
);
assert!(v.score >= 0.7, "normal response score={}", v.score);
}
#[test]
fn highly_repetitive_response_scores_low() {
let rep = "word word word word word word word word word word \
word word word word word word word word word word \
word word word word word word word word word word";
let v = heuristic_score(rep);
assert!(v.score < 0.5, "repetitive response score={}", v.score);
}
#[test]
fn heuristic_score_never_panics_on_unicode() {
let inputs = [
"Привет мир!",
"こんにちは",
"🦀🦀🦀",
"\0\0\0",
&"a ".repeat(1000),
];
for input in &inputs {
let v = heuristic_score(input);
assert!(
(0.0..=1.0).contains(&v.score),
"score out of range for input: {input:?}"
);
}
}
#[test]
fn cascade_state_records_and_retrieves() {
let mut state = CascadeState::new(5);
state.record("ollama", 0.3);
state.record("ollama", 0.7);
let mean = state.mean("ollama");
assert!((mean - 0.5).abs() < 0.01);
}
#[test]
fn cascade_state_window_evicts_old_scores() {
let mut state = CascadeState::new(3);
state.record("p", 0.0);
state.record("p", 0.0);
state.record("p", 0.0);
state.record("p", 1.0); let mean = state.mean("p");
assert!(
(mean - (1.0 / 3.0)).abs() < 0.01,
"expected ~0.333, got {mean}"
);
}
#[test]
fn cascade_state_unknown_provider_returns_neutral() {
let state = CascadeState::new(10);
assert!((state.mean("unknown") - 0.5).abs() < f64::EPSILON);
}
#[test]
fn classifier_mode_serde_roundtrip() {
let json = serde_json::to_string(&ClassifierMode::Heuristic).unwrap();
assert_eq!(json, r#""heuristic""#);
let back: ClassifierMode = serde_json::from_str(&json).unwrap();
assert_eq!(back, ClassifierMode::Heuristic);
let json = serde_json::to_string(&ClassifierMode::Judge).unwrap();
assert_eq!(json, r#""judge""#);
let back: ClassifierMode = serde_json::from_str(&json).unwrap();
assert_eq!(back, ClassifierMode::Judge);
}
#[test]
fn classifier_mode_default_is_heuristic() {
assert_eq!(ClassifierMode::default(), ClassifierMode::Heuristic);
}
#[test]
fn parse_judge_score_integer() {
let score = parse_judge_score("7").unwrap();
assert!(
(score - 0.7).abs() < f64::EPSILON,
"expected 0.7, got {score}"
);
}
#[test]
fn parse_judge_score_decimal() {
let score = parse_judge_score("8.5").unwrap();
assert!((score - 0.85).abs() < 1e-9, "expected 0.85, got {score}");
}
#[test]
fn parse_judge_score_with_surrounding_text() {
let score = parse_judge_score("I would rate this response a 6 out of 10.").unwrap();
assert!(
(score - 0.6).abs() < f64::EPSILON,
"expected 0.6, got {score}"
);
}
#[test]
fn parse_judge_score_ten_clamps_to_one() {
let score = parse_judge_score("10").unwrap();
assert!(
(score - 1.0).abs() < f64::EPSILON,
"expected 1.0, got {score}"
);
}
#[test]
fn parse_judge_score_zero_is_valid() {
let score = parse_judge_score("0").unwrap();
assert!(score.abs() < f64::EPSILON, "expected 0.0, got {score}");
}
#[test]
fn parse_judge_score_garbage_returns_none() {
assert!(parse_judge_score("no number here").is_none());
assert!(parse_judge_score("").is_none());
}
#[test]
fn repetition_ratio_no_repetition() {
let ratio = repetition_ratio("the quick brown fox jumps over the lazy dog");
assert!(ratio < 0.3, "expected low repetition, got {ratio}");
}
#[test]
fn repetition_ratio_full_repetition() {
let text = "abc abc abc abc abc abc abc abc abc abc";
let ratio = repetition_ratio(text);
assert!(ratio > 0.5, "expected high repetition, got {ratio}");
}
}