1use ndarray::Array2;
9use std::collections::HashMap;
10use std::hash::Hash;
11
12#[derive(Debug, Clone, Copy)]
15pub struct Bwss {
16 pub bss: f64,
17 pub wss: f64,
18 pub bdf: f64,
19 pub wdf: f64,
20}
21
22impl Bwss {
23 fn na() -> Self {
24 Bwss {
25 bss: f64::NAN,
26 wss: f64::NAN,
27 bdf: f64::NAN,
28 wdf: f64::NAN,
29 }
30 }
31}
32
33fn bwss_core(groups: &[Vec<f64>]) -> Bwss {
34 let k = groups.len();
35 let mut ns = Vec::with_capacity(k);
36 let mut means = Vec::with_capacity(k);
37 let mut within = 0.0_f64;
38 for g in groups {
39 let n = g.len() as f64;
40 let mean = g.iter().sum::<f64>() / n;
41 within += g.iter().map(|&v| (v - mean).powi(2)).sum::<f64>();
42 ns.push(n);
43 means.push(mean);
44 }
45 let total_n: f64 = ns.iter().sum();
46 let grand = ns.iter().zip(&means).map(|(&n, &m)| n * m).sum::<f64>() / total_n;
47 let bss = ns
48 .iter()
49 .zip(&means)
50 .map(|(&n, &m)| n * (m - grand).powi(2))
51 .sum();
52 Bwss {
53 bss,
54 wss: within,
55 bdf: k as f64 - 1.0,
56 wdf: total_n - k as f64,
57 }
58}
59
60pub fn bwss<T: Eq + Hash + Clone>(x: &[f64], group: &[T]) -> Bwss {
63 assert_eq!(x.len(), group.len(), "x and group lengths differ");
64 let mut idx: HashMap<T, usize> = HashMap::new();
65 let mut groups: Vec<Vec<f64>> = Vec::new();
66 for (&xi, g) in x.iter().zip(group) {
67 if xi.is_nan() {
68 continue;
69 }
70 let i = *idx.entry(g.clone()).or_insert_with(|| {
71 groups.push(Vec::new());
72 groups.len() - 1
73 });
74 groups[i].push(xi);
75 }
76 if groups.is_empty() {
77 return Bwss::na();
78 }
79 bwss_core(&groups)
80}
81
82pub fn bwss_matrix(x: &Array2<f64>) -> Bwss {
85 let mut groups: Vec<Vec<f64>> = Vec::new();
86 for col in x.columns() {
87 let vals: Vec<f64> = col.iter().copied().filter(|v| !v.is_nan()).collect();
88 if !vals.is_empty() {
89 groups.push(vals);
90 }
91 }
92 if groups.is_empty() {
93 return Bwss::na();
94 }
95 bwss_core(&groups)
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use ndarray::array;
102
103 fn close(b: &Bwss, want: [f64; 4], tol: f64) -> bool {
104 let got = [b.bss, b.wss, b.bdf, b.wdf];
105 got.iter()
106 .zip(want)
107 .all(|(&x, y)| (x - y).abs() <= tol + tol * y.abs())
108 }
109
110 #[test]
111 fn grouped_matches_r() {
112 let x = [1.0, 2.0, 3.0, 10.0, 11.0, 12.0, 20.0, 21.0];
114 let g = ["a", "a", "a", "b", "b", "b", "c", "c"];
115 assert!(close(&bwss(&x, &g), [415.5, 4.5, 2.0, 5.0], 1e-12));
116 }
117
118 #[test]
119 fn grouped_with_na_matches_r() {
120 let x = [1.0, 2.0, f64::NAN, 4.0];
122 let g = ["a", "a", "b", "b"];
123 assert!(close(
124 &bwss(&x, &g),
125 [4.16666666666667, 0.5, 1.0, 1.0],
126 1e-12
127 ));
128 }
129
130 #[test]
131 fn matrix_matches_r() {
132 let m = array![[1.0, 10.0, 20.0], [2.0, 11.0, 21.0], [3.0, 12.0, f64::NAN]];
134 assert!(close(&bwss_matrix(&m), [415.5, 4.5, 2.0, 5.0], 1e-12));
135 }
136
137 #[test]
138 fn matrix_drops_all_na_column() {
139 let m = array![
141 [1.0, f64::NAN, 5.0],
142 [2.0, f64::NAN, 7.0],
143 [3.0, f64::NAN, 9.0]
144 ];
145 assert!(close(&bwss_matrix(&m), [37.5, 10.0, 1.0, 4.0], 1e-12));
146 }
147
148 #[test]
149 fn empty_is_nan() {
150 let b = bwss(&[f64::NAN, f64::NAN], &["a", "b"]);
151 assert!(b.bss.is_nan() && b.wss.is_nan() && b.bdf.is_nan() && b.wdf.is_nan());
152 }
153}