Skip to main content

rsomics_edger_exact_test/
lib.rs

1//! edgeR exactTest: classic two-group negative-binomial exact test for DE.
2//! Method: Robinson & Smyth (2008), Biostatistics 9:321-332 — an overdispersed
3//! generalization of Fisher's exact test, conditioning each gene on its total
4//! count. Per gene: logFC from per-group one-group NB fits, logCPM from
5//! aveLogCPM, and PValue from the doubletail NB convolution after quantile-
6//! adjusting both groups to a common library size.
7
8mod special;
9
10use std::fs::File;
11use std::io::{BufRead, BufReader, BufWriter, Write};
12use std::path::Path;
13
14use rsomics_common::{Result, RsomicsError};
15
16const LN2: f64 = std::f64::consts::LN_2;
17const PRIOR_COUNT_FC: f64 = 0.125; // edgeR exactTest's logFC prior.count
18const AVELOGCPM_PRIOR: f64 = 2.0; // edgeR aveLogCPM default
19const AVELOGCPM_DISP: f64 = 0.05; // exactTest's logCPM uses aveLogCPM's own default, not the test dispersion
20const BIG_COUNT: f64 = 900.0;
21
22pub struct Matrix {
23    pub header: String,
24    pub genes: Vec<String>,
25    pub counts: Vec<f64>,
26    pub n_samples: usize,
27}
28
29impl Matrix {
30    pub fn load(path: &Path) -> Result<Self> {
31        let file = File::open(path)
32            .map_err(|e| RsomicsError::InvalidInput(format!("{}: {e}", path.display())))?;
33        let mut lines = BufReader::new(file).lines();
34        let header = lines
35            .next()
36            .ok_or_else(|| RsomicsError::InvalidInput("empty count matrix".into()))?
37            .map_err(RsomicsError::Io)?;
38        let n_samples = header.split('\t').count() - 1;
39        if n_samples == 0 {
40            return Err(RsomicsError::InvalidInput(
41                "count matrix has no sample columns".into(),
42            ));
43        }
44
45        let mut genes = Vec::new();
46        let mut counts = Vec::new();
47        for line in lines {
48            let line = line.map_err(RsomicsError::Io)?;
49            if line.is_empty() {
50                continue;
51            }
52            let mut fields = line.split('\t');
53            let gene = fields
54                .next()
55                .ok_or_else(|| RsomicsError::InvalidInput("row without a gene id".into()))?;
56            genes.push(gene.to_string());
57            let before = counts.len();
58            for f in fields {
59                counts.push(f.parse::<f64>().map_err(|_| {
60                    RsomicsError::InvalidInput(format!("non-numeric count '{f}' for gene {gene}"))
61                })?);
62            }
63            if counts.len() - before != n_samples {
64                return Err(RsomicsError::InvalidInput(format!(
65                    "gene {gene}: {} values, header has {n_samples} samples",
66                    counts.len() - before
67                )));
68            }
69        }
70        Ok(Self {
71            header,
72            genes,
73            counts,
74            n_samples,
75        })
76    }
77
78    pub fn n_genes(&self) -> usize {
79        self.genes.len()
80    }
81    fn row(&self, g: usize) -> &[f64] {
82        &self.counts[g * self.n_samples..(g + 1) * self.n_samples]
83    }
84}
85
86pub struct ExactTestOpts {
87    pub dispersion: f64,
88    pub fdr: bool,
89}
90
91/// Split a `--group` spec like `a,a,b,b` into the two sample-index lists. Levels
92/// are taken in first-appearance order; logFC is reported level2 over level1.
93fn parse_groups(spec: &str, n_samples: usize) -> Result<(Vec<usize>, Vec<usize>)> {
94    let labels: Vec<&str> = spec.split(',').map(str::trim).collect();
95    if labels.len() != n_samples {
96        return Err(RsomicsError::InvalidInput(format!(
97            "--group has {} labels but matrix has {n_samples} samples",
98            labels.len()
99        )));
100    }
101    let mut levels: Vec<&str> = Vec::new();
102    for l in &labels {
103        if !levels.contains(l) {
104            levels.push(l);
105        }
106    }
107    if levels.len() != 2 {
108        return Err(RsomicsError::InvalidInput(format!(
109            "exactTest needs exactly two groups, found {}: {levels:?}",
110            levels.len()
111        )));
112    }
113    let mut g1 = Vec::new();
114    let mut g2 = Vec::new();
115    for (i, &lab) in labels.iter().enumerate() {
116        if lab == levels[0] { &mut g1 } else { &mut g2 }.push(i);
117    }
118    Ok((g1, g2))
119}
120
121fn load_norm_factors(path: &Path, n_samples: usize) -> Result<Vec<f64>> {
122    let file = File::open(path)
123        .map_err(|e| RsomicsError::InvalidInput(format!("{}: {e}", path.display())))?;
124    let mut factors = Vec::with_capacity(n_samples);
125    for line in BufReader::new(file).lines() {
126        let line = line.map_err(RsomicsError::Io)?;
127        let line = line.trim();
128        if line.is_empty() || line.starts_with('#') {
129            continue;
130        }
131        let val = line.rsplit('\t').next().unwrap_or(line);
132        factors.push(
133            val.parse::<f64>().map_err(|_| {
134                RsomicsError::InvalidInput(format!("non-numeric norm factor '{val}'"))
135            })?,
136        );
137    }
138    if factors.len() != n_samples {
139        return Err(RsomicsError::InvalidInput(format!(
140            "{} norm factors for {n_samples} samples",
141            factors.len()
142        )));
143    }
144    Ok(factors)
145}
146
147const MAXIT: usize = 50;
148const TOL: f64 = 1e-10;
149
150/// One-group NB fit (edgeR mglmOneGroup / glm_one_group): Fisher scoring for the
151/// single coefficient beta where mu[j] = exp(beta + offset[j]). Returns beta on
152/// the natural-log scale; -inf when the total count is zero.
153fn mglm_one_group(row: &[f64], offset: &[f64], dispersion: f64) -> f64 {
154    let total: f64 = row.iter().sum();
155    if total == 0.0 {
156        return f64::NEG_INFINITY;
157    }
158    let mean_off = offset.iter().sum::<f64>() / offset.len() as f64;
159    let mut beta = (total / row.len() as f64).ln() - mean_off;
160    for _ in 0..MAXIT {
161        let mut dl = 0.0;
162        let mut info = 0.0;
163        for (&y, &off) in row.iter().zip(offset) {
164            let mu = (beta + off).exp();
165            let denom = 1.0 + mu * dispersion;
166            dl += (y - mu) / denom;
167            info += mu / denom;
168        }
169        let step = dl / info;
170        beta += step;
171        if step.abs() < TOL {
172            break;
173        }
174    }
175    beta
176}
177
178/// aveLogCPM for one gene across all libraries (edgeR aveLogCPM, prior.count=2,
179/// dispersion=0.05). Counts are augmented by a per-library-scaled prior; the fit
180/// runs against offsets that include twice that prior.
181fn ave_log_cpm_gene(row: &[f64], lib: &[f64]) -> f64 {
182    let mean_lib = lib.iter().sum::<f64>() / lib.len() as f64;
183    let prior: Vec<f64> = lib
184        .iter()
185        .map(|&l| AVELOGCPM_PRIOR * l / mean_lib)
186        .collect();
187    let off_aug: Vec<f64> = lib
188        .iter()
189        .zip(&prior)
190        .map(|(&l, &p)| (l + 2.0 * p).ln())
191        .collect();
192    let aug: Vec<f64> = row.iter().zip(&prior).map(|(&c, &p)| c + p).collect();
193    let beta = mglm_one_group(&aug, &off_aug, AVELOGCPM_DISP);
194    (beta + 1e6f64.ln()) / LN2
195}
196
197/// q2qnbinom: remap count x from an NB with mean `input_mean` to the NB with
198/// mean `output_mean` (same dispersion) by averaging a normal-quantile and a
199/// gamma-quantile transformation in the matched tail (Robinson & Smyth).
200fn q2qnbinom(x: f64, input_mean: f64, output_mean: f64, dispersion: f64) -> f64 {
201    let eps = 1e-14;
202    let (mut im, mut om) = (input_mean, output_mean);
203    if im < eps || om < eps {
204        im += 0.25;
205        om += 0.25;
206    }
207    let ri = 1.0 + dispersion * im;
208    let vi = im * ri;
209    let ro = 1.0 + dispersion * om;
210    let vo = om * ro;
211    let upper = x >= im;
212    let p1 = special::pnorm(x, im, vi.sqrt(), !upper, true);
213    let p2 = special::pgamma(x, im / ri, ri, !upper, true);
214    let q1 = special::qnorm(p1, om, vo.sqrt(), !upper, true);
215    let q2 = special::qgamma(p2, om / ro, ro, !upper, true);
216    (q1 + q2) / 2.0
217}
218
219/// Doubletail NB exact test on two unrounded group sums (n1 / n2 libraries),
220/// conditional on the total. Mirrors edgeR exactTestDoubleTail: Poisson limit
221/// via the exact binomial test, beta approximation for large counts (using the
222/// unrounded sums), otherwise the convolution of two NB pmfs summed over the
223/// equal-or-more-extreme integer tail and doubled.
224fn exact_doubletail(raw1: f64, raw2: f64, n1: f64, n2: f64, dispersion: f64) -> f64 {
225    let s1 = raw1.round();
226    let s2 = raw2.round();
227    let s = s1 + s2;
228    if s == 0.0 {
229        return 1.0;
230    }
231    if dispersion <= 0.0 {
232        return binom_test(s1, s2, n1 / (n1 + n2));
233    }
234    if s1 > BIG_COUNT && s2 > BIG_COUNT {
235        return beta_approx(raw1, raw2, n1, n2, dispersion);
236    }
237    let mu = s / (n1 + n2);
238    let mu1 = n1 * mu;
239    let mu2 = n2 * mu;
240    let size1 = n1 / dispersion;
241    let size2 = n2 / dispersion;
242    let p_bot = special::dnbinom_mu(s, (n1 + n2) / dispersion, s);
243    let p = if s1 < mu1 {
244        let mut top = 0.0;
245        let mut x = 0.0;
246        while x <= s1 {
247            top += special::dnbinom_mu(x, size1, mu1) * special::dnbinom_mu(s - x, size2, mu2);
248            x += 1.0;
249        }
250        2.0 * top / p_bot
251    } else if s1 > mu1 {
252        let mut top = 0.0;
253        let mut x = s1;
254        while x <= s {
255            top += special::dnbinom_mu(x, size1, mu1) * special::dnbinom_mu(s - x, size2, mu2);
256            x += 1.0;
257        }
258        2.0 * top / p_bot
259    } else {
260        1.0
261    };
262    p.min(1.0)
263}
264
265/// Exact two-sided binomial test (Poisson-limit case), edgeR binomTest with
266/// p != 0.5: sum the probabilities of all outcomes no more likely than observed.
267fn binom_test(y1: f64, y2: f64, p: f64) -> f64 {
268    let y1 = y1.round();
269    let size = (y1 + y2.round()) as i64;
270    if size == 0 {
271        return 1.0;
272    }
273    let mut d = vec![0.0f64; (size + 1) as usize];
274    for (k, dk) in d.iter_mut().enumerate() {
275        *dk = special::dbinom(k as f64, size as f64, p);
276    }
277    let observed = d[y1 as usize];
278    let tol = observed * (1.0 + 1e-7);
279    let pv: f64 = d.iter().filter(|&&v| v <= tol).sum();
280    pv.min(1.0)
281}
282
283/// Beta approximation for large counts (edgeR exactTestBetaApprox).
284fn beta_approx(s1: f64, s2: f64, n1: f64, n2: f64, dispersion: f64) -> f64 {
285    let y = s1 + s2;
286    if y <= 0.0 {
287        return 1.0;
288    }
289    let mu = y / (n1 + n2);
290    let alpha1 = n1 * mu / (1.0 + dispersion * mu);
291    let alpha2 = n2 / n1 * alpha1;
292    let med = special::qbeta(0.5, alpha1, alpha2);
293    if (s1 + 0.5) / y < med {
294        (2.0 * special::pbeta((s1 + 0.5) / y, alpha1, alpha2, true, false)).min(1.0)
295    } else if (s1 - 0.5) / y > med {
296        (2.0 * special::pbeta((s1 - 0.5) / y, alpha1, alpha2, false, false)).min(1.0)
297    } else {
298        1.0
299    }
300}
301
302fn bh_fdr(pvals: &[f64]) -> Vec<f64> {
303    let n = pvals.len();
304    let mut order: Vec<usize> = (0..n).collect();
305    order.sort_by(|&a, &b| pvals[b].partial_cmp(&pvals[a]).unwrap());
306    let mut adj = vec![0.0f64; n];
307    let mut cummin = f64::INFINITY;
308    for (rank, &i) in order.iter().enumerate() {
309        let m = n - rank;
310        let v = (pvals[i] * n as f64 / m as f64).min(1.0);
311        cummin = cummin.min(v);
312        adj[i] = cummin;
313    }
314    adj
315}
316
317pub fn exact_test(
318    counts_path: &Path,
319    group_spec: &str,
320    norm_factors_path: Option<&Path>,
321    opts: &ExactTestOpts,
322    output: &mut dyn Write,
323) -> Result<u64> {
324    let m = Matrix::load(counts_path)?;
325    let (g1, g2) = parse_groups(group_spec, m.n_samples)?;
326
327    let norm_factors = match norm_factors_path {
328        Some(p) => load_norm_factors(p, m.n_samples)?,
329        None => vec![1.0; m.n_samples],
330    };
331
332    let mut lib = vec![0.0f64; m.n_samples];
333    for row in m.counts.chunks_exact(m.n_samples) {
334        for (s, &c) in lib.iter_mut().zip(row) {
335            *s += c;
336        }
337    }
338    let eff_lib: Vec<f64> = lib
339        .iter()
340        .zip(&norm_factors)
341        .map(|(&l, &f)| l * f)
342        .collect();
343    let offset: Vec<f64> = eff_lib.iter().map(|&l| l.ln()).collect();
344    let lib_average = (offset.iter().sum::<f64>() / offset.len() as f64).exp();
345
346    let mean_eff = eff_lib.iter().sum::<f64>() / eff_lib.len() as f64;
347    let prior_fc: Vec<f64> = eff_lib
348        .iter()
349        .map(|&l| PRIOR_COUNT_FC * l / mean_eff)
350        .collect();
351    let offset_aug: Vec<f64> = eff_lib
352        .iter()
353        .zip(&prior_fc)
354        .map(|(&l, &p)| (l + 2.0 * p).ln())
355        .collect();
356
357    let disp = opts.dispersion;
358    let n1 = g1.len() as f64;
359    let n2 = g2.len() as f64;
360
361    let gene_col = m.header.split('\t').next().unwrap_or("gene");
362    let mut logfc = Vec::with_capacity(m.n_genes());
363    let mut logcpm = Vec::with_capacity(m.n_genes());
364    let mut pvals = Vec::with_capacity(m.n_genes());
365
366    let mut buf1 = vec![0.0f64; g1.len()];
367    let mut buf2 = vec![0.0f64; g2.len()];
368    let mut off1 = vec![0.0f64; g1.len()];
369    let mut off2 = vec![0.0f64; g2.len()];
370    for (j, &s) in g1.iter().enumerate() {
371        off1[j] = offset_aug[s];
372    }
373    for (j, &s) in g2.iter().enumerate() {
374        off2[j] = offset_aug[s];
375    }
376    let off1_g: Vec<f64> = g1.iter().map(|&s| offset[s]).collect();
377    let off2_g: Vec<f64> = g2.iter().map(|&s| offset[s]).collect();
378    let mut all_off = vec![0.0f64; g1.len() + g2.len()];
379    all_off[..g1.len()].copy_from_slice(&off1_g);
380    all_off[g1.len()..].copy_from_slice(&off2_g);
381    let mut all_row = vec![0.0f64; g1.len() + g2.len()];
382
383    for g in 0..m.n_genes() {
384        let row = m.row(g);
385
386        for (j, &s) in g1.iter().enumerate() {
387            buf1[j] = row[s] + prior_fc[s];
388        }
389        for (j, &s) in g2.iter().enumerate() {
390            buf2[j] = row[s] + prior_fc[s];
391        }
392        let ab1 = mglm_one_group(&buf1, &off1, disp);
393        let ab2 = mglm_one_group(&buf2, &off2, disp);
394        logfc.push((ab2 - ab1) / LN2);
395
396        for (j, &s) in g1.iter().enumerate() {
397            all_row[j] = row[s];
398        }
399        for (j, &s) in g2.iter().enumerate() {
400            all_row[g1.len() + j] = row[s];
401        }
402        logcpm.push(ave_log_cpm_gene(&all_row, &eff_lib));
403
404        let abundance = mglm_one_group(&all_row, &all_off, disp);
405        let e = abundance.exp();
406        let output_mean = e * lib_average;
407        let mut s1 = 0.0;
408        for &si in &g1 {
409            let im = e * eff_lib[si];
410            s1 += q2qnbinom(row[si], im, output_mean, disp);
411        }
412        let mut s2 = 0.0;
413        for &si in &g2 {
414            let im = e * eff_lib[si];
415            s2 += q2qnbinom(row[si], im, output_mean, disp);
416        }
417        pvals.push(exact_doubletail(s1, s2, n1, n2, disp));
418    }
419
420    let fdr = if opts.fdr { Some(bh_fdr(&pvals)) } else { None };
421
422    let mut out = BufWriter::new(output);
423    if let Some(_f) = &fdr {
424        writeln!(out, "{gene_col}\tlogFC\tlogCPM\tPValue\tFDR").map_err(RsomicsError::Io)?;
425    } else {
426        writeln!(out, "{gene_col}\tlogFC\tlogCPM\tPValue").map_err(RsomicsError::Io)?;
427    }
428    for g in 0..m.n_genes() {
429        write!(
430            out,
431            "{}\t{}\t{}\t{}",
432            m.genes[g],
433            fmt(logfc[g]),
434            fmt(logcpm[g]),
435            fmt_p(pvals[g])
436        )
437        .map_err(RsomicsError::Io)?;
438        if let Some(f) = &fdr {
439            write!(out, "\t{}", fmt_p(f[g])).map_err(RsomicsError::Io)?;
440        }
441        writeln!(out).map_err(RsomicsError::Io)?;
442    }
443    out.flush().map_err(RsomicsError::Io)?;
444    Ok(m.n_genes() as u64)
445}
446
447fn fmt(v: f64) -> String {
448    format!("{v:.6}")
449}
450
451/// Scientific notation with R's `formatC(format="e")` exponent shape: a sign and
452/// at least two exponent digits (`4.676393e-02`), so goldens diff cleanly.
453fn fmt_p(v: f64) -> String {
454    let s = format!("{v:.6e}");
455    let (mantissa, exp) = s.split_once('e').unwrap();
456    let (sign, digits) = match exp.strip_prefix('-') {
457        Some(d) => ('-', d),
458        None => ('+', exp),
459    };
460    format!("{mantissa}e{sign}{digits:0>2}")
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    // Probed in R: exactTest(d, dispersion=0.1) on the 5-gene fixture.
468    const EXPECT_LOGFC: [f64; 5] = [
469        2.07695720785,
470        0.08819639348,
471        5.71691109651,
472        -2.08892448998,
473        -0.13684694279,
474    ];
475    const EXPECT_LOGCPM: [f64; 5] = [
476        15.60092586,
477        17.29409726,
478        13.03042119,
479        12.44757336,
480        19.58295511,
481    ];
482    const EXPECT_PVAL: [f64; 5] = [
483        0.0003140189807,
484        0.8671983532895,
485        0.0010940667164,
486        0.2522215919834,
487        0.7703694926792,
488    ];
489
490    fn fixture() -> (Vec<String>, Vec<f64>, usize) {
491        let genes = (1..=5).map(|i| format!("g{i}")).collect();
492        let counts = vec![
493            10.0, 12.0, 40.0, 55.0, 100.0, 90.0, 95.0, 110.0, 0.0, 0.0, 5.0, 8.0, 3.0, 2.0, 0.0,
494            1.0, 500.0, 520.0, 480.0, 460.0,
495        ];
496        (genes, counts, 4)
497    }
498
499    #[test]
500    fn matches_edger_probe() {
501        let (genes, counts, ns) = fixture();
502        let m = Matrix {
503            header: "gene\ta1\ta2\tb1\tb2".into(),
504            genes,
505            counts,
506            n_samples: ns,
507        };
508        let (g1, g2) = parse_groups("a,a,b,b", ns).unwrap();
509        let disp = 0.1;
510
511        let mut lib = vec![0.0; ns];
512        for r in m.counts.chunks_exact(ns) {
513            for (s, &c) in lib.iter_mut().zip(r) {
514                *s += c;
515            }
516        }
517        let offset: Vec<f64> = lib.iter().map(|&l| l.ln()).collect();
518        let lib_average = (offset.iter().sum::<f64>() / ns as f64).exp();
519        let mean_eff = lib.iter().sum::<f64>() / ns as f64;
520        let prior_fc: Vec<f64> = lib.iter().map(|&l| PRIOR_COUNT_FC * l / mean_eff).collect();
521        let offset_aug: Vec<f64> = lib
522            .iter()
523            .zip(&prior_fc)
524            .map(|(&l, &p)| (l + 2.0 * p).ln())
525            .collect();
526        let off1: Vec<f64> = g1.iter().map(|&s| offset_aug[s]).collect();
527        let off2: Vec<f64> = g2.iter().map(|&s| offset_aug[s]).collect();
528        let all_off: Vec<f64> = (0..ns).map(|s| offset[s]).collect();
529
530        for g in 0..5 {
531            let row = m.row(g);
532            let b1: Vec<f64> = g1.iter().map(|&s| row[s] + prior_fc[s]).collect();
533            let b2: Vec<f64> = g2.iter().map(|&s| row[s] + prior_fc[s]).collect();
534            let lfc = (mglm_one_group(&b2, &off2, disp) - mglm_one_group(&b1, &off1, disp)) / LN2;
535            assert!((lfc - EXPECT_LOGFC[g]).abs() < 1e-6, "logFC g{g}: {lfc}");
536
537            let lcpm = ave_log_cpm_gene(row, &lib);
538            assert!(
539                (lcpm - EXPECT_LOGCPM[g]).abs() < 1e-5,
540                "logCPM g{g}: {lcpm}"
541            );
542
543            let abundance = mglm_one_group(row, &all_off, disp);
544            let e = abundance.exp();
545            let om = e * lib_average;
546            let s1: f64 = g1
547                .iter()
548                .map(|&s| q2qnbinom(row[s], e * lib[s], om, disp))
549                .sum();
550            let s2: f64 = g2
551                .iter()
552                .map(|&s| q2qnbinom(row[s], e * lib[s], om, disp))
553                .sum();
554            let p = exact_doubletail(s1, s2, 2.0, 2.0, disp);
555            let rel = (p - EXPECT_PVAL[g]).abs() / (1.0 + EXPECT_PVAL[g].abs());
556            assert!(rel < 1e-5, "PValue g{g}: {p} vs {}", EXPECT_PVAL[g]);
557        }
558    }
559
560    #[test]
561    fn bh_matches_r_padjust() {
562        let p = [
563            0.0003140189807,
564            0.8671983532895,
565            0.0010940667164,
566            0.2522215919834,
567            0.7703694926792,
568        ];
569        let adj = bh_fdr(&p);
570        // p.adjust(p, "BH") in R:
571        let expect = [
572            0.001570095,
573            0.867198353,
574            0.002735167,
575            0.420369320,
576            0.867198353,
577        ];
578        for (a, e) in adj.iter().zip(expect) {
579            assert!((a - e).abs() < 1e-6, "{a} vs {e}");
580        }
581    }
582}