use std::iter::IntoIterator;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use statrs::distribution::{ChiSquared, ContinuousCDF, Normal};
use crate::{Computation, Error, Float};
pub fn pearson_chi_squared<T: Float, I: IntoIterator<Item = T>>(
data: I,
n_classes: Option<usize>,
adjust: bool,
) -> Result<Computation<T>, Error> {
let clean_data: Vec<T> = data.into_iter().filter(|v| !v.is_nan()).collect();
let n = clean_data.len();
if n < 2 {
return Err(Error::InsufficientSampleSize {
given: n,
needed: 2,
});
}
let n_t = T::from(n).unwrap();
let num_classes = n_classes.unwrap_or_else(|| (2.0 * (n as f64).powf(0.4)).ceil() as usize);
let num_classes_t = T::from(num_classes).unwrap();
let mean = iter_if_parallel!(&clean_data).copied().sum::<T>() / n_t;
let variance = iter_if_parallel!(&clean_data).map(|&x| (x - mean).powi(2)).sum::<T>()
/ T::from(n - 1).unwrap();
let std_dev = variance.sqrt();
if std_dev < T::epsilon() {
return Err(Error::ZeroRange);
}
let normal_dist = Normal::new(mean.to_f64().unwrap(), std_dev.to_f64().unwrap())?;
#[cfg(feature = "parallel")]
let counts = clean_data
.par_iter()
.fold(
|| vec![0usize; num_classes],
|mut local_counts, &x| {
let p = normal_dist.cdf(x.to_f64().unwrap());
let bin_num = (1.0 + num_classes_t.to_f64().unwrap() * p).floor() as usize;
if bin_num >= 1 && bin_num <= num_classes {
local_counts[bin_num - 1] += 1; }
local_counts
},
)
.reduce(
|| vec![0usize; num_classes],
|mut a, b| {
for (i, &v) in b.iter().enumerate() {
a[i] += v;
}
a
},
);
#[cfg(not(feature = "parallel"))]
let counts = clean_data.iter().fold(vec![0usize; num_classes], |mut local_counts, &x| {
let p = normal_dist.cdf(x.to_f64().unwrap());
let bin_num = (1.0 + num_classes_t.to_f64().unwrap() * p).floor() as usize;
if bin_num >= 1 && bin_num <= num_classes {
local_counts[bin_num - 1] += 1;
}
local_counts
});
let expected_count = n_t / num_classes_t;
let chi_sq_stat = counts
.iter()
.map(|&count| {
let count_t = T::from(count).unwrap();
let diff = count_t - expected_count;
diff.powi(2) / expected_count
})
.sum::<T>();
let dfd = if adjust { 2 } else { 0 };
let df = num_classes.saturating_sub(dfd).saturating_sub(1);
if df < 1 {
return Err(Error::Other("Degrees of freedom is less than 1".to_string()));
}
let chi_sq_dist = ChiSquared::new(df as f64)?;
let p_value = T::from(chi_sq_dist.sf(chi_sq_stat.to_f64().unwrap())).unwrap();
Ok(Computation {
statistic: chi_sq_stat,
p_value,
})
}