Skip to main content

rsomics_sc_scale/
lib.rs

1use std::fs::File;
2use std::io::{BufRead, BufReader, BufWriter, Read, Write};
3use std::path::{Path, PathBuf};
4
5use flate2::read::MultiGzDecoder;
6use rayon::prelude::*;
7use rsomics_common::{Result, RsomicsError};
8
9/// A single-cell count matrix in 10x MatrixMarket layout: rows are genes,
10/// columns are cells, stored as coordinate triplets. Counts are held as f64
11/// because scanpy promotes the integer matrix to float before scaling.
12pub struct CountMatrix {
13    pub n_genes: usize,
14    pub n_cells: usize,
15    pub entries: Vec<Entry>,
16}
17
18#[derive(Clone, Copy)]
19pub struct Entry {
20    pub gene: u32,
21    pub cell: u32,
22    pub value: f64,
23}
24
25pub struct ScaleParams {
26    /// Symmetric clip applied to the z-scores: values are bounded to
27    /// `[-max_value, max_value]`. `None` reproduces scanpy's default.
28    pub max_value: Option<f64>,
29}
30
31pub fn open_mtx(dir: &Path) -> Result<Box<dyn Read>> {
32    for name in ["matrix.mtx.gz", "matrix.mtx"] {
33        let path = dir.join(name);
34        if path.exists() {
35            return open_maybe_gz(&path);
36        }
37    }
38    Err(RsomicsError::InvalidInput(format!(
39        "no matrix.mtx or matrix.mtx.gz in {}",
40        dir.display()
41    )))
42}
43
44fn open_maybe_gz(path: &Path) -> Result<Box<dyn Read>> {
45    let file = File::open(path)
46        .map_err(|e| RsomicsError::InvalidInput(format!("{}: {e}", path.display())))?;
47    if path.extension().is_some_and(|e| e == "gz") {
48        Ok(Box::new(MultiGzDecoder::new(file)))
49    } else {
50        Ok(Box::new(file))
51    }
52}
53
54/// Parse a MatrixMarket coordinate file (real, integer, or pattern; general).
55/// 10x stores genes on rows, cells on columns.
56pub fn parse_mtx(reader: impl Read) -> Result<CountMatrix> {
57    let mut reader = BufReader::new(reader);
58    let mut line = String::new();
59
60    reader.read_line(&mut line).map_err(RsomicsError::Io)?;
61    let banner = line.trim();
62    if !banner.starts_with("%%MatrixMarket") {
63        return Err(RsomicsError::InvalidInput(
64            "missing %%MatrixMarket banner".into(),
65        ));
66    }
67    let pattern = banner.contains("pattern");
68
69    let (n_genes, n_cells, nnz) = loop {
70        line.clear();
71        let n = reader.read_line(&mut line).map_err(RsomicsError::Io)?;
72        if n == 0 {
73            return Err(RsomicsError::InvalidInput("truncated MTX header".into()));
74        }
75        let t = line.trim();
76        if t.is_empty() || t.starts_with('%') {
77            continue;
78        }
79        let mut it = t.split_whitespace();
80        let rows = parse_usize(it.next())?;
81        let cols = parse_usize(it.next())?;
82        let nnz = parse_usize(it.next())?;
83        break (rows, cols, nnz);
84    };
85
86    let mut entries = Vec::with_capacity(nnz);
87    for raw in reader.lines() {
88        let raw = raw.map_err(RsomicsError::Io)?;
89        let t = raw.trim();
90        if t.is_empty() {
91            continue;
92        }
93        let mut it = t.split_whitespace();
94        let gene = parse_usize(it.next())?;
95        let cell = parse_usize(it.next())?;
96        let value = if pattern {
97            1.0
98        } else {
99            it.next()
100                .ok_or_else(|| RsomicsError::InvalidInput("MTX entry missing value".into()))?
101                .parse::<f64>()?
102        };
103        if gene == 0 || gene > n_genes || cell == 0 || cell > n_cells {
104            return Err(RsomicsError::InvalidInput(format!(
105                "MTX index out of bounds: ({gene}, {cell})"
106            )));
107        }
108        entries.push(Entry {
109            gene: (gene - 1) as u32,
110            cell: (cell - 1) as u32,
111            value,
112        });
113    }
114    if entries.len() != nnz {
115        return Err(RsomicsError::InvalidInput(format!(
116            "MTX declared {nnz} entries, found {}",
117            entries.len()
118        )));
119    }
120
121    Ok(CountMatrix {
122        n_genes,
123        n_cells,
124        entries,
125    })
126}
127
128/// Per-gene mean and standard deviation over all cells. The variance uses the
129/// ddof=1 (sample) convention scanpy enforces: `var = (E[x²] - E[x]²)·n/(n-1)`.
130/// A zero-variance gene's std collapses to 1 so its centered row stays at 0.
131pub struct GeneStats {
132    pub mean: Vec<f64>,
133    pub std: Vec<f64>,
134}
135
136pub fn gene_stats(m: &CountMatrix) -> GeneStats {
137    let n = m.n_cells as f64;
138    let mut sum = vec![0.0_f64; m.n_genes];
139    let mut sum_sq = vec![0.0_f64; m.n_genes];
140    for e in &m.entries {
141        let g = e.gene as usize;
142        sum[g] += e.value;
143        sum_sq[g] += e.value * e.value;
144    }
145
146    let mut mean = vec![0.0_f64; m.n_genes];
147    let mut std = vec![1.0_f64; m.n_genes];
148    let factor = if m.n_cells > 1 { n / (n - 1.0) } else { 1.0 };
149    for g in 0..m.n_genes {
150        let mu = sum[g] / n;
151        mean[g] = mu;
152        let var = (sum_sq[g] / n - mu * mu) * factor;
153        let s = var.max(0.0).sqrt();
154        std[g] = if s == 0.0 { 1.0 } else { s };
155    }
156    GeneStats { mean, std }
157}
158
159/// Z-score the matrix per gene and densify into a `genes × cells` buffer in
160/// column-major (cell-major) order: every gene of cell 0, then cell 1, … This
161/// is the MatrixMarket `array` layout the writer emits. An implicit zero count
162/// becomes `-mean/std`, which is why scaling densifies the matrix.
163pub fn scale_dense(m: &CountMatrix, params: &ScaleParams) -> (GeneStats, Vec<f64>) {
164    let stats = gene_stats(m);
165    let g = m.n_genes;
166    let baseline: Vec<f64> = (0..g).map(|i| -stats.mean[i] / stats.std[i]).collect();
167
168    let mut dense = vec![0.0_f64; g * m.n_cells];
169    dense
170        .par_chunks_mut(g)
171        .for_each(|col| col.copy_from_slice(&baseline));
172    for e in &m.entries {
173        let i = e.gene as usize;
174        dense[e.cell as usize * g + i] = (e.value - stats.mean[i]) / stats.std[i];
175    }
176
177    if let Some(mx) = params.max_value {
178        dense.par_iter_mut().for_each(|v| *v = v.clamp(-mx, mx));
179    }
180    (stats, dense)
181}
182
183/// Write the dense scaled matrix in MatrixMarket `array real general` layout:
184/// banner, `n_genes n_cells`, then one value per line in column-major order
185/// (matching scipy's dense MatrixMarket and `dense`'s memory layout).
186pub fn write_dense(n_genes: usize, n_cells: usize, dense: &[f64], out: impl Write) -> Result<()> {
187    let mut w = BufWriter::with_capacity(1 << 20, out);
188    w.write_all(b"%%MatrixMarket matrix array real general\n")
189        .map_err(RsomicsError::Io)?;
190    let mut header = format!("{n_genes} {n_cells}");
191    header.push('\n');
192    w.write_all(header.as_bytes()).map_err(RsomicsError::Io)?;
193
194    let mut fmt = ryu::Buffer::new();
195    let mut buf: Vec<u8> = Vec::with_capacity(1 << 16);
196    for &v in dense {
197        buf.extend_from_slice(fmt.format(v).as_bytes());
198        buf.push(b'\n');
199        if buf.len() >= 1 << 15 {
200            w.write_all(&buf).map_err(RsomicsError::Io)?;
201            buf.clear();
202        }
203    }
204    w.write_all(&buf).map_err(RsomicsError::Io)?;
205    w.flush().map_err(RsomicsError::Io)?;
206    Ok(())
207}
208
209fn parse_usize(tok: Option<&str>) -> Result<usize> {
210    tok.ok_or_else(|| RsomicsError::InvalidInput("MTX header missing a dimension".into()))?
211        .parse::<usize>()
212        .map_err(Into::into)
213}
214
215/// End-to-end: read the 10x matrix from `dir`, scale, write a dense matrix.
216pub fn run(dir: &Path, params: &ScaleParams, out: impl Write) -> Result<(usize, usize)> {
217    let m = parse_mtx(open_mtx(dir)?)?;
218    let shape = (m.n_genes, m.n_cells);
219    let (_stats, dense) = scale_dense(&m, params);
220    write_dense(m.n_genes, m.n_cells, &dense, out)?;
221    Ok(shape)
222}
223
224/// `--max-value` accepts a positive float, or is absent for no clipping.
225pub fn parse_max_value(s: Option<&str>) -> Result<Option<f64>> {
226    let Some(s) = s else { return Ok(None) };
227    let v = s
228        .parse::<f64>()
229        .map_err(|_| RsomicsError::InvalidInput(format!("invalid --max-value '{s}'")))?;
230    if v <= 0.0 || !v.is_finite() {
231        return Err(RsomicsError::InvalidInput(
232            "--max-value must be a positive finite number".into(),
233        ));
234    }
235    Ok(Some(v))
236}
237
238pub fn open_output(path: &str) -> Result<Box<dyn Write>> {
239    if path == "-" {
240        Ok(Box::new(std::io::stdout().lock()))
241    } else {
242        Ok(Box::new(
243            File::create(PathBuf::from(path)).map_err(RsomicsError::Io)?,
244        ))
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    fn tiny() -> CountMatrix {
253        let mut entries = Vec::new();
254        let mut push = |g: u32, c: u32, val: f64| {
255            entries.push(Entry {
256                gene: g,
257                cell: c,
258                value: val,
259            })
260        };
261        push(0, 0, 3.0);
262        push(2, 0, 1.0);
263        push(1, 1, 5.0);
264        push(0, 2, 1.0);
265        push(1, 2, 1.0);
266        push(2, 2, 1.0);
267        CountMatrix {
268            n_genes: 3,
269            n_cells: 4,
270            entries,
271        }
272    }
273
274    #[test]
275    fn stats_ddof1_over_all_cells() {
276        // gene 0 counts across 4 cells: [3,0,1,0] -> mean 1.0,
277        // sample variance = ((3-1)^2+(0-1)^2+(1-1)^2+(0-1)^2)/3 = 6/3 = 2.
278        let s = gene_stats(&tiny());
279        assert!((s.mean[0] - 1.0).abs() < 1e-12);
280        assert!((s.std[0] - 2.0_f64.sqrt()).abs() < 1e-12);
281    }
282
283    #[test]
284    fn zero_variance_gene_keeps_std_one() {
285        let m = CountMatrix {
286            n_genes: 1,
287            n_cells: 3,
288            entries: vec![
289                Entry {
290                    gene: 0,
291                    cell: 0,
292                    value: 2.0,
293                },
294                Entry {
295                    gene: 0,
296                    cell: 1,
297                    value: 2.0,
298                },
299                Entry {
300                    gene: 0,
301                    cell: 2,
302                    value: 2.0,
303                },
304            ],
305        };
306        let s = gene_stats(&m);
307        assert_eq!(s.std[0], 1.0);
308        let (_s, dense) = scale_dense(&m, &ScaleParams { max_value: None });
309        for v in dense {
310            assert!(v.abs() < 1e-12);
311        }
312    }
313
314    #[test]
315    fn densifies_implicit_zeros() {
316        let (_s, dense) = scale_dense(&tiny(), &ScaleParams { max_value: None });
317        assert_eq!(dense.len(), 3 * 4);
318        // gene 0, cell 1 is an implicit zero -> (0-1)/sqrt(2).
319        let v = dense[3];
320        assert!((v - (-1.0 / 2.0_f64.sqrt())).abs() < 1e-12);
321    }
322
323    #[test]
324    fn symmetric_clip() {
325        let (_s, dense) = scale_dense(
326            &tiny(),
327            &ScaleParams {
328                max_value: Some(0.5),
329            },
330        );
331        for v in dense {
332            assert!((-0.5 - 1e-12..=0.5 + 1e-12).contains(&v));
333        }
334    }
335
336    #[test]
337    fn max_value_parsing() {
338        assert_eq!(parse_max_value(None).unwrap(), None);
339        assert_eq!(parse_max_value(Some("10")).unwrap(), Some(10.0));
340        assert!(parse_max_value(Some("-1")).is_err());
341        assert!(parse_max_value(Some("abc")).is_err());
342    }
343}