Skip to main content

limma/
fitmixture.rs

1//! Mixture-model fit by nonlinear least squares.
2//!
3//! Pure-Rust port of limma's `fitmixture.R` ([`fitmixture`]): for a titration
4//! series where each array is a known mixture proportion of two samples, fit
5//! the per-probe log-ratio `M`, average `A` and residual standard deviation,
6//! using the same Gauss-Newton-style update as limma (in `log(cosh)` space so
7//! the per-probe means stay well-conditioned).
8
9use crate::logsumexp::logcosh;
10use ndarray::Array2;
11
12/// Output of [`fitmixture`], all per-probe (length `nprobes`) and in `log2`
13/// units: `a` is the average expression, `m` the log-ratio, `stdev` the
14/// residual standard deviation.
15#[derive(Debug, Clone)]
16pub struct FitMixture {
17    pub a: Vec<f64>,
18    pub m: Vec<f64>,
19    pub stdev: Vec<f64>,
20}
21
22/// `mub[p,a] = logcosh(b_p/2) + ln(1 + tanh(b_p/2)·pm_a)` together with the
23/// per-probe offset `a_p = mean_a(z - mub)`.
24fn mub_and_offset(b: &[f64], z: &Array2<f64>, pm: &[f64]) -> (Array2<f64>, Vec<f64>) {
25    let (nprobes, narrays) = (z.nrows(), z.ncols());
26    let mut mub = Array2::<f64>::zeros((nprobes, narrays));
27    let mut a = vec![0.0; nprobes];
28    for p in 0..nprobes {
29        let half = b[p] / 2.0;
30        let lc = logcosh(half);
31        let th = half.tanh();
32        let mut sum = 0.0;
33        for j in 0..narrays {
34            let m = lc + (1.0 + th * pm[j]).ln();
35            mub[[p, j]] = m;
36            sum += z[[p, j]] - m;
37        }
38        a[p] = sum / narrays as f64;
39    }
40    (mub, a)
41}
42
43/// `fitmixture(log2e, mixprop, niter)`: `log2e` is the `nprobes × narrays`
44/// matrix of log2 expression, `mixprop` the per-array mixing proportion of
45/// sample 1. Returns per-probe `A`, `M` and residual `stdev`.
46pub fn fitmixture(log2e: &Array2<f64>, mixprop: &[f64], niter: usize) -> FitMixture {
47    let nprobes = log2e.nrows();
48    let narrays = log2e.ncols();
49    let ln2 = std::f64::consts::LN_2;
50    let pm: Vec<f64> = mixprop.iter().map(|&m| 2.0 * m - 1.0).collect();
51
52    // Linear-model starting values: regress intensities y = 2^log2e on the two
53    // columns [mixprop, 1-mixprop], floor coefficients at 1, take b = log ratio.
54    let mut xtx = [[0.0_f64; 2]; 2];
55    for &mp in mixprop {
56        let (x0, x1) = (mp, 1.0 - mp);
57        xtx[0][0] += x0 * x0;
58        xtx[0][1] += x0 * x1;
59        xtx[1][1] += x1 * x1;
60    }
61    xtx[1][0] = xtx[0][1];
62    let det = xtx[0][0] * xtx[1][1] - xtx[0][1] * xtx[1][0];
63    let inv = [
64        [xtx[1][1] / det, -xtx[0][1] / det],
65        [-xtx[1][0] / det, xtx[0][0] / det],
66    ];
67
68    let mut z = Array2::<f64>::zeros((nprobes, narrays));
69    let mut b = vec![0.0; nprobes];
70    for p in 0..nprobes {
71        let (mut xty0, mut xty1) = (0.0, 0.0);
72        for j in 0..narrays {
73            let l = log2e[[p, j]];
74            z[[p, j]] = l * ln2;
75            let y = l.exp2();
76            xty0 += mixprop[j] * y;
77            xty1 += (1.0 - mixprop[j]) * y;
78        }
79        let s0 = (inv[0][0] * xty0 + inv[0][1] * xty1).max(1.0);
80        let s1 = (inv[1][0] * xty0 + inv[1][1] * xty1).max(1.0);
81        b[p] = s0.ln() - s1.ln();
82    }
83
84    // Gauss-Newton iterations on the per-probe b (= M·ln2).
85    for _ in 0..niter {
86        let (mub, a) = mub_and_offset(&b, &z, &pm);
87        for p in 0..nprobes {
88            let th = (b[p] / 2.0).tanh();
89            let mut dmu = vec![0.0; narrays];
90            let mut dmu_mean = 0.0;
91            for j in 0..narrays {
92                dmu[j] = (th + pm[j]) / (1.0 + th * pm[j]) / 2.0;
93                dmu_mean += dmu[j];
94            }
95            dmu_mean /= narrays as f64;
96            let (mut num, mut den) = (0.0, 0.0);
97            for j in 0..narrays {
98                let mu = a[p] + mub[[p, j]];
99                num += dmu[j] * (z[[p, j]] - mu);
100                let dd = dmu[j] - dmu_mean;
101                den += dd * dd;
102            }
103            b[p] += (num / narrays as f64) / (1e-6 + den / narrays as f64);
104        }
105    }
106
107    // Final offsets and residual standard deviations.
108    let (mub, a) = mub_and_offset(&b, &z, &pm);
109    let scale = (narrays as f64 / (narrays as f64 - 2.0) / narrays as f64).sqrt();
110    let mut stdev = vec![0.0; nprobes];
111    for p in 0..nprobes {
112        let mut ss = 0.0;
113        for j in 0..narrays {
114            let r = z[[p, j]] - (a[p] + mub[[p, j]]);
115            ss += r * r;
116        }
117        stdev[p] = ss.sqrt() * scale / ln2;
118    }
119
120    FitMixture {
121        a: a.iter().map(|&v| v / ln2).collect(),
122        m: b.iter().map(|&v| v / ln2).collect(),
123        stdev,
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use ndarray::array;
131
132    fn close(a: &[f64], b: &[f64], tol: f64) -> bool {
133        a.len() == b.len()
134            && a.iter()
135                .zip(b)
136                .all(|(&x, &y)| (x - y).abs() <= tol + tol * y.abs())
137    }
138
139    #[test]
140    fn fitmixture_matches_r() {
141        // Reference: fitmixture(log2e, mixprop, niter=4) in limma 3.68.3, on a
142        // model-consistent titration (mixprop 0..1) with mild noise.
143        let log2e = array![
144            [
145                4.3904760172447,
146                5.80204869625303,
147                6.58880543876819,
148                7.0317975199035,
149                7.34222007491416,
150                7.62233273319441
151            ],
152            [
153                8.61562128120492,
154                8.44189831411776,
155                8.13670008266351,
156                7.55564296226014,
157                6.81782517390952,
158                5.63099272063628
159            ],
160            [
161                6.34008451545423,
162                8.0396611674378,
163                8.73791188700199,
164                9.17095797713186,
165                9.66382947418352,
166                9.87762613040235
167            ],
168            [
169                6.03164313024805,
170                5.82338721016494,
171                5.36457218491286,
172                4.80397607212745,
173                4.32376814079242,
174                3.02300486774156
175            ],
176            [
177                4.92710401176557,
178                6.38918171782614,
179                7.10185838995849,
180                7.65096816800767,
181                8.03727417840249,
182                8.19681894669787
183            ]
184        ];
185        let mixprop = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
186        let o = fitmixture(&log2e, &mixprop, 4);
187        assert!(close(
188            &o.a,
189            &[
190                6.01927142601979,
191                7.13664451612898,
192                8.12130966885468,
193                4.51486665208882,
194                6.58763605034809
195            ],
196            1e-9
197        ));
198        assert!(close(
199            &o.m,
200            &[
201                3.25818790500491,
202                -3.0629176507105,
203                3.54802170625494,
204                -2.99367405470867,
205                3.33559775914428
206            ],
207            1e-9
208        ));
209        assert!(close(
210            &o.stdev,
211            &[
212                0.0360867991962196,
213                0.0785007679371339,
214                0.0508483212078436,
215                0.0857081603616798,
216                0.0535213699021616
217            ],
218            1e-9
219        ));
220    }
221}