use std::collections::HashMap;
use async_trait::async_trait;
use rand::Rng;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use uuid::Uuid;
use crate::types::RequestLog;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RiskBand {
Low,
Medium,
High,
}
impl RiskBand {
#[must_use]
pub fn from_degraded_pct(p: f64) -> Self {
if p <= 5.0 {
Self::Low
} else if p <= 15.0 {
Self::Medium
} else {
Self::High
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum JudgeVerdict {
Acceptable,
Degraded,
Unclear,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SampleScore {
pub request_id: Uuid,
pub verdict: JudgeVerdict,
pub reason: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QualityResult {
pub sample_size: u32,
pub acceptable_count: u32,
pub degraded_count: u32,
pub unclear_count: u32,
pub degraded_pct: f64,
pub risk_band: RiskBand,
pub sampled_examples: Vec<SampleScore>,
pub caveats: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QualityConfig {
pub body_logging_enabled: bool,
pub total_samples: u32,
pub budget_usd: f64,
pub cost_per_judge_call_usd: f64,
pub seed: u64,
}
#[derive(Debug, Error)]
pub enum QualityError {
#[error("body logging not opted in by org — Tier 3 quality scoring requires raw bodies")]
BodyLoggingDisabled,
#[error("projected judge cost ${cost:.4} exceeds budget ${budget:.4}")]
OverBudget {
cost: f64,
budget: f64,
},
#[error("judge: {0}")]
Judge(String),
#[error("no sampled requests carry both prompt + response bodies")]
NoScorable,
}
#[async_trait]
pub trait JudgeProvider: Send + Sync {
async fn judge(
&self,
input_body: &str,
original_response: &str,
proposed_response: &str,
) -> Result<(JudgeVerdict, String), QualityError>;
}
pub struct MockJudge {
pub verdict: JudgeVerdict,
pub reason: String,
}
#[async_trait]
impl JudgeProvider for MockJudge {
async fn judge(
&self,
_input: &str,
_orig: &str,
_prop: &str,
) -> Result<(JudgeVerdict, String), QualityError> {
Ok((self.verdict, self.reason.clone()))
}
}
fn stratify(req: &RequestLog) -> (Option<String>, &'static str) {
let bucket = match req.input_tokens {
0..=500 => "small",
501..=4000 => "medium",
_ => "large",
};
(req.tag.clone(), bucket)
}
#[must_use]
pub fn stratified_sample(requests: &[RequestLog], n: u32, seed: u64) -> Vec<Uuid> {
if n == 0 || requests.is_empty() {
return Vec::new();
}
let mut sorted: Vec<&RequestLog> = requests.iter().collect();
sorted.sort_by_key(|r| r.id);
let mut by_stratum: HashMap<(Option<String>, &'static str), Vec<Uuid>> = HashMap::new();
for r in &sorted {
by_stratum.entry(stratify(r)).or_default().push(r.id);
}
let total = requests.len() as f64;
let n_f = f64::from(n);
let mut keys: Vec<_> = by_stratum.keys().cloned().collect();
keys.sort();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut out = Vec::new();
for k in keys {
let stratum = &by_stratum[&k];
let alloc = ((stratum.len() as f64 / total) * n_f).round() as usize;
let alloc = alloc.min(stratum.len());
if alloc == 0 {
continue;
}
let mut idx: Vec<usize> = (0..stratum.len()).collect();
for i in (1..idx.len()).rev() {
let j = rng.gen_range(0..=i);
idx.swap(i, j);
}
for i in idx.into_iter().take(alloc) {
out.push(stratum[i]);
}
}
out.sort();
out.dedup();
if out.len() > n as usize {
out.truncate(n as usize);
}
out
}
pub async fn score_quality<F>(
requests: &[RequestLog],
config: &QualityConfig,
judge: &dyn JudgeProvider,
proposed_response_for: F,
) -> Result<QualityResult, QualityError>
where
F: Fn(&Uuid) -> Option<String>,
{
if !config.body_logging_enabled {
return Err(QualityError::BodyLoggingDisabled);
}
let projected_cost = config.cost_per_judge_call_usd * f64::from(config.total_samples);
if projected_cost > config.budget_usd {
return Err(QualityError::OverBudget {
cost: projected_cost,
budget: config.budget_usd,
});
}
let sampled_ids = stratified_sample(requests, config.total_samples, config.seed);
let by_id: HashMap<Uuid, &RequestLog> = requests.iter().map(|r| (r.id, r)).collect();
let mut scores = Vec::new();
let mut acceptable: u32 = 0;
let mut degraded: u32 = 0;
let mut unclear: u32 = 0;
for id in &sampled_ids {
let Some(req) = by_id.get(id) else { continue };
let Some(input) = req.body.as_ref() else {
continue;
};
let Some(original) = req.response_body.as_ref() else {
continue;
};
let Some(proposed) = proposed_response_for(id) else {
continue;
};
let (verdict, mut reason) = judge.judge(input, original, &proposed).await?;
if reason.len() > 200 {
let mut cut = 200;
while cut > 0 && !reason.is_char_boundary(cut) {
cut -= 1;
}
reason.truncate(cut);
}
match verdict {
JudgeVerdict::Acceptable => acceptable += 1,
JudgeVerdict::Degraded => degraded += 1,
JudgeVerdict::Unclear => unclear += 1,
}
scores.push(SampleScore {
request_id: *id,
verdict,
reason,
});
}
if scores.is_empty() {
return Err(QualityError::NoScorable);
}
let total_classified = f64::from(acceptable + degraded);
let degraded_pct = if total_classified > 0.0 {
(f64::from(degraded) / total_classified) * 100.0
} else {
0.0
};
let risk_band = RiskBand::from_degraded_pct(degraded_pct);
let mut caveats = Vec::new();
let unclear_share = f64::from(unclear) / scores.len() as f64;
if unclear_share > 0.20 {
caveats.push(format!(
"{:.0}% of sampled requests were Unclear — the judge couldn't classify. \
Risk band may be unreliable; consider a stronger judge model.",
unclear_share * 100.0
));
}
if scores.len() < 30 {
caveats.push(format!(
"Small quality sample ({} scored) — risk band has wide uncertainty.",
scores.len()
));
}
Ok(QualityResult {
sample_size: scores.len() as u32,
acceptable_count: acceptable,
degraded_count: degraded,
unclear_count: unclear,
degraded_pct,
risk_band,
sampled_examples: scores,
caveats,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn risk_band_thresholds() {
assert_eq!(RiskBand::from_degraded_pct(0.0), RiskBand::Low);
assert_eq!(RiskBand::from_degraded_pct(5.0), RiskBand::Low);
assert_eq!(RiskBand::from_degraded_pct(5.0001), RiskBand::Medium);
assert_eq!(RiskBand::from_degraded_pct(15.0), RiskBand::Medium);
assert_eq!(RiskBand::from_degraded_pct(15.0001), RiskBand::High);
assert_eq!(RiskBand::from_degraded_pct(100.0), RiskBand::High);
}
}