use std::iter::IntoIterator;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use statrs::distribution::{ContinuousCDF, Normal};
use crate::{Computation, Error, Float};
pub fn dagostino_k_squared<T: Float, I: IntoIterator<Item = T>>(
data: I,
) -> Result<Computation<T>, Error> {
let data: Vec<T> = data.into_iter().collect();
let n = data.len();
if n < 8 {
return Err(Error::InsufficientSampleSize {
given: n,
needed: 8,
});
}
if n > 46340 {
return Err(Error::ExcessiveSampleSize {
given: n,
needed: 46340,
});
}
if data.iter().any(|&v| v.is_nan()) {
return Err(Error::ContainsNaN);
}
let n_t = T::from(n).unwrap();
let mean = iter_if_parallel!(&data).copied().sum::<T>() / n_t;
#[cfg(feature = "parallel")]
let (sum_sq_devs, m3_sum) = data
.par_iter()
.map(|&x| {
let d = x - mean;
(d.powi(2), d.powi(3))
})
.reduce(|| (T::zero(), T::zero()), |a, b| (a.0 + b.0, a.1 + b.1));
#[cfg(not(feature = "parallel"))]
let (sum_sq_devs, m3_sum) = data.iter().fold((T::zero(), T::zero()), |(s2, s3), &x| {
let d = x - mean;
(s2 + d.powi(2), s3 + d.powi(3))
});
if sum_sq_devs < T::epsilon() {
return Err(Error::ZeroRange);
}
let m3 = m3_sum / n_t;
let m2 = sum_sq_devs / n_t;
let s3 = m3 / m2.powf(T::from(1.5).unwrap());
let y = s3
* ((n_t + T::one()) * (n_t + T::from(3.0).unwrap())
/ (T::from(6.0).unwrap() * (n_t - T::from(2.0).unwrap())))
.sqrt();
let n_sq = n_t * n_t;
let b2_num = T::from(3.0).unwrap()
* (n_sq + T::from(27.0).unwrap() * n_t - T::from(70.0).unwrap())
* (n_t + T::one())
* (n_t + T::from(3.0).unwrap());
let b2_den = (n_t - T::from(2.0).unwrap())
* (n_t + T::from(5.0).unwrap())
* (n_t + T::from(7.0).unwrap())
* (n_t + T::from(9.0).unwrap());
let b2 = b2_num / b2_den;
let w_sq = (T::from(2.0).unwrap() * (b2 - T::one())).sqrt() - T::one();
let w = w_sq.sqrt();
let d = T::one() / w.ln().sqrt();
let a = (T::from(2.0).unwrap() / (w_sq - T::one())).sqrt();
let y_over_a = y / a;
let z = d * (y_over_a + (y_over_a.powi(2) + T::one()).sqrt()).ln();
let normal_dist = Normal::new(0.0, 1.0)?;
let mut pval = T::from(2.0).unwrap() * T::from(normal_dist.sf(z.to_f64().unwrap())).unwrap();
if pval > T::one() {
pval = T::from(2.0).unwrap() - pval;
}
Ok(Computation {
statistic: z,
p_value: pval,
})
}