use std::fs::File;
use std::io::{BufRead, BufReader, BufWriter, Read, Write};
use std::path::{Path, PathBuf};
use flate2::read::MultiGzDecoder;
use rayon::prelude::*;
use rsomics_common::{Result, RsomicsError};
pub struct CountMatrix {
pub n_genes: usize,
pub n_cells: usize,
pub entries: Vec<Entry>,
}
#[derive(Clone, Copy)]
pub struct Entry {
pub gene: u32,
pub cell: u32,
pub value: f64,
}
pub struct ScaleParams {
pub max_value: Option<f64>,
}
pub fn open_mtx(dir: &Path) -> Result<Box<dyn Read>> {
for name in ["matrix.mtx.gz", "matrix.mtx"] {
let path = dir.join(name);
if path.exists() {
return open_maybe_gz(&path);
}
}
Err(RsomicsError::InvalidInput(format!(
"no matrix.mtx or matrix.mtx.gz in {}",
dir.display()
)))
}
fn open_maybe_gz(path: &Path) -> Result<Box<dyn Read>> {
let file = File::open(path)
.map_err(|e| RsomicsError::InvalidInput(format!("{}: {e}", path.display())))?;
if path.extension().is_some_and(|e| e == "gz") {
Ok(Box::new(MultiGzDecoder::new(file)))
} else {
Ok(Box::new(file))
}
}
pub fn parse_mtx(reader: impl Read) -> Result<CountMatrix> {
let mut reader = BufReader::new(reader);
let mut line = String::new();
reader.read_line(&mut line).map_err(RsomicsError::Io)?;
let banner = line.trim();
if !banner.starts_with("%%MatrixMarket") {
return Err(RsomicsError::InvalidInput(
"missing %%MatrixMarket banner".into(),
));
}
let pattern = banner.contains("pattern");
let (n_genes, n_cells, nnz) = loop {
line.clear();
let n = reader.read_line(&mut line).map_err(RsomicsError::Io)?;
if n == 0 {
return Err(RsomicsError::InvalidInput("truncated MTX header".into()));
}
let t = line.trim();
if t.is_empty() || t.starts_with('%') {
continue;
}
let mut it = t.split_whitespace();
let rows = parse_usize(it.next())?;
let cols = parse_usize(it.next())?;
let nnz = parse_usize(it.next())?;
break (rows, cols, nnz);
};
let mut entries = Vec::with_capacity(nnz);
for raw in reader.lines() {
let raw = raw.map_err(RsomicsError::Io)?;
let t = raw.trim();
if t.is_empty() {
continue;
}
let mut it = t.split_whitespace();
let gene = parse_usize(it.next())?;
let cell = parse_usize(it.next())?;
let value = if pattern {
1.0
} else {
it.next()
.ok_or_else(|| RsomicsError::InvalidInput("MTX entry missing value".into()))?
.parse::<f64>()?
};
if gene == 0 || gene > n_genes || cell == 0 || cell > n_cells {
return Err(RsomicsError::InvalidInput(format!(
"MTX index out of bounds: ({gene}, {cell})"
)));
}
entries.push(Entry {
gene: (gene - 1) as u32,
cell: (cell - 1) as u32,
value,
});
}
if entries.len() != nnz {
return Err(RsomicsError::InvalidInput(format!(
"MTX declared {nnz} entries, found {}",
entries.len()
)));
}
Ok(CountMatrix {
n_genes,
n_cells,
entries,
})
}
pub struct GeneStats {
pub mean: Vec<f64>,
pub std: Vec<f64>,
}
pub fn gene_stats(m: &CountMatrix) -> GeneStats {
let n = m.n_cells as f64;
let mut sum = vec![0.0_f64; m.n_genes];
let mut sum_sq = vec![0.0_f64; m.n_genes];
for e in &m.entries {
let g = e.gene as usize;
sum[g] += e.value;
sum_sq[g] += e.value * e.value;
}
let mut mean = vec![0.0_f64; m.n_genes];
let mut std = vec![1.0_f64; m.n_genes];
let factor = if m.n_cells > 1 { n / (n - 1.0) } else { 1.0 };
for g in 0..m.n_genes {
let mu = sum[g] / n;
mean[g] = mu;
let var = (sum_sq[g] / n - mu * mu) * factor;
let s = var.max(0.0).sqrt();
std[g] = if s == 0.0 { 1.0 } else { s };
}
GeneStats { mean, std }
}
pub fn scale_dense(m: &CountMatrix, params: &ScaleParams) -> (GeneStats, Vec<f64>) {
let stats = gene_stats(m);
let g = m.n_genes;
let baseline: Vec<f64> = (0..g).map(|i| -stats.mean[i] / stats.std[i]).collect();
let mut dense = vec![0.0_f64; g * m.n_cells];
dense
.par_chunks_mut(g)
.for_each(|col| col.copy_from_slice(&baseline));
for e in &m.entries {
let i = e.gene as usize;
dense[e.cell as usize * g + i] = (e.value - stats.mean[i]) / stats.std[i];
}
if let Some(mx) = params.max_value {
dense.par_iter_mut().for_each(|v| *v = v.clamp(-mx, mx));
}
(stats, dense)
}
pub fn write_dense(n_genes: usize, n_cells: usize, dense: &[f64], out: impl Write) -> Result<()> {
let mut w = BufWriter::with_capacity(1 << 20, out);
w.write_all(b"%%MatrixMarket matrix array real general\n")
.map_err(RsomicsError::Io)?;
let mut header = format!("{n_genes} {n_cells}");
header.push('\n');
w.write_all(header.as_bytes()).map_err(RsomicsError::Io)?;
let mut fmt = ryu::Buffer::new();
let mut buf: Vec<u8> = Vec::with_capacity(1 << 16);
for &v in dense {
buf.extend_from_slice(fmt.format(v).as_bytes());
buf.push(b'\n');
if buf.len() >= 1 << 15 {
w.write_all(&buf).map_err(RsomicsError::Io)?;
buf.clear();
}
}
w.write_all(&buf).map_err(RsomicsError::Io)?;
w.flush().map_err(RsomicsError::Io)?;
Ok(())
}
fn parse_usize(tok: Option<&str>) -> Result<usize> {
tok.ok_or_else(|| RsomicsError::InvalidInput("MTX header missing a dimension".into()))?
.parse::<usize>()
.map_err(Into::into)
}
pub fn run(dir: &Path, params: &ScaleParams, out: impl Write) -> Result<(usize, usize)> {
let m = parse_mtx(open_mtx(dir)?)?;
let shape = (m.n_genes, m.n_cells);
let (_stats, dense) = scale_dense(&m, params);
write_dense(m.n_genes, m.n_cells, &dense, out)?;
Ok(shape)
}
pub fn parse_max_value(s: Option<&str>) -> Result<Option<f64>> {
let Some(s) = s else { return Ok(None) };
let v = s
.parse::<f64>()
.map_err(|_| RsomicsError::InvalidInput(format!("invalid --max-value '{s}'")))?;
if v <= 0.0 || !v.is_finite() {
return Err(RsomicsError::InvalidInput(
"--max-value must be a positive finite number".into(),
));
}
Ok(Some(v))
}
pub fn open_output(path: &str) -> Result<Box<dyn Write>> {
if path == "-" {
Ok(Box::new(std::io::stdout().lock()))
} else {
Ok(Box::new(
File::create(PathBuf::from(path)).map_err(RsomicsError::Io)?,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tiny() -> CountMatrix {
let mut entries = Vec::new();
let mut push = |g: u32, c: u32, val: f64| {
entries.push(Entry {
gene: g,
cell: c,
value: val,
})
};
push(0, 0, 3.0);
push(2, 0, 1.0);
push(1, 1, 5.0);
push(0, 2, 1.0);
push(1, 2, 1.0);
push(2, 2, 1.0);
CountMatrix {
n_genes: 3,
n_cells: 4,
entries,
}
}
#[test]
fn stats_ddof1_over_all_cells() {
let s = gene_stats(&tiny());
assert!((s.mean[0] - 1.0).abs() < 1e-12);
assert!((s.std[0] - 2.0_f64.sqrt()).abs() < 1e-12);
}
#[test]
fn zero_variance_gene_keeps_std_one() {
let m = CountMatrix {
n_genes: 1,
n_cells: 3,
entries: vec![
Entry {
gene: 0,
cell: 0,
value: 2.0,
},
Entry {
gene: 0,
cell: 1,
value: 2.0,
},
Entry {
gene: 0,
cell: 2,
value: 2.0,
},
],
};
let s = gene_stats(&m);
assert_eq!(s.std[0], 1.0);
let (_s, dense) = scale_dense(&m, &ScaleParams { max_value: None });
for v in dense {
assert!(v.abs() < 1e-12);
}
}
#[test]
fn densifies_implicit_zeros() {
let (_s, dense) = scale_dense(&tiny(), &ScaleParams { max_value: None });
assert_eq!(dense.len(), 3 * 4);
let v = dense[3];
assert!((v - (-1.0 / 2.0_f64.sqrt())).abs() < 1e-12);
}
#[test]
fn symmetric_clip() {
let (_s, dense) = scale_dense(
&tiny(),
&ScaleParams {
max_value: Some(0.5),
},
);
for v in dense {
assert!((-0.5 - 1e-12..=0.5 + 1e-12).contains(&v));
}
}
#[test]
fn max_value_parsing() {
assert_eq!(parse_max_value(None).unwrap(), None);
assert_eq!(parse_max_value(Some("10")).unwrap(), Some(10.0));
assert!(parse_max_value(Some("-1")).is_err());
assert!(parse_max_value(Some("abc")).is_err());
}
}