1use anyhow::{bail, Result};
17use ndarray::{s, Array2, Axis};
18
19use crate::fit::{lmfit, non_estimable};
20
21fn sum_contrasts(labels: &[String]) -> Array2<f64> {
27 let mut levels: Vec<&String> = labels.iter().collect();
28 levels.sort();
29 levels.dedup();
30 let k = levels.len();
31 let n = labels.len();
32 if k < 2 {
33 return Array2::zeros((n, 0));
34 }
35 let level_index = |lab: &String| levels.iter().position(|&l| l == lab).unwrap();
36 let mut m = Array2::<f64>::zeros((n, k - 1));
37 for (row, lab) in labels.iter().enumerate() {
38 let li = level_index(lab);
39 if li == k - 1 {
40 for c in 0..(k - 1) {
41 m[[row, c]] = -1.0;
42 }
43 } else {
44 m[[row, li]] = 1.0;
45 }
46 }
47 m
48}
49
50fn center_columns(cov: &Array2<f64>) -> Array2<f64> {
53 let mut out = cov.clone();
54 let n = cov.nrows() as f64;
55 for mut col in out.columns_mut() {
56 let mean = col.sum() / n;
57 col.mapv_inplace(|v| v - mean);
58 }
59 out
60}
61
62fn hstack(n: usize, blocks: &[&Array2<f64>]) -> Array2<f64> {
64 let total: usize = blocks.iter().map(|b| b.ncols()).sum();
65 let mut out = Array2::<f64>::zeros((n, total));
66 let mut off = 0usize;
67 for b in blocks {
68 let w = b.ncols();
69 if w > 0 {
70 out.slice_mut(s![.., off..off + w]).assign(b);
71 off += w;
72 }
73 }
74 out
75}
76
77pub fn remove_batch_effect(
90 x: &Array2<f64>,
91 batch: Option<&[String]>,
92 batch2: Option<&[String]>,
93 covariates: Option<&Array2<f64>>,
94 design: Option<&Array2<f64>>,
95) -> Result<Array2<f64>> {
96 let n_samples = x.ncols();
97
98 if batch.is_none() && batch2.is_none() && covariates.is_none() {
99 return Ok(x.clone());
100 }
101
102 let mut blocks: Vec<Array2<f64>> = Vec::new();
104 for b in [batch, batch2].into_iter().flatten() {
105 if b.len() != n_samples {
106 bail!(
107 "batch length ({}) does not match number of samples ({})",
108 b.len(),
109 n_samples
110 );
111 }
112 blocks.push(sum_contrasts(b));
113 }
114 if let Some(cov) = covariates {
115 if cov.nrows() != n_samples {
116 bail!(
117 "covariates rows ({}) does not match number of samples ({})",
118 cov.nrows(),
119 n_samples
120 );
121 }
122 blocks.push(center_columns(cov));
123 }
124 let block_refs: Vec<&Array2<f64>> = blocks.iter().collect();
125 let x_batch = hstack(n_samples, &block_refs);
126
127 let design_owned;
129 let design = match design {
130 Some(d) => {
131 if d.nrows() != n_samples {
132 bail!(
133 "design rows ({}) does not match number of samples ({})",
134 d.nrows(),
135 n_samples
136 );
137 }
138 d
139 }
140 None => {
141 design_owned = Array2::<f64>::ones((n_samples, 1));
142 &design_owned
143 }
144 };
145 let n_design = design.ncols();
146
147 let full = hstack(n_samples, &[design, &x_batch]);
151 let n_total = full.ncols();
152 let kept: Vec<usize> = match non_estimable(&full) {
153 None => (0..n_total).collect(),
154 Some(dep) => (0..n_total).filter(|j| !dep.contains(j)).collect(),
155 };
156 let reduced = full.select(Axis(1), &kept);
157
158 let gene_names: Vec<String> = (0..x.nrows()).map(|i| i.to_string()).collect();
159 let coef_names: Vec<String> = kept.iter().map(|j| j.to_string()).collect();
160 let fit = lmfit(x, &reduced, gene_names, coef_names)?;
161
162 let n_genes = x.nrows();
165 let mut beta_full = Array2::<f64>::zeros((n_genes, n_total));
166 for (col, &j) in kept.iter().enumerate() {
167 beta_full
168 .slice_mut(s![.., j])
169 .assign(&fit.coefficients.slice(s![.., col]));
170 }
171 let beta_batch = beta_full.slice(s![.., n_design..]).to_owned();
172
173 Ok(x - &beta_batch.dot(&x_batch.t()))
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use ndarray::array;
180
181 fn fixture() -> Array2<f64> {
182 array![
183 [5.1, 4.8, 6.2, 5.5, 4.9, 6.0],
184 [2.3, 3.1, 2.8, 3.5, 2.0, 3.9],
185 [7.7, 7.2, 8.1, 6.9, 7.5, 8.4],
186 ]
187 }
188
189 fn labels(v: &[&str]) -> Vec<String> {
190 v.iter().map(|s| s.to_string()).collect()
191 }
192
193 fn assert_close(got: &Array2<f64>, want: &Array2<f64>) {
194 assert_eq!(got.dim(), want.dim());
195 for (a, b) in got.iter().zip(want.iter()) {
196 assert!((a - b).abs() < 1e-9, "got {a} want {b}");
197 }
198 }
199
200 #[test]
203 fn case_a_batch_only() {
204 let x = fixture();
205 let batch = labels(&["a", "a", "b", "b", "a", "b"]);
206 let got = remove_batch_effect(&x, Some(&batch), None, None, None).unwrap();
207 let want = array![
208 [
209 5.583333333333,
210 5.283333333333,
211 5.716666666667,
212 5.016666666667,
213 5.383333333333,
214 5.516666666667
215 ],
216 [
217 2.766666666667,
218 3.566666666667,
219 2.333333333333,
220 3.033333333333,
221 2.466666666667,
222 3.433333333333
223 ],
224 [
225 7.866666666667,
226 7.366666666667,
227 7.933333333333,
228 6.733333333333,
229 7.666666666667,
230 8.233333333333
231 ],
232 ];
233 assert_close(&got, &want);
234 }
235
236 #[test]
237 fn case_b_batch_and_design() {
238 let x = fixture();
239 let batch = labels(&["a", "a", "b", "b", "a", "b"]);
240 let design = array![
242 [1.0, 0.0],
243 [1.0, 1.0],
244 [1.0, 0.0],
245 [1.0, 1.0],
246 [1.0, 0.0],
247 [1.0, 1.0],
248 ];
249 let got = remove_batch_effect(&x, Some(&batch), None, None, Some(&design)).unwrap();
250 let want = array![
251 [
252 5.637500000000,
253 5.337500000000,
254 5.662500000000,
255 4.962500000000,
256 5.437500000000,
257 5.462500000000
258 ],
259 [
260 2.612500000000,
261 3.412500000000,
262 2.487500000000,
263 3.187500000000,
264 2.312500000000,
265 3.587500000000
266 ],
267 [
268 7.937500000000,
269 7.437500000000,
270 7.862500000000,
271 6.662500000000,
272 7.737500000000,
273 8.162500000000
274 ],
275 ];
276 assert_close(&got, &want);
277 }
278
279 #[test]
282 fn case_c_confounded_full() {
283 let x = fixture();
284 let batch = labels(&["a", "a", "b", "b", "a", "b"]);
285 let batch2 = labels(&["x", "y", "x", "y", "x", "y"]);
286 let covs = array![[0.1], [0.2], [0.3], [0.4], [0.5], [0.6]];
287 let design = array![
288 [1.0, 0.0],
289 [1.0, 1.0],
290 [1.0, 0.0],
291 [1.0, 1.0],
292 [1.0, 0.0],
293 [1.0, 1.0],
294 ];
295 let got = remove_batch_effect(&x, Some(&batch), Some(&batch2), Some(&covs), Some(&design))
296 .unwrap();
297 let want = array![
298 [
299 5.617307692308,
300 5.328846153846,
301 5.648076923077,
302 4.959615384615,
303 5.463461538462,
304 5.482692307692
305 ],
306 [
307 2.578846153846,
308 3.398076923077,
309 2.463461538462,
310 3.182692307692,
311 2.355769230769,
312 3.621153846154
313 ],
314 [
315 8.078846153846,
316 7.498076923077,
317 7.963461538462,
318 6.682692307692,
319 7.555769230769,
320 8.021153846154
321 ],
322 ];
323 assert_close(&got, &want);
324 }
325
326 #[test]
327 fn case_d_covariates_only() {
328 let x = fixture();
329 let covs = array![[0.1], [0.2], [0.3], [0.4], [0.5], [0.6]];
330 let got = remove_batch_effect(&x, None, None, Some(&covs), None).unwrap();
331 let want = array![
332 [
333 5.392857142857,
334 4.975714285714,
335 6.258571428571,
336 5.441428571429,
337 4.724285714286,
338 5.707142857143
339 ],
340 [
341 2.685714285714,
342 3.331428571429,
343 2.877142857143,
344 3.422857142857,
345 1.768571428571,
346 3.514285714286
347 ],
348 [
349 7.928571428571,
350 7.337142857143,
351 8.145714285714,
352 6.854285714286,
353 7.362857142857,
354 8.171428571429
355 ],
356 ];
357 assert_close(&got, &want);
358 }
359
360 #[test]
361 fn all_none_returns_input() {
362 let x = fixture();
363 let got = remove_batch_effect(&x, None, None, None, None).unwrap();
364 assert_close(&got, &x);
365 }
366}