use std::collections::BTreeMap;
use std::path::{Path, PathBuf};
use api_debug_lab::cases::Case;
use api_debug_lab::report::Diagnosis;
use api_debug_lab::rules::diagnose;
struct LabelledCase {
name: String,
case: Case,
}
fn fixtures_root() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("fixtures")
}
fn load_corpus() -> Vec<LabelledCase> {
let root = fixtures_root().join("cases");
let mut paths: Vec<PathBuf> = Vec::new();
walk(&root, &mut paths);
paths.sort();
let mut out: Vec<LabelledCase> = Vec::new();
for path in paths {
let raw = std::fs::read_to_string(&path)
.unwrap_or_else(|e| panic!("read {}: {e}", path.display()));
let v: serde_json::Value =
serde_json::from_str(&raw).unwrap_or_else(|e| panic!("parse {}: {e}", path.display()));
if !v
.as_object()
.is_some_and(|o| o.contains_key("expected_rule_id"))
{
continue;
}
let case = Case::load(path.to_str().unwrap(), &fixtures_root())
.unwrap_or_else(|e| panic!("load {}: {e}", path.display()));
out.push(LabelledCase {
name: case.name.clone(),
case,
});
}
assert!(
out.len() >= 30,
"expected at least 30 labelled cases in the corpus, found {}",
out.len()
);
out
}
fn walk(dir: &Path, out: &mut Vec<PathBuf>) {
let Ok(entries) = std::fs::read_dir(dir) else {
return;
};
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
walk(&path, out);
} else if path.file_name().and_then(|n| n.to_str()) == Some("case.json") {
out.push(path);
}
}
}
const ALL_RULES: [&str; 8] = [
"auth_missing",
"bad_json_payload",
"config_dns_error",
"idempotency_collision",
"rate_limited",
"timeout_retry",
"webhook_signature_mismatch",
"webhook_timestamp_stale",
];
fn rule_probabilities(diagnoses: &[Diagnosis]) -> BTreeMap<String, f32> {
let mut out: BTreeMap<String, f32> = ALL_RULES.iter().map(|r| (r.to_string(), 0.0)).collect();
for d in diagnoses {
out.insert(d.rule_id.clone(), d.confidence);
}
out
}
#[test]
fn primary_classification_is_perfect_on_corpus() {
let corpus = load_corpus();
let mut mismatches: Vec<String> = Vec::new();
for entry in &corpus {
let report = diagnose(&entry.case);
let actual = report.primary.as_ref().map(|d| d.rule_id.clone());
if actual != entry.case.expected_rule_id {
mismatches.push(format!(
"case {}: expected {:?}, got {:?}",
entry.name, entry.case.expected_rule_id, actual
));
}
}
assert!(
mismatches.is_empty(),
"primary classification mismatches:\n{}",
mismatches.join("\n")
);
}
#[test]
fn brier_score_below_threshold() {
const THRESHOLD: f32 = 0.05;
let corpus = load_corpus();
let mut sum_sq = 0.0_f32;
let mut n = 0usize;
for entry in &corpus {
let report = diagnose(&entry.case);
let mut diagnoses: Vec<Diagnosis> = Vec::new();
if let Some(p) = &report.primary {
diagnoses.push(p.clone());
}
diagnoses.extend(report.also_considered.iter().cloned());
let predicted = rule_probabilities(&diagnoses);
for rule in &ALL_RULES {
let predicted_p = *predicted.get(*rule).unwrap_or(&0.0);
let ground_truth = match &entry.case.expected_rule_id {
Some(g) if g == rule => 1.0_f32,
_ => 0.0_f32,
};
let err = predicted_p - ground_truth;
sum_sq += err * err;
n += 1;
}
}
let brier = sum_sq / n as f32;
assert!(
brier <= THRESHOLD,
"Brier score {:.4} exceeded threshold {:.4} (over {} (case, rule) pairs)",
brier,
THRESHOLD,
n
);
}
#[test]
fn per_rule_brier_below_threshold() {
const PER_RULE_THRESHOLD: f32 = 0.08;
let corpus = load_corpus();
let mut sum_sq: BTreeMap<&str, f32> = ALL_RULES.iter().map(|r| (*r, 0.0_f32)).collect();
let mut counts: BTreeMap<&str, usize> = ALL_RULES.iter().map(|r| (*r, 0_usize)).collect();
for entry in &corpus {
let report = diagnose(&entry.case);
let mut diagnoses: Vec<Diagnosis> = Vec::new();
if let Some(p) = &report.primary {
diagnoses.push(p.clone());
}
diagnoses.extend(report.also_considered.iter().cloned());
let predicted = rule_probabilities(&diagnoses);
for rule in &ALL_RULES {
let predicted_p = *predicted.get(*rule).unwrap_or(&0.0);
let ground_truth = match &entry.case.expected_rule_id {
Some(g) if g == rule => 1.0_f32,
_ => 0.0_f32,
};
let err = predicted_p - ground_truth;
*sum_sq.get_mut(rule).unwrap() += err * err;
*counts.get_mut(rule).unwrap() += 1;
}
}
let mut violations: Vec<String> = Vec::new();
for rule in &ALL_RULES {
let n = counts[rule];
if n == 0 {
continue;
}
let brier = sum_sq[rule] / n as f32;
if brier > PER_RULE_THRESHOLD {
violations.push(format!(
"rule {} per-rule Brier {:.4} exceeds threshold {:.4} over {} pairs",
rule, brier, PER_RULE_THRESHOLD, n
));
}
}
assert!(
violations.is_empty(),
"per-rule Brier violations:\n{}",
violations.join("\n")
);
}
fn expected_calibration_error(pairs: &[(f32, f32)], num_bins: usize) -> f32 {
if pairs.is_empty() || num_bins == 0 {
return 0.0;
}
let mut bin_sum_pred: Vec<f32> = vec![0.0; num_bins];
let mut bin_sum_actual: Vec<f32> = vec![0.0; num_bins];
let mut bin_count: Vec<usize> = vec![0; num_bins];
for &(pred, actual) in pairs {
let p = pred.clamp(0.0, 1.0);
let mut bin = (p * num_bins as f32) as usize;
if bin == num_bins {
bin = num_bins - 1;
}
bin_sum_pred[bin] += p;
bin_sum_actual[bin] += actual;
bin_count[bin] += 1;
}
let total = pairs.len() as f32;
let mut ece = 0.0_f32;
for i in 0..num_bins {
let n = bin_count[i];
if n == 0 {
continue;
}
let mean_pred = bin_sum_pred[i] / n as f32;
let mean_actual = bin_sum_actual[i] / n as f32;
ece += (n as f32 / total) * (mean_pred - mean_actual).abs();
}
ece
}
#[test]
fn ece_below_threshold() {
const THRESHOLD: f32 = 0.05;
const NUM_BINS: usize = 10;
let corpus = load_corpus();
let mut pairs: Vec<(f32, f32)> = Vec::new();
for entry in &corpus {
let report = diagnose(&entry.case);
let mut diagnoses: Vec<Diagnosis> = Vec::new();
if let Some(p) = &report.primary {
diagnoses.push(p.clone());
}
diagnoses.extend(report.also_considered.iter().cloned());
let predicted = rule_probabilities(&diagnoses);
for rule in &ALL_RULES {
let predicted_p = *predicted.get(*rule).unwrap_or(&0.0);
let ground_truth = match &entry.case.expected_rule_id {
Some(g) if g == rule => 1.0_f32,
_ => 0.0_f32,
};
pairs.push((predicted_p, ground_truth));
}
}
let ece = expected_calibration_error(&pairs, NUM_BINS);
assert!(
ece <= THRESHOLD,
"ECE {:.4} exceeded threshold {:.4} (over {} pairs, {} bins)",
ece,
THRESHOLD,
pairs.len(),
NUM_BINS
);
}
#[test]
fn unclassified_cases_have_zero_primary_confidence() {
let corpus = load_corpus();
for entry in &corpus {
if entry.case.expected_rule_id.is_some() {
continue;
}
let report = diagnose(&entry.case);
if let Some(p) = report.primary {
assert!(
p.confidence < 0.6,
"case {} expected unclassified but rule {} fired with confidence {:.2}",
entry.name,
p.rule_id,
p.confidence
);
}
}
}