use num_traits::{Float, FromPrimitive};
use rand::thread_rng;
use statrs::distribution::{Normal, ContinuousCDF};
use crate::{CDF, Flip, Flipper, Mean, Re, Sample, SignBitFlip, Statistic, Variance};
#[derive(Debug, Clone, Copy)]
pub struct VarianceTest<F> {
pub null_variance: F,
pub n_permutations: usize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TestResult<F: Float> {
pub observed_statistic: F,
pub p_value: F,
}
impl<D, F> Statistic<D, TestResult<F>> for VarianceTest<F>
where
D: AsRef<[F]> + Clone,
F: Float + FromPrimitive + Copy,
SignBitFlip: Flip<F>,
{
fn compute(&self, data: &D) -> TestResult<F> {
let data_slice = data.as_ref();
let n = data_slice.len();
if n < 2 {
return TestResult {
p_value: F::from(1.0).expect("1.0 is a valid float"),
observed_statistic: F::nan(),
};
}
let sample_mean = Mean.compute(&data);
let centered: Sample<F> = data_slice
.iter()
.map(|&x| x - sample_mean)
.collect();
let observed_var = Variance::default().compute(¢ered); let observed_deviation = (observed_var - self.null_variance).abs();
let flipper = Flipper::sign(thread_rng());
let permuted_deviations: Sample<F> = flipper
.re(¢ered)
.map(|resample| {
let var = Variance::default().compute(&resample);
(var - self.null_variance).abs()
})
.take(self.n_permutations)
.collect();
let extreme_count = permuted_deviations
.as_ref()
.iter()
.filter(|&&dev| dev >= observed_deviation)
.count();
let p_value = F::from(extreme_count + 1).expect("extreme_count + 1 fits in float")
/ F::from(self.n_permutations + 1).expect("n_permutations + 1 fits in float");
TestResult {
p_value,
observed_statistic: observed_var,
}
}
}
impl<F: Float + FromPrimitive> VarianceTest<F> {
pub fn new(null_variance: F, n_permutations: usize) -> Self {
assert!(
n_permutations > 0,
"n_permutations must be positive"
);
Self {
null_variance,
n_permutations,
}
}
pub fn from_absolute_accuracy(
null_variance: F,
accuracy: f64,
confidence_level: f64,
) -> Self {
assert!(
accuracy > 0.0 && accuracy < 0.5,
"accuracy must be in (0, 0.5), got {}",
accuracy
);
assert!(
confidence_level > 0.5 && confidence_level < 1.0,
"confidence_level must be in (0.5, 1.0), got {}",
confidence_level
);
let alpha = 1.0 - confidence_level;
let z = Normal::new(0.0, 1.0)
.expect("Valid N(0,1) distribution")
.inverse_cdf(1.0 - alpha / 2.0);
let n_min = (z * z * 0.25) / (accuracy * accuracy);
let n_permutations = n_min.ceil() as usize;
let n_permutations = n_permutations.clamp(100, 10_000_000);
Self {
null_variance,
n_permutations,
}
}
pub fn unit(accuracy: f64) -> Self
where
F: FromPrimitive,
{
Self::from_absolute_accuracy(F::from(1.0).expect("1.0 is valid"), accuracy, 0.95)
}
}