1use crate::special::ln_norm_cdf;
15use ndarray::Array2;
16use std::f64::consts::PI;
17
18const LN_2PI: f64 = 1.837_877_066_409_345_3;
20
21#[derive(Clone, Debug)]
23pub struct NormexpFit {
24 pub par: [f64; 3],
26 pub m2loglik: f64,
28}
29
30#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum BackgroundMethod {
33 None,
34 Subtract,
35 Half,
36 Minimum,
37 MovingMin,
38 Edwards,
39 Normexp,
40}
41
42pub fn normexp_signal(par: &[f64; 3], x: &[f64]) -> Vec<f64> {
46 let mu = par[0];
47 let sigma = par[1].exp();
48 let sigma2 = sigma * sigma;
49 let alpha = par[2].exp();
50 assert!(alpha > 0.0, "alpha must be positive");
51 assert!(sigma > 0.0, "sigma must be positive");
52
53 let mut signal: Vec<f64> = x
54 .iter()
55 .map(|&xi| {
56 let mu_sf = xi - mu - sigma2 / alpha;
57 let z = mu_sf / sigma;
58 let log_dnorm0 = -0.5 * LN_2PI - sigma.ln() - 0.5 * z * z;
60 let log_pupper = ln_norm_cdf(z);
62 mu_sf + sigma2 * (log_dnorm0 - log_pupper).exp()
63 })
64 .collect();
65
66 let any_neg = signal.iter().any(|&s| !s.is_nan() && s < 0.0);
68 if any_neg {
69 for s in signal.iter_mut() {
70 if !s.is_nan() {
71 *s = s.max(1e-6);
72 }
73 }
74 }
75 signal
76}
77
78pub fn normexp_fit_saddle(x: &[f64]) -> NormexpFit {
81 let xv: Vec<f64> = x.iter().copied().filter(|v| !v.is_nan()).collect();
82 assert!(
83 xv.len() >= 4,
84 "Not enough data: need at least 4 non-missing corrected intensities"
85 );
86
87 let q = quantile_type7(&xv, &[0.0, 0.05, 0.10, 1.0]);
89 if q[0] == q[3] {
90 return NormexpFit {
91 par: [q[0], f64::NEG_INFINITY, f64::NEG_INFINITY],
92 m2loglik: f64::NAN,
93 };
94 }
95 let mu = if q[1] > q[0] {
96 q[1]
97 } else if q[2] > q[0] {
98 q[2]
99 } else {
100 q[0] + 0.05 * (q[3] - q[0])
101 };
102 let below: Vec<f64> = xv.iter().copied().filter(|&v| v < mu).collect();
103 let sigma2 = below.iter().map(|&v| (v - mu) * (v - mu)).sum::<f64>() / below.len() as f64;
104 let mut alpha = xv.iter().sum::<f64>() / xv.len() as f64 - mu;
105 if alpha <= 0.0 {
106 alpha = 1e-6;
107 }
108 let par0 = [mu, 0.5 * sigma2.ln(), alpha.ln()];
109
110 let (par, m2loglik) = nmmin(
114 &par0,
115 |p| normexp_m2loglik_saddle(p, &xv),
116 -1e308,
117 1.490_116e-08,
118 1.0,
119 0.5,
120 2.0,
121 500,
122 );
123 NormexpFit {
124 par: [par[0], par[1], par[2]],
125 m2loglik,
126 }
127}
128
129fn normexp_m2loglik_saddle(par: &[f64], x: &[f64]) -> f64 {
133 let mu = par[0];
134 let sigma = par[1].exp();
135 let sigma2 = sigma * sigma;
136 let alpha = par[2].exp();
137 let alpha2 = alpha * alpha;
138 let alpha3 = alpha * alpha2;
139 let alpha4 = alpha2 * alpha2;
140 let n = x.len();
141 let c2 = sigma2 * alpha;
142
143 let mut upperbound = vec![0.0; n];
144 let mut theta = vec![0.0; n];
145 let mut has_converged = vec![false; n];
146
147 for i in 0..n {
148 let err = x[i] - mu;
149 let upperbound1 = ((err - alpha) / (alpha * err.abs())).max(0.0);
150 let upperbound2 = err / sigma2;
151 upperbound[i] = upperbound1.min(upperbound2);
152 let c1 = -sigma2 - err * alpha;
153 let c0 = -alpha + err;
154 let theta_quadratic = (-c1 - (c1 * c1 - 4.0 * c0 * c2).sqrt()) / (2.0 * c2);
155 theta[i] = theta_quadratic.min(upperbound[i]);
156 }
157
158 let mut j = 0;
160 let mut n_converged = 0usize;
161 loop {
162 j += 1;
163 for i in 0..n {
164 if has_converged[i] {
165 continue;
166 }
167 let omat = 1.0 - alpha * theta[i];
168 let dk = mu + sigma2 * theta[i] + alpha / omat;
169 let ddk = sigma2 + alpha2 / (omat * omat);
170 let delta = (x[i] - dk) / ddk;
171 theta[i] += delta;
172 if j == 1 {
173 theta[i] = theta[i].min(upperbound[i]);
174 }
175 if delta.abs() < 1e-10 {
176 has_converged[i] = true;
177 n_converged += 1;
178 }
179 }
180 if n_converged == n || j > 50 {
181 break;
182 }
183 }
184
185 let mut loglik = 0.0;
186 for i in 0..n {
187 let omat = 1.0 - alpha * theta[i];
188 let omat2 = omat * omat;
189 let k1 = mu * theta[i] + 0.5 * sigma2 * theta[i] * theta[i] - omat.ln();
190 let k2 = sigma2 + alpha2 / omat2;
191 let mut logf = -0.5 * (2.0 * PI * k2).ln() - x[i] * theta[i] + k1;
192 let k3 = 2.0 * alpha3 / (omat * omat2);
193 let k4 = 6.0 * alpha4 / (omat2 * omat2);
194 logf += k4 / (8.0 * k2 * k2) - (5.0 * k3 * k3) / (24.0 * k2 * k2 * k2);
195 loglik += logf;
196 }
197 -2.0 * loglik
198}
199
200#[allow(clippy::too_many_arguments, clippy::needless_range_loop)]
204fn nmmin(
205 start: &[f64],
206 objective: impl Fn(&[f64]) -> f64,
207 abstol: f64,
208 intol: f64,
209 refl: f64,
210 contract: f64,
211 extend: f64,
212 maxit: i32,
213) -> (Vec<f64>, f64) {
214 const BIG: f64 = 1.0e35;
215 let n = start.len();
216 let n1 = n + 1; let c = n + 2; let mut p = vec![vec![0.0_f64; c]; n1];
222 let mut bvec = start.to_vec();
223
224 let f0 = objective(&bvec);
225 let mut funcount = 1i32;
226 let convtol = intol * (f0.abs() + intol);
227 p[n1 - 1][0] = f0;
228 for i in 0..n {
229 p[i][0] = bvec[i];
230 }
231
232 let mut l = 1usize;
233 let mut size = 0.0;
234 let mut step = 0.0;
235 for i in 0..n {
236 if 0.1 * bvec[i].abs() > step {
237 step = 0.1 * bvec[i].abs();
238 }
239 }
240 if step == 0.0 {
241 step = 0.1;
242 }
243
244 let mut vl;
245 let mut vh;
246 let mut h;
247 let mut shrinkfail = false;
248
249 'outer: loop {
250 for j in 2..=n1 {
252 for i in 0..n {
253 p[i][j - 1] = bvec[i];
254 }
255 let mut trystep = step;
256 while p[j - 2][j - 1] == bvec[j - 2] {
257 p[j - 2][j - 1] = bvec[j - 2] + trystep;
258 trystep *= 10.0;
259 }
260 size += trystep;
261 }
262 let mut oldsize = size;
263 let mut calcvert = true;
264
265 loop {
266 if calcvert {
267 for j in 0..n1 {
268 if j + 1 != l {
269 for i in 0..n {
270 bvec[i] = p[i][j];
271 }
272 let mut fv = objective(&bvec);
273 if !fv.is_finite() {
274 fv = BIG;
275 }
276 funcount += 1;
277 p[n1 - 1][j] = fv;
278 }
279 }
280 calcvert = false;
281 }
282
283 vl = p[n1 - 1][l - 1];
285 vh = vl;
286 h = l;
287 for j in 1..=n1 {
288 if j != l {
289 let fj = p[n1 - 1][j - 1];
290 if fj < vl {
291 l = j;
292 vl = fj;
293 }
294 if fj > vh {
295 h = j;
296 vh = fj;
297 }
298 }
299 }
300
301 if vh > vl + convtol && vl > abstol {
302 for i in 0..n {
304 let mut temp = -p[i][h - 1];
305 for j in 0..n1 {
306 temp += p[i][j];
307 }
308 p[i][c - 1] = temp / n as f64;
309 }
310 for i in 0..n {
312 bvec[i] = (1.0 + refl) * p[i][c - 1] - refl * p[i][h - 1];
313 }
314 let mut vr = objective(&bvec);
315 if !vr.is_finite() {
316 vr = BIG;
317 }
318 funcount += 1;
319
320 if vr < vl {
321 p[n1 - 1][c - 1] = vr;
323 for i in 0..n {
324 let fe = extend * bvec[i] + (1.0 - extend) * p[i][c - 1];
325 p[i][c - 1] = bvec[i];
326 bvec[i] = fe;
327 }
328 let mut fe = objective(&bvec);
329 if !fe.is_finite() {
330 fe = BIG;
331 }
332 funcount += 1;
333 if fe < vr {
334 for i in 0..n {
335 p[i][h - 1] = bvec[i];
336 }
337 p[n1 - 1][h - 1] = fe;
338 } else {
339 for i in 0..n {
340 p[i][h - 1] = p[i][c - 1];
341 }
342 p[n1 - 1][h - 1] = vr;
343 }
344 } else {
345 if vr < vh {
346 for i in 0..n {
348 p[i][h - 1] = bvec[i];
349 }
350 p[n1 - 1][h - 1] = vr;
351 }
352 for i in 0..n {
354 bvec[i] = (1.0 - contract) * p[i][h - 1] + contract * p[i][c - 1];
355 }
356 let mut fc = objective(&bvec);
357 if !fc.is_finite() {
358 fc = BIG;
359 }
360 funcount += 1;
361 if fc < p[n1 - 1][h - 1] {
362 for i in 0..n {
363 p[i][h - 1] = bvec[i];
364 }
365 p[n1 - 1][h - 1] = fc;
366 } else if vr >= vh {
367 calcvert = true;
369 size = 0.0;
370 for j in 0..n1 {
371 if j + 1 != l {
372 for i in 0..n {
373 p[i][j] = contract * (p[i][j] - p[i][l - 1]) + p[i][l - 1];
374 size += (p[i][j] - p[i][l - 1]).abs();
375 }
376 }
377 }
378 if size < oldsize {
379 shrinkfail = false;
380 oldsize = size;
381 } else {
382 shrinkfail = true;
383 }
384 }
385 }
386 }
387
388 if !(vh > vl + convtol && vl > abstol && !shrinkfail && funcount <= maxit) {
389 break;
390 }
391 }
392
393 if shrinkfail && funcount <= maxit && vh > vl + convtol && vl > abstol {
396 for i in 0..n {
397 bvec[i] = p[i][l - 1];
398 }
399 shrinkfail = false;
400 continue 'outer;
401 }
402 break 'outer;
403 }
404
405 let fmin = p[n1 - 1][l - 1];
406 let best: Vec<f64> = (0..n).map(|i| p[i][l - 1]).collect();
407 (best, fmin)
408}
409
410pub fn background_correct_matrix(
414 e: &Array2<f64>,
415 eb: Option<&Array2<f64>>,
416 method: BackgroundMethod,
417 offset: f64,
418) -> Array2<f64> {
419 use BackgroundMethod::*;
420
421 let method = if eb.is_none() {
422 match method {
423 Subtract | Half | Minimum | MovingMin | Edwards => None,
424 other => other,
425 }
426 } else {
427 method
428 };
429
430 let mut out = match method {
431 None => e.clone(),
432 Subtract => e - eb.unwrap(),
433 Half => (e - eb.unwrap()).mapv(|v| v.max(0.5)),
434 Minimum => {
435 let mut m = e - eb.unwrap();
436 for mut col in m.columns_mut() {
437 let lo = col
438 .iter()
439 .copied()
440 .filter(|&v| v >= 1e-18)
441 .fold(f64::INFINITY, f64::min);
442 if lo.is_finite() {
443 for v in col.iter_mut() {
444 if *v < 1e-18 {
445 *v = lo / 2.0;
446 }
447 }
448 }
449 }
450 m
451 }
452 MovingMin => e - &ma3x3_min(eb.unwrap()),
453 Edwards => edwards(e, eb.unwrap()),
454 Normexp => {
455 let eb_sub = match eb {
456 Some(bg) => e - bg,
457 Option::None => e.clone(),
458 };
459 let ncol = eb_sub.ncols();
460 let solve = |j: usize| -> Vec<f64> {
466 let x: Vec<f64> = eb_sub.column(j).to_vec();
467 let fit = normexp_fit_saddle(&x);
468 normexp_signal(&fit.par, &x)
469 };
470 #[cfg(feature = "parallel")]
471 let cols: Vec<Vec<f64>> = {
472 use rayon::prelude::*;
473 (0..ncol).into_par_iter().map(solve).collect()
474 };
475 #[cfg(not(feature = "parallel"))]
476 let cols: Vec<Vec<f64>> = (0..ncol).map(solve).collect();
477 let mut m = eb_sub;
480 for (j, sig) in cols.into_iter().enumerate() {
481 for (i, s) in sig.into_iter().enumerate() {
482 m[[i, j]] = s;
483 }
484 }
485 m
486 }
487 };
488
489 if offset != 0.0 {
490 out.mapv_inplace(|v| v + offset);
491 }
492 out
493}
494
495fn edwards(e: &Array2<f64>, eb: &Array2<f64>) -> Array2<f64> {
497 let sub = e - eb;
498 let mut out = sub.clone();
499 for (j, col) in sub.columns().into_iter().enumerate() {
500 let d: Vec<f64> = col.to_vec();
501 let frac = d.iter().filter(|&&v| v < 1e-16).count() as f64 / d.len() as f64;
502 let prob = (frac * 1.1).min(1.0);
503 let delta = quantile_type7(&d, &[prob])[0];
504 for i in 0..d.len() {
505 let s = sub[[i, j]];
506 out[[i, j]] = if s < delta {
507 delta * (1.0 - (eb[[i, j]] + delta) / e[[i, j]]).exp()
508 } else {
509 s
510 };
511 }
512 }
513 out
514}
515
516fn ma3x3_min(x: &Array2<f64>) -> Array2<f64> {
519 let (nr, nc) = x.dim();
520 let mut out = Array2::<f64>::zeros((nr, nc));
521 for r in 0..nr {
522 for col in 0..nc {
523 let mut m = f64::INFINITY;
524 for dr in -1i64..=1 {
525 for dc in -1i64..=1 {
526 let rr = r as i64 + dr;
527 let cc = col as i64 + dc;
528 if rr >= 0 && rr < nr as i64 && cc >= 0 && cc < nc as i64 {
529 let v = x[[rr as usize, cc as usize]];
530 if !v.is_nan() && v < m {
531 m = v;
532 }
533 }
534 }
535 }
536 out[[r, col]] = m;
537 }
538 }
539 out
540}
541
542fn quantile_type7(x: &[f64], probs: &[f64]) -> Vec<f64> {
544 let mut s: Vec<f64> = x.iter().copied().filter(|v| !v.is_nan()).collect();
545 s.sort_by(|a, b| a.partial_cmp(b).unwrap());
546 let n = s.len();
547 probs
548 .iter()
549 .map(|&p| {
550 if n == 0 {
551 return f64::NAN;
552 }
553 if n == 1 {
554 return s[0];
555 }
556 let hpos = (n as f64 - 1.0) * p;
557 let lo = hpos.floor() as usize;
558 let hi = (lo + 1).min(n - 1);
559 s[lo] + (hpos - lo as f64) * (s[hi] - s[lo])
560 })
561 .collect()
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567 use ndarray::Array2;
568
569 fn xvec() -> Vec<f64> {
571 vec![
572 2.1, 3.5, 2.8, 10.2, 5.6, 4.1, 8.9, 3.3, 6.7, 12.5, 2.9, 4.8, 7.2, 3.1, 9.4, 5.0, 4.4,
573 6.1, 3.8, 11.0, 2.5, 5.9, 7.7, 4.6,
574 ]
575 }
576
577 fn emat() -> Array2<f64> {
578 Array2::from_shape_vec((12, 2), {
579 let x = xvec();
580 let mut v = vec![0.0; 24];
582 for i in 0..12 {
583 v[i * 2] = x[i];
584 v[i * 2 + 1] = x[12 + i];
585 }
586 v
587 })
588 .unwrap()
589 }
590
591 fn ebmat() -> Array2<f64> {
592 let col_major = [
593 1.0, 4.0, 2.0, 3.0, 1.5, 4.5, 2.0, 3.5, 1.0, 2.0, 3.0, 5.0, 2.0, 3.5, 1.0, 2.0, 1.0,
594 2.0, 4.0, 1.5, 3.0, 1.0, 2.0, 5.0,
595 ];
596 let mut v = vec![0.0; 24];
597 for i in 0..12 {
598 v[i * 2] = col_major[i];
599 v[i * 2 + 1] = col_major[12 + i];
600 }
601 Array2::from_shape_vec((12, 2), v).unwrap()
602 }
603
604 fn col_major(m: &Array2<f64>) -> Vec<f64> {
605 let (nr, nc) = m.dim();
606 let mut v = Vec::with_capacity(nr * nc);
607 for j in 0..nc {
608 for i in 0..nr {
609 v.push(m[[i, j]]);
610 }
611 }
612 v
613 }
614
615 fn assert_close(a: &[f64], b: &[f64], tol: f64) {
616 assert_eq!(a.len(), b.len(), "length mismatch");
617 for (k, (&x, &y)) in a.iter().zip(b.iter()).enumerate() {
618 assert!(
619 (x - y).abs() <= tol + tol * y.abs(),
620 "index {k}: got {x}, want {y} (diff {})",
621 (x - y).abs()
622 );
623 }
624 }
625
626 #[test]
627 fn normexp_fit_saddle_matches_r() {
628 let fit = normexp_fit_saddle(&xvec());
629 assert!((fit.par[0] - 2.367_538_459_991_54).abs() < 1e-5);
631 assert!((fit.par[1] - -1.030_352_153_905_64).abs() < 1e-5);
632 assert!((fit.par[2] - 1.219_722_868_953_23).abs() < 1e-5);
633 assert!((fit.m2loglik - 111.423_810_138_321).abs() < 1e-5);
634 }
635
636 #[test]
637 fn normexp_signal_matches_r() {
638 let par = [
639 2.367_538_459_991_54,
640 -1.030_352_153_905_64,
641 1.219_722_868_953_23,
642 ];
643 let sig = normexp_signal(&par, &xvec());
644 let want = [
645 0.198163450107075,
646 1.09613822556942,
647 0.484025139654586,
648 7.79484935368556,
649 3.19484935368556,
650 1.69485115668388,
651 6.49484935368557,
652 0.901027384489598,
653 4.29484935368557,
654 10.0948493536856,
655 0.554205422821398,
656 2.39484935370929,
657 4.79484935368557,
658 0.716807903958157,
659 6.99484935368557,
660 2.59484935368604,
661 1.99484937706265,
662 3.69484935368556,
663 1.39491795493823,
664 8.59484935368556,
665 0.322091491194509,
666 3.49484935368557,
667 5.29484935368557,
668 2.19484935455685,
669 ];
670 assert_close(&sig, &want, 1e-6);
671 }
672
673 #[test]
674 fn background_correct_normexp_offsets_match_r() {
675 let e = emat();
676 let bc0 = background_correct_matrix(&e, None, BackgroundMethod::Normexp, 0.0);
677 let want0 = [
678 1.99403319184388e-07,
679 1.40000019940332,
680 0.700000199403319,
681 8.10000019940332,
682 3.50000019940332,
683 2.00000019940332,
684 6.80000019940332,
685 1.20000019940332,
686 4.60000019940332,
687 10.4000001994033,
688 0.800000199403319,
689 2.70000019940332,
690 3.38853053495406,
691 0.614450546313025,
692 5.58655635912811,
693 1.43189569191768,
694 1.07932597446681,
695 2.32320650063882,
696 0.821622365173466,
697 7.18655615956755,
698 0.492789993991968,
699 2.142117281233,
700 3.88691863546145,
701 1.18562095914263,
702 ];
703 assert_close(&col_major(&bc0), &want0, 1e-6);
704
705 let bc16 = background_correct_matrix(&e, None, BackgroundMethod::Normexp, 16.0);
706 let want16: Vec<f64> = want0.iter().map(|v| v + 16.0).collect();
707 assert_close(&col_major(&bc16), &want16, 1e-6);
708 }
709
710 #[test]
711 fn background_correct_eb_methods_match_r() {
712 let e = emat();
713 let eb = ebmat();
714
715 let sub = background_correct_matrix(&e, Some(&eb), BackgroundMethod::Subtract, 0.0);
716 assert_close(
717 &col_major(&sub),
718 &[
719 1.1, -0.5, 0.8, 7.2, 4.1, -0.4, 6.9, -0.2, 5.7, 10.5, -0.1, -0.2, 5.2, -0.4, 8.4,
720 3.0, 3.4, 4.1, -0.2, 9.5, -0.5, 4.9, 5.7, -0.4,
721 ],
722 1e-12,
723 );
724
725 let half = background_correct_matrix(&e, Some(&eb), BackgroundMethod::Half, 0.0);
726 assert_close(
727 &col_major(&half),
728 &[
729 1.1, 0.5, 0.8, 7.2, 4.1, 0.5, 6.9, 0.5, 5.7, 10.5, 0.5, 0.5, 5.2, 0.5, 8.4, 3.0,
730 3.4, 4.1, 0.5, 9.5, 0.5, 4.9, 5.7, 0.5,
731 ],
732 1e-12,
733 );
734
735 let mn = background_correct_matrix(&e, Some(&eb), BackgroundMethod::Minimum, 0.0);
736 assert_close(
737 &col_major(&mn),
738 &[
739 1.1, 0.4, 0.8, 7.2, 4.1, 0.4, 6.9, 0.4, 5.7, 10.5, 0.4, 0.4, 5.2, 1.5, 8.4, 3.0,
740 3.4, 4.1, 1.5, 9.5, 1.5, 4.9, 5.7, 1.5,
741 ],
742 1e-12,
743 );
744
745 let ed = background_correct_matrix(&e, Some(&eb), BackgroundMethod::Edwards, 0.0);
746 assert_close(
747 &col_major(&ed),
748 &[
749 1.1,
750 0.558422539017665,
751 0.808880852322533,
752 7.2,
753 4.1,
754 0.604489443607101,
755 6.9,
756 0.597824798774299,
757 5.7,
758 10.5,
759 0.593157962311005,
760 0.657982551965439,
761 5.2,
762 1.00197356183341,
763 8.4,
764 3.00530848233458,
765 3.4,
766 4.1,
767 1.29360494333044,
768 9.5,
769 0.73912631360648,
770 4.9,
771 5.7,
772 1.43478914370254,
773 ],
774 1e-9,
775 );
776
777 let mm = background_correct_matrix(&e, Some(&eb), BackgroundMethod::MovingMin, 0.0);
778 assert_close(
779 &col_major(&mm),
780 &[
781 1.1, 2.5, 1.8, 9.2, 4.6, 3.1, 7.4, 2.3, 5.7, 11.5, 1.9, 2.8, 6.2, 2.1, 8.4, 4.0,
782 3.4, 5.1, 2.3, 10.0, 1.5, 4.9, 6.7, 2.6,
783 ],
784 1e-12,
785 );
786
787 let nx = background_correct_matrix(&e, Some(&eb), BackgroundMethod::Normexp, 0.0);
788 let want_nx = [
789 1.60000003144471,
790 3.1444714387814e-08,
791 1.30000003144471,
792 7.70000003144471,
793 4.60000003144471,
794 0.100000031444714,
795 7.40000003144471,
796 0.300000031444714,
797 6.20000003144471,
798 11.0000000314447,
799 0.400000031444714,
800 0.300000031444714,
801 5.70004586914821,
802 0.100045869148213,
803 8.90004586914821,
804 3.50004586914821,
805 3.90004586914821,
806 4.60004586914821,
807 0.300045869148213,
808 10.0000458691482,
809 4.58691482126749e-05,
810 5.40004586914821,
811 6.20004586914821,
812 0.100045869148212,
813 ];
814 assert_close(&col_major(&nx), &want_nx, 1e-6);
815 }
816}