Skip to main content

limma/
zscorehyper.rs

1//! Mid-p z-score equivalents of hypergeometric deviates.
2//!
3//! Pure-Rust port of limma's `zscoreHyper.R` ([`zscore_hyper`]): the signed
4//! standard-normal deviate matching the mid-p-value of a hypergeometric count.
5//! Tail probabilities are accumulated in log space (via [`logsumexp`]) so the
6//! result stays accurate deep into either tail, as in R's `log.p = TRUE` path.
7
8use crate::logsumexp::logsumexp;
9use crate::special::ln_gamma;
10use crate::zscore::zscore_from_log_tails;
11use std::f64::consts::LN_2;
12
13/// `ln C(n, k)` (log binomial coefficient); `-inf` when `k > n`.
14fn lchoose(n: u64, k: u64) -> f64 {
15    if k > n {
16        return f64::NEG_INFINITY;
17    }
18    ln_gamma((n + 1) as f64) - ln_gamma((k + 1) as f64) - ln_gamma((n - k + 1) as f64)
19}
20
21/// `ln dhyper(x; m, n, k)`: log hypergeometric pmf for drawing `x` white balls
22/// in `k` draws from an urn of `m` white and `n` black. `-inf` outside the
23/// support `max(0, k-n) <= x <= min(m, k)`.
24fn ln_dhyper(x: i64, m: u64, n: u64, k: u64) -> f64 {
25    if x < 0 {
26        return f64::NEG_INFINITY;
27    }
28    let x = x as u64;
29    if x > m || x > k || k - x > n {
30        return f64::NEG_INFINITY;
31    }
32    lchoose(m, x) + lchoose(n, k - x) - lchoose(m + n, k)
33}
34
35/// `zscoreHyper(q, m, n, k)`: z-score equivalent of a hypergeometric deviate `q`
36/// (white balls drawn) for an urn of `m` white and `n` black with `k` draws.
37/// Uses the mid-p-value — the full tail plus half the point mass at `q` — and
38/// maps the smaller tail through the normal quantile in log space.
39pub fn zscore_hyper(q: i64, m: u64, n: u64, k: u64) -> f64 {
40    let lo = k.saturating_sub(n) as i64;
41    let hi = m.min(k) as i64;
42
43    let ln_d = ln_dhyper(q, m, n, k) - LN_2;
44
45    // ln P(X > q) and ln P(X < q), summed over the support in log space.
46    let mut ln_pupper = f64::NEG_INFINITY;
47    for x in (q + 1).max(lo)..=hi {
48        ln_pupper = logsumexp(ln_pupper, ln_dhyper(x, m, n, k));
49    }
50    let mut ln_plower = f64::NEG_INFINITY;
51    for x in lo..=(q - 1).min(hi) {
52        ln_plower = logsumexp(ln_plower, ln_dhyper(x, m, n, k));
53    }
54
55    // Add half the point mass to each tail (mid-p), preserving log accuracy.
56    let pmidupper = logsumexp(ln_pupper, ln_d);
57    let pmidlower = logsumexp(ln_plower, ln_d);
58
59    zscore_from_log_tails(pmidlower, pmidupper)
60}
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65
66    fn close(a: f64, b: f64, tol: f64) -> bool {
67        (a - b).abs() <= tol + tol * b.abs()
68    }
69
70    #[test]
71    fn zscore_hyper_matches_r() {
72        // Reference: zscoreHyper(q, m, n, k) in limma 3.68.3.
73        let q = [5, 1, 8, 0, 3];
74        let want = [
75            1.85530391853958,
76            -1.35476719565062,
77            4.47368039807182,
78            -2.29868959056222,
79            0.281691411313753,
80        ];
81        for i in 0..q.len() {
82            let got = zscore_hyper(q[i], 10, 20, 8);
83            assert!(
84                close(got, want[i], 1e-9),
85                "q={}: got {got}, want {}",
86                q[i],
87                want[i]
88            );
89        }
90    }
91
92    #[test]
93    fn zscore_hyper_is_symmetric_about_mean() {
94        // m=n, k=20: the mean is q=10, where z is ~0, and q=5/15 are mirror
95        // images. Reference: zscoreHyper(c(15,10,5), 50, 50, 20).
96        let want = [2.45887335894704, 5.56583284934354e-16, -2.45887335894704];
97        let q = [15, 10, 5];
98        for i in 0..q.len() {
99            let got = zscore_hyper(q[i], 50, 50, 20);
100            assert!(close(got, want[i], 1e-9), "q={}: got {got}", q[i]);
101        }
102    }
103}