1use std::collections::BTreeMap;
2use std::fs::File;
3use std::io::{BufRead, BufReader, BufWriter, Read, Write};
4use std::path::{Path, PathBuf};
5
6use flate2::read::MultiGzDecoder;
7use rayon::prelude::*;
8use rsomics_common::{Result, RsomicsError};
9
10pub struct CountMatrix {
14 pub n_genes: usize,
15 pub n_cells: usize,
16 pub entries: Vec<Entry>,
17}
18
19#[derive(Clone, Copy)]
20pub struct Entry {
21 pub gene: u32,
22 pub cell: u32,
23 pub value: f64,
24}
25
26pub fn open_mtx(dir: &Path) -> Result<Box<dyn Read>> {
27 for name in ["matrix.mtx.gz", "matrix.mtx"] {
28 let path = dir.join(name);
29 if path.exists() {
30 return open_maybe_gz(&path);
31 }
32 }
33 Err(RsomicsError::InvalidInput(format!(
34 "no matrix.mtx or matrix.mtx.gz in {}",
35 dir.display()
36 )))
37}
38
39fn open_maybe_gz(path: &Path) -> Result<Box<dyn Read>> {
40 let file = File::open(path)
41 .map_err(|e| RsomicsError::InvalidInput(format!("{}: {e}", path.display())))?;
42 if path.extension().is_some_and(|e| e == "gz") {
43 Ok(Box::new(MultiGzDecoder::new(file)))
44 } else {
45 Ok(Box::new(file))
46 }
47}
48
49pub fn parse_mtx(reader: impl Read) -> Result<CountMatrix> {
51 let mut reader = BufReader::new(reader);
52 let mut line = String::new();
53
54 reader.read_line(&mut line).map_err(RsomicsError::Io)?;
55 let banner = line.trim();
56 if !banner.starts_with("%%MatrixMarket") {
57 return Err(RsomicsError::InvalidInput(
58 "missing %%MatrixMarket banner".into(),
59 ));
60 }
61 let pattern = banner.contains("pattern");
62
63 let (n_genes, n_cells, nnz) = loop {
64 line.clear();
65 let n = reader.read_line(&mut line).map_err(RsomicsError::Io)?;
66 if n == 0 {
67 return Err(RsomicsError::InvalidInput("truncated MTX header".into()));
68 }
69 let t = line.trim();
70 if t.is_empty() || t.starts_with('%') {
71 continue;
72 }
73 let mut it = t.split_whitespace();
74 let rows = parse_usize(it.next())?;
75 let cols = parse_usize(it.next())?;
76 let nnz = parse_usize(it.next())?;
77 break (rows, cols, nnz);
78 };
79
80 let mut entries = Vec::with_capacity(nnz);
81 for raw in reader.lines() {
82 let raw = raw.map_err(RsomicsError::Io)?;
83 let t = raw.trim();
84 if t.is_empty() {
85 continue;
86 }
87 let mut it = t.split_whitespace();
88 let gene = parse_usize(it.next())?;
89 let cell = parse_usize(it.next())?;
90 let value = if pattern {
91 1.0
92 } else {
93 it.next()
94 .ok_or_else(|| RsomicsError::InvalidInput("MTX entry missing value".into()))?
95 .parse::<f64>()?
96 };
97 if gene == 0 || gene > n_genes || cell == 0 || cell > n_cells {
98 return Err(RsomicsError::InvalidInput(format!(
99 "MTX index out of bounds: ({gene}, {cell})"
100 )));
101 }
102 entries.push(Entry {
103 gene: (gene - 1) as u32,
104 cell: (cell - 1) as u32,
105 value,
106 });
107 }
108 if entries.len() != nnz {
109 return Err(RsomicsError::InvalidInput(format!(
110 "MTX declared {nnz} entries, found {}",
111 entries.len()
112 )));
113 }
114
115 Ok(CountMatrix {
116 n_genes,
117 n_cells,
118 entries,
119 })
120}
121
122fn densify_gene_major(m: &CountMatrix) -> Vec<f64> {
126 let mut dense = vec![0.0_f64; m.n_genes * m.n_cells];
127 let nc = m.n_cells;
128 for e in &m.entries {
129 dense[e.gene as usize * nc + e.cell as usize] = e.value;
130 }
131 dense
132}
133
134pub fn read_batch_labels(
139 path: &Path,
140 barcodes: &[String],
141 key: Option<&str>,
142) -> Result<(Vec<usize>, Vec<String>)> {
143 let f = File::open(path)
144 .map_err(|e| RsomicsError::InvalidInput(format!("{}: {e}", path.display())))?;
145 let reader = BufReader::new(f);
146 let mut lines = Vec::new();
147 for raw in reader.lines() {
148 let raw = raw.map_err(RsomicsError::Io)?;
149 let t = raw.trim_end_matches(['\n', '\r']).to_string();
150 if !t.is_empty() {
151 lines.push(t);
152 }
153 }
154 if lines.is_empty() {
155 return Err(RsomicsError::InvalidInput("empty batch TSV".into()));
156 }
157
158 let first: Vec<&str> = lines[0].split('\t').collect();
159 let has_header = is_header(&first, key);
160 let label_col = match (key, has_header) {
161 (Some(k), true) => first.iter().position(|c| *c == k).ok_or_else(|| {
162 RsomicsError::InvalidInput(format!("key {k:?} not in batch TSV header"))
163 })?,
164 (Some(k), false) => {
165 return Err(RsomicsError::InvalidInput(format!(
166 "--key {k:?} given but batch TSV has no header row"
167 )));
168 }
169 (None, _) => 1,
170 };
171
172 let mut by_barcode: BTreeMap<String, String> = BTreeMap::new();
173 for line in lines.iter().skip(usize::from(has_header)) {
174 let cols: Vec<&str> = line.split('\t').collect();
175 if label_col >= cols.len() {
176 return Err(RsomicsError::InvalidInput(
177 "batch TSV row shorter than the selected key column".into(),
178 ));
179 }
180 by_barcode.insert(cols[0].to_string(), cols[label_col].to_string());
181 }
182
183 let mut distinct: Vec<String> = by_barcode.values().cloned().collect();
188 distinct.sort();
189 distinct.dedup();
190 let level_of: BTreeMap<&str, usize> = distinct
191 .iter()
192 .enumerate()
193 .map(|(i, s)| (s.as_str(), i))
194 .collect();
195
196 let mut batch_of_cell = Vec::with_capacity(barcodes.len());
197 for bc in barcodes {
198 let label = by_barcode.get(bc).ok_or_else(|| {
199 RsomicsError::InvalidInput(format!("barcode {bc:?} missing from batch TSV"))
200 })?;
201 batch_of_cell.push(level_of[label.as_str()]);
202 }
203
204 if distinct.len() < 2 {
205 return Err(RsomicsError::InvalidInput(
206 "ComBat needs at least 2 batches".into(),
207 ));
208 }
209 Ok((batch_of_cell, distinct))
210}
211
212fn is_header(cols: &[&str], key: Option<&str>) -> bool {
213 if let Some(k) = key {
214 return cols.contains(&k);
215 }
216 cols.iter().any(|c| c.eq_ignore_ascii_case("barcode"))
217}
218
219pub fn read_barcodes(dir: &Path) -> Result<Vec<String>> {
221 for name in ["barcodes.tsv.gz", "barcodes.tsv"] {
222 let path = dir.join(name);
223 if path.exists() {
224 let r = open_maybe_gz(&path)?;
225 let reader = BufReader::new(r);
226 let mut out = Vec::new();
227 for raw in reader.lines() {
228 let raw = raw.map_err(RsomicsError::Io)?;
229 let t = raw.trim();
230 if !t.is_empty() {
231 out.push(t.split('\t').next().unwrap().to_string());
232 }
233 }
234 return Ok(out);
235 }
236 }
237 Err(RsomicsError::InvalidInput(format!(
238 "no barcodes.tsv in {}",
239 dir.display()
240 )))
241}
242
243const CONV: f64 = 1e-4;
244
245pub fn combat(dense: &mut [f64], n_genes: usize, n_cells: usize, batch_of_cell: &[usize]) {
249 let n_batch = batch_of_cell.iter().copied().max().unwrap() + 1;
250 let mut batch_cells: Vec<Vec<usize>> = vec![Vec::new(); n_batch];
251 for (cell, &b) in batch_of_cell.iter().enumerate() {
252 batch_cells[b].push(cell);
253 }
254 let n_b: Vec<f64> = batch_cells.iter().map(|c| c.len() as f64).collect();
255 let n_array = n_cells as f64;
256
257 let mut var_pooled = vec![0.0_f64; n_genes];
260 let mut stand_mean = vec![0.0_f64; n_genes];
261 let nc = n_cells;
262 dense
263 .par_chunks_mut(nc)
264 .zip(var_pooled.par_iter_mut())
265 .zip(stand_mean.par_iter_mut())
266 .for_each(|((row, vp), sm)| {
267 let mut bmean = vec![0.0_f64; n_batch];
268 for (b, cells) in batch_cells.iter().enumerate() {
269 let mut s = 0.0;
270 for &c in cells {
271 s += row[c];
272 }
273 bmean[b] = s / n_b[b];
274 }
275 let grand: f64 = (0..n_batch).map(|b| n_b[b] / n_array * bmean[b]).sum();
276 let mut ss = 0.0;
277 for (b, cells) in batch_cells.iter().enumerate() {
278 for &c in cells {
279 let d = row[c] - bmean[b];
280 ss += d * d;
281 }
282 }
283 let vp_g = ss / n_array;
284 *vp = vp_g;
285 *sm = grand;
286 let denom = vp_g.sqrt();
287 if vp_g == 0.0 {
288 for v in row.iter_mut() {
289 *v = 0.0;
290 }
291 } else {
292 for v in row.iter_mut() {
293 *v = (*v - grand) / denom;
294 }
295 }
296 });
297
298 let mut gamma_hat = vec![vec![0.0_f64; n_genes]; n_batch];
301 let mut delta_hat = vec![vec![0.0_f64; n_genes]; n_batch];
302 for b in 0..n_batch {
303 let cells = &batch_cells[b];
304 let nb = cells.len() as f64;
305 let gh = &mut gamma_hat[b];
306 let dh = &mut delta_hat[b];
307 dense
308 .par_chunks(nc)
309 .zip(gh.par_iter_mut())
310 .zip(dh.par_iter_mut())
311 .for_each(|((row, g), d)| {
312 let mut s = 0.0;
313 for &c in cells {
314 s += row[c];
315 }
316 let mean = s / nb;
317 *g = mean;
318 let mut ss = 0.0;
319 for &c in cells {
320 let e = row[c] - mean;
321 ss += e * e;
322 }
323 *d = if nb > 1.0 { ss / (nb - 1.0) } else { 0.0 };
324 });
325 }
326
327 let mut gamma_star = vec![vec![0.0_f64; n_genes]; n_batch];
330 let mut delta_star = vec![vec![0.0_f64; n_genes]; n_batch];
331 for b in 0..n_batch {
332 let gh = &gamma_hat[b];
333 let dh = &delta_hat[b];
334 let gamma_bar = mean(gh);
335 let t2 = var_ddof(gh, 0);
336 let a_prior = aprior(dh);
337 let b_prior = bprior(dh);
338 let cells = &batch_cells[b];
339
340 let std_rows: Vec<&[f64]> = (0..n_genes).map(|g| &dense[g * nc..g * nc + nc]).collect();
341 it_sol(
342 &std_rows,
343 cells,
344 gh,
345 dh,
346 gamma_bar,
347 t2,
348 a_prior,
349 b_prior,
350 &mut gamma_star[b],
351 &mut delta_star[b],
352 );
353 }
354
355 dense.par_chunks_mut(nc).enumerate().for_each(|(g, row)| {
358 let vpsq = var_pooled[g].sqrt();
359 let sm = stand_mean[g];
360 for b in 0..n_batch {
361 let dsq = delta_star[b][g].sqrt();
362 let gs = gamma_star[b][g];
363 for &c in &batch_cells[b] {
364 row[c] = (row[c] - gs) / dsq * vpsq + sm;
365 }
366 }
367 });
368}
369
370#[allow(clippy::too_many_arguments)]
377fn it_sol(
378 std_rows: &[&[f64]],
379 cells: &[usize],
380 g_hat: &[f64],
381 d_hat: &[f64],
382 g_bar: f64,
383 t2: f64,
384 a: f64,
385 b: f64,
386 g_out: &mut [f64],
387 d_out: &mut [f64],
388) {
389 let n = cells.len() as f64;
390 let n_genes = g_hat.len();
391 g_out.copy_from_slice(g_hat);
392 d_out.copy_from_slice(d_hat);
393 let mut g_new = vec![0.0_f64; n_genes];
394 let mut d_new = vec![0.0_f64; n_genes];
395
396 loop {
397 let mut g_change = f64::NEG_INFINITY;
398 let mut d_change = f64::NEG_INFINITY;
399 for i in 0..n_genes {
400 let gn = (t2 * n * g_hat[i] + d_out[i] * g_bar) / (t2 * n + d_out[i]);
401 let row = std_rows[i];
402 let mut sum2 = 0.0;
403 for &c in cells {
404 let e = row[c] - gn;
405 sum2 += e * e;
406 }
407 let dn = (0.5 * sum2 + b) / (n / 2.0 + a - 1.0);
408 g_change = numpy_max(g_change, (gn - g_out[i]).abs() / g_out[i].abs());
409 d_change = numpy_max(d_change, (dn - d_out[i]).abs() / d_out[i].abs());
410 g_new[i] = gn;
411 d_new[i] = dn;
412 }
413 g_out.copy_from_slice(&g_new);
414 d_out.copy_from_slice(&d_new);
415 let change = python_max(g_change, d_change);
416 #[allow(clippy::neg_cmp_op_on_partial_ord)]
419 if !(change > CONV) {
420 break;
421 }
422 }
423}
424
425fn numpy_max(acc: f64, x: f64) -> f64 {
427 if acc.is_nan() || x.is_nan() {
428 f64::NAN
429 } else {
430 acc.max(x)
431 }
432}
433
434fn python_max(a: f64, b: f64) -> f64 {
437 if b > a { b } else { a }
438}
439
440fn mean(x: &[f64]) -> f64 {
441 x.iter().sum::<f64>() / x.len() as f64
442}
443
444fn var_ddof(x: &[f64], ddof: usize) -> f64 {
445 let n = x.len() as f64;
446 let m = mean(x);
447 let ss: f64 = x.iter().map(|&v| (v - m) * (v - m)).sum();
448 ss / (n - ddof as f64)
449}
450
451fn aprior(delta_hat: &[f64]) -> f64 {
452 let m = mean(delta_hat);
453 let s2 = var_ddof(delta_hat, 1);
454 (2.0 * s2 + m * m) / s2
455}
456
457fn bprior(delta_hat: &[f64]) -> f64 {
458 let m = mean(delta_hat);
459 let s2 = var_ddof(delta_hat, 1);
460 (m * s2 + m * m * m) / s2
461}
462
463pub fn write_dense_gene_major(
466 n_genes: usize,
467 n_cells: usize,
468 dense: &[f64],
469 out: impl Write,
470) -> Result<()> {
471 let mut w = BufWriter::with_capacity(1 << 20, out);
472 w.write_all(b"%%MatrixMarket matrix array real general\n")
473 .map_err(RsomicsError::Io)?;
474 let mut header = format!("{n_genes} {n_cells}");
475 header.push('\n');
476 w.write_all(header.as_bytes()).map_err(RsomicsError::Io)?;
477
478 let mut fmt = ryu::Buffer::new();
479 let mut buf: Vec<u8> = Vec::with_capacity(1 << 16);
480 for cell in 0..n_cells {
481 for gene in 0..n_genes {
482 buf.extend_from_slice(fmt.format(dense[gene * n_cells + cell]).as_bytes());
483 buf.push(b'\n');
484 if buf.len() >= 1 << 15 {
485 w.write_all(&buf).map_err(RsomicsError::Io)?;
486 buf.clear();
487 }
488 }
489 }
490 w.write_all(&buf).map_err(RsomicsError::Io)?;
491 w.flush().map_err(RsomicsError::Io)?;
492 Ok(())
493}
494
495fn parse_usize(tok: Option<&str>) -> Result<usize> {
496 tok.ok_or_else(|| RsomicsError::InvalidInput("MTX header missing a dimension".into()))?
497 .parse::<usize>()
498 .map_err(Into::into)
499}
500
501pub fn open_output(path: &str) -> Result<Box<dyn Write>> {
502 if path == "-" {
503 Ok(Box::new(std::io::stdout().lock()))
504 } else {
505 Ok(Box::new(
506 File::create(PathBuf::from(path)).map_err(RsomicsError::Io)?,
507 ))
508 }
509}
510
511pub fn run(
514 dir: &Path,
515 batch_tsv: &Path,
516 key: Option<&str>,
517 out: impl Write,
518) -> Result<(usize, usize, usize)> {
519 let m = parse_mtx(open_mtx(dir)?)?;
520 let barcodes = read_barcodes(dir)?;
521 if barcodes.len() != m.n_cells {
522 return Err(RsomicsError::InvalidInput(format!(
523 "{} barcodes but matrix has {} cells",
524 barcodes.len(),
525 m.n_cells
526 )));
527 }
528 let (batch_of_cell, levels) = read_batch_labels(batch_tsv, &barcodes, key)?;
529
530 let mut dense = densify_gene_major(&m);
531 combat(&mut dense, m.n_genes, m.n_cells, &batch_of_cell);
532 write_dense_gene_major(m.n_genes, m.n_cells, &dense, out)?;
533 Ok((m.n_genes, m.n_cells, levels.len()))
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539
540 fn two_batch() -> (Vec<f64>, usize, usize, Vec<usize>) {
541 let n_genes = 3;
542 let n_cells = 6;
543 let batch = vec![0, 0, 0, 1, 1, 1];
544 let mut dense = vec![0.0_f64; n_genes * n_cells];
545 let vals = [
546 [1.0, 2.0, 1.5, 4.0, 5.0, 4.5],
547 [2.0, 2.5, 3.0, 1.0, 0.5, 1.2],
548 [0.5, 0.7, 0.6, 0.55, 0.62, 0.58],
549 ];
550 for (g, row) in vals.iter().enumerate() {
551 for (c, &v) in row.iter().enumerate() {
552 dense[g * n_cells + c] = v;
553 }
554 }
555 (dense, n_genes, n_cells, batch)
556 }
557
558 #[test]
559 fn corrected_means_converge_across_batches() {
560 let (mut dense, ng, nc, batch) = two_batch();
561 combat(&mut dense, ng, nc, &batch);
562 let m0a = (0..3).map(|c| dense[c]).sum::<f64>() / 3.0;
565 let m0b = (3..6).map(|c| dense[c]).sum::<f64>() / 3.0;
566 assert!(
567 (m0a - m0b).abs() < 1.0,
568 "batch means not pulled together: {m0a} vs {m0b}"
569 );
570 }
571
572 #[test]
573 fn zero_variance_gene_collapses_to_grand_mean() {
574 let n_genes = 2;
576 let n_cells = 6;
577 let batch = vec![0, 0, 0, 1, 1, 1];
578 let g0 = [1.0, 2.0, 1.5, 4.0, 5.0, 4.5];
579 let g1 = [2.0; 6];
580 let mut dense = vec![0.0_f64; n_genes * n_cells];
581 dense[..6].copy_from_slice(&g0);
582 dense[6..].copy_from_slice(&g1);
583 combat(&mut dense, n_genes, n_cells, &batch);
584 for &v in &dense {
585 assert!(v.is_finite(), "non-finite ComBat output: {v}");
586 }
587 for &v in &dense[6..] {
590 assert!(
591 (v - 2.0).abs() < 1e-12,
592 "zero-var gene not at grand mean: {v}"
593 );
594 }
595 }
596
597 #[test]
598 fn priors_match_numpy_moments() {
599 let d = [1.0, 2.0, 3.0, 4.0];
600 let m = mean(&d);
601 assert!((m - 2.5).abs() < 1e-12);
602 assert!((var_ddof(&d, 1) - 5.0 / 3.0).abs() < 1e-12);
604 assert!((var_ddof(&d, 0) - 1.25).abs() < 1e-12);
606 }
607}