1use std::fs::File;
2use std::io::{BufRead, BufReader, BufWriter, Read, Write};
3use std::path::{Path, PathBuf};
4
5use flate2::read::MultiGzDecoder;
6use rayon::prelude::*;
7use rsomics_common::{Result, RsomicsError};
8
9pub struct CountMatrix {
13 pub n_genes: usize,
14 pub n_cells: usize,
15 pub entries: Vec<Entry>,
16}
17
18#[derive(Clone, Copy)]
19pub struct Entry {
20 pub gene: u32,
21 pub cell: u32,
22 pub value: f64,
23}
24
25pub struct ScaleParams {
26 pub max_value: Option<f64>,
29}
30
31pub fn open_mtx(dir: &Path) -> Result<Box<dyn Read>> {
32 for name in ["matrix.mtx.gz", "matrix.mtx"] {
33 let path = dir.join(name);
34 if path.exists() {
35 return open_maybe_gz(&path);
36 }
37 }
38 Err(RsomicsError::InvalidInput(format!(
39 "no matrix.mtx or matrix.mtx.gz in {}",
40 dir.display()
41 )))
42}
43
44fn open_maybe_gz(path: &Path) -> Result<Box<dyn Read>> {
45 let file = File::open(path)
46 .map_err(|e| RsomicsError::InvalidInput(format!("{}: {e}", path.display())))?;
47 if path.extension().is_some_and(|e| e == "gz") {
48 Ok(Box::new(MultiGzDecoder::new(file)))
49 } else {
50 Ok(Box::new(file))
51 }
52}
53
54pub fn parse_mtx(reader: impl Read) -> Result<CountMatrix> {
57 let mut reader = BufReader::new(reader);
58 let mut line = String::new();
59
60 reader.read_line(&mut line).map_err(RsomicsError::Io)?;
61 let banner = line.trim();
62 if !banner.starts_with("%%MatrixMarket") {
63 return Err(RsomicsError::InvalidInput(
64 "missing %%MatrixMarket banner".into(),
65 ));
66 }
67 let pattern = banner.contains("pattern");
68
69 let (n_genes, n_cells, nnz) = loop {
70 line.clear();
71 let n = reader.read_line(&mut line).map_err(RsomicsError::Io)?;
72 if n == 0 {
73 return Err(RsomicsError::InvalidInput("truncated MTX header".into()));
74 }
75 let t = line.trim();
76 if t.is_empty() || t.starts_with('%') {
77 continue;
78 }
79 let mut it = t.split_whitespace();
80 let rows = parse_usize(it.next())?;
81 let cols = parse_usize(it.next())?;
82 let nnz = parse_usize(it.next())?;
83 break (rows, cols, nnz);
84 };
85
86 let mut entries = Vec::with_capacity(nnz);
87 for raw in reader.lines() {
88 let raw = raw.map_err(RsomicsError::Io)?;
89 let t = raw.trim();
90 if t.is_empty() {
91 continue;
92 }
93 let mut it = t.split_whitespace();
94 let gene = parse_usize(it.next())?;
95 let cell = parse_usize(it.next())?;
96 let value = if pattern {
97 1.0
98 } else {
99 it.next()
100 .ok_or_else(|| RsomicsError::InvalidInput("MTX entry missing value".into()))?
101 .parse::<f64>()?
102 };
103 if gene == 0 || gene > n_genes || cell == 0 || cell > n_cells {
104 return Err(RsomicsError::InvalidInput(format!(
105 "MTX index out of bounds: ({gene}, {cell})"
106 )));
107 }
108 entries.push(Entry {
109 gene: (gene - 1) as u32,
110 cell: (cell - 1) as u32,
111 value,
112 });
113 }
114 if entries.len() != nnz {
115 return Err(RsomicsError::InvalidInput(format!(
116 "MTX declared {nnz} entries, found {}",
117 entries.len()
118 )));
119 }
120
121 Ok(CountMatrix {
122 n_genes,
123 n_cells,
124 entries,
125 })
126}
127
128pub struct GeneStats {
132 pub mean: Vec<f64>,
133 pub std: Vec<f64>,
134}
135
136pub fn gene_stats(m: &CountMatrix) -> GeneStats {
137 let n = m.n_cells as f64;
138 let mut sum = vec![0.0_f64; m.n_genes];
139 let mut sum_sq = vec![0.0_f64; m.n_genes];
140 for e in &m.entries {
141 let g = e.gene as usize;
142 sum[g] += e.value;
143 sum_sq[g] += e.value * e.value;
144 }
145
146 let mut mean = vec![0.0_f64; m.n_genes];
147 let mut std = vec![1.0_f64; m.n_genes];
148 let factor = if m.n_cells > 1 { n / (n - 1.0) } else { 1.0 };
149 for g in 0..m.n_genes {
150 let mu = sum[g] / n;
151 mean[g] = mu;
152 let var = (sum_sq[g] / n - mu * mu) * factor;
153 let s = var.max(0.0).sqrt();
154 std[g] = if s == 0.0 { 1.0 } else { s };
155 }
156 GeneStats { mean, std }
157}
158
159pub fn scale_dense(m: &CountMatrix, params: &ScaleParams) -> (GeneStats, Vec<f64>) {
164 let stats = gene_stats(m);
165 let g = m.n_genes;
166 let baseline: Vec<f64> = (0..g).map(|i| -stats.mean[i] / stats.std[i]).collect();
167
168 let mut dense = vec![0.0_f64; g * m.n_cells];
169 dense
170 .par_chunks_mut(g)
171 .for_each(|col| col.copy_from_slice(&baseline));
172 for e in &m.entries {
173 let i = e.gene as usize;
174 dense[e.cell as usize * g + i] = (e.value - stats.mean[i]) / stats.std[i];
175 }
176
177 if let Some(mx) = params.max_value {
178 dense.par_iter_mut().for_each(|v| *v = v.clamp(-mx, mx));
179 }
180 (stats, dense)
181}
182
183pub fn write_dense(n_genes: usize, n_cells: usize, dense: &[f64], out: impl Write) -> Result<()> {
187 let mut w = BufWriter::with_capacity(1 << 20, out);
188 w.write_all(b"%%MatrixMarket matrix array real general\n")
189 .map_err(RsomicsError::Io)?;
190 let mut header = format!("{n_genes} {n_cells}");
191 header.push('\n');
192 w.write_all(header.as_bytes()).map_err(RsomicsError::Io)?;
193
194 let mut fmt = ryu::Buffer::new();
195 let mut buf: Vec<u8> = Vec::with_capacity(1 << 16);
196 for &v in dense {
197 buf.extend_from_slice(fmt.format(v).as_bytes());
198 buf.push(b'\n');
199 if buf.len() >= 1 << 15 {
200 w.write_all(&buf).map_err(RsomicsError::Io)?;
201 buf.clear();
202 }
203 }
204 w.write_all(&buf).map_err(RsomicsError::Io)?;
205 w.flush().map_err(RsomicsError::Io)?;
206 Ok(())
207}
208
209fn parse_usize(tok: Option<&str>) -> Result<usize> {
210 tok.ok_or_else(|| RsomicsError::InvalidInput("MTX header missing a dimension".into()))?
211 .parse::<usize>()
212 .map_err(Into::into)
213}
214
215pub fn run(dir: &Path, params: &ScaleParams, out: impl Write) -> Result<(usize, usize)> {
217 let m = parse_mtx(open_mtx(dir)?)?;
218 let shape = (m.n_genes, m.n_cells);
219 let (_stats, dense) = scale_dense(&m, params);
220 write_dense(m.n_genes, m.n_cells, &dense, out)?;
221 Ok(shape)
222}
223
224pub fn parse_max_value(s: Option<&str>) -> Result<Option<f64>> {
226 let Some(s) = s else { return Ok(None) };
227 let v = s
228 .parse::<f64>()
229 .map_err(|_| RsomicsError::InvalidInput(format!("invalid --max-value '{s}'")))?;
230 if v <= 0.0 || !v.is_finite() {
231 return Err(RsomicsError::InvalidInput(
232 "--max-value must be a positive finite number".into(),
233 ));
234 }
235 Ok(Some(v))
236}
237
238pub fn open_output(path: &str) -> Result<Box<dyn Write>> {
239 if path == "-" {
240 Ok(Box::new(std::io::stdout().lock()))
241 } else {
242 Ok(Box::new(
243 File::create(PathBuf::from(path)).map_err(RsomicsError::Io)?,
244 ))
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 fn tiny() -> CountMatrix {
253 let mut entries = Vec::new();
254 let mut push = |g: u32, c: u32, val: f64| {
255 entries.push(Entry {
256 gene: g,
257 cell: c,
258 value: val,
259 })
260 };
261 push(0, 0, 3.0);
262 push(2, 0, 1.0);
263 push(1, 1, 5.0);
264 push(0, 2, 1.0);
265 push(1, 2, 1.0);
266 push(2, 2, 1.0);
267 CountMatrix {
268 n_genes: 3,
269 n_cells: 4,
270 entries,
271 }
272 }
273
274 #[test]
275 fn stats_ddof1_over_all_cells() {
276 let s = gene_stats(&tiny());
279 assert!((s.mean[0] - 1.0).abs() < 1e-12);
280 assert!((s.std[0] - 2.0_f64.sqrt()).abs() < 1e-12);
281 }
282
283 #[test]
284 fn zero_variance_gene_keeps_std_one() {
285 let m = CountMatrix {
286 n_genes: 1,
287 n_cells: 3,
288 entries: vec![
289 Entry {
290 gene: 0,
291 cell: 0,
292 value: 2.0,
293 },
294 Entry {
295 gene: 0,
296 cell: 1,
297 value: 2.0,
298 },
299 Entry {
300 gene: 0,
301 cell: 2,
302 value: 2.0,
303 },
304 ],
305 };
306 let s = gene_stats(&m);
307 assert_eq!(s.std[0], 1.0);
308 let (_s, dense) = scale_dense(&m, &ScaleParams { max_value: None });
309 for v in dense {
310 assert!(v.abs() < 1e-12);
311 }
312 }
313
314 #[test]
315 fn densifies_implicit_zeros() {
316 let (_s, dense) = scale_dense(&tiny(), &ScaleParams { max_value: None });
317 assert_eq!(dense.len(), 3 * 4);
318 let v = dense[3];
320 assert!((v - (-1.0 / 2.0_f64.sqrt())).abs() < 1e-12);
321 }
322
323 #[test]
324 fn symmetric_clip() {
325 let (_s, dense) = scale_dense(
326 &tiny(),
327 &ScaleParams {
328 max_value: Some(0.5),
329 },
330 );
331 for v in dense {
332 assert!((-0.5 - 1e-12..=0.5 + 1e-12).contains(&v));
333 }
334 }
335
336 #[test]
337 fn max_value_parsing() {
338 assert_eq!(parse_max_value(None).unwrap(), None);
339 assert_eq!(parse_max_value(Some("10")).unwrap(), Some(10.0));
340 assert!(parse_max_value(Some("-1")).is_err());
341 assert!(parse_max_value(Some("abc")).is_err());
342 }
343}