use anyhow::{bail, Result};
use ndarray::{s, Array2, Axis};
use crate::fit::{lmfit, non_estimable};
fn sum_contrasts(labels: &[String]) -> Array2<f64> {
let mut levels: Vec<&String> = labels.iter().collect();
levels.sort();
levels.dedup();
let k = levels.len();
let n = labels.len();
if k < 2 {
return Array2::zeros((n, 0));
}
let level_index = |lab: &String| levels.iter().position(|&l| l == lab).unwrap();
let mut m = Array2::<f64>::zeros((n, k - 1));
for (row, lab) in labels.iter().enumerate() {
let li = level_index(lab);
if li == k - 1 {
for c in 0..(k - 1) {
m[[row, c]] = -1.0;
}
} else {
m[[row, li]] = 1.0;
}
}
m
}
fn center_columns(cov: &Array2<f64>) -> Array2<f64> {
let mut out = cov.clone();
let n = cov.nrows() as f64;
for mut col in out.columns_mut() {
let mean = col.sum() / n;
col.mapv_inplace(|v| v - mean);
}
out
}
fn hstack(n: usize, blocks: &[&Array2<f64>]) -> Array2<f64> {
let total: usize = blocks.iter().map(|b| b.ncols()).sum();
let mut out = Array2::<f64>::zeros((n, total));
let mut off = 0usize;
for b in blocks {
let w = b.ncols();
if w > 0 {
out.slice_mut(s![.., off..off + w]).assign(b);
off += w;
}
}
out
}
pub fn remove_batch_effect(
x: &Array2<f64>,
batch: Option<&[String]>,
batch2: Option<&[String]>,
covariates: Option<&Array2<f64>>,
design: Option<&Array2<f64>>,
) -> Result<Array2<f64>> {
let n_samples = x.ncols();
if batch.is_none() && batch2.is_none() && covariates.is_none() {
return Ok(x.clone());
}
let mut blocks: Vec<Array2<f64>> = Vec::new();
for b in [batch, batch2].into_iter().flatten() {
if b.len() != n_samples {
bail!(
"batch length ({}) does not match number of samples ({})",
b.len(),
n_samples
);
}
blocks.push(sum_contrasts(b));
}
if let Some(cov) = covariates {
if cov.nrows() != n_samples {
bail!(
"covariates rows ({}) does not match number of samples ({})",
cov.nrows(),
n_samples
);
}
blocks.push(center_columns(cov));
}
let block_refs: Vec<&Array2<f64>> = blocks.iter().collect();
let x_batch = hstack(n_samples, &block_refs);
let design_owned;
let design = match design {
Some(d) => {
if d.nrows() != n_samples {
bail!(
"design rows ({}) does not match number of samples ({})",
d.nrows(),
n_samples
);
}
d
}
None => {
design_owned = Array2::<f64>::ones((n_samples, 1));
&design_owned
}
};
let n_design = design.ncols();
let full = hstack(n_samples, &[design, &x_batch]);
let n_total = full.ncols();
let kept: Vec<usize> = match non_estimable(&full) {
None => (0..n_total).collect(),
Some(dep) => (0..n_total).filter(|j| !dep.contains(j)).collect(),
};
let reduced = full.select(Axis(1), &kept);
let gene_names: Vec<String> = (0..x.nrows()).map(|i| i.to_string()).collect();
let coef_names: Vec<String> = kept.iter().map(|j| j.to_string()).collect();
let fit = lmfit(x, &reduced, gene_names, coef_names)?;
let n_genes = x.nrows();
let mut beta_full = Array2::<f64>::zeros((n_genes, n_total));
for (col, &j) in kept.iter().enumerate() {
beta_full
.slice_mut(s![.., j])
.assign(&fit.coefficients.slice(s![.., col]));
}
let beta_batch = beta_full.slice(s![.., n_design..]).to_owned();
Ok(x - &beta_batch.dot(&x_batch.t()))
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn fixture() -> Array2<f64> {
array![
[5.1, 4.8, 6.2, 5.5, 4.9, 6.0],
[2.3, 3.1, 2.8, 3.5, 2.0, 3.9],
[7.7, 7.2, 8.1, 6.9, 7.5, 8.4],
]
}
fn labels(v: &[&str]) -> Vec<String> {
v.iter().map(|s| s.to_string()).collect()
}
fn assert_close(got: &Array2<f64>, want: &Array2<f64>) {
assert_eq!(got.dim(), want.dim());
for (a, b) in got.iter().zip(want.iter()) {
assert!((a - b).abs() < 1e-9, "got {a} want {b}");
}
}
#[test]
fn case_a_batch_only() {
let x = fixture();
let batch = labels(&["a", "a", "b", "b", "a", "b"]);
let got = remove_batch_effect(&x, Some(&batch), None, None, None).unwrap();
let want = array![
[
5.583333333333,
5.283333333333,
5.716666666667,
5.016666666667,
5.383333333333,
5.516666666667
],
[
2.766666666667,
3.566666666667,
2.333333333333,
3.033333333333,
2.466666666667,
3.433333333333
],
[
7.866666666667,
7.366666666667,
7.933333333333,
6.733333333333,
7.666666666667,
8.233333333333
],
];
assert_close(&got, &want);
}
#[test]
fn case_b_batch_and_design() {
let x = fixture();
let batch = labels(&["a", "a", "b", "b", "a", "b"]);
let design = array![
[1.0, 0.0],
[1.0, 1.0],
[1.0, 0.0],
[1.0, 1.0],
[1.0, 0.0],
[1.0, 1.0],
];
let got = remove_batch_effect(&x, Some(&batch), None, None, Some(&design)).unwrap();
let want = array![
[
5.637500000000,
5.337500000000,
5.662500000000,
4.962500000000,
5.437500000000,
5.462500000000
],
[
2.612500000000,
3.412500000000,
2.487500000000,
3.187500000000,
2.312500000000,
3.587500000000
],
[
7.937500000000,
7.437500000000,
7.862500000000,
6.662500000000,
7.737500000000,
8.162500000000
],
];
assert_close(&got, &want);
}
#[test]
fn case_c_confounded_full() {
let x = fixture();
let batch = labels(&["a", "a", "b", "b", "a", "b"]);
let batch2 = labels(&["x", "y", "x", "y", "x", "y"]);
let covs = array![[0.1], [0.2], [0.3], [0.4], [0.5], [0.6]];
let design = array![
[1.0, 0.0],
[1.0, 1.0],
[1.0, 0.0],
[1.0, 1.0],
[1.0, 0.0],
[1.0, 1.0],
];
let got = remove_batch_effect(&x, Some(&batch), Some(&batch2), Some(&covs), Some(&design))
.unwrap();
let want = array![
[
5.617307692308,
5.328846153846,
5.648076923077,
4.959615384615,
5.463461538462,
5.482692307692
],
[
2.578846153846,
3.398076923077,
2.463461538462,
3.182692307692,
2.355769230769,
3.621153846154
],
[
8.078846153846,
7.498076923077,
7.963461538462,
6.682692307692,
7.555769230769,
8.021153846154
],
];
assert_close(&got, &want);
}
#[test]
fn case_d_covariates_only() {
let x = fixture();
let covs = array![[0.1], [0.2], [0.3], [0.4], [0.5], [0.6]];
let got = remove_batch_effect(&x, None, None, Some(&covs), None).unwrap();
let want = array![
[
5.392857142857,
4.975714285714,
6.258571428571,
5.441428571429,
4.724285714286,
5.707142857143
],
[
2.685714285714,
3.331428571429,
2.877142857143,
3.422857142857,
1.768571428571,
3.514285714286
],
[
7.928571428571,
7.337142857143,
8.145714285714,
6.854285714286,
7.362857142857,
8.171428571429
],
];
assert_close(&got, &want);
}
#[test]
fn all_none_returns_input() {
let x = fixture();
let got = remove_batch_effect(&x, None, None, None, None).unwrap();
assert_close(&got, &x);
}
}