Skip to main content

limma/
bwss.rs

1//! Between- and within-group sums of squares.
2//!
3//! Pure-Rust port of limma's `bwss.R` ([`bwss`] and [`bwss_matrix`]): a one-way
4//! decomposition of total variation into between-group and within-group sums of
5//! squares, with the associated degrees of freedom. `NaN` entries are dropped
6//! (matching R's `na.rm`); groups left empty after dropping are discarded.
7
8use ndarray::Array2;
9use std::collections::HashMap;
10use std::hash::Hash;
11
12/// Between/within sums of squares (`bss`, `wss`) and degrees of freedom
13/// (`bdf`, `wdf`). All four are `NaN` when there is no data.
14#[derive(Debug, Clone, Copy)]
15pub struct Bwss {
16    pub bss: f64,
17    pub wss: f64,
18    pub bdf: f64,
19    pub wdf: f64,
20}
21
22impl Bwss {
23    fn na() -> Self {
24        Bwss {
25            bss: f64::NAN,
26            wss: f64::NAN,
27            bdf: f64::NAN,
28            wdf: f64::NAN,
29        }
30    }
31}
32
33fn bwss_core(groups: &[Vec<f64>]) -> Bwss {
34    let k = groups.len();
35    let mut ns = Vec::with_capacity(k);
36    let mut means = Vec::with_capacity(k);
37    let mut within = 0.0_f64;
38    for g in groups {
39        let n = g.len() as f64;
40        let mean = g.iter().sum::<f64>() / n;
41        within += g.iter().map(|&v| (v - mean).powi(2)).sum::<f64>();
42        ns.push(n);
43        means.push(mean);
44    }
45    let total_n: f64 = ns.iter().sum();
46    let grand = ns.iter().zip(&means).map(|(&n, &m)| n * m).sum::<f64>() / total_n;
47    let bss = ns
48        .iter()
49        .zip(&means)
50        .map(|(&n, &m)| n * (m - grand).powi(2))
51        .sum();
52    Bwss {
53        bss,
54        wss: within,
55        bdf: k as f64 - 1.0,
56        wdf: total_n - k as f64,
57    }
58}
59
60/// `bwss(x, group)`. Decompose the variation of `x` across the levels of
61/// `group`. Group labels are generic; their order does not affect the result.
62pub fn bwss<T: Eq + Hash + Clone>(x: &[f64], group: &[T]) -> Bwss {
63    assert_eq!(x.len(), group.len(), "x and group lengths differ");
64    let mut idx: HashMap<T, usize> = HashMap::new();
65    let mut groups: Vec<Vec<f64>> = Vec::new();
66    for (&xi, g) in x.iter().zip(group) {
67        if xi.is_nan() {
68            continue;
69        }
70        let i = *idx.entry(g.clone()).or_insert_with(|| {
71            groups.push(Vec::new());
72            groups.len() - 1
73        });
74        groups[i].push(xi);
75    }
76    if groups.is_empty() {
77        return Bwss::na();
78    }
79    bwss_core(&groups)
80}
81
82/// `bwss.matrix(x)`. Treat each column of `x` as a group and decompose the
83/// variation between and within columns. All-`NaN` columns are dropped.
84pub fn bwss_matrix(x: &Array2<f64>) -> Bwss {
85    let mut groups: Vec<Vec<f64>> = Vec::new();
86    for col in x.columns() {
87        let vals: Vec<f64> = col.iter().copied().filter(|v| !v.is_nan()).collect();
88        if !vals.is_empty() {
89            groups.push(vals);
90        }
91    }
92    if groups.is_empty() {
93        return Bwss::na();
94    }
95    bwss_core(&groups)
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use ndarray::array;
102
103    fn close(b: &Bwss, want: [f64; 4], tol: f64) -> bool {
104        let got = [b.bss, b.wss, b.bdf, b.wdf];
105        got.iter()
106            .zip(want)
107            .all(|(&x, y)| (x - y).abs() <= tol + tol * y.abs())
108    }
109
110    #[test]
111    fn grouped_matches_r() {
112        // Reference: bwss(c(1,2,3,10,11,12,20,21), c(a,a,a,b,b,b,c,c)).
113        let x = [1.0, 2.0, 3.0, 10.0, 11.0, 12.0, 20.0, 21.0];
114        let g = ["a", "a", "a", "b", "b", "b", "c", "c"];
115        assert!(close(&bwss(&x, &g), [415.5, 4.5, 2.0, 5.0], 1e-12));
116    }
117
118    #[test]
119    fn grouped_with_na_matches_r() {
120        // Reference: bwss(c(1,2,NA,4), c(a,a,b,b)) drops the NA observation.
121        let x = [1.0, 2.0, f64::NAN, 4.0];
122        let g = ["a", "a", "b", "b"];
123        assert!(close(
124            &bwss(&x, &g),
125            [4.16666666666667, 0.5, 1.0, 1.0],
126            1e-12
127        ));
128    }
129
130    #[test]
131    fn matrix_matches_r() {
132        // Reference: bwss.matrix(matrix(c(1,2,3,10,11,12,20,21,NA), nrow=3)).
133        let m = array![[1.0, 10.0, 20.0], [2.0, 11.0, 21.0], [3.0, 12.0, f64::NAN]];
134        assert!(close(&bwss_matrix(&m), [415.5, 4.5, 2.0, 5.0], 1e-12));
135    }
136
137    #[test]
138    fn matrix_drops_all_na_column() {
139        // Reference: bwss.matrix(matrix(c(1,2,3,NA,NA,NA,5,7,9), nrow=3)).
140        let m = array![
141            [1.0, f64::NAN, 5.0],
142            [2.0, f64::NAN, 7.0],
143            [3.0, f64::NAN, 9.0]
144        ];
145        assert!(close(&bwss_matrix(&m), [37.5, 10.0, 1.0, 4.0], 1e-12));
146    }
147
148    #[test]
149    fn empty_is_nan() {
150        let b = bwss(&[f64::NAN, f64::NAN], &["a", "b"]);
151        assert!(b.bss.is_nan() && b.wss.is_nan() && b.bdf.is_nan() && b.wdf.is_nan());
152    }
153}