use std::collections::BTreeMap;
use jammi_numerics::classification::ClassificationResult;
use jammi_numerics::ner::{Entity, NerMetrics};
use jammi_numerics::retrieval::{AggregateMetrics, QueryMetrics};
use jammi_numerics::stats::{bootstrap_ci, mann_whitney_u, Interval};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingEvalReport {
#[serde(default)]
pub eval_run_id: String,
pub aggregate: AggregateMetrics,
pub per_query: Vec<PerQueryRecord>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerQueryRecord {
pub query_id: String,
pub metrics: QueryMetrics,
#[serde(default)]
pub recall_at_ks: Vec<(usize, f64)>,
#[serde(default)]
pub distance: f64,
#[serde(default)]
pub cohorts: BTreeMap<String, String>,
}
pub const PER_QUERY_RECALL_KS: [usize; 4] = [1, 3, 5, 10];
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceEvalReport {
pub aggregate: InferenceAggregate,
pub per_record: Vec<PerRecordPrediction>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "task", rename_all = "snake_case")]
pub enum InferenceAggregate {
Classification(ClassificationResult),
Ner(NerMetrics),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "task", rename_all = "snake_case")]
pub enum PerRecordPrediction {
Classification {
record_id: String,
predicted: String,
gold: String,
},
Ner {
record_id: String,
predicted: Vec<Entity>,
gold: Vec<Entity>,
},
}
pub const CALIBRATION_INTERVAL_LEVEL: f64 = 0.90;
pub const CALIBRATION_ECE_BINS: usize = 10;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "shape", rename_all = "snake_case")]
pub enum CalibrationPrediction {
Gaussian {
record_id: String,
mean: f64,
sd: f64,
outcome: f64,
#[serde(default)]
cohorts: BTreeMap<String, String>,
},
Sample {
record_id: String,
draws: Vec<f64>,
outcome: f64,
#[serde(default)]
cohorts: BTreeMap<String, String>,
},
}
impl CalibrationPrediction {
fn record_id(&self) -> &str {
match self {
Self::Gaussian { record_id, .. } | Self::Sample { record_id, .. } => record_id,
}
}
fn outcome(&self) -> f64 {
match self {
Self::Gaussian { outcome, .. } | Self::Sample { outcome, .. } => *outcome,
}
}
fn cohorts(&self) -> &BTreeMap<String, String> {
match self {
Self::Gaussian { cohorts, .. } | Self::Sample { cohorts, .. } => cohorts,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerRecordCalibration {
pub record_id: String,
pub crps: f64,
pub nll: f64,
pub pit: f64,
pub covered: bool,
pub interval_width: f64,
#[serde(default)]
pub cohorts: BTreeMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalibrationAggregate {
pub n: usize,
pub crps: f64,
pub nll: f64,
pub adaptive_ece: f64,
pub sharpness: f64,
pub coverage: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CohortCalibration {
pub key: String,
pub value: String,
pub n: usize,
pub crps: f64,
#[serde(default)]
pub crps_ci_lower: Option<f64>,
#[serde(default)]
pub crps_ci_upper: Option<f64>,
pub coverage: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalibrationEvalReport {
#[serde(default)]
pub eval_run_id: String,
pub aggregate: CalibrationAggregate,
pub per_cohort: Vec<CohortCalibration>,
pub per_record: Vec<PerRecordCalibration>,
}
impl CalibrationEvalReport {
pub fn significance_vs(&self, baseline: &CalibrationEvalReport) -> Option<MetricSignificance> {
calibration_significance(&baseline.per_record, &self.per_record)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompareEvalReport {
pub per_table: Vec<TableEvalReport>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TableEvalReport {
pub table_name: String,
pub embedding_eval: EmbeddingEvalReport,
pub delta: Option<AggregateDelta>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AggregateDelta {
pub recall_at_k: MetricDelta,
pub precision_at_k: MetricDelta,
pub mrr: MetricDelta,
pub ndcg: MetricDelta,
#[serde(default)]
pub significance: Option<DeltaSignificance>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct MetricDelta {
pub absolute: f64,
pub relative: f64,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct DeltaSignificance {
pub recall_at_k: MetricSignificance,
pub precision_at_k: MetricSignificance,
pub mrr: MetricSignificance,
pub ndcg: MetricSignificance,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct MetricSignificance {
pub p_value: f64,
pub ci_lower: f64,
pub ci_upper: f64,
}
pub const BOOTSTRAP_ITERATIONS: usize = 10_000;
pub const BOOTSTRAP_ALPHA: f64 = 0.05;
pub const BOOTSTRAP_SEED: u64 = 0x6a616d6d695f7031;
struct PairedMetric {
baseline: f64,
treatment: f64,
}
fn paired_metric<F>(
baseline: &[PerQueryRecord],
treatment: &[PerQueryRecord],
extract: F,
) -> Vec<PairedMetric>
where
F: Fn(&QueryMetrics) -> f64,
{
let treatment_by_id: BTreeMap<&str, &QueryMetrics> = treatment
.iter()
.map(|r| (r.query_id.as_str(), &r.metrics))
.collect();
baseline
.iter()
.filter_map(|b| {
treatment_by_id
.get(b.query_id.as_str())
.map(|t| PairedMetric {
baseline: extract(&b.metrics),
treatment: extract(t),
})
})
.collect()
}
fn metric_significance(paired: &[PairedMetric]) -> Option<MetricSignificance> {
if paired.is_empty() {
return None;
}
let differences: Vec<f64> = paired.iter().map(|p| p.treatment - p.baseline).collect();
let baseline: Vec<f64> = paired.iter().map(|p| p.baseline).collect();
let treatment: Vec<f64> = paired.iter().map(|p| p.treatment).collect();
let Interval { lower, upper } = bootstrap_ci(
&differences,
|resample| resample.iter().sum::<f64>() / resample.len() as f64,
BOOTSTRAP_ITERATIONS,
BOOTSTRAP_ALPHA,
BOOTSTRAP_SEED,
)
.ok()?;
let p_value = mann_whitney_u(&baseline, &treatment).ok()?.p_value;
Some(MetricSignificance {
p_value,
ci_lower: lower,
ci_upper: upper,
})
}
pub fn delta_significance(
baseline: &[PerQueryRecord],
treatment: &[PerQueryRecord],
) -> Option<DeltaSignificance> {
Some(DeltaSignificance {
recall_at_k: metric_significance(&paired_metric(baseline, treatment, |m| m.recall))?,
precision_at_k: metric_significance(&paired_metric(baseline, treatment, |m| m.precision))?,
mrr: metric_significance(&paired_metric(baseline, treatment, |m| m.mrr))?,
ndcg: metric_significance(&paired_metric(baseline, treatment, |m| m.ndcg))?,
})
}
const CALIBRATION_INTERVAL_Z: f64 = 1.6448536269514722;
fn score_prediction(
prediction: &CalibrationPrediction,
) -> jammi_db::error::Result<PerRecordCalibration> {
use jammi_numerics::calibration::{
crps_gaussian, crps_sample, gaussian_nll, pit_values, sample_nll,
};
let map_err = |e: jammi_numerics::NumericsError| {
jammi_db::error::JammiError::Eval(format!(
"calibration scoring failed for record '{}': {e}",
prediction.record_id()
))
};
let (crps, nll, pit, lower, upper) = match prediction {
CalibrationPrediction::Gaussian {
mean, sd, outcome, ..
} => {
let crps = crps_gaussian(*outcome, *mean, *sd).map_err(map_err)?;
let nll = gaussian_nll(*outcome, *mean, *sd).map_err(map_err)?;
let pit = pit_values(&[*mean], &[*sd], &[*outcome]).map_err(map_err)?[0];
let half = CALIBRATION_INTERVAL_Z * *sd;
(crps, nll, pit, *mean - half, *mean + half)
}
CalibrationPrediction::Sample { draws, outcome, .. } => {
let crps = crps_sample(draws, *outcome).map_err(map_err)?;
let nll = sample_nll(draws, *outcome).map_err(map_err)?;
let pit = draws.iter().filter(|&&x| x <= *outcome).count() as f64 / draws.len() as f64;
let (lower, upper) = empirical_central_interval(draws);
(crps, nll, pit, lower, upper)
}
};
Ok(PerRecordCalibration {
record_id: prediction.record_id().to_string(),
crps,
nll,
pit,
covered: lower <= prediction.outcome() && prediction.outcome() <= upper,
interval_width: upper - lower,
cohorts: prediction.cohorts().clone(),
})
}
fn empirical_central_interval(draws: &[f64]) -> (f64, f64) {
let mut sorted = draws.to_vec();
sorted.sort_by(f64::total_cmp);
let alpha = 1.0 - CALIBRATION_INTERVAL_LEVEL;
(
empirical_quantile(&sorted, alpha / 2.0),
empirical_quantile(&sorted, 1.0 - alpha / 2.0),
)
}
fn empirical_quantile(sorted: &[f64], q: f64) -> f64 {
let n = sorted.len();
if n == 1 {
return sorted[0];
}
let pos = q * (n - 1) as f64;
let lo = pos.floor() as usize;
let hi = pos.ceil() as usize;
if lo == hi {
return sorted[lo];
}
let frac = pos - lo as f64;
sorted[lo] * (1.0 - frac) + sorted[hi] * frac
}
pub fn compute_calibration(
eval_run_id: String,
predictions: &[CalibrationPrediction],
) -> jammi_db::error::Result<CalibrationEvalReport> {
if predictions.is_empty() {
return Err(jammi_db::error::JammiError::Eval(
"eval_calibration requires at least one held-out prediction".into(),
));
}
let per_record: Vec<PerRecordCalibration> = predictions
.iter()
.map(score_prediction)
.collect::<jammi_db::error::Result<_>>()?;
let n = per_record.len();
let n_f = n as f64;
let crps = per_record.iter().map(|r| r.crps).sum::<f64>() / n_f;
let nll = per_record.iter().map(|r| r.nll).sum::<f64>() / n_f;
let sharpness = per_record.iter().map(|r| r.interval_width).sum::<f64>() / n_f;
let coverage = per_record.iter().filter(|r| r.covered).count() as f64 / n_f;
let pit: Vec<f64> = per_record.iter().map(|r| r.pit).collect();
let adaptive_ece =
jammi_numerics::calibration::pit_calibration_error(&pit, CALIBRATION_ECE_BINS).map_err(
|e| jammi_db::error::JammiError::Eval(format!("calibration diagnostic failed: {e}")),
)?;
let aggregate = CalibrationAggregate {
n,
crps,
nll,
adaptive_ece,
sharpness,
coverage,
};
let per_cohort = compute_cohort_calibration(&per_record);
Ok(CalibrationEvalReport {
eval_run_id,
aggregate,
per_cohort,
per_record,
})
}
fn compute_cohort_calibration(per_record: &[PerRecordCalibration]) -> Vec<CohortCalibration> {
let mut groups: BTreeMap<(String, String), Vec<usize>> = BTreeMap::new();
for (i, rec) in per_record.iter().enumerate() {
for (key, value) in &rec.cohorts {
groups
.entry((key.clone(), value.clone()))
.or_default()
.push(i);
}
}
groups
.into_iter()
.map(|((key, value), idxs)| {
let crps_values: Vec<f64> = idxs.iter().map(|&i| per_record[i].crps).collect();
let n = crps_values.len();
let crps = crps_values.iter().sum::<f64>() / n as f64;
let covered = idxs.iter().filter(|&&i| per_record[i].covered).count();
let coverage = covered as f64 / n as f64;
let (crps_ci_lower, crps_ci_upper) = if n < 2 {
(None, None)
} else {
match bootstrap_ci(
&crps_values,
|resample| resample.iter().sum::<f64>() / resample.len() as f64,
BOOTSTRAP_ITERATIONS,
BOOTSTRAP_ALPHA,
BOOTSTRAP_SEED,
) {
Ok(Interval { lower, upper }) => (Some(lower), Some(upper)),
Err(_) => (None, None),
}
};
CohortCalibration {
key,
value,
n,
crps,
crps_ci_lower,
crps_ci_upper,
coverage,
}
})
.collect()
}
pub(crate) fn calibration_significance(
baseline: &[PerRecordCalibration],
treatment: &[PerRecordCalibration],
) -> Option<MetricSignificance> {
let treatment_by_id: BTreeMap<&str, f64> = treatment
.iter()
.map(|r| (r.record_id.as_str(), r.crps))
.collect();
let paired: Vec<(f64, f64)> = baseline
.iter()
.filter_map(|b| {
treatment_by_id
.get(b.record_id.as_str())
.map(|&t| (b.crps, t))
})
.collect();
if paired.is_empty() {
return None;
}
let differences: Vec<f64> = paired.iter().map(|(b, t)| t - b).collect();
let base: Vec<f64> = paired.iter().map(|(b, _)| *b).collect();
let treat: Vec<f64> = paired.iter().map(|(_, t)| *t).collect();
let Interval { lower, upper } = bootstrap_ci(
&differences,
|resample| resample.iter().sum::<f64>() / resample.len() as f64,
BOOTSTRAP_ITERATIONS,
BOOTSTRAP_ALPHA,
BOOTSTRAP_SEED,
)
.ok()?;
let p_value = mann_whitney_u(&base, &treat).ok()?.p_value;
Some(MetricSignificance {
p_value,
ci_lower: lower,
ci_upper: upper,
})
}
#[cfg(test)]
mod calibration_tests {
use super::*;
fn gaussian(id: &str, mean: f64, sd: f64, outcome: f64) -> CalibrationPrediction {
CalibrationPrediction::Gaussian {
record_id: id.to_string(),
mean,
sd,
outcome,
cohorts: BTreeMap::new(),
}
}
fn gaussian_with_cohort(
id: &str,
mean: f64,
sd: f64,
outcome: f64,
key: &str,
value: &str,
) -> CalibrationPrediction {
let mut cohorts = BTreeMap::new();
cohorts.insert(key.to_string(), value.to_string());
CalibrationPrediction::Gaussian {
record_id: id.to_string(),
mean,
sd,
outcome,
cohorts,
}
}
fn calibrated_standard_normal(n: usize) -> Vec<CalibrationPrediction> {
use statrs::distribution::{ContinuousCDF, Normal};
let normal = Normal::standard();
(0..n)
.map(|i| {
let p = (i as f64 + 0.5) / n as f64;
let outcome = normal.inverse_cdf(p);
gaussian(&format!("r{i}"), 0.0, 1.0, outcome)
})
.collect()
}
#[test]
fn empty_predictions_is_error() {
assert!(compute_calibration("run".into(), &[]).is_err());
}
#[test]
fn gaussian_proper_scores_match_numerics() {
let preds = vec![gaussian("a", 0.0, 1.0, 0.0), gaussian("b", 0.0, 1.0, 1.0)];
let report = compute_calibration("run".into(), &preds).unwrap();
let expected_crps = (jammi_numerics::calibration::crps_gaussian(0.0, 0.0, 1.0).unwrap()
+ jammi_numerics::calibration::crps_gaussian(1.0, 0.0, 1.0).unwrap())
/ 2.0;
assert!((report.aggregate.crps - expected_crps).abs() < 1e-12);
assert_eq!(report.aggregate.n, 2);
}
#[test]
fn calibrated_predictor_hits_nominal_coverage() {
let preds = calibrated_standard_normal(400);
let report = compute_calibration("run".into(), &preds).unwrap();
assert!(
(report.aggregate.coverage - CALIBRATION_INTERVAL_LEVEL).abs() < 0.02,
"coverage {} should be near {CALIBRATION_INTERVAL_LEVEL}",
report.aggregate.coverage
);
assert!(
report.aggregate.adaptive_ece < 0.1,
"calibrated predictor ECE should be small: {}",
report.aggregate.adaptive_ece
);
}
#[test]
fn overconfident_predictor_undercovers() {
let preds: Vec<CalibrationPrediction> = (0..100)
.map(|i| {
let outcome = (i as f64 - 50.0) / 5.0; gaussian(&format!("r{i}"), 0.0, 0.1, outcome)
})
.collect();
let report = compute_calibration("run".into(), &preds).unwrap();
assert!(
report.aggregate.coverage < 0.2,
"overconfident predictor should badly under-cover: {}",
report.aggregate.coverage
);
}
#[test]
fn sample_and_gaussian_families_agree_on_a_wide_ensemble() {
use statrs::distribution::{ContinuousCDF, Normal};
let normal = Normal::standard();
let draws: Vec<f64> = (0..5000)
.map(|i| normal.inverse_cdf((i as f64 + 0.5) / 5000.0))
.collect();
let sample = CalibrationPrediction::Sample {
record_id: "s".into(),
draws,
outcome: 0.5,
cohorts: BTreeMap::new(),
};
let gauss = gaussian("g", 0.0, 1.0, 0.5);
let sample_report =
compute_calibration("run".into(), std::slice::from_ref(&sample)).unwrap();
let gauss_report = compute_calibration("run".into(), std::slice::from_ref(&gauss)).unwrap();
assert!(
(sample_report.aggregate.crps - gauss_report.aggregate.crps).abs() < 2e-2,
"sample CRPS {} vs gaussian CRPS {}",
sample_report.aggregate.crps,
gauss_report.aggregate.crps
);
}
#[test]
fn cohorts_slice_coverage_and_crps() {
let mut preds = Vec::new();
for i in 0..30 {
let outcome = ((i as f64 + 0.5) / 30.0 - 0.5) * 2.0;
preds.push(gaussian_with_cohort(
&format!("e{i}"),
0.0,
1.0,
outcome,
"tier",
"easy",
));
}
for i in 0..30 {
let outcome = (i as f64 - 15.0) * 2.0; preds.push(gaussian_with_cohort(
&format!("h{i}"),
0.0,
0.1,
outcome,
"tier",
"hard",
));
}
let report = compute_calibration("run".into(), &preds).unwrap();
assert_eq!(report.per_cohort.len(), 2);
let easy = report
.per_cohort
.iter()
.find(|c| c.value == "easy")
.unwrap();
let hard = report
.per_cohort
.iter()
.find(|c| c.value == "hard")
.unwrap();
assert_eq!(easy.n, 30);
assert_eq!(hard.n, 30);
assert!(
easy.coverage > hard.coverage,
"easy cohort {} should cover more than hard {}",
easy.coverage,
hard.coverage
);
assert!(easy.crps_ci_lower.is_some(), "n=30 cohort carries a CI");
assert!(easy.crps_ci_lower.unwrap() <= easy.crps);
assert!(easy.crps <= easy.crps_ci_upper.unwrap());
}
#[test]
fn singleton_cohort_has_no_ci() {
let preds = vec![gaussian_with_cohort("a", 0.0, 1.0, 0.0, "tier", "solo")];
let report = compute_calibration("run".into(), &preds).unwrap();
let cohort = &report.per_cohort[0];
assert_eq!(cohort.n, 1);
assert!(cohort.crps_ci_lower.is_none());
}
#[test]
fn compute_is_deterministic() {
let preds = calibrated_standard_normal(60);
let a = compute_calibration("run".into(), &preds).unwrap();
let b = compute_calibration("run".into(), &preds).unwrap();
assert_eq!(a.aggregate.crps, b.aggregate.crps);
assert_eq!(a.aggregate.adaptive_ece, b.aggregate.adaptive_ece);
}
#[test]
fn crps_significance_pairs_by_record_id() {
let baseline: Vec<PerRecordCalibration> = compute_calibration(
"b".into(),
&(0..40)
.map(|i| gaussian(&format!("r{i}"), 0.0, 0.1, (i as f64 - 20.0) / 4.0))
.collect::<Vec<_>>(),
)
.unwrap()
.per_record;
let treatment: Vec<PerRecordCalibration> = compute_calibration(
"t".into(),
&(0..40)
.map(|i| gaussian(&format!("r{i}"), 0.0, 5.0, (i as f64 - 20.0) / 4.0))
.collect::<Vec<_>>(),
)
.unwrap()
.per_record;
let sig = calibration_significance(&baseline, &treatment).expect("shared ids");
assert!(
sig.ci_upper < 0.0,
"treatment should be significantly better-scored: ci_upper {}",
sig.ci_upper
);
}
#[test]
fn crps_significance_none_without_shared_ids() {
let baseline = compute_calibration("b".into(), &[gaussian("a", 0.0, 1.0, 0.0)])
.unwrap()
.per_record;
let treatment = compute_calibration("t".into(), &[gaussian("z", 0.0, 1.0, 0.0)])
.unwrap()
.per_record;
assert!(calibration_significance(&baseline, &treatment).is_none());
}
}
#[cfg(test)]
mod projection_fixture_tests {
use super::*;
use jammi_numerics::classification::{ClassMetrics, ClassificationResult};
use jammi_numerics::ner::{Entity, NerMetrics, TypeMetrics};
use jammi_numerics::retrieval::AggregateMetrics;
const FIXTURE: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../tests/fixtures/eval_report_projection.json"
);
fn per_query(
query_id: &str,
metrics: QueryMetrics,
recall_at_ks: Vec<(usize, f64)>,
distance: f64,
cohorts: &[(&str, &str)],
) -> PerQueryRecord {
PerQueryRecord {
query_id: query_id.to_string(),
metrics,
recall_at_ks,
distance,
cohorts: cohorts
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect(),
}
}
fn embedding_report() -> EmbeddingEvalReport {
EmbeddingEvalReport {
eval_run_id: "run-embedding-fixture".into(),
aggregate: AggregateMetrics {
recall_at_k: 0.75,
precision_at_k: 0.5,
mrr: 0.625,
ndcg: 0.8125,
},
per_query: vec![
per_query(
"q1",
QueryMetrics {
recall: 1.0,
precision: 0.5,
mrr: 1.0,
ndcg: 1.0,
},
vec![(1, 0.5), (3, 1.0), (5, 1.0), (10, 1.0)],
0.125,
&[("family", "A"), ("split", "val")],
),
per_query(
"q2",
QueryMetrics {
recall: 0.5,
precision: 0.5,
mrr: 0.25,
ndcg: 0.625,
},
vec![(1, 0.0), (3, 0.5), (5, 0.5), (10, 0.5)],
0.0,
&[],
),
],
}
}
fn classification_report() -> InferenceEvalReport {
InferenceEvalReport {
aggregate: InferenceAggregate::Classification(ClassificationResult {
accuracy: 0.75,
f1: 0.5,
per_class: [
(
"spam".to_string(),
ClassMetrics {
precision: 1.0,
recall: 0.5,
f1: 0.625,
},
),
(
"ham".to_string(),
ClassMetrics {
precision: 0.5,
recall: 1.0,
f1: 0.625,
},
),
]
.into(),
}),
per_record: vec![
PerRecordPrediction::Classification {
record_id: "r1".into(),
predicted: "spam".into(),
gold: "spam".into(),
},
PerRecordPrediction::Classification {
record_id: "r2".into(),
predicted: "ham".into(),
gold: "spam".into(),
},
],
}
}
fn ner_report() -> InferenceEvalReport {
InferenceEvalReport {
aggregate: InferenceAggregate::Ner(NerMetrics {
precision: 0.5,
recall: 0.25,
f1: 0.375,
per_type: [(
"PER".to_string(),
TypeMetrics {
precision: 0.5,
recall: 0.25,
f1: 0.375,
support: 4,
},
)]
.into(),
}),
per_record: vec![PerRecordPrediction::Ner {
record_id: "n1".into(),
predicted: vec![Entity {
label: "PER".into(),
start: 0,
end: 5,
text: "Alice".into(),
confidence: 0.5,
}],
gold: vec![Entity {
label: "PER".into(),
start: 0,
end: 5,
text: String::new(),
confidence: 0.0,
}],
}],
}
}
fn compare_member(run_id: &str, recall: f64) -> EmbeddingEvalReport {
EmbeddingEvalReport {
eval_run_id: run_id.into(),
aggregate: AggregateMetrics {
recall_at_k: recall,
precision_at_k: 0.5,
mrr: 0.5,
ndcg: 0.5,
},
per_query: vec![per_query(
"q1",
QueryMetrics {
recall,
precision: 0.5,
mrr: 0.5,
ndcg: 0.5,
},
vec![(1, recall)],
0.25,
&[],
)],
}
}
fn sig(p_value: f64, ci_lower: f64, ci_upper: f64) -> MetricSignificance {
MetricSignificance {
p_value,
ci_lower,
ci_upper,
}
}
fn compare_report() -> CompareEvalReport {
let delta = |significance| AggregateDelta {
recall_at_k: MetricDelta {
absolute: 0.25,
relative: 0.5,
},
precision_at_k: MetricDelta {
absolute: 0.0,
relative: 0.0,
},
mrr: MetricDelta {
absolute: -0.125,
relative: -0.25,
},
ndcg: MetricDelta {
absolute: 0.125,
relative: 0.25,
},
significance,
};
CompareEvalReport {
per_table: vec![
TableEvalReport {
table_name: "emb_baseline".into(),
embedding_eval: compare_member("run-baseline", 0.5),
delta: None,
},
TableEvalReport {
table_name: "emb_paired".into(),
embedding_eval: compare_member("run-paired", 0.75),
delta: Some(delta(Some(DeltaSignificance {
recall_at_k: sig(0.0625, 0.125, 0.375),
precision_at_k: sig(1.0, 0.0, 0.0),
mrr: sig(0.5, -0.25, 0.0),
ndcg: sig(0.25, 0.0, 0.25),
}))),
},
TableEvalReport {
table_name: "emb_unpaired".into(),
embedding_eval: compare_member("run-unpaired", 0.25),
delta: Some(delta(None)),
},
],
}
}
#[test]
fn report_serde_shapes_match_the_committed_fixture() {
let fixture: serde_json::Value =
serde_json::from_str(&std::fs::read_to_string(FIXTURE).expect("read fixture"))
.expect("fixture is valid JSON");
for (key, actual) in [
(
"embedding",
serde_json::to_value(embedding_report()).unwrap(),
),
(
"inference_classification",
serde_json::to_value(classification_report()).unwrap(),
),
("inference_ner", serde_json::to_value(ner_report()).unwrap()),
("compare", serde_json::to_value(compare_report()).unwrap()),
] {
assert_eq!(
actual,
fixture[key],
"serde shape of '{key}' drifted from the fixture; actual:\n{}",
serde_json::to_string_pretty(&actual).unwrap()
);
}
}
}
#[cfg(test)]
mod significance_tests {
use super::*;
fn record(query_id: &str, recall: f64, precision: f64, mrr: f64, ndcg: f64) -> PerQueryRecord {
PerQueryRecord {
query_id: query_id.to_string(),
metrics: QueryMetrics {
recall,
precision,
mrr,
ndcg,
},
recall_at_ks: Vec::new(),
distance: 0.0,
cohorts: BTreeMap::new(),
}
}
fn improving_pair() -> (Vec<PerQueryRecord>, Vec<PerQueryRecord>) {
let baseline: Vec<PerQueryRecord> = (0..20)
.map(|i| record(&format!("q{i}"), 0.2, 0.2, 0.2, 0.2))
.collect();
let treatment: Vec<PerQueryRecord> = (0..20)
.map(|i| record(&format!("q{i}"), 0.8, 0.8, 0.8, 0.8))
.collect();
(baseline, treatment)
}
#[test]
fn pairs_only_shared_query_ids() {
let baseline = vec![
record("a", 0.1, 0.1, 0.1, 0.1),
record("b", 0.2, 0.2, 0.2, 0.2),
];
let treatment = vec![
record("b", 0.5, 0.5, 0.5, 0.5),
record("c", 0.9, 0.9, 0.9, 0.9),
];
let paired = paired_metric(&baseline, &treatment, |m| m.recall);
assert_eq!(paired.len(), 1, "only query 'b' is shared");
assert_eq!(paired[0].baseline, 0.2);
assert_eq!(paired[0].treatment, 0.5);
}
#[test]
fn no_shared_queries_yields_none() {
let baseline = vec![record("a", 0.1, 0.1, 0.1, 0.1)];
let treatment = vec![record("z", 0.9, 0.9, 0.9, 0.9)];
assert!(delta_significance(&baseline, &treatment).is_none());
}
#[test]
fn deterministic_under_pinned_seed() {
let (baseline, treatment) = improving_pair();
let first = delta_significance(&baseline, &treatment).expect("paired");
let second = delta_significance(&baseline, &treatment).expect("paired");
assert_eq!(first.recall_at_k.p_value, second.recall_at_k.p_value);
assert_eq!(first.recall_at_k.ci_lower, second.recall_at_k.ci_lower);
assert_eq!(first.recall_at_k.ci_upper, second.recall_at_k.ci_upper);
assert_eq!(first.ndcg.ci_upper, second.ndcg.ci_upper);
}
#[test]
fn improvement_is_significant() {
let (baseline, treatment) = improving_pair();
let sig = delta_significance(&baseline, &treatment).expect("paired");
assert!(
sig.recall_at_k.ci_lower > 0.0,
"CI lower bound should exclude zero: {}",
sig.recall_at_k.ci_lower
);
assert!(
sig.recall_at_k.p_value < 0.01,
"p-value should be small for a clear lift: {}",
sig.recall_at_k.p_value
);
}
#[test]
fn significance_is_invariant_to_per_query_emission_order() {
let lifts = [0.05, 0.40, -0.10, 0.30, 0.15, 0.55, -0.05, 0.25, 0.45, 0.10];
let make = |order: &[usize]| -> (Vec<PerQueryRecord>, Vec<PerQueryRecord>) {
let baseline: Vec<PerQueryRecord> = order
.iter()
.map(|&i| {
let b = 0.2 + 0.03 * i as f64;
record(&format!("q{i}"), b, b, b, b)
})
.collect();
let treatment: Vec<PerQueryRecord> = order
.iter()
.map(|&i| {
let b = 0.2 + 0.03 * i as f64;
let t = b + lifts[i];
record(&format!("q{i}"), t, t, t, t)
})
.collect();
(baseline, treatment)
};
let ascending: Vec<usize> = (0..lifts.len()).collect();
let mut shuffled = ascending.clone();
shuffled.reverse();
shuffled.rotate_left(3);
assert_ne!(ascending, shuffled, "the two emission orders must differ");
let (b1, t1) = make(&ascending);
let (b2, t2) = make(&shuffled);
let a = delta_significance(&b1, &t1).expect("paired");
let c = delta_significance(&b2, &t2).expect("paired");
assert!(
a.recall_at_k.ci_lower != 0.0 || a.recall_at_k.ci_upper != 0.0,
"fixture must be non-degenerate: CI is [{}, {}]",
a.recall_at_k.ci_lower,
a.recall_at_k.ci_upper
);
for (x, y) in [
(a.recall_at_k, c.recall_at_k),
(a.precision_at_k, c.precision_at_k),
(a.mrr, c.mrr),
(a.ndcg, c.ndcg),
] {
assert_eq!(
x.ci_lower.to_bits(),
y.ci_lower.to_bits(),
"ci_lower must be byte-identical across emission orders"
);
assert_eq!(
x.ci_upper.to_bits(),
y.ci_upper.to_bits(),
"ci_upper must be byte-identical across emission orders"
);
assert_eq!(x.p_value.to_bits(), y.p_value.to_bits());
}
}
#[test]
fn identical_runs_ci_brackets_zero() {
let baseline: Vec<PerQueryRecord> = (0..20)
.map(|i| record(&format!("q{i}"), 0.5, 0.5, 0.5, 0.5))
.collect();
let treatment = baseline.clone();
let sig = delta_significance(&baseline, &treatment).expect("paired");
assert_eq!(sig.mrr.ci_lower, 0.0);
assert_eq!(sig.mrr.ci_upper, 0.0);
assert!(
sig.mrr.p_value > 0.99,
"identical distributions should be indistinguishable: {}",
sig.mrr.p_value
);
}
}