1use anyhow::Result;
16use ndarray::Array1;
17
18use crate::ebayes::ebayes;
19use crate::fit::MArrayLM;
20use crate::fitgamma::fit_gamma_intercept;
21use crate::proptruenull::{prop_true_null, PropTrueNullMethod};
22
23fn is_varying(v: &Array1<f64>) -> bool {
27 let first = v[0];
28 v.iter().any(|&x| x != first)
29}
30
31fn prob_de(lods: f64) -> f64 {
34 if lods > 700.0 {
35 1.0
36 } else {
37 let e = lods.exp();
38 e / (1.0 + e)
39 }
40}
41
42pub fn pred_fcm(
47 fit: &MArrayLM,
48 coef: usize,
49 var_indep_of_fc: bool,
50 all_de: bool,
51 method: PropTrueNullMethod,
52) -> Result<Vec<f64>> {
53 let mut base = fit.clone();
56 if base.p_value.is_none() {
57 ebayes(&mut base, 0.01, (0.1, 4.0), false, false)?;
58 }
59 let ng = base.coefficients.nrows();
60
61 let pcol: Vec<f64> = base.p_value.as_ref().unwrap().column(coef).to_vec();
63 let mut p = 1.0 - prop_true_null(&pcol, method, 20);
64 if p == 0.0 {
65 p = 1e-8;
66 }
67
68 let trend = base.s2_prior.as_ref().is_some_and(is_varying);
72 let robust = base.df_prior.as_ref().is_some_and(is_varying);
73
74 let mut f = fit.clone();
76 ebayes(&mut f, p, (0.1, 4.0), trend, robust)?;
77
78 let v = f.cov_coefficients[[coef, coef]];
79 let beta: Array1<f64> = f.coefficients.column(coef).to_owned();
80 let s2post = f.s2_post.as_ref().expect("eBayes fills s2_post");
81
82 let a = p / (1.0 - p);
83
84 let pfc: Vec<f64> = if var_indep_of_fc {
85 let y2: Vec<f64> = beta.iter().map(|&b| b * b).collect();
86 let offset: Vec<f64> = s2post.iter().map(|&s| v * s).collect();
87 let mut v0 = fit_gamma_intercept(&y2, &offset, 1000);
88 if v0 < 0.0 {
89 v0 = 1e-8;
90 }
91 let base = beta
92 .iter()
93 .zip(s2post.iter())
94 .map(|(&b, &s)| b * v0 / (v0 + v * s));
95 if all_de {
96 base.collect()
97 } else {
98 base.zip(beta.iter().zip(s2post.iter()))
99 .map(|(pf, (&bb, &s))| {
100 let vs = v * s;
101 let bfac = (vs / (vs + v0)).sqrt();
102 let cfac = (bb * bb * v0 / (2.0 * v * v * s * s + 2.0 * v * v0 * s)).exp();
103 pf * prob_de((a * bfac * cfac).ln())
104 })
105 .collect()
106 }
107 } else {
108 let b2: Vec<f64> = beta
109 .iter()
110 .zip(s2post.iter())
111 .map(|(&b, &s)| b * b / s)
112 .collect();
113 let offset = vec![v; ng];
114 let v0 = fit_gamma_intercept(&b2, &offset, 1000).min(1e-8);
116 let base = beta.iter().map(|&b| b * v0 / (v0 + v));
117 if all_de {
118 base.collect()
119 } else {
120 let bfac = (v / (v + v0)).sqrt();
121 base.zip(beta.iter().zip(s2post.iter()))
122 .map(|(pf, (&bb, &s))| {
123 let cfac = (bb * bb * v0 / (2.0 * v * v * s + 2.0 * v * v0 * s)).exp();
124 pf * prob_de((a * bfac * cfac).ln())
125 })
126 .collect()
127 }
128 };
129
130 Ok(pfc)
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use crate::fit::lmfit;
137 use ndarray::{array, Array2};
138
139 fn build_fit() -> MArrayLM {
140 let y: Array2<f64> = array![
141 [
142 2.28724716134052,
143 0.839750359624071,
144 1.21855053450735,
145 2.57637147386057,
146 2.8425853501473,
147 3.01905836527295
148 ],
149 [
150 -1.19677168222235,
151 0.7053418309055,
152 -0.699317078685514,
153 -1.84084472173727,
154 -1.99575176374787,
155 -1.41246029509811
156 ],
157 [
158 -0.694292510435459,
159 1.30596472081169,
160 -0.285432751528726,
161 2.3436741847097,
162 1.82921984200165,
163 1.72066693894479
164 ],
165 [
166 -0.412292951136803,
167 -1.38799621659285,
168 -1.31155267260939,
169 -0.795192647386953,
170 -1.8934234291213,
171 -2.67436101486938
172 ],
173 [
174 -0.970673341119483,
175 1.27291686425524,
176 -0.391012431449258,
177 3.31896914255711,
178 2.20729543721107,
179 3.30872211763837
180 ],
181 [
182 -0.947279945228108,
183 0.184192771235767,
184 -0.401526613094972,
185 -1.49075021102894,
186 -2.91170186529517,
187 -4.20387854268349
188 ],
189 [
190 0.748139340290551,
191 0.752279895740033,
192 1.35051758092295,
193 0.769154194657112,
194 -0.34606859178086,
195 0.991289625052712
196 ],
197 [
198 -0.116955225887152,
199 0.591745052462727,
200 0.591190027089221,
201 1.15347367477894,
202 -0.304607588245324,
203 1.02322044470037
204 ],
205 [
206 0.152657626282234,
207 -0.983052595771021,
208 0.100525455628569,
209 1.26068350268094,
210 -1.78589348744452,
211 0.840145438883483
212 ],
213 [
214 2.18997810732938,
215 -0.276063955112006,
216 0.931071995520097,
217 0.700623506572324,
218 0.58727467185797,
219 0.12007860795572
220 ],
221 [
222 0.356986230329022,
223 -0.870851022568591,
224 -0.262742348566532,
225 0.432627160845955,
226 1.63579443444659,
227 -0.426255055445707
228 ],
229 [
230 2.71675178313072,
231 0.718710553084245,
232 -0.00766810471266396,
233 -0.922601718256921,
234 -0.645423473634397,
235 0.458926243504027
236 ],
237 [
238 2.28145192598956,
239 0.110652877769336,
240 0.367153006545634,
241 -0.615584206630919,
242 0.61899216878734,
243 0.645047947693915
244 ],
245 [
246 0.324020540138516,
247 -0.0784667679717042,
248 1.70716254513761,
249 -0.866659688251375,
250 0.236393598401322,
251 0.611530549290199
252 ],
253 [
254 1.89606706680993,
255 -0.420490459341998,
256 0.72374026252838,
257 -1.63951708718114,
258 0.846500898751643,
259 -0.889211293868794
260 ],
261 [
262 0.467680511321698,
263 -0.562125876285266,
264 0.481036048707917,
265 -1.32583924384341,
266 -0.573645738849693,
267 1.54389234869222
268 ],
269 [
270 -0.893800723085444,
271 0.997513444755305,
272 -1.56786824422525,
273 -0.88903672763797,
274 1.11799320399617,
275 -1.24176360488504
276 ],
277 [
278 -0.307328299537195,
279 -1.10513005881326,
280 0.318250283480828,
281 -0.557602330302113,
282 -1.54000113193302,
283 1.10344734039128
284 ],
285 [
286 -0.00482242226757041,
287 -0.142287830774585,
288 0.16599145067735,
289 -0.0624023088383481,
290 -0.438123899300085,
291 0.982772356675575
292 ],
293 [
294 0.988164149499945,
295 0.314994904887913,
296 -0.899907629628172,
297 2.42269297715943,
298 -0.150672970896448,
299 0.304327174033201
300 ],
301 ];
302 let design: Array2<f64> = array![
303 [1.0, 0.0],
304 [1.0, 0.0],
305 [1.0, 0.0],
306 [1.0, 1.0],
307 [1.0, 1.0],
308 [1.0, 1.0],
309 ];
310 lmfit(&y, &design, vec![String::new(); 20], vec![String::new(); 2]).unwrap()
311 }
312
313 fn assert_close(got: &[f64], want: &[f64]) {
314 assert_eq!(got.len(), want.len());
315 for (g, w) in got.iter().zip(want.iter()) {
316 assert!((g - w).abs() <= 1e-7 * w.abs() + 1e-13, "got {g}, want {w}");
317 }
318 }
319
320 #[test]
321 fn pred_fcm_matches_r() {
322 let fit = build_fit();
323
324 let vi1_ad1 = [
326 0.806899941410833,
327 -0.800165248392108,
328 1.09769272440758,
329 -0.443849844795292,
330 1.75947115517052,
331 -1.46726184458905,
332 -0.28324270522318,
333 0.158937728669327,
334 0.206001182143709,
335 -0.283330991201421,
336 0.476902630464977,
337 -0.894526143117239,
338 -0.416180722805507,
339 -0.38870547681697,
340 -0.765312904916374,
341 -0.14633414417854,
342 0.0889910610860675,
343 0.0197269550114635,
344 0.0913603309470277,
345 0.428463022546148,
346 ];
347 let got = pred_fcm(&fit, 1, true, true, PropTrueNullMethod::Lfdr).unwrap();
348 assert_close(&got, &vi1_ad1);
349
350 let got_mean = pred_fcm(&fit, 1, true, true, PropTrueNullMethod::Mean).unwrap();
352 assert_close(&got_mean, &vi1_ad1);
353
354 let vi1_ad0 = [
356 0.270478766277483,
357 -0.265522699426538,
358 0.57398255588566,
359 -0.093477821284407,
360 1.65297623275505,
361 -1.18175391914552,
362 -0.0523086424303369,
363 0.0275529195325538,
364 0.0364280464603553,
365 -0.0523279383852657,
366 0.103863229886988,
367 -0.342837057870809,
368 -0.085364720344582,
369 -0.0777824277550811,
370 -0.241150197241235,
371 -0.0252550990394734,
372 0.0151196325151701,
373 0.00332235440767211,
374 0.0155298924292991,
375 0.0889042770881764,
376 ];
377 let got = pred_fcm(&fit, 1, true, false, PropTrueNullMethod::Lfdr).unwrap();
378 assert_close(&got, &vi1_ad0);
379
380 let vi0_ad1 = [
382 2.04623353621094e-08,
383 -2.02915489485312e-08,
384 2.78366071164941e-08,
385 -1.12556760863579e-08,
386 4.46187773593186e-08,
387 -3.72085836014727e-08,
388 -7.18280783738075e-09,
389 4.03053332738792e-09,
390 5.22402476154024e-09,
391 -7.18504669898158e-09,
392 1.20938682218567e-08,
393 -2.2684465559181e-08,
394 -1.05540093439608e-08,
395 -9.85725914146247e-09,
396 -1.94077214703572e-08,
397 -3.71091653306243e-09,
398 2.25674193629162e-09,
399 5.00259757625001e-10,
400 2.31682471975736e-09,
401 1.08654786147003e-08,
402 ];
403 let got = pred_fcm(&fit, 1, false, true, PropTrueNullMethod::Lfdr).unwrap();
404 assert_close(&got, &vi0_ad1);
405
406 let vi0_ad0 = [
408 4.92171467490767e-09,
409 -4.88063617568995e-09,
410 6.69541561943082e-09,
411 -2.70727776232544e-09,
412 1.07319571258011e-08,
413 -8.94961577592611e-09,
414 -1.72764885189419e-09,
415 9.69446269885002e-10,
416 1.25651146134861e-09,
417 -1.72818735532886e-09,
418 2.90888439514644e-09,
419 -5.45619380724241e-09,
420 -2.53850897651131e-09,
421 -2.3709227445188e-09,
422 -4.66805306961375e-09,
423 -8.92570262603364e-10,
424 5.42804108910441e-10,
425 1.20325256308392e-10,
426 5.57255553814274e-10,
427 2.61342529709022e-09,
428 ];
429 let got = pred_fcm(&fit, 1, false, false, PropTrueNullMethod::Lfdr).unwrap();
430 assert_close(&got, &vi0_ad0);
431 }
432}