use ndarray::Array2;
use std::collections::HashMap;
use std::hash::Hash;
#[derive(Debug, Clone, Copy)]
pub struct Bwss {
pub bss: f64,
pub wss: f64,
pub bdf: f64,
pub wdf: f64,
}
impl Bwss {
fn na() -> Self {
Bwss {
bss: f64::NAN,
wss: f64::NAN,
bdf: f64::NAN,
wdf: f64::NAN,
}
}
}
fn bwss_core(groups: &[Vec<f64>]) -> Bwss {
let k = groups.len();
let mut ns = Vec::with_capacity(k);
let mut means = Vec::with_capacity(k);
let mut within = 0.0_f64;
for g in groups {
let n = g.len() as f64;
let mean = g.iter().sum::<f64>() / n;
within += g.iter().map(|&v| (v - mean).powi(2)).sum::<f64>();
ns.push(n);
means.push(mean);
}
let total_n: f64 = ns.iter().sum();
let grand = ns.iter().zip(&means).map(|(&n, &m)| n * m).sum::<f64>() / total_n;
let bss = ns
.iter()
.zip(&means)
.map(|(&n, &m)| n * (m - grand).powi(2))
.sum();
Bwss {
bss,
wss: within,
bdf: k as f64 - 1.0,
wdf: total_n - k as f64,
}
}
pub fn bwss<T: Eq + Hash + Clone>(x: &[f64], group: &[T]) -> Bwss {
assert_eq!(x.len(), group.len(), "x and group lengths differ");
let mut idx: HashMap<T, usize> = HashMap::new();
let mut groups: Vec<Vec<f64>> = Vec::new();
for (&xi, g) in x.iter().zip(group) {
if xi.is_nan() {
continue;
}
let i = *idx.entry(g.clone()).or_insert_with(|| {
groups.push(Vec::new());
groups.len() - 1
});
groups[i].push(xi);
}
if groups.is_empty() {
return Bwss::na();
}
bwss_core(&groups)
}
pub fn bwss_matrix(x: &Array2<f64>) -> Bwss {
let mut groups: Vec<Vec<f64>> = Vec::new();
for col in x.columns() {
let vals: Vec<f64> = col.iter().copied().filter(|v| !v.is_nan()).collect();
if !vals.is_empty() {
groups.push(vals);
}
}
if groups.is_empty() {
return Bwss::na();
}
bwss_core(&groups)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn close(b: &Bwss, want: [f64; 4], tol: f64) -> bool {
let got = [b.bss, b.wss, b.bdf, b.wdf];
got.iter()
.zip(want)
.all(|(&x, y)| (x - y).abs() <= tol + tol * y.abs())
}
#[test]
fn grouped_matches_r() {
let x = [1.0, 2.0, 3.0, 10.0, 11.0, 12.0, 20.0, 21.0];
let g = ["a", "a", "a", "b", "b", "b", "c", "c"];
assert!(close(&bwss(&x, &g), [415.5, 4.5, 2.0, 5.0], 1e-12));
}
#[test]
fn grouped_with_na_matches_r() {
let x = [1.0, 2.0, f64::NAN, 4.0];
let g = ["a", "a", "b", "b"];
assert!(close(
&bwss(&x, &g),
[4.16666666666667, 0.5, 1.0, 1.0],
1e-12
));
}
#[test]
fn matrix_matches_r() {
let m = array![[1.0, 10.0, 20.0], [2.0, 11.0, 21.0], [3.0, 12.0, f64::NAN]];
assert!(close(&bwss_matrix(&m), [415.5, 4.5, 2.0, 5.0], 1e-12));
}
#[test]
fn matrix_drops_all_na_column() {
let m = array![
[1.0, f64::NAN, 5.0],
[2.0, f64::NAN, 7.0],
[3.0, f64::NAN, 9.0]
];
assert!(close(&bwss_matrix(&m), [37.5, 10.0, 1.0, 4.0], 1e-12));
}
#[test]
fn empty_is_nan() {
let b = bwss(&[f64::NAN, f64::NAN], &["a", "b"]);
assert!(b.bss.is_nan() && b.wss.is_nan() && b.bdf.is_nan() && b.wdf.is_nan());
}
}