1use crate::logsumexp::logcosh;
10use ndarray::Array2;
11
12#[derive(Debug, Clone)]
16pub struct FitMixture {
17 pub a: Vec<f64>,
18 pub m: Vec<f64>,
19 pub stdev: Vec<f64>,
20}
21
22fn mub_and_offset(b: &[f64], z: &Array2<f64>, pm: &[f64]) -> (Array2<f64>, Vec<f64>) {
25 let (nprobes, narrays) = (z.nrows(), z.ncols());
26 let mut mub = Array2::<f64>::zeros((nprobes, narrays));
27 let mut a = vec![0.0; nprobes];
28 for p in 0..nprobes {
29 let half = b[p] / 2.0;
30 let lc = logcosh(half);
31 let th = half.tanh();
32 let mut sum = 0.0;
33 for j in 0..narrays {
34 let m = lc + (1.0 + th * pm[j]).ln();
35 mub[[p, j]] = m;
36 sum += z[[p, j]] - m;
37 }
38 a[p] = sum / narrays as f64;
39 }
40 (mub, a)
41}
42
43pub fn fitmixture(log2e: &Array2<f64>, mixprop: &[f64], niter: usize) -> FitMixture {
47 let nprobes = log2e.nrows();
48 let narrays = log2e.ncols();
49 let ln2 = std::f64::consts::LN_2;
50 let pm: Vec<f64> = mixprop.iter().map(|&m| 2.0 * m - 1.0).collect();
51
52 let mut xtx = [[0.0_f64; 2]; 2];
55 for &mp in mixprop {
56 let (x0, x1) = (mp, 1.0 - mp);
57 xtx[0][0] += x0 * x0;
58 xtx[0][1] += x0 * x1;
59 xtx[1][1] += x1 * x1;
60 }
61 xtx[1][0] = xtx[0][1];
62 let det = xtx[0][0] * xtx[1][1] - xtx[0][1] * xtx[1][0];
63 let inv = [
64 [xtx[1][1] / det, -xtx[0][1] / det],
65 [-xtx[1][0] / det, xtx[0][0] / det],
66 ];
67
68 let mut z = Array2::<f64>::zeros((nprobes, narrays));
69 let mut b = vec![0.0; nprobes];
70 for p in 0..nprobes {
71 let (mut xty0, mut xty1) = (0.0, 0.0);
72 for j in 0..narrays {
73 let l = log2e[[p, j]];
74 z[[p, j]] = l * ln2;
75 let y = l.exp2();
76 xty0 += mixprop[j] * y;
77 xty1 += (1.0 - mixprop[j]) * y;
78 }
79 let s0 = (inv[0][0] * xty0 + inv[0][1] * xty1).max(1.0);
80 let s1 = (inv[1][0] * xty0 + inv[1][1] * xty1).max(1.0);
81 b[p] = s0.ln() - s1.ln();
82 }
83
84 for _ in 0..niter {
86 let (mub, a) = mub_and_offset(&b, &z, &pm);
87 for p in 0..nprobes {
88 let th = (b[p] / 2.0).tanh();
89 let mut dmu = vec![0.0; narrays];
90 let mut dmu_mean = 0.0;
91 for j in 0..narrays {
92 dmu[j] = (th + pm[j]) / (1.0 + th * pm[j]) / 2.0;
93 dmu_mean += dmu[j];
94 }
95 dmu_mean /= narrays as f64;
96 let (mut num, mut den) = (0.0, 0.0);
97 for j in 0..narrays {
98 let mu = a[p] + mub[[p, j]];
99 num += dmu[j] * (z[[p, j]] - mu);
100 let dd = dmu[j] - dmu_mean;
101 den += dd * dd;
102 }
103 b[p] += (num / narrays as f64) / (1e-6 + den / narrays as f64);
104 }
105 }
106
107 let (mub, a) = mub_and_offset(&b, &z, &pm);
109 let scale = (narrays as f64 / (narrays as f64 - 2.0) / narrays as f64).sqrt();
110 let mut stdev = vec![0.0; nprobes];
111 for p in 0..nprobes {
112 let mut ss = 0.0;
113 for j in 0..narrays {
114 let r = z[[p, j]] - (a[p] + mub[[p, j]]);
115 ss += r * r;
116 }
117 stdev[p] = ss.sqrt() * scale / ln2;
118 }
119
120 FitMixture {
121 a: a.iter().map(|&v| v / ln2).collect(),
122 m: b.iter().map(|&v| v / ln2).collect(),
123 stdev,
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use ndarray::array;
131
132 fn close(a: &[f64], b: &[f64], tol: f64) -> bool {
133 a.len() == b.len()
134 && a.iter()
135 .zip(b)
136 .all(|(&x, &y)| (x - y).abs() <= tol + tol * y.abs())
137 }
138
139 #[test]
140 fn fitmixture_matches_r() {
141 let log2e = array![
144 [
145 4.3904760172447,
146 5.80204869625303,
147 6.58880543876819,
148 7.0317975199035,
149 7.34222007491416,
150 7.62233273319441
151 ],
152 [
153 8.61562128120492,
154 8.44189831411776,
155 8.13670008266351,
156 7.55564296226014,
157 6.81782517390952,
158 5.63099272063628
159 ],
160 [
161 6.34008451545423,
162 8.0396611674378,
163 8.73791188700199,
164 9.17095797713186,
165 9.66382947418352,
166 9.87762613040235
167 ],
168 [
169 6.03164313024805,
170 5.82338721016494,
171 5.36457218491286,
172 4.80397607212745,
173 4.32376814079242,
174 3.02300486774156
175 ],
176 [
177 4.92710401176557,
178 6.38918171782614,
179 7.10185838995849,
180 7.65096816800767,
181 8.03727417840249,
182 8.19681894669787
183 ]
184 ];
185 let mixprop = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
186 let o = fitmixture(&log2e, &mixprop, 4);
187 assert!(close(
188 &o.a,
189 &[
190 6.01927142601979,
191 7.13664451612898,
192 8.12130966885468,
193 4.51486665208882,
194 6.58763605034809
195 ],
196 1e-9
197 ));
198 assert!(close(
199 &o.m,
200 &[
201 3.25818790500491,
202 -3.0629176507105,
203 3.54802170625494,
204 -2.99367405470867,
205 3.33559775914428
206 ],
207 1e-9
208 ));
209 assert!(close(
210 &o.stdev,
211 &[
212 0.0360867991962196,
213 0.0785007679371339,
214 0.0508483212078436,
215 0.0857081603616798,
216 0.0535213699021616
217 ],
218 1e-9
219 ));
220 }
221}