Skip to main content

mollendorff_forge/bootstrap/
engine.rs

1//! Bootstrap Resampling Engine
2//!
3//! Implements non-parametric bootstrap for confidence intervals.
4//! Validated against R's boot package.
5
6use super::config::{BootstrapConfig, BootstrapStatistic};
7use rand::rngs::StdRng;
8use rand::{RngExt, SeedableRng};
9use serde::{Deserialize, Serialize};
10
11/// A confidence interval
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ConfidenceInterval {
14    /// Confidence level (e.g., 0.95)
15    pub level: f64,
16    /// Lower bound
17    pub lower: f64,
18    /// Upper bound
19    pub upper: f64,
20}
21
22impl ConfidenceInterval {
23    /// Create a new confidence interval
24    #[must_use]
25    pub const fn new(level: f64, lower: f64, upper: f64) -> Self {
26        Self {
27            level,
28            lower,
29            upper,
30        }
31    }
32
33    /// Width of the interval
34    #[must_use]
35    pub fn width(&self) -> f64 {
36        self.upper - self.lower
37    }
38}
39
40/// Bootstrap analysis result
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct BootstrapResult {
43    /// Original sample statistic
44    pub original_estimate: f64,
45    /// Bootstrap mean estimate
46    pub bootstrap_mean: f64,
47    /// Bootstrap standard error
48    pub bootstrap_std_error: f64,
49    /// Bias (bootstrap mean - original)
50    pub bias: f64,
51    /// Confidence intervals
52    pub confidence_intervals: Vec<ConfidenceInterval>,
53    /// Bootstrap distribution (all resampled statistics)
54    pub distribution: Vec<f64>,
55    /// Number of bootstrap iterations
56    pub iterations: usize,
57}
58
59impl BootstrapResult {
60    /// Export results to YAML format
61    #[must_use]
62    pub fn to_yaml(&self) -> String {
63        serde_yaml_ng::to_string(self).unwrap_or_else(|_| "# Error serializing results".to_string())
64    }
65
66    /// Export results to JSON format
67    ///
68    /// # Errors
69    ///
70    /// Returns an error if JSON serialization fails.
71    pub fn to_json(&self) -> Result<String, serde_json::Error> {
72        serde_json::to_string_pretty(self)
73    }
74
75    /// Get the bias-corrected estimate
76    #[must_use]
77    pub fn bias_corrected_estimate(&self) -> f64 {
78        self.original_estimate - self.bias
79    }
80}
81
82/// Bootstrap Resampling Engine
83pub struct BootstrapEngine {
84    config: BootstrapConfig,
85    rng: StdRng,
86}
87
88impl BootstrapEngine {
89    /// Create a new bootstrap engine
90    ///
91    /// # Errors
92    ///
93    /// Returns an error if the configuration is invalid (see [`BootstrapConfig::validate`]).
94    pub fn new(config: BootstrapConfig) -> Result<Self, String> {
95        config.validate()?;
96
97        let rng = config
98            .seed
99            .map_or_else(|| StdRng::from_rng(&mut rand::rng()), StdRng::seed_from_u64);
100
101        Ok(Self { config, rng })
102    }
103
104    /// Run the bootstrap analysis
105    ///
106    /// # Errors
107    ///
108    /// Returns an error if the analysis fails.
109    pub fn analyze(&mut self) -> Result<BootstrapResult, String> {
110        let data = &self.config.data;
111        let n = data.len();
112
113        // Calculate original estimate
114        let original_estimate = self.compute_statistic(data);
115
116        // Bootstrap resampling
117        let mut distribution = Vec::with_capacity(self.config.iterations);
118
119        for _ in 0..self.config.iterations {
120            // Resample with replacement
121            let sample: Vec<f64> = (0..n)
122                .map(|_| {
123                    let idx = self.rng.random_range(0..n);
124                    data[idx]
125                })
126                .collect();
127
128            let stat = self.compute_statistic(&sample);
129            distribution.push(stat);
130        }
131
132        // Sort for percentile calculation
133        distribution.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
134
135        // Calculate bootstrap statistics
136        let bootstrap_mean = distribution.iter().sum::<f64>() / distribution.len() as f64;
137        let variance: f64 = distribution
138            .iter()
139            .map(|x| (x - bootstrap_mean).powi(2))
140            .sum::<f64>()
141            / (distribution.len() - 1) as f64;
142        let bootstrap_std_error = variance.sqrt();
143        let bias = bootstrap_mean - original_estimate;
144
145        // Calculate confidence intervals
146        let confidence_intervals = self.calculate_confidence_intervals(&distribution);
147
148        Ok(BootstrapResult {
149            original_estimate,
150            bootstrap_mean,
151            bootstrap_std_error,
152            bias,
153            confidence_intervals,
154            distribution,
155            iterations: self.config.iterations,
156        })
157    }
158
159    /// Compute the statistic on a sample
160    fn compute_statistic(&self, sample: &[f64]) -> f64 {
161        if sample.is_empty() {
162            return 0.0;
163        }
164
165        match self.config.statistic {
166            BootstrapStatistic::Mean => sample.iter().sum::<f64>() / sample.len() as f64,
167            BootstrapStatistic::Median => {
168                let mut sorted = sample.to_vec();
169                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
170                let mid = sorted.len() / 2;
171                if sorted.len().is_multiple_of(2) {
172                    f64::midpoint(sorted[mid - 1], sorted[mid])
173                } else {
174                    sorted[mid]
175                }
176            },
177            BootstrapStatistic::Std => {
178                let mean = sample.iter().sum::<f64>() / sample.len() as f64;
179                let variance: f64 = sample.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
180                    / (sample.len() - 1) as f64;
181                variance.sqrt()
182            },
183            BootstrapStatistic::Var => {
184                let mean = sample.iter().sum::<f64>() / sample.len() as f64;
185                sample.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (sample.len() - 1) as f64
186            },
187            BootstrapStatistic::Percentile => {
188                let mut sorted = sample.to_vec();
189                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
190                // cast_possible_truncation: percentile index is always in [0, sorted.len()-1]
191                #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
192                let idx = ((self.config.percentile_value / 100.0) * (sorted.len() as f64 - 1.0))
193                    .round() as usize;
194                sorted[idx.min(sorted.len() - 1)]
195            },
196            BootstrapStatistic::Min => sample.iter().copied().fold(f64::INFINITY, f64::min),
197            BootstrapStatistic::Max => sample.iter().copied().fold(f64::NEG_INFINITY, f64::max),
198        }
199    }
200
201    /// Calculate confidence intervals using percentile method
202    fn calculate_confidence_intervals(&self, distribution: &[f64]) -> Vec<ConfidenceInterval> {
203        self.config
204            .confidence_levels
205            .iter()
206            .map(|&level| {
207                let alpha = 1.0 - level;
208                // cast_possible_truncation: indices are bounded by [0, distribution.len()]
209                #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
210                let lower_idx = ((alpha / 2.0) * distribution.len() as f64) as usize;
211                #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
212                let upper_idx = ((1.0 - alpha / 2.0) * distribution.len() as f64) as usize;
213
214                ConfidenceInterval::new(
215                    level,
216                    distribution[lower_idx.min(distribution.len() - 1)],
217                    distribution[upper_idx.min(distribution.len() - 1)],
218                )
219            })
220            .collect()
221    }
222
223    /// Get the configuration
224    #[must_use]
225    pub const fn config(&self) -> &BootstrapConfig {
226        &self.config
227    }
228}
229
230#[cfg(test)]
231mod engine_tests {
232    use super::*;
233
234    #[test]
235    fn test_bootstrap_mean() {
236        let config = BootstrapConfig::new()
237            .with_data(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
238            .with_iterations(5000)
239            .with_seed(12345);
240
241        let mut engine = BootstrapEngine::new(config).unwrap();
242        let result = engine.analyze().unwrap();
243
244        // Original mean should be 5.5
245        assert!(
246            (result.original_estimate - 5.5).abs() < 0.01,
247            "Original mean should be 5.5"
248        );
249
250        // Bootstrap mean should be close to original
251        assert!(
252            (result.bootstrap_mean - 5.5).abs() < 0.5,
253            "Bootstrap mean should be close to 5.5"
254        );
255
256        // Should have confidence intervals
257        assert!(!result.confidence_intervals.is_empty());
258    }
259
260    #[test]
261    fn test_bootstrap_median() {
262        let config = BootstrapConfig::new()
263            .with_data(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
264            .with_statistic(BootstrapStatistic::Median)
265            .with_iterations(5000)
266            .with_seed(12345);
267
268        let mut engine = BootstrapEngine::new(config).unwrap();
269        let result = engine.analyze().unwrap();
270
271        // Original median should be 5.5
272        assert!(
273            (result.original_estimate - 5.5).abs() < 0.01,
274            "Original median should be 5.5"
275        );
276    }
277
278    #[test]
279    fn test_confidence_intervals() {
280        let config = BootstrapConfig::new()
281            .with_data(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
282            .with_confidence_levels(vec![0.90, 0.95])
283            .with_iterations(10000)
284            .with_seed(12345);
285
286        let mut engine = BootstrapEngine::new(config).unwrap();
287        let result = engine.analyze().unwrap();
288
289        assert_eq!(result.confidence_intervals.len(), 2);
290
291        // 95% CI should be wider than 90% CI
292        let ci_90 = result
293            .confidence_intervals
294            .iter()
295            .find(|ci| (ci.level - 0.90).abs() < 0.01)
296            .unwrap();
297        let ci_95 = result
298            .confidence_intervals
299            .iter()
300            .find(|ci| (ci.level - 0.95).abs() < 0.01)
301            .unwrap();
302
303        assert!(
304            ci_95.width() >= ci_90.width(),
305            "95% CI should be >= 90% CI width"
306        );
307    }
308
309    #[test]
310    fn test_reproducibility() {
311        let config1 = BootstrapConfig::new()
312            .with_data(vec![1.0, 2.0, 3.0, 4.0, 5.0])
313            .with_iterations(1000)
314            .with_seed(42);
315
316        let config2 = BootstrapConfig::new()
317            .with_data(vec![1.0, 2.0, 3.0, 4.0, 5.0])
318            .with_iterations(1000)
319            .with_seed(42);
320
321        let mut engine1 = BootstrapEngine::new(config1).unwrap();
322        let mut engine2 = BootstrapEngine::new(config2).unwrap();
323
324        let result1 = engine1.analyze().unwrap();
325        let result2 = engine2.analyze().unwrap();
326
327        assert!(
328            (result1.bootstrap_mean - result2.bootstrap_mean).abs() < 0.0001,
329            "Same seed should produce same results"
330        );
331    }
332
333    /// R boot package equivalence test
334    #[test]
335    fn test_r_boot_equivalence() {
336        // This test validates against R's boot package
337        // R code:
338        //   library(boot)
339        //   data <- c(5, -2, 8, 3, -5, 12, 1, -1, 6, 4)
340        //   mean_func <- function(d, i) mean(d[i])
341        //   results <- boot(data, mean_func, R=10000, seed=12345)
342        //   boot.ci(results, type="perc")
343
344        let config = BootstrapConfig::new()
345            .with_data(vec![5.0, -2.0, 8.0, 3.0, -5.0, 12.0, 1.0, -1.0, 6.0, 4.0])
346            .with_iterations(10000)
347            .with_seed(12345)
348            .with_confidence_levels(vec![0.95]);
349
350        let mut engine = BootstrapEngine::new(config).unwrap();
351        let result = engine.analyze().unwrap();
352
353        // Original mean = 3.1
354        assert!(
355            (result.original_estimate - 3.1).abs() < 0.01,
356            "Original mean should be 3.1"
357        );
358
359        // Bootstrap mean should be close to 3.1
360        assert!(
361            (result.bootstrap_mean - 3.1).abs() < 1.0,
362            "Bootstrap mean should be close to 3.1"
363        );
364
365        // Standard error should be reasonable
366        assert!(
367            result.bootstrap_std_error > 0.0 && result.bootstrap_std_error < 5.0,
368            "Standard error should be reasonable"
369        );
370    }
371
372    #[test]
373    fn test_yaml_export() {
374        let config = BootstrapConfig::new()
375            .with_data(vec![1.0, 2.0, 3.0, 4.0, 5.0])
376            .with_iterations(100)
377            .with_seed(42);
378
379        let mut engine = BootstrapEngine::new(config).unwrap();
380        let result = engine.analyze().unwrap();
381        let yaml = result.to_yaml();
382
383        assert!(yaml.contains("original_estimate"));
384        assert!(yaml.contains("bootstrap_mean"));
385        assert!(yaml.contains("confidence_intervals"));
386    }
387}