Skip to main content

rsomics_tmm_norm/
lib.rs

1use std::fs;
2use std::io::{BufWriter, Write};
3use std::path::Path;
4
5use rsomics_common::{Result, RsomicsError};
6
7const LOGRATIO_TRIM: f64 = 0.3;
8const SUM_TRIM: f64 = 0.05;
9const ACUTOFF: f64 = -1e10;
10
11pub struct Matrix {
12    pub samples: Vec<String>,
13    pub counts: Vec<Vec<f64>>,
14}
15
16pub fn read_matrix(path: &Path) -> Result<Matrix> {
17    let bytes = fs::read(path)
18        .map_err(|e| RsomicsError::InvalidInput(format!("{}: {e}", path.display())))?;
19    let mut lines = bytes.split(|&b| b == b'\n');
20
21    let header = lines
22        .next()
23        .ok_or_else(|| RsomicsError::InvalidInput("empty count matrix".into()))?;
24    let samples: Vec<String> = header
25        .split(|&b| b == b'\t')
26        .skip(1)
27        .map(|s| String::from_utf8_lossy(s).into_owned())
28        .collect();
29    if samples.is_empty() {
30        return Err(RsomicsError::InvalidInput(
31            "count matrix has no sample columns".into(),
32        ));
33    }
34
35    let mut counts: Vec<Vec<f64>> = Vec::new();
36    for line in lines {
37        if line.is_empty() {
38            continue;
39        }
40        let mut row = Vec::with_capacity(samples.len());
41        for cell in line.split(|&b| b == b'\t').skip(1) {
42            row.push(parse_count(cell)?);
43        }
44        if row.len() != samples.len() {
45            return Err(RsomicsError::InvalidInput(format!(
46                "row has {} values but header has {} samples",
47                row.len(),
48                samples.len()
49            )));
50        }
51        counts.push(row);
52    }
53    Ok(Matrix { samples, counts })
54}
55
56/// Counts are typically non-negative integers; take a no-allocation byte path
57/// for that case and fall back to a full f64 parse for decimals.
58fn parse_count(cell: &[u8]) -> Result<f64> {
59    if cell.is_empty() {
60        return Err(RsomicsError::InvalidInput("empty count cell".into()));
61    }
62    let mut acc: u64 = 0;
63    let mut all_digits = true;
64    for &b in cell {
65        if b.is_ascii_digit() {
66            acc = acc.wrapping_mul(10).wrapping_add((b - b'0') as u64);
67        } else {
68            all_digits = false;
69            break;
70        }
71    }
72    if all_digits {
73        return Ok(acc as f64);
74    }
75    let s = std::str::from_utf8(cell)
76        .map_err(|_| RsomicsError::InvalidInput("non-UTF8 count cell".into()))?;
77    let v: f64 = s
78        .parse()
79        .map_err(|_| RsomicsError::InvalidInput(format!("non-numeric count '{s}'")))?;
80    if v < 0.0 {
81        return Err(RsomicsError::InvalidInput(
82            "negative counts not allowed".into(),
83        ));
84    }
85    Ok(v)
86}
87
88/// edgeR calcNormFactors(method="TMM"). Returns one factor per sample,
89/// scaled so their geometric mean is 1.
90pub fn tmm_factors(m: &Matrix) -> Vec<f64> {
91    let n_samples = m.samples.len();
92    if n_samples == 0 {
93        return Vec::new();
94    }
95
96    let kept: Vec<&Vec<f64>> = m
97        .counts
98        .iter()
99        .filter(|row| row.iter().any(|&c| c > 0.0))
100        .collect();
101
102    if kept.is_empty() || n_samples == 1 {
103        return vec![1.0; n_samples];
104    }
105
106    let n_genes = kept.len();
107    let mut lib_size = vec![0.0f64; n_samples];
108    for row in &kept {
109        for (j, &c) in row.iter().enumerate() {
110            lib_size[j] += c;
111        }
112    }
113
114    let ref_col = reference_column(&kept, &lib_size, n_genes, n_samples);
115
116    let mut obs = vec![0.0f64; n_genes];
117    let reference: Vec<f64> = (0..n_genes).map(|g| kept[g][ref_col]).collect();
118
119    let mut f = vec![0.0f64; n_samples];
120    for (j, fj) in f.iter_mut().enumerate() {
121        for (g, o) in obs.iter_mut().enumerate() {
122            *o = kept[g][j];
123        }
124        *fj = calc_factor_tmm(&obs, &reference, lib_size[j], lib_size[ref_col]);
125    }
126
127    let log_mean: f64 = f.iter().map(|x| x.ln()).sum::<f64>() / n_samples as f64;
128    let scale = log_mean.exp();
129    for fj in &mut f {
130        *fj /= scale;
131    }
132    f
133}
134
135fn reference_column(
136    kept: &[&Vec<f64>],
137    lib_size: &[f64],
138    n_genes: usize,
139    n_samples: usize,
140) -> usize {
141    let mut f75 = vec![0.0f64; n_samples];
142    let mut col = vec![0.0f64; n_genes];
143    for (j, slot) in f75.iter_mut().enumerate() {
144        for (g, c) in col.iter_mut().enumerate() {
145            *c = kept[g][j];
146        }
147        *slot = quantile_type7(&mut col, 0.75) / lib_size[j];
148    }
149
150    let mut sorted = f75.clone();
151    let median = median_sorted(&mut sorted);
152    if median < 1e-20 {
153        // degenerate libraries: largest sqrt-mass column
154        let mut sqrt_mass = vec![0.0f64; n_samples];
155        for row in kept {
156            for (j, m) in sqrt_mass.iter_mut().enumerate() {
157                *m += row[j].sqrt();
158            }
159        }
160        return argmax(&sqrt_mass);
161    }
162
163    let mean: f64 = f75.iter().sum::<f64>() / n_samples as f64;
164    let mut best = 0usize;
165    let mut best_dist = f64::INFINITY;
166    for (j, &v) in f75.iter().enumerate() {
167        let d = (v - mean).abs();
168        if d < best_dist {
169            best_dist = d;
170            best = j;
171        }
172    }
173    best
174}
175
176fn argmax(x: &[f64]) -> usize {
177    x.iter()
178        .enumerate()
179        .fold((0usize, f64::NEG_INFINITY), |(bi, bm), (i, &v)| {
180            if v > bm { (i, v) } else { (bi, bm) }
181        })
182        .0
183}
184
185fn calc_factor_tmm(obs: &[f64], reference: &[f64], n_o: f64, n_r: f64) -> f64 {
186    let mut log_r = Vec::with_capacity(obs.len());
187    let mut abs_e = Vec::with_capacity(obs.len());
188    let mut var = Vec::with_capacity(obs.len());
189
190    for (&o, &r) in obs.iter().zip(reference.iter()) {
191        let lr = ((o / n_o) / (r / n_r)).log2();
192        let ae = ((o / n_o).log2() + (r / n_r).log2()) / 2.0;
193        let v = (n_o - o) / n_o / o + (n_r - r) / n_r / r;
194        if lr.is_finite() && ae.is_finite() && ae > ACUTOFF {
195            log_r.push(lr);
196            abs_e.push(ae);
197            var.push(v);
198        }
199    }
200
201    if log_r.is_empty() {
202        return 1.0;
203    }
204    if log_r.iter().fold(0.0f64, |m, &x| m.max(x.abs())) < 1e-6 {
205        return 1.0;
206    }
207
208    let n = log_r.len();
209    let lo_l = (n as f64 * LOGRATIO_TRIM).floor() + 1.0;
210    let hi_l = n as f64 + 1.0 - lo_l;
211    let lo_s = (n as f64 * SUM_TRIM).floor() + 1.0;
212    let hi_s = n as f64 + 1.0 - lo_s;
213
214    let rank_r = average_rank(&log_r);
215    let rank_e = average_rank(&abs_e);
216
217    let mut num = 0.0f64;
218    let mut den = 0.0f64;
219    for i in 0..n {
220        if rank_r[i] >= lo_l && rank_r[i] <= hi_l && rank_e[i] >= lo_s && rank_e[i] <= hi_s {
221            let w = 1.0 / var[i];
222            if w.is_finite() && (log_r[i] / var[i]).is_finite() {
223                num += log_r[i] / var[i];
224                den += w;
225            }
226        }
227    }
228
229    let f = if den == 0.0 { 0.0 } else { num / den };
230    let f = if f.is_nan() { 0.0 } else { f };
231    2.0f64.powf(f)
232}
233
234/// R's default `rank` with ties.method = "average".
235fn average_rank(x: &[f64]) -> Vec<f64> {
236    let n = x.len();
237    let mut order: Vec<usize> = (0..n).collect();
238    order.sort_by(|&a, &b| x[a].partial_cmp(&x[b]).unwrap());
239
240    let mut ranks = vec![0.0f64; n];
241    let mut i = 0;
242    while i < n {
243        let mut j = i + 1;
244        while j < n && x[order[j]] == x[order[i]] {
245            j += 1;
246        }
247        // ranks i..j (0-based) are tied; average of 1-based positions
248        let avg = ((i + 1 + j) as f64) / 2.0;
249        for &idx in &order[i..j] {
250            ranks[idx] = avg;
251        }
252        i = j;
253    }
254    ranks
255}
256
257/// R quantile type 7 (the default): linear interpolation, h = (n-1)*p.
258fn quantile_type7(x: &mut [f64], p: f64) -> f64 {
259    let n = x.len();
260    if n == 0 {
261        return 0.0;
262    }
263    x.sort_by(|a, b| a.partial_cmp(b).unwrap());
264    let h = (n - 1) as f64 * p;
265    let lo = h.floor() as usize;
266    let frac = h - lo as f64;
267    if lo + 1 < n {
268        x[lo] + frac * (x[lo + 1] - x[lo])
269    } else {
270        x[lo]
271    }
272}
273
274fn median_sorted(x: &mut [f64]) -> f64 {
275    x.sort_by(|a, b| a.partial_cmp(b).unwrap());
276    let n = x.len();
277    if n == 0 {
278        return 0.0;
279    }
280    if n % 2 == 1 {
281        x[n / 2]
282    } else {
283        (x[n / 2 - 1] + x[n / 2]) / 2.0
284    }
285}
286
287pub fn write_factors(samples: &[String], factors: &[f64], output: &mut dyn Write) -> Result<()> {
288    let mut out = BufWriter::new(output);
289    writeln!(out, "sample\tnorm.factor").map_err(RsomicsError::Io)?;
290    for (s, f) in samples.iter().zip(factors.iter()) {
291        writeln!(out, "{s}\t{f:.10}").map_err(RsomicsError::Io)?;
292    }
293    out.flush().map_err(RsomicsError::Io)?;
294    Ok(())
295}
296
297pub fn run(counts_path: &Path, output: &mut dyn Write) -> Result<usize> {
298    let m = read_matrix(counts_path)?;
299    let factors = tmm_factors(&m);
300    write_factors(&m.samples, &factors, output)?;
301    Ok(m.samples.len())
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn average_rank_handles_ties() {
310        // R: rank(c(3,1,1,2)) == c(4.0, 1.5, 1.5, 3.0)
311        let r = average_rank(&[3.0, 1.0, 1.0, 2.0]);
312        assert_eq!(r, vec![4.0, 1.5, 1.5, 3.0]);
313    }
314
315    #[test]
316    fn quantile_type7_matches_r() {
317        // R: quantile(c(1,2,3,4), 0.75) == 3.25
318        let mut v = vec![1.0, 2.0, 3.0, 4.0];
319        assert!((quantile_type7(&mut v, 0.75) - 3.25).abs() < 1e-12);
320        // R: quantile(c(10,20,30), 0.75) == 25
321        let mut v = vec![30.0, 10.0, 20.0];
322        assert!((quantile_type7(&mut v, 0.75) - 25.0).abs() < 1e-12);
323    }
324
325    #[test]
326    fn identical_columns_give_unit_factors() {
327        let m = Matrix {
328            samples: vec!["a".into(), "b".into(), "c".into()],
329            counts: vec![
330                vec![100.0, 100.0, 100.0],
331                vec![50.0, 50.0, 50.0],
332                vec![10.0, 10.0, 10.0],
333                vec![200.0, 200.0, 200.0],
334            ],
335        };
336        for f in tmm_factors(&m) {
337            assert!((f - 1.0).abs() < 1e-9);
338        }
339    }
340
341    #[test]
342    fn factors_have_unit_geometric_mean() {
343        let m = Matrix {
344            samples: vec!["a".into(), "b".into()],
345            counts: vec![
346                vec![100.0, 200.0],
347                vec![50.0, 80.0],
348                vec![10.0, 5.0],
349                vec![300.0, 600.0],
350                vec![5.0, 0.0],
351            ],
352        };
353        let f = tmm_factors(&m);
354        let log_mean: f64 = f.iter().map(|x| x.ln()).sum::<f64>() / f.len() as f64;
355        assert!(log_mean.abs() < 1e-9);
356    }
357}