Skip to main content

kaizen/experiment/stats/
bootstrap.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Bootstrap CI and winsorization helpers.
3use rand::rngs::SmallRng;
4use rand::{RngExt, SeedableRng};
5
6/// Clamp values to `[p_lo quantile, p_hi quantile]`.
7pub fn winsorize(xs: &[f64], p_lo: f64, p_hi: f64) -> Vec<f64> {
8    if xs.is_empty() {
9        return Vec::new();
10    }
11    let Some(lo) = quantile(xs, p_lo) else {
12        return xs.to_vec();
13    };
14    let Some(hi) = quantile(xs, p_hi) else {
15        return xs.to_vec();
16    };
17    xs.iter().map(|v| v.clamp(lo, hi)).collect()
18}
19
20/// 95% percentile bootstrap CI on the median delta (treatment − control).
21pub fn bootstrap_ci(
22    control: &[f64],
23    treatment: &[f64],
24    seed: u64,
25    resamples: u32,
26) -> (Option<f64>, Option<f64>) {
27    let mut rng = SmallRng::seed_from_u64(seed);
28    let mut deltas: Vec<f64> = Vec::with_capacity(resamples as usize);
29    let mut buf_c = vec![0.0_f64; control.len()];
30    let mut buf_t = vec![0.0_f64; treatment.len()];
31    for _ in 0..resamples {
32        for slot in buf_c.iter_mut() {
33            *slot = control[rng.random_range(0..control.len())];
34        }
35        for slot in buf_t.iter_mut() {
36            *slot = treatment[rng.random_range(0..treatment.len())];
37        }
38        let (Some(mc), Some(mt)) = (median(&buf_c), median(&buf_t)) else {
39            continue;
40        };
41        deltas.push(mt - mc);
42    }
43    if deltas.is_empty() {
44        return (None, None);
45    }
46    deltas.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
47    let lo_i = ((deltas.len() as f64 * 0.025).round() as usize).min(deltas.len() - 1);
48    let hi_i = ((deltas.len() as f64 * 0.975).round() as usize).min(deltas.len() - 1);
49    (Some(deltas[lo_i]), Some(deltas[hi_i]))
50}
51
52/// Block-bootstrap CI where each element of `clusters_*` is one cluster's values.
53///
54/// Resamples whole clusters with replacement so within-cluster correlation
55/// doesn't inflate precision. Falls back to point-wise bootstrap when clusters
56/// are singletons (one session per cluster).
57pub fn cluster_bootstrap_ci(
58    clusters_control: &[Vec<f64>],
59    clusters_treatment: &[Vec<f64>],
60    seed: u64,
61    resamples: u32,
62) -> (Option<f64>, Option<f64>) {
63    if clusters_control.is_empty() || clusters_treatment.is_empty() {
64        return (None, None);
65    }
66    let mut rng = SmallRng::seed_from_u64(seed);
67    let mut deltas: Vec<f64> = Vec::with_capacity(resamples as usize);
68    for _ in 0..resamples {
69        let sample_c: Vec<f64> = (0..clusters_control.len())
70            .flat_map(|_| {
71                let idx = rng.random_range(0..clusters_control.len());
72                clusters_control[idx].iter().copied()
73            })
74            .collect();
75        let sample_t: Vec<f64> = (0..clusters_treatment.len())
76            .flat_map(|_| {
77                let idx = rng.random_range(0..clusters_treatment.len());
78                clusters_treatment[idx].iter().copied()
79            })
80            .collect();
81        let (Some(mc), Some(mt)) = (median(&sample_c), median(&sample_t)) else {
82            continue;
83        };
84        deltas.push(mt - mc);
85    }
86    if deltas.is_empty() {
87        return (None, None);
88    }
89    deltas.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
90    let lo_i = ((deltas.len() as f64 * 0.025).round() as usize).min(deltas.len() - 1);
91    let hi_i = ((deltas.len() as f64 * 0.975).round() as usize).min(deltas.len() - 1);
92    (Some(deltas[lo_i]), Some(deltas[hi_i]))
93}
94
95pub fn quantile(xs: &[f64], p: f64) -> Option<f64> {
96    if xs.is_empty() {
97        return None;
98    }
99    let mut v = xs.to_vec();
100    v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
101    let idx = ((v.len() - 1) as f64 * p).round() as usize;
102    Some(v[idx.min(v.len() - 1)])
103}
104
105pub fn median(xs: &[f64]) -> Option<f64> {
106    if xs.is_empty() {
107        return None;
108    }
109    let mut v = xs.to_vec();
110    v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
111    let n = v.len();
112    if n % 2 == 1 {
113        Some(v[n / 2])
114    } else {
115        Some((v[n / 2 - 1] + v[n / 2]) / 2.0)
116    }
117}
118
119pub fn mean(xs: &[f64]) -> Option<f64> {
120    if xs.is_empty() {
121        return None;
122    }
123    Some(xs.iter().sum::<f64>() / xs.len() as f64)
124}