use std::collections::BTreeMap;
use std::fs::File;
use std::io::{BufRead, BufReader, BufWriter, Read, Write};
use std::path::{Path, PathBuf};
use flate2::read::MultiGzDecoder;
use rayon::prelude::*;
use rsomics_common::{Result, RsomicsError};
pub struct CountMatrix {
pub n_genes: usize,
pub n_cells: usize,
pub entries: Vec<Entry>,
}
#[derive(Clone, Copy)]
pub struct Entry {
pub gene: u32,
pub cell: u32,
pub value: f64,
}
pub fn open_mtx(dir: &Path) -> Result<Box<dyn Read>> {
for name in ["matrix.mtx.gz", "matrix.mtx"] {
let path = dir.join(name);
if path.exists() {
return open_maybe_gz(&path);
}
}
Err(RsomicsError::InvalidInput(format!(
"no matrix.mtx or matrix.mtx.gz in {}",
dir.display()
)))
}
fn open_maybe_gz(path: &Path) -> Result<Box<dyn Read>> {
let file = File::open(path)
.map_err(|e| RsomicsError::InvalidInput(format!("{}: {e}", path.display())))?;
if path.extension().is_some_and(|e| e == "gz") {
Ok(Box::new(MultiGzDecoder::new(file)))
} else {
Ok(Box::new(file))
}
}
pub fn parse_mtx(reader: impl Read) -> Result<CountMatrix> {
let mut reader = BufReader::new(reader);
let mut line = String::new();
reader.read_line(&mut line).map_err(RsomicsError::Io)?;
let banner = line.trim();
if !banner.starts_with("%%MatrixMarket") {
return Err(RsomicsError::InvalidInput(
"missing %%MatrixMarket banner".into(),
));
}
let pattern = banner.contains("pattern");
let (n_genes, n_cells, nnz) = loop {
line.clear();
let n = reader.read_line(&mut line).map_err(RsomicsError::Io)?;
if n == 0 {
return Err(RsomicsError::InvalidInput("truncated MTX header".into()));
}
let t = line.trim();
if t.is_empty() || t.starts_with('%') {
continue;
}
let mut it = t.split_whitespace();
let rows = parse_usize(it.next())?;
let cols = parse_usize(it.next())?;
let nnz = parse_usize(it.next())?;
break (rows, cols, nnz);
};
let mut entries = Vec::with_capacity(nnz);
for raw in reader.lines() {
let raw = raw.map_err(RsomicsError::Io)?;
let t = raw.trim();
if t.is_empty() {
continue;
}
let mut it = t.split_whitespace();
let gene = parse_usize(it.next())?;
let cell = parse_usize(it.next())?;
let value = if pattern {
1.0
} else {
it.next()
.ok_or_else(|| RsomicsError::InvalidInput("MTX entry missing value".into()))?
.parse::<f64>()?
};
if gene == 0 || gene > n_genes || cell == 0 || cell > n_cells {
return Err(RsomicsError::InvalidInput(format!(
"MTX index out of bounds: ({gene}, {cell})"
)));
}
entries.push(Entry {
gene: (gene - 1) as u32,
cell: (cell - 1) as u32,
value,
});
}
if entries.len() != nnz {
return Err(RsomicsError::InvalidInput(format!(
"MTX declared {nnz} entries, found {}",
entries.len()
)));
}
Ok(CountMatrix {
n_genes,
n_cells,
entries,
})
}
fn densify_gene_major(m: &CountMatrix) -> Vec<f64> {
let mut dense = vec![0.0_f64; m.n_genes * m.n_cells];
let nc = m.n_cells;
for e in &m.entries {
dense[e.gene as usize * nc + e.cell as usize] = e.value;
}
dense
}
pub fn read_batch_labels(
path: &Path,
barcodes: &[String],
key: Option<&str>,
) -> Result<(Vec<usize>, Vec<String>)> {
let f = File::open(path)
.map_err(|e| RsomicsError::InvalidInput(format!("{}: {e}", path.display())))?;
let reader = BufReader::new(f);
let mut lines = Vec::new();
for raw in reader.lines() {
let raw = raw.map_err(RsomicsError::Io)?;
let t = raw.trim_end_matches(['\n', '\r']).to_string();
if !t.is_empty() {
lines.push(t);
}
}
if lines.is_empty() {
return Err(RsomicsError::InvalidInput("empty batch TSV".into()));
}
let first: Vec<&str> = lines[0].split('\t').collect();
let has_header = is_header(&first, key);
let label_col = match (key, has_header) {
(Some(k), true) => first.iter().position(|c| *c == k).ok_or_else(|| {
RsomicsError::InvalidInput(format!("key {k:?} not in batch TSV header"))
})?,
(Some(k), false) => {
return Err(RsomicsError::InvalidInput(format!(
"--key {k:?} given but batch TSV has no header row"
)));
}
(None, _) => 1,
};
let mut by_barcode: BTreeMap<String, String> = BTreeMap::new();
for line in lines.iter().skip(usize::from(has_header)) {
let cols: Vec<&str> = line.split('\t').collect();
if label_col >= cols.len() {
return Err(RsomicsError::InvalidInput(
"batch TSV row shorter than the selected key column".into(),
));
}
by_barcode.insert(cols[0].to_string(), cols[label_col].to_string());
}
let mut distinct: Vec<String> = by_barcode.values().cloned().collect();
distinct.sort();
distinct.dedup();
let level_of: BTreeMap<&str, usize> = distinct
.iter()
.enumerate()
.map(|(i, s)| (s.as_str(), i))
.collect();
let mut batch_of_cell = Vec::with_capacity(barcodes.len());
for bc in barcodes {
let label = by_barcode.get(bc).ok_or_else(|| {
RsomicsError::InvalidInput(format!("barcode {bc:?} missing from batch TSV"))
})?;
batch_of_cell.push(level_of[label.as_str()]);
}
if distinct.len() < 2 {
return Err(RsomicsError::InvalidInput(
"ComBat needs at least 2 batches".into(),
));
}
Ok((batch_of_cell, distinct))
}
fn is_header(cols: &[&str], key: Option<&str>) -> bool {
if let Some(k) = key {
return cols.contains(&k);
}
cols.iter().any(|c| c.eq_ignore_ascii_case("barcode"))
}
pub fn read_barcodes(dir: &Path) -> Result<Vec<String>> {
for name in ["barcodes.tsv.gz", "barcodes.tsv"] {
let path = dir.join(name);
if path.exists() {
let r = open_maybe_gz(&path)?;
let reader = BufReader::new(r);
let mut out = Vec::new();
for raw in reader.lines() {
let raw = raw.map_err(RsomicsError::Io)?;
let t = raw.trim();
if !t.is_empty() {
out.push(t.split('\t').next().unwrap().to_string());
}
}
return Ok(out);
}
}
Err(RsomicsError::InvalidInput(format!(
"no barcodes.tsv in {}",
dir.display()
)))
}
const CONV: f64 = 1e-4;
pub fn combat(dense: &mut [f64], n_genes: usize, n_cells: usize, batch_of_cell: &[usize]) {
let n_batch = batch_of_cell.iter().copied().max().unwrap() + 1;
let mut batch_cells: Vec<Vec<usize>> = vec![Vec::new(); n_batch];
for (cell, &b) in batch_of_cell.iter().enumerate() {
batch_cells[b].push(cell);
}
let n_b: Vec<f64> = batch_cells.iter().map(|c| c.len() as f64).collect();
let n_array = n_cells as f64;
let mut var_pooled = vec![0.0_f64; n_genes];
let mut stand_mean = vec![0.0_f64; n_genes];
let nc = n_cells;
dense
.par_chunks_mut(nc)
.zip(var_pooled.par_iter_mut())
.zip(stand_mean.par_iter_mut())
.for_each(|((row, vp), sm)| {
let mut bmean = vec![0.0_f64; n_batch];
for (b, cells) in batch_cells.iter().enumerate() {
let mut s = 0.0;
for &c in cells {
s += row[c];
}
bmean[b] = s / n_b[b];
}
let grand: f64 = (0..n_batch).map(|b| n_b[b] / n_array * bmean[b]).sum();
let mut ss = 0.0;
for (b, cells) in batch_cells.iter().enumerate() {
for &c in cells {
let d = row[c] - bmean[b];
ss += d * d;
}
}
let vp_g = ss / n_array;
*vp = vp_g;
*sm = grand;
let denom = vp_g.sqrt();
if vp_g == 0.0 {
for v in row.iter_mut() {
*v = 0.0;
}
} else {
for v in row.iter_mut() {
*v = (*v - grand) / denom;
}
}
});
let mut gamma_hat = vec![vec![0.0_f64; n_genes]; n_batch];
let mut delta_hat = vec![vec![0.0_f64; n_genes]; n_batch];
for b in 0..n_batch {
let cells = &batch_cells[b];
let nb = cells.len() as f64;
let gh = &mut gamma_hat[b];
let dh = &mut delta_hat[b];
dense
.par_chunks(nc)
.zip(gh.par_iter_mut())
.zip(dh.par_iter_mut())
.for_each(|((row, g), d)| {
let mut s = 0.0;
for &c in cells {
s += row[c];
}
let mean = s / nb;
*g = mean;
let mut ss = 0.0;
for &c in cells {
let e = row[c] - mean;
ss += e * e;
}
*d = if nb > 1.0 { ss / (nb - 1.0) } else { 0.0 };
});
}
let mut gamma_star = vec![vec![0.0_f64; n_genes]; n_batch];
let mut delta_star = vec![vec![0.0_f64; n_genes]; n_batch];
for b in 0..n_batch {
let gh = &gamma_hat[b];
let dh = &delta_hat[b];
let gamma_bar = mean(gh);
let t2 = var_ddof(gh, 0);
let a_prior = aprior(dh);
let b_prior = bprior(dh);
let cells = &batch_cells[b];
let std_rows: Vec<&[f64]> = (0..n_genes).map(|g| &dense[g * nc..g * nc + nc]).collect();
it_sol(
&std_rows,
cells,
gh,
dh,
gamma_bar,
t2,
a_prior,
b_prior,
&mut gamma_star[b],
&mut delta_star[b],
);
}
dense.par_chunks_mut(nc).enumerate().for_each(|(g, row)| {
let vpsq = var_pooled[g].sqrt();
let sm = stand_mean[g];
for b in 0..n_batch {
let dsq = delta_star[b][g].sqrt();
let gs = gamma_star[b][g];
for &c in &batch_cells[b] {
row[c] = (row[c] - gs) / dsq * vpsq + sm;
}
}
});
}
#[allow(clippy::too_many_arguments)]
fn it_sol(
std_rows: &[&[f64]],
cells: &[usize],
g_hat: &[f64],
d_hat: &[f64],
g_bar: f64,
t2: f64,
a: f64,
b: f64,
g_out: &mut [f64],
d_out: &mut [f64],
) {
let n = cells.len() as f64;
let n_genes = g_hat.len();
g_out.copy_from_slice(g_hat);
d_out.copy_from_slice(d_hat);
let mut g_new = vec![0.0_f64; n_genes];
let mut d_new = vec![0.0_f64; n_genes];
loop {
let mut g_change = f64::NEG_INFINITY;
let mut d_change = f64::NEG_INFINITY;
for i in 0..n_genes {
let gn = (t2 * n * g_hat[i] + d_out[i] * g_bar) / (t2 * n + d_out[i]);
let row = std_rows[i];
let mut sum2 = 0.0;
for &c in cells {
let e = row[c] - gn;
sum2 += e * e;
}
let dn = (0.5 * sum2 + b) / (n / 2.0 + a - 1.0);
g_change = numpy_max(g_change, (gn - g_out[i]).abs() / g_out[i].abs());
d_change = numpy_max(d_change, (dn - d_out[i]).abs() / d_out[i].abs());
g_new[i] = gn;
d_new[i] = dn;
}
g_out.copy_from_slice(&g_new);
d_out.copy_from_slice(&d_new);
let change = python_max(g_change, d_change);
#[allow(clippy::neg_cmp_op_on_partial_ord)]
if !(change > CONV) {
break;
}
}
}
fn numpy_max(acc: f64, x: f64) -> f64 {
if acc.is_nan() || x.is_nan() {
f64::NAN
} else {
acc.max(x)
}
}
fn python_max(a: f64, b: f64) -> f64 {
if b > a { b } else { a }
}
fn mean(x: &[f64]) -> f64 {
x.iter().sum::<f64>() / x.len() as f64
}
fn var_ddof(x: &[f64], ddof: usize) -> f64 {
let n = x.len() as f64;
let m = mean(x);
let ss: f64 = x.iter().map(|&v| (v - m) * (v - m)).sum();
ss / (n - ddof as f64)
}
fn aprior(delta_hat: &[f64]) -> f64 {
let m = mean(delta_hat);
let s2 = var_ddof(delta_hat, 1);
(2.0 * s2 + m * m) / s2
}
fn bprior(delta_hat: &[f64]) -> f64 {
let m = mean(delta_hat);
let s2 = var_ddof(delta_hat, 1);
(m * s2 + m * m * m) / s2
}
pub fn write_dense_gene_major(
n_genes: usize,
n_cells: usize,
dense: &[f64],
out: impl Write,
) -> Result<()> {
let mut w = BufWriter::with_capacity(1 << 20, out);
w.write_all(b"%%MatrixMarket matrix array real general\n")
.map_err(RsomicsError::Io)?;
let mut header = format!("{n_genes} {n_cells}");
header.push('\n');
w.write_all(header.as_bytes()).map_err(RsomicsError::Io)?;
let mut fmt = ryu::Buffer::new();
let mut buf: Vec<u8> = Vec::with_capacity(1 << 16);
for cell in 0..n_cells {
for gene in 0..n_genes {
buf.extend_from_slice(fmt.format(dense[gene * n_cells + cell]).as_bytes());
buf.push(b'\n');
if buf.len() >= 1 << 15 {
w.write_all(&buf).map_err(RsomicsError::Io)?;
buf.clear();
}
}
}
w.write_all(&buf).map_err(RsomicsError::Io)?;
w.flush().map_err(RsomicsError::Io)?;
Ok(())
}
fn parse_usize(tok: Option<&str>) -> Result<usize> {
tok.ok_or_else(|| RsomicsError::InvalidInput("MTX header missing a dimension".into()))?
.parse::<usize>()
.map_err(Into::into)
}
pub fn open_output(path: &str) -> Result<Box<dyn Write>> {
if path == "-" {
Ok(Box::new(std::io::stdout().lock()))
} else {
Ok(Box::new(
File::create(PathBuf::from(path)).map_err(RsomicsError::Io)?,
))
}
}
pub fn run(
dir: &Path,
batch_tsv: &Path,
key: Option<&str>,
out: impl Write,
) -> Result<(usize, usize, usize)> {
let m = parse_mtx(open_mtx(dir)?)?;
let barcodes = read_barcodes(dir)?;
if barcodes.len() != m.n_cells {
return Err(RsomicsError::InvalidInput(format!(
"{} barcodes but matrix has {} cells",
barcodes.len(),
m.n_cells
)));
}
let (batch_of_cell, levels) = read_batch_labels(batch_tsv, &barcodes, key)?;
let mut dense = densify_gene_major(&m);
combat(&mut dense, m.n_genes, m.n_cells, &batch_of_cell);
write_dense_gene_major(m.n_genes, m.n_cells, &dense, out)?;
Ok((m.n_genes, m.n_cells, levels.len()))
}
#[cfg(test)]
mod tests {
use super::*;
fn two_batch() -> (Vec<f64>, usize, usize, Vec<usize>) {
let n_genes = 3;
let n_cells = 6;
let batch = vec![0, 0, 0, 1, 1, 1];
let mut dense = vec![0.0_f64; n_genes * n_cells];
let vals = [
[1.0, 2.0, 1.5, 4.0, 5.0, 4.5],
[2.0, 2.5, 3.0, 1.0, 0.5, 1.2],
[0.5, 0.7, 0.6, 0.55, 0.62, 0.58],
];
for (g, row) in vals.iter().enumerate() {
for (c, &v) in row.iter().enumerate() {
dense[g * n_cells + c] = v;
}
}
(dense, n_genes, n_cells, batch)
}
#[test]
fn corrected_means_converge_across_batches() {
let (mut dense, ng, nc, batch) = two_batch();
combat(&mut dense, ng, nc, &batch);
let m0a = (0..3).map(|c| dense[c]).sum::<f64>() / 3.0;
let m0b = (3..6).map(|c| dense[c]).sum::<f64>() / 3.0;
assert!(
(m0a - m0b).abs() < 1.0,
"batch means not pulled together: {m0a} vs {m0b}"
);
}
#[test]
fn zero_variance_gene_collapses_to_grand_mean() {
let n_genes = 2;
let n_cells = 6;
let batch = vec![0, 0, 0, 1, 1, 1];
let g0 = [1.0, 2.0, 1.5, 4.0, 5.0, 4.5];
let g1 = [2.0; 6];
let mut dense = vec![0.0_f64; n_genes * n_cells];
dense[..6].copy_from_slice(&g0);
dense[6..].copy_from_slice(&g1);
combat(&mut dense, n_genes, n_cells, &batch);
for &v in &dense {
assert!(v.is_finite(), "non-finite ComBat output: {v}");
}
for &v in &dense[6..] {
assert!(
(v - 2.0).abs() < 1e-12,
"zero-var gene not at grand mean: {v}"
);
}
}
#[test]
fn priors_match_numpy_moments() {
let d = [1.0, 2.0, 3.0, 4.0];
let m = mean(&d);
assert!((m - 2.5).abs() < 1e-12);
assert!((var_ddof(&d, 1) - 5.0 / 3.0).abs() < 1e-12);
assert!((var_ddof(&d, 0) - 1.25).abs() < 1e-12);
}
}