limma-rust 0.1.0

Pure-Rust port of the Bioconductor limma differential-expression package
Documentation
//! Remove batch effects from an expression matrix. Port of limma's
//! `removeBatchEffect` (`removeBatchEffect.R`).
//!
//! The batch covariates are coded with sum-to-zero contrasts (`contr.sum`) and
//! the numeric covariates are mean-centred, exactly as in R. A single combined
//! linear model `lmFit(x, cbind(design, X.batch))` is fitted and the fitted
//! batch component `beta %*% t(X.batch)` is subtracted from the data.
//!
//! When a batch column is confounded with the design (or another batch column)
//! the joint model is rank-deficient; R's `lmFit` flags those coefficients as
//! `NA` and `removeBatchEffect` replaces them with zero. We reproduce this by
//! dropping the linearly dependent columns (the same first-occurrence-kept set
//! R's pivoted QR retains), fitting the reduced model, and treating the dropped
//! coefficients as zero.

use anyhow::{bail, Result};
use ndarray::{s, Array2, Axis};

use crate::fit::{lmfit, non_estimable};

/// Build the `model.matrix(~f)[,-1]` sum-to-zero contrast columns for a factor
/// given as per-sample labels. Levels are taken in sorted order (matching R's
/// `as.factor`). A factor with `k` levels yields `k-1` columns: level `i`
/// (`0 <= i < k-1`) maps to the `i`-th unit row, and the last level maps to a
/// row of `-1`s.
fn sum_contrasts(labels: &[String]) -> Array2<f64> {
    let mut levels: Vec<&String> = labels.iter().collect();
    levels.sort();
    levels.dedup();
    let k = levels.len();
    let n = labels.len();
    if k < 2 {
        return Array2::zeros((n, 0));
    }
    let level_index = |lab: &String| levels.iter().position(|&l| l == lab).unwrap();
    let mut m = Array2::<f64>::zeros((n, k - 1));
    for (row, lab) in labels.iter().enumerate() {
        let li = level_index(lab);
        if li == k - 1 {
            for c in 0..(k - 1) {
                m[[row, c]] = -1.0;
            }
        } else {
            m[[row, li]] = 1.0;
        }
    }
    m
}

/// Centre each column of `cov` by subtracting its mean (R's
/// `t(t(covariates) - colMeans(covariates))`).
fn center_columns(cov: &Array2<f64>) -> Array2<f64> {
    let mut out = cov.clone();
    let n = cov.nrows() as f64;
    for mut col in out.columns_mut() {
        let mean = col.sum() / n;
        col.mapv_inplace(|v| v - mean);
    }
    out
}

/// Horizontally stack a list of `n`-row blocks.
fn hstack(n: usize, blocks: &[&Array2<f64>]) -> Array2<f64> {
    let total: usize = blocks.iter().map(|b| b.ncols()).sum();
    let mut out = Array2::<f64>::zeros((n, total));
    let mut off = 0usize;
    for b in blocks {
        let w = b.ncols();
        if w > 0 {
            out.slice_mut(s![.., off..off + w]).assign(b);
            off += w;
        }
    }
    out
}

