Skip to main content

rsomics_sc_combat/
lib.rs

1use std::collections::BTreeMap;
2use std::fs::File;
3use std::io::{BufRead, BufReader, BufWriter, Read, Write};
4use std::path::{Path, PathBuf};
5
6use flate2::read::MultiGzDecoder;
7use rayon::prelude::*;
8use rsomics_common::{Result, RsomicsError};
9
10/// Single-cell counts in 10x MatrixMarket layout: genes on rows, cells on
11/// columns, coordinate triplets. Counts are f64 because ComBat runs on the
12/// log-normalized float matrix scanpy feeds it.
13pub struct CountMatrix {
14    pub n_genes: usize,
15    pub n_cells: usize,
16    pub entries: Vec<Entry>,
17}
18
19#[derive(Clone, Copy)]
20pub struct Entry {
21    pub gene: u32,
22    pub cell: u32,
23    pub value: f64,
24}
25
26pub fn open_mtx(dir: &Path) -> Result<Box<dyn Read>> {
27    for name in ["matrix.mtx.gz", "matrix.mtx"] {
28        let path = dir.join(name);
29        if path.exists() {
30            return open_maybe_gz(&path);
31        }
32    }
33    Err(RsomicsError::InvalidInput(format!(
34        "no matrix.mtx or matrix.mtx.gz in {}",
35        dir.display()
36    )))
37}
38
39fn open_maybe_gz(path: &Path) -> Result<Box<dyn Read>> {
40    let file = File::open(path)
41        .map_err(|e| RsomicsError::InvalidInput(format!("{}: {e}", path.display())))?;
42    if path.extension().is_some_and(|e| e == "gz") {
43        Ok(Box::new(MultiGzDecoder::new(file)))
44    } else {
45        Ok(Box::new(file))
46    }
47}
48
49/// Parse a 10x MatrixMarket coordinate file (real/integer/pattern, general).
50pub fn parse_mtx(reader: impl Read) -> Result<CountMatrix> {
51    let mut reader = BufReader::new(reader);
52    let mut line = String::new();
53
54    reader.read_line(&mut line).map_err(RsomicsError::Io)?;
55    let banner = line.trim();
56    if !banner.starts_with("%%MatrixMarket") {
57        return Err(RsomicsError::InvalidInput(
58            "missing %%MatrixMarket banner".into(),
59        ));
60    }
61    let pattern = banner.contains("pattern");
62
63    let (n_genes, n_cells, nnz) = loop {
64        line.clear();
65        let n = reader.read_line(&mut line).map_err(RsomicsError::Io)?;
66        if n == 0 {
67            return Err(RsomicsError::InvalidInput("truncated MTX header".into()));
68        }
69        let t = line.trim();
70        if t.is_empty() || t.starts_with('%') {
71            continue;
72        }
73        let mut it = t.split_whitespace();
74        let rows = parse_usize(it.next())?;
75        let cols = parse_usize(it.next())?;
76        let nnz = parse_usize(it.next())?;
77        break (rows, cols, nnz);
78    };
79
80    let mut entries = Vec::with_capacity(nnz);
81    for raw in reader.lines() {
82        let raw = raw.map_err(RsomicsError::Io)?;
83        let t = raw.trim();
84        if t.is_empty() {
85            continue;
86        }
87        let mut it = t.split_whitespace();
88        let gene = parse_usize(it.next())?;
89        let cell = parse_usize(it.next())?;
90        let value = if pattern {
91            1.0
92        } else {
93            it.next()
94                .ok_or_else(|| RsomicsError::InvalidInput("MTX entry missing value".into()))?
95                .parse::<f64>()?
96        };
97        if gene == 0 || gene > n_genes || cell == 0 || cell > n_cells {
98            return Err(RsomicsError::InvalidInput(format!(
99                "MTX index out of bounds: ({gene}, {cell})"
100            )));
101        }
102        entries.push(Entry {
103            gene: (gene - 1) as u32,
104            cell: (cell - 1) as u32,
105            value,
106        });
107    }
108    if entries.len() != nnz {
109        return Err(RsomicsError::InvalidInput(format!(
110            "MTX declared {nnz} entries, found {}",
111            entries.len()
112        )));
113    }
114
115    Ok(CountMatrix {
116        n_genes,
117        n_cells,
118        entries,
119    })
120}
121
122/// Densify the sparse matrix to genes × cells in gene-major rows: gene `g`
123/// occupies `dense[g*n_cells .. (g+1)*n_cells]`. ComBat works gene-at-a-time,
124/// so a contiguous per-gene row is the cache-friendly layout.
125fn densify_gene_major(m: &CountMatrix) -> Vec<f64> {
126    let mut dense = vec![0.0_f64; m.n_genes * m.n_cells];
127    let nc = m.n_cells;
128    for e in &m.entries {
129        dense[e.gene as usize * nc + e.cell as usize] = e.value;
130    }
131    dense
132}
133
134/// Read a barcode → batch-label TSV. Two columns (barcode, label); a header
135/// line `barcode<TAB>colname` is honored when `key` names that second column,
136/// otherwise the first data row sets the schema. Returns the per-cell batch
137/// index aligned to `barcodes`, plus the ordered distinct labels.
138pub fn read_batch_labels(
139    path: &Path,
140    barcodes: &[String],
141    key: Option<&str>,
142) -> Result<(Vec<usize>, Vec<String>)> {
143    let f = File::open(path)
144        .map_err(|e| RsomicsError::InvalidInput(format!("{}: {e}", path.display())))?;
145    let reader = BufReader::new(f);
146    let mut lines = Vec::new();
147    for raw in reader.lines() {
148        let raw = raw.map_err(RsomicsError::Io)?;
149        let t = raw.trim_end_matches(['\n', '\r']).to_string();
150        if !t.is_empty() {
151            lines.push(t);
152        }
153    }
154    if lines.is_empty() {
155        return Err(RsomicsError::InvalidInput("empty batch TSV".into()));
156    }
157
158    let first: Vec<&str> = lines[0].split('\t').collect();
159    let has_header = is_header(&first, key);
160    let label_col = match (key, has_header) {
161        (Some(k), true) => first.iter().position(|c| *c == k).ok_or_else(|| {
162            RsomicsError::InvalidInput(format!("key {k:?} not in batch TSV header"))
163        })?,
164        (Some(k), false) => {
165            return Err(RsomicsError::InvalidInput(format!(
166                "--key {k:?} given but batch TSV has no header row"
167            )));
168        }
169        (None, _) => 1,
170    };
171
172    let mut by_barcode: BTreeMap<String, String> = BTreeMap::new();
173    for line in lines.iter().skip(usize::from(has_header)) {
174        let cols: Vec<&str> = line.split('\t').collect();
175        if label_col >= cols.len() {
176            return Err(RsomicsError::InvalidInput(
177                "batch TSV row shorter than the selected key column".into(),
178            ));
179        }
180        by_barcode.insert(cols[0].to_string(), cols[label_col].to_string());
181    }
182
183    // scanpy sanitizes the obs column to a pandas Categorical, whose default
184    // categories are the sorted unique labels; groupby then iterates them in
185    // that order. The numeric result is order-invariant, but matching the
186    // ordering keeps the reported levels identical.
187    let mut distinct: Vec<String> = by_barcode.values().cloned().collect();
188    distinct.sort();
189    distinct.dedup();
190    let level_of: BTreeMap<&str, usize> = distinct
191        .iter()
192        .enumerate()
193        .map(|(i, s)| (s.as_str(), i))
194        .collect();
195
196    let mut batch_of_cell = Vec::with_capacity(barcodes.len());
197    for bc in barcodes {
198        let label = by_barcode.get(bc).ok_or_else(|| {
199            RsomicsError::InvalidInput(format!("barcode {bc:?} missing from batch TSV"))
200        })?;
201        batch_of_cell.push(level_of[label.as_str()]);
202    }
203
204    if distinct.len() < 2 {
205        return Err(RsomicsError::InvalidInput(
206            "ComBat needs at least 2 batches".into(),
207        ));
208    }
209    Ok((batch_of_cell, distinct))
210}
211
212fn is_header(cols: &[&str], key: Option<&str>) -> bool {
213    if let Some(k) = key {
214        return cols.contains(&k);
215    }
216    cols.iter().any(|c| c.eq_ignore_ascii_case("barcode"))
217}
218
219/// Read the 10x barcodes file (`barcodes.tsv[.gz]`), one per cell.
220pub fn read_barcodes(dir: &Path) -> Result<Vec<String>> {
221    for name in ["barcodes.tsv.gz", "barcodes.tsv"] {
222        let path = dir.join(name);
223        if path.exists() {
224            let r = open_maybe_gz(&path)?;
225            let reader = BufReader::new(r);
226            let mut out = Vec::new();
227            for raw in reader.lines() {
228                let raw = raw.map_err(RsomicsError::Io)?;
229                let t = raw.trim();
230                if !t.is_empty() {
231                    out.push(t.split('\t').next().unwrap().to_string());
232                }
233            }
234            return Ok(out);
235        }
236    }
237    Err(RsomicsError::InvalidInput(format!(
238        "no barcodes.tsv in {}",
239        dir.display()
240    )))
241}
242
243const CONV: f64 = 1e-4;
244
245/// Parametric ComBat (Johnson, Li & Rabinovic 2007) with batch as the only
246/// model term. Operates in place on a gene-major dense matrix and overwrites
247/// it with the corrected values, mirroring scanpy's `_combat`.
248pub fn combat(dense: &mut [f64], n_genes: usize, n_cells: usize, batch_of_cell: &[usize]) {
249    let n_batch = batch_of_cell.iter().copied().max().unwrap() + 1;
250    let mut batch_cells: Vec<Vec<usize>> = vec![Vec::new(); n_batch];
251    for (cell, &b) in batch_of_cell.iter().enumerate() {
252        batch_cells[b].push(cell);
253    }
254    let n_b: Vec<f64> = batch_cells.iter().map(|c| c.len() as f64).collect();
255    let n_array = n_cells as f64;
256
257    // var_pooled (population, ddof=0) and grand mean per gene, plus the
258    // standardized matrix overwriting `dense`.
259    let mut var_pooled = vec![0.0_f64; n_genes];
260    let mut stand_mean = vec![0.0_f64; n_genes];
261    let nc = n_cells;
262    dense
263        .par_chunks_mut(nc)
264        .zip(var_pooled.par_iter_mut())
265        .zip(stand_mean.par_iter_mut())
266        .for_each(|((row, vp), sm)| {
267            let mut bmean = vec![0.0_f64; n_batch];
268            for (b, cells) in batch_cells.iter().enumerate() {
269                let mut s = 0.0;
270                for &c in cells {
271                    s += row[c];
272                }
273                bmean[b] = s / n_b[b];
274            }
275            let grand: f64 = (0..n_batch).map(|b| n_b[b] / n_array * bmean[b]).sum();
276            let mut ss = 0.0;
277            for (b, cells) in batch_cells.iter().enumerate() {
278                for &c in cells {
279                    let d = row[c] - bmean[b];
280                    ss += d * d;
281                }
282            }
283            let vp_g = ss / n_array;
284            *vp = vp_g;
285            *sm = grand;
286            let denom = vp_g.sqrt();
287            if vp_g == 0.0 {
288                for v in row.iter_mut() {
289                    *v = 0.0;
290                }
291            } else {
292                for v in row.iter_mut() {
293                    *v = (*v - grand) / denom;
294                }
295            }
296        });
297
298    // gamma_hat[batch][gene] = batch mean of standardized data;
299    // delta_hat[batch][gene] = batch sample variance (ddof=1).
300    let mut gamma_hat = vec![vec![0.0_f64; n_genes]; n_batch];
301    let mut delta_hat = vec![vec![0.0_f64; n_genes]; n_batch];
302    for b in 0..n_batch {
303        let cells = &batch_cells[b];
304        let nb = cells.len() as f64;
305        let gh = &mut gamma_hat[b];
306        let dh = &mut delta_hat[b];
307        dense
308            .par_chunks(nc)
309            .zip(gh.par_iter_mut())
310            .zip(dh.par_iter_mut())
311            .for_each(|((row, g), d)| {
312                let mut s = 0.0;
313                for &c in cells {
314                    s += row[c];
315                }
316                let mean = s / nb;
317                *g = mean;
318                let mut ss = 0.0;
319                for &c in cells {
320                    let e = row[c] - mean;
321                    ss += e * e;
322                }
323                *d = if nb > 1.0 { ss / (nb - 1.0) } else { 0.0 };
324            });
325    }
326
327    // EB hyperparameters per batch. gamma_bar/t2 are numpy moments over genes
328    // (t2 ddof=0); a_prior/b_prior come from delta_hat moments (ddof=1).
329    let mut gamma_star = vec![vec![0.0_f64; n_genes]; n_batch];
330    let mut delta_star = vec![vec![0.0_f64; n_genes]; n_batch];
331    for b in 0..n_batch {
332        let gh = &gamma_hat[b];
333        let dh = &delta_hat[b];
334        let gamma_bar = mean(gh);
335        let t2 = var_ddof(gh, 0);
336        let a_prior = aprior(dh);
337        let b_prior = bprior(dh);
338        let cells = &batch_cells[b];
339
340        let std_rows: Vec<&[f64]> = (0..n_genes).map(|g| &dense[g * nc..g * nc + nc]).collect();
341        it_sol(
342            &std_rows,
343            cells,
344            gh,
345            dh,
346            gamma_bar,
347            t2,
348            a_prior,
349            b_prior,
350            &mut gamma_star[b],
351            &mut delta_star[b],
352        );
353    }
354
355    // De-standardize: subtract the additive effect, divide by sqrt(delta*),
356    // rescale by sqrt(var_pooled) and add back the gene-wise mean.
357    dense.par_chunks_mut(nc).enumerate().for_each(|(g, row)| {
358        let vpsq = var_pooled[g].sqrt();
359        let sm = stand_mean[g];
360        for b in 0..n_batch {
361            let dsq = delta_star[b][g].sqrt();
362            let gs = gamma_star[b][g];
363            for &c in &batch_cells[b] {
364                row[c] = (row[c] - gs) / dsq * vpsq + sm;
365            }
366        }
367    });
368}
369
370/// Iterative EB posterior for γ and δ, a faithful port of scanpy's vectorized
371/// `_it_sol`. All genes step together off the previous pass's δ; convergence
372/// is one batch-wide scalar `change`. NaN matters: a zero-variance gene yields
373/// `0/0` in the γ relative-change, numpy's reduction propagates it, and
374/// `while change > conv` then halts after a single pass — so we reproduce the
375/// numpy-max (NaN-propagating) and Python-max (left-biased) reductions exactly.
376#[allow(clippy::too_many_arguments)]
377fn it_sol(
378    std_rows: &[&[f64]],
379    cells: &[usize],
380    g_hat: &[f64],
381    d_hat: &[f64],
382    g_bar: f64,
383    t2: f64,
384    a: f64,
385    b: f64,
386    g_out: &mut [f64],
387    d_out: &mut [f64],
388) {
389    let n = cells.len() as f64;
390    let n_genes = g_hat.len();
391    g_out.copy_from_slice(g_hat);
392    d_out.copy_from_slice(d_hat);
393    let mut g_new = vec![0.0_f64; n_genes];
394    let mut d_new = vec![0.0_f64; n_genes];
395
396    loop {
397        let mut g_change = f64::NEG_INFINITY;
398        let mut d_change = f64::NEG_INFINITY;
399        for i in 0..n_genes {
400            let gn = (t2 * n * g_hat[i] + d_out[i] * g_bar) / (t2 * n + d_out[i]);
401            let row = std_rows[i];
402            let mut sum2 = 0.0;
403            for &c in cells {
404                let e = row[c] - gn;
405                sum2 += e * e;
406            }
407            let dn = (0.5 * sum2 + b) / (n / 2.0 + a - 1.0);
408            g_change = numpy_max(g_change, (gn - g_out[i]).abs() / g_out[i].abs());
409            d_change = numpy_max(d_change, (dn - d_out[i]).abs() / d_out[i].abs());
410            g_new[i] = gn;
411            d_new[i] = dn;
412        }
413        g_out.copy_from_slice(&g_new);
414        d_out.copy_from_slice(&d_new);
415        let change = python_max(g_change, d_change);
416        // `change > CONV` (not `<= CONV`): a NaN change must stop the loop,
417        // mirroring numpy's `while change > conv`.
418        #[allow(clippy::neg_cmp_op_on_partial_ord)]
419        if !(change > CONV) {
420            break;
421        }
422    }
423}
424
425/// numpy reduction: NaN propagates.
426fn numpy_max(acc: f64, x: f64) -> f64 {
427    if acc.is_nan() || x.is_nan() {
428        f64::NAN
429    } else {
430        acc.max(x)
431    }
432}
433
434/// CPython `max(a, b)`: returns `a` unless `b` strictly exceeds it, so a NaN in
435/// `b` leaves `a`, while a NaN in `a` is returned as-is.
436fn python_max(a: f64, b: f64) -> f64 {
437    if b > a { b } else { a }
438}
439
440fn mean(x: &[f64]) -> f64 {
441    x.iter().sum::<f64>() / x.len() as f64
442}
443
444fn var_ddof(x: &[f64], ddof: usize) -> f64 {
445    let n = x.len() as f64;
446    let m = mean(x);
447    let ss: f64 = x.iter().map(|&v| (v - m) * (v - m)).sum();
448    ss / (n - ddof as f64)
449}
450
451fn aprior(delta_hat: &[f64]) -> f64 {
452    let m = mean(delta_hat);
453    let s2 = var_ddof(delta_hat, 1);
454    (2.0 * s2 + m * m) / s2
455}
456
457fn bprior(delta_hat: &[f64]) -> f64 {
458    let m = mean(delta_hat);
459    let s2 = var_ddof(delta_hat, 1);
460    (m * s2 + m * m * m) / s2
461}
462
463/// Write the dense matrix in MatrixMarket `array real general` layout, one
464/// value per line in column-major order, from a gene-major buffer.
465pub fn write_dense_gene_major(
466    n_genes: usize,
467    n_cells: usize,
468    dense: &[f64],
469    out: impl Write,
470) -> Result<()> {
471    let mut w = BufWriter::with_capacity(1 << 20, out);
472    w.write_all(b"%%MatrixMarket matrix array real general\n")
473        .map_err(RsomicsError::Io)?;
474    let mut header = format!("{n_genes} {n_cells}");
475    header.push('\n');
476    w.write_all(header.as_bytes()).map_err(RsomicsError::Io)?;
477
478    let mut fmt = ryu::Buffer::new();
479    let mut buf: Vec<u8> = Vec::with_capacity(1 << 16);
480    for cell in 0..n_cells {
481        for gene in 0..n_genes {
482            buf.extend_from_slice(fmt.format(dense[gene * n_cells + cell]).as_bytes());
483            buf.push(b'\n');
484            if buf.len() >= 1 << 15 {
485                w.write_all(&buf).map_err(RsomicsError::Io)?;
486                buf.clear();
487            }
488        }
489    }
490    w.write_all(&buf).map_err(RsomicsError::Io)?;
491    w.flush().map_err(RsomicsError::Io)?;
492    Ok(())
493}
494
495fn parse_usize(tok: Option<&str>) -> Result<usize> {
496    tok.ok_or_else(|| RsomicsError::InvalidInput("MTX header missing a dimension".into()))?
497        .parse::<usize>()
498        .map_err(Into::into)
499}
500
501pub fn open_output(path: &str) -> Result<Box<dyn Write>> {
502    if path == "-" {
503        Ok(Box::new(std::io::stdout().lock()))
504    } else {
505        Ok(Box::new(
506            File::create(PathBuf::from(path)).map_err(RsomicsError::Io)?,
507        ))
508    }
509}
510
511/// End-to-end: read the 10x matrix and barcodes from `dir`, read batch labels,
512/// run ComBat, write the corrected dense matrix.
513pub fn run(
514    dir: &Path,
515    batch_tsv: &Path,
516    key: Option<&str>,
517    out: impl Write,
518) -> Result<(usize, usize, usize)> {
519    let m = parse_mtx(open_mtx(dir)?)?;
520    let barcodes = read_barcodes(dir)?;
521    if barcodes.len() != m.n_cells {
522        return Err(RsomicsError::InvalidInput(format!(
523            "{} barcodes but matrix has {} cells",
524            barcodes.len(),
525            m.n_cells
526        )));
527    }
528    let (batch_of_cell, levels) = read_batch_labels(batch_tsv, &barcodes, key)?;
529
530    let mut dense = densify_gene_major(&m);
531    combat(&mut dense, m.n_genes, m.n_cells, &batch_of_cell);
532    write_dense_gene_major(m.n_genes, m.n_cells, &dense, out)?;
533    Ok((m.n_genes, m.n_cells, levels.len()))
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539
540    fn two_batch() -> (Vec<f64>, usize, usize, Vec<usize>) {
541        let n_genes = 3;
542        let n_cells = 6;
543        let batch = vec![0, 0, 0, 1, 1, 1];
544        let mut dense = vec![0.0_f64; n_genes * n_cells];
545        let vals = [
546            [1.0, 2.0, 1.5, 4.0, 5.0, 4.5],
547            [2.0, 2.5, 3.0, 1.0, 0.5, 1.2],
548            [0.5, 0.7, 0.6, 0.55, 0.62, 0.58],
549        ];
550        for (g, row) in vals.iter().enumerate() {
551            for (c, &v) in row.iter().enumerate() {
552                dense[g * n_cells + c] = v;
553            }
554        }
555        (dense, n_genes, n_cells, batch)
556    }
557
558    #[test]
559    fn corrected_means_converge_across_batches() {
560        let (mut dense, ng, nc, batch) = two_batch();
561        combat(&mut dense, ng, nc, &batch);
562        // After ComBat the per-batch gene means move toward the grand mean;
563        // for gene 0 the two batch means should be far closer than the raw 1.5 vs 4.5.
564        let m0a = (0..3).map(|c| dense[c]).sum::<f64>() / 3.0;
565        let m0b = (3..6).map(|c| dense[c]).sum::<f64>() / 3.0;
566        assert!(
567            (m0a - m0b).abs() < 1.0,
568            "batch means not pulled together: {m0a} vs {m0b}"
569        );
570    }
571
572    #[test]
573    fn zero_variance_gene_collapses_to_grand_mean() {
574        // gene 0 varies, gene 1 is constant 2.0 everywhere (zero variance).
575        let n_genes = 2;
576        let n_cells = 6;
577        let batch = vec![0, 0, 0, 1, 1, 1];
578        let g0 = [1.0, 2.0, 1.5, 4.0, 5.0, 4.5];
579        let g1 = [2.0; 6];
580        let mut dense = vec![0.0_f64; n_genes * n_cells];
581        dense[..6].copy_from_slice(&g0);
582        dense[6..].copy_from_slice(&g1);
583        combat(&mut dense, n_genes, n_cells, &batch);
584        for &v in &dense {
585            assert!(v.is_finite(), "non-finite ComBat output: {v}");
586        }
587        // a zero-variance gene has sqrt(var_pooled)=0, so every cell collapses
588        // to the gene's grand mean (2.0) regardless of the EB estimates.
589        for &v in &dense[6..] {
590            assert!(
591                (v - 2.0).abs() < 1e-12,
592                "zero-var gene not at grand mean: {v}"
593            );
594        }
595    }
596
597    #[test]
598    fn priors_match_numpy_moments() {
599        let d = [1.0, 2.0, 3.0, 4.0];
600        let m = mean(&d);
601        assert!((m - 2.5).abs() < 1e-12);
602        // ddof=1 variance of [1,2,3,4] = 5/3
603        assert!((var_ddof(&d, 1) - 5.0 / 3.0).abs() < 1e-12);
604        // ddof=0 variance = 1.25
605        assert!((var_ddof(&d, 0) - 1.25).abs() < 1e-12);
606    }
607}