Skip to main content

limma/
batch.rs

1//! Remove batch effects from an expression matrix. Port of limma's
2//! `removeBatchEffect` (`removeBatchEffect.R`).
3//!
4//! The batch covariates are coded with sum-to-zero contrasts (`contr.sum`) and
5//! the numeric covariates are mean-centred, exactly as in R. A single combined
6//! linear model `lmFit(x, cbind(design, X.batch))` is fitted and the fitted
7//! batch component `beta %*% t(X.batch)` is subtracted from the data.
8//!
9//! When a batch column is confounded with the design (or another batch column)
10//! the joint model is rank-deficient; R's `lmFit` flags those coefficients as
11//! `NA` and `removeBatchEffect` replaces them with zero. We reproduce this by
12//! dropping the linearly dependent columns (the same first-occurrence-kept set
13//! R's pivoted QR retains), fitting the reduced model, and treating the dropped
14//! coefficients as zero.
15
16use anyhow::{bail, Result};
17use ndarray::{s, Array2, Axis};
18
19use crate::fit::{lmfit, non_estimable};
20
21/// Build the `model.matrix(~f)[,-1]` sum-to-zero contrast columns for a factor
22/// given as per-sample labels. Levels are taken in sorted order (matching R's
23/// `as.factor`). A factor with `k` levels yields `k-1` columns: level `i`
24/// (`0 <= i < k-1`) maps to the `i`-th unit row, and the last level maps to a
25/// row of `-1`s.
26fn sum_contrasts(labels: &[String]) -> Array2<f64> {
27    let mut levels: Vec<&String> = labels.iter().collect();
28    levels.sort();
29    levels.dedup();
30    let k = levels.len();
31    let n = labels.len();
32    if k < 2 {
33        return Array2::zeros((n, 0));
34    }
35    let level_index = |lab: &String| levels.iter().position(|&l| l == lab).unwrap();
36    let mut m = Array2::<f64>::zeros((n, k - 1));
37    for (row, lab) in labels.iter().enumerate() {
38        let li = level_index(lab);
39        if li == k - 1 {
40            for c in 0..(k - 1) {
41                m[[row, c]] = -1.0;
42            }
43        } else {
44            m[[row, li]] = 1.0;
45        }
46    }
47    m
48}
49
50/// Centre each column of `cov` by subtracting its mean (R's
51/// `t(t(covariates) - colMeans(covariates))`).
52fn center_columns(cov: &Array2<f64>) -> Array2<f64> {
53    let mut out = cov.clone();
54    let n = cov.nrows() as f64;
55    for mut col in out.columns_mut() {
56        let mean = col.sum() / n;
57        col.mapv_inplace(|v| v - mean);
58    }
59    out
60}
61
62/// Horizontally stack a list of `n`-row blocks.
63fn hstack(n: usize, blocks: &[&Array2<f64>]) -> Array2<f64> {
64    let total: usize = blocks.iter().map(|b| b.ncols()).sum();
65    let mut out = Array2::<f64>::zeros((n, total));
66    let mut off = 0usize;
67    for b in blocks {
68        let w = b.ncols();
69        if w > 0 {
70            out.slice_mut(s![.., off..off + w]).assign(b);
71            off += w;
72        }
73    }
74    out
75}
76
77/// Remove batch effects from `x` (`n_genes x n_samples`).
78///
79/// * `batch`, `batch2` — optional per-sample factor labels (length `n_samples`)
80///   for one or two blocking factors, coded with sum-to-zero contrasts.
81/// * `covariates` — optional `n_samples x k` numeric covariates to remove
82///   (mean-centred before fitting).
83/// * `design` — optional `n_samples x p` design matrix of experimental
84///   conditions to preserve. Defaults to a single intercept column (one-group
85///   experiment) when `None`.
86///
87/// Returns the batch-corrected matrix (`n_genes x n_samples`). With all of
88/// `batch`, `batch2`, `covariates` `None` the input is returned unchanged.
89pub fn remove_batch_effect(
90    x: &Array2<f64>,
91    batch: Option<&[String]>,
92    batch2: Option<&[String]>,
93    covariates: Option<&Array2<f64>>,
94    design: Option<&Array2<f64>>,
95) -> Result<Array2<f64>> {
96    let n_samples = x.ncols();
97
98    if batch.is_none() && batch2.is_none() && covariates.is_none() {
99        return Ok(x.clone());
100    }
101
102    // Build the batch-covariate block X.batch = cbind(batch, batch2, covariates).
103    let mut blocks: Vec<Array2<f64>> = Vec::new();
104    for b in [batch, batch2].into_iter().flatten() {
105        if b.len() != n_samples {
106            bail!(
107                "batch length ({}) does not match number of samples ({})",
108                b.len(),
109                n_samples
110            );
111        }
112        blocks.push(sum_contrasts(b));
113    }
114    if let Some(cov) = covariates {
115        if cov.nrows() != n_samples {
116            bail!(
117                "covariates rows ({}) does not match number of samples ({})",
118                cov.nrows(),
119                n_samples
120            );
121        }
122        blocks.push(center_columns(cov));
123    }
124    let block_refs: Vec<&Array2<f64>> = blocks.iter().collect();
125    let x_batch = hstack(n_samples, &block_refs);
126
127    // Design of interest (default: one-group intercept).
128    let design_owned;
129    let design = match design {
130        Some(d) => {
131            if d.nrows() != n_samples {
132                bail!(
133                    "design rows ({}) does not match number of samples ({})",
134                    d.nrows(),
135                    n_samples
136                );
137            }
138            d
139        }
140        None => {
141            design_owned = Array2::<f64>::ones((n_samples, 1));
142            &design_owned
143        }
144    };
145    let n_design = design.ncols();
146
147    // Combined model cbind(design, X.batch); drop columns that are not
148    // estimable (confounded), fit the reduced model, and treat dropped
149    // coefficients as zero.
150    let full = hstack(n_samples, &[design, &x_batch]);
151    let n_total = full.ncols();
152    let kept: Vec<usize> = match non_estimable(&full) {
153        None => (0..n_total).collect(),
154        Some(dep) => (0..n_total).filter(|j| !dep.contains(j)).collect(),
155    };
156    let reduced = full.select(Axis(1), &kept);
157
158    let gene_names: Vec<String> = (0..x.nrows()).map(|i| i.to_string()).collect();
159    let coef_names: Vec<String> = kept.iter().map(|j| j.to_string()).collect();
160    let fit = lmfit(x, &reduced, gene_names, coef_names)?;
161
162    // Scatter the reduced coefficients back to full width (dropped -> 0), then
163    // keep only the X.batch columns.
164    let n_genes = x.nrows();
165    let mut beta_full = Array2::<f64>::zeros((n_genes, n_total));
166    for (col, &j) in kept.iter().enumerate() {
167        beta_full
168            .slice_mut(s![.., j])
169            .assign(&fit.coefficients.slice(s![.., col]));
170    }
171    let beta_batch = beta_full.slice(s![.., n_design..]).to_owned();
172
173    Ok(x - &beta_batch.dot(&x_batch.t()))
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use ndarray::array;
180
181    fn fixture() -> Array2<f64> {
182        array![
183            [5.1, 4.8, 6.2, 5.5, 4.9, 6.0],
184            [2.3, 3.1, 2.8, 3.5, 2.0, 3.9],
185            [7.7, 7.2, 8.1, 6.9, 7.5, 8.4],
186        ]
187    }
188
189    fn labels(v: &[&str]) -> Vec<String> {
190        v.iter().map(|s| s.to_string()).collect()
191    }
192
193    fn assert_close(got: &Array2<f64>, want: &Array2<f64>) {
194        assert_eq!(got.dim(), want.dim());
195        for (a, b) in got.iter().zip(want.iter()) {
196            assert!((a - b).abs() < 1e-9, "got {a} want {b}");
197        }
198    }
199
200    // Reference matrices from R limma 3.68.3 (scratch/rbe_ref.R).
201
202    #[test]
203    fn case_a_batch_only() {
204        let x = fixture();
205        let batch = labels(&["a", "a", "b", "b", "a", "b"]);
206        let got = remove_batch_effect(&x, Some(&batch), None, None, None).unwrap();
207        let want = array![
208            [
209                5.583333333333,
210                5.283333333333,
211                5.716666666667,
212                5.016666666667,
213                5.383333333333,
214                5.516666666667
215            ],
216            [
217                2.766666666667,
218                3.566666666667,
219                2.333333333333,
220                3.033333333333,
221                2.466666666667,
222                3.433333333333
223            ],
224            [
225                7.866666666667,
226                7.366666666667,
227                7.933333333333,
228                6.733333333333,
229                7.666666666667,
230                8.233333333333
231            ],
232        ];
233        assert_close(&got, &want);
234    }
235
236    #[test]
237    fn case_b_batch_and_design() {
238        let x = fixture();
239        let batch = labels(&["a", "a", "b", "b", "a", "b"]);
240        // model.matrix(~group), group = g1 g2 g1 g2 g1 g2.
241        let design = array![
242            [1.0, 0.0],
243            [1.0, 1.0],
244            [1.0, 0.0],
245            [1.0, 1.0],
246            [1.0, 0.0],
247            [1.0, 1.0],
248        ];
249        let got = remove_batch_effect(&x, Some(&batch), None, None, Some(&design)).unwrap();
250        let want = array![
251            [
252                5.637500000000,
253                5.337500000000,
254                5.662500000000,
255                4.962500000000,
256                5.437500000000,
257                5.462500000000
258            ],
259            [
260                2.612500000000,
261                3.412500000000,
262                2.487500000000,
263                3.187500000000,
264                2.312500000000,
265                3.587500000000
266            ],
267            [
268                7.937500000000,
269                7.437500000000,
270                7.862500000000,
271                6.662500000000,
272                7.737500000000,
273                8.162500000000
274            ],
275        ];
276        assert_close(&got, &want);
277    }
278
279    /// batch2 is confounded with the design here (x/y tracks g1/g2), so R drops
280    /// the `batch21` coefficient (NA -> 0). Exercises the rank-deficient path.
281    #[test]
282    fn case_c_confounded_full() {
283        let x = fixture();
284        let batch = labels(&["a", "a", "b", "b", "a", "b"]);
285        let batch2 = labels(&["x", "y", "x", "y", "x", "y"]);
286        let covs = array![[0.1], [0.2], [0.3], [0.4], [0.5], [0.6]];
287        let design = array![
288            [1.0, 0.0],
289            [1.0, 1.0],
290            [1.0, 0.0],
291            [1.0, 1.0],
292            [1.0, 0.0],
293            [1.0, 1.0],
294        ];
295        let got = remove_batch_effect(&x, Some(&batch), Some(&batch2), Some(&covs), Some(&design))
296            .unwrap();
297        let want = array![
298            [
299                5.617307692308,
300                5.328846153846,
301                5.648076923077,
302                4.959615384615,
303                5.463461538462,
304                5.482692307692
305            ],
306            [
307                2.578846153846,
308                3.398076923077,
309                2.463461538462,
310                3.182692307692,
311                2.355769230769,
312                3.621153846154
313            ],
314            [
315                8.078846153846,
316                7.498076923077,
317                7.963461538462,
318                6.682692307692,
319                7.555769230769,
320                8.021153846154
321            ],
322        ];
323        assert_close(&got, &want);
324    }
325
326    #[test]
327    fn case_d_covariates_only() {
328        let x = fixture();
329        let covs = array![[0.1], [0.2], [0.3], [0.4], [0.5], [0.6]];
330        let got = remove_batch_effect(&x, None, None, Some(&covs), None).unwrap();
331        let want = array![
332            [
333                5.392857142857,
334                4.975714285714,
335                6.258571428571,
336                5.441428571429,
337                4.724285714286,
338                5.707142857143
339            ],
340            [
341                2.685714285714,
342                3.331428571429,
343                2.877142857143,
344                3.422857142857,
345                1.768571428571,
346                3.514285714286
347            ],
348            [
349                7.928571428571,
350                7.337142857143,
351                8.145714285714,
352                6.854285714286,
353                7.362857142857,
354                8.171428571429
355            ],
356        ];
357        assert_close(&got, &want);
358    }
359
360    #[test]
361    fn all_none_returns_input() {
362        let x = fixture();
363        let got = remove_batch_effect(&x, None, None, None, None).unwrap();
364        assert_close(&got, &x);
365    }
366}