/// Remove batch effects from `x` (`n_genes x n_samples`).
///
/// * `batch`, `batch2` — optional per-sample factor labels (length `n_samples`)
///   for one or two blocking factors, coded with sum-to-zero contrasts.
/// * `covariates` — optional `n_samples x k` numeric covariates to remove
///   (mean-centred before fitting).
/// * `design` — optional `n_samples x p` design matrix of experimental
///   conditions to preserve. Defaults to a single intercept column (one-group
///   experiment) when `None`.
///
/// Returns the batch-corrected matrix (`n_genes x n_samples`). With all of
/// `batch`, `batch2`, `covariates` `None` the input is returned unchanged.
pub fn remove_batch_effect(
    x: &Array2<f64>,
    batch: Option<&[String]>,
    batch2: Option<&[String]>,
    covariates: Option<&Array2<f64>>,
    design: Option<&Array2<f64>>,
) -> Result<Array2<f64>> {
    let n_samples = x.ncols();

    if batch.is_none() && batch2.is_none() && covariates.is_none() {
        return Ok(x.clone());
    }

    // Build the batch-covariate block X.batch = cbind(batch, batch2, covariates).
    let mut blocks: Vec<Array2<f64>> = Vec::new();
    for b in [batch, batch2].into_iter().flatten() {
        if b.len() != n_samples {
            bail!(
                "batch length ({}) does not match number of samples ({})",
                b.len(),
                n_samples
            );
        }
        blocks.push(sum_contrasts(b));
    }
    if let Some(cov) = covariates {
        if cov.nrows() != n_samples {
            bail!(
                "covariates rows ({}) does not match number of samples ({})",
                cov.nrows(),
                n_samples
            );
        }
        blocks.push(center_columns(cov));
    }
    let block_refs: Vec<&Array2<f64>> = blocks.iter().collect();
    let x_batch = hstack(n_samples, &block_refs);

    // Design of interest (default: one-group intercept).
    let design_owned;
    let design = match design {
        Some(d) => {
            if d.nrows() != n_samples {
                bail!(
                    "design rows ({}) does not match number of samples ({})",
                    d.nrows(),
                    n_samples
                );
            }
            d
        }
        None => {
            design_owned = Array2::<f64>::ones((n_samples, 1));
            &design_owned
        }
    };
    let n_design = design.ncols();

    // Combined model cbind(design, X.batch); drop columns that are not
    // estimable (confounded), fit the reduced model, and treat dropped
    // coefficients as zero.
    let full = hstack(n_samples, &[design, &x_batch]);
    let n_total = full.ncols();
    let kept: Vec<usize> = match non_estimable(&full) {
        None => (0..n_total).collect(),
        Some(dep) => (0..n_total).filter(|j| !dep.contains(j)).collect(),
    };
    let reduced = full.select(Axis(1), &kept);

    let gene_names: Vec<String> = (0..x.nrows()).map(|i| i.to_string()).collect();
    let coef_names: Vec<String> = kept.iter().map(|j| j.to_string()).collect();
    let fit = lmfit(x, &reduced, gene_names, coef_names)?;

    // Scatter the reduced coefficients back to full width (dropped -> 0), then
    // keep only the X.batch columns.
    let n_genes = x.nrows();
    let mut beta_full = Array2::<f64>::zeros((n_genes, n_total));
    for (col, &j) in kept.iter().enumerate() {
        beta_full
            .slice_mut(s![.., j])
            .assign(&fit.coefficients.slice(s![.., col]));
    }
    let beta_batch = beta_full.slice(s![.., n_design..]).to_owned();

    Ok(x - &beta_batch.dot(&x_batch.t()))
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::array;

    fn fixture() -> Array2<f64> {
        array![
            [5.1, 4.8, 6.2, 5.5, 4.9, 6.0],
            [2.3, 3.1, 2.8, 3.5, 2.0, 3.9],
            [7.7, 7.2, 8.1, 6.9, 7.5, 8.4],
        ]
    }

    fn labels(v: &[&str]) -> Vec<String> {
        v.iter().map(|s| s.to_string()).collect()
    }

    fn assert_close(got: &Array2<f64>, want: &Array2<f64>) {
        assert_eq!(got.dim(), want.dim());
        for (a, b) in got.iter().zip(want.iter()) {
            assert!((a - b).abs() < 1e-9, "got {a} want {b}");
        }
    }

    // Reference matrices from R limma 3.68.3 (scratch/rbe_ref.R).

    #[test]
    fn case_a_batch_only() {
        let x = fixture();
        let batch = labels(&["a", "a", "b", "b", "a", "b"]);
        let got = remove_batch_effect(&x, Some(&batch), None, None, None).unwrap();
        let want = array![
            [
                5.583333333333,
                5.283333333333,
                5.716666666667,
                5.016666666667,
                5.383333333333,
                5.516666666667
            ],
            [
                2.766666666667,
                3.566666666667,
                2.333333333333,
                3.033333333333,
                2.466666666667,
                3.433333333333
            ],
            [
                7.866666666667,
                7.366666666667,
                7.933333333333,
                6.733333333333,
                7.666666666667,
                8.233333333333
            ],
        ];
        assert_close(&got, &want);
    }

    #[test]
    fn case_b_batch_and_design() {
        let x = fixture();
        let batch = labels(&["a", "a", "b", "b", "a", "b"]);
        // model.matrix(~group), group = g1 g2 g1 g2 g1 g2.
        let design = array![
            [1.0, 0.0],
            [1.0, 1.0],
            [1.0, 0.0],
            [1.0, 1.0],
            [1.0, 0.0],
            [1.0, 1.0],
        ];
        let got = remove_batch_effect(&x, Some(&batch), None, None, Some(&design)).unwrap();
        let want = array![
            [
                5.637500000000,
                5.337500000000,
                5.662500000000,
                4.962500000000,
                5.437500000000,
                5.462500000000
            ],
            [
                2.612500000000,
                3.412500000000,
                2.487500000000,
                3.187500000000,
                2.312500000000,
                3.587500000000
            ],
            [
                7.937500000000,
                7.437500000000,
                7.862500000000,
                6.662500000000,
                7.737500000000,
                8.162500000000
            ],
        ];
        assert_close(&got, &want);
    }

    /// batch2 is confounded with the design here (x/y tracks g1/g2), so R drops
    /// the `batch21` coefficient (NA -> 0). Exercises the rank-deficient path.
    #[test]
    fn case_c_confounded_full() {
        let x = fixture();
        let batch = labels(&["a", "a", "b", "b", "a", "b"]);
        let batch2 = labels(&["x", "y", "x", "y", "x", "y"]);
        let covs = array![[0.1], [0.2], [0.3], [0.4], [0.5], [0.6]];
        let design = array![
            [1.0, 0.0],
            [1.0, 1.0],
            [1.0, 0.0],
            [1.0, 1.0],
            [1.0, 0.0],
            [1.0, 1.0],
        ];
        let got = remove_batch_effect(&x, Some(&batch), Some(&batch2), Some(&covs), Some(&design))
            .unwrap();
        let want = array![
            [
                5.617307692308,
                5.328846153846,
                5.648076923077,
                4.959615384615,
                5.463461538462,
                5.482692307692
            ],
            [
                2.578846153846,
                3.398076923077,
                2.463461538462,
                3.182692307692,
                2.355769230769,
                3.621153846154
            ],
            [
                8.078846153846,
                7.498076923077,
                7.963461538462,
                6.682692307692,
                7.555769230769,
                8.021153846154
            ],
        ];
        assert_close(&got, &want);
    }

    #[test]
    fn case_d_covariates_only() {
        let x = fixture();
        let covs = array![[0.1], [0.2], [0.3], [0.4], [0.5], [0.6]];
        let got = remove_batch_effect(&x, None, None, Some(&covs), None).unwrap();
        let want = array![
            [
                5.392857142857,
                4.975714285714,
                6.258571428571,
                5.441428571429,
                4.724285714286,
                5.707142857143
            ],
            [
                2.685714285714,
                3.331428571429,
                2.877142857143,
                3.422857142857,
                1.768571428571,
                3.514285714286
            ],
            [
                7.928571428571,
                7.337142857143,
                8.145714285714,
                6.854285714286,
                7.362857142857,
                8.171428571429
            ],
        ];
        assert_close(&got, &want);
    }

    #[test]
    fn all_none_returns_input() {
        let x = fixture();
        let got = remove_batch_effect(&x, None, None, None, None).unwrap();
        assert_close(&got, &x);
    }
}