Skip to main content

fluxbench_stats/
bootstrap.rs

1//! Bootstrap Resampling
2//!
3//! Implements both percentile and BCa (Bias-Corrected and Accelerated) bootstrap
4//! methods for computing confidence intervals.
5
6use crate::{BCA_THRESHOLD, DEFAULT_BOOTSTRAP_ITERATIONS, DEFAULT_CONFIDENCE_LEVEL};
7use rand::Rng;
8use rand::thread_rng;
9use rayon::prelude::*;
10use thiserror::Error;
11
12/// Bootstrap configuration
13#[derive(Debug, Clone)]
14pub struct BootstrapConfig {
15    /// Number of bootstrap iterations (default: 100,000)
16    pub iterations: usize,
17    /// Confidence level (default: 0.95 for 95% CI)
18    pub confidence_level: f64,
19    /// Whether to use parallel computation
20    pub parallel: bool,
21    /// Force BCa method even for large samples
22    pub force_bca: bool,
23}
24
25impl Default for BootstrapConfig {
26    fn default() -> Self {
27        Self {
28            iterations: DEFAULT_BOOTSTRAP_ITERATIONS,
29            confidence_level: DEFAULT_CONFIDENCE_LEVEL,
30            parallel: true,
31            force_bca: false,
32        }
33    }
34}
35
36/// Which bootstrap method was used
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum BootstrapMethod {
39    /// Standard percentile method (for N >= 100)
40    Percentile,
41    /// BCa method (for small samples or when forced)
42    Bca,
43}
44
45/// Confidence interval bounds
46#[derive(Debug, Clone, Copy)]
47pub struct ConfidenceInterval {
48    /// Lower bound of the confidence interval
49    pub lower: f64,
50    /// Upper bound of the confidence interval
51    pub upper: f64,
52    /// Confidence level (e.g., 0.95 for 95% CI)
53    pub level: f64,
54}
55
56/// Result of bootstrap analysis
57#[derive(Debug, Clone)]
58pub struct BootstrapResult {
59    /// Point estimate (sample mean)
60    pub point_estimate: f64,
61    /// Confidence interval
62    pub confidence_interval: ConfidenceInterval,
63    /// Standard error of the mean
64    pub standard_error: f64,
65    /// Which method was used
66    pub method: BootstrapMethod,
67    /// Warning message if any
68    pub warning: Option<String>,
69}
70
71/// Errors that can occur during bootstrap
72#[derive(Debug, Error)]
73#[non_exhaustive]
74pub enum BootstrapError {
75    #[error("Not enough samples: got {got}, need at least {min}")]
76    NotEnoughSamples { got: usize, min: usize },
77
78    #[error("Invalid confidence level: {0} (must be between 0 and 1)")]
79    InvalidConfidenceLevel(f64),
80
81    #[error("All samples have the same value")]
82    NoVariance,
83}
84
85/// Compute bootstrap confidence interval for the mean
86///
87/// Automatically selects BCa method for small samples (N < 100).
88///
89/// # Examples
90///
91/// ```ignore
92/// # use fluxbench_stats::{compute_bootstrap, BootstrapConfig};
93/// let samples = vec![100.0, 102.0, 98.0, 101.0, 99.0];
94/// let config = BootstrapConfig::default();
95/// let result = compute_bootstrap(&samples, &config).unwrap();
96/// println!("Mean: {}", result.point_estimate);
97/// println!("95% CI: [{}, {}]",
98///     result.confidence_interval.lower,
99///     result.confidence_interval.upper);
100/// ```
101pub fn compute_bootstrap(
102    samples: &[f64],
103    config: &BootstrapConfig,
104) -> Result<BootstrapResult, BootstrapError> {
105    // Validate inputs
106    if samples.len() < 3 {
107        return Err(BootstrapError::NotEnoughSamples {
108            got: samples.len(),
109            min: 3,
110        });
111    }
112
113    if config.confidence_level <= 0.0 || config.confidence_level >= 1.0 {
114        return Err(BootstrapError::InvalidConfidenceLevel(
115            config.confidence_level,
116        ));
117    }
118
119    let n = samples.len();
120    let point_estimate = mean(samples);
121
122    // Check for zero variance
123    let variance = samples
124        .iter()
125        .map(|x| (x - point_estimate).powi(2))
126        .sum::<f64>()
127        / n as f64;
128    if variance == 0.0 {
129        return Ok(BootstrapResult {
130            point_estimate,
131            confidence_interval: ConfidenceInterval {
132                lower: point_estimate,
133                upper: point_estimate,
134                level: config.confidence_level,
135            },
136            standard_error: 0.0,
137            method: BootstrapMethod::Percentile,
138            warning: Some("All samples have identical values".to_string()),
139        });
140    }
141
142    // Select method based on sample size
143    let use_bca = config.force_bca || n < BCA_THRESHOLD;
144
145    // Generate bootstrap distribution
146    let bootstrap_means = if config.parallel {
147        generate_bootstrap_means_parallel(samples, config.iterations)
148    } else {
149        generate_bootstrap_means_serial(samples, config.iterations)
150    };
151
152    // Sort bootstrap means once (shared by both CI methods and avoids per-call allocation)
153    let mut sorted_means = bootstrap_means.clone();
154    sorted_means.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
155
156    // Compute confidence interval
157    let (ci, method) = if use_bca {
158        let ci = bca_interval(samples, &sorted_means, config.confidence_level);
159        (ci, BootstrapMethod::Bca)
160    } else {
161        let ci = percentile_interval_sorted(&sorted_means, config.confidence_level);
162        (ci, BootstrapMethod::Percentile)
163    };
164
165    // Compute standard error from bootstrap distribution.
166    // By the bootstrap property, mean(bootstrap_means) == point_estimate,
167    // so we reuse it directly to avoid a redundant O(N) pass.
168    let se = (bootstrap_means
169        .iter()
170        .map(|x| (x - point_estimate).powi(2))
171        .sum::<f64>()
172        / bootstrap_means.len() as f64)
173        .sqrt();
174
175    let warning = if n < 10 {
176        Some("Very small sample size may lead to unreliable estimates".to_string())
177    } else {
178        None
179    };
180
181    Ok(BootstrapResult {
182        point_estimate,
183        confidence_interval: ConfidenceInterval {
184            lower: ci.0,
185            upper: ci.1,
186            level: config.confidence_level,
187        },
188        standard_error: se,
189        method,
190        warning,
191    })
192}
193
194/// Generate bootstrap means using parallel iteration (Rayon).
195///
196/// # Panics
197///
198/// Panics if `samples` is empty. Callers must validate this before calling.
199fn generate_bootstrap_means_parallel(samples: &[f64], iterations: usize) -> Vec<f64> {
200    assert!(!samples.is_empty(), "samples must not be empty");
201    let n = samples.len();
202    (0..iterations)
203        .into_par_iter()
204        .map_init(thread_rng, |rng, _| {
205            let mut sum = 0.0;
206            for _ in 0..n {
207                // SAFETY: index is always in bounds because gen_range(0..n) < n
208                sum += samples[rng.gen_range(0..n)];
209            }
210            sum / n as f64
211        })
212        .collect()
213}
214
215/// Generate bootstrap means serially (for testing or small samples).
216///
217/// # Panics
218///
219/// Panics if `samples` is empty. Callers must validate this before calling.
220fn generate_bootstrap_means_serial(samples: &[f64], iterations: usize) -> Vec<f64> {
221    assert!(!samples.is_empty(), "samples must not be empty");
222    let n = samples.len();
223    let mut rng = thread_rng();
224    (0..iterations)
225        .map(|_| {
226            let mut sum = 0.0;
227            for _ in 0..n {
228                sum += samples[rng.gen_range(0..n)];
229            }
230            sum / n as f64
231        })
232        .collect()
233}
234
235/// Standard percentile interval from pre-sorted bootstrap means.
236///
237/// # Arguments
238/// * `sorted` — Bootstrap means, **must be pre-sorted in ascending order**.
239/// * `confidence` — Confidence level (e.g. 0.95).
240fn percentile_interval_sorted(sorted: &[f64], confidence: f64) -> (f64, f64) {
241    let n = sorted.len();
242    let alpha = (1.0 - confidence) / 2.0;
243
244    let lower_idx = ((alpha * n as f64).floor() as usize).min(n - 1);
245    let upper_idx = (((1.0 - alpha) * n as f64).floor() as usize).min(n - 1);
246
247    (sorted[lower_idx], sorted[upper_idx])
248}
249
250/// BCa (Bias-Corrected and Accelerated) interval.
251///
252/// More accurate for small samples and skewed distributions.
253///
254/// # Arguments
255/// * `samples` — Original sample data (unsorted).
256/// * `sorted` — Bootstrap means, **must be pre-sorted in ascending order**.
257/// * `confidence` — Confidence level (e.g. 0.95).
258fn bca_interval(samples: &[f64], sorted: &[f64], confidence: f64) -> (f64, f64) {
259    let n = samples.len();
260    let b = sorted.len();
261
262    let theta_hat = mean(samples);
263
264    // Bias correction factor (z0)
265    let count_below = sorted.iter().filter(|&&x| x < theta_hat).count();
266    let prop = count_below as f64 / b as f64;
267    let z0 = normal_quantile(prop.clamp(0.0001, 0.9999));
268
269    // Acceleration factor (a) via jackknife
270    let jackknife_means: Vec<f64> = (0..n)
271        .map(|i| {
272            let sum: f64 = samples
273                .iter()
274                .enumerate()
275                .filter(|(j, _)| *j != i)
276                .map(|(_, &v)| v)
277                .sum();
278            sum / (n - 1) as f64
279        })
280        .collect();
281
282    let jack_mean = mean(&jackknife_means);
283    let numerator: f64 = jackknife_means
284        .iter()
285        .map(|x| (jack_mean - x).powi(3))
286        .sum();
287    let denominator: f64 = jackknife_means
288        .iter()
289        .map(|x| (jack_mean - x).powi(2))
290        .sum();
291
292    let a = if denominator.abs() < 1e-10 {
293        0.0
294    } else {
295        numerator / (6.0 * denominator.powf(1.5))
296    };
297
298    // Adjusted percentiles
299    let alpha = (1.0 - confidence) / 2.0;
300    let z_alpha = normal_quantile(alpha);
301    let z_1_alpha = normal_quantile(1.0 - alpha);
302
303    let alpha1 = normal_cdf(z0 + (z0 + z_alpha) / (1.0 - a * (z0 + z_alpha)));
304    let alpha2 = normal_cdf(z0 + (z0 + z_1_alpha) / (1.0 - a * (z0 + z_1_alpha)));
305
306    let lower_idx = ((alpha1 * b as f64).floor() as usize).clamp(0, b - 1);
307    let upper_idx = ((alpha2 * b as f64).floor() as usize).clamp(0, b - 1);
308
309    (sorted[lower_idx], sorted[upper_idx])
310}
311
312/// Compute mean of samples
313fn mean(samples: &[f64]) -> f64 {
314    if samples.is_empty() {
315        return 0.0;
316    }
317    samples.iter().sum::<f64>() / samples.len() as f64
318}
319
320/// Standard normal quantile (inverse CDF)
321fn normal_quantile(p: f64) -> f64 {
322    // Rational approximation for the normal quantile function
323    // Abramowitz and Stegun approximation (26.2.23)
324    if p <= 0.0 {
325        return f64::NEG_INFINITY;
326    }
327    if p >= 1.0 {
328        return f64::INFINITY;
329    }
330
331    let p = p.clamp(1e-10, 1.0 - 1e-10);
332
333    let sign = if p < 0.5 { -1.0 } else { 1.0 };
334    let p = if p < 0.5 { p } else { 1.0 - p };
335
336    let t = (-2.0 * p.ln()).sqrt();
337
338    // Coefficients for rational approximation
339    let c0 = 2.515517;
340    let c1 = 0.802853;
341    let c2 = 0.010328;
342    let d1 = 1.432788;
343    let d2 = 0.189269;
344    let d3 = 0.001308;
345
346    let x = t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t);
347
348    sign * x
349}
350
351/// Standard normal CDF
352fn normal_cdf(x: f64) -> f64 {
353    0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
354}
355
356/// Error function approximation
357fn erf(x: f64) -> f64 {
358    // Abramowitz and Stegun approximation (7.1.26)
359    let a1 = 0.254829592;
360    let a2 = -0.284496736;
361    let a3 = 1.421413741;
362    let a4 = -1.453152027;
363    let a5 = 1.061405429;
364    let p = 0.3275911;
365
366    let sign = if x >= 0.0 { 1.0 } else { -1.0 };
367    let x = x.abs();
368
369    let t = 1.0 / (1.0 + p * x);
370    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
371
372    sign * y
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[test]
380    fn test_basic_bootstrap() {
381        let samples: Vec<f64> = (0..100).map(|x| x as f64).collect();
382        let config = BootstrapConfig {
383            iterations: 1000,
384            ..Default::default()
385        };
386
387        let result = compute_bootstrap(&samples, &config).unwrap();
388
389        // Mean should be around 49.5
390        assert!((result.point_estimate - 49.5).abs() < 0.1);
391
392        // CI should contain the mean
393        assert!(result.confidence_interval.lower < result.point_estimate);
394        assert!(result.confidence_interval.upper > result.point_estimate);
395    }
396
397    #[test]
398    fn test_bca_for_small_samples() {
399        let samples = vec![1.0, 2.0, 3.0, 4.0, 5.0];
400        let config = BootstrapConfig {
401            iterations: 1000,
402            ..Default::default()
403        };
404
405        let result = compute_bootstrap(&samples, &config).unwrap();
406
407        // Should use BCa for small sample
408        assert_eq!(result.method, BootstrapMethod::Bca);
409    }
410
411    #[test]
412    fn test_percentile_for_large_samples() {
413        let samples: Vec<f64> = (0..200).map(|x| x as f64).collect();
414        let config = BootstrapConfig {
415            iterations: 1000,
416            force_bca: false,
417            ..Default::default()
418        };
419
420        let result = compute_bootstrap(&samples, &config).unwrap();
421
422        // Should use percentile for large sample
423        assert_eq!(result.method, BootstrapMethod::Percentile);
424    }
425
426    #[test]
427    fn test_force_bca() {
428        let samples: Vec<f64> = (0..200).map(|x| x as f64).collect();
429        let config = BootstrapConfig {
430            iterations: 1000,
431            force_bca: true,
432            ..Default::default()
433        };
434
435        let result = compute_bootstrap(&samples, &config).unwrap();
436
437        // Should use BCa when forced
438        assert_eq!(result.method, BootstrapMethod::Bca);
439    }
440
441    #[test]
442    fn test_not_enough_samples() {
443        let samples = vec![1.0, 2.0];
444        let config = BootstrapConfig::default();
445
446        let result = compute_bootstrap(&samples, &config);
447        assert!(matches!(
448            result,
449            Err(BootstrapError::NotEnoughSamples { .. })
450        ));
451    }
452
453    #[test]
454    fn test_normal_quantile() {
455        // Test known values
456        assert!((normal_quantile(0.5) - 0.0).abs() < 0.01);
457        assert!((normal_quantile(0.975) - 1.96).abs() < 0.01);
458        assert!((normal_quantile(0.025) - (-1.96)).abs() < 0.01);
459    }
460
461    #[test]
462    fn test_normal_cdf() {
463        // Test known values
464        assert!((normal_cdf(0.0) - 0.5).abs() < 0.01);
465        assert!((normal_cdf(1.96) - 0.975).abs() < 0.01);
466        assert!((normal_cdf(-1.96) - 0.025).abs() < 0.01);
467    }
468}