1use crate::{BCA_THRESHOLD, DEFAULT_BOOTSTRAP_ITERATIONS, DEFAULT_CONFIDENCE_LEVEL};
7use rand::Rng;
8use rand::thread_rng;
9use rayon::prelude::*;
10use thiserror::Error;
11
12#[derive(Debug, Clone)]
14pub struct BootstrapConfig {
15 pub iterations: usize,
17 pub confidence_level: f64,
19 pub parallel: bool,
21 pub force_bca: bool,
23}
24
25impl Default for BootstrapConfig {
26 fn default() -> Self {
27 Self {
28 iterations: DEFAULT_BOOTSTRAP_ITERATIONS,
29 confidence_level: DEFAULT_CONFIDENCE_LEVEL,
30 parallel: true,
31 force_bca: false,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum BootstrapMethod {
39 Percentile,
41 Bca,
43}
44
45#[derive(Debug, Clone, Copy)]
47pub struct ConfidenceInterval {
48 pub lower: f64,
50 pub upper: f64,
52 pub level: f64,
54}
55
56#[derive(Debug, Clone)]
58pub struct BootstrapResult {
59 pub point_estimate: f64,
61 pub confidence_interval: ConfidenceInterval,
63 pub standard_error: f64,
65 pub method: BootstrapMethod,
67 pub warning: Option<String>,
69}
70
71#[derive(Debug, Error)]
73#[non_exhaustive]
74pub enum BootstrapError {
75 #[error("Not enough samples: got {got}, need at least {min}")]
76 NotEnoughSamples { got: usize, min: usize },
77
78 #[error("Invalid confidence level: {0} (must be between 0 and 1)")]
79 InvalidConfidenceLevel(f64),
80
81 #[error("All samples have the same value")]
82 NoVariance,
83}
84
85pub fn compute_bootstrap(
102 samples: &[f64],
103 config: &BootstrapConfig,
104) -> Result<BootstrapResult, BootstrapError> {
105 if samples.len() < 3 {
107 return Err(BootstrapError::NotEnoughSamples {
108 got: samples.len(),
109 min: 3,
110 });
111 }
112
113 if config.confidence_level <= 0.0 || config.confidence_level >= 1.0 {
114 return Err(BootstrapError::InvalidConfidenceLevel(
115 config.confidence_level,
116 ));
117 }
118
119 let n = samples.len();
120 let point_estimate = mean(samples);
121
122 let variance = samples
124 .iter()
125 .map(|x| (x - point_estimate).powi(2))
126 .sum::<f64>()
127 / n as f64;
128 if variance == 0.0 {
129 return Ok(BootstrapResult {
130 point_estimate,
131 confidence_interval: ConfidenceInterval {
132 lower: point_estimate,
133 upper: point_estimate,
134 level: config.confidence_level,
135 },
136 standard_error: 0.0,
137 method: BootstrapMethod::Percentile,
138 warning: Some("All samples have identical values".to_string()),
139 });
140 }
141
142 let use_bca = config.force_bca || n < BCA_THRESHOLD;
144
145 let bootstrap_means = if config.parallel {
147 generate_bootstrap_means_parallel(samples, config.iterations)
148 } else {
149 generate_bootstrap_means_serial(samples, config.iterations)
150 };
151
152 let mut sorted_means = bootstrap_means.clone();
154 sorted_means.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
155
156 let (ci, method) = if use_bca {
158 let ci = bca_interval(samples, &sorted_means, config.confidence_level);
159 (ci, BootstrapMethod::Bca)
160 } else {
161 let ci = percentile_interval_sorted(&sorted_means, config.confidence_level);
162 (ci, BootstrapMethod::Percentile)
163 };
164
165 let se = (bootstrap_means
169 .iter()
170 .map(|x| (x - point_estimate).powi(2))
171 .sum::<f64>()
172 / bootstrap_means.len() as f64)
173 .sqrt();
174
175 let warning = if n < 10 {
176 Some("Very small sample size may lead to unreliable estimates".to_string())
177 } else {
178 None
179 };
180
181 Ok(BootstrapResult {
182 point_estimate,
183 confidence_interval: ConfidenceInterval {
184 lower: ci.0,
185 upper: ci.1,
186 level: config.confidence_level,
187 },
188 standard_error: se,
189 method,
190 warning,
191 })
192}
193
194fn generate_bootstrap_means_parallel(samples: &[f64], iterations: usize) -> Vec<f64> {
200 assert!(!samples.is_empty(), "samples must not be empty");
201 let n = samples.len();
202 (0..iterations)
203 .into_par_iter()
204 .map_init(thread_rng, |rng, _| {
205 let mut sum = 0.0;
206 for _ in 0..n {
207 sum += samples[rng.gen_range(0..n)];
209 }
210 sum / n as f64
211 })
212 .collect()
213}
214
215fn generate_bootstrap_means_serial(samples: &[f64], iterations: usize) -> Vec<f64> {
221 assert!(!samples.is_empty(), "samples must not be empty");
222 let n = samples.len();
223 let mut rng = thread_rng();
224 (0..iterations)
225 .map(|_| {
226 let mut sum = 0.0;
227 for _ in 0..n {
228 sum += samples[rng.gen_range(0..n)];
229 }
230 sum / n as f64
231 })
232 .collect()
233}
234
235fn percentile_interval_sorted(sorted: &[f64], confidence: f64) -> (f64, f64) {
241 let n = sorted.len();
242 let alpha = (1.0 - confidence) / 2.0;
243
244 let lower_idx = ((alpha * n as f64).floor() as usize).min(n - 1);
245 let upper_idx = (((1.0 - alpha) * n as f64).floor() as usize).min(n - 1);
246
247 (sorted[lower_idx], sorted[upper_idx])
248}
249
250fn bca_interval(samples: &[f64], sorted: &[f64], confidence: f64) -> (f64, f64) {
259 let n = samples.len();
260 let b = sorted.len();
261
262 let theta_hat = mean(samples);
263
264 let count_below = sorted.iter().filter(|&&x| x < theta_hat).count();
266 let prop = count_below as f64 / b as f64;
267 let z0 = normal_quantile(prop.clamp(0.0001, 0.9999));
268
269 let jackknife_means: Vec<f64> = (0..n)
271 .map(|i| {
272 let sum: f64 = samples
273 .iter()
274 .enumerate()
275 .filter(|(j, _)| *j != i)
276 .map(|(_, &v)| v)
277 .sum();
278 sum / (n - 1) as f64
279 })
280 .collect();
281
282 let jack_mean = mean(&jackknife_means);
283 let numerator: f64 = jackknife_means
284 .iter()
285 .map(|x| (jack_mean - x).powi(3))
286 .sum();
287 let denominator: f64 = jackknife_means
288 .iter()
289 .map(|x| (jack_mean - x).powi(2))
290 .sum();
291
292 let a = if denominator.abs() < 1e-10 {
293 0.0
294 } else {
295 numerator / (6.0 * denominator.powf(1.5))
296 };
297
298 let alpha = (1.0 - confidence) / 2.0;
300 let z_alpha = normal_quantile(alpha);
301 let z_1_alpha = normal_quantile(1.0 - alpha);
302
303 let alpha1 = normal_cdf(z0 + (z0 + z_alpha) / (1.0 - a * (z0 + z_alpha)));
304 let alpha2 = normal_cdf(z0 + (z0 + z_1_alpha) / (1.0 - a * (z0 + z_1_alpha)));
305
306 let lower_idx = ((alpha1 * b as f64).floor() as usize).clamp(0, b - 1);
307 let upper_idx = ((alpha2 * b as f64).floor() as usize).clamp(0, b - 1);
308
309 (sorted[lower_idx], sorted[upper_idx])
310}
311
312fn mean(samples: &[f64]) -> f64 {
314 if samples.is_empty() {
315 return 0.0;
316 }
317 samples.iter().sum::<f64>() / samples.len() as f64
318}
319
320fn normal_quantile(p: f64) -> f64 {
322 if p <= 0.0 {
325 return f64::NEG_INFINITY;
326 }
327 if p >= 1.0 {
328 return f64::INFINITY;
329 }
330
331 let p = p.clamp(1e-10, 1.0 - 1e-10);
332
333 let sign = if p < 0.5 { -1.0 } else { 1.0 };
334 let p = if p < 0.5 { p } else { 1.0 - p };
335
336 let t = (-2.0 * p.ln()).sqrt();
337
338 let c0 = 2.515517;
340 let c1 = 0.802853;
341 let c2 = 0.010328;
342 let d1 = 1.432788;
343 let d2 = 0.189269;
344 let d3 = 0.001308;
345
346 let x = t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t);
347
348 sign * x
349}
350
351fn normal_cdf(x: f64) -> f64 {
353 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
354}
355
356fn erf(x: f64) -> f64 {
358 let a1 = 0.254829592;
360 let a2 = -0.284496736;
361 let a3 = 1.421413741;
362 let a4 = -1.453152027;
363 let a5 = 1.061405429;
364 let p = 0.3275911;
365
366 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
367 let x = x.abs();
368
369 let t = 1.0 / (1.0 + p * x);
370 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
371
372 sign * y
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[test]
380 fn test_basic_bootstrap() {
381 let samples: Vec<f64> = (0..100).map(|x| x as f64).collect();
382 let config = BootstrapConfig {
383 iterations: 1000,
384 ..Default::default()
385 };
386
387 let result = compute_bootstrap(&samples, &config).unwrap();
388
389 assert!((result.point_estimate - 49.5).abs() < 0.1);
391
392 assert!(result.confidence_interval.lower < result.point_estimate);
394 assert!(result.confidence_interval.upper > result.point_estimate);
395 }
396
397 #[test]
398 fn test_bca_for_small_samples() {
399 let samples = vec![1.0, 2.0, 3.0, 4.0, 5.0];
400 let config = BootstrapConfig {
401 iterations: 1000,
402 ..Default::default()
403 };
404
405 let result = compute_bootstrap(&samples, &config).unwrap();
406
407 assert_eq!(result.method, BootstrapMethod::Bca);
409 }
410
411 #[test]
412 fn test_percentile_for_large_samples() {
413 let samples: Vec<f64> = (0..200).map(|x| x as f64).collect();
414 let config = BootstrapConfig {
415 iterations: 1000,
416 force_bca: false,
417 ..Default::default()
418 };
419
420 let result = compute_bootstrap(&samples, &config).unwrap();
421
422 assert_eq!(result.method, BootstrapMethod::Percentile);
424 }
425
426 #[test]
427 fn test_force_bca() {
428 let samples: Vec<f64> = (0..200).map(|x| x as f64).collect();
429 let config = BootstrapConfig {
430 iterations: 1000,
431 force_bca: true,
432 ..Default::default()
433 };
434
435 let result = compute_bootstrap(&samples, &config).unwrap();
436
437 assert_eq!(result.method, BootstrapMethod::Bca);
439 }
440
441 #[test]
442 fn test_not_enough_samples() {
443 let samples = vec![1.0, 2.0];
444 let config = BootstrapConfig::default();
445
446 let result = compute_bootstrap(&samples, &config);
447 assert!(matches!(
448 result,
449 Err(BootstrapError::NotEnoughSamples { .. })
450 ));
451 }
452
453 #[test]
454 fn test_normal_quantile() {
455 assert!((normal_quantile(0.5) - 0.0).abs() < 0.01);
457 assert!((normal_quantile(0.975) - 1.96).abs() < 0.01);
458 assert!((normal_quantile(0.025) - (-1.96)).abs() < 0.01);
459 }
460
461 #[test]
462 fn test_normal_cdf() {
463 assert!((normal_cdf(0.0) - 0.5).abs() < 0.01);
465 assert!((normal_cdf(1.96) - 0.975).abs() < 0.01);
466 assert!((normal_cdf(-1.96) - 0.025).abs() < 0.01);
467 }
468}