Skip to main content

kaizen/experiment/stats/
mod.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Non-parametric stats for experiment reports.
3//!
4//! Effect size = median(treatment) − median(control). CI via
5//! percentile bootstrap (default 10k resamples, 95%). Winsorize p1/p99
6//! before resampling to blunt skew.
7
8pub mod bootstrap;
9pub mod cuped;
10pub mod power;
11pub mod sequential;
12pub mod srm;
13
14pub use bootstrap::winsorize;
15pub use srm::has_srm;
16
17use bootstrap::{bootstrap_ci, mean, median};
18use serde::{Deserialize, Serialize};
19
20pub const DEFAULT_RESAMPLES: u32 = 10_000;
21pub const MIN_SAMPLE: usize = 30;
22
23#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
24pub struct Summary {
25    pub n_control: usize,
26    pub n_treatment: usize,
27    pub median_control: Option<f64>,
28    pub median_treatment: Option<f64>,
29    pub mean_control: Option<f64>,
30    pub mean_treatment: Option<f64>,
31    pub delta_median: Option<f64>,
32    pub delta_pct: Option<f64>,
33    pub ci95_lo: Option<f64>,
34    pub ci95_hi: Option<f64>,
35    pub small_sample_warning: bool,
36    /// Set when observed arm counts deviate from expected 50/50 at p < 0.001.
37    pub srm_warning: bool,
38}
39
40/// Pure stats for a metric. Deterministic given `seed`.
41pub fn summarize(control: &[f64], treatment: &[f64], seed: u64, resamples: u32) -> Summary {
42    let c = winsorize(control, 0.01, 0.99);
43    let t = winsorize(treatment, 0.01, 0.99);
44    let median_c = median(&c);
45    let median_t = median(&t);
46    let mean_c = mean(&c);
47    let mean_t = mean(&t);
48    let delta = match (median_c, median_t) {
49        (Some(a), Some(b)) => Some(b - a),
50        _ => None,
51    };
52    let delta_pct = match (median_c, delta) {
53        (Some(a), Some(d)) if a != 0.0 => Some(100.0 * d / a),
54        _ => None,
55    };
56    let (lo, hi) = if c.is_empty() || t.is_empty() {
57        (None, None)
58    } else {
59        bootstrap_ci(&c, &t, seed, resamples)
60    };
61    Summary {
62        n_control: control.len(),
63        n_treatment: treatment.len(),
64        median_control: median_c,
65        median_treatment: median_t,
66        mean_control: mean_c,
67        mean_treatment: mean_t,
68        delta_median: delta,
69        delta_pct,
70        ci95_lo: lo,
71        ci95_hi: hi,
72        small_sample_warning: control.len().min(treatment.len()) < MIN_SAMPLE,
73        srm_warning: has_srm(control.len(), treatment.len()),
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80
81    #[test]
82    fn known_positive_shift_detected() {
83        let control: Vec<f64> = (0..100).map(|_| 10.0).collect();
84        let treatment: Vec<f64> = (0..100).map(|_| 110.0).collect();
85        let s = summarize(&control, &treatment, 42, 1000);
86        assert_eq!(s.delta_median, Some(100.0));
87        let lo = s.ci95_lo.unwrap();
88        let hi = s.ci95_hi.unwrap();
89        assert!(lo > 0.0, "CI should exclude zero above, got {lo}");
90        assert!(hi >= lo);
91        assert!(!s.srm_warning);
92    }
93
94    #[test]
95    fn srm_warning_on_imbalance() {
96        let control: Vec<f64> = (0..800).map(|_| 1.0).collect();
97        let treatment: Vec<f64> = (0..200).map(|_| 1.0).collect();
98        let s = summarize(&control, &treatment, 0, 100);
99        assert!(s.srm_warning, "should flag SRM for 800:200 split");
100    }
101
102    #[test]
103    fn small_sample_warns() {
104        let c: Vec<f64> = vec![1.0, 2.0, 3.0];
105        let t: Vec<f64> = vec![4.0, 5.0, 6.0];
106        let s = summarize(&c, &t, 1, 100);
107        assert!(s.small_sample_warning);
108    }
109
110    #[test]
111    fn winsorize_clips_outliers() {
112        let mut xs: Vec<f64> = (0..200).map(|i| i as f64).collect();
113        xs.push(10_000.0);
114        let w = winsorize(&xs, 0.01, 0.99);
115        let max_w = w.iter().cloned().fold(f64::MIN, f64::max);
116        assert!(max_w < 10_000.0, "extreme still present: {max_w}");
117    }
118
119    #[test]
120    fn empty_inputs_safe() {
121        let s = summarize(&[], &[], 0, 10);
122        assert_eq!(s.n_control, 0);
123        assert!(s.delta_median.is_none());
124        assert!(s.ci95_lo.is_none());
125    }
126}