1use std::fs;
2use std::io::{BufWriter, Write};
3use std::path::Path;
4
5use rsomics_common::{Result, RsomicsError};
6
7const LOGRATIO_TRIM: f64 = 0.3;
8const SUM_TRIM: f64 = 0.05;
9const ACUTOFF: f64 = -1e10;
10
11pub struct Matrix {
12 pub samples: Vec<String>,
13 pub counts: Vec<Vec<f64>>,
14}
15
16pub fn read_matrix(path: &Path) -> Result<Matrix> {
17 let bytes = fs::read(path)
18 .map_err(|e| RsomicsError::InvalidInput(format!("{}: {e}", path.display())))?;
19 let mut lines = bytes.split(|&b| b == b'\n');
20
21 let header = lines
22 .next()
23 .ok_or_else(|| RsomicsError::InvalidInput("empty count matrix".into()))?;
24 let samples: Vec<String> = header
25 .split(|&b| b == b'\t')
26 .skip(1)
27 .map(|s| String::from_utf8_lossy(s).into_owned())
28 .collect();
29 if samples.is_empty() {
30 return Err(RsomicsError::InvalidInput(
31 "count matrix has no sample columns".into(),
32 ));
33 }
34
35 let mut counts: Vec<Vec<f64>> = Vec::new();
36 for line in lines {
37 if line.is_empty() {
38 continue;
39 }
40 let mut row = Vec::with_capacity(samples.len());
41 for cell in line.split(|&b| b == b'\t').skip(1) {
42 row.push(parse_count(cell)?);
43 }
44 if row.len() != samples.len() {
45 return Err(RsomicsError::InvalidInput(format!(
46 "row has {} values but header has {} samples",
47 row.len(),
48 samples.len()
49 )));
50 }
51 counts.push(row);
52 }
53 Ok(Matrix { samples, counts })
54}
55
56fn parse_count(cell: &[u8]) -> Result<f64> {
59 if cell.is_empty() {
60 return Err(RsomicsError::InvalidInput("empty count cell".into()));
61 }
62 let mut acc: u64 = 0;
63 let mut all_digits = true;
64 for &b in cell {
65 if b.is_ascii_digit() {
66 acc = acc.wrapping_mul(10).wrapping_add((b - b'0') as u64);
67 } else {
68 all_digits = false;
69 break;
70 }
71 }
72 if all_digits {
73 return Ok(acc as f64);
74 }
75 let s = std::str::from_utf8(cell)
76 .map_err(|_| RsomicsError::InvalidInput("non-UTF8 count cell".into()))?;
77 let v: f64 = s
78 .parse()
79 .map_err(|_| RsomicsError::InvalidInput(format!("non-numeric count '{s}'")))?;
80 if v < 0.0 {
81 return Err(RsomicsError::InvalidInput(
82 "negative counts not allowed".into(),
83 ));
84 }
85 Ok(v)
86}
87
88pub fn tmm_factors(m: &Matrix) -> Vec<f64> {
91 let n_samples = m.samples.len();
92 if n_samples == 0 {
93 return Vec::new();
94 }
95
96 let kept: Vec<&Vec<f64>> = m
97 .counts
98 .iter()
99 .filter(|row| row.iter().any(|&c| c > 0.0))
100 .collect();
101
102 if kept.is_empty() || n_samples == 1 {
103 return vec![1.0; n_samples];
104 }
105
106 let n_genes = kept.len();
107 let mut lib_size = vec![0.0f64; n_samples];
108 for row in &kept {
109 for (j, &c) in row.iter().enumerate() {
110 lib_size[j] += c;
111 }
112 }
113
114 let ref_col = reference_column(&kept, &lib_size, n_genes, n_samples);
115
116 let mut obs = vec![0.0f64; n_genes];
117 let reference: Vec<f64> = (0..n_genes).map(|g| kept[g][ref_col]).collect();
118
119 let mut f = vec![0.0f64; n_samples];
120 for (j, fj) in f.iter_mut().enumerate() {
121 for (g, o) in obs.iter_mut().enumerate() {
122 *o = kept[g][j];
123 }
124 *fj = calc_factor_tmm(&obs, &reference, lib_size[j], lib_size[ref_col]);
125 }
126
127 let log_mean: f64 = f.iter().map(|x| x.ln()).sum::<f64>() / n_samples as f64;
128 let scale = log_mean.exp();
129 for fj in &mut f {
130 *fj /= scale;
131 }
132 f
133}
134
135fn reference_column(
136 kept: &[&Vec<f64>],
137 lib_size: &[f64],
138 n_genes: usize,
139 n_samples: usize,
140) -> usize {
141 let mut f75 = vec![0.0f64; n_samples];
142 let mut col = vec![0.0f64; n_genes];
143 for (j, slot) in f75.iter_mut().enumerate() {
144 for (g, c) in col.iter_mut().enumerate() {
145 *c = kept[g][j];
146 }
147 *slot = quantile_type7(&mut col, 0.75) / lib_size[j];
148 }
149
150 let mut sorted = f75.clone();
151 let median = median_sorted(&mut sorted);
152 if median < 1e-20 {
153 let mut sqrt_mass = vec![0.0f64; n_samples];
155 for row in kept {
156 for (j, m) in sqrt_mass.iter_mut().enumerate() {
157 *m += row[j].sqrt();
158 }
159 }
160 return argmax(&sqrt_mass);
161 }
162
163 let mean: f64 = f75.iter().sum::<f64>() / n_samples as f64;
164 let mut best = 0usize;
165 let mut best_dist = f64::INFINITY;
166 for (j, &v) in f75.iter().enumerate() {
167 let d = (v - mean).abs();
168 if d < best_dist {
169 best_dist = d;
170 best = j;
171 }
172 }
173 best
174}
175
176fn argmax(x: &[f64]) -> usize {
177 x.iter()
178 .enumerate()
179 .fold((0usize, f64::NEG_INFINITY), |(bi, bm), (i, &v)| {
180 if v > bm { (i, v) } else { (bi, bm) }
181 })
182 .0
183}
184
185fn calc_factor_tmm(obs: &[f64], reference: &[f64], n_o: f64, n_r: f64) -> f64 {
186 let mut log_r = Vec::with_capacity(obs.len());
187 let mut abs_e = Vec::with_capacity(obs.len());
188 let mut var = Vec::with_capacity(obs.len());
189
190 for (&o, &r) in obs.iter().zip(reference.iter()) {
191 let lr = ((o / n_o) / (r / n_r)).log2();
192 let ae = ((o / n_o).log2() + (r / n_r).log2()) / 2.0;
193 let v = (n_o - o) / n_o / o + (n_r - r) / n_r / r;
194 if lr.is_finite() && ae.is_finite() && ae > ACUTOFF {
195 log_r.push(lr);
196 abs_e.push(ae);
197 var.push(v);
198 }
199 }
200
201 if log_r.is_empty() {
202 return 1.0;
203 }
204 if log_r.iter().fold(0.0f64, |m, &x| m.max(x.abs())) < 1e-6 {
205 return 1.0;
206 }
207
208 let n = log_r.len();
209 let lo_l = (n as f64 * LOGRATIO_TRIM).floor() + 1.0;
210 let hi_l = n as f64 + 1.0 - lo_l;
211 let lo_s = (n as f64 * SUM_TRIM).floor() + 1.0;
212 let hi_s = n as f64 + 1.0 - lo_s;
213
214 let rank_r = average_rank(&log_r);
215 let rank_e = average_rank(&abs_e);
216
217 let mut num = 0.0f64;
218 let mut den = 0.0f64;
219 for i in 0..n {
220 if rank_r[i] >= lo_l && rank_r[i] <= hi_l && rank_e[i] >= lo_s && rank_e[i] <= hi_s {
221 let w = 1.0 / var[i];
222 if w.is_finite() && (log_r[i] / var[i]).is_finite() {
223 num += log_r[i] / var[i];
224 den += w;
225 }
226 }
227 }
228
229 let f = if den == 0.0 { 0.0 } else { num / den };
230 let f = if f.is_nan() { 0.0 } else { f };
231 2.0f64.powf(f)
232}
233
234fn average_rank(x: &[f64]) -> Vec<f64> {
236 let n = x.len();
237 let mut order: Vec<usize> = (0..n).collect();
238 order.sort_by(|&a, &b| x[a].partial_cmp(&x[b]).unwrap());
239
240 let mut ranks = vec![0.0f64; n];
241 let mut i = 0;
242 while i < n {
243 let mut j = i + 1;
244 while j < n && x[order[j]] == x[order[i]] {
245 j += 1;
246 }
247 let avg = ((i + 1 + j) as f64) / 2.0;
249 for &idx in &order[i..j] {
250 ranks[idx] = avg;
251 }
252 i = j;
253 }
254 ranks
255}
256
257fn quantile_type7(x: &mut [f64], p: f64) -> f64 {
259 let n = x.len();
260 if n == 0 {
261 return 0.0;
262 }
263 x.sort_by(|a, b| a.partial_cmp(b).unwrap());
264 let h = (n - 1) as f64 * p;
265 let lo = h.floor() as usize;
266 let frac = h - lo as f64;
267 if lo + 1 < n {
268 x[lo] + frac * (x[lo + 1] - x[lo])
269 } else {
270 x[lo]
271 }
272}
273
274fn median_sorted(x: &mut [f64]) -> f64 {
275 x.sort_by(|a, b| a.partial_cmp(b).unwrap());
276 let n = x.len();
277 if n == 0 {
278 return 0.0;
279 }
280 if n % 2 == 1 {
281 x[n / 2]
282 } else {
283 (x[n / 2 - 1] + x[n / 2]) / 2.0
284 }
285}
286
287pub fn write_factors(samples: &[String], factors: &[f64], output: &mut dyn Write) -> Result<()> {
288 let mut out = BufWriter::new(output);
289 writeln!(out, "sample\tnorm.factor").map_err(RsomicsError::Io)?;
290 for (s, f) in samples.iter().zip(factors.iter()) {
291 writeln!(out, "{s}\t{f:.10}").map_err(RsomicsError::Io)?;
292 }
293 out.flush().map_err(RsomicsError::Io)?;
294 Ok(())
295}
296
297pub fn run(counts_path: &Path, output: &mut dyn Write) -> Result<usize> {
298 let m = read_matrix(counts_path)?;
299 let factors = tmm_factors(&m);
300 write_factors(&m.samples, &factors, output)?;
301 Ok(m.samples.len())
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn average_rank_handles_ties() {
310 let r = average_rank(&[3.0, 1.0, 1.0, 2.0]);
312 assert_eq!(r, vec![4.0, 1.5, 1.5, 3.0]);
313 }
314
315 #[test]
316 fn quantile_type7_matches_r() {
317 let mut v = vec![1.0, 2.0, 3.0, 4.0];
319 assert!((quantile_type7(&mut v, 0.75) - 3.25).abs() < 1e-12);
320 let mut v = vec![30.0, 10.0, 20.0];
322 assert!((quantile_type7(&mut v, 0.75) - 25.0).abs() < 1e-12);
323 }
324
325 #[test]
326 fn identical_columns_give_unit_factors() {
327 let m = Matrix {
328 samples: vec!["a".into(), "b".into(), "c".into()],
329 counts: vec![
330 vec![100.0, 100.0, 100.0],
331 vec![50.0, 50.0, 50.0],
332 vec![10.0, 10.0, 10.0],
333 vec![200.0, 200.0, 200.0],
334 ],
335 };
336 for f in tmm_factors(&m) {
337 assert!((f - 1.0).abs() < 1e-9);
338 }
339 }
340
341 #[test]
342 fn factors_have_unit_geometric_mean() {
343 let m = Matrix {
344 samples: vec!["a".into(), "b".into()],
345 counts: vec![
346 vec![100.0, 200.0],
347 vec![50.0, 80.0],
348 vec![10.0, 5.0],
349 vec![300.0, 600.0],
350 vec![5.0, 0.0],
351 ],
352 };
353 let f = tmm_factors(&m);
354 let log_mean: f64 = f.iter().map(|x| x.ln()).sum::<f64>() / f.len() as f64;
355 assert!(log_mean.abs() < 1e-9);
356 }
357}