use rand::rngs::StdRng;
use rand::Rng;
use rand::SeedableRng;
use crate::errors::ElinorError;
use crate::errors::Result;
use crate::statistical_tests::student_t_test::compute_t_stat;
#[derive(Debug, Clone, Copy)]
pub struct BootstrapTest {
n_resamples: usize,
random_state: u64,
p_value: f64,
}
impl BootstrapTest {
pub fn from_paired_samples<I>(samples: I) -> Result<Self>
where
I: IntoIterator<Item = (f64, f64)>,
{
BootstrapTester::new().test(samples)
}
pub const fn n_resamples(&self) -> usize {
self.n_resamples
}
pub const fn random_state(&self) -> u64 {
self.random_state
}
pub const fn p_value(&self) -> f64 {
self.p_value
}
}
#[derive(Debug, Clone, Copy)]
pub struct BootstrapTester {
n_resamples: usize,
random_state: Option<u64>,
}
impl Default for BootstrapTester {
fn default() -> Self {
Self::new()
}
}
impl BootstrapTester {
pub const fn new() -> Self {
Self {
n_resamples: 10000,
random_state: None,
}
}
pub fn with_n_resamples(mut self, n_resamples: usize) -> Self {
self.n_resamples = n_resamples.max(1);
self
}
pub const fn with_random_state(mut self, random_state: u64) -> Self {
self.random_state = Some(random_state);
self
}
pub fn test<I>(&self, samples: I) -> Result<BootstrapTest>
where
I: IntoIterator<Item = (f64, f64)>,
{
let samples: Vec<f64> = samples.into_iter().map(|(x, y)| x - y).collect();
if samples.len() <= 1 {
return Err(ElinorError::InvalidArgument(
"The input must have at least two samples.".to_string(),
));
}
let random_state = self
.random_state
.map_or_else(|| rand::thread_rng().gen(), |seed| seed);
let mut rng = StdRng::seed_from_u64(random_state);
let (t_stat, mean, _) = compute_t_stat(&samples)?;
let samples: Vec<f64> = samples.iter().map(|x| x - mean).collect();
let mut count: usize = 0;
for _ in 0..self.n_resamples {
let resampled: Vec<f64> = (0..samples.len())
.map(|_| samples[rng.gen_range(0..samples.len())])
.collect();
let (resampled_t_stat, _, _) = compute_t_stat(&resampled).unwrap_or((0.0, 0.0, 0.0));
if resampled_t_stat.abs() >= t_stat.abs() {
count += 1;
}
}
let p_value = count as f64 / self.n_resamples as f64;
Ok(BootstrapTest {
n_resamples: self.n_resamples,
random_state,
p_value,
})
}
}
#[cfg(test)]
mod tests {
use std::vec;
use super::*;
use approx::relative_eq;
#[test]
fn test_bootstrap_test_from_samples_empty() {
let samples = vec![];
let result = BootstrapTest::from_paired_samples(samples);
assert_eq!(
result.unwrap_err(),
ElinorError::InvalidArgument("The input must have at least two samples.".to_string())
);
}
#[test]
fn test_bootstrap_test_from_samples_single() {
let samples = vec![(1.0, 1.0)];
let result = BootstrapTest::from_paired_samples(samples);
assert_eq!(
result.unwrap_err(),
ElinorError::InvalidArgument("The input must have at least two samples.".to_string())
);
}
#[test]
fn test_bootstrap_test_from_samples_zero_variance() {
let samples = vec![(1.0, 0.0), (1.0, 0.0)];
let result = BootstrapTest::from_paired_samples(samples);
assert_eq!(
result.unwrap_err(),
ElinorError::Uncomputable("The variance is zero.".to_string())
);
}
#[test]
fn test_bootstrap_tester_with_parameters() {
let tester = BootstrapTester::new()
.with_n_resamples(334)
.with_random_state(42);
let samples = vec![(1.0, 0.0), (0.0, 1.0), (1.0, 3.0)];
let result = tester.test(samples).unwrap();
assert_eq!(result.n_resamples(), 334);
assert_eq!(result.random_state(), 42);
}
#[test]
fn test_bootstrap_tester_with_random_state_consistency() {
let samples = vec![(1.0, 0.0), (0.0, 1.0), (1.0, 3.0)];
let p_values: Vec<f64> = (0..10)
.map(|_| {
let tester = BootstrapTester::new().with_random_state(42);
let result = tester.test(samples.clone()).unwrap();
result.p_value()
})
.collect();
let x = p_values[0];
assert!(p_values.iter().all(|&y| relative_eq!(x, y)));
}
}