1use std::collections::HashMap;
173use std::convert::Infallible;
174use std::sync::{Arc, Mutex, OnceLock};
175
176use crate::estimate::EstimationError;
177use crate::mixture_link::{
178 beta_logistic_inverse_link_jet, component_inverse_link_jet, sas_inverse_link_jet,
179};
180use gam_math::probability::erfcx_nonnegative;
181use gam_math::special::stable_polynomial_times_exp_neg as cloglog_stable_poly_times_exp_neg;
182use gam_problem::types::{
183 InverseLink, LikelihoodScaleMetadata, LikelihoodSpec, LinkComponent, LinkFunction,
184 MixtureLinkState, ResponseFamily, SasLinkState, StandardLink,
185};
186use statrs::function::erf::erfc;
187
188const N_POINTS: usize = 7;
190const SQRT_2: f64 = std::f64::consts::SQRT_2;
191const QUADRATURE_EXP_LOG_MAX: f64 = 700.0;
192
193#[inline]
196fn safe_exp(x: f64) -> f64 {
197 if x.is_nan() {
198 f64::NAN
199 } else {
200 x.min(QUADRATURE_EXP_LOG_MAX).exp()
201 }
202}
203
204#[inline]
205fn safe_expwith_saturation(x: f64) -> (f64, bool) {
206 (safe_exp(x), x > QUADRATURE_EXP_LOG_MAX)
207}
208
209#[derive(Clone, Copy, Debug, Default)]
210struct Complex {
211 re: f64,
212 im: f64,
213}
214
215pub struct QuadratureContext {
217 gh_cache: OnceLock<GaussHermiteRule>,
218 gh15_cache: OnceLock<GaussHermiteRuleDynamic>,
219 gh21_cache: OnceLock<GaussHermiteRuleDynamic>,
220 gh31_cache: OnceLock<GaussHermiteRuleDynamic>,
221 gh51_cache: OnceLock<GaussHermiteRuleDynamic>,
222 cc_cache: Mutex<HashMap<usize, Arc<ClenshawCurtisRule>>>,
226}
227
228#[derive(Clone, Copy, Debug, Eq, PartialEq)]
229pub enum IntegratedExpectationMode {
230 ExactClosedForm,
231 ExactSpecialFunction,
232 ControlledAsymptotic,
233 QuadratureFallback,
234}
235
236impl IntegratedExpectationMode {
237 #[inline]
241 pub const fn rank(self) -> u8 {
242 match self {
243 Self::ExactClosedForm => 0,
244 Self::ExactSpecialFunction => 1,
245 Self::ControlledAsymptotic => 2,
246 Self::QuadratureFallback => 3,
247 }
248 }
249}
250
251#[derive(Clone, Copy, Debug)]
252pub struct IntegratedMeanDerivative {
253 pub mean: f64,
254 pub dmean_dmu: f64,
255 pub mode: IntegratedExpectationMode,
256}
257
258#[derive(Clone, Copy, Debug)]
259pub struct IntegratedInverseLinkJet {
260 pub mean: f64,
261 pub d1: f64,
262 pub d2: f64,
263 pub d3: f64,
264 pub mode: IntegratedExpectationMode,
265}
266
267#[derive(Clone, Copy, Debug)]
268pub(crate) struct IntegratedInverseLinkJet5 {
269 pub mean: f64,
270 pub d1: f64,
271 pub d2: f64,
272 pub d3: f64,
273 pub d4: f64,
274 pub d5: f64,
275 pub mode: IntegratedExpectationMode,
276}
277
278#[inline]
279pub(crate) fn validate_latent_cloglog_inputs(eta: f64, sigma: f64) -> Result<(), EstimationError> {
280 if !eta.is_finite() || !sigma.is_finite() || sigma < 0.0 {
281 crate::bail_invalid_estim!(
282 "latent cloglog jet requires finite eta and sigma >= 0, got eta={eta}, sigma={sigma}"
283 );
284 }
285 Ok::<(), _>(())
286}
287
288#[derive(Clone, Copy, Debug)]
293pub struct IntegratedMomentsJet {
294 pub mean: f64,
295 pub variance: f64,
296 pub d1: f64,
297 pub d2: f64,
298 pub d3: f64,
299 pub mode: IntegratedExpectationMode,
300}
301
302const LOGIT_SIGMA_DEGENERATE: f64 = 1e-10;
303const LOGIT_SIGMA_TAYLOR_MAX: f64 = 2.5e-1;
304const LOGIT_TAIL_LOG_MAX: f64 = -18.0;
305const LOGIT_ERFCX_MU_MAX: f64 = 40.0;
306const LOGIT_ERFCX_SIGMA_MAX: f64 = 6.0;
307const LOGIT_JET_GHQ_SIGMA_MAX: f64 = 1.0;
331const CLOGLOG_SIGMA_DEGENERATE: f64 = 1e-10;
332const CLOGLOG_SIGMA_TAYLOR_MAX: f64 = 0.25;
333const CLOGLOG_JET_MOMENT_SIGMA_MAX: f64 = 1.0;
340const CLOGLOG_RARE_EVENT_LOG_MAX: f64 = -18.0;
341const CLOGLOG_LARGE_SIGMA_ASYMPTOTIC_MIN: f64 = 8.0;
342const CLOGLOG_POSITIVE_SATURATION_EDGE: f64 = 5.0;
343const CLOGLOG_POSITIVE_SATURATION_SIGMAS: f64 = 8.0;
344const CLOGLOG_GUMBEL_QUAD_ETA_LO: f64 = -40.0;
350const CLOGLOG_GUMBEL_QUAD_ETA_HI: f64 = 6.0;
351const CLOGLOG_GUMBEL_QUAD_MIN_NODES: usize = 97;
355const CLOGLOG_GUMBEL_QUAD_NODE_SCALE: f64 = 320.0;
359const CLOGLOG_GUMBEL_QUAD_MAX_NODES: usize = 513;
360const SERIES_CONSECUTIVE_SMALL_TERMS: usize = 6;
361const LOGIT_MAX_TERMS: usize = 160;
362const LOGIT_ERFCX_ACCURACY_TARGET: f64 = 1.0e-11;
378const CLOGLOG_MILES_ALPHA: f64 = 60.0;
379const CLOGLOG_MILES_MAX_TERMS: usize = 256;
380const CLOGLOG_MILES_PEAK_LOG_MAX: f64 = 0.0;
397const CLOGLOG_GAMMA_K_REF: f64 = 0.5;
398const CLOGLOG_GAMMA_T_MAX_REF: f64 = 24.0;
399const CLOGLOG_GAMMA_H_REF: f64 = 0.01;
400const CLOGLOG_CC_TOL: f64 = 1e-12;
404const CLOGLOG_CC_NODE_CAP: usize = 1025;
408const CLOGLOG_GAMMA_SAMPLE_COUNT: usize =
412 (CLOGLOG_GAMMA_T_MAX_REF / CLOGLOG_GAMMA_H_REF) as usize + 1;
413const CLOGLOG_CC_PREFER_THRESHOLD: usize = CLOGLOG_GAMMA_SAMPLE_COUNT / 3;
418const CLOGLOG_CC_MIN_N: usize = 17;
422
423impl QuadratureContext {
424 pub fn new() -> Self {
425 Self {
426 gh_cache: OnceLock::new(),
427 gh15_cache: OnceLock::new(),
428 gh21_cache: OnceLock::new(),
429 gh31_cache: OnceLock::new(),
430 gh51_cache: OnceLock::new(),
431 cc_cache: Mutex::new(HashMap::new()),
432 }
433 }
434
435 fn gauss_hermite(&self) -> &GaussHermiteRule {
436 self.gh_cache.get_or_init(compute_gauss_hermite)
437 }
438
439 fn gauss_hermite_n(&self, n: usize) -> &GaussHermiteRuleDynamic {
440 match n {
441 7 => self.gh15_cache.get_or_init(|| compute_gauss_hermite_n(15)),
444 15 => self.gh15_cache.get_or_init(|| compute_gauss_hermite_n(15)),
445 21 => self.gh21_cache.get_or_init(|| compute_gauss_hermite_n(21)),
446 31 => self.gh31_cache.get_or_init(|| compute_gauss_hermite_n(31)),
447 51 => self.gh51_cache.get_or_init(|| compute_gauss_hermite_n(51)),
448 _ => self.gh21_cache.get_or_init(|| compute_gauss_hermite_n(21)),
449 }
450 }
451
452 fn clenshaw_curtis_n(&self, n: usize) -> Arc<ClenshawCurtisRule> {
453 let mut cache = match self.cc_cache.lock() {
454 Ok(guard) => guard,
455 Err(poisoned) => poisoned.into_inner(),
456 };
457 cache
458 .entry(n)
459 .or_insert_with(|| Arc::new(compute_clenshaw_curtis_n(n)))
460 .clone()
461 }
462}
463
464impl Default for QuadratureContext {
465 fn default() -> Self {
466 Self::new()
467 }
468}
469
470struct GaussHermiteRule {
472 nodes: [f64; N_POINTS],
474 weights: [f64; N_POINTS],
476}
477
478pub(crate) struct GaussHermiteRuleDynamic {
479 pub(crate) nodes: Vec<f64>,
480 pub(crate) weights: Vec<f64>,
481}
482
483#[derive(Clone)]
484struct ClenshawCurtisRule {
485 nodes: Vec<f64>,
486 weights: Vec<f64>,
487}
488
489fn compute_clenshaw_curtis_n(n: usize) -> ClenshawCurtisRule {
490 assert!(
491 n >= 2,
492 "Clenshaw-Curtis rule requires at least two nodes: n={n}"
493 );
494 let m = n - 1;
509 let theta: Vec<f64> = (0..=m)
510 .map(|j| std::f64::consts::PI * (j as f64) / (m as f64))
511 .collect();
512 let nodes: Vec<f64> = theta.iter().map(|&th| th.cos()).collect();
513
514 if n == 2 {
515 return ClenshawCurtisRule {
516 nodes,
517 weights: vec![1.0, 1.0],
518 };
519 }
520
521 let mut weights = vec![0.0_f64; n];
522 let mut v = vec![1.0_f64; m - 1];
523
524 if m.is_multiple_of(2) {
525 let w0 = 1.0 / ((m * m - 1) as f64);
526 weights[0] = w0;
527 weights[m] = w0;
528 for k in 1..(m / 2) {
529 let denom = (4 * k * k - 1) as f64;
530 for j in 1..m {
531 v[j - 1] -= 2.0 * (2.0 * (k as f64) * theta[j]).cos() / denom;
532 }
533 }
534 for j in 1..m {
535 v[j - 1] -= ((m as f64) * theta[j]).cos() / ((m * m - 1) as f64);
536 }
537 } else {
538 let w0 = 1.0 / ((m * m) as f64);
539 weights[0] = w0;
540 weights[m] = w0;
541 for k in 1..=((m - 1) / 2) {
542 let denom = (4 * k * k - 1) as f64;
543 for j in 1..m {
544 v[j - 1] -= 2.0 * (2.0 * (k as f64) * theta[j]).cos() / denom;
545 }
546 }
547 }
548
549 for j in 1..m {
550 weights[j] = 2.0 * v[j - 1] / (m as f64);
551 }
552
553 for j in 0..=(m / 2) {
557 let jj = m - j;
558 let avg = 0.5 * (weights[j] + weights[jj]);
559 weights[j] = avg;
560 weights[jj] = avg;
561 }
562 let weight_sum: f64 = weights.iter().sum();
563 if weight_sum.is_finite() && weight_sum != 0.0 {
564 let scale = 2.0 / weight_sum;
565 for w in &mut weights {
566 *w *= scale;
567 }
568 }
569
570 ClenshawCurtisRule { nodes, weights }
571}
572
573fn cloglog_cc_required_nodes(mu: f64, sigma: f64, tol: f64) -> Result<usize, EstimationError> {
574 if !(mu.is_finite() && sigma.is_finite() && sigma > 0.0 && tol.is_finite() && tol > 0.0) {
575 crate::bail_invalid_estim!(
576 "CC cloglog backend requires finite mu, positive sigma, and positive tolerance"
577 .to_string(),
578 );
579 }
580
581 let p_tail = (tol / 8.0).clamp(1e-300, 0.25);
586 let a = gam_math::probability::standard_normal_quantile(p_tail)
587 .map(|z| -z)
588 .unwrap_or(8.0)
589 .max(1.0);
590
591 let ay = a * sigma;
592 let y = if ay > 0.0 {
593 1.0_f64.min(std::f64::consts::PI / (4.0 * ay))
594 } else {
595 1.0
596 };
597 let rho = y + (1.0 + y * y).sqrt();
598 let m_s = (0.5 * (a * y) * (a * y)).exp() / (2.0 * std::f64::consts::PI).sqrt();
599 let eps_quad = (tol / 4.0).max(1e-300);
600 let numer = ((8.0 * a * m_s) / ((rho - 1.0).max(1e-12) * eps_quad)).max(1.0);
601 let denom = rho.ln();
602 if !denom.is_finite() || denom <= 0.0 {
603 crate::bail_invalid_estim!("CC cloglog backend ellipse bound became degenerate");
604 }
605
606 let mut n = (1.0 + numer.ln() / denom).ceil() as usize;
607 n = n.max(CLOGLOG_CC_MIN_N);
608 if n.is_multiple_of(2) {
609 n += 1;
610 }
611 Ok(n)
612}
613
614#[inline]
615fn cloglog_should_prefer_cc(mu: f64, sigma: f64, tol: f64) -> bool {
616 match cloglog_cc_required_nodes(mu, sigma, tol) {
622 Ok(n) => n <= CLOGLOG_CC_PREFER_THRESHOLD,
623 Err(_) => false,
624 }
625}
626
627fn compute_gauss_hermite() -> GaussHermiteRule {
640 let mut diag = [0.0f64; N_POINTS]; let mut off_diag = [0.0f64; N_POINTS - 1];
646
647 for i in 0..(N_POINTS - 1) {
648 off_diag[i] = (((i + 1) as f64) / 2.0).sqrt();
650 }
651
652 let (eigenvalues, eigenvectors) = symmetric_tridiagonal_eigen(&mut diag, &mut off_diag);
655
656 let nodes = eigenvalues;
658 let mut weights = [0.0f64; N_POINTS];
659
660 let mu0 = std::f64::consts::PI.sqrt();
666 for i in 0..N_POINTS {
667 let v0 = eigenvectors[i][0];
668 weights[i] = mu0 * v0 * v0;
669 }
670
671 let mut indices: [usize; N_POINTS] = [0, 1, 2, 3, 4, 5, 6];
673 indices.sort_by(|&a, &b| nodes[a].total_cmp(&nodes[b]));
674
675 let sorted_nodes: [f64; N_POINTS] = std::array::from_fn(|i| nodes[indices[i]]);
676 let sortedweights: [f64; N_POINTS] = std::array::from_fn(|i| weights[indices[i]]);
677
678 GaussHermiteRule {
679 nodes: sorted_nodes,
680 weights: sortedweights,
681 }
682}
683
684pub(crate) fn compute_gauss_hermite_n(n: usize) -> GaussHermiteRuleDynamic {
685 let mut diag = vec![0.0f64; n];
686 let mut off_diag = vec![0.0f64; n.saturating_sub(1)];
687 for (i, od) in off_diag.iter_mut().enumerate() {
688 *od = (((i + 1) as f64) / 2.0).sqrt();
689 }
690 let (nodes, eigenvectors) = symmetric_tridiagonal_eigen_dynamic(&mut diag, &mut off_diag);
691 let mu0 = std::f64::consts::PI.sqrt();
692 let mut pairs = (0..n)
693 .map(|i| {
694 let v0 = eigenvectors[i][0];
695 (nodes[i], mu0 * v0 * v0)
696 })
697 .collect::<Vec<_>>();
698 pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
699 GaussHermiteRuleDynamic {
700 nodes: pairs.iter().map(|p| p.0).collect(),
701 weights: pairs.iter().map(|p| p.1).collect(),
702 }
703}
704
705fn symmetric_tridiagonal_eigen(
709 diag: &mut [f64; N_POINTS],
710 off_diag: &mut [f64; N_POINTS - 1],
711) -> ([f64; N_POINTS], [[f64; N_POINTS]; N_POINTS]) {
712 let mut diag_vec = diag.to_vec();
713 let mut off_diag_vec = off_diag.to_vec();
714 let (eigenvalues, eigenvectors) =
715 symmetric_tridiagonal_eigen_dynamic(&mut diag_vec, &mut off_diag_vec);
716
717 let mut values = [0.0; N_POINTS];
718 let mut vectors = [[0.0; N_POINTS]; N_POINTS];
719 values.copy_from_slice(&eigenvalues);
720 for i in 0..N_POINTS {
721 vectors[i].copy_from_slice(&eigenvectors[i]);
722 }
723 diag.copy_from_slice(&values);
724 off_diag.copy_from_slice(&off_diag_vec);
725 (values, vectors)
726}
727
728fn symmetric_tridiagonal_eigen_dynamic(
729 diag: &mut [f64],
730 off_diag: &mut [f64],
731) -> (Vec<f64>, Vec<Vec<f64>>) {
732 let dim = diag.len();
733 let mut z = vec![vec![0.0_f64; dim]; dim];
734 for (i, row) in z.iter_mut().enumerate().take(dim) {
735 row[i] = 1.0;
736 }
737 const DEFLATION_TOL: f64 = 1e-15;
743 const MAX_QL_SWEEPS: usize = 200;
744 let eps = DEFLATION_TOL;
745 let max_iter = MAX_QL_SWEEPS;
746 let mut t_norm = 0.0_f64;
752 for i in 0..dim {
753 let left = if i > 0 { off_diag[i - 1].abs() } else { 0.0 };
754 let right = if i + 1 < dim { off_diag[i].abs() } else { 0.0 };
755 let row_sum = diag[i].abs() + left + right;
756 if row_sum > t_norm {
757 t_norm = row_sum;
758 }
759 }
760 let mut n = dim;
761 while n > 1 {
762 let mut converged = false;
763 for _ in 0..max_iter {
764 let mut m = n - 1;
765 while m > 0 {
766 let row_scale = (diag[m - 1].abs() + diag[m].abs()).max(t_norm);
767 if off_diag[m - 1].abs() <= eps * row_scale {
768 off_diag[m - 1] = 0.0;
769 break;
770 }
771 m -= 1;
772 }
773 if m == n - 1 {
774 n -= 1;
775 converged = true;
776 break;
777 }
778 let shift = wilkinson_shift(diag[n - 2], diag[n - 1], off_diag[n - 2]);
779 let mut x = diag[m] - shift;
780 let mut y = off_diag[m];
781 for k in m..(n - 1) {
782 let (c, s) = if y.abs() > eps {
783 let r = x.hypot(y);
784 if r > 0.0 && r.is_finite() {
785 (x / r, -y / r)
786 } else {
787 (1.0, 0.0)
788 }
789 } else {
790 (1.0, 0.0)
791 };
792 if k > m {
793 off_diag[k - 1] = x.hypot(y);
794 }
795 let d1 = diag[k];
796 let d2 = diag[k + 1];
797 let e_k = off_diag[k];
798 diag[k] = c * c * d1 + s * s * d2 - 2.0 * c * s * e_k;
799 diag[k + 1] = s * s * d1 + c * c * d2 + 2.0 * c * s * e_k;
800 off_diag[k] = c * s * (d1 - d2) + (c * c - s * s) * e_k;
801 if k < n - 2 {
802 x = off_diag[k];
803 y = -s * off_diag[k + 1];
804 off_diag[k + 1] *= c;
805 }
806 for i in 0..dim {
807 let t = z[k][i];
808 z[k][i] = c * t - s * z[k + 1][i];
809 z[k + 1][i] = s * t + c * z[k + 1][i];
810 }
811 }
812 }
813 if !converged {
814 off_diag[n - 2] = 0.0;
815 n -= 1;
816 }
817 }
818 (diag.to_vec(), z)
819}
820
821#[inline]
822fn wilkinson_shift(a: f64, c: f64, b: f64) -> f64 {
823 let d = (a - c) * 0.5;
824 let t = d.hypot(b);
825 let sgn = if d >= 0.0 { 1.0 } else { -1.0 }; let denom = d + sgn * t;
827
828 if denom.abs() > f64::EPSILON * t.max(1.0) {
829 c - (b * b) / denom
830 } else {
831 c - t
833 }
834}
835
836#[inline]
847pub fn logit_posterior_mean(ctx: &QuadratureContext, eta: f64, se_eta: f64) -> f64 {
848 match logit_posterior_meanwith_deriv_controlled(eta, se_eta) {
849 Ok(out) => out.mean,
850 Err(_) => integrate_normal_ghq_adaptive(ctx, eta, se_eta, sigmoid),
851 }
852}
853
854#[inline]
862pub fn logit_posterior_meanwith_deriv(
863 eta: f64,
864 se_eta: f64,
865) -> Result<(f64, f64), EstimationError> {
866 let out = logit_posterior_meanwith_deriv_controlled(eta, se_eta)?;
877 Ok((out.mean, out.dmean_dmu))
878}
879
880#[inline]
881pub fn probit_posterior_meanwith_deriv_exact(mu: f64, sigma: f64) -> IntegratedMeanDerivative {
882 if !(mu.is_finite() && sigma.is_finite()) || sigma <= 1e-12 {
910 let mean = gam_math::probability::normal_cdf(mu);
911 let dmean_dmu = gam_math::probability::normal_pdf(mu);
912 return IntegratedMeanDerivative {
913 mean,
914 dmean_dmu,
915 mode: IntegratedExpectationMode::ExactClosedForm,
916 };
917 }
918 let denom = (1.0 + sigma * sigma).sqrt();
919 let z = mu / denom;
920 IntegratedMeanDerivative {
921 mean: gam_math::probability::normal_cdf(z),
922 dmean_dmu: gam_math::probability::normal_pdf(z) / denom,
923 mode: IntegratedExpectationMode::ExactClosedForm,
924 }
925}
926
927#[inline]
928fn logistic_normal_exact_eligible(mu: f64, sigma: f64) -> bool {
929 mu.is_finite()
930 && sigma.is_finite()
931 && mu.abs() <= LOGIT_ERFCX_MU_MAX
932 && (LOGIT_SIGMA_TAYLOR_MAX..=LOGIT_ERFCX_SIGMA_MAX).contains(&sigma)
933}
934
935#[inline]
979fn logistic_normal_series_cutoff(mu: f64, sigma: f64, target_accuracy: f64) -> Option<usize> {
980 assert!(sigma > 0.0);
981 assert!(target_accuracy > 0.0);
982 let m = mu.abs();
983 let s = sigma;
984 let gauss = (-(m * m) / (2.0 * s * s)).exp();
985 let coeff_mean = m * (2.0_f64 / std::f64::consts::PI).sqrt() * gauss / (s * s * s);
986 let coeff_deriv =
987 2.0 * gauss * (m * m - s * s).abs() / ((2.0 * std::f64::consts::PI).sqrt() * s.powi(5));
988 let asymptotic_index = |coeff: f64| -> f64 {
992 if !coeff.is_finite() || coeff <= target_accuracy {
993 0.0
994 } else {
995 (coeff / target_accuracy).sqrt() - 1.0
996 }
997 };
998 let peak_floor = m / (s * s) + 1.0;
1002 let required = asymptotic_index(coeff_mean)
1003 .max(asymptotic_index(coeff_deriv))
1004 .max(peak_floor);
1005 if !required.is_finite() || required > LOGIT_MAX_TERMS as f64 {
1006 return None;
1007 }
1008 Some((required.ceil() as usize).max(4))
1011}
1012
1013#[inline]
1014fn stable_sigmoidwith_derivative(x: f64) -> (f64, f64) {
1015 let x_clamped = x.clamp(-QUADRATURE_EXP_LOG_MAX, QUADRATURE_EXP_LOG_MAX);
1016 if x_clamped != x {
1017 return (sigmoid(x), 0.0);
1018 }
1019 if x_clamped >= 0.0 {
1020 let z = (-x_clamped).exp();
1021 let denom = 1.0 + z;
1022 (1.0 / denom, z / (denom * denom))
1023 } else {
1024 let z = x_clamped.exp();
1025 let denom = 1.0 + z;
1026 (z / denom, z / (denom * denom))
1027 }
1028}
1029
1030#[inline]
1031fn logit_small_sigma_taylor(mu: f64, sigma: f64) -> IntegratedMeanDerivative {
1032 let (mean0, d1, d2, d3) = component_point_jet(LinkComponent::Logit, mu);
1040 let s2 = sigma * sigma;
1041 IntegratedMeanDerivative {
1042 mean: (mean0 + 0.5 * s2 * d2).clamp(0.0, 1.0),
1043 dmean_dmu: (d1 + 0.5 * s2 * d3).max(0.0),
1044 mode: IntegratedExpectationMode::ControlledAsymptotic,
1045 }
1046}
1047
1048#[inline]
1049fn logit_tail_asymptotic(mu: f64, sigma: f64) -> Option<IntegratedMeanDerivative> {
1050 if mu <= 0.0 {
1055 let log_mean = mu + 0.5 * sigma * sigma;
1056 if log_mean <= LOGIT_TAIL_LOG_MAX {
1057 let mean = safe_exp(log_mean);
1058 return Some(IntegratedMeanDerivative {
1059 mean,
1060 dmean_dmu: mean,
1061 mode: IntegratedExpectationMode::ControlledAsymptotic,
1062 });
1063 }
1064 } else {
1065 let log_tail = -mu + 0.5 * sigma * sigma;
1066 if log_tail <= LOGIT_TAIL_LOG_MAX {
1067 let tail = safe_exp(log_tail);
1068 return Some(IntegratedMeanDerivative {
1069 mean: 1.0 - tail,
1070 dmean_dmu: tail,
1071 mode: IntegratedExpectationMode::ControlledAsymptotic,
1072 });
1073 }
1074 }
1075 None
1076}
1077
1078#[inline]
1079fn scaled_erfcx_termwith_derivative(m: f64, s: f64, x: f64, dxdm: f64) -> (f64, f64) {
1080 let pref = 0.5 * (-(m * m) / (2.0 * s * s)).exp();
1081 if x >= 0.0 {
1082 let ex = erfcx_nonnegative(x);
1083 let term = pref * ex;
1084 let ex_prime = 2.0 * x * ex - std::f64::consts::FRAC_2_SQRT_PI;
1085 let dterm = pref * ((-m / (s * s)) * ex + ex_prime * dxdm);
1086 (term, dterm)
1087 } else {
1088 let lead = (x * x - (m * m) / (2.0 * s * s)).exp();
1089 let dlead = lead * (2.0 * x * dxdm - m / (s * s));
1090 let (rest, drest) = scaled_erfcx_termwith_derivative(m, s, -x, -dxdm);
1091 (lead - rest, dlead - drest)
1092 }
1093}
1094
1095pub(crate) fn logit_posterior_meanwith_deriv_exact(
1096 mu: f64,
1097 sigma: f64,
1098) -> Result<IntegratedMeanDerivative, EstimationError> {
1099 if !(mu.is_finite() && sigma.is_finite()) {
1117 crate::bail_invalid_estim!("logit exact expectation requires finite mu and sigma");
1118 }
1119 if sigma <= LOGIT_SIGMA_DEGENERATE {
1120 let (mean, dmean_dmu) = stable_sigmoidwith_derivative(mu);
1121 return Ok(IntegratedMeanDerivative {
1122 mean,
1123 dmean_dmu,
1124 mode: IntegratedExpectationMode::ExactClosedForm,
1125 });
1126 }
1127 if let Some(out) = logit_tail_asymptotic(mu, sigma) {
1128 return Ok(out);
1129 }
1130 if sigma < LOGIT_SIGMA_TAYLOR_MAX {
1131 return Ok(logit_small_sigma_taylor(mu, sigma));
1132 }
1133 if logistic_normal_exact_eligible(mu, sigma)
1134 && let Ok(out) = logit_posterior_meanwith_deriv_exact_erfcx(mu, sigma)
1135 {
1136 return Ok(out);
1137 }
1138 Err(EstimationError::InvalidInput(
1147 "logit analytic expectation has no certified representation in this regime".to_string(),
1148 ))
1149}
1150
1151fn logit_posterior_meanwith_deriv_exact_erfcx(
1152 mu: f64,
1153 sigma: f64,
1154) -> Result<IntegratedMeanDerivative, EstimationError> {
1155 let m = mu.abs();
1181 let s = sigma;
1182 let z = SQRT_2 * s;
1183 let phi_term = gam_math::probability::normal_cdf(m / s);
1184 let phi_prime = gam_math::probability::normal_pdf(m / s) / s;
1185 let Some(max_k) = logistic_normal_series_cutoff(mu, sigma, LOGIT_ERFCX_ACCURACY_TARGET) else {
1186 crate::bail_invalid_estim!(
1187 "logit erfcx series truncation bound exceeds LOGIT_MAX_TERMS at the required accuracy"
1188 .to_string(),
1189 );
1190 };
1191
1192 let mut sum = 0.0_f64;
1193 let mut dsum = 0.0_f64;
1194 let mut k = 1usize;
1202 while k <= max_k {
1203 for kk in [k, k + 1].into_iter().filter(|kk| *kk <= max_k) {
1204 let kf = kk as f64;
1205 let a = (kf * s * s + m) / z;
1206 let b = (kf * s * s - m) / z;
1207 let sign = if kk % 2 == 1 { 1.0 } else { -1.0 };
1208 let (va, dva) = scaled_erfcx_termwith_derivative(m, s, a, 1.0 / z);
1209 let (vb, dvb) = scaled_erfcx_termwith_derivative(m, s, b, -1.0 / z);
1210 sum += sign * (va - vb);
1211 dsum += sign * (dva - dvb);
1212 }
1213 k += 2;
1214 }
1215
1216 let mut mean = phi_term + sum;
1217 let dmean = (phi_prime + dsum).max(0.0);
1218 if mu < 0.0 {
1219 mean = 1.0 - mean;
1220 }
1221 if !(mean.is_finite() && dmean.is_finite() && dmean >= 0.0) {
1222 crate::bail_invalid_estim!("logit erfcx expectation produced non-finite values");
1223 }
1224 Ok(IntegratedMeanDerivative {
1225 mean,
1226 dmean_dmu: dmean,
1227 mode: IntegratedExpectationMode::ExactSpecialFunction,
1228 })
1229}
1230
1231#[inline]
1236fn logit_posterior_meanwith_deriv_quadrature(mu: f64, sigma: f64) -> IntegratedMeanDerivative {
1237 let mean = integrate_normal_adaptive(mu, sigma, |x| stable_sigmoidwith_derivative(x).0);
1238 let dmean_dmu =
1239 integrate_normal_adaptive(mu, sigma, |x| stable_sigmoidwith_derivative(x).1).max(0.0);
1240 IntegratedMeanDerivative {
1241 mean,
1242 dmean_dmu,
1243 mode: IntegratedExpectationMode::QuadratureFallback,
1244 }
1245}
1246
1247#[inline]
1248fn logit_posterior_meanwith_deriv_controlled(
1249 mu: f64,
1250 sigma: f64,
1251) -> Result<IntegratedMeanDerivative, EstimationError> {
1252 if !(mu.is_finite() && sigma.is_finite()) {
1253 crate::bail_invalid_estim!("logit integrated moments require finite mu and sigma");
1254 }
1255 let candidate = match logit_posterior_meanwith_deriv_exact(mu, sigma) {
1256 Ok(out) => out,
1257 Err(_) => return Ok(logit_posterior_meanwith_deriv_quadrature(mu, sigma)),
1258 };
1259 match candidate.mode {
1271 IntegratedExpectationMode::ExactSpecialFunction
1272 | IntegratedExpectationMode::ControlledAsymptotic => {
1273 let reference = logit_posterior_meanwith_deriv_quadrature(mu, sigma);
1274 if integrated_mean_derivative_drift_exceeds(
1275 &candidate, &reference, 1e-6, 1e-4, 1e-7, 1e-3,
1276 ) {
1277 Ok(reference)
1278 } else {
1279 Ok(candidate)
1280 }
1281 }
1282 _ => Ok(candidate),
1283 }
1284}
1285
1286#[inline]
1287fn log_normal_cdf_stable(x: f64) -> f64 {
1288 if !x.is_finite() {
1289 return if x.is_sign_negative() {
1290 f64::NEG_INFINITY
1291 } else {
1292 0.0
1293 };
1294 }
1295 if x < -8.0 {
1296 let u = -x / SQRT_2;
1297 -u * u + (0.5 * erfcx_nonnegative(u)).ln()
1298 } else {
1299 gam_math::probability::normal_cdf(x).max(1e-300).ln()
1300 }
1301}
1302
1303#[inline]
1304fn cloglog_extreme_asymptotic(mu: f64, sigma: f64) -> Option<IntegratedMeanDerivative> {
1305 let rare_log = mu + 0.5 * sigma * sigma;
1315 if rare_log <= CLOGLOG_RARE_EVENT_LOG_MAX {
1316 let mean = safe_exp(rare_log);
1317 return Some(IntegratedMeanDerivative {
1318 mean,
1319 dmean_dmu: mean,
1320 mode: IntegratedExpectationMode::ControlledAsymptotic,
1321 });
1322 }
1323 if mu - CLOGLOG_POSITIVE_SATURATION_SIGMAS * sigma >= CLOGLOG_POSITIVE_SATURATION_EDGE {
1324 return Some(IntegratedMeanDerivative {
1325 mean: 1.0,
1326 dmean_dmu: 0.0,
1327 mode: IntegratedExpectationMode::ControlledAsymptotic,
1328 });
1329 }
1330 None
1336}
1337
1338#[inline]
1339fn cloglog_survival_extreme_asymptotic(
1340 mu: f64,
1341 sigma: f64,
1342) -> Option<(f64, IntegratedExpectationMode)> {
1343 let rare_log = mu + 0.5 * sigma * sigma;
1344 if rare_log <= CLOGLOG_RARE_EVENT_LOG_MAX {
1345 let mean = safe_exp(rare_log);
1346 return Some((
1347 (1.0 - mean).clamp(0.0, 1.0),
1348 IntegratedExpectationMode::ControlledAsymptotic,
1349 ));
1350 }
1351 if mu - CLOGLOG_POSITIVE_SATURATION_SIGMAS * sigma >= CLOGLOG_POSITIVE_SATURATION_EDGE {
1352 return Some((0.0, IntegratedExpectationMode::ControlledAsymptotic));
1357 }
1358 None
1362}
1363
1364#[inline]
1372fn cloglog_gumbel_quad_nodes(sigma: f64) -> usize {
1373 let target = (CLOGLOG_GUMBEL_QUAD_NODE_SCALE / sigma.min(CLOGLOG_LARGE_SIGMA_ASYMPTOTIC_MIN))
1374 .ceil() as usize;
1375 let n = target
1376 .max(CLOGLOG_GUMBEL_QUAD_MIN_NODES)
1377 .min(CLOGLOG_GUMBEL_QUAD_MAX_NODES);
1378 if n % 2 == 0 { n + 1 } else { n }
1379}
1380
1381fn cloglog_log_survival_gumbel_quadrature(ctx: &QuadratureContext, mu: f64, sigma: f64) -> f64 {
1404 let a = CLOGLOG_GUMBEL_QUAD_ETA_LO;
1405 let b = CLOGLOG_GUMBEL_QUAD_ETA_HI;
1406 let half = 0.5 * (b - a);
1407 let mid = 0.5 * (a + b);
1408 let rule = ctx.clenshaw_curtis_n(cloglog_gumbel_quad_nodes(sigma));
1409 let mut running_max = f64::NEG_INFINITY;
1412 let mut running_sum = 0.0_f64;
1413 for (&node, &weight) in rule.nodes.iter().zip(rule.weights.iter()) {
1414 let eta = half * node + mid;
1415 let summand = (weight * half).ln()
1416 + (eta - safe_exp(eta))
1417 + log_normal_cdf_stable((eta - mu) / sigma);
1418 if !summand.is_finite() {
1419 continue;
1420 }
1421 if summand > running_max {
1422 running_sum = running_sum * (running_max - summand).exp() + 1.0;
1423 running_max = summand;
1424 } else {
1425 running_sum += (summand - running_max).exp();
1426 }
1427 }
1428 if running_max == f64::NEG_INFINITY {
1429 f64::NEG_INFINITY
1430 } else {
1431 running_max + running_sum.ln()
1432 }
1433}
1434
1435pub(crate) fn cloglog_log_survival_term_controlled(
1444 ctx: &QuadratureContext,
1445 mu: f64,
1446 sigma: f64,
1447) -> (f64, IntegratedExpectationMode) {
1448 if !(mu.is_finite() && sigma.is_finite()) || sigma <= CLOGLOG_SIGMA_DEGENERATE {
1449 return (-safe_exp(mu), IntegratedExpectationMode::ExactClosedForm);
1451 }
1452 let rare_log = mu + 0.5 * sigma * sigma;
1453 if rare_log <= CLOGLOG_RARE_EVENT_LOG_MAX {
1454 return (
1457 (-safe_exp(rare_log)).ln_1p(),
1458 IntegratedExpectationMode::ControlledAsymptotic,
1459 );
1460 }
1461 if sigma >= CLOGLOG_LARGE_SIGMA_ASYMPTOTIC_MIN {
1462 return (
1463 cloglog_log_survival_gumbel_quadrature(ctx, mu, sigma),
1464 IntegratedExpectationMode::ControlledAsymptotic,
1465 );
1466 }
1467 let (value, mode) = cloglog_survival_term_controlled(ctx, mu, sigma);
1468 if value > 0.0 {
1469 (value.ln(), mode)
1470 } else {
1471 (
1474 cloglog_log_survival_gumbel_quadrature(ctx, mu, sigma),
1475 IntegratedExpectationMode::QuadratureFallback,
1476 )
1477 }
1478}
1479
1480#[inline]
1499fn gumbel_survival(x: f64) -> f64 {
1500 (-safe_exp(x)).exp()
1501}
1502
1503#[inline]
1508fn cloglog_mean_d1_exact(x: f64) -> f64 {
1509 let ex = safe_exp(x);
1510 if ex.is_infinite() {
1511 0.0
1512 } else {
1513 ex * (-ex).exp()
1514 }
1515}
1516
1517#[inline]
1527fn cloglog_mean_exact(x: f64) -> f64 {
1528 cloglog_negative_tail_mean(x)
1529}
1530
1531#[inline]
1547fn cloglog_negative_tail_mean(eta: f64) -> f64 {
1548 if eta < -745.0 {
1552 0.0
1554 } else {
1555 let ex = safe_exp(eta);
1558 -(-ex).exp_m1()
1559 }
1560}
1561
1562#[inline]
1567fn cloglog_small_sigma_taylor(mu: f64, sigma: f64) -> IntegratedMeanDerivative {
1568 if sigma <= CLOGLOG_SIGMA_DEGENERATE {
1592 return IntegratedMeanDerivative {
1593 mean: cloglog_mean_exact(mu),
1594 dmean_dmu: cloglog_mean_d1_exact(mu),
1595 mode: IntegratedExpectationMode::ExactClosedForm,
1596 };
1597 }
1598
1599 let ex = safe_exp(mu);
1600 if !ex.is_finite() {
1601 return IntegratedMeanDerivative {
1603 mean: 1.0,
1604 dmean_dmu: 0.0,
1605 mode: IntegratedExpectationMode::ControlledAsymptotic,
1606 };
1607 }
1608 let surv = (-ex).exp();
1609 if surv == 0.0 {
1610 return IntegratedMeanDerivative {
1612 mean: 1.0,
1613 dmean_dmu: 0.0,
1614 mode: IntegratedExpectationMode::ControlledAsymptotic,
1615 };
1616 }
1617
1618 let s2 = sigma * sigma;
1619 let s4 = s2 * s2;
1620 let s6 = s4 * s2;
1621 let s8 = s4 * s4;
1622 let e2x = ex * ex;
1623 let e3x = e2x * ex;
1624 let e4x = e3x * ex;
1625 let e5x = e4x * ex;
1626 let e6x = e5x * ex;
1627 let e7x = e6x * ex;
1628 let e8x = e7x * ex;
1629 let e9x = e8x * ex;
1630 let f0 = -(-ex).exp_m1();
1644 let f1 = ex * surv;
1645 let f2 = surv * (ex - e2x);
1646 let f3 = surv * (ex - 3.0 * e2x + e3x);
1647 let f4 = surv * (ex - 7.0 * e2x + 6.0 * e3x - e4x);
1648 let f5 = surv * (ex - 15.0 * e2x + 25.0 * e3x - 10.0 * e4x + e5x);
1649 let f6 = surv * (ex - 31.0 * e2x + 90.0 * e3x - 65.0 * e4x + 15.0 * e5x - e6x);
1650 let f7 = surv * (ex - 63.0 * e2x + 301.0 * e3x - 350.0 * e4x + 140.0 * e5x - 21.0 * e6x + e7x);
1651 let f8 = surv
1652 * (ex - 127.0 * e2x + 966.0 * e3x - 1701.0 * e4x + 1050.0 * e5x - 266.0 * e6x + 28.0 * e7x
1653 - e8x);
1654 let f9 = surv
1655 * (ex - 255.0 * e2x + 3025.0 * e3x - 7770.0 * e4x + 6951.0 * e5x - 2646.0 * e6x
1656 + 462.0 * e7x
1657 - 36.0 * e8x
1658 + e9x);
1659 IntegratedMeanDerivative {
1663 mean: f0 + 0.5 * s2 * f2 + (s4 / 8.0) * f4 + (s6 / 48.0) * f6 + (s8 / 384.0) * f8,
1664 dmean_dmu: (f1 + 0.5 * s2 * f3 + (s4 / 8.0) * f5 + (s6 / 48.0) * f7 + (s8 / 384.0) * f9)
1665 .max(0.0),
1666 mode: IntegratedExpectationMode::ControlledAsymptotic,
1667 }
1668}
1669
1670#[inline]
1671fn adaptive_simpson_refine(
1676 g: &impl Fn(f64) -> f64,
1677 a: f64,
1678 b: f64,
1679 fa: f64,
1680 fb: f64,
1681 fm: f64,
1682 whole: f64,
1683 tol: f64,
1684 depth: i32,
1685) -> f64 {
1686 let m = 0.5 * (a + b);
1687 let lm = 0.5 * (a + m);
1688 let rm = 0.5 * (m + b);
1689 let flm = g(lm);
1690 let frm = g(rm);
1691 let left = (m - a) / 6.0 * (fa + 4.0 * flm + fm);
1692 let right = (b - m) / 6.0 * (fm + 4.0 * frm + fb);
1693 let est = left + right;
1694 if depth <= 0 || (est - whole).abs() <= 15.0 * tol {
1695 return est + (est - whole) / 15.0;
1696 }
1697 adaptive_simpson_refine(g, a, m, fa, fm, flm, left, 0.5 * tol, depth - 1)
1698 + adaptive_simpson_refine(g, m, b, fm, fb, frm, right, 0.5 * tol, depth - 1)
1699}
1700
1701fn integrate_normal_adaptive(mu: f64, sigma: f64, f: impl Fn(f64) -> f64) -> f64 {
1716 if !(sigma.is_finite()) || sigma < 1e-10 {
1717 return f(mu);
1718 }
1719 const K: f64 = 15.0;
1720 const INITIAL_PANELS: usize = 24;
1721 const TOL: f64 = 1e-12;
1722 const MAX_DEPTH: i32 = 40;
1723 let inv_sqrt_2pi = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
1724 let g = |u: f64| f(mu + sigma * u) * inv_sqrt_2pi * (-0.5 * u * u).exp();
1728 let panel = 2.0 * K / INITIAL_PANELS as f64;
1729 let mut total = 0.0;
1730 for p in 0..INITIAL_PANELS {
1731 let a = -K + p as f64 * panel;
1732 let b = a + panel;
1733 let fa = g(a);
1734 let fb = g(b);
1735 let fm = g(0.5 * (a + b));
1736 let whole = (b - a) / 6.0 * (fa + 4.0 * fm + fb);
1737 total += adaptive_simpson_refine(&g, a, b, fa, fb, fm, whole, TOL, MAX_DEPTH);
1738 }
1739 total
1740}
1741
1742fn cloglog_posterior_meanwith_deriv_quadrature(mu: f64, sigma: f64) -> IntegratedMeanDerivative {
1743 if sigma < 1e-10 {
1744 return IntegratedMeanDerivative {
1745 mean: cloglog_mean_exact(mu),
1746 dmean_dmu: cloglog_mean_d1_exact(mu),
1747 mode: IntegratedExpectationMode::ExactClosedForm,
1748 };
1749 }
1750 let mean = cloglog_mean_from_survival(survival_posterior_mean_quadrature(mu, sigma));
1751 let dmean_dmu = integrate_normal_adaptive(mu, sigma, cloglog_mean_d1_exact).max(0.0);
1752 IntegratedMeanDerivative {
1753 mean,
1754 dmean_dmu,
1755 mode: IntegratedExpectationMode::QuadratureFallback,
1756 }
1757}
1758
1759#[inline]
1760fn survival_posterior_mean_quadrature(eta: f64, se_eta: f64) -> f64 {
1761 integrate_normal_adaptive(eta, se_eta, gumbel_survival).clamp(0.0, 1.0)
1762}
1763
1764fn cloglog_survival_term_controlled(
1765 ctx: &QuadratureContext,
1766 mu: f64,
1767 sigma: f64,
1768) -> (f64, IntegratedExpectationMode) {
1769 if !(mu.is_finite() && sigma.is_finite()) || sigma <= CLOGLOG_SIGMA_DEGENERATE {
1806 return (
1807 gumbel_survival(mu).clamp(0.0, 1.0),
1808 IntegratedExpectationMode::ExactClosedForm,
1809 );
1810 }
1811 if sigma < CLOGLOG_SIGMA_TAYLOR_MAX {
1812 let mean = cloglog_small_sigma_taylor(mu, sigma).mean;
1813 return (
1814 (1.0 - mean).clamp(0.0, 1.0),
1815 IntegratedExpectationMode::ControlledAsymptotic,
1816 );
1817 }
1818 if let Some(out) = cloglog_survival_extreme_asymptotic(mu, sigma) {
1819 return out;
1820 }
1821 if sigma >= CLOGLOG_LARGE_SIGMA_ASYMPTOTIC_MIN {
1822 let log_s = cloglog_log_survival_gumbel_quadrature(ctx, mu, sigma);
1829 return (
1830 safe_exp(log_s).clamp(0.0, 1.0),
1831 IntegratedExpectationMode::ControlledAsymptotic,
1832 );
1833 }
1834 if cloglog_survival_miles_is_reliable(mu, sigma)
1835 && let Ok(out) = cloglog_survival_miles(mu, sigma)
1836 {
1837 return (
1838 out.clamp(0.0, 1.0),
1839 IntegratedExpectationMode::ExactSpecialFunction,
1840 );
1841 }
1842 if cloglog_should_prefer_cc(mu, sigma, CLOGLOG_CC_TOL)
1843 && let Ok(out) = cloglog_survival_cc(ctx, mu, sigma, CLOGLOG_CC_TOL)
1844 {
1845 return (
1846 out.clamp(0.0, 1.0),
1847 IntegratedExpectationMode::ExactSpecialFunction,
1848 );
1849 }
1850 if let Ok(out) = cloglog_survival_gamma_reference(mu, sigma) {
1851 return (
1852 out.clamp(0.0, 1.0),
1853 IntegratedExpectationMode::ExactSpecialFunction,
1854 );
1855 }
1856 (
1857 survival_posterior_mean_quadrature(mu, sigma),
1858 IntegratedExpectationMode::QuadratureFallback,
1859 )
1860}
1861
1862#[inline]
1863fn lognormal_laplace_term_controlled(
1864 ctx: &QuadratureContext,
1865 z: f64,
1866 mu: f64,
1867 sigma: f64,
1868) -> (f64, IntegratedExpectationMode) {
1869 if !(z.is_finite() && z > 0.0) {
1895 return (f64::NAN, IntegratedExpectationMode::QuadratureFallback);
1896 }
1897 lognormal_laplace_unit_term_shared(ctx, mu + z.ln(), sigma)
1898}
1899
1900#[inline]
1901pub(crate) fn lognormal_laplace_unit_term_shared(
1902 ctx: &QuadratureContext,
1903 shifted_mu: f64,
1904 sigma: f64,
1905) -> (f64, IntegratedExpectationMode) {
1906 cloglog_survival_term_controlled(ctx, shifted_mu, sigma)
1907}
1908
1909#[inline]
1913pub fn lognormal_laplace_unit_log_term_shared(
1914 ctx: &QuadratureContext,
1915 shifted_mu: f64,
1916 sigma: f64,
1917) -> (f64, IntegratedExpectationMode) {
1918 cloglog_log_survival_term_controlled(ctx, shifted_mu, sigma)
1919}
1920
1921#[inline]
1922fn cloglog_survivalsecond_moment_controlled(
1923 ctx: &QuadratureContext,
1924 mu: f64,
1925 sigma: f64,
1926) -> (f64, IntegratedExpectationMode) {
1927 lognormal_laplace_term_controlled(ctx, 2.0, mu, sigma)
1942}
1943
1944#[inline]
1945fn cloglog_survival_pair_controlled(
1946 ctx: &QuadratureContext,
1947 mu: f64,
1948 sigma: f64,
1949) -> (
1950 (f64, IntegratedExpectationMode),
1951 (f64, IntegratedExpectationMode),
1952) {
1953 let shiftedmu = mu + sigma * sigma;
1954
1955 if cloglog_survival_miles_is_reliable(mu, sigma)
1966 && cloglog_survival_miles_is_reliable(shiftedmu, sigma)
1967 && let (Ok(base), Ok(shifted)) = (
1968 cloglog_survival_miles(mu, sigma),
1969 cloglog_survival_miles(shiftedmu, sigma),
1970 )
1971 {
1972 return (
1973 (
1974 base.clamp(0.0, 1.0),
1975 IntegratedExpectationMode::ExactSpecialFunction,
1976 ),
1977 (
1978 shifted.clamp(0.0, 1.0),
1979 IntegratedExpectationMode::ExactSpecialFunction,
1980 ),
1981 );
1982 }
1983
1984 if cloglog_should_prefer_cc(mu, sigma, CLOGLOG_CC_TOL)
1985 && cloglog_should_prefer_cc(shiftedmu, sigma, CLOGLOG_CC_TOL)
1986 && let (Ok(base), Ok(shifted)) = (
1987 cloglog_survival_cc(ctx, mu, sigma, CLOGLOG_CC_TOL),
1988 cloglog_survival_cc(ctx, shiftedmu, sigma, CLOGLOG_CC_TOL),
1989 )
1990 {
1991 return (
1992 (
1993 base.clamp(0.0, 1.0),
1994 IntegratedExpectationMode::ExactSpecialFunction,
1995 ),
1996 (
1997 shifted.clamp(0.0, 1.0),
1998 IntegratedExpectationMode::ExactSpecialFunction,
1999 ),
2000 );
2001 }
2002
2003 if let (Ok(base), Ok(shifted)) = (
2004 cloglog_survival_gamma_reference(mu, sigma),
2005 cloglog_survival_gamma_reference(shiftedmu, sigma),
2006 ) {
2007 return (
2008 (
2009 base.clamp(0.0, 1.0),
2010 IntegratedExpectationMode::ExactSpecialFunction,
2011 ),
2012 (
2013 shifted.clamp(0.0, 1.0),
2014 IntegratedExpectationMode::ExactSpecialFunction,
2015 ),
2016 );
2017 }
2018
2019 (
2020 cloglog_survival_term_controlled(ctx, mu, sigma),
2021 cloglog_survival_term_controlled(ctx, shiftedmu, sigma),
2022 )
2023}
2024
2025#[inline]
2026fn cloglog_mean_from_survival(survival: f64) -> f64 {
2027 let survival = survival.clamp(0.0, 1.0);
2028 if survival > 0.5 {
2029 -survival.ln().exp_m1()
2039 } else {
2040 1.0 - survival
2041 }
2042}
2043
2044#[inline]
2045fn cloglog_shift_identity_derivative(mu: f64, sigma: f64, shifted_survival: f64) -> f64 {
2046 if !(mu.is_finite() && sigma.is_finite()) || shifted_survival <= 0.0 {
2060 return 0.0;
2061 }
2062 cloglog_shift_identity_derivative_log(mu, sigma, shifted_survival.ln())
2063}
2064
2065#[inline]
2075fn cloglog_shift_identity_derivative_log(mu: f64, sigma: f64, log_shifted_survival: f64) -> f64 {
2076 if !(mu.is_finite() && sigma.is_finite()) || log_shifted_survival == f64::NEG_INFINITY {
2077 return 0.0;
2078 }
2079 let log_derivative = mu + 0.5 * sigma * sigma + log_shifted_survival;
2080 let upper = 1.0 / std::f64::consts::E;
2081 if !log_derivative.is_finite() {
2082 return upper;
2085 }
2086 safe_exp(log_derivative).clamp(0.0, upper)
2087}
2088
2089#[inline]
2090fn log_half_erfc_stable(u: f64) -> f64 {
2091 if u > 0.0 {
2101 -u * u + (0.5 * erfcx_nonnegative(u)).ln()
2102 } else {
2103 (0.5 * erfc(u)).ln()
2104 }
2105}
2106
2107#[inline]
2147fn cloglog_survival_miles_is_reliable(mu: f64, sigma: f64) -> bool {
2148 if !(mu.is_finite() && sigma.is_finite() && sigma > 0.0) {
2149 return false;
2150 }
2151 let alpha_ln = CLOGLOG_MILES_ALPHA.ln();
2152 let shifted = mu - alpha_ln;
2153 let peak_log = CLOGLOG_MILES_ALPHA - 0.5 * shifted * shifted / (sigma * sigma);
2154 peak_log.is_finite() && peak_log <= CLOGLOG_MILES_PEAK_LOG_MAX
2155}
2156
2157fn cloglog_survival_miles(mu: f64, sigma: f64) -> Result<f64, EstimationError> {
2158 let alpha_ln = CLOGLOG_MILES_ALPHA.ln();
2188 let mut s_sum = 0.0_f64;
2189 let mut stable_pairs = 0usize;
2190
2191 for pair_start in (0..CLOGLOG_MILES_MAX_TERMS).step_by(2) {
2192 let mut pair_s = 0.0_f64;
2193 for n in pair_start..(pair_start + 2).min(CLOGLOG_MILES_MAX_TERMS) {
2194 let nf = n as f64;
2195 let sign = if n % 2 == 0 { 1.0 } else { -1.0 };
2196 let base_log = nf * mu + 0.5 * sigma * sigma * nf * nf
2197 - statrs::function::gamma::ln_gamma(nf + 1.0);
2198 let u = (mu - alpha_ln + sigma * sigma * nf) / (SQRT_2 * sigma);
2199 let log_half_erfc = log_half_erfc_stable(u);
2200 let term_log = base_log + log_half_erfc;
2201 if term_log > QUADRATURE_EXP_LOG_MAX {
2202 crate::bail_invalid_estim!("Miles cloglog series term exceeded finite exp range");
2203 }
2204 let term = sign * safe_exp(term_log);
2205 pair_s += term;
2206 }
2207 s_sum += pair_s;
2208
2209 let s_scale = s_sum.abs().max(1.0);
2210 if pair_s.abs() <= 2e-15 * s_scale {
2211 stable_pairs += 1;
2212 if stable_pairs >= SERIES_CONSECUTIVE_SMALL_TERMS {
2213 if s_sum.is_finite() && (-1e-10..=1.0 + 1e-10).contains(&s_sum) {
2214 return Ok(s_sum.clamp(0.0, 1.0));
2215 }
2216 break;
2217 }
2218 } else {
2219 stable_pairs = 0;
2220 }
2221 }
2222
2223 Err(EstimationError::InvalidInput(
2224 "Miles cloglog series did not converge safely".to_string(),
2225 ))
2226}
2227
2228fn cloglog_survival_cc(
2229 ctx: &QuadratureContext,
2230 mu: f64,
2231 sigma: f64,
2232 tol: f64,
2233) -> Result<f64, EstimationError> {
2234 if !(mu.is_finite() && sigma.is_finite() && sigma > 0.0 && tol.is_finite() && tol > 0.0) {
2235 crate::bail_invalid_estim!(
2236 "CC cloglog backend requires finite mu, positive sigma, and positive tolerance"
2237 .to_string(),
2238 );
2239 }
2240
2241 let p_tail = (tol / 8.0).clamp(1e-300, 0.25);
2279 let a = gam_math::probability::standard_normal_quantile(p_tail)
2280 .map(|z| -z)
2281 .unwrap_or(8.0)
2282 .max(1.0);
2283 let n = cloglog_cc_required_nodes(mu, sigma, tol)?;
2284 if n > CLOGLOG_CC_NODE_CAP {
2285 crate::bail_invalid_estim!("CC cloglog backend requires too many nodes");
2286 }
2287
2288 let rule = ctx.clenshaw_curtis_n(n);
2289 let inv_sqrt_2pi = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
2290 let mut sum = 0.0_f64;
2291 let mut c = 0.0_f64;
2292 for (&x, &w) in rule.nodes.iter().zip(rule.weights.iter()) {
2293 let t = a * x;
2294 let u = mu + sigma * t;
2295 let e = safe_exp(u);
2296 let w0 = (-0.5 * t * t).exp() * inv_sqrt_2pi;
2297 let yk = w * w0 * (-e).exp() - c;
2298 let tk = sum + yk;
2299 c = (tk - sum) - yk;
2300 sum = tk;
2301 }
2302
2303 let survival = (a * sum).clamp(0.0, 1.0);
2304 if !survival.is_finite() {
2305 crate::bail_invalid_estim!("CC cloglog backend produced non-finite values");
2306 }
2307 Ok(survival)
2308}
2309
2310#[inline]
2311fn complex_add(a: Complex, b: Complex) -> Complex {
2312 Complex {
2313 re: a.re + b.re,
2314 im: a.im + b.im,
2315 }
2316}
2317
2318#[inline]
2319fn complex_sub(a: Complex, b: Complex) -> Complex {
2320 Complex {
2321 re: a.re - b.re,
2322 im: a.im - b.im,
2323 }
2324}
2325
2326#[inline]
2327fn complexmul(a: Complex, b: Complex) -> Complex {
2328 Complex {
2329 re: a.re * b.re - a.im * b.im,
2330 im: a.re * b.im + a.im * b.re,
2331 }
2332}
2333
2334#[inline]
2335fn complex_div(a: Complex, b: Complex) -> Complex {
2336 let den = (b.re * b.re + b.im * b.im).max(1e-300);
2337 Complex {
2338 re: (a.re * b.re + a.im * b.im) / den,
2339 im: (a.im * b.re - a.re * b.im) / den,
2340 }
2341}
2342
2343#[inline]
2344fn complex_abs(z: Complex) -> f64 {
2345 z.re.hypot(z.im)
2346}
2347
2348#[inline]
2349fn complex_ln(z: Complex) -> Complex {
2350 Complex {
2351 re: complex_abs(z).ln(),
2352 im: z.im.atan2(z.re),
2353 }
2354}
2355
2356#[inline]
2357fn complex_exp(z: Complex) -> Complex {
2358 let e = z.re.exp();
2359 Complex {
2360 re: e * z.im.cos(),
2361 im: e * z.im.sin(),
2362 }
2363}
2364
2365#[inline]
2366fn complex_sin(z: Complex) -> Complex {
2367 Complex {
2368 re: z.re.sin() * z.im.cosh(),
2369 im: z.re.cos() * z.im.sinh(),
2370 }
2371}
2372
2373fn complex_log_gamma_lanczos(z: Complex) -> Complex {
2374 const G: f64 = 7.0;
2378 const COEFFS: [f64; 9] = [
2379 0.999_999_999_999_809_9,
2380 676.520_368_121_885_1,
2381 -1_259.139_216_722_402_8,
2382 771.323_428_777_653_1,
2383 -176.615_029_162_140_6,
2384 12.507_343_278_686_905,
2385 -0.138_571_095_265_720_12,
2386 9.984_369_578_019_572e-6,
2387 1.505_632_735_149_311_6e-7,
2388 ];
2389
2390 if z.re < 0.5 {
2391 let piz = Complex {
2392 re: std::f64::consts::PI * z.re,
2393 im: std::f64::consts::PI * z.im,
2394 };
2395 let one_minusz = Complex {
2396 re: 1.0 - z.re,
2397 im: -z.im,
2398 };
2399 return complex_sub(
2400 complex_sub(
2401 Complex {
2402 re: std::f64::consts::PI.ln(),
2403 im: 0.0,
2404 },
2405 complex_ln(complex_sin(piz)),
2406 ),
2407 complex_log_gamma_lanczos(one_minusz),
2408 );
2409 }
2410
2411 let z1 = Complex {
2412 re: z.re - 1.0,
2413 im: z.im,
2414 };
2415 let mut x = Complex {
2416 re: COEFFS[0],
2417 im: 0.0,
2418 };
2419 for (i, c) in COEFFS.iter().enumerate().skip(1) {
2420 x = complex_add(
2421 x,
2422 complex_div(
2423 Complex { re: *c, im: 0.0 },
2424 Complex {
2425 re: z1.re + i as f64,
2426 im: z1.im,
2427 },
2428 ),
2429 );
2430 }
2431 let t = Complex {
2432 re: z1.re + G + 0.5,
2433 im: z1.im,
2434 };
2435 complex_add(
2436 complex_add(
2437 Complex {
2438 re: 0.5 * (2.0 * std::f64::consts::PI).ln(),
2439 im: 0.0,
2440 },
2441 complexmul(
2442 Complex {
2443 re: z1.re + 0.5,
2444 im: z1.im,
2445 },
2446 complex_ln(t),
2447 ),
2448 ),
2449 complex_sub(complex_ln(x), t),
2450 )
2451}
2452
2453fn cloglog_survival_gamma_reference(mu: f64, sigma: f64) -> Result<f64, EstimationError> {
2457 if !(mu.is_finite() && sigma.is_finite()) || sigma <= 0.0 {
2458 crate::bail_invalid_estim!(
2459 "Gamma cloglog reference backend requires finite mu and positive sigma"
2460 );
2461 }
2462
2463 let n = (CLOGLOG_GAMMA_T_MAX_REF / CLOGLOG_GAMMA_H_REF).round() as usize;
2497 let n = if n.is_multiple_of(2) { n } else { n + 1 };
2498 let h = CLOGLOG_GAMMA_T_MAX_REF / n as f64;
2499
2500 let eval = |t: f64| -> f64 {
2501 let z = Complex {
2502 re: CLOGLOG_GAMMA_K_REF,
2503 im: t,
2504 };
2505 let log_gamma = complex_log_gamma_lanczos(z);
2506 let z_sq = complexmul(z, z);
2507 let exponent = complex_sub(
2508 complex_add(
2509 log_gamma,
2510 Complex {
2511 re: 0.5 * sigma * sigma * z_sq.re,
2512 im: 0.5 * sigma * sigma * z_sq.im,
2513 },
2514 ),
2515 Complex {
2516 re: mu * z.re,
2517 im: mu * z.im,
2518 },
2519 );
2520 complex_exp(exponent).re
2521 };
2522
2523 let f0 = eval(0.0);
2524 let fn_ = eval(CLOGLOG_GAMMA_T_MAX_REF);
2525 let mut sum_s = f0 + fn_;
2526 for i in 1..n {
2527 let t = i as f64 * h;
2528 let fi = eval(t);
2529 let w = if i % 2 == 0 { 2.0 } else { 4.0 };
2530 sum_s += w * fi;
2531 }
2532 let sval = ((h / 3.0) * sum_s / std::f64::consts::PI).clamp(0.0, 1.0);
2533 if !sval.is_finite() {
2534 crate::bail_invalid_estim!("Gamma cloglog reference backend produced non-finite values");
2535 }
2536 Ok(sval)
2537}
2538
2539pub(crate) fn cloglog_posterior_meanwith_deriv_controlled(
2540 ctx: &QuadratureContext,
2541 mu: f64,
2542 sigma: f64,
2543) -> IntegratedMeanDerivative {
2544 if !(mu.is_finite() && sigma.is_finite()) || sigma <= CLOGLOG_SIGMA_DEGENERATE {
2589 return IntegratedMeanDerivative {
2590 mean: cloglog_mean_exact(mu),
2593 dmean_dmu: cloglog_mean_d1_exact(mu),
2596 mode: IntegratedExpectationMode::ExactClosedForm,
2597 };
2598 }
2599 if sigma >= CLOGLOG_LARGE_SIGMA_ASYMPTOTIC_MIN {
2600 let (log_base, base_mode) = cloglog_log_survival_term_controlled(ctx, mu, sigma);
2606 let (log_shift, shift_mode) =
2607 cloglog_log_survival_term_controlled(ctx, mu + sigma * sigma, sigma);
2608 let mean = (-log_base.exp_m1()).clamp(0.0, 1.0);
2610 let dmean = cloglog_shift_identity_derivative_log(mu, sigma, log_shift);
2611 return IntegratedMeanDerivative {
2612 mean,
2613 dmean_dmu: dmean.max(0.0),
2614 mode: worse_integrated_expectation_mode(base_mode, shift_mode),
2615 };
2616 }
2617 let candidate = if sigma < CLOGLOG_SIGMA_TAYLOR_MAX {
2618 cloglog_small_sigma_taylor(mu, sigma)
2619 } else if let Some(out) = cloglog_extreme_asymptotic(mu, sigma) {
2620 out
2621 } else {
2622 let ((survival, mode), (shifted_survival, shifted_mode)) =
2623 cloglog_survival_pair_controlled(ctx, mu, sigma);
2624 if matches!(mode, IntegratedExpectationMode::QuadratureFallback)
2625 || matches!(shifted_mode, IntegratedExpectationMode::QuadratureFallback)
2626 {
2627 return cloglog_posterior_meanwith_deriv_quadrature(mu, sigma);
2628 }
2629 let mean = cloglog_mean_from_survival(survival);
2630 let dmean = cloglog_shift_identity_derivative(mu, sigma, shifted_survival);
2631 let mode = if matches!(mode, IntegratedExpectationMode::ControlledAsymptotic)
2632 || matches!(
2633 shifted_mode,
2634 IntegratedExpectationMode::ControlledAsymptotic
2635 ) {
2636 IntegratedExpectationMode::ControlledAsymptotic
2637 } else {
2638 mode
2639 };
2640 IntegratedMeanDerivative {
2641 mean,
2642 dmean_dmu: dmean.max(0.0),
2643 mode,
2644 }
2645 };
2646 if matches!(
2652 candidate.mode,
2653 IntegratedExpectationMode::ControlledAsymptotic
2654 ) && sigma >= CLOGLOG_LARGE_SIGMA_ASYMPTOTIC_MIN
2655 {
2656 return candidate;
2657 }
2658 let ghq = cloglog_posterior_meanwith_deriv_quadrature(mu, sigma);
2659 if integrated_mean_derivative_drift_exceeds(&candidate, &ghq, 1e-6, 1e-4, 1e-7, 1e-3) {
2666 ghq
2667 } else {
2668 candidate
2669 }
2670}
2671
2672pub fn integrated_inverse_link_mean_and_derivative(
2673 quadctx: &QuadratureContext,
2674 link: LinkFunction,
2675 mu: f64,
2676 sigma: f64,
2677) -> Result<IntegratedMeanDerivative, EstimationError> {
2678 match link {
2707 LinkFunction::Log => {
2708 let (mean, saturated) = safe_expwith_saturation(mu + 0.5 * sigma * sigma);
2709 Ok(IntegratedMeanDerivative {
2710 mean,
2711 dmean_dmu: mean,
2712 mode: if saturated {
2713 IntegratedExpectationMode::ControlledAsymptotic
2714 } else {
2715 IntegratedExpectationMode::ExactClosedForm
2716 },
2717 })
2718 }
2719 LinkFunction::Probit => Ok(probit_posterior_meanwith_deriv_exact(mu, sigma)),
2720 LinkFunction::Logit => logit_posterior_meanwith_deriv_controlled(mu, sigma),
2721 LinkFunction::CLogLog => Ok(cloglog_posterior_meanwith_deriv_controlled(quadctx, mu, sigma)),
2722 LinkFunction::LogLog | LinkFunction::Cauchit => {
2723 let component = if matches!(link, LinkFunction::LogLog) {
2725 LinkComponent::LogLog
2726 } else {
2727 LinkComponent::Cauchit
2728 };
2729 let (mean, dmean_dmu, _, _) = integrate_normal_ghq_adaptive(quadctx, mu, sigma, |x| {
2730 component_point_jet(component, x)
2731 });
2732 Ok(IntegratedMeanDerivative {
2733 mean,
2734 dmean_dmu,
2735 mode: if sigma <= 1e-10 {
2736 IntegratedExpectationMode::ExactClosedForm
2737 } else {
2738 IntegratedExpectationMode::QuadratureFallback
2739 },
2740 })
2741 }
2742 LinkFunction::Sas => Err(EstimationError::InvalidInput(
2743 "state-less integrated SAS moments are unsupported; use SAS-aware prediction APIs with explicit (epsilon, log_delta)".to_string(),
2744 )),
2745 LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
2746 "state-less integrated Beta-Logistic moments are unsupported; use link-aware prediction APIs with explicit (delta, epsilon)".to_string(),
2747 )),
2748 LinkFunction::Identity => Ok(IntegratedMeanDerivative {
2749 mean: mu,
2750 dmean_dmu: 1.0,
2751 mode: IntegratedExpectationMode::ExactClosedForm,
2752 }),
2753 }
2754}
2755
2756#[inline]
2757pub fn integrated_inverse_link_jet(
2758 quadctx: &QuadratureContext,
2759 link: LinkFunction,
2760 mu: f64,
2761 sigma: f64,
2762) -> Result<IntegratedInverseLinkJet, EstimationError> {
2763 match link {
2764 LinkFunction::Log => {
2765 let (mean, saturated) = safe_expwith_saturation(mu + 0.5 * sigma * sigma);
2766 Ok(IntegratedInverseLinkJet {
2767 mean,
2768 d1: mean,
2769 d2: mean,
2770 d3: mean,
2771 mode: if saturated {
2772 IntegratedExpectationMode::ControlledAsymptotic
2773 } else {
2774 IntegratedExpectationMode::ExactClosedForm
2775 },
2776 })
2777 }
2778 LinkFunction::Probit => Ok(integrated_probit_jet(mu, sigma)),
2779 LinkFunction::Logit => {
2780 if sigma > LOGIT_JET_GHQ_SIGMA_MAX {
2781 return logit_wide_sigma_jet(mu, sigma);
2785 }
2786 let (mean, d1, d2, d3) = integrate_normal_ghq_adaptive(quadctx, mu, sigma, |x| {
2791 component_point_jet(LinkComponent::Logit, x)
2792 });
2793 let mode = if sigma <= 1e-10 {
2794 IntegratedExpectationMode::ExactClosedForm
2795 } else {
2796 match logit_posterior_meanwith_deriv_controlled(mu, sigma) {
2800 Ok(scalar) => scalar.mode,
2801 Err(_) => IntegratedExpectationMode::QuadratureFallback,
2802 }
2803 };
2804 Ok(IntegratedInverseLinkJet {
2805 mean,
2806 d1: d1.max(0.0),
2807 d2,
2808 d3,
2809 mode,
2810 })
2811 }
2812 LinkFunction::CLogLog => {
2813 validate_latent_cloglog_inputs(mu, sigma)?;
2814 Ok(integrated_cloglog_inverse_link_jet_controlled(
2815 quadctx, mu, sigma,
2816 ))
2817 }
2818 LinkFunction::LogLog | LinkFunction::Cauchit => {
2819 let component = if matches!(link, LinkFunction::LogLog) {
2821 LinkComponent::LogLog
2822 } else {
2823 LinkComponent::Cauchit
2824 };
2825 let (mean, d1, d2, d3) = integrate_normal_ghq_adaptive(quadctx, mu, sigma, |x| {
2826 component_point_jet(component, x)
2827 });
2828 Ok(IntegratedInverseLinkJet {
2829 mean,
2830 d1,
2831 d2,
2832 d3,
2833 mode: if sigma <= 1e-10 {
2834 IntegratedExpectationMode::ExactClosedForm
2835 } else {
2836 IntegratedExpectationMode::QuadratureFallback
2837 },
2838 })
2839 }
2840 LinkFunction::Sas => Err(EstimationError::InvalidInput(
2841 "state-less integrated SAS jet is unsupported; use SAS-aware prediction APIs with explicit (epsilon, log_delta)".to_string(),
2842 )),
2843 LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
2844 "state-less integrated Beta-Logistic jet is unsupported; use link-aware prediction APIs with explicit (delta, epsilon)".to_string(),
2845 )),
2846 LinkFunction::Identity => Ok(IntegratedInverseLinkJet {
2847 mean: mu,
2848 d1: 1.0,
2849 d2: 0.0,
2850 d3: 0.0,
2851 mode: IntegratedExpectationMode::ExactClosedForm,
2852 }),
2853 }
2854}
2855
2856#[inline]
2867fn logit_wide_sigma_jet(mu: f64, sigma: f64) -> Result<IntegratedInverseLinkJet, EstimationError> {
2868 let scalar = logit_posterior_meanwith_deriv_controlled(mu, sigma)?;
2869 let d2 = integrate_normal_adaptive(mu, sigma, |x| {
2870 component_point_jet(LinkComponent::Logit, x).2
2871 });
2872 let d3 = integrate_normal_adaptive(mu, sigma, |x| {
2873 component_point_jet(LinkComponent::Logit, x).3
2874 });
2875 Ok(IntegratedInverseLinkJet {
2876 mean: scalar.mean,
2877 d1: scalar.dmean_dmu.max(0.0),
2878 d2,
2879 d3,
2880 mode: scalar.mode,
2881 })
2882}
2883
2884#[inline]
2885pub fn integrated_logit_inverse_link_jet_pirls(
2886 quadctx: &QuadratureContext,
2887 mu: f64,
2888 sigma: f64,
2889) -> Result<IntegratedInverseLinkJet, EstimationError> {
2890 if sigma <= 1e-10 {
2895 let (mean, d1, d2, d3) = component_point_jet(LinkComponent::Logit, mu);
2896 return Ok(IntegratedInverseLinkJet {
2897 mean,
2898 d1,
2899 d2,
2900 d3,
2901 mode: IntegratedExpectationMode::ExactClosedForm,
2902 });
2903 }
2904 if sigma > LOGIT_JET_GHQ_SIGMA_MAX {
2905 return logit_wide_sigma_jet(mu, sigma);
2906 }
2907 let (mean, d1, d2, d3) = integrate_normal_ghq_adaptive(quadctx, mu, sigma, |x| {
2908 component_point_jet(LinkComponent::Logit, x)
2909 });
2910 let mode = match logit_posterior_meanwith_deriv_controlled(mu, sigma) {
2911 Ok(scalar) => scalar.mode,
2912 Err(_) => IntegratedExpectationMode::QuadratureFallback,
2913 };
2914 Ok(IntegratedInverseLinkJet {
2915 mean,
2916 d1: d1.max(0.0),
2917 d2,
2918 d3,
2919 mode,
2920 })
2921}
2922
2923#[inline]
2924fn sas_point_jet(x: f64, epsilon: f64, log_delta: f64) -> (f64, f64, f64, f64) {
2925 let jet = sas_inverse_link_jet(x, epsilon, log_delta);
2926 (jet.mu, jet.d1, jet.d2, jet.d3)
2927}
2928
2929#[inline]
2930fn beta_logistic_point_jet(x: f64, log_shape_center: f64, epsilon: f64) -> (f64, f64, f64, f64) {
2931 let jet = beta_logistic_inverse_link_jet(x, log_shape_center, epsilon);
2932 (jet.mu, jet.d1, jet.d2, jet.d3)
2933}
2934
2935#[inline]
2936fn worse_integrated_expectation_mode(
2937 lhs: IntegratedExpectationMode,
2938 rhs: IntegratedExpectationMode,
2939) -> IntegratedExpectationMode {
2940 if lhs.rank() >= rhs.rank() { lhs } else { rhs }
2941}
2942
2943#[inline]
2944fn integrated_scalar_drift_exceeds(
2945 candidate: f64,
2946 reference: f64,
2947 abs_tol: f64,
2948 rel_tol: f64,
2949) -> bool {
2950 if !(candidate.is_finite() && reference.is_finite()) {
2951 return true;
2952 }
2953 (candidate - reference).abs() > abs_tol.max(rel_tol * reference.abs().max(candidate.abs()))
2954}
2955
2956#[inline]
2957fn integrated_mean_derivative_drift_exceeds(
2958 candidate: &IntegratedMeanDerivative,
2959 reference: &IntegratedMeanDerivative,
2960 mean_abs_tol: f64,
2961 mean_rel_tol: f64,
2962 deriv_abs_tol: f64,
2963 deriv_rel_tol: f64,
2964) -> bool {
2965 integrated_scalar_drift_exceeds(candidate.mean, reference.mean, mean_abs_tol, mean_rel_tol)
2966 || integrated_scalar_drift_exceeds(
2967 candidate.dmean_dmu,
2968 reference.dmean_dmu,
2969 deriv_abs_tol,
2970 deriv_rel_tol,
2971 )
2972}
2973
2974#[inline]
2975fn component_point_jet(component: LinkComponent, x: f64) -> (f64, f64, f64, f64) {
2976 let jet = component_inverse_link_jet(component, x);
2979 (jet.mu, jet.d1, jet.d2, jet.d3)
2980}
2981
2982#[inline]
2983fn integrated_mixture_component_jet(
2984 ctx: &QuadratureContext,
2985 component: LinkComponent,
2986 mu: f64,
2987 sigma: f64,
2988) -> IntegratedInverseLinkJet {
2989 match component {
2994 LinkComponent::Logit => integrated_inverse_link_jet(ctx, LinkFunction::Logit, mu, sigma)
2995 .unwrap_or_else(|_| integrated_logit_jet_ghq(ctx, mu, sigma)),
2996 LinkComponent::Probit => integrated_probit_jet(mu, sigma),
2997 LinkComponent::CLogLog => integrated_cloglog_inverse_link_jet_controlled(ctx, mu, sigma),
2998 LinkComponent::LogLog | LinkComponent::Cauchit => {
2999 let (mean, d1, d2, d3) = integrate_normal_ghq_adaptive(ctx, mu, sigma, |x| {
3000 component_point_jet(component, x)
3001 });
3002 IntegratedInverseLinkJet {
3003 mean,
3004 d1: d1.max(0.0),
3005 d2,
3006 d3,
3007 mode: if sigma <= 1e-10 {
3008 IntegratedExpectationMode::ExactClosedForm
3009 } else {
3010 IntegratedExpectationMode::QuadratureFallback
3011 },
3012 }
3013 }
3014 }
3015}
3016
3017#[inline]
3018fn integrated_mixture_jet(
3019 ctx: &QuadratureContext,
3020 mu: f64,
3021 sigma: f64,
3022 mixture_state: &MixtureLinkState,
3023) -> Result<IntegratedInverseLinkJet, EstimationError> {
3024 if mixture_state.components.is_empty() {
3029 crate::bail_invalid_estim!(
3030 "integrated mixture-link jet requires at least one blended component"
3031 );
3032 }
3033 if mixture_state.components.len() != mixture_state.pi.len() {
3034 crate::bail_invalid_estim!(
3035 "integrated mixture-link jet requires matching component and weight counts"
3036 );
3037 }
3038
3039 let mut mean = 0.0_f64;
3044 let mut d1 = 0.0_f64;
3045 let mut d2 = 0.0_f64;
3046 let mut d3 = 0.0_f64;
3047 let mut mode = IntegratedExpectationMode::ExactClosedForm;
3048 let mut saw_positive_weight = false;
3049
3050 for (&component, &weight) in mixture_state.components.iter().zip(mixture_state.pi.iter()) {
3051 if weight <= 0.0 {
3052 continue;
3053 }
3054 let jet = integrated_mixture_component_jet(ctx, component, mu, sigma);
3055 mean += weight * jet.mean;
3056 d1 += weight * jet.d1;
3057 d2 += weight * jet.d2;
3058 d3 += weight * jet.d3;
3059 if jet.mode.rank() > mode.rank() {
3060 mode = jet.mode;
3061 }
3062 saw_positive_weight = true;
3063 }
3064
3065 if !saw_positive_weight {
3066 crate::bail_invalid_estim!(
3067 "integrated mixture-link jet requires at least one positive component weight"
3068 .to_string(),
3069 );
3070 }
3071
3072 Ok(IntegratedInverseLinkJet {
3073 mean,
3074 d1: d1.max(0.0),
3075 d2,
3076 d3,
3077 mode,
3078 })
3079}
3080
3081#[inline]
3082fn integrated_sas_jet_ghq(
3083 ctx: &QuadratureContext,
3084 mu: f64,
3085 sigma: f64,
3086 sas_state: &SasLinkState,
3087) -> IntegratedInverseLinkJet {
3088 let (mean, d1, d2, d3) = integrate_normal_ghq_adaptive(ctx, mu, sigma, |x| {
3089 sas_point_jet(x, sas_state.epsilon, sas_state.log_delta)
3090 });
3091 IntegratedInverseLinkJet {
3092 mean,
3093 d1: d1.max(0.0),
3094 d2,
3095 d3,
3096 mode: if sigma <= 1e-10 {
3097 IntegratedExpectationMode::ExactClosedForm
3098 } else {
3099 IntegratedExpectationMode::QuadratureFallback
3100 },
3101 }
3102}
3103
3104#[inline]
3105fn integrated_beta_logistic_jet_ghq(
3106 ctx: &QuadratureContext,
3107 mu: f64,
3108 sigma: f64,
3109 beta_state: &SasLinkState,
3110) -> IntegratedInverseLinkJet {
3111 let (mean, d1, d2, d3) = integrate_normal_ghq_adaptive(ctx, mu, sigma, |x| {
3112 beta_logistic_point_jet(x, beta_state.log_delta, beta_state.epsilon)
3113 });
3114 IntegratedInverseLinkJet {
3115 mean,
3116 d1: d1.max(0.0),
3117 d2,
3118 d3,
3119 mode: if sigma <= 1e-10 {
3120 IntegratedExpectationMode::ExactClosedForm
3121 } else {
3122 IntegratedExpectationMode::QuadratureFallback
3123 },
3124 }
3125}
3126
3127#[inline]
3129pub fn integrated_inverse_link_jetwith_state(
3130 quadctx: &QuadratureContext,
3131 link: LinkFunction,
3132 mu: f64,
3133 sigma: f64,
3134 mixture_link_state: Option<&MixtureLinkState>,
3135 sas_link_state: Option<&SasLinkState>,
3136) -> Result<IntegratedInverseLinkJet, EstimationError> {
3137 if let Some(state) = mixture_link_state {
3138 return integrated_mixture_jet(quadctx, mu, sigma, state);
3139 }
3140 if matches!(link, LinkFunction::Sas) {
3141 let sas = sas_link_state.ok_or_else(|| {
3142 EstimationError::InvalidInput(
3143 "state-less integrated SAS jet is unsupported; explicit SasLinkState is required"
3144 .to_string(),
3145 )
3146 })?;
3147 return Ok(integrated_sas_jet_ghq(quadctx, mu, sigma, sas));
3148 }
3149 if matches!(link, LinkFunction::BetaLogistic) {
3150 let state = sas_link_state.ok_or_else(|| {
3151 EstimationError::InvalidInput(
3152 "state-less integrated Beta-Logistic jet is unsupported; explicit link state is required"
3153 .to_string(),
3154 )
3155 })?;
3156 return Ok(integrated_beta_logistic_jet_ghq(quadctx, mu, sigma, state));
3157 }
3158 integrated_inverse_link_jet(quadctx, link, mu, sigma)
3159}
3160
3161#[inline]
3176pub fn integrated_family_moments_jet(
3177 quadctx: &QuadratureContext,
3178 likelihood: &LikelihoodSpec,
3179 scale: LikelihoodScaleMetadata,
3180 eta: f64,
3181 se_eta: f64,
3182) -> Result<IntegratedMomentsJet, EstimationError> {
3183 const PROB_EPS: f64 = 1e-12;
3184 if !(eta.is_finite() && (-700.0..=700.0).contains(&eta)) {
3185 crate::bail_invalid_estim!(
3186 "integrated moments eta must be finite and within [-700, 700]; got {eta}"
3187 );
3188 }
3189 let e = eta;
3190 let se = se_eta.max(0.0);
3191 let mixture_link_state: Option<&MixtureLinkState> = likelihood.link.mixture_state();
3195 let sas_link_state: Option<&SasLinkState> = likelihood.link.sas_state();
3196 match &likelihood.response {
3197 ResponseFamily::Binomial => match &likelihood.link {
3198 InverseLink::Standard(StandardLink::Logit) => {
3199 let jet = integrated_inverse_link_jet(quadctx, LinkFunction::Logit, e, se)?;
3200 let mean = jet.mean;
3201 Ok(IntegratedMomentsJet {
3202 mean,
3203 variance: (mean * (1.0 - mean)).max(PROB_EPS),
3204 d1: jet.d1,
3205 d2: jet.d2,
3206 d3: jet.d3,
3207 mode: jet.mode,
3208 })
3209 }
3210 InverseLink::Standard(StandardLink::Probit) => {
3211 let jet = integrated_inverse_link_jet(quadctx, LinkFunction::Probit, e, se)?;
3212 let mean = jet.mean;
3213 Ok(IntegratedMomentsJet {
3214 mean,
3215 variance: (mean * (1.0 - mean)).max(PROB_EPS),
3216 d1: jet.d1,
3217 d2: jet.d2,
3218 d3: jet.d3,
3219 mode: jet.mode,
3220 })
3221 }
3222 InverseLink::Standard(StandardLink::CLogLog) => {
3223 let jet = integrated_inverse_link_jet(quadctx, LinkFunction::CLogLog, e, se)?;
3224 let mean = jet.mean;
3225 Ok(IntegratedMomentsJet {
3226 mean,
3227 variance: (mean * (1.0 - mean)).max(PROB_EPS),
3228 d1: jet.d1,
3229 d2: jet.d2,
3230 d3: jet.d3,
3231 mode: jet.mode,
3232 })
3233 }
3234 InverseLink::LatentCLogLog(_) => Err(EstimationError::InvalidInput(
3235 "Binomial+LatentCLogLog integrated moments require an explicit latent cloglog inverse-link state"
3236 .to_string(),
3237 )),
3238 InverseLink::Sas(_) => {
3239 let jet = integrated_inverse_link_jetwith_state(
3240 quadctx,
3241 LinkFunction::Sas,
3242 e,
3243 se,
3244 mixture_link_state,
3245 sas_link_state,
3246 )?;
3247 let mean = jet.mean;
3248 Ok(IntegratedMomentsJet {
3249 mean,
3250 variance: (mean * (1.0 - mean)).max(PROB_EPS),
3251 d1: jet.d1,
3252 d2: jet.d2,
3253 d3: jet.d3,
3254 mode: jet.mode,
3255 })
3256 }
3257 InverseLink::BetaLogistic(_) => {
3258 let jet = integrated_inverse_link_jetwith_state(
3259 quadctx,
3260 LinkFunction::BetaLogistic,
3261 e,
3262 se,
3263 mixture_link_state,
3264 sas_link_state,
3265 )?;
3266 let mean = jet.mean;
3267 Ok(IntegratedMomentsJet {
3268 mean,
3269 variance: (mean * (1.0 - mean)).max(PROB_EPS),
3270 d1: jet.d1,
3271 d2: jet.d2,
3272 d3: jet.d3,
3273 mode: jet.mode,
3274 })
3275 }
3276 InverseLink::Mixture(state) => {
3277 let jet = integrated_mixture_jet(quadctx, e, se, state)?;
3278 let mean = jet.mean;
3279 Ok(IntegratedMomentsJet {
3280 mean,
3281 variance: (mean * (1.0 - mean)).max(PROB_EPS),
3282 d1: jet.d1,
3283 d2: jet.d2,
3284 d3: jet.d3,
3285 mode: jet.mode,
3286 })
3287 }
3288 InverseLink::Standard(other) => Err(EstimationError::InvalidInput(format!(
3289 "Binomial response paired with unsupported standard link {other:?} for integrated moments"
3290 ))),
3291 },
3292 ResponseFamily::Gaussian => Ok(IntegratedMomentsJet {
3293 mean: e,
3294 variance: 1.0,
3295 d1: 1.0,
3296 d2: 0.0,
3297 d3: 0.0,
3298 mode: IntegratedExpectationMode::ExactClosedForm,
3299 }),
3300 ResponseFamily::RoystonParmar => {
3301 let jet = integrated_inverse_link_jetwith_state(
3302 quadctx,
3303 LinkFunction::CLogLog,
3304 e,
3305 se,
3306 mixture_link_state,
3307 sas_link_state,
3308 )?;
3309 let mean = (1.0 - jet.mean).clamp(0.0, 1.0);
3310 Ok(IntegratedMomentsJet {
3311 mean,
3312 variance: (mean * (1.0 - mean)).max(PROB_EPS),
3313 d1: -jet.d1,
3314 d2: -jet.d2,
3315 d3: -jet.d3,
3316 mode: jet.mode,
3317 })
3318 }
3319 ResponseFamily::Beta { phi } => {
3320 let jet = integrated_inverse_link_jet(quadctx, LinkFunction::Logit, e, se)?;
3321 let mean = jet.mean.clamp(PROB_EPS, 1.0 - PROB_EPS);
3322 Ok(IntegratedMomentsJet {
3323 mean,
3324 variance: (mean * (1.0 - mean) / (1.0 + phi.max(1e-12))).max(PROB_EPS),
3325 d1: jet.d1,
3326 d2: jet.d2,
3327 d3: jet.d3,
3328 mode: jet.mode,
3329 })
3330 }
3331 ResponseFamily::Poisson
3332 | ResponseFamily::Tweedie { .. }
3333 | ResponseFamily::NegativeBinomial { .. }
3334 | ResponseFamily::Gamma => {
3335 let s2 = se * se;
3340 let (mean, saturated) = safe_expwith_saturation(e + 0.5 * s2);
3341 let variance = match &likelihood.response {
3352 ResponseFamily::Poisson => mean,
3353 ResponseFamily::Tweedie { p } => {
3354 let phi = scale.fixed_phi().ok_or_else(|| {
3355 EstimationError::InvalidInput(format!(
3356 "Tweedie integrated variance requires dispersion φ in the scale \
3357 metadata (Var = φ·μ^p); got {scale:?} with no φ"
3358 ))
3359 })?;
3360 phi * mean.powf(*p)
3361 }
3362 ResponseFamily::NegativeBinomial { theta, .. } => {
3363 mean + mean * mean / theta.max(1e-12)
3364 }
3365 ResponseFamily::Gamma => {
3366 let shape = scale.gamma_shape().ok_or_else(|| {
3367 EstimationError::InvalidInput(format!(
3368 "Gamma integrated variance requires the shape k in the scale \
3369 metadata (Var = μ²/k = φ·μ²); got {scale:?} with no shape"
3370 ))
3371 })?;
3372 mean * mean / shape.max(1e-12)
3373 }
3374 other => {
3378 return Err(EstimationError::InvalidInput(format!(
3379 "integrated log-normal moments reached unexpected family {other:?}"
3380 )));
3381 }
3382 };
3383 Ok(IntegratedMomentsJet {
3384 mean,
3385 variance,
3386 d1: mean,
3387 d2: mean,
3388 d3: mean,
3389 mode: if saturated {
3390 IntegratedExpectationMode::ControlledAsymptotic
3391 } else {
3392 IntegratedExpectationMode::ExactClosedForm
3393 },
3394 })
3395 }
3396 }
3397}
3398
3399pub fn logit_posterior_meanwith_deriv_batch(
3402 ctx: &QuadratureContext,
3403 eta: &ndarray::Array1<f64>,
3404 se_eta: &ndarray::Array1<f64>,
3405) -> Result<(ndarray::Array1<f64>, ndarray::Array1<f64>), EstimationError> {
3406 use rayon::iter::{IntoParallelIterator, ParallelIterator};
3407 let n = eta.len();
3408 let pairs: Result<Vec<(f64, f64)>, _> = (0..n)
3410 .into_par_iter()
3411 .map(|i| {
3412 let integrated = integrated_inverse_link_mean_and_derivative(
3413 ctx,
3414 LinkFunction::Logit,
3415 eta[i],
3416 se_eta[i],
3417 )?;
3418 Ok::<_, EstimationError>((integrated.mean, integrated.dmean_dmu))
3419 })
3420 .collect();
3421 let pairs = pairs?;
3422 let mut mu = ndarray::Array1::<f64>::zeros(n);
3423 let mut dmu = ndarray::Array1::<f64>::zeros(n);
3424 for (i, (m, d)) in pairs.into_iter().enumerate() {
3425 mu[i] = m;
3426 dmu[i] = d;
3427 }
3428
3429 Ok((mu, dmu))
3430}
3431
3432pub fn logit_posterior_mean_batch(
3436 ctx: &QuadratureContext,
3437 eta: &ndarray::Array1<f64>,
3438 se_eta: &ndarray::Array1<f64>,
3439) -> Result<ndarray::Array1<f64>, EstimationError> {
3440 use rayon::iter::{IntoParallelIterator, ParallelIterator};
3441 let n = eta.len();
3442 let values: Result<Vec<f64>, EstimationError> = (0..n)
3443 .into_par_iter()
3444 .map(|i| {
3445 integrated_inverse_link_mean_and_derivative(ctx, LinkFunction::Logit, eta[i], se_eta[i])
3446 .map(|integrated| integrated.mean)
3447 })
3448 .collect();
3449 Ok(ndarray::Array1::from_vec(values?))
3450}
3451
3452pub trait GhqValue: Sized {
3453 fn zero() -> Self;
3454 fn addweighted(&mut self, weight: f64, value: Self);
3455 fn scale(self, factor: f64) -> Self;
3456}
3457
3458impl GhqValue for f64 {
3459 #[inline]
3460 fn zero() -> Self {
3461 0.0
3462 }
3463
3464 #[inline]
3465 fn addweighted(&mut self, weight: f64, value: Self) {
3466 *self += weight * value;
3467 }
3468
3469 #[inline]
3470 fn scale(self, factor: f64) -> Self {
3471 self * factor
3472 }
3473}
3474
3475impl GhqValue for (f64, f64) {
3476 #[inline]
3477 fn zero() -> Self {
3478 (0.0, 0.0)
3479 }
3480
3481 #[inline]
3482 fn addweighted(&mut self, weight: f64, value: Self) {
3483 self.0 += weight * value.0;
3484 self.1 += weight * value.1;
3485 }
3486
3487 #[inline]
3488 fn scale(self, factor: f64) -> Self {
3489 (self.0 * factor, self.1 * factor)
3490 }
3491}
3492
3493impl GhqValue for (f64, f64, f64, f64) {
3494 #[inline]
3495 fn zero() -> Self {
3496 (0.0, 0.0, 0.0, 0.0)
3497 }
3498
3499 #[inline]
3500 fn addweighted(&mut self, weight: f64, value: Self) {
3501 self.0 += weight * value.0;
3502 self.1 += weight * value.1;
3503 self.2 += weight * value.2;
3504 self.3 += weight * value.3;
3505 }
3506
3507 #[inline]
3508 fn scale(self, factor: f64) -> Self {
3509 (
3510 self.0 * factor,
3511 self.1 * factor,
3512 self.2 * factor,
3513 self.3 * factor,
3514 )
3515 }
3516}
3517
3518impl GhqValue for (f64, f64, f64, f64, f64, f64) {
3519 #[inline]
3520 fn zero() -> Self {
3521 (0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
3522 }
3523
3524 #[inline]
3525 fn addweighted(&mut self, weight: f64, value: Self) {
3526 self.0 += weight * value.0;
3527 self.1 += weight * value.1;
3528 self.2 += weight * value.2;
3529 self.3 += weight * value.3;
3530 self.4 += weight * value.4;
3531 self.5 += weight * value.5;
3532 }
3533
3534 #[inline]
3535 fn scale(self, factor: f64) -> Self {
3536 (
3537 self.0 * factor,
3538 self.1 * factor,
3539 self.2 * factor,
3540 self.3 * factor,
3541 self.4 * factor,
3542 self.5 * factor,
3543 )
3544 }
3545}
3546
3547#[inline]
3548fn integrate_normal_ghq_adaptive<F, R>(ctx: &QuadratureContext, eta: f64, se_eta: f64, f: F) -> R
3549where
3550 F: Fn(f64) -> R,
3551 R: GhqValue,
3552{
3553 if se_eta < 1e-10 {
3554 return f(eta);
3555 }
3556 let n = adaptive_point_count_from_sd(se_eta.abs());
3557 with_gh_nodesweights(ctx, n, |nodes, weights| {
3558 let scale = SQRT_2 * se_eta;
3559 let mut sum = R::zero();
3560 for i in 0..n {
3561 sum.addweighted(weights[i], f(eta + scale * nodes[i]));
3562 }
3563 sum.scale(1.0 / std::f64::consts::PI.sqrt())
3564 })
3565}
3566
3567#[inline]
3568fn integrated_probit_jet(mu: f64, sigma: f64) -> IntegratedInverseLinkJet {
3569 if sigma <= 1e-10 {
3570 let z = mu.clamp(-30.0, 30.0);
3571 let clamp_active = z != mu;
3572 let pdf = gam_math::probability::normal_pdf(z);
3573 return IntegratedInverseLinkJet {
3574 mean: gam_math::probability::normal_cdf(z),
3575 d1: if clamp_active { 0.0 } else { pdf },
3576 d2: if clamp_active { 0.0 } else { -z * pdf },
3577 d3: if clamp_active {
3578 0.0
3579 } else {
3580 (z * z - 1.0) * pdf
3581 },
3582 mode: IntegratedExpectationMode::ExactClosedForm,
3583 };
3584 }
3585 let s = (1.0 + sigma * sigma).sqrt();
3586 let z = mu / s;
3587 let pdf = gam_math::probability::normal_pdf(z);
3588 IntegratedInverseLinkJet {
3589 mean: gam_math::probability::normal_cdf(z),
3590 d1: pdf / s,
3591 d2: -z * pdf / (s * s),
3592 d3: (z * z - 1.0) * pdf / (s * s * s),
3593 mode: IntegratedExpectationMode::ExactClosedForm,
3594 }
3595}
3596
3597#[inline]
3598fn integrated_logit_jet_ghq(
3599 ctx: &QuadratureContext,
3600 mu: f64,
3601 sigma: f64,
3602) -> IntegratedInverseLinkJet {
3603 let (mean, d1, d2, d3) = integrate_normal_ghq_adaptive(ctx, mu, sigma, |x| {
3604 component_point_jet(LinkComponent::Logit, x)
3605 });
3606 IntegratedInverseLinkJet {
3607 mean,
3608 d1: d1.max(0.0),
3609 d2,
3610 d3,
3611 mode: if sigma <= 1e-10 {
3612 IntegratedExpectationMode::ExactClosedForm
3613 } else {
3614 IntegratedExpectationMode::QuadratureFallback
3615 },
3616 }
3617}
3618
3619#[inline]
3620fn cloglog_inverse_link_controlled_values(
3621 ctx: &QuadratureContext,
3622 mu: f64,
3623 sigma: f64,
3624 max_order: usize,
3625) -> ([f64; 6], IntegratedExpectationMode) {
3626 assert!(max_order <= 5);
3627 if sigma <= 1e-10 {
3628 let (mean, d1, d2, d3, d4, d5) = cloglog_point_jet5(mu);
3629 return (
3630 [mean, d1, d2, d3, d4, d5],
3631 IntegratedExpectationMode::ExactClosedForm,
3632 );
3633 }
3634
3635 let (k, log_k0, mode) = latent_cloglog_kernel_terms(ctx, mu, sigma, max_order);
3636 let mut values = [0.0; 6];
3637 values[0] = if log_k0.is_finite() {
3638 -log_k0.exp_m1()
3639 } else {
3640 1.0
3641 };
3642 values[1] = k[1].max(0.0);
3643 if sigma > CLOGLOG_JET_MOMENT_SIGMA_MAX {
3644 if max_order >= 2 {
3645 values[2] = integrate_normal_adaptive(mu, sigma, |x| cloglog_point_jet5(x).2);
3646 }
3647 if max_order >= 3 {
3648 values[3] = integrate_normal_adaptive(mu, sigma, |x| cloglog_point_jet5(x).3);
3649 }
3650 if max_order >= 4 {
3651 values[4] = integrate_normal_adaptive(mu, sigma, |x| cloglog_point_jet5(x).4);
3652 }
3653 if max_order >= 5 {
3654 values[5] = integrate_normal_adaptive(mu, sigma, |x| cloglog_point_jet5(x).5);
3655 }
3656 return (
3657 values,
3658 worse_integrated_expectation_mode(mode, IntegratedExpectationMode::QuadratureFallback),
3659 );
3660 }
3661 if max_order >= 2 {
3662 values[2] = k[1] - k[2];
3663 }
3664 if max_order >= 3 {
3665 values[3] = k[1] - 3.0 * k[2] + k[3];
3666 }
3667 if max_order >= 4 {
3668 values[4] = k[1] - 7.0 * k[2] + 6.0 * k[3] - k[4];
3669 }
3670 if max_order >= 5 {
3671 values[5] = k[1] - 15.0 * k[2] + 25.0 * k[3] - 10.0 * k[4] + k[5];
3672 }
3673 (values, mode)
3674}
3675
3676#[inline]
3677pub(crate) fn latent_cloglog_inverse_link_jet5_controlled(
3678 ctx: &QuadratureContext,
3679 mu: f64,
3680 sigma: f64,
3681) -> IntegratedInverseLinkJet5 {
3682 let (values, mode) = cloglog_inverse_link_controlled_values(ctx, mu, sigma, 5);
3683 IntegratedInverseLinkJet5 {
3684 mean: values[0],
3685 d1: values[1],
3686 d2: values[2],
3687 d3: values[3],
3688 d4: values[4],
3689 d5: values[5],
3690 mode,
3691 }
3692}
3693
3694#[derive(Clone, Copy, Debug)]
3704pub struct LatentCLogLogJet5 {
3705 pub mean: f64,
3706 pub d1: f64,
3707 pub d2: f64,
3708 pub d3: f64,
3709 pub d4: f64,
3710 pub d5: f64,
3711 pub mode: IntegratedExpectationMode,
3712}
3713
3714pub fn latent_cloglog_jet5(
3715 quadctx: &QuadratureContext,
3716 eta: f64,
3717 sigma: f64,
3718) -> Result<LatentCLogLogJet5, EstimationError> {
3719 validate_latent_cloglog_inputs(eta, sigma)?;
3720 let jet = latent_cloglog_inverse_link_jet5_controlled(quadctx, eta, sigma);
3726 Ok(LatentCLogLogJet5 {
3727 mean: jet.mean,
3728 d1: jet.d1,
3729 d2: jet.d2,
3730 d3: jet.d3,
3731 d4: jet.d4,
3732 d5: jet.d5,
3733 mode: jet.mode,
3734 })
3735}
3736
3737#[inline]
3738pub fn latent_cloglog_inverse_link_jet(
3739 quadctx: &QuadratureContext,
3740 eta: f64,
3741 sigma: f64,
3742) -> Result<IntegratedInverseLinkJet, EstimationError> {
3743 let jet = latent_cloglog_jet5(quadctx, eta, sigma)?;
3744 Ok(IntegratedInverseLinkJet {
3745 mean: jet.mean,
3746 d1: jet.d1,
3747 d2: jet.d2,
3748 d3: jet.d3,
3749 mode: jet.mode,
3750 })
3751}
3752
3753#[inline]
3754fn integrated_cloglog_inverse_link_jet_controlled(
3755 ctx: &QuadratureContext,
3756 mu: f64,
3757 sigma: f64,
3758) -> IntegratedInverseLinkJet {
3759 let (values, mode) = cloglog_inverse_link_controlled_values(ctx, mu, sigma, 3);
3760 IntegratedInverseLinkJet {
3761 mean: values[0],
3762 d1: values[1],
3763 d2: values[2],
3764 d3: values[3],
3765 mode,
3766 }
3767}
3768
3769#[inline]
3770fn latent_cloglog_kernel_terms(
3771 ctx: &QuadratureContext,
3772 mu: f64,
3773 sigma: f64,
3774 max_order: usize,
3775) -> ([f64; 6], f64, IntegratedExpectationMode) {
3776 let sigma2 = sigma * sigma;
3777 let mut k = [0.0; 6];
3778 let mut log_k0 = f64::NEG_INFINITY;
3779 let mut mode = IntegratedExpectationMode::ExactClosedForm;
3780
3781 for (order, out) in k.iter_mut().enumerate().take(max_order + 1) {
3782 let kf = order as f64;
3783 let shifted_mu = mu + kf * sigma2;
3784 let (log_survival, term_mode) =
3792 cloglog_log_survival_term_controlled(ctx, shifted_mu, sigma);
3793 mode = worse_integrated_expectation_mode(mode, term_mode);
3794
3795 let log_value = kf * mu + 0.5 * kf * kf * sigma2 + log_survival;
3796 if order == 0 {
3797 log_k0 = log_value;
3798 }
3799 if !log_value.is_finite() {
3800 *out = 0.0;
3801 continue;
3802 }
3803 let upper = if order == 0 {
3804 1.0
3805 } else {
3806 let k_over_e = kf / std::f64::consts::E;
3807 k_over_e.powf(kf)
3808 };
3809 *out = safe_exp(log_value).clamp(0.0, upper);
3810 }
3811
3812 (k, log_k0, mode)
3813}
3814
3815#[inline]
3816pub fn normal_expectation_1d_adaptive<F>(
3817 ctx: &QuadratureContext,
3818 eta: f64,
3819 se_eta: f64,
3820 f: F,
3821) -> f64
3822where
3823 F: Fn(f64) -> f64,
3824{
3825 integrate_normal_ghq_adaptive(ctx, eta, se_eta, f)
3826}
3827
3828#[inline]
3829pub fn normal_expectation_1d_adaptive_pair<F>(
3830 ctx: &QuadratureContext,
3831 eta: f64,
3832 se_eta: f64,
3833 f: F,
3834) -> (f64, f64)
3835where
3836 F: Fn(f64) -> (f64, f64),
3837{
3838 integrate_normal_ghq_adaptive(ctx, eta, se_eta, f)
3839}
3840
3841fn adaptive_point_count_from_sd(max_sd: f64) -> usize {
3842 if max_sd.is_finite() && max_sd > 2.5 {
3852 51
3853 } else if max_sd.is_finite() && max_sd > 0.5 {
3854 31
3855 } else if max_sd.is_finite() && max_sd > 0.35 {
3856 21
3857 } else if max_sd.is_finite() && max_sd > 0.1 {
3858 15
3859 } else {
3860 7
3861 }
3862}
3863
3864#[inline]
3865fn with_gh_nodesweights<R>(
3866 ctx: &QuadratureContext,
3867 n: usize,
3868 f: impl FnOnce(&[f64], &[f64]) -> R,
3869) -> R {
3870 if n == 7 {
3871 let gh = ctx.gauss_hermite();
3872 f(&gh.nodes, &gh.weights)
3873 } else {
3874 let gh = ctx.gauss_hermite_n(n);
3875 f(&gh.nodes, &gh.weights)
3876 }
3877}
3878
3879#[inline]
3889fn cholesky_static<const D: usize>(cov: &[[f64; D]; D]) -> Option<[[f64; D]; D]> {
3890 let mut l = [[0.0_f64; D]; D];
3891 for i in 0..D {
3892 for j in 0..=i {
3893 let mut sum = cov[i][j];
3894 for k in 0..j {
3895 sum -= l[i][k] * l[j][k];
3896 }
3897 if i == j {
3898 if !sum.is_finite() || sum <= 0.0 {
3899 return None;
3900 }
3901 l[i][j] = sum.sqrt();
3902 } else {
3903 l[i][j] = sum / l[j][j];
3904 }
3905 }
3906 }
3907 Some(l)
3908}
3909
3910#[inline]
3913fn cholesky_static_with_jitter<const D: usize>(cov: &[[f64; D]; D]) -> Option<[[f64; D]; D]> {
3914 if D == 0 {
3915 return None;
3916 }
3917 for retry in 0..8 {
3918 let jitter = if retry == 0 {
3919 0.0
3920 } else {
3921 1e-12 * 10f64.powi(retry - 1)
3922 };
3923 if jitter == 0.0 {
3924 if let Some(l) = cholesky_static::<D>(cov) {
3925 return Some(l);
3926 }
3927 } else {
3928 let mut base = *cov;
3929 for i in 0..D {
3930 base[i][i] = cov[i][i] + jitter;
3931 }
3932 if let Some(l) = cholesky_static::<D>(&base) {
3933 return Some(l);
3934 }
3935 }
3936 }
3937 None
3938}
3939
3940#[inline]
3941fn adaptive_point_countwith_cap(max_sd: f64, max_n: usize) -> usize {
3942 adaptive_point_count_from_sd(max_sd).min(max_n)
3943}
3944
3945#[inline]
3946fn ghq_nd_integrate_try<const D: usize, F, R, E>(
3947 ctx: &QuadratureContext,
3948 mu: [f64; D],
3949 cov: [[f64; D]; D],
3950 max_n: usize,
3951 f: F,
3952) -> Result<Option<R>, E>
3953where
3954 F: Fn([f64; D]) -> Result<R, E>,
3955 R: GhqValue,
3956{
3957 let mut maxvar = 0.0_f64;
3958 for (i, row) in cov.iter().enumerate() {
3959 maxvar = maxvar.max(row[i]).max(0.0);
3960 }
3961 let n = adaptive_point_countwith_cap(maxvar.sqrt(), max_n);
3962
3963 let mut cov_arr = cov;
3968 for i in 0..D {
3969 cov_arr[i][i] = cov_arr[i][i].max(0.0);
3970 }
3971 let Some(l) = cholesky_static_with_jitter::<D>(&cov_arr) else {
3972 return Ok(None);
3973 };
3974 let norm = 1.0 / std::f64::consts::PI.powf(0.5 * D as f64);
3975
3976 with_gh_nodesweights(ctx, n, |nodes, weights| {
3977 let mut acc = R::zero();
3978 let mut idx = [0usize; D];
3979 loop {
3980 let mut z = [0.0_f64; D];
3981 let mut weight = 1.0_f64;
3982 for d in 0..D {
3983 z[d] = SQRT_2 * nodes[idx[d]];
3984 weight *= weights[idx[d]];
3985 }
3986
3987 let mut x = mu;
3988 for row in 0..D {
3989 let mut dot = 0.0_f64;
3990 for (col, zc) in z.iter().enumerate().take(row + 1) {
3991 dot += l[row][col] * *zc;
3992 }
3993 x[row] += dot;
3994 }
3995 acc.addweighted(weight, f(x)?);
3996
3997 let mut carry = true;
3998 for d in (0..D).rev() {
3999 idx[d] += 1;
4000 if idx[d] < n {
4001 carry = false;
4002 break;
4003 }
4004 idx[d] = 0;
4005 }
4006 if carry {
4007 break;
4008 }
4009 }
4010 Ok(Some(acc.scale(norm)))
4011 })
4012}
4013
4014#[inline]
4015fn ghq_nd_integrate<const D: usize, F, R>(
4016 ctx: &QuadratureContext,
4017 mu: [f64; D],
4018 cov: [[f64; D]; D],
4019 max_n: usize,
4020 f: F,
4021) -> Option<R>
4022where
4023 F: Fn([f64; D]) -> R,
4024 R: GhqValue,
4025{
4026 match ghq_nd_integrate_try::<D, _, R, Infallible>(ctx, mu, cov, max_n, |x| Ok(f(x))) {
4027 Ok(v) => v,
4028 Err(e) => match e {},
4029 }
4030}
4031
4032#[inline]
4033fn ghq_nd_integrate_result<const D: usize, F, R, E>(
4034 ctx: &QuadratureContext,
4035 mu: [f64; D],
4036 cov: [[f64; D]; D],
4037 max_n: usize,
4038 f: F,
4039) -> Result<Option<R>, E>
4040where
4041 F: Fn([f64; D]) -> Result<R, E>,
4042 R: GhqValue,
4043{
4044 ghq_nd_integrate_try::<D, _, R, E>(ctx, mu, cov, max_n, f)
4045}
4046
4047pub fn normal_expectation_nd_adaptive<const D: usize, F>(
4049 ctx: &QuadratureContext,
4050 mu: [f64; D],
4051 cov: [[f64; D]; D],
4052 max_n: usize,
4053 f: F,
4054) -> f64
4055where
4056 F: Fn([f64; D]) -> f64,
4057{
4058 match ghq_nd_integrate::<D, _, f64>(ctx, mu, cov, max_n, &f) {
4059 Some(v) => v,
4060 None => f(mu),
4061 }
4062}
4063
4064pub fn normal_expectation_nd_adaptive_result<const D: usize, F, R, E>(
4066 ctx: &QuadratureContext,
4067 mu: [f64; D],
4068 cov: [[f64; D]; D],
4069 max_n: usize,
4070 f: F,
4071) -> Result<R, E>
4072where
4073 F: Fn([f64; D]) -> Result<R, E>,
4074 R: GhqValue,
4075{
4076 match ghq_nd_integrate_result::<D, _, R, E>(ctx, mu, cov, max_n, &f)? {
4077 Some(v) => Ok(v),
4078 None => f(mu),
4079 }
4080}
4081
4082pub fn normal_expectation_2d_adaptive_result<F, E>(
4084 ctx: &QuadratureContext,
4085 mu: [f64; 2],
4086 cov: [[f64; 2]; 2],
4087 f: F,
4088) -> Result<f64, E>
4089where
4090 F: Fn(f64, f64) -> Result<f64, E>,
4091{
4092 normal_expectation_nd_adaptive_result::<2, _, _, E>(ctx, mu, cov, 21, |x| f(x[0], x[1]))
4093}
4094
4095pub fn normal_expectation_3d_adaptive<F>(
4097 ctx: &QuadratureContext,
4098 mu: [f64; 3],
4099 cov: [[f64; 3]; 3],
4100 f: F,
4101) -> f64
4102where
4103 F: Fn(f64, f64, f64) -> f64,
4104{
4105 normal_expectation_nd_adaptive::<3, _>(ctx, mu, cov, 15, |x| f(x[0], x[1], x[2]))
4107}
4108
4109#[inline]
4128pub fn probit_posterior_mean(eta: f64, se_eta: f64) -> f64 {
4129 if se_eta < 1e-10 {
4130 return gam_math::probability::normal_cdf(eta);
4131 }
4132 let denom = (1.0 + se_eta * se_eta).sqrt();
4133 gam_math::probability::normal_cdf(eta / denom)
4134}
4135
4136#[inline]
4137pub fn logit_posterior_meanvariance(ctx: &QuadratureContext, eta: f64, se_eta: f64) -> (f64, f64) {
4138 let (m1, m2) = integrate_normal_ghq_adaptive(ctx, eta, se_eta, |x| {
4139 let p = sigmoid(x);
4140 (p, p * p)
4141 });
4142 let m1 = m1.clamp(0.0, 1.0);
4143 let m2 = m2.clamp(0.0, 1.0);
4144 (m1, (m2 - m1 * m1).max(0.0))
4145}
4146
4147#[inline]
4148pub fn probit_posterior_meanvariance(ctx: &QuadratureContext, eta: f64, se_eta: f64) -> (f64, f64) {
4149 let m1 = probit_posterior_mean(eta, se_eta);
4150 let m2 = integrate_normal_ghq_adaptive(ctx, eta, se_eta, |x| {
4151 let p = gam_math::probability::normal_cdf(x);
4152 p * p
4153 })
4154 .clamp(0.0, 1.0);
4155 (m1, (m2 - m1 * m1).max(0.0))
4156}
4157
4158#[inline]
4159pub fn cloglog_posterior_meanvariance(
4160 ctx: &QuadratureContext,
4161 eta: f64,
4162 se_eta: f64,
4163) -> (f64, f64) {
4164 if !(eta.is_finite() && se_eta.is_finite()) || se_eta <= CLOGLOG_SIGMA_DEGENERATE {
4184 return (cloglog_mean_exact(eta), 0.0);
4185 }
4186 let (survival, _) = cloglog_survival_term_controlled(ctx, eta, se_eta);
4187 let (survival_sq, _) = cloglog_survivalsecond_moment_controlled(ctx, eta, se_eta);
4188 let mean = cloglog_mean_from_survival(survival);
4189 let variance = (survival_sq - survival * survival).max(0.0);
4190 (mean, variance)
4191}
4192
4193#[inline]
4227pub fn cloglog_posterior_mean(ctx: &QuadratureContext, eta: f64, se_eta: f64) -> f64 {
4228 if !(eta.is_finite() && se_eta.is_finite()) || se_eta <= CLOGLOG_SIGMA_DEGENERATE {
4232 return cloglog_mean_exact(eta);
4233 }
4234 let (survival, _) = cloglog_survival_term_controlled(ctx, eta, se_eta);
4235 cloglog_mean_from_survival(survival)
4236}
4237
4238#[inline]
4252pub fn survival_posterior_mean(ctx: &QuadratureContext, eta: f64, se_eta: f64) -> f64 {
4253 cloglog_survival_term_controlled(ctx, eta, se_eta)
4254 .0
4255 .clamp(0.0, 1.0)
4256}
4257
4258#[inline]
4259pub fn survival_posterior_meanvariance(
4260 ctx: &QuadratureContext,
4261 eta: f64,
4262 se_eta: f64,
4263) -> (f64, f64) {
4264 let (m1, _) = cloglog_survival_term_controlled(ctx, eta, se_eta);
4265 let (m2, _) = cloglog_survivalsecond_moment_controlled(ctx, eta, se_eta);
4266 (m1.clamp(0.0, 1.0), (m2 - m1 * m1).max(0.0))
4267}
4268
4269pub fn logit_posterior_mean_exact(mu: f64, sigma: f64) -> f64 {
4345 if !(mu.is_finite() && sigma.is_finite()) || sigma <= 0.0 {
4346 return sigmoid(mu);
4347 }
4348 if sigma < LOGIT_SIGMA_DEGENERATE {
4349 return sigmoid(mu);
4352 }
4353
4354 let inv_sqrt_pi = 0.5 * std::f64::consts::FRAC_2_SQRT_PI; let sqrt2_sigma = SQRT_2 * sigma;
4356 let coeff = (2.0_f64 * std::f64::consts::PI).sqrt() / sigma; let c = -mu / sqrt2_sigma; let beta = std::f64::consts::PI / sqrt2_sigma; let r2 = FADDEEVA_ASYMPTOTIC_RADIUS * FADDEEVA_ASYMPTOTIC_RADIUS;
4360
4361 let mut corr = 0.0_f64;
4367 let mut n = 1usize;
4368 let tail_start = loop {
4369 let b = (2.0 * (n as f64) - 1.0) * beta;
4370 let abs_xi2 = c * c + b * b;
4371 if abs_xi2 > r2 && n >= FADDEEVA_TAIL_MIN_INDEX {
4372 break n;
4373 }
4374 let xi = Complex { re: c, im: b };
4375 let d = if abs_xi2 > r2 {
4376 inv_sqrt_pi * faddeeva_asymptotic_a(xi).re
4378 } else {
4379 faddeeva_upper_halfplane(xi).im - inv_sqrt_pi * c / abs_xi2
4380 };
4381 corr += d;
4382 n += 1;
4383 };
4384
4385 corr += faddeeva_pole_series_em_tail(c, beta, tail_start, inv_sqrt_pi);
4386
4387 sigmoid(mu) - coeff * corr
4388}
4389
4390const FADDEEVA_TAIL_MIN_INDEX: usize = 48;
4394const FADDEEVA_ASYMPTOTIC_RADIUS: f64 = 7.0;
4397const FADDEEVA_ASYMPTOTIC_TERMS: usize = 14;
4400
4401fn faddeeva_asymptotic_a(xi: Complex) -> Complex {
4405 let inv = complex_div(Complex { re: 1.0, im: 0.0 }, xi);
4406 let inv2 = complexmul(inv, inv);
4407 let mut xp = complexmul(inv2, inv); let mut cm = 0.5_f64; let mut s = Complex::default();
4410 for m in 1..=FADDEEVA_ASYMPTOTIC_TERMS {
4411 s = complex_add(
4412 s,
4413 Complex {
4414 re: cm * xp.re,
4415 im: cm * xp.im,
4416 },
4417 );
4418 cm *= (2.0 * (m as f64) + 1.0) / 2.0; xp = complexmul(xp, inv2);
4420 }
4421 s
4422}
4423
4424fn faddeeva_pole_series_em_tail(c: f64, beta: f64, tail_start: usize, inv_sqrt_pi: f64) -> f64 {
4433 let b_a = (2.0 * (tail_start as f64) - 1.0) * beta;
4434 let xi = Complex { re: c, im: b_a };
4435 let inv = complex_div(Complex { re: 1.0, im: 0.0 }, xi);
4436 let inv2 = complexmul(inv, inv);
4437 let two_i_beta = Complex {
4439 re: 0.0,
4440 im: 2.0 * beta,
4441 };
4442
4443 let mut s = Complex::default(); let mut a_acc = Complex::default(); let mut fp_inner = Complex::default(); let mut x2m = inv2; let mut x2m1 = complexmul(inv2, inv); let mut x2m2 = complexmul(inv2, inv2); let mut cm = 0.5_f64;
4451 for m in 1..=FADDEEVA_ASYMPTOTIC_TERMS {
4452 let mf = m as f64;
4453 let inv_4ibm = Complex {
4455 re: 0.0,
4456 im: -1.0 / (4.0 * beta * mf),
4457 };
4458 s = complex_add(
4459 s,
4460 complexmul(
4461 Complex {
4462 re: cm * x2m.re,
4463 im: cm * x2m.im,
4464 },
4465 inv_4ibm,
4466 ),
4467 );
4468 a_acc = complex_add(
4469 a_acc,
4470 Complex {
4471 re: cm * x2m1.re,
4472 im: cm * x2m1.im,
4473 },
4474 );
4475 let fc = cm * (-(2.0 * mf + 1.0));
4476 fp_inner = complex_add(
4477 fp_inner,
4478 Complex {
4479 re: fc * x2m2.re,
4480 im: fc * x2m2.im,
4481 },
4482 );
4483 cm *= (2.0 * mf + 1.0) / 2.0;
4484 x2m = complexmul(x2m, inv2);
4485 x2m1 = complexmul(x2m1, inv2);
4486 x2m2 = complexmul(x2m2, inv2);
4487 }
4488
4489 s = complex_add(
4491 s,
4492 Complex {
4493 re: 0.5 * a_acc.re,
4494 im: 0.5 * a_acc.im,
4495 },
4496 );
4497 let fprime = complexmul(two_i_beta, fp_inner);
4499 s = complex_add(
4500 s,
4501 Complex {
4502 re: -fprime.re / 12.0,
4503 im: -fprime.im / 12.0,
4504 },
4505 );
4506
4507 inv_sqrt_pi * s.re
4510}
4511
4512fn faddeeva_upper_halfplane(z: Complex) -> Complex {
4525 let (l, coeffs) = faddeeva_weideman_coeffs();
4526 let iz = Complex {
4527 re: -z.im,
4528 im: z.re,
4529 }; let l_minus = Complex {
4531 re: l - iz.re,
4532 im: -iz.im,
4533 }; let l_plus = Complex {
4535 re: l + iz.re,
4536 im: iz.im,
4537 }; let zz = complex_div(l_plus, l_minus); let mut p = Complex {
4541 re: coeffs[0],
4542 im: 0.0,
4543 };
4544 for &c in &coeffs[1..] {
4545 p = complex_add(complexmul(p, zz), Complex { re: c, im: 0.0 });
4546 }
4547 let l_minus_sq = complexmul(l_minus, l_minus);
4548 let term1 = complex_div(
4549 Complex {
4550 re: 2.0 * p.re,
4551 im: 2.0 * p.im,
4552 },
4553 l_minus_sq,
4554 );
4555 let inv_sqrt_pi = 0.5 * std::f64::consts::FRAC_2_SQRT_PI;
4556 let term2 = complex_div(
4557 Complex {
4558 re: inv_sqrt_pi,
4559 im: 0.0,
4560 },
4561 l_minus,
4562 );
4563 complex_add(term1, term2)
4564}
4565
4566const FADDEEVA_WEIDEMAN_N: usize = 44;
4569
4570fn faddeeva_weideman_coeffs() -> &'static (f64, [f64; FADDEEVA_WEIDEMAN_N]) {
4576 static CACHE: OnceLock<(f64, [f64; FADDEEVA_WEIDEMAN_N])> = OnceLock::new();
4577 CACHE.get_or_init(|| {
4578 let n = FADDEEVA_WEIDEMAN_N;
4579 let l = (n as f64 / SQRT_2).sqrt();
4580 let m = 2 * n;
4581 let m2 = 2 * m; let mut f = vec![0.0_f64; m2];
4585 for (idx, fi) in f.iter_mut().enumerate().skip(1) {
4586 let k = (idx as isize - 1) - (m as isize - 1);
4587 let theta = (k as f64) * std::f64::consts::PI / (m as f64);
4588 let t = l * (0.5 * theta).tan();
4589 *fi = (-t * t).exp() * (l * l + t * t);
4590 }
4591 let half = m2 / 2;
4594 let mut coeffs = [0.0_f64; FADDEEVA_WEIDEMAN_N];
4595 for j in 1..=n {
4596 let mut acc = 0.0_f64;
4597 for (p, _) in f.iter().enumerate() {
4598 let fp = f[(p + half) % m2];
4599 if fp != 0.0 {
4600 acc += fp
4601 * (-2.0 * std::f64::consts::PI * (j as f64) * (p as f64) / (m2 as f64))
4602 .cos();
4603 }
4604 }
4605 coeffs[n - j] = acc / (m2 as f64);
4607 }
4608 (l, coeffs)
4609 })
4610}
4611
4612#[inline]
4614fn sigmoid(x: f64) -> f64 {
4615 let x_clamped = x.clamp(-QUADRATURE_EXP_LOG_MAX, QUADRATURE_EXP_LOG_MAX);
4616 1.0 / (1.0 + f64::exp(-x_clamped))
4617}
4618
4619#[derive(Clone, Copy, Debug)]
4635pub struct CLogLogConvolutionDerivatives {
4636 pub l: f64,
4638
4639 pub l_mu: f64,
4641 pub l_sigma: f64,
4642
4643 pub l_mumu: f64,
4645 pub l_musigma: f64,
4646 pub l_sigmasigma: f64,
4647
4648 pub l_mumumu: f64,
4650 pub l_mumusigma: f64,
4651 pub l_musigmasigma: f64,
4652 pub l_sigmasigmasigma: f64,
4653
4654 pub l_mumumumu: f64,
4656 pub l_mumumusigma: f64,
4657 pub l_mumusigmasigma: f64,
4658 pub l_musigmasigmasigma: f64,
4659 pub l_sigmasigmasigmasigma: f64,
4660}
4661
4662#[inline]
4663pub(crate) fn cloglog_point_jet5(t: f64) -> (f64, f64, f64, f64, f64, f64) {
4664 if t.is_nan() {
4665 return (f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN);
4666 }
4667 let et = safe_exp(t);
4668
4669 (
4670 -(-et).exp_m1(),
4671 cloglog_stable_poly_times_exp_neg(et, &[0.0, 1.0]),
4672 cloglog_stable_poly_times_exp_neg(et, &[0.0, 1.0, -1.0]),
4673 cloglog_stable_poly_times_exp_neg(et, &[0.0, 1.0, -3.0, 1.0]),
4674 cloglog_stable_poly_times_exp_neg(et, &[0.0, 1.0, -7.0, 6.0, -1.0]),
4675 cloglog_stable_poly_times_exp_neg(et, &[0.0, 1.0, -15.0, 25.0, -10.0, 1.0]),
4676 )
4677}
4678
4679#[inline]
4691fn cloglog_g_derivatives(t: f64) -> (f64, f64, f64, f64, f64) {
4692 let (g, g1, g2, g3, g4, _) = cloglog_point_jet5(t);
4693 (g, g1, g2, g3, g4)
4694}
4695
4696pub fn cloglog_ghq_value(ctx: &QuadratureContext, mu: f64, sigma: f64, n_nodes: usize) -> f64 {
4704 if sigma.abs() < 1e-14 {
4705 let (g, _, _, _, _) = cloglog_g_derivatives(mu);
4706 return g.clamp(0.0, 1.0);
4707 }
4708 let inv_sqrt_pi = 1.0 / std::f64::consts::PI.sqrt();
4709
4710 let inv_sig2 = 1.0 / (sigma * sigma);
4742 let mut eta_hat = mu;
4743 let mut converged = false;
4744 for _ in 0..100 {
4745 let (g, g1, g2, _, _, _) = cloglog_point_jet5(eta_hat);
4746 if !(g > 0.0) || !g1.is_finite() || !g2.is_finite() {
4747 break;
4748 }
4749 let r = g1 / g;
4750 let lp = r - (eta_hat - mu) * inv_sig2;
4751 let lpp = g2 / g - r * r - inv_sig2;
4752 if !lpp.is_finite() || lpp >= 0.0 {
4753 break;
4754 }
4755 let step = lp / lpp;
4756 eta_hat -= step;
4757 if step.abs() <= 1e-13 * (1.0 + eta_hat.abs()) {
4758 converged = true;
4759 break;
4760 }
4761 }
4762
4763 let tau = if converged {
4767 let (g, g1, g2, _, _, _) = cloglog_point_jet5(eta_hat);
4768 if g > 0.0 {
4769 let r = g1 / g;
4770 let lpp = g2 / g - r * r - inv_sig2;
4771 let tau2 = -1.0 / lpp;
4772 if tau2.is_finite() && tau2 > 0.0 {
4773 Some(tau2.sqrt())
4774 } else {
4775 None
4776 }
4777 } else {
4778 None
4779 }
4780 } else {
4781 None
4782 };
4783
4784 let eval_at = |n: usize| -> f64 {
4786 match tau {
4787 Some(tau) => {
4788 let pref = tau * inv_sqrt_pi / sigma;
4789 with_gh_nodesweights(ctx, n, |nodes, weights| {
4790 let mut sum = 0.0_f64;
4791 for i in 0..nodes.len() {
4792 let t = nodes[i];
4793 let eta_i = eta_hat + SQRT_2 * tau * t;
4794 let (g, _, _, _, _, _) = cloglog_point_jet5(eta_i);
4795 let dev = eta_i - mu;
4796 sum += weights[i] * (t * t - 0.5 * dev * dev * inv_sig2).exp() * g;
4797 }
4798 (pref * sum).clamp(0.0, 1.0)
4799 })
4800 }
4801 None => {
4802 let scale = SQRT_2 * sigma;
4803 with_gh_nodesweights(ctx, n, |nodes, weights| {
4804 let mut sum = 0.0_f64;
4805 for i in 0..nodes.len() {
4806 let t = mu + scale * nodes[i];
4807 let (g, _, _, _, _) = cloglog_g_derivatives(t);
4808 sum += weights[i] * g;
4809 }
4810 (sum * inv_sqrt_pi).clamp(0.0, 1.0)
4811 })
4812 }
4813 }
4814 };
4815
4816 const CLOGLOG_GHQ_ORDER_LADDER: [usize; 5] = [7, 15, 21, 31, 51];
4821 const CLOGLOG_GHQ_CONV_TOL: f64 = 1e-10;
4822 let floor = n_nodes.min(*CLOGLOG_GHQ_ORDER_LADDER.last().unwrap());
4823 let mut prev: Option<f64> = None;
4824 let mut result = 0.0_f64;
4825 for &n in CLOGLOG_GHQ_ORDER_LADDER.iter().filter(|&&n| n >= floor) {
4826 let cur = eval_at(n);
4827 result = cur;
4828 if let Some(p) = prev
4829 && (cur - p).abs() < CLOGLOG_GHQ_CONV_TOL
4830 {
4831 break;
4832 }
4833 prev = Some(cur);
4834 }
4835 result
4836}
4837
4838pub fn cloglog_ghq_derivatives(
4849 ctx: &QuadratureContext,
4850 mu: f64,
4851 sigma: f64,
4852 n_nodes: usize,
4853) -> CLogLogConvolutionDerivatives {
4854 let inv_sqrt_pi = 1.0 / std::f64::consts::PI.sqrt();
4855
4856 if sigma.abs() < 1e-14 {
4863 let (g, g1, g2, g3, g4) = cloglog_g_derivatives(mu);
4864 return CLogLogConvolutionDerivatives {
4865 l: g,
4866 l_mu: g1,
4867 l_sigma: 0.0,
4868 l_mumu: g2,
4869 l_musigma: 0.0,
4870 l_sigmasigma: g2,
4871 l_mumumu: g3,
4872 l_mumusigma: 0.0,
4873 l_musigmasigma: g3,
4874 l_sigmasigmasigma: 0.0,
4875 l_mumumumu: g4,
4876 l_mumumusigma: 0.0,
4877 l_mumusigmasigma: g4,
4878 l_musigmasigmasigma: 0.0,
4879 l_sigmasigmasigmasigma: 3.0 * g4,
4880 };
4881 }
4882
4883 let scale = SQRT_2 * sigma;
4884 let sqrt2 = SQRT_2;
4885
4886 with_gh_nodesweights(ctx, n_nodes, |nodes, weights| {
4887 let mut s = [[0.0_f64; 5]; 5];
4899
4900 for i in 0..nodes.len() {
4901 let x = nodes[i];
4902 let t = mu + scale * x;
4903 let (g0, g1, g2, g3, g4) = cloglog_g_derivatives(t);
4904 let w = weights[i];
4905
4906 let x2 = x * x;
4908 let x3 = x2 * x;
4909 let x4 = x3 * x;
4910
4911 s[0][0] += w * g0;
4913
4914 s[1][0] += w * g1;
4916 s[1][1] += w * x * g1;
4917
4918 s[2][0] += w * g2;
4920 s[2][1] += w * x * g2;
4921 s[2][2] += w * x2 * g2;
4922
4923 s[3][0] += w * g3;
4925 s[3][1] += w * x * g3;
4926 s[3][2] += w * x2 * g3;
4927 s[3][3] += w * x3 * g3;
4928
4929 s[4][0] += w * g4;
4931 s[4][1] += w * x * g4;
4932 s[4][2] += w * x2 * g4;
4933 s[4][3] += w * x3 * g4;
4934 s[4][4] += w * x4 * g4;
4935 }
4936
4937 let sqrt2_1 = sqrt2;
4940 let sqrt2_2 = 2.0; let sqrt2_3 = 2.0 * sqrt2; let sqrt2_4 = 4.0; CLogLogConvolutionDerivatives {
4945 l: inv_sqrt_pi * s[0][0],
4947
4948 l_mu: inv_sqrt_pi * s[1][0],
4950 l_sigma: inv_sqrt_pi * sqrt2_1 * s[1][1],
4951
4952 l_mumu: inv_sqrt_pi * s[2][0],
4954 l_musigma: inv_sqrt_pi * sqrt2_1 * s[2][1],
4955 l_sigmasigma: inv_sqrt_pi * sqrt2_2 * s[2][2],
4956
4957 l_mumumu: inv_sqrt_pi * s[3][0],
4959 l_mumusigma: inv_sqrt_pi * sqrt2_1 * s[3][1],
4960 l_musigmasigma: inv_sqrt_pi * sqrt2_2 * s[3][2],
4961 l_sigmasigmasigma: inv_sqrt_pi * sqrt2_3 * s[3][3],
4962
4963 l_mumumumu: inv_sqrt_pi * s[4][0],
4965 l_mumumusigma: inv_sqrt_pi * sqrt2_1 * s[4][1],
4966 l_mumusigmasigma: inv_sqrt_pi * sqrt2_2 * s[4][2],
4967 l_musigmasigmasigma: inv_sqrt_pi * sqrt2_3 * s[4][3],
4968 l_sigmasigmasigmasigma: inv_sqrt_pi * sqrt2_4 * s[4][4],
4969 }
4970 })
4971}
4972
4973pub fn cloglog_ghq_derivatives_adaptive(
4979 ctx: &QuadratureContext,
4980 mu: f64,
4981 sigma: f64,
4982) -> CLogLogConvolutionDerivatives {
4983 let n = adaptive_point_count_from_sd(sigma.abs());
4984 cloglog_ghq_derivatives(ctx, mu, sigma, n)
4985}
4986
4987#[cfg(test)]
4988mod tests {
4989 use super::*;
4990 use approx::assert_relative_eq;
4991
4992 pub(crate) fn cloglog_posterior_meanwith_deriv_gamma_reference(
4993 mu: f64,
4994 sigma: f64,
4995 ) -> Result<IntegratedMeanDerivative, EstimationError> {
4996 let survival = cloglog_survival_gamma_reference(mu, sigma)?;
4999 let shifted_survival = cloglog_survival_gamma_reference(mu + sigma * sigma, sigma)?;
5000 let mean = cloglog_mean_from_survival(survival);
5001 let dmean = cloglog_shift_identity_derivative(mu, sigma, shifted_survival);
5002 if !(mean.is_finite() && dmean.is_finite()) {
5003 crate::bail_invalid_estim!(
5004 "Gamma cloglog reference backend produced non-finite values"
5005 );
5006 }
5007 Ok(IntegratedMeanDerivative {
5008 mean,
5009 dmean_dmu: dmean.max(0.0),
5010 mode: IntegratedExpectationMode::ExactSpecialFunction,
5011 })
5012 }
5013
5014 fn even_moment_exp_neg_x2(power: usize) -> f64 {
5015 assert!(power.is_multiple_of(2));
5016 let m = power / 2;
5017 let mut odd_double_factorial = 1.0_f64;
5018 for k in 0..m {
5019 odd_double_factorial *= (2 * k + 1) as f64;
5020 }
5021 odd_double_factorial * std::f64::consts::PI.sqrt() / 2.0_f64.powi(m as i32)
5022 }
5023
5024 fn normal_pdf(z: f64) -> f64 {
5025 (-(z * z) * 0.5).exp() / (2.0 * std::f64::consts::PI).sqrt()
5026 }
5027
5028 fn high_res_sigmoid_integral(eta: f64, se: f64) -> f64 {
5029 let a = -12.0_f64;
5031 let b = 12.0_f64;
5032 let n = 20_000usize; let h = (b - a) / n as f64;
5034
5035 let integrand = |z: f64| -> f64 { sigmoid(eta + se * z) * normal_pdf(z) };
5036
5037 let mut sum = integrand(a) + integrand(b);
5038 for i in 1..n {
5039 let x = a + (i as f64) * h;
5040 if i % 2 == 0 {
5041 sum += 2.0 * integrand(x);
5042 } else {
5043 sum += 4.0 * integrand(x);
5044 }
5045 }
5046 sum * h / 3.0
5047 }
5048
5049 #[test]
5050 fn test_computed_nodes_symmetric() {
5051 let ctx = QuadratureContext::new();
5053 let gh = ctx.gauss_hermite();
5054 for i in 0..N_POINTS / 2 {
5055 let j = N_POINTS - 1 - i;
5056 assert_relative_eq!(gh.nodes[i], -gh.nodes[j], epsilon = 1e-12);
5057 }
5058 assert_relative_eq!(gh.nodes[N_POINTS / 2], 0.0, epsilon = 1e-12);
5060 }
5061
5062 #[test]
5063 fn test_computedweights_symmetric() {
5064 let ctx = QuadratureContext::new();
5066 let gh = ctx.gauss_hermite();
5067 for i in 0..N_POINTS / 2 {
5068 let j = N_POINTS - 1 - i;
5069 assert_relative_eq!(gh.weights[i], gh.weights[j], epsilon = 1e-12);
5070 }
5071 }
5072
5073 #[test]
5074 fn testweights_sum_to_sqrt_pi() {
5075 let ctx = QuadratureContext::new();
5077 let gh = ctx.gauss_hermite();
5078 let sum: f64 = gh.weights.iter().sum();
5079 assert_relative_eq!(sum, std::f64::consts::PI.sqrt(), epsilon = 1e-10);
5080 }
5081
5082 #[test]
5083 fn test_clenshaw_curtisweights_are_symmetric_and_integrate_constants() {
5084 let rule = compute_clenshaw_curtis_n(33);
5085 let m = rule.weights.len() - 1;
5086 for j in 0..=m / 2 {
5087 assert_relative_eq!(rule.nodes[j], -rule.nodes[m - j], epsilon = 1e-14);
5088 assert_relative_eq!(rule.weights[j], rule.weights[m - j], epsilon = 1e-14);
5089 }
5090 let sum: f64 = rule.weights.iter().sum();
5091 assert_relative_eq!(sum, 2.0, epsilon = 1e-14, max_relative = 1e-14);
5092 }
5093
5094 #[test]
5095 fn test_cc_preference_prefers_moderate_central_case() {
5096 assert!(cloglog_should_prefer_cc(-0.2, 0.8, CLOGLOG_CC_TOL));
5097 }
5098
5099 #[test]
5100 fn test_cc_preference_prefers_moderately_large_case() {
5101 assert!(cloglog_should_prefer_cc(0.0, 2.0, CLOGLOG_CC_TOL));
5102 }
5103
5104 #[test]
5105 fn test_cc_preference_rejects_broad_case() {
5106 assert!(!cloglog_should_prefer_cc(0.0, 5.0, CLOGLOG_CC_TOL));
5107 }
5108
5109 #[test]
5110 fn testwilkinson_shift_finitewhen_d_iszero() {
5111 let shift = wilkinson_shift(0.0, 0.0, 1.25);
5114 assert!(shift.is_finite());
5115 assert_relative_eq!(shift, -1.25, epsilon = 1e-14);
5116 }
5117
5118 #[test]
5119 fn test_matches_abramowitz_stegun_7_point_gauss_hermite_constants() {
5120 let known_nodes = [
5124 -2.651_961_356_835_233_4,
5125 -1.673_551_628_767_471_4,
5126 -0.816_287_882_858_964_7,
5127 0.0,
5128 0.816_287_882_858_964_7,
5129 1.673_551_628_767_471_4,
5130 2.651_961_356_835_233_4,
5131 ];
5132 let knownweights = [
5133 0.000_971_781_245_099_519_1,
5134 0.054_515_582_819_127_03,
5135 0.425_607_252_610_127_8,
5136 0.810_264_617_556_807_3,
5137 0.425_607_252_610_127_8,
5138 0.054_515_582_819_127_03,
5139 0.000_971_781_245_099_519_1,
5140 ];
5141
5142 let ctx = QuadratureContext::new();
5143 let gh = ctx.gauss_hermite();
5144 for i in 0..N_POINTS {
5145 assert_relative_eq!(gh.nodes[i], known_nodes[i], epsilon = 1e-12);
5146 assert_relative_eq!(gh.weights[i], knownweights[i], epsilon = 1e-12);
5147 }
5148 }
5149
5150 #[test]
5151 fn test_gauss_hermite_weight_assembly_uses_eigenvector_rows() {
5152 let mut diag = [0.0_f64; N_POINTS];
5153 let mut off_diag = [0.0_f64; N_POINTS - 1];
5154 for (i, od) in off_diag.iter_mut().enumerate() {
5155 *od = (((i + 1) as f64) / 2.0).sqrt();
5156 }
5157 let (nodes, eigenvectors) = symmetric_tridiagonal_eigen(&mut diag, &mut off_diag);
5158 let mu0 = std::f64::consts::PI.sqrt();
5159 let mut row_pairs: Vec<(f64, f64)> = (0..N_POINTS)
5160 .map(|i| (nodes[i], mu0 * eigenvectors[i][0] * eigenvectors[i][0]))
5161 .collect();
5162 let mut column_pairs: Vec<(f64, f64)> = (0..N_POINTS)
5163 .map(|i| (nodes[i], mu0 * eigenvectors[0][i] * eigenvectors[0][i]))
5164 .collect();
5165 row_pairs.sort_by(|a, b| a.0.total_cmp(&b.0));
5166 column_pairs.sort_by(|a, b| a.0.total_cmp(&b.0));
5167
5168 let knownweights = [
5169 0.000_971_781_245_099_519_1,
5170 0.054_515_582_819_127_03,
5171 0.425_607_252_610_127_8,
5172 0.810_264_617_556_807_3,
5173 0.425_607_252_610_127_8,
5174 0.054_515_582_819_127_03,
5175 0.000_971_781_245_099_519_1,
5176 ];
5177
5178 for i in 0..N_POINTS {
5179 assert_relative_eq!(row_pairs[i].1, knownweights[i], epsilon = 1e-12);
5180 }
5181 let column_error: f64 = column_pairs
5182 .iter()
5183 .zip(knownweights.iter())
5184 .map(|(actual, expected)| (actual.1 - expected).abs())
5185 .sum();
5186 assert!(
5187 column_error > 1.0,
5188 "column-oriented eigenvector indexing unexpectedly matched A&S weights"
5189 );
5190 }
5191
5192 #[test]
5193 fn testzero_se_returns_mode() {
5194 let eta = 1.5;
5196 let se = 0.0;
5197 let ctx = QuadratureContext::new();
5198 let mean = logit_posterior_mean(&ctx, eta, se);
5199 let mode = sigmoid(eta);
5200 assert_relative_eq!(mean, mode, epsilon = 1e-10);
5201 }
5202
5203 #[test]
5204 fn test_symmetric_atzero() {
5205 let eta = 0.0;
5207 let se = 1.0;
5208 let ctx = QuadratureContext::new();
5209 let mean = logit_posterior_mean(&ctx, eta, se);
5210 assert_relative_eq!(mean, 0.5, epsilon = 0.01);
5212 }
5213
5214 #[test]
5215 fn test_shrinkage_at_extremes() {
5216 let eta = 3.0; let se = 1.0;
5219 let ctx = QuadratureContext::new();
5220 let mean = logit_posterior_mean(&ctx, eta, se);
5221 let mode = sigmoid(eta);
5222
5223 assert!(mean < mode, "Expected mean {} < mode {}", mean, mode);
5225 assert!(mean > 0.8, "Mean {} should still be high", mean);
5227 }
5228
5229 #[test]
5230 fn test_matches_monte_carlo() {
5231 let eta = 2.0;
5233 let se = 0.8;
5234
5235 let ctx = QuadratureContext::new();
5236 let quad_mean = logit_posterior_mean(&ctx, eta, se);
5237
5238 let n_samples = 100_000;
5240 let mut mc_sum = 0.0;
5241 let mut rng_state = 12345u64; for _ in 0..n_samples {
5243 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
5245 let u1 = ((rng_state as f64) / (u64::MAX as f64)).max(1e-10); rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
5247 let u2 = (rng_state as f64) / (u64::MAX as f64);
5248 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
5249 let eta_sample = eta + se * z;
5250 mc_sum += sigmoid(eta_sample);
5251 }
5252 let mc_mean = mc_sum / (n_samples as f64);
5253
5254 assert_relative_eq!(quad_mean, mc_mean, epsilon = 0.01);
5256 }
5257
5258 #[test]
5259 fn test_quadrature_integrates_x_squared() {
5260 let ctx = QuadratureContext::new();
5263 let gh = ctx.gauss_hermite();
5264 let mut sum = 0.0;
5265 for i in 0..N_POINTS {
5266 sum += gh.weights[i] * gh.nodes[i] * gh.nodes[i];
5267 }
5268 let expected = std::f64::consts::PI.sqrt() / 2.0;
5269 assert_relative_eq!(sum, expected, epsilon = 1e-10);
5270 }
5271
5272 #[test]
5273 fn test_quadrature_integrates_x_fourth() {
5274 let ctx = QuadratureContext::new();
5277 let gh = ctx.gauss_hermite();
5278 let mut sum = 0.0;
5279 for i in 0..N_POINTS {
5280 let x = gh.nodes[i];
5281 sum += gh.weights[i] * x * x * x * x;
5282 }
5283 let expected = 3.0 * std::f64::consts::PI.sqrt() / 4.0;
5284 assert_relative_eq!(sum, expected, epsilon = 1e-10);
5285 }
5286
5287 #[test]
5288 fn test_moment_exactness_up_to_degree_13() {
5289 let ctx = QuadratureContext::new();
5290 let gh = ctx.gauss_hermite();
5291
5292 for degree in 0..=13usize {
5293 let approx: f64 = (0..N_POINTS)
5294 .map(|i| gh.weights[i] * gh.nodes[i].powi(degree as i32))
5295 .sum();
5296
5297 let expected = if degree % 2 == 1 {
5298 0.0
5299 } else {
5300 even_moment_exp_neg_x2(degree)
5301 };
5302
5303 let err = (approx - expected).abs();
5304 let rel_scale = approx.abs().max(expected.abs()).max(1.0);
5305 assert!(
5306 err <= 1e-10 || err / rel_scale <= 1e-10,
5307 "degree={} approx={} expected={} abs_err={}",
5308 degree,
5309 approx,
5310 expected,
5311 err
5312 );
5313 }
5314 }
5315
5316 #[test]
5317 fn test_integrated_sigmoid_matches_high_res_integral_random_pairs() {
5318 let ctx = QuadratureContext::new();
5319 let mut rng_state = 0x4d595df4d0f33173u64;
5320
5321 for _ in 0..20 {
5322 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
5323 let u_eta = (rng_state as f64) / (u64::MAX as f64);
5324 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
5325 let u_se = (rng_state as f64) / (u64::MAX as f64);
5326
5327 let eta = -6.0 + 12.0 * u_eta;
5328 let se = 0.02 + 1.5 * u_se;
5329
5330 let ghq = logit_posterior_mean(&ctx, eta, se);
5331 let numeric = high_res_sigmoid_integral(eta, se);
5332 assert_relative_eq!(ghq, numeric, epsilon = 2e-3);
5333 }
5334 }
5335
5336 #[test]
5337 fn test_logit_posterior_derivative_remains_positive_in_positive_tail() {
5338 let eta = 20.0;
5339 let se = 0.0;
5340 let (_, dmu) = logit_posterior_meanwith_deriv(eta, se)
5341 .expect("logit posterior mean derivative should evaluate");
5342 assert!(dmu > 0.0);
5343 assert!(
5344 dmu < 1e-6,
5345 "positive-tail derivative should stay tiny but nonzero, got {dmu}"
5346 );
5347 }
5348
5349 #[test]
5350 fn test_logit_posterior_derivative_matches_central_difference() {
5351 let ctx = QuadratureContext::new();
5352 let eta = 1.7;
5353 let se = 0.9;
5354 let h = 1e-5;
5355
5356 let (_, dmu) = logit_posterior_meanwith_deriv(eta, se)
5357 .expect("logit posterior mean derivative should evaluate");
5358 let mu_plus = logit_posterior_mean(&ctx, eta + h, se);
5359 let mu_minus = logit_posterior_mean(&ctx, eta - h, se);
5360 let dmufd = (mu_plus - mu_minus) / (2.0 * h);
5361
5362 assert_eq!(dmu.signum(), dmufd.signum());
5363 assert_relative_eq!(dmu, dmufd, epsilon = 5e-6, max_relative = 2e-4);
5364 }
5365
5366 fn dense_sigmoid_normal_mean(mu: f64, sigma: f64) -> f64 {
5372 let a = -18.0_f64;
5373 let b = 18.0_f64;
5374 let n = 400_000usize; let h = (b - a) / n as f64;
5376 let integrand = |z: f64| -> f64 { sigmoid(mu + sigma * z) * normal_pdf(z) };
5377 let mut sum = integrand(a) + integrand(b);
5378 for i in 1..n {
5379 let z = a + (i as f64) * h;
5380 sum += if i % 2 == 0 { 2.0 } else { 4.0 } * integrand(z);
5381 }
5382 sum * h / 3.0
5383 }
5384
5385 #[test]
5386 fn test_logit_posterior_mean_exact_symmetry_identity() {
5387 let cases = [
5390 (-3.0, 0.5),
5391 (-1.2, 1.7),
5392 (0.0, 2.2),
5393 (2.3, 0.8),
5394 (3.0, 0.05),
5395 ];
5396 for (mu, sigma) in cases {
5397 let p = logit_posterior_mean_exact(mu, sigma);
5398 let q = logit_posterior_mean_exact(-mu, sigma);
5399 assert!(
5400 (p + q - 1.0).abs() < 1e-12,
5401 "symmetry broken at mu={mu} sigma={sigma}: p+q-1 = {:.3e}",
5402 p + q - 1.0
5403 );
5404 }
5405 }
5406
5407 #[test]
5408 fn test_logit_posterior_mean_exact_matches_high_res_integral() {
5409 let cases = [
5413 (-2.0, 0.4),
5414 (-0.7, 1.1),
5415 (0.8, 0.9),
5416 (2.4, 1.7),
5417 (3.0, 0.05),
5418 (3.0, 0.5),
5419 (-2.0, 2.0),
5420 (5.0, 3.0),
5421 ];
5422 for (mu, sigma) in cases {
5423 let exact = logit_posterior_mean_exact(mu, sigma);
5424 let numeric = dense_sigmoid_normal_mean(mu, sigma);
5425 assert!(
5426 (exact - numeric).abs() < 1e-10,
5427 "oracle ≠ dense reference at mu={mu} sigma={sigma}: \
5428 exact={exact:.13} ref={numeric:.13} err={:.3e}",
5429 (exact - numeric).abs()
5430 );
5431 }
5432 }
5433
5434 #[test]
5441 fn test_logit_posterior_mean_exact_no_truncation_bias_1459() {
5442 let table = [
5445 (1.0, 0.02),
5446 (1.0, 0.05),
5447 (1.0, 0.5),
5448 (1.0, 2.0),
5449 (3.0, 0.02),
5450 (3.0, 0.05),
5451 (3.0, 0.5),
5452 (3.0, 2.0),
5453 (-2.0, 0.02),
5454 (-2.0, 0.05),
5455 (-2.0, 0.5),
5456 (-2.0, 2.0),
5457 ];
5458 for (mu, sigma) in table {
5459 let exact = logit_posterior_mean_exact(mu, sigma);
5460 let reference = dense_sigmoid_normal_mean(mu, sigma);
5461 let err = (exact - reference).abs();
5462 assert!(
5463 err < 1e-10,
5464 "#1459 truncation bias resurfaced at mu={mu} sigma={sigma}: \
5465 err={err:.3e} (pre-fix bias here was ~{:.2e})",
5466 mu.abs() / (2.0 * std::f64::consts::PI.powi(2) * 4096.0)
5467 );
5468 }
5469
5470 let mu = 3.0;
5475 let errs: Vec<f64> = [0.05, 0.5, 2.0]
5476 .iter()
5477 .map(|&s| logit_posterior_mean_exact(mu, s) - dense_sigmoid_normal_mean(mu, s))
5478 .collect();
5479 for e in &errs {
5480 assert!(
5481 e.abs() < 1e-10,
5482 "residual {e:.3e} at mu=3 — old σ-independent plateau was 3.71e-5"
5483 );
5484 }
5485 }
5486
5487 #[test]
5495 fn test_faddeeva_weideman_matches_known_values() {
5496 let w0 = faddeeva_upper_halfplane(Complex { re: 0.0, im: 0.0 });
5498 assert!(
5499 (w0.re - 1.0).abs() < 1e-13 && w0.im.abs() < 1e-13,
5500 "w(0)={w0:?}"
5501 );
5502 let on_axis = [
5504 (0.1, 0.8964569799691268),
5505 (0.5, 0.6156903441929258),
5506 (1.0, 0.427583576155807),
5507 (2.0, 0.2553956763105058),
5508 (5.0, 0.11070463773306861),
5509 (9.0, 0.06230772403777468),
5510 ];
5511 for (y, want) in on_axis {
5512 let w = faddeeva_upper_halfplane(Complex { re: 0.0, im: y });
5513 assert!(
5514 (w.re - want).abs() < 1e-13 && w.im.abs() < 1e-13,
5515 "w(i·{y}): got {w:?}, want re={want}, err={:.2e}",
5516 (w.re - want).abs()
5517 );
5518 }
5519 let off_axis = [
5521 ((0.7, 1.3), (0.31327301971562715, 0.12443489420104513)),
5522 ((-1.5, 0.8), (0.21066359024766423, -0.27001624496296617)),
5523 ((3.0, 0.4), (0.030278754646989155, 0.1957320888774461)),
5524 ];
5525 for ((re, im), (wre, wim)) in off_axis {
5526 let w = faddeeva_upper_halfplane(Complex { re, im });
5527 assert!(
5528 (w.re - wre).abs() < 1e-13 && (w.im - wim).abs() < 1e-13,
5529 "w({re}+{im}i): got {w:?}, want ({wre},{wim})"
5530 );
5531 }
5532 let w = faddeeva_upper_halfplane(Complex { re: 3.0, im: 40.0 });
5536 assert!(
5537 (w.re - 0.01402158696172506).abs() < 1e-13
5538 && (w.im - 0.0010509664408184546).abs() < 1e-13,
5539 "tail value mismatch: w={w:?}"
5540 );
5541 }
5542
5543 #[test]
5544 fn test_integrated_logit_mean_close_to_exact_oracle() {
5545 let ctx = QuadratureContext::new();
5549 let cases = [(-3.0, 0.3), (-1.0, 0.8), (0.5, 1.2), (2.8, 1.0)];
5550 for (eta, se) in cases {
5551 let ghq = logit_posterior_mean(&ctx, eta, se);
5552 let exact = logit_posterior_mean_exact(eta, se);
5553 assert!(
5554 (ghq - exact).abs() < 1e-6,
5555 "production path drifts from oracle at eta={eta} se={se}: \
5556 ghq={ghq:.12} oracle={exact:.12} gap={:.3e}",
5557 (ghq - exact).abs()
5558 );
5559 }
5560 }
5561
5562 #[test]
5563 fn test_probit_posterior_mean_reduces_to_map_atzero_se() {
5564 let eta = 1.25;
5565 let p = probit_posterior_mean(eta, 0.0);
5566 let map = gam_math::probability::normal_cdf(eta);
5567 assert_relative_eq!(p, map, epsilon = 1e-12);
5568 }
5569
5570 #[test]
5571 fn test_probit_posterior_mean_shrinks_extremeswith_uncertainty() {
5572 let hi_eta = 3.0;
5573 let lo_eta = -3.0;
5574 let p_hi_map = probit_posterior_mean(hi_eta, 0.0);
5575 let p_hi_unc = probit_posterior_mean(hi_eta, 2.0);
5576 let p_lo_map = probit_posterior_mean(lo_eta, 0.0);
5577 let p_lo_unc = probit_posterior_mean(lo_eta, 2.0);
5578 assert!(p_hi_unc < p_hi_map);
5579 assert!(p_lo_unc > p_lo_map);
5580 }
5581
5582 #[test]
5583 fn test_survival_posterior_mean_is_bounded_and_shrinks_tail() {
5584 let ctx = QuadratureContext::new();
5585 let eta: f64 = 3.0;
5586 let map = (-(eta.exp())).exp();
5587 let pm = survival_posterior_mean(&ctx, eta, 1.5);
5588 assert!((0.0..=1.0).contains(&pm));
5589 assert!(pm > map);
5590 }
5591
5592 #[test]
5593 fn test_cloglog_and_survival_posterior_means_are_complements() {
5594 let ctx = QuadratureContext::new();
5595 let cases = [
5596 (-3.0, 0.0),
5597 (-0.2, 0.1),
5598 (0.4, 0.8),
5599 (2.0, 1.5),
5600 (10.0, 0.3),
5601 (0.0, 20.0),
5602 (10.0, 10.0),
5603 (-0.5, 100.0),
5604 ];
5605 for (eta, se) in cases {
5606 let clog = cloglog_posterior_mean(&ctx, eta, se);
5607 let surv = survival_posterior_mean(&ctx, eta, se);
5608 assert_relative_eq!(clog + surv, 1.0, epsilon = 2e-10, max_relative = 2e-10);
5609 }
5610 }
5611
5612 #[test]
5613 fn test_cloglog_and_survival_share_large_sigmaspecial_function_path() {
5614 let ctx = QuadratureContext::new();
5615 let eta = -0.2;
5616 let se = 0.8;
5617 let clog = cloglog_posterior_mean(&ctx, eta, se);
5618 let surv = survival_posterior_mean(&ctx, eta, se);
5619 let integrated =
5620 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::CLogLog, eta, se)
5621 .expect("cloglog integrated inverse-link moments should evaluate");
5622 assert_eq!(
5623 integrated.mode,
5624 IntegratedExpectationMode::ExactSpecialFunction
5625 );
5626 assert_relative_eq!(clog, integrated.mean, epsilon = 1e-12, max_relative = 1e-12);
5627 assert_relative_eq!(clog + surv, 1.0, epsilon = 1e-10, max_relative = 1e-10);
5628 }
5629
5630 #[test]
5631 fn test_cloglog_and_survival_posteriorvariances_match() {
5632 let ctx = QuadratureContext::new();
5633 let cases = [(-3.0, 0.0), (-0.2, 0.1), (0.4, 0.8), (2.0, 1.5)];
5634 for (eta, se) in cases {
5635 let (_, clogvar) = cloglog_posterior_meanvariance(&ctx, eta, se);
5636 let (_, survvar) = survival_posterior_meanvariance(&ctx, eta, se);
5637 assert_relative_eq!(clogvar, survvar, epsilon = 1e-12, max_relative = 1e-12);
5638 }
5639 }
5640
5641 #[test]
5642 fn test_survivalvariance_uses_exactsecond_moment_shift() {
5643 let ctx = QuadratureContext::new();
5644 let eta = -0.2;
5645 let se = 0.8;
5646 let (survival, _) = cloglog_survival_term_controlled(&ctx, eta, se);
5647 let (survival_sq, _) = cloglog_survivalsecond_moment_controlled(&ctx, eta, se);
5648 let (_, variance) = survival_posterior_meanvariance(&ctx, eta, se);
5649 assert_relative_eq!(
5650 variance,
5651 (survival_sq - survival * survival).max(0.0),
5652 epsilon = 1e-12,
5653 max_relative = 1e-12
5654 );
5655 }
5656
5657 #[test]
5658 fn test_lognormal_laplace_shift_matches_explicitmu_plus_logz() {
5659 let ctx = QuadratureContext::new();
5660 let mu = -0.2;
5661 let sigma = 0.8;
5662 let z = 2.0;
5663 let shifted = lognormal_laplace_term_controlled(&ctx, z, mu, sigma);
5664 let explicit = cloglog_survival_term_controlled(&ctx, mu + z.ln(), sigma);
5665 assert_eq!(shifted.1, explicit.1);
5666 assert_relative_eq!(shifted.0, explicit.0, epsilon = 1e-12, max_relative = 1e-12);
5667 }
5668
5669 #[test]
5670 fn test_integrated_dispatch_uses_closed_form_probit() {
5671 let ctx = QuadratureContext::new();
5672 let out = integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Probit, 0.7, 1.3)
5673 .expect("probit integrated inverse-link moments should evaluate");
5674 assert_eq!(out.mode, IntegratedExpectationMode::ExactClosedForm);
5675 let direct = probit_posterior_meanwith_deriv_exact(0.7, 1.3);
5676 assert_relative_eq!(out.mean, direct.mean, epsilon = 1e-12);
5677 assert_relative_eq!(out.dmean_dmu, direct.dmean_dmu, epsilon = 1e-12);
5678 }
5679
5680 #[test]
5681 fn test_integrated_probit_jet_matches_closed_form_derivatives() {
5682 let ctx = QuadratureContext::new();
5683 let mu = 0.7;
5684 let sigma = 1.3;
5685 let out = integrated_inverse_link_jet(&ctx, LinkFunction::Probit, mu, sigma)
5686 .expect("probit integrated inverse-link jet should evaluate");
5687 let s = (1.0 + sigma * sigma).sqrt();
5688 let z = mu / s;
5689 let pdf = gam_math::probability::normal_pdf(z);
5690 assert_relative_eq!(
5691 out.mean,
5692 gam_math::probability::normal_cdf(z),
5693 epsilon = 1e-12
5694 );
5695 assert_relative_eq!(out.d1, pdf / s, epsilon = 1e-12);
5696 assert_relative_eq!(out.d2, -z * pdf / (s * s), epsilon = 1e-12);
5697 assert_relative_eq!(out.d3, (z * z - 1.0) * pdf / (s * s * s), epsilon = 1e-12);
5698 }
5699
5700 #[test]
5701 fn test_integrated_logit_jet_matches_central_differences() {
5702 let ctx = QuadratureContext::new();
5715 let mu = 1.1;
5716 let sigma = 0.8;
5717 let out = integrated_inverse_link_jet(&ctx, LinkFunction::Logit, mu, sigma)
5718 .expect("logit integrated inverse-link jet should evaluate");
5719 assert!(matches!(
5720 out.mode,
5721 IntegratedExpectationMode::ExactSpecialFunction
5722 | IntegratedExpectationMode::QuadratureFallback
5723 ));
5724 let (ref_mean, ref_d1, ref_d2, ref_d3) = logit_reference_jet_highres_simpson(mu, sigma);
5725 assert_relative_eq!(out.mean, ref_mean, epsilon = 1e-11, max_relative = 1e-10);
5726 assert_relative_eq!(out.d1, ref_d1, epsilon = 1e-11, max_relative = 1e-10);
5727 assert_relative_eq!(out.d2, ref_d2, epsilon = 1e-11, max_relative = 1e-10);
5728 assert_relative_eq!(out.d3, ref_d3, epsilon = 1e-11, max_relative = 1e-10);
5729 }
5730
5731 #[test]
5732 fn test_integrated_logit_pirls_jet_matches_general_dispatch() {
5733 let ctx = QuadratureContext::new();
5743 let mu = 1.1;
5744 let sigma = 0.8;
5745
5746 let pirls =
5747 integrated_logit_inverse_link_jet_pirls(&ctx, mu, sigma).expect("PIRLS logit jet");
5748 let general = integrated_inverse_link_jet(&ctx, LinkFunction::Logit, mu, sigma)
5749 .expect("general logit jet");
5750
5751 assert!(matches!(
5752 pirls.mode,
5753 IntegratedExpectationMode::ExactSpecialFunction
5754 | IntegratedExpectationMode::QuadratureFallback
5755 ));
5756 assert_eq!(pirls.mode, general.mode);
5757 assert_relative_eq!(pirls.mean, general.mean, epsilon = 1e-12);
5758 assert_relative_eq!(pirls.d1, general.d1, epsilon = 1e-12);
5759 assert_relative_eq!(pirls.d2, general.d2, epsilon = 1e-10);
5760 assert_relative_eq!(pirls.d3, general.d3, epsilon = 1e-8);
5761 }
5762
5763 #[test]
5764 fn test_integrated_cloglog_jet_matches_central_differences() {
5765 let ctx = QuadratureContext::new();
5766 let mu = 0.4;
5767 let sigma = 0.6;
5768 let h = 1e-4;
5769 let out = integrated_inverse_link_jet(&ctx, LinkFunction::CLogLog, mu, sigma)
5770 .expect("cloglog integrated inverse-link jet should evaluate");
5771 let plus = integrated_inverse_link_jet(&ctx, LinkFunction::CLogLog, mu + h, sigma)
5772 .expect("cloglog integrated inverse-link jet should evaluate");
5773 let minus = integrated_inverse_link_jet(&ctx, LinkFunction::CLogLog, mu - h, sigma)
5774 .expect("cloglog integrated inverse-link jet should evaluate");
5775 let d1fd = (plus.mean - minus.mean) / (2.0 * h);
5776 let d2fd = (plus.d1 - minus.d1) / (2.0 * h);
5777 let d3fd = (plus.d2 - minus.d2) / (2.0 * h);
5778 assert_eq!(out.d1.signum(), d1fd.signum());
5779 assert_eq!(out.d2.signum(), d2fd.signum());
5780 assert_eq!(out.d3.signum(), d3fd.signum());
5781 assert_relative_eq!(out.d1, d1fd, epsilon = 2e-5, max_relative = 3e-4);
5782 assert_relative_eq!(out.d2, d2fd, epsilon = 4e-5, max_relative = 8e-4);
5783 assert_relative_eq!(out.d3, d3fd, epsilon = 8e-5, max_relative = 2e-3);
5784 }
5785
5786 #[test]
5787 fn test_integrated_cloglog_wide_sigma_d3_matches_simpson_and_d2_slope() {
5788 let ctx = QuadratureContext::new();
5789 let cases = [(0.0, 4.0), (-1.0, 4.0), (2.0, 3.0), (3.0, 3.0)];
5790 let h = 1e-4;
5791
5792 for (mu, sigma) in cases {
5793 let out = integrated_inverse_link_jet(&ctx, LinkFunction::CLogLog, mu, sigma)
5794 .expect("wide-sigma cloglog integrated jet should evaluate");
5795 let reference = cloglog_reference_jet_highres_simpson(mu, sigma);
5796 let plus = integrated_inverse_link_jet(&ctx, LinkFunction::CLogLog, mu + h, sigma)
5797 .expect("wide-sigma cloglog integrated jet should evaluate");
5798 let minus = integrated_inverse_link_jet(&ctx, LinkFunction::CLogLog, mu - h, sigma)
5799 .expect("wide-sigma cloglog integrated jet should evaluate");
5800 let d3fd = (plus.d2 - minus.d2) / (2.0 * h);
5801
5802 assert_eq!(out.mode, IntegratedExpectationMode::QuadratureFallback);
5803 assert_relative_eq!(out.mean, reference.0, epsilon = 4e-8, max_relative = 4e-8);
5804 assert_relative_eq!(out.d1, reference.1, epsilon = 4e-8, max_relative = 4e-8);
5805 assert_relative_eq!(out.d2, reference.2, epsilon = 2e-9, max_relative = 2e-7);
5806 assert_relative_eq!(out.d3, reference.3, epsilon = 2e-9, max_relative = 2e-7);
5807 assert_relative_eq!(out.d3, d3fd, epsilon = 2e-7, max_relative = 4e-5);
5808 }
5809 }
5810
5811 #[test]
5812 fn test_latent_cloglog_jet5_matches_higher_order_central_differences() {
5813 let ctx = QuadratureContext::new();
5814 let mu = 0.35;
5815 let sigma = 0.7;
5816 let h = 2e-4;
5817
5818 let out = latent_cloglog_inverse_link_jet5_controlled(&ctx, mu, sigma);
5819 let plus = latent_cloglog_inverse_link_jet5_controlled(&ctx, mu + h, sigma);
5820 let minus = latent_cloglog_inverse_link_jet5_controlled(&ctx, mu - h, sigma);
5821
5822 let d4fd = (plus.d3 - minus.d3) / (2.0 * h);
5823 let d5fd = (plus.d4 - minus.d4) / (2.0 * h);
5824
5825 assert_eq!(out.d4.signum(), d4fd.signum());
5826 assert_eq!(out.d5.signum(), d5fd.signum());
5827 assert_relative_eq!(out.d4, d4fd, epsilon = 2e-4, max_relative = 5e-3);
5828 assert_relative_eq!(out.d5, d5fd, epsilon = 6e-4, max_relative = 2e-2);
5829 }
5830
5831 #[test]
5832 fn test_logit_exact_derivative_matches_finite_difference() {
5833 let out = logit_posterior_meanwith_deriv_controlled(1.1, 0.8).expect("controlled logit");
5843 let (ref_mean, ref_d1, _, _) = logit_reference_jet_highres_simpson(1.1, 0.8);
5844 assert_relative_eq!(out.mean, ref_mean, epsilon = 1e-11, max_relative = 1e-10);
5845 assert!(out.dmean_dmu > 0.0);
5846 assert_relative_eq!(out.dmean_dmu, ref_d1, epsilon = 1e-11, max_relative = 1e-10);
5847 }
5848
5849 #[test]
5850 fn test_logit_exact_clamped_degenerate_branch_is_locally_flat() {
5851 let out = logit_posterior_meanwith_deriv_exact(-710.0, 0.0).expect("exact logit");
5852 let h = 1e-6;
5853 let plus = logit_posterior_meanwith_deriv_exact(-710.0 + h, 0.0)
5854 .expect("exact logit plus")
5855 .mean;
5856 let minus = logit_posterior_meanwith_deriv_exact(-710.0 - h, 0.0)
5857 .expect("exact logit minus")
5858 .mean;
5859 let fd = (plus - minus) / (2.0 * h);
5860 assert_eq!(fd, 0.0);
5861 assert_eq!(out.dmean_dmu, 0.0);
5862 }
5863
5864 fn simpson_integrate<F>(a: f64, b: f64, n_intervals: usize, f: F) -> f64
5865 where
5866 F: Fn(f64) -> f64,
5867 {
5868 assert_eq!(n_intervals % 2, 0, "Simpson integration requires an even n");
5869 let h = (b - a) / n_intervals as f64;
5870 let mut sum = f(a) + f(b);
5871 for i in 1..n_intervals {
5872 let x = a + i as f64 * h;
5873 let w = if i % 2 == 0 { 2.0 } else { 4.0 };
5874 sum += w * f(x);
5875 }
5876 sum * h / 3.0
5877 }
5878
5879 fn cloglog_reference_mean_and_derivative(mu: f64, sigma: f64) -> (f64, f64) {
5880 if sigma <= CLOGLOG_SIGMA_DEGENERATE {
5881 return (cloglog_mean_exact(mu), cloglog_mean_d1_exact(mu));
5882 }
5883
5884 let z_max = 12.0;
5888 let n_intervals = 4096;
5889 let inv_sqrt_2pi = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
5890 let mean = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5891 let eta = mu + sigma * z;
5892 inv_sqrt_2pi * (-0.5 * z * z).exp() * cloglog_mean_exact(eta)
5893 });
5894 let deriv = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5895 let eta = mu + sigma * z;
5896 inv_sqrt_2pi * (-0.5 * z * z).exp() * cloglog_mean_d1_exact(eta)
5897 });
5898 (mean, deriv)
5899 }
5900
5901 fn logit_reference_jet_highres_simpson(mu: f64, sigma: f64) -> (f64, f64, f64, f64) {
5914 let z_max = 14.0;
5915 let n_intervals = 16384;
5916 let inv_sqrt_2pi = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
5917 let phi = |z: f64| inv_sqrt_2pi * (-0.5 * z * z).exp();
5918 let mean = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5919 let eta = mu + sigma * z;
5920 let (p, _, _, _) = component_point_jet(LinkComponent::Logit, eta);
5921 phi(z) * p
5922 });
5923 let d1 = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5924 let eta = mu + sigma * z;
5925 let (_, p1, _, _) = component_point_jet(LinkComponent::Logit, eta);
5926 phi(z) * p1
5927 });
5928 let d2 = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5929 let eta = mu + sigma * z;
5930 let (_, _, p2, _) = component_point_jet(LinkComponent::Logit, eta);
5931 phi(z) * p2
5932 });
5933 let d3 = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5934 let eta = mu + sigma * z;
5935 let (_, _, _, p3) = component_point_jet(LinkComponent::Logit, eta);
5936 phi(z) * p3
5937 });
5938 (mean, d1, d2, d3)
5939 }
5940
5941 fn cloglog_reference_jet_highres_simpson(mu: f64, sigma: f64) -> (f64, f64, f64, f64) {
5942 let z_max = 14.0;
5943 let n_intervals = 16384;
5944 let inv_sqrt_2pi = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
5945 let phi = |z: f64| inv_sqrt_2pi * (-0.5 * z * z).exp();
5946 let mean = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5947 let eta = mu + sigma * z;
5948 let (g, _, _, _, _, _) = cloglog_point_jet5(eta);
5949 phi(z) * g
5950 });
5951 let d1 = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5952 let eta = mu + sigma * z;
5953 let (_, g1, _, _, _, _) = cloglog_point_jet5(eta);
5954 phi(z) * g1
5955 });
5956 let d2 = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5957 let eta = mu + sigma * z;
5958 let (_, _, g2, _, _, _) = cloglog_point_jet5(eta);
5959 phi(z) * g2
5960 });
5961 let d3 = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5962 let eta = mu + sigma * z;
5963 let (_, _, _, g3, _, _) = cloglog_point_jet5(eta);
5964 phi(z) * g3
5965 });
5966 (mean, d1, d2, d3)
5967 }
5968
5969 #[test]
5970 fn test_cloglog_taylor_negative_tail_matches_mathematical_target() {
5971 let mu = -40.0;
5972 let sigma = 0.1;
5973 let out = cloglog_small_sigma_taylor(mu, sigma);
5974 let (expected_mean, expected_deriv) = cloglog_reference_mean_and_derivative(mu, sigma);
5975
5976 assert!(
5977 out.dmean_dmu > 0.0,
5978 "negative-tail derivative should remain positive"
5979 );
5980 assert_relative_eq!(
5981 out.mean,
5982 expected_mean,
5983 epsilon = 1e-30,
5984 max_relative = 1e-12
5985 );
5986 assert_relative_eq!(
5987 out.dmean_dmu,
5988 expected_deriv,
5989 epsilon = 1e-30,
5990 max_relative = 1e-12
5991 );
5992 }
5993
5994 #[test]
5995 fn test_cloglog_degenerate_negative_tail_matches_pointwise_target() {
5996 let ctx = QuadratureContext::new();
5997 let mu = -40.0;
5998 let out = cloglog_posterior_meanwith_deriv_controlled(&ctx, mu, 0.0);
5999
6000 assert!(
6001 out.dmean_dmu > 0.0,
6002 "degenerate negative-tail derivative should remain positive"
6003 );
6004 assert_relative_eq!(
6005 out.mean,
6006 cloglog_mean_exact(mu),
6007 epsilon = 1e-30,
6008 max_relative = 1e-15
6009 );
6010 assert_relative_eq!(
6011 out.dmean_dmu,
6012 cloglog_mean_d1_exact(mu),
6013 epsilon = 1e-30,
6014 max_relative = 1e-15
6015 );
6016 }
6017
6018 #[test]
6019 fn test_degenerate_probit_and_logit_jets_are_flat_on_active_clamps() {
6020 let probit = integrated_probit_jet(-40.0, 0.0);
6021 assert_eq!(probit.d1, 0.0);
6022 assert_eq!(probit.d2, 0.0);
6023 assert_eq!(probit.d3, 0.0);
6024
6025 let logit = component_point_jet(LinkComponent::Logit, -710.0);
6026 assert_eq!(logit.1, 0.0);
6027 assert_eq!(logit.2, 0.0);
6028 assert_eq!(logit.3, 0.0);
6029 }
6030
6031 #[test]
6032 fn test_degenerate_cloglog_component_jet_preserves_smooth_negative_tail() {
6033 let eta: f64 = -40.0;
6034 let t = eta.exp();
6035 let s = (-t).exp();
6036 let cloglog = component_point_jet(LinkComponent::CLogLog, eta);
6037 let expected_mean = -(-t).exp_m1();
6038 let expected_d1 = t * s;
6039 let expected_d2 = (t - t * t) * s;
6040 let expected_d3 = (t - 3.0 * t * t + t * t * t) * s;
6041
6042 assert!(cloglog.1 > 0.0, "negative-tail d1 should remain positive");
6043 assert_relative_eq!(
6044 cloglog.0,
6045 expected_mean,
6046 epsilon = 1e-30,
6047 max_relative = 1e-15
6048 );
6049 assert_relative_eq!(
6050 cloglog.1,
6051 expected_d1,
6052 epsilon = 1e-30,
6053 max_relative = 1e-15
6054 );
6055 assert_relative_eq!(
6056 cloglog.2,
6057 expected_d2,
6058 epsilon = 1e-30,
6059 max_relative = 1e-15
6060 );
6061 assert_relative_eq!(
6062 cloglog.3,
6063 expected_d3,
6064 epsilon = 1e-30,
6065 max_relative = 1e-15
6066 );
6067 }
6068
6069 #[test]
6070 fn test_zero_sigma_logit_and_cloglog_share_component_tail_jets() {
6071 let ctx = QuadratureContext::new();
6072 for (link, component, eta) in [
6073 (LinkFunction::Logit, LinkComponent::Logit, 50.0),
6074 (LinkFunction::CLogLog, LinkComponent::CLogLog, -50.0),
6075 ] {
6076 let integrated = integrated_inverse_link_jet(&ctx, link, eta, 0.0)
6077 .expect("degenerate integrated jet");
6078 let point = component_inverse_link_jet(component, eta);
6079 assert_eq!(integrated.mode, IntegratedExpectationMode::ExactClosedForm);
6080 assert_eq!(integrated.mean, point.mu);
6081 assert_eq!(integrated.d1, point.d1);
6082 assert_eq!(integrated.d2, point.d2);
6083 assert_eq!(integrated.d3, point.d3);
6084 }
6085 }
6086
6087 #[test]
6088 fn test_cloglog_controlled_matches_mathematical_target_on_small_sigma_grid() {
6089 let ctx = QuadratureContext::new();
6090 let cases = [
6094 (-30.0, 1e-10),
6095 (-30.0, 0.1),
6096 (-10.0, 0.24),
6097 (-3.0, 0.2),
6098 (0.0, 0.05),
6099 (0.4, 0.1),
6100 (3.0, 0.24),
6101 (10.0, 0.1),
6102 (30.0, 0.24),
6103 ];
6104
6105 for &(mu, sigma) in &cases {
6106 let approx = cloglog_posterior_meanwith_deriv_controlled(&ctx, mu, sigma);
6107 let (expected_mean, expected_deriv) = cloglog_reference_mean_and_derivative(mu, sigma);
6108 assert_relative_eq!(
6109 approx.mean,
6110 expected_mean,
6111 epsilon = 1e-12,
6112 max_relative = 2e-3
6113 );
6114 assert_relative_eq!(
6115 approx.dmean_dmu,
6116 expected_deriv,
6117 epsilon = 1e-12,
6118 max_relative = 4e-3
6119 );
6120 }
6121 }
6122
6123 #[test]
6124 fn test_cloglog_dispatch_uses_gamma_backend_for_large_sigma_central_regime() {
6125 let ctx = QuadratureContext::new();
6126 let out =
6127 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::CLogLog, -0.2, 0.8)
6128 .expect("cloglog integrated inverse-link moments should evaluate");
6129 assert_eq!(out.mode, IntegratedExpectationMode::ExactSpecialFunction);
6130 assert!(out.mean.is_finite());
6131 assert!(out.dmean_dmu.is_finite());
6132 assert!(out.dmean_dmu >= 0.0);
6133 }
6134
6135 #[test]
6136 fn test_cloglog_dispatch_uses_large_sigma_asymptotic_without_ghq() {
6137 let ctx = QuadratureContext::new();
6138 let out =
6139 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::CLogLog, 0.0, 20.0)
6140 .expect("cloglog integrated inverse-link moments should evaluate");
6141 assert_eq!(out.mode, IntegratedExpectationMode::ControlledAsymptotic);
6142 assert!(out.mean.is_finite());
6143 assert!(out.dmean_dmu.is_finite());
6144 assert!(out.dmean_dmu >= 0.0);
6145 }
6146
6147 #[test]
6148 fn test_cloglog_cc_matches_gamma_reference_on_central_case() {
6149 let ctx = QuadratureContext::new();
6150 let mu = -0.2;
6151 let sigma = 0.8;
6152 let cc = cloglog_survival_cc(&ctx, mu, sigma, CLOGLOG_CC_TOL).expect("cc backend");
6153 let gamma = cloglog_survival_gamma_reference(mu, sigma).expect("gamma backend");
6154 assert_relative_eq!(cc, gamma, epsilon = 5e-6, max_relative = 5e-6);
6155 }
6156
6157 #[test]
6158 fn test_cloglog_gamma_reference_matches_seeded_monte_carlo_small_case() {
6159 let mu = -0.2;
6160 let sigma = 0.8;
6161 let gamma =
6162 cloglog_posterior_meanwith_deriv_gamma_reference(mu, sigma).expect("gamma reference");
6163 let mut rng_state = 0x9e3779b97f4a7c15u64;
6164 let mut mean_mc = 0.0f64;
6165 let mut deriv_mc = 0.0f64;
6166 let n_samples = 300_000usize;
6167 for _ in 0..n_samples {
6168 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
6169 let u1 = ((rng_state as f64) / (u64::MAX as f64)).clamp(1e-12, 1.0 - 1e-12);
6170 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
6171 let u2 = ((rng_state as f64) / (u64::MAX as f64)).clamp(1e-12, 1.0 - 1e-12);
6172 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
6173 let eta = mu + sigma * z;
6174 mean_mc += cloglog_mean_exact(eta);
6175 deriv_mc += cloglog_mean_d1_exact(eta);
6176 }
6177 mean_mc /= n_samples as f64;
6178 deriv_mc /= n_samples as f64;
6179 assert_relative_eq!(gamma.mean, mean_mc, epsilon = 2e-3, max_relative = 2e-3);
6180 assert_relative_eq!(
6181 gamma.dmean_dmu,
6182 deriv_mc,
6183 epsilon = 2e-3,
6184 max_relative = 2e-3
6185 );
6186 }
6187
6188 #[test]
6189 fn test_logit_dispatch_uses_tail_asymptotic_outside_old_guard() {
6190 let ctx = QuadratureContext::new();
6191 let out = integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, 35.0, 1.0)
6192 .expect("logit integrated inverse-link moments should evaluate");
6193 assert_eq!(out.mode, IntegratedExpectationMode::ControlledAsymptotic);
6194 assert!(out.mean.is_finite());
6195 assert!(out.dmean_dmu.is_finite());
6196 assert!(out.dmean_dmu >= 0.0);
6197 }
6198
6199 #[test]
6200 fn test_logit_dispatch_prefers_erfcx_in_moderate_regime() {
6201 let ctx = QuadratureContext::new();
6212 let out = integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, 1.1, 0.8)
6213 .expect("logit integrated inverse-link moments should evaluate");
6214 assert!(matches!(
6215 out.mode,
6216 IntegratedExpectationMode::ExactSpecialFunction
6217 | IntegratedExpectationMode::QuadratureFallback
6218 ));
6219 assert!(out.mean.is_finite());
6220 assert!(out.dmean_dmu.is_finite());
6221 assert!(out.dmean_dmu >= 0.0);
6222 let (ref_mean, ref_d1, _, _) = logit_reference_jet_highres_simpson(1.1, 0.8);
6223 assert_relative_eq!(out.mean, ref_mean, epsilon = 1e-11, max_relative = 1e-10);
6224 assert_relative_eq!(out.dmean_dmu, ref_d1, epsilon = 1e-11, max_relative = 1e-10);
6225 }
6226
6227 #[test]
6228 fn test_logit_dispatch_large_sigma_uses_accurate_quadrature_not_monahan() {
6229 let ctx = QuadratureContext::new();
6238 let out = integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, 0.5, 20.0)
6239 .expect("logit integrated inverse-link moments should evaluate");
6240 assert_eq!(out.mode, IntegratedExpectationMode::QuadratureFallback);
6241 let (ref_mean, ref_d1, _, _) = logit_reference_jet_highres_simpson(0.5, 20.0);
6242 assert_relative_eq!(out.mean, ref_mean, epsilon = 1e-9, max_relative = 1e-7);
6243 assert_relative_eq!(out.dmean_dmu, ref_d1, epsilon = 1e-9, max_relative = 1e-7);
6244 let kappa = (1.0 + std::f64::consts::PI * 20.0 * 20.0 / 8.0)
6247 .sqrt()
6248 .recip();
6249 let monahan_mean = gam_math::probability::normal_cdf(0.5 * kappa);
6250 assert!(
6251 (out.mean - monahan_mean).abs() > 1e-3,
6252 "dispatcher must not return the inaccurate Monahan mean {monahan_mean}; got {}",
6253 out.mean
6254 );
6255 }
6256
6257 #[test]
6258 fn test_logit_controlled_path_keeps_exact_backend_in_moderate_regime() {
6259 let out = logit_posterior_meanwith_deriv_controlled(1.1, 0.8).expect("logit controlled");
6269 assert!(matches!(
6270 out.mode,
6271 IntegratedExpectationMode::ExactSpecialFunction
6272 | IntegratedExpectationMode::QuadratureFallback
6273 ));
6274 let (ref_mean, ref_d1, _, _) = logit_reference_jet_highres_simpson(1.1, 0.8);
6275 assert_relative_eq!(out.mean, ref_mean, epsilon = 1e-11, max_relative = 1e-10);
6276 assert_relative_eq!(out.dmean_dmu, ref_d1, epsilon = 1e-11, max_relative = 1e-10);
6277 }
6278
6279 #[test]
6280 fn test_logit_dispatch_derivative_correct_at_mu_zero_small_sigma() {
6281 let ctx = QuadratureContext::new();
6290 for &(mu, sigma) in &[(0.0, 0.3), (0.0, 0.4), (0.0, 0.5)] {
6291 let out =
6292 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, mu, sigma)
6293 .expect("logit integrated inverse-link moments should evaluate");
6294 assert_relative_eq!(out.mean, 0.5, epsilon = 1e-10);
6296 assert!(
6298 out.dmean_dmu <= 0.25 + 1e-9,
6299 "E[sigmoid'] must not exceed 0.25 at (μ={mu}, σ={sigma}); got {}",
6300 out.dmean_dmu
6301 );
6302 let (_, ref_d1, _, _) = logit_reference_jet_highres_simpson(mu, sigma);
6303 assert_relative_eq!(out.dmean_dmu, ref_d1, epsilon = 1e-9, max_relative = 1e-6);
6304 }
6305 }
6306
6307 #[test]
6308 fn test_logit_erfcx_exact_branch_is_self_certified() {
6309 for &(mu, sigma) in &[(8.0, 1.0), (10.0, 1.0), (15.0, 2.0)] {
6316 let out = logit_posterior_meanwith_deriv_exact(mu, sigma)
6317 .expect("erfcx branch should certify");
6318 assert_eq!(out.mode, IntegratedExpectationMode::ExactSpecialFunction);
6319 let (ref_mean, ref_d1, _, _) = logit_reference_jet_highres_simpson(mu, sigma);
6320 assert_relative_eq!(out.mean, ref_mean, epsilon = 1e-9, max_relative = 1e-7);
6321 assert_relative_eq!(out.dmean_dmu, ref_d1, epsilon = 1e-9, max_relative = 1e-7);
6322 }
6323 assert!(
6327 logit_posterior_meanwith_deriv_exact(0.0, 0.3).is_err(),
6328 "erfcx branch must not claim ExactSpecialFunction when it cannot certify the derivative"
6329 );
6330 }
6331
6332 #[test]
6333 fn test_logit_integrated_derivative_is_even_in_mu() {
6334 let ctx = QuadratureContext::new();
6340 for &(mu, sigma) in &[(0.3, 0.3), (1.1, 0.8), (10.0, 1.0), (3.0, 3.0), (35.0, 1.0)] {
6341 let pos =
6342 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, mu, sigma)
6343 .expect("logit moments (+μ)");
6344 let neg =
6345 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, -mu, sigma)
6346 .expect("logit moments (-μ)");
6347 assert_relative_eq!(
6348 pos.dmean_dmu,
6349 neg.dmean_dmu,
6350 epsilon = 1e-9,
6351 max_relative = 1e-7
6352 );
6353 assert_relative_eq!(
6355 neg.mean,
6356 1.0 - pos.mean,
6357 epsilon = 1e-9,
6358 max_relative = 1e-7
6359 );
6360 }
6361 }
6362
6363 #[test]
6364 fn test_logit_dmean_dmu_equals_fd_of_mean_across_regimes() {
6365 let ctx = QuadratureContext::new();
6380 let h = 1e-4;
6381 let cases = [
6382 (0.0, 0.8), (0.7, 0.8), (1.5, 1.2), (-1.1, 0.9), (8.0, 1.0), (10.0, 1.5), (-9.0, 1.0), (0.5, 0.05), (0.5, 20.0), ];
6392 for &(mu, sigma) in &cases {
6393 let at = |m: f64| {
6394 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, m, sigma)
6395 .expect("logit moments")
6396 };
6397 let out = at(mu);
6398 let fd = (at(mu + h).mean - at(mu - h).mean) / (2.0 * h);
6399 assert!(
6400 (out.dmean_dmu - fd).abs() <= 1e-5,
6401 "dmean_dmu must equal d/dμ of mean at (μ={mu}, σ={sigma}): \
6402 returned {}, FD of mean {} (mode {:?})",
6403 out.dmean_dmu,
6404 fd,
6405 out.mode
6406 );
6407 assert!(
6411 out.dmean_dmu <= 0.25 + 1e-9 && out.dmean_dmu >= 0.0,
6412 "dmean_dmu out of [0, 0.25] at (μ={mu}, σ={sigma}): {}",
6413 out.dmean_dmu
6414 );
6415 }
6416 }
6417
6418 #[test]
6419 fn test_logit_scalar_matches_jet_at_large_sigma() {
6420 let ctx = QuadratureContext::new();
6426 for &(mu, sigma) in &[(3.0, 3.0), (4.0, 4.0), (2.0, 5.0), (5.0, 5.0)] {
6427 let scalar =
6428 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, mu, sigma)
6429 .expect("scalar logit moments");
6430 let jet = integrated_inverse_link_jet(&ctx, LinkFunction::Logit, mu, sigma)
6431 .expect("jet logit moments");
6432 let (ref_mean, ref_d1, _, _) = logit_reference_jet_highres_simpson(mu, sigma);
6436 assert_relative_eq!(scalar.mean, ref_mean, epsilon = 1e-9, max_relative = 1e-8);
6437 assert_relative_eq!(
6438 scalar.dmean_dmu,
6439 ref_d1,
6440 epsilon = 1e-9,
6441 max_relative = 1e-8
6442 );
6443 assert_relative_eq!(scalar.mean, jet.mean, epsilon = 1e-12, max_relative = 1e-12);
6451 assert_relative_eq!(
6452 scalar.dmean_dmu,
6453 jet.d1,
6454 epsilon = 1e-12,
6455 max_relative = 1e-12
6456 );
6457 }
6458 }
6459
6460 #[test]
6461 fn test_logit_jet_accurate_at_wide_sigma() {
6462 let ctx = QuadratureContext::new();
6470 for &(mu, sigma) in &[(3.0, 3.0), (4.0, 4.0), (2.0, 5.0), (5.0, 5.0), (0.5, 20.0)] {
6471 let jet = integrated_inverse_link_jet(&ctx, LinkFunction::Logit, mu, sigma)
6472 .expect("wide-σ logit jet");
6473 let (rm, rd1, rd2, rd3) = logit_reference_jet_highres_simpson(mu, sigma);
6474 assert_relative_eq!(jet.mean, rm, epsilon = 1e-8, max_relative = 1e-7);
6475 assert_relative_eq!(jet.d1, rd1, epsilon = 1e-8, max_relative = 1e-6);
6476 assert_relative_eq!(jet.d2, rd2, epsilon = 1e-8, max_relative = 1e-6);
6477 assert_relative_eq!(jet.d3, rd3, epsilon = 1e-8, max_relative = 1e-6);
6478 let scalar =
6480 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, mu, sigma)
6481 .expect("scalar logit moments");
6482 assert_relative_eq!(jet.d1, scalar.dmean_dmu, epsilon = 1e-12);
6483 assert_relative_eq!(jet.mean, scalar.mean, epsilon = 1e-12);
6484 let pirls = integrated_logit_inverse_link_jet_pirls(&ctx, mu, sigma)
6486 .expect("wide-σ PIRLS logit jet");
6487 assert_relative_eq!(pirls.mean, jet.mean, epsilon = 1e-12);
6488 assert_relative_eq!(pirls.d1, jet.d1, epsilon = 1e-12);
6489 assert_relative_eq!(pirls.d2, jet.d2, epsilon = 1e-12);
6490 assert_relative_eq!(pirls.d3, jet.d3, epsilon = 1e-12);
6491 assert_eq!(pirls.mode, jet.mode);
6492 }
6493 }
6494
6495 #[test]
6496 fn test_logit_jet_continuous_across_ghq_simpson_seam() {
6497 let ctx = QuadratureContext::new();
6505 let sigma = LOGIT_JET_GHQ_SIGMA_MAX;
6506 for mu in [-2.0, -0.5, 0.0, 0.7, 1.3, 3.0] {
6507 let ghq = integrated_inverse_link_jet(&ctx, LinkFunction::Logit, mu, sigma)
6509 .expect("jet at seam (GHQ dispatch)");
6510 let simpson = logit_wide_sigma_jet(mu, sigma).expect("jet at seam (Simpson)");
6512 assert_relative_eq!(ghq.mean, simpson.mean, epsilon = 1e-9, max_relative = 1e-8);
6515 assert_relative_eq!(ghq.d1, simpson.d1, epsilon = 1e-9, max_relative = 1e-7);
6516 assert_relative_eq!(ghq.d2, simpson.d2, epsilon = 1e-9, max_relative = 1e-7);
6517 assert_relative_eq!(ghq.d3, simpson.d3, epsilon = 1e-8, max_relative = 1e-6);
6518 }
6519 }
6520
6521 #[test]
6522 fn test_logit_batch_uses_same_dispatchvalues() {
6523 let ctx = QuadratureContext::new();
6524 let eta = ndarray::array![-2.0, 0.0, 1.25, 35.0];
6525 let se = ndarray::array![0.1, 0.5, 1.0, 1.0];
6526 let batch_mean = logit_posterior_mean_batch(&ctx, &eta, &se)
6527 .expect("logit posterior mean batch should evaluate");
6528 let (batchmu, batch_dmu) = logit_posterior_meanwith_deriv_batch(&ctx, &eta, &se)
6529 .expect("logit posterior mean derivative batch should evaluate");
6530 for i in 0..eta.len() {
6531 let direct = integrated_inverse_link_mean_and_derivative(
6532 &ctx,
6533 LinkFunction::Logit,
6534 eta[i],
6535 se[i],
6536 )
6537 .expect("logit integrated inverse-link moments should evaluate");
6538 assert_relative_eq!(batch_mean[i], direct.mean, epsilon = 1e-12);
6539 assert_relative_eq!(batchmu[i], direct.mean, epsilon = 1e-12);
6540 assert_relative_eq!(batch_dmu[i], direct.dmean_dmu, epsilon = 1e-12);
6541 }
6542 }
6543
6544 #[test]
6545 fn exact_logit_small_se_branch_loses_tail_derivative() {
6546 let eta = 50.0_f64;
6547 let stable_z = (-eta).exp();
6548 let stable_dmu = stable_z / (1.0_f64 + stable_z).powi(2);
6549 assert!(stable_dmu > 0.0);
6550 let out = logit_posterior_meanwith_deriv_exact(eta, 0.0).expect("exact branch");
6551 let dmu = out.dmean_dmu;
6552 assert!(
6553 (dmu - stable_dmu).abs() < 1e-30,
6554 "exact logit small-se branch should use the stable derivative z/(1+z)^2 at eta={eta}; got {} vs {}",
6555 dmu,
6556 stable_dmu
6557 );
6558 }
6559
6560 #[test]
6561 fn integrated_family_moments_rejects_latent_cloglog_without_concrete_handler() {
6562 let ctx = QuadratureContext::new();
6568 let latent =
6569 gam_problem::types::LatentCLogLogState::new(0.4).expect("valid latent cloglog state");
6570 let spec =
6571 LikelihoodSpec::new(ResponseFamily::Binomial, InverseLink::LatentCLogLog(latent));
6572 let err = integrated_family_moments_jet(
6573 &ctx,
6574 &spec,
6575 LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
6576 0.2,
6577 0.5,
6578 )
6579 .expect_err("latent cloglog moments should error in this dispatcher");
6580 assert!(format!("{err}").contains("LatentCLogLog"));
6581 }
6582
6583 #[test]
6584 fn integrated_family_moments_supports_stateful_sas() {
6585 let ctx = QuadratureContext::new();
6586 let sas = crate::mixture_link::state_from_sasspec(gam_problem::types::SasLinkSpec {
6587 initial_epsilon: 0.3,
6588 initial_log_delta: -0.2,
6589 })
6590 .expect("sas state should reconstruct from raw parameters");
6591 let spec = LikelihoodSpec::new(ResponseFamily::Binomial, InverseLink::Sas(sas));
6592 let out = integrated_family_moments_jet(
6593 &ctx,
6594 &spec,
6595 LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
6596 0.2,
6597 0.5,
6598 )
6599 .expect("stateful SAS integrated moments should evaluate");
6600 assert!(out.mean.is_finite());
6601 assert!(out.d1.is_finite());
6602 assert!(out.d2.is_finite());
6603 assert!(out.d3.is_finite());
6604 assert!(out.mean > 0.0 && out.mean < 1.0);
6605 }
6606
6607 #[test]
6608 fn integrated_family_moments_supports_pure_probit_mixture() {
6609 let ctx = QuadratureContext::new();
6610 let state = crate::mixture_link::state_fromspec(&gam_problem::types::MixtureLinkSpec {
6611 components: vec![gam_problem::types::LinkComponent::Probit],
6612 initial_rho: ndarray::Array1::<f64>::zeros(0),
6613 })
6614 .expect("single-component probit mixture state");
6615 let spec = LikelihoodSpec::new(ResponseFamily::Binomial, InverseLink::Mixture(state));
6616 let out = integrated_family_moments_jet(
6617 &ctx,
6618 &spec,
6619 LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
6620 0.7,
6621 1.3,
6622 )
6623 .expect("pure probit mixture integrated moments should evaluate");
6624 let exact = integrated_probit_jet(0.7, 1.3);
6625 assert_relative_eq!(out.mean, exact.mean, epsilon = 1e-12);
6626 assert_relative_eq!(out.d1, exact.d1, epsilon = 1e-12);
6627 assert_relative_eq!(out.d2, exact.d2, epsilon = 1e-12);
6628 assert_relative_eq!(out.d3, exact.d3, epsilon = 1e-12);
6629 assert_eq!(out.mode, IntegratedExpectationMode::ExactClosedForm);
6630 }
6631
6632 #[test]
6633 fn integrated_family_moments_supports_pure_logit_mixture() {
6634 let ctx = QuadratureContext::new();
6635 let state = crate::mixture_link::state_fromspec(&gam_problem::types::MixtureLinkSpec {
6636 components: vec![gam_problem::types::LinkComponent::Logit],
6637 initial_rho: ndarray::Array1::<f64>::zeros(0),
6638 })
6639 .expect("single-component logit mixture state");
6640 let spec = LikelihoodSpec::new(ResponseFamily::Binomial, InverseLink::Mixture(state));
6641 let out = integrated_family_moments_jet(
6642 &ctx,
6643 &spec,
6644 LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
6645 1.1,
6646 0.8,
6647 )
6648 .expect("pure logit mixture integrated moments should evaluate");
6649 let exact = integrated_inverse_link_jet(&ctx, LinkFunction::Logit, 1.1, 0.8)
6650 .expect("canonical integrated logit jet");
6651 assert_relative_eq!(out.mean, exact.mean, epsilon = 1e-12);
6652 assert_relative_eq!(out.d1, exact.d1, epsilon = 1e-12);
6653 assert_relative_eq!(out.d2, exact.d2, epsilon = 1e-12);
6654 assert_relative_eq!(out.d3, exact.d3, epsilon = 1e-12);
6655 assert_eq!(out.mode, exact.mode);
6656 }
6657
6658 #[test]
6659 fn integrated_family_moments_supports_stateful_mixture() {
6660 let ctx = QuadratureContext::new();
6661 let state = crate::mixture_link::state_fromspec(&gam_problem::types::MixtureLinkSpec {
6662 components: vec![
6663 gam_problem::types::LinkComponent::Logit,
6664 gam_problem::types::LinkComponent::Probit,
6665 ],
6666 initial_rho: ndarray::array![0.35],
6667 })
6668 .expect("mixture state should reconstruct from rho");
6669 let spec = LikelihoodSpec::new(
6670 ResponseFamily::Binomial,
6671 InverseLink::Mixture(state.clone()),
6672 );
6673 let out = integrated_family_moments_jet(
6674 &ctx,
6675 &spec,
6676 LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
6677 0.2,
6678 0.5,
6679 )
6680 .expect("stateful mixture integrated moments should evaluate");
6681 let direct = integrated_mixture_jet(&ctx, 0.2, 0.5, &state)
6682 .expect("direct integrated mixture jet should evaluate");
6683 assert_relative_eq!(out.mean, direct.mean, epsilon = 1e-12);
6684 assert_relative_eq!(out.d1, direct.d1, epsilon = 1e-12);
6685 assert_relative_eq!(out.d2, direct.d2, epsilon = 1e-12);
6686 assert_relative_eq!(out.d3, direct.d3, epsilon = 1e-12);
6687 assert_eq!(out.mode, direct.mode);
6688 }
6689
6690 #[test]
6691 fn integrated_family_moments_use_scale_dispersion_for_tweedie_and_gamma() {
6692 let ctx = QuadratureContext::new();
6696 let e = 0.3_f64;
6698 let se = 0.5_f64;
6699 let m = (e + 0.5 * se * se).exp();
6700
6701 let p = 1.5_f64;
6703 let phi = 2.0_f64;
6704 let tweedie = LikelihoodSpec::tweedie_log(p);
6705 let out = integrated_family_moments_jet(
6706 &ctx,
6707 &tweedie,
6708 LikelihoodScaleMetadata::EstimatedTweediePhi { phi },
6709 e,
6710 se,
6711 )
6712 .expect("tweedie integrated moments should evaluate");
6713 let expected = phi * m.powf(p);
6714 assert_relative_eq!(out.variance, expected, epsilon = 1e-12);
6715 assert_relative_eq!(out.variance / m.powf(p), phi, epsilon = 1e-12);
6717
6718 let shape = 4.0_f64;
6720 let gamma = LikelihoodSpec::gamma_log();
6721 let out = integrated_family_moments_jet(
6722 &ctx,
6723 &gamma,
6724 LikelihoodScaleMetadata::EstimatedGammaShape { shape },
6725 e,
6726 se,
6727 )
6728 .expect("gamma integrated moments should evaluate");
6729 let expected = m * m / shape;
6730 assert_relative_eq!(out.variance, expected, epsilon = 1e-12);
6731 assert_relative_eq!(out.variance / (m * m), 1.0 / shape, epsilon = 1e-12);
6733
6734 let poisson = LikelihoodSpec::poisson_log();
6736 let out = integrated_family_moments_jet(
6737 &ctx,
6738 &poisson,
6739 LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
6740 e,
6741 se,
6742 )
6743 .expect("poisson integrated moments should evaluate");
6744 assert_relative_eq!(out.variance, m, epsilon = 1e-12);
6745
6746 let theta = 3.0_f64;
6748 let nb = LikelihoodSpec::negative_binomial_log(theta);
6749 let out = integrated_family_moments_jet(
6750 &ctx,
6751 &nb,
6752 LikelihoodScaleMetadata::EstimatedNegBinTheta { theta },
6753 e,
6754 se,
6755 )
6756 .expect("negative-binomial integrated moments should evaluate");
6757 assert_relative_eq!(out.variance, m + m * m / theta, epsilon = 1e-12);
6758
6759 let err = integrated_family_moments_jet(
6761 &ctx,
6762 &gamma,
6763 LikelihoodScaleMetadata::Unspecified,
6764 e,
6765 se,
6766 )
6767 .expect_err("gamma without a shape in the scale metadata must error");
6768 assert!(
6769 format!("{err}").contains("Gamma integrated variance requires the shape"),
6770 "unexpected error message: {err}"
6771 );
6772
6773 let err = integrated_family_moments_jet(
6775 &ctx,
6776 &tweedie,
6777 LikelihoodScaleMetadata::Unspecified,
6778 e,
6779 se,
6780 )
6781 .expect_err("tweedie without a φ in the scale metadata must error");
6782 assert!(
6783 format!("{err}").contains("Tweedie integrated variance requires dispersion"),
6784 "unexpected error message: {err}"
6785 );
6786 }
6787
6788 #[test]
6791 fn cloglog_g_derivatives_at_zero() {
6792 let (g, g1, g2, g3, g4) = cloglog_g_derivatives(0.0);
6793 let expected_g = 1.0 - (-1.0_f64).exp();
6795 assert_relative_eq!(g, expected_g, epsilon = 1e-14);
6796 let e_neg1 = (-1.0_f64).exp();
6798 assert_relative_eq!(g1, e_neg1, epsilon = 1e-14);
6799 assert_relative_eq!(g2, 0.0, epsilon = 1e-14);
6801 assert_relative_eq!(g3, -e_neg1, epsilon = 1e-14);
6803 assert_relative_eq!(g4, -e_neg1, epsilon = 1e-14);
6805 }
6806
6807 #[test]
6808 fn cloglog_g_derivatives_saturation() {
6809 let (g, g1, g2, g3, g4) = cloglog_g_derivatives(50.0);
6811 assert_relative_eq!(g, 1.0, epsilon = 1e-10);
6812 assert_eq!(g1, 0.0);
6813 assert_eq!(g2, 0.0);
6814 assert_eq!(g3, 0.0);
6815 assert_eq!(g4, 0.0);
6816
6817 let (g, g1, g2, g3, g4) = cloglog_g_derivatives(-50.0);
6819 let expected = (-50.0_f64).exp();
6820 assert_relative_eq!(g, expected, max_relative = 1e-10);
6821 assert_relative_eq!(g1, expected, max_relative = 1e-10);
6822 assert_relative_eq!(g2, expected, max_relative = 1e-10);
6824 assert_relative_eq!(g3, expected, max_relative = 1e-10);
6825 assert_relative_eq!(g4, expected, max_relative = 1e-10);
6826 }
6827
6828 #[test]
6829 fn cloglog_ghq_value_sigma_zero_matches_pointwise() {
6830 let ctx = QuadratureContext::new();
6831 for &mu in &[-2.0, -1.0, 0.0, 0.5, 1.5] {
6833 let val = cloglog_ghq_value(&ctx, mu, 0.0, 21);
6834 let (g, _, _, _, _) = cloglog_g_derivatives(mu);
6835 assert_relative_eq!(val, g, epsilon = 1e-14);
6836 }
6837 }
6838
6839 #[test]
6840 fn cloglog_ghq_value_bounded_zero_one() {
6841 let ctx = QuadratureContext::new();
6842 for &mu in &[-5.0, -2.0, 0.0, 1.0, 3.0, 10.0] {
6844 for &sigma in &[0.1, 0.5, 1.0, 2.0, 5.0] {
6845 let val = cloglog_ghq_value(&ctx, mu, sigma, 31);
6846 assert!((0.0..=1.0).contains(&val), "L({mu},{sigma}) = {val}");
6847 }
6848 }
6849 }
6850
6851 #[test]
6852 fn cloglog_ghq_derivatives_sigma_zero_matches_pointwise() {
6853 let ctx = QuadratureContext::new();
6854 let mu = 0.3;
6855 let d = cloglog_ghq_derivatives(&ctx, mu, 0.0, 21);
6856 let (g, g1, g2, g3, g4) = cloglog_g_derivatives(mu);
6857 assert_relative_eq!(d.l, g, epsilon = 1e-14);
6858 assert_relative_eq!(d.l_mu, g1, epsilon = 1e-14);
6859 assert_relative_eq!(d.l_mumu, g2, epsilon = 1e-14);
6860 assert_relative_eq!(d.l_mumumu, g3, epsilon = 1e-14);
6861 assert_relative_eq!(d.l_mumumumu, g4, epsilon = 1e-14);
6862
6863 assert_eq!(d.l_sigma, 0.0);
6865 assert_eq!(d.l_musigma, 0.0);
6866 assert_eq!(d.l_mumusigma, 0.0);
6867 assert_eq!(d.l_mumumusigma, 0.0);
6868 assert_eq!(d.l_sigmasigmasigma, 0.0);
6869 assert_eq!(d.l_musigmasigmasigma, 0.0);
6870
6871 assert_relative_eq!(d.l_sigmasigma, g2, epsilon = 1e-14);
6874 assert_relative_eq!(d.l_musigmasigma, g3, epsilon = 1e-14);
6875 assert_relative_eq!(d.l_mumusigmasigma, g4, epsilon = 1e-14);
6876 assert_relative_eq!(d.l_sigmasigmasigmasigma, 3.0 * g4, epsilon = 1e-14);
6877 }
6878
6879 #[test]
6880 fn cloglog_ghq_derivatives_finite_difference_mu() {
6881 let ctx = QuadratureContext::new();
6883 let mu = 0.5;
6884 let sigma = 0.8;
6885 let h = 1e-6;
6886 let d = cloglog_ghq_derivatives(&ctx, mu, sigma, 31);
6887 let l_plus = cloglog_ghq_value(&ctx, mu + h, sigma, 31);
6888 let l_minus = cloglog_ghq_value(&ctx, mu - h, sigma, 31);
6889 let fd_mu = (l_plus - l_minus) / (2.0 * h);
6890 assert_relative_eq!(d.l_mu, fd_mu, epsilon = 1e-5);
6891
6892 let d_plus = cloglog_ghq_derivatives(&ctx, mu + h, sigma, 31);
6894 let d_minus = cloglog_ghq_derivatives(&ctx, mu - h, sigma, 31);
6895 let fd_mumu = (d_plus.l_mu - d_minus.l_mu) / (2.0 * h);
6896 assert_relative_eq!(d.l_mumu, fd_mumu, epsilon = 1e-4);
6897 }
6898
6899 #[test]
6900 fn cloglog_ghq_derivatives_finite_difference_sigma() {
6901 let ctx = QuadratureContext::new();
6903 let mu = 0.2;
6904 let sigma = 1.0;
6905 let h = 1e-6;
6906 let d = cloglog_ghq_derivatives(&ctx, mu, sigma, 31);
6907 let l_plus = cloglog_ghq_value(&ctx, mu, sigma + h, 31);
6908 let l_minus = cloglog_ghq_value(&ctx, mu, sigma - h, 31);
6909 let fd_sigma = (l_plus - l_minus) / (2.0 * h);
6910 assert_relative_eq!(d.l_sigma, fd_sigma, epsilon = 1e-5);
6911 }
6912
6913 #[test]
6914 fn cloglog_ghq_derivatives_finite_difference_cross() {
6915 let ctx = QuadratureContext::new();
6917 let mu = -0.5;
6918 let sigma = 0.6;
6919 let h = 1e-6;
6920 let d = cloglog_ghq_derivatives(&ctx, mu, sigma, 31);
6921 let d_plus = cloglog_ghq_derivatives(&ctx, mu, sigma + h, 31);
6922 let d_minus = cloglog_ghq_derivatives(&ctx, mu, sigma - h, 31);
6923 let fd_musigma = (d_plus.l_mu - d_minus.l_mu) / (2.0 * h);
6924 assert_relative_eq!(d.l_musigma, fd_musigma, epsilon = 1e-4);
6925 }
6926
6927 #[test]
6928 fn cloglog_ghq_l_mu_nonnegative() {
6929 let ctx = QuadratureContext::new();
6931 for &mu in &[-3.0, -1.0, 0.0, 1.0, 3.0] {
6932 for &sigma in &[0.1, 0.5, 1.0, 2.0] {
6933 let d = cloglog_ghq_derivatives(&ctx, mu, sigma, 21);
6934 assert!(
6935 d.l_mu >= -1e-14,
6936 "L_mu should be non-negative at mu={mu}, sigma={sigma}: got {}",
6937 d.l_mu
6938 );
6939 }
6940 }
6941 }
6942
6943 #[test]
6944 fn cloglog_ghq_adaptive_matches_explicit() {
6945 let ctx = QuadratureContext::new();
6946 let mu = 0.7;
6947 let sigma = 1.2;
6948 let adaptive = cloglog_ghq_derivatives_adaptive(&ctx, mu, sigma);
6949 let n = adaptive_point_count_from_sd(sigma);
6950 let explicit = cloglog_ghq_derivatives(&ctx, mu, sigma, n);
6951 assert_relative_eq!(adaptive.l, explicit.l, epsilon = 1e-15);
6952 assert_relative_eq!(adaptive.l_mu, explicit.l_mu, epsilon = 1e-15);
6953 assert_relative_eq!(adaptive.l_sigma, explicit.l_sigma, epsilon = 1e-15);
6954 assert_relative_eq!(adaptive.l_mumu, explicit.l_mumu, epsilon = 1e-15);
6955 }
6956
6957 #[test]
6958 fn cloglog_ghq_value_matches_mathematical_target_in_central_regime() {
6959 let ctx = QuadratureContext::new();
6960 for &mu in &[-1.0, 0.0, 0.5, 2.0] {
6961 for &sigma in &[0.1, 0.5, 1.0] {
6962 let ghq = cloglog_ghq_value(&ctx, mu, sigma, 51);
6963 let (expected_mean, _) = cloglog_reference_mean_and_derivative(mu, sigma);
6964 assert_relative_eq!(ghq, expected_mean, epsilon = 1e-12, max_relative = 2e-8);
6965 }
6966 }
6967 }
6968
6969 #[test]
6972 fn cloglog_negative_tail_mean_matches_exact_near_transition() {
6973 let eta: f64 = -30.0;
6977 let exact = {
6978 let ex = eta.exp();
6979 -(-ex).exp_m1()
6980 };
6981 let tail = cloglog_negative_tail_mean(eta);
6982 assert!(
6983 (exact - tail).abs() < 1e-26 * exact.abs().max(1e-300),
6984 "tail mean at η={eta}: exact={exact:.6e} tail={tail:.6e}"
6985 );
6986 }
6987
6988 #[inline]
6989 fn cloglog_negative_tail_derivative(eta: f64) -> f64 {
6990 if eta < -745.0 {
6992 0.0
6993 } else {
6994 let ex = safe_exp(eta);
6995 (ex * (-ex).exp()).max(0.0)
6996 }
6997 }
6998
6999 #[test]
7000 fn cloglog_negative_tail_derivative_matches_exact_near_transition() {
7001 let eta: f64 = -30.0;
7003 let ex = eta.exp();
7004 let exact = ex * (-ex).exp();
7005 let tail = cloglog_negative_tail_derivative(eta);
7006 assert!(
7007 (exact - tail).abs() < 1e-26 * exact.abs().max(1e-300),
7008 "tail derivative at η={eta}: exact={exact:.6e} tail={tail:.6e}"
7009 );
7010 }
7011
7012 #[test]
7013 fn cloglog_negative_tail_degenerate_branch_matches_target_near_transition() {
7014 let ctx = QuadratureContext::default();
7015 let sigma = 0.0;
7016 for &mu in &[-30.001, -30.0, -29.999] {
7017 let out = cloglog_posterior_meanwith_deriv_controlled(&ctx, mu, sigma);
7018 assert_relative_eq!(
7019 out.mean,
7020 cloglog_mean_exact(mu),
7021 epsilon = 1e-28,
7022 max_relative = 1e-15
7023 );
7024 assert_relative_eq!(
7025 out.dmean_dmu,
7026 cloglog_mean_d1_exact(mu),
7027 epsilon = 1e-28,
7028 max_relative = 1e-15
7029 );
7030 }
7031 }
7032
7033 #[test]
7034 fn cloglog_negative_tail_small_sigma_branch_matches_target_near_transition() {
7035 let ctx = QuadratureContext::default();
7036 let sigma = 0.1;
7037 for &mu in &[-30.001, -30.0, -29.999] {
7038 let out = cloglog_posterior_meanwith_deriv_controlled(&ctx, mu, sigma);
7039 let (expected_mean, expected_deriv) = cloglog_reference_mean_and_derivative(mu, sigma);
7040 assert_relative_eq!(
7041 out.mean,
7042 expected_mean,
7043 epsilon = 1e-24,
7044 max_relative = 1e-10
7045 );
7046 assert_relative_eq!(
7047 out.dmean_dmu,
7048 expected_deriv,
7049 epsilon = 1e-24,
7050 max_relative = 1e-10
7051 );
7052 }
7053 }
7054
7055 fn ref_cholesky_heap(cov: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
7059 let n = cov.len();
7060 if n == 0 || cov.iter().any(|r| r.len() != n) {
7061 return None;
7062 }
7063 let mut base = cov.to_vec();
7064 for retry in 0..8 {
7065 let jitter = if retry == 0 {
7066 0.0
7067 } else {
7068 1e-12 * 10f64.powi(retry - 1)
7069 };
7070 if jitter > 0.0 {
7071 for i in 0..n {
7072 base[i][i] = cov[i][i] + jitter;
7073 }
7074 }
7075 let mut l = vec![vec![0.0_f64; n]; n];
7076 let mut ok = true;
7077 for i in 0..n {
7078 for j in 0..=i {
7079 let mut sum = base[i][j];
7080 for k in 0..j {
7081 sum -= l[i][k] * l[j][k];
7082 }
7083 if i == j {
7084 if !sum.is_finite() || sum <= 0.0 {
7085 ok = false;
7086 break;
7087 }
7088 l[i][j] = sum.sqrt();
7089 } else {
7090 l[i][j] = sum / l[j][j];
7091 }
7092 }
7093 if !ok {
7094 break;
7095 }
7096 }
7097 if ok {
7098 return Some(l);
7099 }
7100 }
7101 None
7102 }
7103
7104 #[test]
7105 fn cholesky_static_matches_heap_d2() {
7106 let cases: &[[[f64; 2]; 2]] = &[
7109 [[1.0, 0.0], [0.0, 1.0]],
7110 [[2.5, 0.3], [0.3, 0.75]],
7111 [[1.0, 0.9999], [0.9999, 1.0]],
7112 [[1e-10, 0.0], [0.0, 1e-10]],
7113 [[4.0, -1.5], [-1.5, 2.25]],
7114 ];
7115 for cov in cases {
7116 let stack = cholesky_static_with_jitter::<2>(cov).expect("stack cholesky");
7117 let heap_in: Vec<Vec<f64>> = cov.iter().map(|r| r.to_vec()).collect();
7118 let heap = ref_cholesky_heap(&heap_in).expect("heap cholesky");
7119 for i in 0..2 {
7120 for j in 0..2 {
7121 assert_eq!(
7122 stack[i][j].to_bits(),
7123 heap[i][j].to_bits(),
7124 "mismatch at ({i},{j}) for cov={cov:?}"
7125 );
7126 }
7127 }
7128 }
7129 }
7130
7131 #[test]
7132 fn cholesky_static_matches_heap_d3() {
7133 let cases: &[[[f64; 3]; 3]] = &[
7134 [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
7135 [[2.0, 0.5, 0.1], [0.5, 1.5, -0.2], [0.1, -0.2, 0.8]],
7136 [[4.0, 1.0, 0.5], [1.0, 3.0, 0.25], [0.5, 0.25, 2.0]],
7137 ];
7138 for cov in cases {
7139 let stack = cholesky_static_with_jitter::<3>(cov).expect("stack cholesky");
7140 let heap_in: Vec<Vec<f64>> = cov.iter().map(|r| r.to_vec()).collect();
7141 let heap = ref_cholesky_heap(&heap_in).expect("heap cholesky");
7142 for i in 0..3 {
7143 for j in 0..3 {
7144 assert_eq!(
7145 stack[i][j].to_bits(),
7146 heap[i][j].to_bits(),
7147 "mismatch at ({i},{j}) for cov={cov:?}"
7148 );
7149 }
7150 }
7151 }
7152 }
7153
7154 #[test]
7155 fn cholesky_static_d1() {
7156 let l = cholesky_static_with_jitter::<1>(&[[2.25]]).expect("d=1");
7157 assert_eq!(l[0][0], 1.5);
7158 assert!(cholesky_static_with_jitter::<1>(&[[-1.0e-13]]).is_some());
7168 assert!(cholesky_static_with_jitter::<1>(&[[-1.0e3]]).is_none());
7171 }
7172}