Skip to main content

kaizen/experiment/
stats.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Non-parametric stats for experiment reports.
3//!
4//! Effect size = median(treatment) − median(control). CI via
5//! percentile bootstrap (default 10k resamples, 95%). Winsorize p1/p99
6//! before resampling to blunt skew.
7
8use rand::rngs::SmallRng;
9use rand::{RngExt, SeedableRng};
10use serde::{Deserialize, Serialize};
11
12pub const DEFAULT_RESAMPLES: u32 = 10_000;
13pub const MIN_SAMPLE: usize = 30;
14
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
16pub struct Summary {
17    pub n_control: usize,
18    pub n_treatment: usize,
19    pub median_control: Option<f64>,
20    pub median_treatment: Option<f64>,
21    pub mean_control: Option<f64>,
22    pub mean_treatment: Option<f64>,
23    pub delta_median: Option<f64>,
24    pub delta_pct: Option<f64>,
25    pub ci95_lo: Option<f64>,
26    pub ci95_hi: Option<f64>,
27    pub small_sample_warning: bool,
28}
29
30/// Pure stats for a metric. Deterministic given `seed`.
31pub fn summarize(control: &[f64], treatment: &[f64], seed: u64, resamples: u32) -> Summary {
32    let c = winsorize(control, 0.01, 0.99);
33    let t = winsorize(treatment, 0.01, 0.99);
34    let median_c = median(&c);
35    let median_t = median(&t);
36    let mean_c = mean(&c);
37    let mean_t = mean(&t);
38    let delta = match (median_c, median_t) {
39        (Some(a), Some(b)) => Some(b - a),
40        _ => None,
41    };
42    let delta_pct = match (median_c, delta) {
43        (Some(a), Some(d)) if a != 0.0 => Some(100.0 * d / a),
44        _ => None,
45    };
46    let (lo, hi) = if c.is_empty() || t.is_empty() {
47        (None, None)
48    } else {
49        bootstrap_ci(&c, &t, seed, resamples)
50    };
51    Summary {
52        n_control: control.len(),
53        n_treatment: treatment.len(),
54        median_control: median_c,
55        median_treatment: median_t,
56        mean_control: mean_c,
57        mean_treatment: mean_t,
58        delta_median: delta,
59        delta_pct,
60        ci95_lo: lo,
61        ci95_hi: hi,
62        small_sample_warning: control.len().min(treatment.len()) < MIN_SAMPLE,
63    }
64}
65
66/// Clamp values to `[p_lo quantile, p_hi quantile]`.
67pub fn winsorize(xs: &[f64], p_lo: f64, p_hi: f64) -> Vec<f64> {
68    if xs.is_empty() {
69        return Vec::new();
70    }
71    let Some(lo) = quantile(xs, p_lo) else {
72        return xs.to_vec();
73    };
74    let Some(hi) = quantile(xs, p_hi) else {
75        return xs.to_vec();
76    };
77    xs.iter().map(|v| v.clamp(lo, hi)).collect()
78}
79
80fn quantile(xs: &[f64], p: f64) -> Option<f64> {
81    if xs.is_empty() {
82        return None;
83    }
84    let mut v = xs.to_vec();
85    v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
86    let idx = ((v.len() - 1) as f64 * p).round() as usize;
87    Some(v[idx.min(v.len() - 1)])
88}
89
90fn median(xs: &[f64]) -> Option<f64> {
91    if xs.is_empty() {
92        return None;
93    }
94    let mut v = xs.to_vec();
95    v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
96    let n = v.len();
97    if n % 2 == 1 {
98        Some(v[n / 2])
99    } else {
100        Some((v[n / 2 - 1] + v[n / 2]) / 2.0)
101    }
102}
103
104fn mean(xs: &[f64]) -> Option<f64> {
105    if xs.is_empty() {
106        return None;
107    }
108    Some(xs.iter().sum::<f64>() / xs.len() as f64)
109}
110
111fn bootstrap_ci(
112    control: &[f64],
113    treatment: &[f64],
114    seed: u64,
115    resamples: u32,
116) -> (Option<f64>, Option<f64>) {
117    let mut rng = SmallRng::seed_from_u64(seed);
118    let mut deltas: Vec<f64> = Vec::with_capacity(resamples as usize);
119    let mut buf_c = vec![0.0_f64; control.len()];
120    let mut buf_t = vec![0.0_f64; treatment.len()];
121    for _ in 0..resamples {
122        for slot in buf_c.iter_mut() {
123            *slot = control[rng.random_range(0..control.len())];
124        }
125        for slot in buf_t.iter_mut() {
126            *slot = treatment[rng.random_range(0..treatment.len())];
127        }
128        let (Some(mc), Some(mt)) = (median(&buf_c), median(&buf_t)) else {
129            continue;
130        };
131        deltas.push(mt - mc);
132    }
133    if deltas.is_empty() {
134        return (None, None);
135    }
136    deltas.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
137    let lo_i = ((deltas.len() as f64 * 0.025).round() as usize).min(deltas.len() - 1);
138    let hi_i = ((deltas.len() as f64 * 0.975).round() as usize).min(deltas.len() - 1);
139    (Some(deltas[lo_i]), Some(deltas[hi_i]))
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn known_positive_shift_detected() {
148        // Two tight clusters separated by 100 — CI must clear 0 comfortably.
149        let control: Vec<f64> = (0..100).map(|_| 10.0).collect();
150        let treatment: Vec<f64> = (0..100).map(|_| 110.0).collect();
151        let s = summarize(&control, &treatment, 42, 1000);
152        assert_eq!(s.delta_median, Some(100.0));
153        let lo = s.ci95_lo.unwrap();
154        let hi = s.ci95_hi.unwrap();
155        assert!(lo > 0.0, "CI should exclude zero above, got {lo}");
156        assert!(hi >= lo);
157    }
158
159    #[test]
160    fn small_sample_warns() {
161        let c: Vec<f64> = vec![1.0, 2.0, 3.0];
162        let t: Vec<f64> = vec![4.0, 5.0, 6.0];
163        let s = summarize(&c, &t, 1, 100);
164        assert!(s.small_sample_warning);
165    }
166
167    #[test]
168    fn winsorize_clips_outliers() {
169        // 200 ordinary values + one extreme; p99 quantile ignores the tail.
170        let mut xs: Vec<f64> = (0..200).map(|i| i as f64).collect();
171        xs.push(10_000.0);
172        let w = winsorize(&xs, 0.01, 0.99);
173        let max_w = w.iter().cloned().fold(f64::MIN, f64::max);
174        assert!(max_w < 10_000.0, "extreme still present: {max_w}");
175    }
176
177    #[test]
178    fn empty_inputs_safe() {
179        let s = summarize(&[], &[], 0, 10);
180        assert_eq!(s.n_control, 0);
181        assert!(s.delta_median.is_none());
182        assert!(s.ci95_lo.is_none());
183    }
184}