limma-rust 0.1.0

Pure-Rust port of the Bioconductor limma differential-expression package
Documentation
//! Between- and within-group sums of squares.
//!
//! Pure-Rust port of limma's `bwss.R` ([`bwss`] and [`bwss_matrix`]): a one-way
//! decomposition of total variation into between-group and within-group sums of
//! squares, with the associated degrees of freedom. `NaN` entries are dropped
//! (matching R's `na.rm`); groups left empty after dropping are discarded.

use ndarray::Array2;
use std::collections::HashMap;
use std::hash::Hash;

/// Between/within sums of squares (`bss`, `wss`) and degrees of freedom
/// (`bdf`, `wdf`). All four are `NaN` when there is no data.
#[derive(Debug, Clone, Copy)]
pub struct Bwss {
    pub bss: f64,
    pub wss: f64,
    pub bdf: f64,
    pub wdf: f64,
}

impl Bwss {
    fn na() -> Self {
        Bwss {
            bss: f64::NAN,
            wss: f64::NAN,
            bdf: f64::NAN,
            wdf: f64::NAN,
        }
    }
}

fn bwss_core(groups: &[Vec<f64>]) -> Bwss {
    let k = groups.len();
    let mut ns = Vec::with_capacity(k);
    let mut means = Vec::with_capacity(k);
    let mut within = 0.0_f64;
    for g in groups {
        let n = g.len() as f64;
        let mean = g.iter().sum::<f64>() / n;
        within += g.iter().map(|&v| (v - mean).powi(2)).sum::<f64>();
        ns.push(n);
        means.push(mean);
    }
    let total_n: f64 = ns.iter().sum();
    let grand = ns.iter().zip(&means).map(|(&n, &m)| n * m).sum::<f64>() / total_n;
    let bss = ns
        .iter()
        .zip(&means)
        .map(|(&n, &m)| n * (m - grand).powi(2))
        .sum();
    Bwss {
        bss,
        wss: within,
        bdf: k as f64 - 1.0,
        wdf: total_n - k as f64,
    }
}

/// `bwss(x, group)`. Decompose the variation of `x` across the levels of
/// `group`. Group labels are generic; their order does not affect the result.
pub fn bwss<T: Eq + Hash + Clone>(x: &[f64], group: &[T]) -> Bwss {
    assert_eq!(x.len(), group.len(), "x and group lengths differ");
    let mut idx: HashMap<T, usize> = HashMap::new();
    let mut groups: Vec<Vec<f64>> = Vec::new();
    for (&xi, g) in x.iter().zip(group) {
        if xi.is_nan() {
            continue;
        }
        let i = *idx.entry(g.clone()).or_insert_with(|| {
            groups.push(Vec::new());
            groups.len() - 1
        });
        groups[i].push(xi);
    }
    if groups.is_empty() {
        return Bwss::na();
    }
    bwss_core(&groups)
}

/// `bwss.matrix(x)`. Treat each column of `x` as a group and decompose the
/// variation between and within columns. All-`NaN` columns are dropped.
pub fn bwss_matrix(x: &Array2<f64>) -> Bwss {
    let mut groups: Vec<Vec<f64>> = Vec::new();
    for col in x.columns() {
        let vals: Vec<f64> = col.iter().copied().filter(|v| !v.is_nan()).collect();
        if !vals.is_empty() {
            groups.push(vals);
        }
    }
    if groups.is_empty() {
        return Bwss::na();
    }
    bwss_core(&groups)
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::array;

    fn close(b: &Bwss, want: [f64; 4], tol: f64) -> bool {
        let got = [b.bss, b.wss, b.bdf, b.wdf];
        got.iter()
            .zip(want)
            .all(|(&x, y)| (x - y).abs() <= tol + tol * y.abs())
    }

    #[test]
    fn grouped_matches_r() {
        // Reference: bwss(c(1,2,3,10,11,12,20,21), c(a,a,a,b,b,b,c,c)).
        let x = [1.0, 2.0, 3.0, 10.0, 11.0, 12.0, 20.0, 21.0];
        let g = ["a", "a", "a", "b", "b", "b", "c", "c"];
        assert!(close(&bwss(&x, &g), [415.5, 4.5, 2.0, 5.0], 1e-12));
    }

    #[test]
    fn grouped_with_na_matches_r() {
        // Reference: bwss(c(1,2,NA,4), c(a,a,b,b)) drops the NA observation.
        let x = [1.0, 2.0, f64::NAN, 4.0];
        let g = ["a", "a", "b", "b"];
        assert!(close(
            &bwss(&x, &g),
            [4.16666666666667, 0.5, 1.0, 1.0],
            1e-12
        ));
    }

    #[test]
    fn matrix_matches_r() {
        // Reference: bwss.matrix(matrix(c(1,2,3,10,11,12,20,21,NA), nrow=3)).
        let m = array![[1.0, 10.0, 20.0], [2.0, 11.0, 21.0], [3.0, 12.0, f64::NAN]];
        assert!(close(&bwss_matrix(&m), [415.5, 4.5, 2.0, 5.0], 1e-12));
    }

    #[test]
    fn matrix_drops_all_na_column() {
        // Reference: bwss.matrix(matrix(c(1,2,3,NA,NA,NA,5,7,9), nrow=3)).
        let m = array![
            [1.0, f64::NAN, 5.0],
            [2.0, f64::NAN, 7.0],
            [3.0, f64::NAN, 9.0]
        ];
        assert!(close(&bwss_matrix(&m), [37.5, 10.0, 1.0, 4.0], 1e-12));
    }

    #[test]
    fn empty_is_nan() {
        let b = bwss(&[f64::NAN, f64::NAN], &["a", "b"]);
        assert!(b.bss.is_nan() && b.wss.is_nan() && b.bdf.is_nan() && b.wdf.is_nan());
    }
}