use logp::{entropy_nats, jensen_shannon_divergence, kl_divergence};
fn main() {
let categories = ["2xx", "3xx", "4xx", "403", "5xx", "timeout"];
let reference = [0.70, 0.10, 0.10, 0.03, 0.05, 0.02];
let tol = 1e-12;
println!("Distribution shift detection via divergences");
println!();
println!("Reference distribution (e.g., from training data):");
for (cat, &p) in categories.iter().zip(&reference) {
println!(" {cat:<8} {p:.2}");
}
let h_ref = entropy_nats(&reference, tol).unwrap();
println!(" entropy: {h_ref:.4} nats");
println!();
let batches: Vec<(&str, [f64; 6])> = vec![
("no_drift ", [0.70, 0.10, 0.10, 0.03, 0.05, 0.02]),
("mild_drift ", [0.65, 0.10, 0.10, 0.04, 0.08, 0.03]),
("moderate ", [0.55, 0.10, 0.12, 0.05, 0.12, 0.06]),
("severe_drift", [0.40, 0.08, 0.15, 0.07, 0.20, 0.10]),
("broken_etl ", [0.10, 0.05, 0.30, 0.15, 0.25, 0.15]),
];
println!(
"{:<14} {:>8} {:>8} {:>10} {:>10}",
"batch", "JS", "KL(b|r)", "KL(r|b)", "alert?"
);
println!("{}", "-".repeat(54));
let js_threshold = 0.01;
for (name, batch_dist) in &batches {
let smoothed = smooth(&reference, 1e-10);
let batch_smoothed = smooth(batch_dist, 1e-10);
let js = jensen_shannon_divergence(&smoothed, &batch_smoothed, tol).unwrap();
let kl_batch_ref = kl_divergence(&batch_smoothed, &smoothed, tol).unwrap();
let kl_ref_batch = kl_divergence(&smoothed, &batch_smoothed, tol).unwrap();
let alert = if js > js_threshold { "YES" } else { "no" };
println!(
"{:<14} {:>8.5} {:>8.5} {:>10.5} {:>10}",
name, js, kl_batch_ref, kl_ref_batch, alert
);
}
println!();
println!("JS divergence is symmetric and bounded [0, ln(2) ~ 0.693].");
println!("KL(batch || ref) measures the information lost by using the reference");
println!("to model the batch. Note KL is asymmetric: KL(b|r) != KL(r|b).");
println!();
println!("In practice, alert thresholds depend on batch size and domain.");
println!("JS > {js_threshold} is a starting heuristic; calibrate on historical data.");
}
fn smooth(p: &[f64], eps: f64) -> Vec<f64> {
let shifted: Vec<f64> = p.iter().map(|&x| x + eps).collect();
let s: f64 = shifted.iter().sum();
shifted.iter().map(|&x| x / s).collect()
}