mollendorff_forge/bootstrap/
engine.rs1use super::config::{BootstrapConfig, BootstrapStatistic};
7use rand::rngs::StdRng;
8use rand::{RngExt, SeedableRng};
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ConfidenceInterval {
14 pub level: f64,
16 pub lower: f64,
18 pub upper: f64,
20}
21
22impl ConfidenceInterval {
23 #[must_use]
25 pub const fn new(level: f64, lower: f64, upper: f64) -> Self {
26 Self {
27 level,
28 lower,
29 upper,
30 }
31 }
32
33 #[must_use]
35 pub fn width(&self) -> f64 {
36 self.upper - self.lower
37 }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct BootstrapResult {
43 pub original_estimate: f64,
45 pub bootstrap_mean: f64,
47 pub bootstrap_std_error: f64,
49 pub bias: f64,
51 pub confidence_intervals: Vec<ConfidenceInterval>,
53 pub distribution: Vec<f64>,
55 pub iterations: usize,
57}
58
59impl BootstrapResult {
60 #[must_use]
62 pub fn to_yaml(&self) -> String {
63 serde_yaml_ng::to_string(self).unwrap_or_else(|_| "# Error serializing results".to_string())
64 }
65
66 pub fn to_json(&self) -> Result<String, serde_json::Error> {
72 serde_json::to_string_pretty(self)
73 }
74
75 #[must_use]
77 pub fn bias_corrected_estimate(&self) -> f64 {
78 self.original_estimate - self.bias
79 }
80}
81
82pub struct BootstrapEngine {
84 config: BootstrapConfig,
85 rng: StdRng,
86}
87
88impl BootstrapEngine {
89 pub fn new(config: BootstrapConfig) -> Result<Self, String> {
95 config.validate()?;
96
97 let rng = config
98 .seed
99 .map_or_else(|| StdRng::from_rng(&mut rand::rng()), StdRng::seed_from_u64);
100
101 Ok(Self { config, rng })
102 }
103
104 pub fn analyze(&mut self) -> Result<BootstrapResult, String> {
110 let data = &self.config.data;
111 let n = data.len();
112
113 let original_estimate = self.compute_statistic(data);
115
116 let mut distribution = Vec::with_capacity(self.config.iterations);
118
119 for _ in 0..self.config.iterations {
120 let sample: Vec<f64> = (0..n)
122 .map(|_| {
123 let idx = self.rng.random_range(0..n);
124 data[idx]
125 })
126 .collect();
127
128 let stat = self.compute_statistic(&sample);
129 distribution.push(stat);
130 }
131
132 distribution.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
134
135 let bootstrap_mean = distribution.iter().sum::<f64>() / distribution.len() as f64;
137 let variance: f64 = distribution
138 .iter()
139 .map(|x| (x - bootstrap_mean).powi(2))
140 .sum::<f64>()
141 / (distribution.len() - 1) as f64;
142 let bootstrap_std_error = variance.sqrt();
143 let bias = bootstrap_mean - original_estimate;
144
145 let confidence_intervals = self.calculate_confidence_intervals(&distribution);
147
148 Ok(BootstrapResult {
149 original_estimate,
150 bootstrap_mean,
151 bootstrap_std_error,
152 bias,
153 confidence_intervals,
154 distribution,
155 iterations: self.config.iterations,
156 })
157 }
158
159 fn compute_statistic(&self, sample: &[f64]) -> f64 {
161 if sample.is_empty() {
162 return 0.0;
163 }
164
165 match self.config.statistic {
166 BootstrapStatistic::Mean => sample.iter().sum::<f64>() / sample.len() as f64,
167 BootstrapStatistic::Median => {
168 let mut sorted = sample.to_vec();
169 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
170 let mid = sorted.len() / 2;
171 if sorted.len().is_multiple_of(2) {
172 f64::midpoint(sorted[mid - 1], sorted[mid])
173 } else {
174 sorted[mid]
175 }
176 },
177 BootstrapStatistic::Std => {
178 let mean = sample.iter().sum::<f64>() / sample.len() as f64;
179 let variance: f64 = sample.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
180 / (sample.len() - 1) as f64;
181 variance.sqrt()
182 },
183 BootstrapStatistic::Var => {
184 let mean = sample.iter().sum::<f64>() / sample.len() as f64;
185 sample.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (sample.len() - 1) as f64
186 },
187 BootstrapStatistic::Percentile => {
188 let mut sorted = sample.to_vec();
189 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
190 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
192 let idx = ((self.config.percentile_value / 100.0) * (sorted.len() as f64 - 1.0))
193 .round() as usize;
194 sorted[idx.min(sorted.len() - 1)]
195 },
196 BootstrapStatistic::Min => sample.iter().copied().fold(f64::INFINITY, f64::min),
197 BootstrapStatistic::Max => sample.iter().copied().fold(f64::NEG_INFINITY, f64::max),
198 }
199 }
200
201 fn calculate_confidence_intervals(&self, distribution: &[f64]) -> Vec<ConfidenceInterval> {
203 self.config
204 .confidence_levels
205 .iter()
206 .map(|&level| {
207 let alpha = 1.0 - level;
208 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
210 let lower_idx = ((alpha / 2.0) * distribution.len() as f64) as usize;
211 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
212 let upper_idx = ((1.0 - alpha / 2.0) * distribution.len() as f64) as usize;
213
214 ConfidenceInterval::new(
215 level,
216 distribution[lower_idx.min(distribution.len() - 1)],
217 distribution[upper_idx.min(distribution.len() - 1)],
218 )
219 })
220 .collect()
221 }
222
223 #[must_use]
225 pub const fn config(&self) -> &BootstrapConfig {
226 &self.config
227 }
228}
229
230#[cfg(test)]
231mod engine_tests {
232 use super::*;
233
234 #[test]
235 fn test_bootstrap_mean() {
236 let config = BootstrapConfig::new()
237 .with_data(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
238 .with_iterations(5000)
239 .with_seed(12345);
240
241 let mut engine = BootstrapEngine::new(config).unwrap();
242 let result = engine.analyze().unwrap();
243
244 assert!(
246 (result.original_estimate - 5.5).abs() < 0.01,
247 "Original mean should be 5.5"
248 );
249
250 assert!(
252 (result.bootstrap_mean - 5.5).abs() < 0.5,
253 "Bootstrap mean should be close to 5.5"
254 );
255
256 assert!(!result.confidence_intervals.is_empty());
258 }
259
260 #[test]
261 fn test_bootstrap_median() {
262 let config = BootstrapConfig::new()
263 .with_data(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
264 .with_statistic(BootstrapStatistic::Median)
265 .with_iterations(5000)
266 .with_seed(12345);
267
268 let mut engine = BootstrapEngine::new(config).unwrap();
269 let result = engine.analyze().unwrap();
270
271 assert!(
273 (result.original_estimate - 5.5).abs() < 0.01,
274 "Original median should be 5.5"
275 );
276 }
277
278 #[test]
279 fn test_confidence_intervals() {
280 let config = BootstrapConfig::new()
281 .with_data(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
282 .with_confidence_levels(vec![0.90, 0.95])
283 .with_iterations(10000)
284 .with_seed(12345);
285
286 let mut engine = BootstrapEngine::new(config).unwrap();
287 let result = engine.analyze().unwrap();
288
289 assert_eq!(result.confidence_intervals.len(), 2);
290
291 let ci_90 = result
293 .confidence_intervals
294 .iter()
295 .find(|ci| (ci.level - 0.90).abs() < 0.01)
296 .unwrap();
297 let ci_95 = result
298 .confidence_intervals
299 .iter()
300 .find(|ci| (ci.level - 0.95).abs() < 0.01)
301 .unwrap();
302
303 assert!(
304 ci_95.width() >= ci_90.width(),
305 "95% CI should be >= 90% CI width"
306 );
307 }
308
309 #[test]
310 fn test_reproducibility() {
311 let config1 = BootstrapConfig::new()
312 .with_data(vec![1.0, 2.0, 3.0, 4.0, 5.0])
313 .with_iterations(1000)
314 .with_seed(42);
315
316 let config2 = BootstrapConfig::new()
317 .with_data(vec![1.0, 2.0, 3.0, 4.0, 5.0])
318 .with_iterations(1000)
319 .with_seed(42);
320
321 let mut engine1 = BootstrapEngine::new(config1).unwrap();
322 let mut engine2 = BootstrapEngine::new(config2).unwrap();
323
324 let result1 = engine1.analyze().unwrap();
325 let result2 = engine2.analyze().unwrap();
326
327 assert!(
328 (result1.bootstrap_mean - result2.bootstrap_mean).abs() < 0.0001,
329 "Same seed should produce same results"
330 );
331 }
332
333 #[test]
335 fn test_r_boot_equivalence() {
336 let config = BootstrapConfig::new()
345 .with_data(vec![5.0, -2.0, 8.0, 3.0, -5.0, 12.0, 1.0, -1.0, 6.0, 4.0])
346 .with_iterations(10000)
347 .with_seed(12345)
348 .with_confidence_levels(vec![0.95]);
349
350 let mut engine = BootstrapEngine::new(config).unwrap();
351 let result = engine.analyze().unwrap();
352
353 assert!(
355 (result.original_estimate - 3.1).abs() < 0.01,
356 "Original mean should be 3.1"
357 );
358
359 assert!(
361 (result.bootstrap_mean - 3.1).abs() < 1.0,
362 "Bootstrap mean should be close to 3.1"
363 );
364
365 assert!(
367 result.bootstrap_std_error > 0.0 && result.bootstrap_std_error < 5.0,
368 "Standard error should be reasonable"
369 );
370 }
371
372 #[test]
373 fn test_yaml_export() {
374 let config = BootstrapConfig::new()
375 .with_data(vec![1.0, 2.0, 3.0, 4.0, 5.0])
376 .with_iterations(100)
377 .with_seed(42);
378
379 let mut engine = BootstrapEngine::new(config).unwrap();
380 let result = engine.analyze().unwrap();
381 let yaml = result.to_yaml();
382
383 assert!(yaml.contains("original_estimate"));
384 assert!(yaml.contains("bootstrap_mean"));
385 assert!(yaml.contains("confidence_intervals"));
386 }
387}