use num_traits::{Float, FromPrimitive};
use rand::thread_rng;
use statrs::distribution::{Normal, ContinuousCDF};
use crate::{CDF, Flip, Flipper, Mean, Re, Sample, SignBitFlip, Statistic};
#[derive(Debug, Clone, Copy)]
pub struct MeanTest<F> {
pub null_mean: F,
pub n_permutations: usize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TestResult<F: Float> {
pub observed_statistic: F,
pub p_value: F,
}
impl<D, F> Statistic<D, TestResult<F>> for MeanTest<F>
where
D: AsRef<[F]> + Clone,
F: Float + FromPrimitive + Copy,
SignBitFlip: Flip<F>,
{
fn compute(&self, data: &D) -> TestResult<F> {
let data_slice = data.as_ref();
let n = data_slice.len();
if n == 0 {
return TestResult {
p_value: F::from(1.0).expect("1.0 is a valid float"),
observed_statistic: F::zero(),
};
}
let centered: Sample<F> = data_slice
.iter()
.map(|&x| x - self.null_mean)
.collect();
let observed_stat = Mean.compute(¢ered);
let observed_abs = observed_stat.abs();
let flipper = Flipper::sign(thread_rng());
let permuted_stats: Sample<F> = flipper
.re(¢ered)
.map(|resample| Mean.compute(&resample))
.take(self.n_permutations)
.collect();
let extreme_count = permuted_stats
.as_ref()
.iter()
.filter(|&&stat| stat.abs() >= observed_abs)
.count();
let p_value = F::from(extreme_count + 1).expect("extreme_count + 1 fits in float")
/ F::from(self.n_permutations + 1).expect("n_permutations + 1 fits in float");
TestResult {
p_value,
observed_statistic: observed_stat,
}
}
}
impl<F: Float + FromPrimitive> MeanTest<F> {
pub fn new(null_mean: F, n_permutations: usize) -> Self {
assert!(n_permutations > 0, "n_permutations must be positive");
Self {
null_mean,
n_permutations,
}
}
pub fn from_absolute_accuracy(
null_mean: F,
accuracy: f64,
confidence_level: f64,
) -> Self {
assert!(
accuracy > 0.0 && accuracy < 0.5,
"accuracy must be in (0, 0.5), got {}",
accuracy
);
assert!(
confidence_level > 0.5 && confidence_level < 1.0,
"confidence_level must be in (0.5, 1.0), got {}",
confidence_level
);
let alpha = 1.0 - confidence_level;
let z = Normal::new(0.0, 1.0)
.expect("Valid N(0,1) distribution")
.inverse_cdf(1.0 - alpha / 2.0);
let n_min = (z * z * 0.25) / (accuracy * accuracy);
let n_permutations = n_min.ceil() as usize;
let n_permutations = n_permutations.clamp(100, 10_000_000);
Self {
null_mean,
n_permutations,
}
}
pub fn zero(accuracy: f64) -> Self
where
F: FromPrimitive,
{
Self::from_absolute_accuracy(F::zero(), accuracy, 0.95)
}
}