1use std::collections::HashMap;
173use std::convert::Infallible;
174use std::sync::{Arc, Mutex, OnceLock};
175
176use crate::estimate::EstimationError;
177use crate::mixture_link::{
178 beta_logistic_inverse_link_jet, component_inverse_link_jet, sas_inverse_link_jet,
179};
180use gam_math::probability::erfcx_nonnegative;
181use gam_math::special::stable_polynomial_times_exp_neg as cloglog_stable_poly_times_exp_neg;
182use gam_problem::types::{
183 InverseLink, LikelihoodScaleMetadata, LikelihoodSpec, LinkComponent, LinkFunction,
184 MixtureLinkState, ResponseFamily, SasLinkState, StandardLink,
185};
186use statrs::function::erf::erfc;
187
188const N_POINTS: usize = 7;
190const SQRT_2: f64 = std::f64::consts::SQRT_2;
191const QUADRATURE_EXP_LOG_MAX: f64 = 700.0;
192
193#[inline]
196fn safe_exp(x: f64) -> f64 {
197 if x.is_nan() {
198 f64::NAN
199 } else {
200 x.min(QUADRATURE_EXP_LOG_MAX).exp()
201 }
202}
203
204#[inline]
205fn safe_expwith_saturation(x: f64) -> (f64, bool) {
206 (safe_exp(x), x > QUADRATURE_EXP_LOG_MAX)
207}
208
209#[derive(Clone, Copy, Debug, Default)]
210struct Complex {
211 re: f64,
212 im: f64,
213}
214
215pub struct QuadratureContext {
217 gh_cache: OnceLock<GaussHermiteRule>,
218 gh15_cache: OnceLock<GaussHermiteRuleDynamic>,
219 gh21_cache: OnceLock<GaussHermiteRuleDynamic>,
220 gh31_cache: OnceLock<GaussHermiteRuleDynamic>,
221 gh51_cache: OnceLock<GaussHermiteRuleDynamic>,
222 cc_cache: Mutex<HashMap<usize, Arc<ClenshawCurtisRule>>>,
226}
227
228#[derive(Clone, Copy, Debug, Eq, PartialEq)]
229pub enum IntegratedExpectationMode {
230 ExactClosedForm,
231 ExactSpecialFunction,
232 ControlledAsymptotic,
233 QuadratureFallback,
234}
235
236impl IntegratedExpectationMode {
237 #[inline]
241 pub const fn rank(self) -> u8 {
242 match self {
243 Self::ExactClosedForm => 0,
244 Self::ExactSpecialFunction => 1,
245 Self::ControlledAsymptotic => 2,
246 Self::QuadratureFallback => 3,
247 }
248 }
249}
250
251#[derive(Clone, Copy, Debug)]
252pub struct IntegratedMeanDerivative {
253 pub mean: f64,
254 pub dmean_dmu: f64,
255 pub mode: IntegratedExpectationMode,
256}
257
258#[derive(Clone, Copy, Debug)]
259pub struct IntegratedInverseLinkJet {
260 pub mean: f64,
261 pub d1: f64,
262 pub d2: f64,
263 pub d3: f64,
264 pub mode: IntegratedExpectationMode,
265}
266
267#[derive(Clone, Copy, Debug)]
268pub(crate) struct IntegratedInverseLinkJet5 {
269 pub mean: f64,
270 pub d1: f64,
271 pub d2: f64,
272 pub d3: f64,
273 pub d4: f64,
274 pub d5: f64,
275 pub mode: IntegratedExpectationMode,
276}
277
278#[inline]
279pub(crate) fn validate_latent_cloglog_inputs(eta: f64, sigma: f64) -> Result<(), EstimationError> {
280 if !eta.is_finite() || !sigma.is_finite() || sigma < 0.0 {
281 crate::bail_invalid_estim!(
282 "latent cloglog jet requires finite eta and sigma >= 0, got eta={eta}, sigma={sigma}"
283 );
284 }
285 Ok::<(), _>(())
286}
287
288#[derive(Clone, Copy, Debug)]
293pub struct IntegratedMomentsJet {
294 pub mean: f64,
295 pub variance: f64,
296 pub d1: f64,
297 pub d2: f64,
298 pub d3: f64,
299 pub mode: IntegratedExpectationMode,
300}
301
302const LOGIT_SIGMA_DEGENERATE: f64 = 1e-10;
303const LOGIT_SIGMA_TAYLOR_MAX: f64 = 2.5e-1;
304const LOGIT_TAIL_LOG_MAX: f64 = -18.0;
305const LOGIT_ERFCX_MU_MAX: f64 = 40.0;
306const LOGIT_ERFCX_SIGMA_MAX: f64 = 6.0;
307const LOGIT_JET_GHQ_SIGMA_MAX: f64 = 1.0;
331const CLOGLOG_SIGMA_DEGENERATE: f64 = 1e-10;
332const CLOGLOG_SIGMA_TAYLOR_MAX: f64 = 0.25;
333const CLOGLOG_JET_MOMENT_SIGMA_MAX: f64 = 1.0;
340const CLOGLOG_RARE_EVENT_LOG_MAX: f64 = -18.0;
341const CLOGLOG_LARGE_SIGMA_ASYMPTOTIC_MIN: f64 = 8.0;
342const CLOGLOG_POSITIVE_SATURATION_EDGE: f64 = 5.0;
343const CLOGLOG_POSITIVE_SATURATION_SIGMAS: f64 = 8.0;
344const CLOGLOG_GUMBEL_QUAD_ETA_LO: f64 = -40.0;
350const CLOGLOG_GUMBEL_QUAD_ETA_HI: f64 = 6.0;
351const CLOGLOG_GUMBEL_QUAD_MIN_NODES: usize = 97;
355const CLOGLOG_GUMBEL_QUAD_NODE_SCALE: f64 = 320.0;
359const CLOGLOG_GUMBEL_QUAD_MAX_NODES: usize = 513;
360const SERIES_CONSECUTIVE_SMALL_TERMS: usize = 6;
361const LOGIT_MAX_TERMS: usize = 160;
362const LOGIT_ERFCX_ACCURACY_TARGET: f64 = 1.0e-11;
378const CLOGLOG_MILES_ALPHA: f64 = 60.0;
379const CLOGLOG_MILES_MAX_TERMS: usize = 256;
380const CLOGLOG_MILES_PEAK_LOG_MAX: f64 = 0.0;
397const CLOGLOG_GAMMA_K_REF: f64 = 0.5;
398const CLOGLOG_GAMMA_T_MAX_REF: f64 = 24.0;
399const CLOGLOG_GAMMA_H_REF: f64 = 0.01;
400const CLOGLOG_CC_TOL: f64 = 1e-12;
404const CLOGLOG_CC_NODE_CAP: usize = 1025;
408const CLOGLOG_GAMMA_SAMPLE_COUNT: usize =
412 (CLOGLOG_GAMMA_T_MAX_REF / CLOGLOG_GAMMA_H_REF) as usize + 1;
413const CLOGLOG_CC_PREFER_THRESHOLD: usize = CLOGLOG_GAMMA_SAMPLE_COUNT / 3;
418const CLOGLOG_CC_MIN_N: usize = 17;
422
423impl QuadratureContext {
424 pub fn new() -> Self {
425 Self {
426 gh_cache: OnceLock::new(),
427 gh15_cache: OnceLock::new(),
428 gh21_cache: OnceLock::new(),
429 gh31_cache: OnceLock::new(),
430 gh51_cache: OnceLock::new(),
431 cc_cache: Mutex::new(HashMap::new()),
432 }
433 }
434
435 fn gauss_hermite(&self) -> &GaussHermiteRule {
436 self.gh_cache.get_or_init(compute_gauss_hermite)
437 }
438
439 fn gauss_hermite_n(&self, n: usize) -> &GaussHermiteRuleDynamic {
440 match n {
441 7 => self.gh15_cache.get_or_init(|| compute_gauss_hermite_n(15)),
444 15 => self.gh15_cache.get_or_init(|| compute_gauss_hermite_n(15)),
445 21 => self.gh21_cache.get_or_init(|| compute_gauss_hermite_n(21)),
446 31 => self.gh31_cache.get_or_init(|| compute_gauss_hermite_n(31)),
447 51 => self.gh51_cache.get_or_init(|| compute_gauss_hermite_n(51)),
448 _ => self.gh21_cache.get_or_init(|| compute_gauss_hermite_n(21)),
449 }
450 }
451
452 fn clenshaw_curtis_n(&self, n: usize) -> Arc<ClenshawCurtisRule> {
453 let mut cache = match self.cc_cache.lock() {
454 Ok(guard) => guard,
455 Err(poisoned) => poisoned.into_inner(),
456 };
457 cache
458 .entry(n)
459 .or_insert_with(|| Arc::new(compute_clenshaw_curtis_n(n)))
460 .clone()
461 }
462}
463
464impl Default for QuadratureContext {
465 fn default() -> Self {
466 Self::new()
467 }
468}
469
470struct GaussHermiteRule {
472 nodes: [f64; N_POINTS],
474 weights: [f64; N_POINTS],
476}
477
478pub(crate) struct GaussHermiteRuleDynamic {
479 pub(crate) nodes: Vec<f64>,
480 pub(crate) weights: Vec<f64>,
481}
482
483#[derive(Clone)]
484struct ClenshawCurtisRule {
485 nodes: Vec<f64>,
486 weights: Vec<f64>,
487}
488
489fn compute_clenshaw_curtis_n(n: usize) -> ClenshawCurtisRule {
490 assert!(
491 n >= 2,
492 "Clenshaw-Curtis rule requires at least two nodes: n={n}"
493 );
494 let m = n - 1;
509 let theta: Vec<f64> = (0..=m)
510 .map(|j| std::f64::consts::PI * (j as f64) / (m as f64))
511 .collect();
512 let nodes: Vec<f64> = theta.iter().map(|&th| th.cos()).collect();
513
514 if n == 2 {
515 return ClenshawCurtisRule {
516 nodes,
517 weights: vec![1.0, 1.0],
518 };
519 }
520
521 let mut weights = vec![0.0_f64; n];
522 let mut v = vec![1.0_f64; m - 1];
523
524 if m.is_multiple_of(2) {
525 let w0 = 1.0 / ((m * m - 1) as f64);
526 weights[0] = w0;
527 weights[m] = w0;
528 for k in 1..(m / 2) {
529 let denom = (4 * k * k - 1) as f64;
530 for j in 1..m {
531 v[j - 1] -= 2.0 * (2.0 * (k as f64) * theta[j]).cos() / denom;
532 }
533 }
534 for j in 1..m {
535 v[j - 1] -= ((m as f64) * theta[j]).cos() / ((m * m - 1) as f64);
536 }
537 } else {
538 let w0 = 1.0 / ((m * m) as f64);
539 weights[0] = w0;
540 weights[m] = w0;
541 for k in 1..=((m - 1) / 2) {
542 let denom = (4 * k * k - 1) as f64;
543 for j in 1..m {
544 v[j - 1] -= 2.0 * (2.0 * (k as f64) * theta[j]).cos() / denom;
545 }
546 }
547 }
548
549 for j in 1..m {
550 weights[j] = 2.0 * v[j - 1] / (m as f64);
551 }
552
553 for j in 0..=(m / 2) {
557 let jj = m - j;
558 let avg = 0.5 * (weights[j] + weights[jj]);
559 weights[j] = avg;
560 weights[jj] = avg;
561 }
562 let weight_sum: f64 = weights.iter().sum();
563 if weight_sum.is_finite() && weight_sum != 0.0 {
564 let scale = 2.0 / weight_sum;
565 for w in &mut weights {
566 *w *= scale;
567 }
568 }
569
570 ClenshawCurtisRule { nodes, weights }
571}
572
573fn cloglog_cc_required_nodes(mu: f64, sigma: f64, tol: f64) -> Result<usize, EstimationError> {
574 if !(mu.is_finite() && sigma.is_finite() && sigma > 0.0 && tol.is_finite() && tol > 0.0) {
575 crate::bail_invalid_estim!(
576 "CC cloglog backend requires finite mu, positive sigma, and positive tolerance"
577 .to_string(),
578 );
579 }
580
581 let p_tail = (tol / 8.0).clamp(1e-300, 0.25);
586 let a = gam_math::probability::standard_normal_quantile(p_tail)
587 .map(|z| -z)
588 .unwrap_or(8.0)
589 .max(1.0);
590
591 let ay = a * sigma;
592 let y = if ay > 0.0 {
593 1.0_f64.min(std::f64::consts::PI / (4.0 * ay))
594 } else {
595 1.0
596 };
597 let rho = y + (1.0 + y * y).sqrt();
598 let m_s = (0.5 * (a * y) * (a * y)).exp() / (2.0 * std::f64::consts::PI).sqrt();
599 let eps_quad = (tol / 4.0).max(1e-300);
600 let numer = ((8.0 * a * m_s) / ((rho - 1.0).max(1e-12) * eps_quad)).max(1.0);
601 let denom = rho.ln();
602 if !denom.is_finite() || denom <= 0.0 {
603 crate::bail_invalid_estim!("CC cloglog backend ellipse bound became degenerate");
604 }
605
606 let mut n = (1.0 + numer.ln() / denom).ceil() as usize;
607 n = n.max(CLOGLOG_CC_MIN_N);
608 if n.is_multiple_of(2) {
609 n += 1;
610 }
611 Ok(n)
612}
613
614#[inline]
615fn cloglog_should_prefer_cc(mu: f64, sigma: f64, tol: f64) -> bool {
616 match cloglog_cc_required_nodes(mu, sigma, tol) {
622 Ok(n) => n <= CLOGLOG_CC_PREFER_THRESHOLD,
623 Err(_) => false,
624 }
625}
626
627fn compute_gauss_hermite() -> GaussHermiteRule {
640 let mut diag = [0.0f64; N_POINTS]; let mut off_diag = [0.0f64; N_POINTS - 1];
646
647 for i in 0..(N_POINTS - 1) {
648 off_diag[i] = (((i + 1) as f64) / 2.0).sqrt();
650 }
651
652 let (eigenvalues, eigenvectors) = symmetric_tridiagonal_eigen(&mut diag, &mut off_diag);
655
656 let nodes = eigenvalues;
658 let mut weights = [0.0f64; N_POINTS];
659
660 let mu0 = std::f64::consts::PI.sqrt();
666 for i in 0..N_POINTS {
667 let v0 = eigenvectors[i][0];
668 weights[i] = mu0 * v0 * v0;
669 }
670
671 let mut indices: [usize; N_POINTS] = [0, 1, 2, 3, 4, 5, 6];
673 indices.sort_by(|&a, &b| nodes[a].total_cmp(&nodes[b]));
674
675 let sorted_nodes: [f64; N_POINTS] = std::array::from_fn(|i| nodes[indices[i]]);
676 let sortedweights: [f64; N_POINTS] = std::array::from_fn(|i| weights[indices[i]]);
677
678 GaussHermiteRule {
679 nodes: sorted_nodes,
680 weights: sortedweights,
681 }
682}
683
684pub(crate) fn compute_gauss_hermite_n(n: usize) -> GaussHermiteRuleDynamic {
685 let mut diag = vec![0.0f64; n];
686 let mut off_diag = vec![0.0f64; n.saturating_sub(1)];
687 for (i, od) in off_diag.iter_mut().enumerate() {
688 *od = (((i + 1) as f64) / 2.0).sqrt();
689 }
690 let (nodes, eigenvectors) = symmetric_tridiagonal_eigen_dynamic(&mut diag, &mut off_diag);
691 let mu0 = std::f64::consts::PI.sqrt();
692 let mut pairs = (0..n)
693 .map(|i| {
694 let v0 = eigenvectors[i][0];
695 (nodes[i], mu0 * v0 * v0)
696 })
697 .collect::<Vec<_>>();
698 pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
699 GaussHermiteRuleDynamic {
700 nodes: pairs.iter().map(|p| p.0).collect(),
701 weights: pairs.iter().map(|p| p.1).collect(),
702 }
703}
704
705fn symmetric_tridiagonal_eigen(
709 diag: &mut [f64; N_POINTS],
710 off_diag: &mut [f64; N_POINTS - 1],
711) -> ([f64; N_POINTS], [[f64; N_POINTS]; N_POINTS]) {
712 let mut diag_vec = diag.to_vec();
713 let mut off_diag_vec = off_diag.to_vec();
714 let (eigenvalues, eigenvectors) =
715 symmetric_tridiagonal_eigen_dynamic(&mut diag_vec, &mut off_diag_vec);
716
717 let mut values = [0.0; N_POINTS];
718 let mut vectors = [[0.0; N_POINTS]; N_POINTS];
719 values.copy_from_slice(&eigenvalues);
720 for i in 0..N_POINTS {
721 vectors[i].copy_from_slice(&eigenvectors[i]);
722 }
723 diag.copy_from_slice(&values);
724 off_diag.copy_from_slice(&off_diag_vec);
725 (values, vectors)
726}
727
728fn symmetric_tridiagonal_eigen_dynamic(
729 diag: &mut [f64],
730 off_diag: &mut [f64],
731) -> (Vec<f64>, Vec<Vec<f64>>) {
732 let dim = diag.len();
733 let mut z = vec![vec![0.0_f64; dim]; dim];
734 for (i, row) in z.iter_mut().enumerate().take(dim) {
735 row[i] = 1.0;
736 }
737 const DEFLATION_TOL: f64 = 1e-15;
743 const MAX_QL_SWEEPS: usize = 200;
744 let eps = DEFLATION_TOL;
745 let max_iter = MAX_QL_SWEEPS;
746 let mut t_norm = 0.0_f64;
752 for i in 0..dim {
753 let left = if i > 0 { off_diag[i - 1].abs() } else { 0.0 };
754 let right = if i + 1 < dim { off_diag[i].abs() } else { 0.0 };
755 let row_sum = diag[i].abs() + left + right;
756 if row_sum > t_norm {
757 t_norm = row_sum;
758 }
759 }
760 let mut n = dim;
761 while n > 1 {
762 let mut converged = false;
763 for _ in 0..max_iter {
764 let mut m = n - 1;
765 while m > 0 {
766 let row_scale = (diag[m - 1].abs() + diag[m].abs()).max(t_norm);
767 if off_diag[m - 1].abs() <= eps * row_scale {
768 off_diag[m - 1] = 0.0;
769 break;
770 }
771 m -= 1;
772 }
773 if m == n - 1 {
774 n -= 1;
775 converged = true;
776 break;
777 }
778 let shift = wilkinson_shift(diag[n - 2], diag[n - 1], off_diag[n - 2]);
779 let mut x = diag[m] - shift;
780 let mut y = off_diag[m];
781 for k in m..(n - 1) {
782 let (c, s) = if y.abs() > eps {
783 let r = x.hypot(y);
784 if r > 0.0 && r.is_finite() {
785 (x / r, -y / r)
786 } else {
787 (1.0, 0.0)
788 }
789 } else {
790 (1.0, 0.0)
791 };
792 if k > m {
793 off_diag[k - 1] = x.hypot(y);
794 }
795 let d1 = diag[k];
796 let d2 = diag[k + 1];
797 let e_k = off_diag[k];
798 diag[k] = c * c * d1 + s * s * d2 - 2.0 * c * s * e_k;
799 diag[k + 1] = s * s * d1 + c * c * d2 + 2.0 * c * s * e_k;
800 off_diag[k] = c * s * (d1 - d2) + (c * c - s * s) * e_k;
801 if k < n - 2 {
802 x = off_diag[k];
803 y = -s * off_diag[k + 1];
804 off_diag[k + 1] *= c;
805 }
806 for i in 0..dim {
807 let t = z[k][i];
808 z[k][i] = c * t - s * z[k + 1][i];
809 z[k + 1][i] = s * t + c * z[k + 1][i];
810 }
811 }
812 }
813 if !converged {
814 off_diag[n - 2] = 0.0;
815 n -= 1;
816 }
817 }
818 (diag.to_vec(), z)
819}
820
821#[inline]
822fn wilkinson_shift(a: f64, c: f64, b: f64) -> f64 {
823 let d = (a - c) * 0.5;
824 let t = d.hypot(b);
825 let sgn = if d >= 0.0 { 1.0 } else { -1.0 }; let denom = d + sgn * t;
827
828 if denom.abs() > f64::EPSILON * t.max(1.0) {
829 c - (b * b) / denom
830 } else {
831 c - t
833 }
834}
835
836#[inline]
847pub fn logit_posterior_mean(ctx: &QuadratureContext, eta: f64, se_eta: f64) -> f64 {
848 match logit_posterior_meanwith_deriv_controlled(eta, se_eta) {
849 Ok(out) => out.mean,
850 Err(_) => integrate_normal_ghq_adaptive(ctx, eta, se_eta, sigmoid),
851 }
852}
853
854#[inline]
862pub fn logit_posterior_meanwith_deriv(
863 eta: f64,
864 se_eta: f64,
865) -> Result<(f64, f64), EstimationError> {
866 let out = logit_posterior_meanwith_deriv_controlled(eta, se_eta)?;
877 Ok((out.mean, out.dmean_dmu))
878}
879
880#[inline]
881pub fn probit_posterior_meanwith_deriv_exact(mu: f64, sigma: f64) -> IntegratedMeanDerivative {
882 if !(mu.is_finite() && sigma.is_finite()) || sigma <= 1e-12 {
910 let mean = gam_math::probability::normal_cdf(mu);
911 let dmean_dmu = gam_math::probability::normal_pdf(mu);
912 return IntegratedMeanDerivative {
913 mean,
914 dmean_dmu,
915 mode: IntegratedExpectationMode::ExactClosedForm,
916 };
917 }
918 let denom = (1.0 + sigma * sigma).sqrt();
919 let z = mu / denom;
920 IntegratedMeanDerivative {
921 mean: gam_math::probability::normal_cdf(z),
922 dmean_dmu: gam_math::probability::normal_pdf(z) / denom,
923 mode: IntegratedExpectationMode::ExactClosedForm,
924 }
925}
926
927#[inline]
928fn logistic_normal_exact_eligible(mu: f64, sigma: f64) -> bool {
929 mu.is_finite()
930 && sigma.is_finite()
931 && mu.abs() <= LOGIT_ERFCX_MU_MAX
932 && (LOGIT_SIGMA_TAYLOR_MAX..=LOGIT_ERFCX_SIGMA_MAX).contains(&sigma)
933}
934
935#[inline]
979fn logistic_normal_series_cutoff(mu: f64, sigma: f64, target_accuracy: f64) -> Option<usize> {
980 assert!(sigma > 0.0);
981 assert!(target_accuracy > 0.0);
982 let m = mu.abs();
983 let s = sigma;
984 let gauss = (-(m * m) / (2.0 * s * s)).exp();
985 let coeff_mean = m * (2.0_f64 / std::f64::consts::PI).sqrt() * gauss / (s * s * s);
986 let coeff_deriv =
987 2.0 * gauss * (m * m - s * s).abs() / ((2.0 * std::f64::consts::PI).sqrt() * s.powi(5));
988 let asymptotic_index = |coeff: f64| -> f64 {
992 if !coeff.is_finite() || coeff <= target_accuracy {
993 0.0
994 } else {
995 (coeff / target_accuracy).sqrt() - 1.0
996 }
997 };
998 let peak_floor = m / (s * s) + 1.0;
1002 let required = asymptotic_index(coeff_mean)
1003 .max(asymptotic_index(coeff_deriv))
1004 .max(peak_floor);
1005 if !required.is_finite() || required > LOGIT_MAX_TERMS as f64 {
1006 return None;
1007 }
1008 Some((required.ceil() as usize).max(4))
1011}
1012
1013#[inline]
1014fn stable_sigmoidwith_derivative(x: f64) -> (f64, f64) {
1015 let x_clamped = x.clamp(-QUADRATURE_EXP_LOG_MAX, QUADRATURE_EXP_LOG_MAX);
1016 if x_clamped != x {
1017 return (sigmoid(x), 0.0);
1018 }
1019 if x_clamped >= 0.0 {
1020 let z = (-x_clamped).exp();
1021 let denom = 1.0 + z;
1022 (1.0 / denom, z / (denom * denom))
1023 } else {
1024 let z = x_clamped.exp();
1025 let denom = 1.0 + z;
1026 (z / denom, z / (denom * denom))
1027 }
1028}
1029
1030#[inline]
1031fn logit_small_sigma_taylor(mu: f64, sigma: f64) -> IntegratedMeanDerivative {
1032 let (mean0, d1, d2, d3) = component_point_jet(LinkComponent::Logit, mu);
1040 let s2 = sigma * sigma;
1041 IntegratedMeanDerivative {
1042 mean: (mean0 + 0.5 * s2 * d2).clamp(0.0, 1.0),
1043 dmean_dmu: (d1 + 0.5 * s2 * d3).max(0.0),
1044 mode: IntegratedExpectationMode::ControlledAsymptotic,
1045 }
1046}
1047
1048#[inline]
1049fn logit_tail_asymptotic(mu: f64, sigma: f64) -> Option<IntegratedMeanDerivative> {
1050 if mu <= 0.0 {
1055 let log_mean = mu + 0.5 * sigma * sigma;
1056 if log_mean <= LOGIT_TAIL_LOG_MAX {
1057 let mean = safe_exp(log_mean);
1058 return Some(IntegratedMeanDerivative {
1059 mean,
1060 dmean_dmu: mean,
1061 mode: IntegratedExpectationMode::ControlledAsymptotic,
1062 });
1063 }
1064 } else {
1065 let log_tail = -mu + 0.5 * sigma * sigma;
1066 if log_tail <= LOGIT_TAIL_LOG_MAX {
1067 let tail = safe_exp(log_tail);
1068 return Some(IntegratedMeanDerivative {
1069 mean: 1.0 - tail,
1070 dmean_dmu: tail,
1071 mode: IntegratedExpectationMode::ControlledAsymptotic,
1072 });
1073 }
1074 }
1075 None
1076}
1077
1078#[inline]
1079fn scaled_erfcx_termwith_derivative(m: f64, s: f64, x: f64, dxdm: f64) -> (f64, f64) {
1080 let pref = 0.5 * (-(m * m) / (2.0 * s * s)).exp();
1081 if x >= 0.0 {
1082 let ex = erfcx_nonnegative(x);
1083 let term = pref * ex;
1084 let ex_prime = 2.0 * x * ex - std::f64::consts::FRAC_2_SQRT_PI;
1085 let dterm = pref * ((-m / (s * s)) * ex + ex_prime * dxdm);
1086 (term, dterm)
1087 } else {
1088 let lead = (x * x - (m * m) / (2.0 * s * s)).exp();
1089 let dlead = lead * (2.0 * x * dxdm - m / (s * s));
1090 let (rest, drest) = scaled_erfcx_termwith_derivative(m, s, -x, -dxdm);
1091 (lead - rest, dlead - drest)
1092 }
1093}
1094
1095pub(crate) fn logit_posterior_meanwith_deriv_exact(
1096 mu: f64,
1097 sigma: f64,
1098) -> Result<IntegratedMeanDerivative, EstimationError> {
1099 if !(mu.is_finite() && sigma.is_finite()) {
1117 crate::bail_invalid_estim!("logit exact expectation requires finite mu and sigma");
1118 }
1119 if sigma <= LOGIT_SIGMA_DEGENERATE {
1120 let (mean, dmean_dmu) = stable_sigmoidwith_derivative(mu);
1121 return Ok(IntegratedMeanDerivative {
1122 mean,
1123 dmean_dmu,
1124 mode: IntegratedExpectationMode::ExactClosedForm,
1125 });
1126 }
1127 if let Some(out) = logit_tail_asymptotic(mu, sigma) {
1128 return Ok(out);
1129 }
1130 if sigma < LOGIT_SIGMA_TAYLOR_MAX {
1131 return Ok(logit_small_sigma_taylor(mu, sigma));
1132 }
1133 if logistic_normal_exact_eligible(mu, sigma)
1134 && let Ok(out) = logit_posterior_meanwith_deriv_exact_erfcx(mu, sigma)
1135 {
1136 return Ok(out);
1137 }
1138 Err(EstimationError::InvalidInput(
1147 "logit analytic expectation has no certified representation in this regime".to_string(),
1148 ))
1149}
1150
1151fn logit_posterior_meanwith_deriv_exact_erfcx(
1152 mu: f64,
1153 sigma: f64,
1154) -> Result<IntegratedMeanDerivative, EstimationError> {
1155 let m = mu.abs();
1181 let s = sigma;
1182 let z = SQRT_2 * s;
1183 let phi_term = gam_math::probability::normal_cdf(m / s);
1184 let phi_prime = gam_math::probability::normal_pdf(m / s) / s;
1185 let Some(max_k) = logistic_normal_series_cutoff(mu, sigma, LOGIT_ERFCX_ACCURACY_TARGET) else {
1186 crate::bail_invalid_estim!(
1187 "logit erfcx series truncation bound exceeds LOGIT_MAX_TERMS at the required accuracy"
1188 .to_string(),
1189 );
1190 };
1191
1192 let mut sum = 0.0_f64;
1193 let mut dsum = 0.0_f64;
1194 let mut k = 1usize;
1202 while k <= max_k {
1203 for kk in [k, k + 1].into_iter().filter(|kk| *kk <= max_k) {
1204 let kf = kk as f64;
1205 let a = (kf * s * s + m) / z;
1206 let b = (kf * s * s - m) / z;
1207 let sign = if kk % 2 == 1 { 1.0 } else { -1.0 };
1208 let (va, dva) = scaled_erfcx_termwith_derivative(m, s, a, 1.0 / z);
1209 let (vb, dvb) = scaled_erfcx_termwith_derivative(m, s, b, -1.0 / z);
1210 sum += sign * (va - vb);
1211 dsum += sign * (dva - dvb);
1212 }
1213 k += 2;
1214 }
1215
1216 let mut mean = phi_term + sum;
1217 let dmean = (phi_prime + dsum).max(0.0);
1218 if mu < 0.0 {
1219 mean = 1.0 - mean;
1220 }
1221 if !(mean.is_finite() && dmean.is_finite() && dmean >= 0.0) {
1222 crate::bail_invalid_estim!("logit erfcx expectation produced non-finite values");
1223 }
1224 Ok(IntegratedMeanDerivative {
1225 mean,
1226 dmean_dmu: dmean,
1227 mode: IntegratedExpectationMode::ExactSpecialFunction,
1228 })
1229}
1230
1231#[inline]
1236fn logit_posterior_meanwith_deriv_quadrature(mu: f64, sigma: f64) -> IntegratedMeanDerivative {
1237 let mean = integrate_normal_adaptive(mu, sigma, |x| stable_sigmoidwith_derivative(x).0);
1238 let dmean_dmu =
1239 integrate_normal_adaptive(mu, sigma, |x| stable_sigmoidwith_derivative(x).1).max(0.0);
1240 IntegratedMeanDerivative {
1241 mean,
1242 dmean_dmu,
1243 mode: IntegratedExpectationMode::QuadratureFallback,
1244 }
1245}
1246
1247#[inline]
1248fn logit_posterior_meanwith_deriv_controlled(
1249 mu: f64,
1250 sigma: f64,
1251) -> Result<IntegratedMeanDerivative, EstimationError> {
1252 if !(mu.is_finite() && sigma.is_finite()) {
1253 crate::bail_invalid_estim!("logit integrated moments require finite mu and sigma");
1254 }
1255 let candidate = match logit_posterior_meanwith_deriv_exact(mu, sigma) {
1256 Ok(out) => out,
1257 Err(_) => return Ok(logit_posterior_meanwith_deriv_quadrature(mu, sigma)),
1258 };
1259 match candidate.mode {
1271 IntegratedExpectationMode::ExactSpecialFunction
1272 | IntegratedExpectationMode::ControlledAsymptotic => {
1273 let reference = logit_posterior_meanwith_deriv_quadrature(mu, sigma);
1274 if integrated_mean_derivative_drift_exceeds(
1275 &candidate, &reference, 1e-6, 1e-4, 1e-7, 1e-3,
1276 ) {
1277 Ok(reference)
1278 } else {
1279 Ok(candidate)
1280 }
1281 }
1282 _ => Ok(candidate),
1283 }
1284}
1285
1286#[inline]
1287fn log_normal_cdf_stable(x: f64) -> f64 {
1288 if !x.is_finite() {
1289 return if x.is_sign_negative() {
1290 f64::NEG_INFINITY
1291 } else {
1292 0.0
1293 };
1294 }
1295 if x < -8.0 {
1296 let u = -x / SQRT_2;
1297 -u * u + (0.5 * erfcx_nonnegative(u)).ln()
1298 } else {
1299 gam_math::probability::normal_cdf(x).max(1e-300).ln()
1300 }
1301}
1302
1303#[inline]
1304fn cloglog_extreme_asymptotic(mu: f64, sigma: f64) -> Option<IntegratedMeanDerivative> {
1305 let rare_log = mu + 0.5 * sigma * sigma;
1315 if rare_log <= CLOGLOG_RARE_EVENT_LOG_MAX {
1316 let mean = safe_exp(rare_log);
1317 return Some(IntegratedMeanDerivative {
1318 mean,
1319 dmean_dmu: mean,
1320 mode: IntegratedExpectationMode::ControlledAsymptotic,
1321 });
1322 }
1323 if mu - CLOGLOG_POSITIVE_SATURATION_SIGMAS * sigma >= CLOGLOG_POSITIVE_SATURATION_EDGE {
1324 return Some(IntegratedMeanDerivative {
1325 mean: 1.0,
1326 dmean_dmu: 0.0,
1327 mode: IntegratedExpectationMode::ControlledAsymptotic,
1328 });
1329 }
1330 None
1336}
1337
1338#[inline]
1339fn cloglog_survival_extreme_asymptotic(
1340 mu: f64,
1341 sigma: f64,
1342) -> Option<(f64, IntegratedExpectationMode)> {
1343 let rare_log = mu + 0.5 * sigma * sigma;
1344 if rare_log <= CLOGLOG_RARE_EVENT_LOG_MAX {
1345 let mean = safe_exp(rare_log);
1346 return Some((
1347 (1.0 - mean).clamp(0.0, 1.0),
1348 IntegratedExpectationMode::ControlledAsymptotic,
1349 ));
1350 }
1351 if mu - CLOGLOG_POSITIVE_SATURATION_SIGMAS * sigma >= CLOGLOG_POSITIVE_SATURATION_EDGE {
1352 return Some((0.0, IntegratedExpectationMode::ControlledAsymptotic));
1357 }
1358 None
1362}
1363
1364#[inline]
1372fn cloglog_gumbel_quad_nodes(sigma: f64) -> usize {
1373 let target = (CLOGLOG_GUMBEL_QUAD_NODE_SCALE / sigma.min(CLOGLOG_LARGE_SIGMA_ASYMPTOTIC_MIN))
1374 .ceil() as usize;
1375 let n = target
1376 .max(CLOGLOG_GUMBEL_QUAD_MIN_NODES)
1377 .min(CLOGLOG_GUMBEL_QUAD_MAX_NODES);
1378 if n % 2 == 0 { n + 1 } else { n }
1379}
1380
1381fn cloglog_log_survival_gumbel_quadrature(ctx: &QuadratureContext, mu: f64, sigma: f64) -> f64 {
1404 let a = CLOGLOG_GUMBEL_QUAD_ETA_LO;
1405 let b = CLOGLOG_GUMBEL_QUAD_ETA_HI;
1406 let half = 0.5 * (b - a);
1407 let mid = 0.5 * (a + b);
1408 let rule = ctx.clenshaw_curtis_n(cloglog_gumbel_quad_nodes(sigma));
1409 let mut running_max = f64::NEG_INFINITY;
1412 let mut running_sum = 0.0_f64;
1413 for (&node, &weight) in rule.nodes.iter().zip(rule.weights.iter()) {
1414 let eta = half * node + mid;
1415 let summand = (weight * half).ln()
1416 + (eta - safe_exp(eta))
1417 + log_normal_cdf_stable((eta - mu) / sigma);
1418 if !summand.is_finite() {
1419 continue;
1420 }
1421 if summand > running_max {
1422 running_sum = running_sum * (running_max - summand).exp() + 1.0;
1423 running_max = summand;
1424 } else {
1425 running_sum += (summand - running_max).exp();
1426 }
1427 }
1428 if running_max == f64::NEG_INFINITY {
1429 f64::NEG_INFINITY
1430 } else {
1431 running_max + running_sum.ln()
1432 }
1433}
1434
1435pub(crate) fn cloglog_log_survival_term_controlled(
1444 ctx: &QuadratureContext,
1445 mu: f64,
1446 sigma: f64,
1447) -> (f64, IntegratedExpectationMode) {
1448 if !(mu.is_finite() && sigma.is_finite()) || sigma <= CLOGLOG_SIGMA_DEGENERATE {
1449 return (-safe_exp(mu), IntegratedExpectationMode::ExactClosedForm);
1451 }
1452 let rare_log = mu + 0.5 * sigma * sigma;
1453 if rare_log <= CLOGLOG_RARE_EVENT_LOG_MAX {
1454 return (
1457 (-safe_exp(rare_log)).ln_1p(),
1458 IntegratedExpectationMode::ControlledAsymptotic,
1459 );
1460 }
1461 if sigma >= CLOGLOG_LARGE_SIGMA_ASYMPTOTIC_MIN {
1462 return (
1463 cloglog_log_survival_gumbel_quadrature(ctx, mu, sigma),
1464 IntegratedExpectationMode::ControlledAsymptotic,
1465 );
1466 }
1467 let (value, mode) = cloglog_survival_term_controlled(ctx, mu, sigma);
1468 if value > 0.0 {
1469 (value.ln(), mode)
1470 } else {
1471 (
1474 cloglog_log_survival_gumbel_quadrature(ctx, mu, sigma),
1475 IntegratedExpectationMode::QuadratureFallback,
1476 )
1477 }
1478}
1479
1480#[inline]
1499fn gumbel_survival(x: f64) -> f64 {
1500 (-safe_exp(x)).exp()
1501}
1502
1503#[inline]
1508fn cloglog_mean_d1_exact(x: f64) -> f64 {
1509 let ex = safe_exp(x);
1510 if ex.is_infinite() {
1511 0.0
1512 } else {
1513 ex * (-ex).exp()
1514 }
1515}
1516
1517#[inline]
1527fn cloglog_mean_exact(x: f64) -> f64 {
1528 cloglog_negative_tail_mean(x)
1529}
1530
1531#[inline]
1547fn cloglog_negative_tail_mean(eta: f64) -> f64 {
1548 if eta < -745.0 {
1552 0.0
1554 } else {
1555 let ex = safe_exp(eta);
1558 -(-ex).exp_m1()
1559 }
1560}
1561
1562#[inline]
1567fn cloglog_small_sigma_taylor(mu: f64, sigma: f64) -> IntegratedMeanDerivative {
1568 if sigma <= CLOGLOG_SIGMA_DEGENERATE {
1592 return IntegratedMeanDerivative {
1593 mean: cloglog_mean_exact(mu),
1594 dmean_dmu: cloglog_mean_d1_exact(mu),
1595 mode: IntegratedExpectationMode::ExactClosedForm,
1596 };
1597 }
1598
1599 let ex = safe_exp(mu);
1600 if !ex.is_finite() {
1601 return IntegratedMeanDerivative {
1603 mean: 1.0,
1604 dmean_dmu: 0.0,
1605 mode: IntegratedExpectationMode::ControlledAsymptotic,
1606 };
1607 }
1608 let surv = (-ex).exp();
1609 if surv == 0.0 {
1610 return IntegratedMeanDerivative {
1612 mean: 1.0,
1613 dmean_dmu: 0.0,
1614 mode: IntegratedExpectationMode::ControlledAsymptotic,
1615 };
1616 }
1617
1618 let s2 = sigma * sigma;
1619 let s4 = s2 * s2;
1620 let s6 = s4 * s2;
1621 let s8 = s4 * s4;
1622 let e2x = ex * ex;
1623 let e3x = e2x * ex;
1624 let e4x = e3x * ex;
1625 let e5x = e4x * ex;
1626 let e6x = e5x * ex;
1627 let e7x = e6x * ex;
1628 let e8x = e7x * ex;
1629 let e9x = e8x * ex;
1630 let f0 = -(-ex).exp_m1();
1644 let f1 = ex * surv;
1645 let f2 = surv * (ex - e2x);
1646 let f3 = surv * (ex - 3.0 * e2x + e3x);
1647 let f4 = surv * (ex - 7.0 * e2x + 6.0 * e3x - e4x);
1648 let f5 = surv * (ex - 15.0 * e2x + 25.0 * e3x - 10.0 * e4x + e5x);
1649 let f6 = surv * (ex - 31.0 * e2x + 90.0 * e3x - 65.0 * e4x + 15.0 * e5x - e6x);
1650 let f7 = surv * (ex - 63.0 * e2x + 301.0 * e3x - 350.0 * e4x + 140.0 * e5x - 21.0 * e6x + e7x);
1651 let f8 = surv
1652 * (ex - 127.0 * e2x + 966.0 * e3x - 1701.0 * e4x + 1050.0 * e5x - 266.0 * e6x + 28.0 * e7x
1653 - e8x);
1654 let f9 = surv
1655 * (ex - 255.0 * e2x + 3025.0 * e3x - 7770.0 * e4x + 6951.0 * e5x - 2646.0 * e6x
1656 + 462.0 * e7x
1657 - 36.0 * e8x
1658 + e9x);
1659 IntegratedMeanDerivative {
1663 mean: f0 + 0.5 * s2 * f2 + (s4 / 8.0) * f4 + (s6 / 48.0) * f6 + (s8 / 384.0) * f8,
1664 dmean_dmu: (f1 + 0.5 * s2 * f3 + (s4 / 8.0) * f5 + (s6 / 48.0) * f7 + (s8 / 384.0) * f9)
1665 .max(0.0),
1666 mode: IntegratedExpectationMode::ControlledAsymptotic,
1667 }
1668}
1669
1670#[inline]
1671fn adaptive_simpson_refine(
1676 g: &impl Fn(f64) -> f64,
1677 a: f64,
1678 b: f64,
1679 fa: f64,
1680 fb: f64,
1681 fm: f64,
1682 whole: f64,
1683 tol: f64,
1684 depth: i32,
1685) -> f64 {
1686 let m = 0.5 * (a + b);
1687 let lm = 0.5 * (a + m);
1688 let rm = 0.5 * (m + b);
1689 let flm = g(lm);
1690 let frm = g(rm);
1691 let left = (m - a) / 6.0 * (fa + 4.0 * flm + fm);
1692 let right = (b - m) / 6.0 * (fm + 4.0 * frm + fb);
1693 let est = left + right;
1694 if depth <= 0 || (est - whole).abs() <= 15.0 * tol {
1695 return est + (est - whole) / 15.0;
1696 }
1697 adaptive_simpson_refine(g, a, m, fa, fm, flm, left, 0.5 * tol, depth - 1)
1698 + adaptive_simpson_refine(g, m, b, fm, fb, frm, right, 0.5 * tol, depth - 1)
1699}
1700
1701fn integrate_normal_adaptive(mu: f64, sigma: f64, f: impl Fn(f64) -> f64) -> f64 {
1716 if !(sigma.is_finite()) || sigma < 1e-10 {
1717 return f(mu);
1718 }
1719 const K: f64 = 15.0;
1720 const INITIAL_PANELS: usize = 24;
1721 const TOL: f64 = 1e-12;
1722 const MAX_DEPTH: i32 = 40;
1723 let inv_sqrt_2pi = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
1724 let g = |u: f64| f(mu + sigma * u) * inv_sqrt_2pi * (-0.5 * u * u).exp();
1728 let panel = 2.0 * K / INITIAL_PANELS as f64;
1729 let mut total = 0.0;
1730 for p in 0..INITIAL_PANELS {
1731 let a = -K + p as f64 * panel;
1732 let b = a + panel;
1733 let fa = g(a);
1734 let fb = g(b);
1735 let fm = g(0.5 * (a + b));
1736 let whole = (b - a) / 6.0 * (fa + 4.0 * fm + fb);
1737 total += adaptive_simpson_refine(&g, a, b, fa, fb, fm, whole, TOL, MAX_DEPTH);
1738 }
1739 total
1740}
1741
1742fn cloglog_posterior_meanwith_deriv_quadrature(mu: f64, sigma: f64) -> IntegratedMeanDerivative {
1743 if sigma < 1e-10 {
1744 return IntegratedMeanDerivative {
1745 mean: cloglog_mean_exact(mu),
1746 dmean_dmu: cloglog_mean_d1_exact(mu),
1747 mode: IntegratedExpectationMode::ExactClosedForm,
1748 };
1749 }
1750 let mean = cloglog_mean_from_survival(survival_posterior_mean_quadrature(mu, sigma));
1751 let dmean_dmu = integrate_normal_adaptive(mu, sigma, cloglog_mean_d1_exact).max(0.0);
1752 IntegratedMeanDerivative {
1753 mean,
1754 dmean_dmu,
1755 mode: IntegratedExpectationMode::QuadratureFallback,
1756 }
1757}
1758
1759#[inline]
1760fn survival_posterior_mean_quadrature(eta: f64, se_eta: f64) -> f64 {
1761 integrate_normal_adaptive(eta, se_eta, gumbel_survival).clamp(0.0, 1.0)
1762}
1763
1764fn cloglog_survival_term_controlled(
1765 ctx: &QuadratureContext,
1766 mu: f64,
1767 sigma: f64,
1768) -> (f64, IntegratedExpectationMode) {
1769 if !(mu.is_finite() && sigma.is_finite()) || sigma <= CLOGLOG_SIGMA_DEGENERATE {
1806 return (
1807 gumbel_survival(mu).clamp(0.0, 1.0),
1808 IntegratedExpectationMode::ExactClosedForm,
1809 );
1810 }
1811 if sigma < CLOGLOG_SIGMA_TAYLOR_MAX {
1812 let mean = cloglog_small_sigma_taylor(mu, sigma).mean;
1813 return (
1814 (1.0 - mean).clamp(0.0, 1.0),
1815 IntegratedExpectationMode::ControlledAsymptotic,
1816 );
1817 }
1818 if let Some(out) = cloglog_survival_extreme_asymptotic(mu, sigma) {
1819 return out;
1820 }
1821 if sigma >= CLOGLOG_LARGE_SIGMA_ASYMPTOTIC_MIN {
1822 let log_s = cloglog_log_survival_gumbel_quadrature(ctx, mu, sigma);
1829 return (
1830 safe_exp(log_s).clamp(0.0, 1.0),
1831 IntegratedExpectationMode::ControlledAsymptotic,
1832 );
1833 }
1834 if cloglog_survival_miles_is_reliable(mu, sigma)
1835 && let Ok(out) = cloglog_survival_miles(mu, sigma)
1836 {
1837 return (
1838 out.clamp(0.0, 1.0),
1839 IntegratedExpectationMode::ExactSpecialFunction,
1840 );
1841 }
1842 if cloglog_should_prefer_cc(mu, sigma, CLOGLOG_CC_TOL)
1843 && let Ok(out) = cloglog_survival_cc(ctx, mu, sigma, CLOGLOG_CC_TOL)
1844 {
1845 return (
1846 out.clamp(0.0, 1.0),
1847 IntegratedExpectationMode::ExactSpecialFunction,
1848 );
1849 }
1850 if let Ok(out) = cloglog_survival_gamma_reference(mu, sigma) {
1851 return (
1852 out.clamp(0.0, 1.0),
1853 IntegratedExpectationMode::ExactSpecialFunction,
1854 );
1855 }
1856 (
1857 survival_posterior_mean_quadrature(mu, sigma),
1858 IntegratedExpectationMode::QuadratureFallback,
1859 )
1860}
1861
1862#[inline]
1863fn lognormal_laplace_term_controlled(
1864 ctx: &QuadratureContext,
1865 z: f64,
1866 mu: f64,
1867 sigma: f64,
1868) -> (f64, IntegratedExpectationMode) {
1869 if !(z.is_finite() && z > 0.0) {
1895 return (f64::NAN, IntegratedExpectationMode::QuadratureFallback);
1896 }
1897 lognormal_laplace_unit_term_shared(ctx, mu + z.ln(), sigma)
1898}
1899
1900#[inline]
1901pub(crate) fn lognormal_laplace_unit_term_shared(
1902 ctx: &QuadratureContext,
1903 shifted_mu: f64,
1904 sigma: f64,
1905) -> (f64, IntegratedExpectationMode) {
1906 cloglog_survival_term_controlled(ctx, shifted_mu, sigma)
1907}
1908
1909#[inline]
1913pub fn lognormal_laplace_unit_log_term_shared(
1914 ctx: &QuadratureContext,
1915 shifted_mu: f64,
1916 sigma: f64,
1917) -> (f64, IntegratedExpectationMode) {
1918 cloglog_log_survival_term_controlled(ctx, shifted_mu, sigma)
1919}
1920
1921#[inline]
1922fn cloglog_survivalsecond_moment_controlled(
1923 ctx: &QuadratureContext,
1924 mu: f64,
1925 sigma: f64,
1926) -> (f64, IntegratedExpectationMode) {
1927 lognormal_laplace_term_controlled(ctx, 2.0, mu, sigma)
1942}
1943
1944#[inline]
1945fn cloglog_survival_pair_controlled(
1946 ctx: &QuadratureContext,
1947 mu: f64,
1948 sigma: f64,
1949) -> (
1950 (f64, IntegratedExpectationMode),
1951 (f64, IntegratedExpectationMode),
1952) {
1953 let shiftedmu = mu + sigma * sigma;
1954
1955 if cloglog_survival_miles_is_reliable(mu, sigma)
1966 && cloglog_survival_miles_is_reliable(shiftedmu, sigma)
1967 && let (Ok(base), Ok(shifted)) = (
1968 cloglog_survival_miles(mu, sigma),
1969 cloglog_survival_miles(shiftedmu, sigma),
1970 )
1971 {
1972 return (
1973 (
1974 base.clamp(0.0, 1.0),
1975 IntegratedExpectationMode::ExactSpecialFunction,
1976 ),
1977 (
1978 shifted.clamp(0.0, 1.0),
1979 IntegratedExpectationMode::ExactSpecialFunction,
1980 ),
1981 );
1982 }
1983
1984 if cloglog_should_prefer_cc(mu, sigma, CLOGLOG_CC_TOL)
1985 && cloglog_should_prefer_cc(shiftedmu, sigma, CLOGLOG_CC_TOL)
1986 && let (Ok(base), Ok(shifted)) = (
1987 cloglog_survival_cc(ctx, mu, sigma, CLOGLOG_CC_TOL),
1988 cloglog_survival_cc(ctx, shiftedmu, sigma, CLOGLOG_CC_TOL),
1989 )
1990 {
1991 return (
1992 (
1993 base.clamp(0.0, 1.0),
1994 IntegratedExpectationMode::ExactSpecialFunction,
1995 ),
1996 (
1997 shifted.clamp(0.0, 1.0),
1998 IntegratedExpectationMode::ExactSpecialFunction,
1999 ),
2000 );
2001 }
2002
2003 if let (Ok(base), Ok(shifted)) = (
2004 cloglog_survival_gamma_reference(mu, sigma),
2005 cloglog_survival_gamma_reference(shiftedmu, sigma),
2006 ) {
2007 return (
2008 (
2009 base.clamp(0.0, 1.0),
2010 IntegratedExpectationMode::ExactSpecialFunction,
2011 ),
2012 (
2013 shifted.clamp(0.0, 1.0),
2014 IntegratedExpectationMode::ExactSpecialFunction,
2015 ),
2016 );
2017 }
2018
2019 (
2020 cloglog_survival_term_controlled(ctx, mu, sigma),
2021 cloglog_survival_term_controlled(ctx, shiftedmu, sigma),
2022 )
2023}
2024
2025#[inline]
2026fn cloglog_mean_from_survival(survival: f64) -> f64 {
2027 let survival = survival.clamp(0.0, 1.0);
2028 if survival > 0.5 {
2029 -survival.ln().exp_m1()
2039 } else {
2040 1.0 - survival
2041 }
2042}
2043
2044#[inline]
2045fn cloglog_shift_identity_derivative(mu: f64, sigma: f64, shifted_survival: f64) -> f64 {
2046 if !(mu.is_finite() && sigma.is_finite()) || shifted_survival <= 0.0 {
2060 return 0.0;
2061 }
2062 cloglog_shift_identity_derivative_log(mu, sigma, shifted_survival.ln())
2063}
2064
2065#[inline]
2075fn cloglog_shift_identity_derivative_log(mu: f64, sigma: f64, log_shifted_survival: f64) -> f64 {
2076 if !(mu.is_finite() && sigma.is_finite()) || log_shifted_survival == f64::NEG_INFINITY {
2077 return 0.0;
2078 }
2079 let log_derivative = mu + 0.5 * sigma * sigma + log_shifted_survival;
2080 let upper = 1.0 / std::f64::consts::E;
2081 if !log_derivative.is_finite() {
2082 return upper;
2085 }
2086 safe_exp(log_derivative).clamp(0.0, upper)
2087}
2088
2089#[inline]
2090fn log_half_erfc_stable(u: f64) -> f64 {
2091 if u > 0.0 {
2101 -u * u + (0.5 * erfcx_nonnegative(u)).ln()
2102 } else {
2103 (0.5 * erfc(u)).ln()
2104 }
2105}
2106
2107#[inline]
2147fn cloglog_survival_miles_is_reliable(mu: f64, sigma: f64) -> bool {
2148 if !(mu.is_finite() && sigma.is_finite() && sigma > 0.0) {
2149 return false;
2150 }
2151 let alpha_ln = CLOGLOG_MILES_ALPHA.ln();
2152 let shifted = mu - alpha_ln;
2153 let peak_log = CLOGLOG_MILES_ALPHA - 0.5 * shifted * shifted / (sigma * sigma);
2154 peak_log.is_finite() && peak_log <= CLOGLOG_MILES_PEAK_LOG_MAX
2155}
2156
2157fn cloglog_survival_miles(mu: f64, sigma: f64) -> Result<f64, EstimationError> {
2158 let alpha_ln = CLOGLOG_MILES_ALPHA.ln();
2188 let mut s_sum = 0.0_f64;
2189 let mut stable_pairs = 0usize;
2190
2191 for pair_start in (0..CLOGLOG_MILES_MAX_TERMS).step_by(2) {
2192 let mut pair_s = 0.0_f64;
2193 for n in pair_start..(pair_start + 2).min(CLOGLOG_MILES_MAX_TERMS) {
2194 let nf = n as f64;
2195 let sign = if n % 2 == 0 { 1.0 } else { -1.0 };
2196 let base_log = nf * mu + 0.5 * sigma * sigma * nf * nf
2197 - statrs::function::gamma::ln_gamma(nf + 1.0);
2198 let u = (mu - alpha_ln + sigma * sigma * nf) / (SQRT_2 * sigma);
2199 let log_half_erfc = log_half_erfc_stable(u);
2200 let term_log = base_log + log_half_erfc;
2201 if term_log > QUADRATURE_EXP_LOG_MAX {
2202 crate::bail_invalid_estim!("Miles cloglog series term exceeded finite exp range");
2203 }
2204 let term = sign * safe_exp(term_log);
2205 pair_s += term;
2206 }
2207 s_sum += pair_s;
2208
2209 let s_scale = s_sum.abs().max(1.0);
2210 if pair_s.abs() <= 2e-15 * s_scale {
2211 stable_pairs += 1;
2212 if stable_pairs >= SERIES_CONSECUTIVE_SMALL_TERMS {
2213 if s_sum.is_finite() && (-1e-10..=1.0 + 1e-10).contains(&s_sum) {
2214 return Ok(s_sum.clamp(0.0, 1.0));
2215 }
2216 break;
2217 }
2218 } else {
2219 stable_pairs = 0;
2220 }
2221 }
2222
2223 Err(EstimationError::InvalidInput(
2224 "Miles cloglog series did not converge safely".to_string(),
2225 ))
2226}
2227
2228fn cloglog_survival_cc(
2229 ctx: &QuadratureContext,
2230 mu: f64,
2231 sigma: f64,
2232 tol: f64,
2233) -> Result<f64, EstimationError> {
2234 if !(mu.is_finite() && sigma.is_finite() && sigma > 0.0 && tol.is_finite() && tol > 0.0) {
2235 crate::bail_invalid_estim!(
2236 "CC cloglog backend requires finite mu, positive sigma, and positive tolerance"
2237 .to_string(),
2238 );
2239 }
2240
2241 let p_tail = (tol / 8.0).clamp(1e-300, 0.25);
2279 let a = gam_math::probability::standard_normal_quantile(p_tail)
2280 .map(|z| -z)
2281 .unwrap_or(8.0)
2282 .max(1.0);
2283 let n = cloglog_cc_required_nodes(mu, sigma, tol)?;
2284 if n > CLOGLOG_CC_NODE_CAP {
2285 crate::bail_invalid_estim!("CC cloglog backend requires too many nodes");
2286 }
2287
2288 let rule = ctx.clenshaw_curtis_n(n);
2289 let inv_sqrt_2pi = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
2290 let mut sum = 0.0_f64;
2291 let mut c = 0.0_f64;
2292 for (&x, &w) in rule.nodes.iter().zip(rule.weights.iter()) {
2293 let t = a * x;
2294 let u = mu + sigma * t;
2295 let e = safe_exp(u);
2296 let w0 = (-0.5 * t * t).exp() * inv_sqrt_2pi;
2297 let yk = w * w0 * (-e).exp() - c;
2298 let tk = sum + yk;
2299 c = (tk - sum) - yk;
2300 sum = tk;
2301 }
2302
2303 let survival = (a * sum).clamp(0.0, 1.0);
2304 if !survival.is_finite() {
2305 crate::bail_invalid_estim!("CC cloglog backend produced non-finite values");
2306 }
2307 Ok(survival)
2308}
2309
2310#[inline]
2311fn complex_add(a: Complex, b: Complex) -> Complex {
2312 Complex {
2313 re: a.re + b.re,
2314 im: a.im + b.im,
2315 }
2316}
2317
2318#[inline]
2319fn complex_sub(a: Complex, b: Complex) -> Complex {
2320 Complex {
2321 re: a.re - b.re,
2322 im: a.im - b.im,
2323 }
2324}
2325
2326#[inline]
2327fn complexmul(a: Complex, b: Complex) -> Complex {
2328 Complex {
2329 re: a.re * b.re - a.im * b.im,
2330 im: a.re * b.im + a.im * b.re,
2331 }
2332}
2333
2334#[inline]
2335fn complex_div(a: Complex, b: Complex) -> Complex {
2336 let den = (b.re * b.re + b.im * b.im).max(1e-300);
2337 Complex {
2338 re: (a.re * b.re + a.im * b.im) / den,
2339 im: (a.im * b.re - a.re * b.im) / den,
2340 }
2341}
2342
2343#[inline]
2344fn complex_abs(z: Complex) -> f64 {
2345 z.re.hypot(z.im)
2346}
2347
2348#[inline]
2349fn complex_ln(z: Complex) -> Complex {
2350 Complex {
2351 re: complex_abs(z).ln(),
2352 im: z.im.atan2(z.re),
2353 }
2354}
2355
2356#[inline]
2357fn complex_exp(z: Complex) -> Complex {
2358 let e = z.re.exp();
2359 Complex {
2360 re: e * z.im.cos(),
2361 im: e * z.im.sin(),
2362 }
2363}
2364
2365#[inline]
2366fn complex_sin(z: Complex) -> Complex {
2367 Complex {
2368 re: z.re.sin() * z.im.cosh(),
2369 im: z.re.cos() * z.im.sinh(),
2370 }
2371}
2372
2373fn complex_log_gamma_lanczos(z: Complex) -> Complex {
2374 const G: f64 = 7.0;
2378 const COEFFS: [f64; 9] = [
2379 0.999_999_999_999_809_9,
2380 676.520_368_121_885_1,
2381 -1_259.139_216_722_402_8,
2382 771.323_428_777_653_1,
2383 -176.615_029_162_140_6,
2384 12.507_343_278_686_905,
2385 -0.138_571_095_265_720_12,
2386 9.984_369_578_019_572e-6,
2387 1.505_632_735_149_311_6e-7,
2388 ];
2389
2390 if z.re < 0.5 {
2391 let piz = Complex {
2392 re: std::f64::consts::PI * z.re,
2393 im: std::f64::consts::PI * z.im,
2394 };
2395 let one_minusz = Complex {
2396 re: 1.0 - z.re,
2397 im: -z.im,
2398 };
2399 return complex_sub(
2400 complex_sub(
2401 Complex {
2402 re: std::f64::consts::PI.ln(),
2403 im: 0.0,
2404 },
2405 complex_ln(complex_sin(piz)),
2406 ),
2407 complex_log_gamma_lanczos(one_minusz),
2408 );
2409 }
2410
2411 let z1 = Complex {
2412 re: z.re - 1.0,
2413 im: z.im,
2414 };
2415 let mut x = Complex {
2416 re: COEFFS[0],
2417 im: 0.0,
2418 };
2419 for (i, c) in COEFFS.iter().enumerate().skip(1) {
2420 x = complex_add(
2421 x,
2422 complex_div(
2423 Complex { re: *c, im: 0.0 },
2424 Complex {
2425 re: z1.re + i as f64,
2426 im: z1.im,
2427 },
2428 ),
2429 );
2430 }
2431 let t = Complex {
2432 re: z1.re + G + 0.5,
2433 im: z1.im,
2434 };
2435 complex_add(
2436 complex_add(
2437 Complex {
2438 re: 0.5 * (2.0 * std::f64::consts::PI).ln(),
2439 im: 0.0,
2440 },
2441 complexmul(
2442 Complex {
2443 re: z1.re + 0.5,
2444 im: z1.im,
2445 },
2446 complex_ln(t),
2447 ),
2448 ),
2449 complex_sub(complex_ln(x), t),
2450 )
2451}
2452
2453fn cloglog_survival_gamma_reference(mu: f64, sigma: f64) -> Result<f64, EstimationError> {
2457 if !(mu.is_finite() && sigma.is_finite()) || sigma <= 0.0 {
2458 crate::bail_invalid_estim!(
2459 "Gamma cloglog reference backend requires finite mu and positive sigma"
2460 );
2461 }
2462
2463 let n = (CLOGLOG_GAMMA_T_MAX_REF / CLOGLOG_GAMMA_H_REF).round() as usize;
2497 let n = if n.is_multiple_of(2) { n } else { n + 1 };
2498 let h = CLOGLOG_GAMMA_T_MAX_REF / n as f64;
2499
2500 let eval = |t: f64| -> f64 {
2501 let z = Complex {
2502 re: CLOGLOG_GAMMA_K_REF,
2503 im: t,
2504 };
2505 let log_gamma = complex_log_gamma_lanczos(z);
2506 let z_sq = complexmul(z, z);
2507 let exponent = complex_sub(
2508 complex_add(
2509 log_gamma,
2510 Complex {
2511 re: 0.5 * sigma * sigma * z_sq.re,
2512 im: 0.5 * sigma * sigma * z_sq.im,
2513 },
2514 ),
2515 Complex {
2516 re: mu * z.re,
2517 im: mu * z.im,
2518 },
2519 );
2520 complex_exp(exponent).re
2521 };
2522
2523 let f0 = eval(0.0);
2524 let fn_ = eval(CLOGLOG_GAMMA_T_MAX_REF);
2525 let mut sum_s = f0 + fn_;
2526 for i in 1..n {
2527 let t = i as f64 * h;
2528 let fi = eval(t);
2529 let w = if i % 2 == 0 { 2.0 } else { 4.0 };
2530 sum_s += w * fi;
2531 }
2532 let sval = ((h / 3.0) * sum_s / std::f64::consts::PI).clamp(0.0, 1.0);
2533 if !sval.is_finite() {
2534 crate::bail_invalid_estim!("Gamma cloglog reference backend produced non-finite values");
2535 }
2536 Ok(sval)
2537}
2538
2539pub(crate) fn cloglog_posterior_meanwith_deriv_controlled(
2540 ctx: &QuadratureContext,
2541 mu: f64,
2542 sigma: f64,
2543) -> IntegratedMeanDerivative {
2544 if !(mu.is_finite() && sigma.is_finite()) || sigma <= CLOGLOG_SIGMA_DEGENERATE {
2589 return IntegratedMeanDerivative {
2590 mean: cloglog_mean_exact(mu),
2593 dmean_dmu: cloglog_mean_d1_exact(mu),
2596 mode: IntegratedExpectationMode::ExactClosedForm,
2597 };
2598 }
2599 if sigma >= CLOGLOG_LARGE_SIGMA_ASYMPTOTIC_MIN {
2600 let (log_base, base_mode) = cloglog_log_survival_term_controlled(ctx, mu, sigma);
2606 let (log_shift, shift_mode) =
2607 cloglog_log_survival_term_controlled(ctx, mu + sigma * sigma, sigma);
2608 let mean = (-log_base.exp_m1()).clamp(0.0, 1.0);
2610 let dmean = cloglog_shift_identity_derivative_log(mu, sigma, log_shift);
2611 return IntegratedMeanDerivative {
2612 mean,
2613 dmean_dmu: dmean.max(0.0),
2614 mode: worse_integrated_expectation_mode(base_mode, shift_mode),
2615 };
2616 }
2617 let candidate = if sigma < CLOGLOG_SIGMA_TAYLOR_MAX {
2618 cloglog_small_sigma_taylor(mu, sigma)
2619 } else if let Some(out) = cloglog_extreme_asymptotic(mu, sigma) {
2620 out
2621 } else {
2622 let ((survival, mode), (shifted_survival, shifted_mode)) =
2623 cloglog_survival_pair_controlled(ctx, mu, sigma);
2624 if matches!(mode, IntegratedExpectationMode::QuadratureFallback)
2625 || matches!(shifted_mode, IntegratedExpectationMode::QuadratureFallback)
2626 {
2627 return cloglog_posterior_meanwith_deriv_quadrature(mu, sigma);
2628 }
2629 let mean = cloglog_mean_from_survival(survival);
2630 let dmean = cloglog_shift_identity_derivative(mu, sigma, shifted_survival);
2631 let mode = if matches!(mode, IntegratedExpectationMode::ControlledAsymptotic)
2632 || matches!(
2633 shifted_mode,
2634 IntegratedExpectationMode::ControlledAsymptotic
2635 ) {
2636 IntegratedExpectationMode::ControlledAsymptotic
2637 } else {
2638 mode
2639 };
2640 IntegratedMeanDerivative {
2641 mean,
2642 dmean_dmu: dmean.max(0.0),
2643 mode,
2644 }
2645 };
2646 if matches!(
2652 candidate.mode,
2653 IntegratedExpectationMode::ControlledAsymptotic
2654 ) && sigma >= CLOGLOG_LARGE_SIGMA_ASYMPTOTIC_MIN
2655 {
2656 return candidate;
2657 }
2658 let ghq = cloglog_posterior_meanwith_deriv_quadrature(mu, sigma);
2659 if integrated_mean_derivative_drift_exceeds(&candidate, &ghq, 1e-6, 1e-4, 1e-7, 1e-3) {
2666 ghq
2667 } else {
2668 candidate
2669 }
2670}
2671
2672pub fn integrated_inverse_link_mean_and_derivative(
2673 quadctx: &QuadratureContext,
2674 link: LinkFunction,
2675 mu: f64,
2676 sigma: f64,
2677) -> Result<IntegratedMeanDerivative, EstimationError> {
2678 match link {
2707 LinkFunction::Log => {
2708 let (mean, saturated) = safe_expwith_saturation(mu + 0.5 * sigma * sigma);
2709 Ok(IntegratedMeanDerivative {
2710 mean,
2711 dmean_dmu: mean,
2712 mode: if saturated {
2713 IntegratedExpectationMode::ControlledAsymptotic
2714 } else {
2715 IntegratedExpectationMode::ExactClosedForm
2716 },
2717 })
2718 }
2719 LinkFunction::Probit => Ok(probit_posterior_meanwith_deriv_exact(mu, sigma)),
2720 LinkFunction::Logit => logit_posterior_meanwith_deriv_controlled(mu, sigma),
2721 LinkFunction::CLogLog => Ok(cloglog_posterior_meanwith_deriv_controlled(quadctx, mu, sigma)),
2722 LinkFunction::Sas => Err(EstimationError::InvalidInput(
2723 "state-less integrated SAS moments are unsupported; use SAS-aware prediction APIs with explicit (epsilon, log_delta)".to_string(),
2724 )),
2725 LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
2726 "state-less integrated Beta-Logistic moments are unsupported; use link-aware prediction APIs with explicit (delta, epsilon)".to_string(),
2727 )),
2728 LinkFunction::Identity => Ok(IntegratedMeanDerivative {
2729 mean: mu,
2730 dmean_dmu: 1.0,
2731 mode: IntegratedExpectationMode::ExactClosedForm,
2732 }),
2733 }
2734}
2735
2736#[inline]
2737pub fn integrated_inverse_link_jet(
2738 quadctx: &QuadratureContext,
2739 link: LinkFunction,
2740 mu: f64,
2741 sigma: f64,
2742) -> Result<IntegratedInverseLinkJet, EstimationError> {
2743 match link {
2744 LinkFunction::Log => {
2745 let (mean, saturated) = safe_expwith_saturation(mu + 0.5 * sigma * sigma);
2746 Ok(IntegratedInverseLinkJet {
2747 mean,
2748 d1: mean,
2749 d2: mean,
2750 d3: mean,
2751 mode: if saturated {
2752 IntegratedExpectationMode::ControlledAsymptotic
2753 } else {
2754 IntegratedExpectationMode::ExactClosedForm
2755 },
2756 })
2757 }
2758 LinkFunction::Probit => Ok(integrated_probit_jet(mu, sigma)),
2759 LinkFunction::Logit => {
2760 if sigma > LOGIT_JET_GHQ_SIGMA_MAX {
2761 return logit_wide_sigma_jet(mu, sigma);
2765 }
2766 let (mean, d1, d2, d3) = integrate_normal_ghq_adaptive(quadctx, mu, sigma, |x| {
2771 component_point_jet(LinkComponent::Logit, x)
2772 });
2773 let mode = if sigma <= 1e-10 {
2774 IntegratedExpectationMode::ExactClosedForm
2775 } else {
2776 match logit_posterior_meanwith_deriv_controlled(mu, sigma) {
2780 Ok(scalar) => scalar.mode,
2781 Err(_) => IntegratedExpectationMode::QuadratureFallback,
2782 }
2783 };
2784 Ok(IntegratedInverseLinkJet {
2785 mean,
2786 d1: d1.max(0.0),
2787 d2,
2788 d3,
2789 mode,
2790 })
2791 }
2792 LinkFunction::CLogLog => {
2793 validate_latent_cloglog_inputs(mu, sigma)?;
2794 Ok(integrated_cloglog_inverse_link_jet_controlled(
2795 quadctx, mu, sigma,
2796 ))
2797 }
2798 LinkFunction::Sas => Err(EstimationError::InvalidInput(
2799 "state-less integrated SAS jet is unsupported; use SAS-aware prediction APIs with explicit (epsilon, log_delta)".to_string(),
2800 )),
2801 LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
2802 "state-less integrated Beta-Logistic jet is unsupported; use link-aware prediction APIs with explicit (delta, epsilon)".to_string(),
2803 )),
2804 LinkFunction::Identity => Ok(IntegratedInverseLinkJet {
2805 mean: mu,
2806 d1: 1.0,
2807 d2: 0.0,
2808 d3: 0.0,
2809 mode: IntegratedExpectationMode::ExactClosedForm,
2810 }),
2811 }
2812}
2813
2814#[inline]
2825fn logit_wide_sigma_jet(mu: f64, sigma: f64) -> Result<IntegratedInverseLinkJet, EstimationError> {
2826 let scalar = logit_posterior_meanwith_deriv_controlled(mu, sigma)?;
2827 let d2 = integrate_normal_adaptive(mu, sigma, |x| {
2828 component_point_jet(LinkComponent::Logit, x).2
2829 });
2830 let d3 = integrate_normal_adaptive(mu, sigma, |x| {
2831 component_point_jet(LinkComponent::Logit, x).3
2832 });
2833 Ok(IntegratedInverseLinkJet {
2834 mean: scalar.mean,
2835 d1: scalar.dmean_dmu.max(0.0),
2836 d2,
2837 d3,
2838 mode: scalar.mode,
2839 })
2840}
2841
2842#[inline]
2843pub fn integrated_logit_inverse_link_jet_pirls(
2844 quadctx: &QuadratureContext,
2845 mu: f64,
2846 sigma: f64,
2847) -> Result<IntegratedInverseLinkJet, EstimationError> {
2848 if sigma <= 1e-10 {
2853 let (mean, d1, d2, d3) = component_point_jet(LinkComponent::Logit, mu);
2854 return Ok(IntegratedInverseLinkJet {
2855 mean,
2856 d1,
2857 d2,
2858 d3,
2859 mode: IntegratedExpectationMode::ExactClosedForm,
2860 });
2861 }
2862 if sigma > LOGIT_JET_GHQ_SIGMA_MAX {
2863 return logit_wide_sigma_jet(mu, sigma);
2864 }
2865 let (mean, d1, d2, d3) = integrate_normal_ghq_adaptive(quadctx, mu, sigma, |x| {
2866 component_point_jet(LinkComponent::Logit, x)
2867 });
2868 let mode = match logit_posterior_meanwith_deriv_controlled(mu, sigma) {
2869 Ok(scalar) => scalar.mode,
2870 Err(_) => IntegratedExpectationMode::QuadratureFallback,
2871 };
2872 Ok(IntegratedInverseLinkJet {
2873 mean,
2874 d1: d1.max(0.0),
2875 d2,
2876 d3,
2877 mode,
2878 })
2879}
2880
2881#[inline]
2882fn sas_point_jet(x: f64, epsilon: f64, log_delta: f64) -> (f64, f64, f64, f64) {
2883 let jet = sas_inverse_link_jet(x, epsilon, log_delta);
2884 (jet.mu, jet.d1, jet.d2, jet.d3)
2885}
2886
2887#[inline]
2888fn beta_logistic_point_jet(x: f64, log_shape_center: f64, epsilon: f64) -> (f64, f64, f64, f64) {
2889 let jet = beta_logistic_inverse_link_jet(x, log_shape_center, epsilon);
2890 (jet.mu, jet.d1, jet.d2, jet.d3)
2891}
2892
2893#[inline]
2894fn worse_integrated_expectation_mode(
2895 lhs: IntegratedExpectationMode,
2896 rhs: IntegratedExpectationMode,
2897) -> IntegratedExpectationMode {
2898 if lhs.rank() >= rhs.rank() { lhs } else { rhs }
2899}
2900
2901#[inline]
2902fn integrated_scalar_drift_exceeds(
2903 candidate: f64,
2904 reference: f64,
2905 abs_tol: f64,
2906 rel_tol: f64,
2907) -> bool {
2908 if !(candidate.is_finite() && reference.is_finite()) {
2909 return true;
2910 }
2911 (candidate - reference).abs() > abs_tol.max(rel_tol * reference.abs().max(candidate.abs()))
2912}
2913
2914#[inline]
2915fn integrated_mean_derivative_drift_exceeds(
2916 candidate: &IntegratedMeanDerivative,
2917 reference: &IntegratedMeanDerivative,
2918 mean_abs_tol: f64,
2919 mean_rel_tol: f64,
2920 deriv_abs_tol: f64,
2921 deriv_rel_tol: f64,
2922) -> bool {
2923 integrated_scalar_drift_exceeds(candidate.mean, reference.mean, mean_abs_tol, mean_rel_tol)
2924 || integrated_scalar_drift_exceeds(
2925 candidate.dmean_dmu,
2926 reference.dmean_dmu,
2927 deriv_abs_tol,
2928 deriv_rel_tol,
2929 )
2930}
2931
2932#[inline]
2933fn component_point_jet(component: LinkComponent, x: f64) -> (f64, f64, f64, f64) {
2934 let jet = component_inverse_link_jet(component, x);
2937 (jet.mu, jet.d1, jet.d2, jet.d3)
2938}
2939
2940#[inline]
2941fn integrated_mixture_component_jet(
2942 ctx: &QuadratureContext,
2943 component: LinkComponent,
2944 mu: f64,
2945 sigma: f64,
2946) -> IntegratedInverseLinkJet {
2947 match component {
2952 LinkComponent::Logit => integrated_inverse_link_jet(ctx, LinkFunction::Logit, mu, sigma)
2953 .unwrap_or_else(|_| integrated_logit_jet_ghq(ctx, mu, sigma)),
2954 LinkComponent::Probit => integrated_probit_jet(mu, sigma),
2955 LinkComponent::CLogLog => integrated_cloglog_inverse_link_jet_controlled(ctx, mu, sigma),
2956 LinkComponent::LogLog | LinkComponent::Cauchit => {
2957 let (mean, d1, d2, d3) = integrate_normal_ghq_adaptive(ctx, mu, sigma, |x| {
2958 component_point_jet(component, x)
2959 });
2960 IntegratedInverseLinkJet {
2961 mean,
2962 d1: d1.max(0.0),
2963 d2,
2964 d3,
2965 mode: if sigma <= 1e-10 {
2966 IntegratedExpectationMode::ExactClosedForm
2967 } else {
2968 IntegratedExpectationMode::QuadratureFallback
2969 },
2970 }
2971 }
2972 }
2973}
2974
2975#[inline]
2976fn integrated_mixture_jet(
2977 ctx: &QuadratureContext,
2978 mu: f64,
2979 sigma: f64,
2980 mixture_state: &MixtureLinkState,
2981) -> Result<IntegratedInverseLinkJet, EstimationError> {
2982 if mixture_state.components.is_empty() {
2987 crate::bail_invalid_estim!(
2988 "integrated mixture-link jet requires at least one blended component"
2989 );
2990 }
2991 if mixture_state.components.len() != mixture_state.pi.len() {
2992 crate::bail_invalid_estim!(
2993 "integrated mixture-link jet requires matching component and weight counts"
2994 );
2995 }
2996
2997 let mut mean = 0.0_f64;
3002 let mut d1 = 0.0_f64;
3003 let mut d2 = 0.0_f64;
3004 let mut d3 = 0.0_f64;
3005 let mut mode = IntegratedExpectationMode::ExactClosedForm;
3006 let mut saw_positive_weight = false;
3007
3008 for (&component, &weight) in mixture_state.components.iter().zip(mixture_state.pi.iter()) {
3009 if weight <= 0.0 {
3010 continue;
3011 }
3012 let jet = integrated_mixture_component_jet(ctx, component, mu, sigma);
3013 mean += weight * jet.mean;
3014 d1 += weight * jet.d1;
3015 d2 += weight * jet.d2;
3016 d3 += weight * jet.d3;
3017 if jet.mode.rank() > mode.rank() {
3018 mode = jet.mode;
3019 }
3020 saw_positive_weight = true;
3021 }
3022
3023 if !saw_positive_weight {
3024 crate::bail_invalid_estim!(
3025 "integrated mixture-link jet requires at least one positive component weight"
3026 .to_string(),
3027 );
3028 }
3029
3030 Ok(IntegratedInverseLinkJet {
3031 mean,
3032 d1: d1.max(0.0),
3033 d2,
3034 d3,
3035 mode,
3036 })
3037}
3038
3039#[inline]
3040fn integrated_sas_jet_ghq(
3041 ctx: &QuadratureContext,
3042 mu: f64,
3043 sigma: f64,
3044 sas_state: &SasLinkState,
3045) -> IntegratedInverseLinkJet {
3046 let (mean, d1, d2, d3) = integrate_normal_ghq_adaptive(ctx, mu, sigma, |x| {
3047 sas_point_jet(x, sas_state.epsilon, sas_state.log_delta)
3048 });
3049 IntegratedInverseLinkJet {
3050 mean,
3051 d1: d1.max(0.0),
3052 d2,
3053 d3,
3054 mode: if sigma <= 1e-10 {
3055 IntegratedExpectationMode::ExactClosedForm
3056 } else {
3057 IntegratedExpectationMode::QuadratureFallback
3058 },
3059 }
3060}
3061
3062#[inline]
3063fn integrated_beta_logistic_jet_ghq(
3064 ctx: &QuadratureContext,
3065 mu: f64,
3066 sigma: f64,
3067 beta_state: &SasLinkState,
3068) -> IntegratedInverseLinkJet {
3069 let (mean, d1, d2, d3) = integrate_normal_ghq_adaptive(ctx, mu, sigma, |x| {
3070 beta_logistic_point_jet(x, beta_state.log_delta, beta_state.epsilon)
3071 });
3072 IntegratedInverseLinkJet {
3073 mean,
3074 d1: d1.max(0.0),
3075 d2,
3076 d3,
3077 mode: if sigma <= 1e-10 {
3078 IntegratedExpectationMode::ExactClosedForm
3079 } else {
3080 IntegratedExpectationMode::QuadratureFallback
3081 },
3082 }
3083}
3084
3085#[inline]
3087pub fn integrated_inverse_link_jetwith_state(
3088 quadctx: &QuadratureContext,
3089 link: LinkFunction,
3090 mu: f64,
3091 sigma: f64,
3092 mixture_link_state: Option<&MixtureLinkState>,
3093 sas_link_state: Option<&SasLinkState>,
3094) -> Result<IntegratedInverseLinkJet, EstimationError> {
3095 if let Some(state) = mixture_link_state {
3096 return integrated_mixture_jet(quadctx, mu, sigma, state);
3097 }
3098 if matches!(link, LinkFunction::Sas) {
3099 let sas = sas_link_state.ok_or_else(|| {
3100 EstimationError::InvalidInput(
3101 "state-less integrated SAS jet is unsupported; explicit SasLinkState is required"
3102 .to_string(),
3103 )
3104 })?;
3105 return Ok(integrated_sas_jet_ghq(quadctx, mu, sigma, sas));
3106 }
3107 if matches!(link, LinkFunction::BetaLogistic) {
3108 let state = sas_link_state.ok_or_else(|| {
3109 EstimationError::InvalidInput(
3110 "state-less integrated Beta-Logistic jet is unsupported; explicit link state is required"
3111 .to_string(),
3112 )
3113 })?;
3114 return Ok(integrated_beta_logistic_jet_ghq(quadctx, mu, sigma, state));
3115 }
3116 integrated_inverse_link_jet(quadctx, link, mu, sigma)
3117}
3118
3119#[inline]
3134pub fn integrated_family_moments_jet(
3135 quadctx: &QuadratureContext,
3136 likelihood: &LikelihoodSpec,
3137 scale: LikelihoodScaleMetadata,
3138 eta: f64,
3139 se_eta: f64,
3140) -> Result<IntegratedMomentsJet, EstimationError> {
3141 const PROB_EPS: f64 = 1e-12;
3142 if !(eta.is_finite() && (-700.0..=700.0).contains(&eta)) {
3143 crate::bail_invalid_estim!(
3144 "integrated moments eta must be finite and within [-700, 700]; got {eta}"
3145 );
3146 }
3147 let e = eta;
3148 let se = se_eta.max(0.0);
3149 let mixture_link_state: Option<&MixtureLinkState> = likelihood.link.mixture_state();
3153 let sas_link_state: Option<&SasLinkState> = likelihood.link.sas_state();
3154 match &likelihood.response {
3155 ResponseFamily::Binomial => match &likelihood.link {
3156 InverseLink::Standard(StandardLink::Logit) => {
3157 let jet = integrated_inverse_link_jet(quadctx, LinkFunction::Logit, e, se)?;
3158 let mean = jet.mean;
3159 Ok(IntegratedMomentsJet {
3160 mean,
3161 variance: (mean * (1.0 - mean)).max(PROB_EPS),
3162 d1: jet.d1,
3163 d2: jet.d2,
3164 d3: jet.d3,
3165 mode: jet.mode,
3166 })
3167 }
3168 InverseLink::Standard(StandardLink::Probit) => {
3169 let jet = integrated_inverse_link_jet(quadctx, LinkFunction::Probit, e, se)?;
3170 let mean = jet.mean;
3171 Ok(IntegratedMomentsJet {
3172 mean,
3173 variance: (mean * (1.0 - mean)).max(PROB_EPS),
3174 d1: jet.d1,
3175 d2: jet.d2,
3176 d3: jet.d3,
3177 mode: jet.mode,
3178 })
3179 }
3180 InverseLink::Standard(StandardLink::CLogLog) => {
3181 let jet = integrated_inverse_link_jet(quadctx, LinkFunction::CLogLog, e, se)?;
3182 let mean = jet.mean;
3183 Ok(IntegratedMomentsJet {
3184 mean,
3185 variance: (mean * (1.0 - mean)).max(PROB_EPS),
3186 d1: jet.d1,
3187 d2: jet.d2,
3188 d3: jet.d3,
3189 mode: jet.mode,
3190 })
3191 }
3192 InverseLink::LatentCLogLog(_) => Err(EstimationError::InvalidInput(
3193 "Binomial+LatentCLogLog integrated moments require an explicit latent cloglog inverse-link state"
3194 .to_string(),
3195 )),
3196 InverseLink::Sas(_) => {
3197 let jet = integrated_inverse_link_jetwith_state(
3198 quadctx,
3199 LinkFunction::Sas,
3200 e,
3201 se,
3202 mixture_link_state,
3203 sas_link_state,
3204 )?;
3205 let mean = jet.mean;
3206 Ok(IntegratedMomentsJet {
3207 mean,
3208 variance: (mean * (1.0 - mean)).max(PROB_EPS),
3209 d1: jet.d1,
3210 d2: jet.d2,
3211 d3: jet.d3,
3212 mode: jet.mode,
3213 })
3214 }
3215 InverseLink::BetaLogistic(_) => {
3216 let jet = integrated_inverse_link_jetwith_state(
3217 quadctx,
3218 LinkFunction::BetaLogistic,
3219 e,
3220 se,
3221 mixture_link_state,
3222 sas_link_state,
3223 )?;
3224 let mean = jet.mean;
3225 Ok(IntegratedMomentsJet {
3226 mean,
3227 variance: (mean * (1.0 - mean)).max(PROB_EPS),
3228 d1: jet.d1,
3229 d2: jet.d2,
3230 d3: jet.d3,
3231 mode: jet.mode,
3232 })
3233 }
3234 InverseLink::Mixture(state) => {
3235 let jet = integrated_mixture_jet(quadctx, e, se, state)?;
3236 let mean = jet.mean;
3237 Ok(IntegratedMomentsJet {
3238 mean,
3239 variance: (mean * (1.0 - mean)).max(PROB_EPS),
3240 d1: jet.d1,
3241 d2: jet.d2,
3242 d3: jet.d3,
3243 mode: jet.mode,
3244 })
3245 }
3246 InverseLink::Standard(other) => Err(EstimationError::InvalidInput(format!(
3247 "Binomial response paired with unsupported standard link {other:?} for integrated moments"
3248 ))),
3249 },
3250 ResponseFamily::Gaussian => Ok(IntegratedMomentsJet {
3251 mean: e,
3252 variance: 1.0,
3253 d1: 1.0,
3254 d2: 0.0,
3255 d3: 0.0,
3256 mode: IntegratedExpectationMode::ExactClosedForm,
3257 }),
3258 ResponseFamily::RoystonParmar => {
3259 let jet = integrated_inverse_link_jetwith_state(
3260 quadctx,
3261 LinkFunction::CLogLog,
3262 e,
3263 se,
3264 mixture_link_state,
3265 sas_link_state,
3266 )?;
3267 let mean = (1.0 - jet.mean).clamp(0.0, 1.0);
3268 Ok(IntegratedMomentsJet {
3269 mean,
3270 variance: (mean * (1.0 - mean)).max(PROB_EPS),
3271 d1: -jet.d1,
3272 d2: -jet.d2,
3273 d3: -jet.d3,
3274 mode: jet.mode,
3275 })
3276 }
3277 ResponseFamily::Beta { phi } => {
3278 let jet = integrated_inverse_link_jet(quadctx, LinkFunction::Logit, e, se)?;
3279 let mean = jet.mean.clamp(PROB_EPS, 1.0 - PROB_EPS);
3280 Ok(IntegratedMomentsJet {
3281 mean,
3282 variance: (mean * (1.0 - mean) / (1.0 + phi.max(1e-12))).max(PROB_EPS),
3283 d1: jet.d1,
3284 d2: jet.d2,
3285 d3: jet.d3,
3286 mode: jet.mode,
3287 })
3288 }
3289 ResponseFamily::Poisson
3290 | ResponseFamily::Tweedie { .. }
3291 | ResponseFamily::NegativeBinomial { .. }
3292 | ResponseFamily::Gamma => {
3293 let s2 = se * se;
3298 let (mean, saturated) = safe_expwith_saturation(e + 0.5 * s2);
3299 let variance = match &likelihood.response {
3310 ResponseFamily::Poisson => mean,
3311 ResponseFamily::Tweedie { p } => {
3312 let phi = scale.fixed_phi().ok_or_else(|| {
3313 EstimationError::InvalidInput(format!(
3314 "Tweedie integrated variance requires dispersion φ in the scale \
3315 metadata (Var = φ·μ^p); got {scale:?} with no φ"
3316 ))
3317 })?;
3318 phi * mean.powf(*p)
3319 }
3320 ResponseFamily::NegativeBinomial { theta, .. } => {
3321 mean + mean * mean / theta.max(1e-12)
3322 }
3323 ResponseFamily::Gamma => {
3324 let shape = scale.gamma_shape().ok_or_else(|| {
3325 EstimationError::InvalidInput(format!(
3326 "Gamma integrated variance requires the shape k in the scale \
3327 metadata (Var = μ²/k = φ·μ²); got {scale:?} with no shape"
3328 ))
3329 })?;
3330 mean * mean / shape.max(1e-12)
3331 }
3332 other => {
3336 return Err(EstimationError::InvalidInput(format!(
3337 "integrated log-normal moments reached unexpected family {other:?}"
3338 )));
3339 }
3340 };
3341 Ok(IntegratedMomentsJet {
3342 mean,
3343 variance,
3344 d1: mean,
3345 d2: mean,
3346 d3: mean,
3347 mode: if saturated {
3348 IntegratedExpectationMode::ControlledAsymptotic
3349 } else {
3350 IntegratedExpectationMode::ExactClosedForm
3351 },
3352 })
3353 }
3354 }
3355}
3356
3357pub fn logit_posterior_meanwith_deriv_batch(
3360 ctx: &QuadratureContext,
3361 eta: &ndarray::Array1<f64>,
3362 se_eta: &ndarray::Array1<f64>,
3363) -> Result<(ndarray::Array1<f64>, ndarray::Array1<f64>), EstimationError> {
3364 use rayon::iter::{IntoParallelIterator, ParallelIterator};
3365 let n = eta.len();
3366 let pairs: Result<Vec<(f64, f64)>, _> = (0..n)
3368 .into_par_iter()
3369 .map(|i| {
3370 let integrated = integrated_inverse_link_mean_and_derivative(
3371 ctx,
3372 LinkFunction::Logit,
3373 eta[i],
3374 se_eta[i],
3375 )?;
3376 Ok::<_, EstimationError>((integrated.mean, integrated.dmean_dmu))
3377 })
3378 .collect();
3379 let pairs = pairs?;
3380 let mut mu = ndarray::Array1::<f64>::zeros(n);
3381 let mut dmu = ndarray::Array1::<f64>::zeros(n);
3382 for (i, (m, d)) in pairs.into_iter().enumerate() {
3383 mu[i] = m;
3384 dmu[i] = d;
3385 }
3386
3387 Ok((mu, dmu))
3388}
3389
3390pub fn logit_posterior_mean_batch(
3394 ctx: &QuadratureContext,
3395 eta: &ndarray::Array1<f64>,
3396 se_eta: &ndarray::Array1<f64>,
3397) -> Result<ndarray::Array1<f64>, EstimationError> {
3398 use rayon::iter::{IntoParallelIterator, ParallelIterator};
3399 let n = eta.len();
3400 let values: Result<Vec<f64>, EstimationError> = (0..n)
3401 .into_par_iter()
3402 .map(|i| {
3403 integrated_inverse_link_mean_and_derivative(ctx, LinkFunction::Logit, eta[i], se_eta[i])
3404 .map(|integrated| integrated.mean)
3405 })
3406 .collect();
3407 Ok(ndarray::Array1::from_vec(values?))
3408}
3409
3410pub trait GhqValue: Sized {
3411 fn zero() -> Self;
3412 fn addweighted(&mut self, weight: f64, value: Self);
3413 fn scale(self, factor: f64) -> Self;
3414}
3415
3416impl GhqValue for f64 {
3417 #[inline]
3418 fn zero() -> Self {
3419 0.0
3420 }
3421
3422 #[inline]
3423 fn addweighted(&mut self, weight: f64, value: Self) {
3424 *self += weight * value;
3425 }
3426
3427 #[inline]
3428 fn scale(self, factor: f64) -> Self {
3429 self * factor
3430 }
3431}
3432
3433impl GhqValue for (f64, f64) {
3434 #[inline]
3435 fn zero() -> Self {
3436 (0.0, 0.0)
3437 }
3438
3439 #[inline]
3440 fn addweighted(&mut self, weight: f64, value: Self) {
3441 self.0 += weight * value.0;
3442 self.1 += weight * value.1;
3443 }
3444
3445 #[inline]
3446 fn scale(self, factor: f64) -> Self {
3447 (self.0 * factor, self.1 * factor)
3448 }
3449}
3450
3451impl GhqValue for (f64, f64, f64, f64) {
3452 #[inline]
3453 fn zero() -> Self {
3454 (0.0, 0.0, 0.0, 0.0)
3455 }
3456
3457 #[inline]
3458 fn addweighted(&mut self, weight: f64, value: Self) {
3459 self.0 += weight * value.0;
3460 self.1 += weight * value.1;
3461 self.2 += weight * value.2;
3462 self.3 += weight * value.3;
3463 }
3464
3465 #[inline]
3466 fn scale(self, factor: f64) -> Self {
3467 (
3468 self.0 * factor,
3469 self.1 * factor,
3470 self.2 * factor,
3471 self.3 * factor,
3472 )
3473 }
3474}
3475
3476impl GhqValue for (f64, f64, f64, f64, f64, f64) {
3477 #[inline]
3478 fn zero() -> Self {
3479 (0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
3480 }
3481
3482 #[inline]
3483 fn addweighted(&mut self, weight: f64, value: Self) {
3484 self.0 += weight * value.0;
3485 self.1 += weight * value.1;
3486 self.2 += weight * value.2;
3487 self.3 += weight * value.3;
3488 self.4 += weight * value.4;
3489 self.5 += weight * value.5;
3490 }
3491
3492 #[inline]
3493 fn scale(self, factor: f64) -> Self {
3494 (
3495 self.0 * factor,
3496 self.1 * factor,
3497 self.2 * factor,
3498 self.3 * factor,
3499 self.4 * factor,
3500 self.5 * factor,
3501 )
3502 }
3503}
3504
3505#[inline]
3506fn integrate_normal_ghq_adaptive<F, R>(ctx: &QuadratureContext, eta: f64, se_eta: f64, f: F) -> R
3507where
3508 F: Fn(f64) -> R,
3509 R: GhqValue,
3510{
3511 if se_eta < 1e-10 {
3512 return f(eta);
3513 }
3514 let n = adaptive_point_count_from_sd(se_eta.abs());
3515 with_gh_nodesweights(ctx, n, |nodes, weights| {
3516 let scale = SQRT_2 * se_eta;
3517 let mut sum = R::zero();
3518 for i in 0..n {
3519 sum.addweighted(weights[i], f(eta + scale * nodes[i]));
3520 }
3521 sum.scale(1.0 / std::f64::consts::PI.sqrt())
3522 })
3523}
3524
3525#[inline]
3526fn integrated_probit_jet(mu: f64, sigma: f64) -> IntegratedInverseLinkJet {
3527 if sigma <= 1e-10 {
3528 let z = mu.clamp(-30.0, 30.0);
3529 let clamp_active = z != mu;
3530 let pdf = gam_math::probability::normal_pdf(z);
3531 return IntegratedInverseLinkJet {
3532 mean: gam_math::probability::normal_cdf(z),
3533 d1: if clamp_active { 0.0 } else { pdf },
3534 d2: if clamp_active { 0.0 } else { -z * pdf },
3535 d3: if clamp_active {
3536 0.0
3537 } else {
3538 (z * z - 1.0) * pdf
3539 },
3540 mode: IntegratedExpectationMode::ExactClosedForm,
3541 };
3542 }
3543 let s = (1.0 + sigma * sigma).sqrt();
3544 let z = mu / s;
3545 let pdf = gam_math::probability::normal_pdf(z);
3546 IntegratedInverseLinkJet {
3547 mean: gam_math::probability::normal_cdf(z),
3548 d1: pdf / s,
3549 d2: -z * pdf / (s * s),
3550 d3: (z * z - 1.0) * pdf / (s * s * s),
3551 mode: IntegratedExpectationMode::ExactClosedForm,
3552 }
3553}
3554
3555#[inline]
3556fn integrated_logit_jet_ghq(
3557 ctx: &QuadratureContext,
3558 mu: f64,
3559 sigma: f64,
3560) -> IntegratedInverseLinkJet {
3561 let (mean, d1, d2, d3) = integrate_normal_ghq_adaptive(ctx, mu, sigma, |x| {
3562 component_point_jet(LinkComponent::Logit, x)
3563 });
3564 IntegratedInverseLinkJet {
3565 mean,
3566 d1: d1.max(0.0),
3567 d2,
3568 d3,
3569 mode: if sigma <= 1e-10 {
3570 IntegratedExpectationMode::ExactClosedForm
3571 } else {
3572 IntegratedExpectationMode::QuadratureFallback
3573 },
3574 }
3575}
3576
3577#[inline]
3578fn cloglog_inverse_link_controlled_values(
3579 ctx: &QuadratureContext,
3580 mu: f64,
3581 sigma: f64,
3582 max_order: usize,
3583) -> ([f64; 6], IntegratedExpectationMode) {
3584 assert!(max_order <= 5);
3585 if sigma <= 1e-10 {
3586 let (mean, d1, d2, d3, d4, d5) = cloglog_point_jet5(mu);
3587 return (
3588 [mean, d1, d2, d3, d4, d5],
3589 IntegratedExpectationMode::ExactClosedForm,
3590 );
3591 }
3592
3593 let (k, log_k0, mode) = latent_cloglog_kernel_terms(ctx, mu, sigma, max_order);
3594 let mut values = [0.0; 6];
3595 values[0] = if log_k0.is_finite() {
3596 -log_k0.exp_m1()
3597 } else {
3598 1.0
3599 };
3600 values[1] = k[1].max(0.0);
3601 if sigma > CLOGLOG_JET_MOMENT_SIGMA_MAX {
3602 if max_order >= 2 {
3603 values[2] = integrate_normal_adaptive(mu, sigma, |x| cloglog_point_jet5(x).2);
3604 }
3605 if max_order >= 3 {
3606 values[3] = integrate_normal_adaptive(mu, sigma, |x| cloglog_point_jet5(x).3);
3607 }
3608 if max_order >= 4 {
3609 values[4] = integrate_normal_adaptive(mu, sigma, |x| cloglog_point_jet5(x).4);
3610 }
3611 if max_order >= 5 {
3612 values[5] = integrate_normal_adaptive(mu, sigma, |x| cloglog_point_jet5(x).5);
3613 }
3614 return (
3615 values,
3616 worse_integrated_expectation_mode(mode, IntegratedExpectationMode::QuadratureFallback),
3617 );
3618 }
3619 if max_order >= 2 {
3620 values[2] = k[1] - k[2];
3621 }
3622 if max_order >= 3 {
3623 values[3] = k[1] - 3.0 * k[2] + k[3];
3624 }
3625 if max_order >= 4 {
3626 values[4] = k[1] - 7.0 * k[2] + 6.0 * k[3] - k[4];
3627 }
3628 if max_order >= 5 {
3629 values[5] = k[1] - 15.0 * k[2] + 25.0 * k[3] - 10.0 * k[4] + k[5];
3630 }
3631 (values, mode)
3632}
3633
3634#[inline]
3635pub(crate) fn latent_cloglog_inverse_link_jet5_controlled(
3636 ctx: &QuadratureContext,
3637 mu: f64,
3638 sigma: f64,
3639) -> IntegratedInverseLinkJet5 {
3640 let (values, mode) = cloglog_inverse_link_controlled_values(ctx, mu, sigma, 5);
3641 IntegratedInverseLinkJet5 {
3642 mean: values[0],
3643 d1: values[1],
3644 d2: values[2],
3645 d3: values[3],
3646 d4: values[4],
3647 d5: values[5],
3648 mode,
3649 }
3650}
3651
3652#[derive(Clone, Copy, Debug)]
3662pub struct LatentCLogLogJet5 {
3663 pub mean: f64,
3664 pub d1: f64,
3665 pub d2: f64,
3666 pub d3: f64,
3667 pub d4: f64,
3668 pub d5: f64,
3669 pub mode: IntegratedExpectationMode,
3670}
3671
3672pub fn latent_cloglog_jet5(
3673 quadctx: &QuadratureContext,
3674 eta: f64,
3675 sigma: f64,
3676) -> Result<LatentCLogLogJet5, EstimationError> {
3677 validate_latent_cloglog_inputs(eta, sigma)?;
3678 let jet = latent_cloglog_inverse_link_jet5_controlled(quadctx, eta, sigma);
3684 Ok(LatentCLogLogJet5 {
3685 mean: jet.mean,
3686 d1: jet.d1,
3687 d2: jet.d2,
3688 d3: jet.d3,
3689 d4: jet.d4,
3690 d5: jet.d5,
3691 mode: jet.mode,
3692 })
3693}
3694
3695#[inline]
3696pub fn latent_cloglog_inverse_link_jet(
3697 quadctx: &QuadratureContext,
3698 eta: f64,
3699 sigma: f64,
3700) -> Result<IntegratedInverseLinkJet, EstimationError> {
3701 let jet = latent_cloglog_jet5(quadctx, eta, sigma)?;
3702 Ok(IntegratedInverseLinkJet {
3703 mean: jet.mean,
3704 d1: jet.d1,
3705 d2: jet.d2,
3706 d3: jet.d3,
3707 mode: jet.mode,
3708 })
3709}
3710
3711#[inline]
3712fn integrated_cloglog_inverse_link_jet_controlled(
3713 ctx: &QuadratureContext,
3714 mu: f64,
3715 sigma: f64,
3716) -> IntegratedInverseLinkJet {
3717 let (values, mode) = cloglog_inverse_link_controlled_values(ctx, mu, sigma, 3);
3718 IntegratedInverseLinkJet {
3719 mean: values[0],
3720 d1: values[1],
3721 d2: values[2],
3722 d3: values[3],
3723 mode,
3724 }
3725}
3726
3727#[inline]
3728fn latent_cloglog_kernel_terms(
3729 ctx: &QuadratureContext,
3730 mu: f64,
3731 sigma: f64,
3732 max_order: usize,
3733) -> ([f64; 6], f64, IntegratedExpectationMode) {
3734 let sigma2 = sigma * sigma;
3735 let mut k = [0.0; 6];
3736 let mut log_k0 = f64::NEG_INFINITY;
3737 let mut mode = IntegratedExpectationMode::ExactClosedForm;
3738
3739 for (order, out) in k.iter_mut().enumerate().take(max_order + 1) {
3740 let kf = order as f64;
3741 let shifted_mu = mu + kf * sigma2;
3742 let (log_survival, term_mode) =
3750 cloglog_log_survival_term_controlled(ctx, shifted_mu, sigma);
3751 mode = worse_integrated_expectation_mode(mode, term_mode);
3752
3753 let log_value = kf * mu + 0.5 * kf * kf * sigma2 + log_survival;
3754 if order == 0 {
3755 log_k0 = log_value;
3756 }
3757 if !log_value.is_finite() {
3758 *out = 0.0;
3759 continue;
3760 }
3761 let upper = if order == 0 {
3762 1.0
3763 } else {
3764 let k_over_e = kf / std::f64::consts::E;
3765 k_over_e.powf(kf)
3766 };
3767 *out = safe_exp(log_value).clamp(0.0, upper);
3768 }
3769
3770 (k, log_k0, mode)
3771}
3772
3773#[inline]
3774pub fn normal_expectation_1d_adaptive<F>(
3775 ctx: &QuadratureContext,
3776 eta: f64,
3777 se_eta: f64,
3778 f: F,
3779) -> f64
3780where
3781 F: Fn(f64) -> f64,
3782{
3783 integrate_normal_ghq_adaptive(ctx, eta, se_eta, f)
3784}
3785
3786#[inline]
3787pub fn normal_expectation_1d_adaptive_pair<F>(
3788 ctx: &QuadratureContext,
3789 eta: f64,
3790 se_eta: f64,
3791 f: F,
3792) -> (f64, f64)
3793where
3794 F: Fn(f64) -> (f64, f64),
3795{
3796 integrate_normal_ghq_adaptive(ctx, eta, se_eta, f)
3797}
3798
3799fn adaptive_point_count_from_sd(max_sd: f64) -> usize {
3800 if max_sd.is_finite() && max_sd > 2.5 {
3810 51
3811 } else if max_sd.is_finite() && max_sd > 0.5 {
3812 31
3813 } else if max_sd.is_finite() && max_sd > 0.35 {
3814 21
3815 } else if max_sd.is_finite() && max_sd > 0.1 {
3816 15
3817 } else {
3818 7
3819 }
3820}
3821
3822#[inline]
3823fn with_gh_nodesweights<R>(
3824 ctx: &QuadratureContext,
3825 n: usize,
3826 f: impl FnOnce(&[f64], &[f64]) -> R,
3827) -> R {
3828 if n == 7 {
3829 let gh = ctx.gauss_hermite();
3830 f(&gh.nodes, &gh.weights)
3831 } else {
3832 let gh = ctx.gauss_hermite_n(n);
3833 f(&gh.nodes, &gh.weights)
3834 }
3835}
3836
3837#[inline]
3847fn cholesky_static<const D: usize>(cov: &[[f64; D]; D]) -> Option<[[f64; D]; D]> {
3848 let mut l = [[0.0_f64; D]; D];
3849 for i in 0..D {
3850 for j in 0..=i {
3851 let mut sum = cov[i][j];
3852 for k in 0..j {
3853 sum -= l[i][k] * l[j][k];
3854 }
3855 if i == j {
3856 if !sum.is_finite() || sum <= 0.0 {
3857 return None;
3858 }
3859 l[i][j] = sum.sqrt();
3860 } else {
3861 l[i][j] = sum / l[j][j];
3862 }
3863 }
3864 }
3865 Some(l)
3866}
3867
3868#[inline]
3871fn cholesky_static_with_jitter<const D: usize>(cov: &[[f64; D]; D]) -> Option<[[f64; D]; D]> {
3872 if D == 0 {
3873 return None;
3874 }
3875 for retry in 0..8 {
3876 let jitter = if retry == 0 {
3877 0.0
3878 } else {
3879 1e-12 * 10f64.powi(retry - 1)
3880 };
3881 if jitter == 0.0 {
3882 if let Some(l) = cholesky_static::<D>(cov) {
3883 return Some(l);
3884 }
3885 } else {
3886 let mut base = *cov;
3887 for i in 0..D {
3888 base[i][i] = cov[i][i] + jitter;
3889 }
3890 if let Some(l) = cholesky_static::<D>(&base) {
3891 return Some(l);
3892 }
3893 }
3894 }
3895 None
3896}
3897
3898#[inline]
3899fn adaptive_point_countwith_cap(max_sd: f64, max_n: usize) -> usize {
3900 adaptive_point_count_from_sd(max_sd).min(max_n)
3901}
3902
3903#[inline]
3904fn ghq_nd_integrate_try<const D: usize, F, R, E>(
3905 ctx: &QuadratureContext,
3906 mu: [f64; D],
3907 cov: [[f64; D]; D],
3908 max_n: usize,
3909 f: F,
3910) -> Result<Option<R>, E>
3911where
3912 F: Fn([f64; D]) -> Result<R, E>,
3913 R: GhqValue,
3914{
3915 let mut maxvar = 0.0_f64;
3916 for (i, row) in cov.iter().enumerate() {
3917 maxvar = maxvar.max(row[i]).max(0.0);
3918 }
3919 let n = adaptive_point_countwith_cap(maxvar.sqrt(), max_n);
3920
3921 let mut cov_arr = cov;
3926 for i in 0..D {
3927 cov_arr[i][i] = cov_arr[i][i].max(0.0);
3928 }
3929 let Some(l) = cholesky_static_with_jitter::<D>(&cov_arr) else {
3930 return Ok(None);
3931 };
3932 let norm = 1.0 / std::f64::consts::PI.powf(0.5 * D as f64);
3933
3934 with_gh_nodesweights(ctx, n, |nodes, weights| {
3935 let mut acc = R::zero();
3936 let mut idx = [0usize; D];
3937 loop {
3938 let mut z = [0.0_f64; D];
3939 let mut weight = 1.0_f64;
3940 for d in 0..D {
3941 z[d] = SQRT_2 * nodes[idx[d]];
3942 weight *= weights[idx[d]];
3943 }
3944
3945 let mut x = mu;
3946 for row in 0..D {
3947 let mut dot = 0.0_f64;
3948 for (col, zc) in z.iter().enumerate().take(row + 1) {
3949 dot += l[row][col] * *zc;
3950 }
3951 x[row] += dot;
3952 }
3953 acc.addweighted(weight, f(x)?);
3954
3955 let mut carry = true;
3956 for d in (0..D).rev() {
3957 idx[d] += 1;
3958 if idx[d] < n {
3959 carry = false;
3960 break;
3961 }
3962 idx[d] = 0;
3963 }
3964 if carry {
3965 break;
3966 }
3967 }
3968 Ok(Some(acc.scale(norm)))
3969 })
3970}
3971
3972#[inline]
3973fn ghq_nd_integrate<const D: usize, F, R>(
3974 ctx: &QuadratureContext,
3975 mu: [f64; D],
3976 cov: [[f64; D]; D],
3977 max_n: usize,
3978 f: F,
3979) -> Option<R>
3980where
3981 F: Fn([f64; D]) -> R,
3982 R: GhqValue,
3983{
3984 match ghq_nd_integrate_try::<D, _, R, Infallible>(ctx, mu, cov, max_n, |x| Ok(f(x))) {
3985 Ok(v) => v,
3986 Err(e) => match e {},
3987 }
3988}
3989
3990#[inline]
3991fn ghq_nd_integrate_result<const D: usize, F, R, E>(
3992 ctx: &QuadratureContext,
3993 mu: [f64; D],
3994 cov: [[f64; D]; D],
3995 max_n: usize,
3996 f: F,
3997) -> Result<Option<R>, E>
3998where
3999 F: Fn([f64; D]) -> Result<R, E>,
4000 R: GhqValue,
4001{
4002 ghq_nd_integrate_try::<D, _, R, E>(ctx, mu, cov, max_n, f)
4003}
4004
4005pub fn normal_expectation_nd_adaptive<const D: usize, F>(
4007 ctx: &QuadratureContext,
4008 mu: [f64; D],
4009 cov: [[f64; D]; D],
4010 max_n: usize,
4011 f: F,
4012) -> f64
4013where
4014 F: Fn([f64; D]) -> f64,
4015{
4016 match ghq_nd_integrate::<D, _, f64>(ctx, mu, cov, max_n, &f) {
4017 Some(v) => v,
4018 None => f(mu),
4019 }
4020}
4021
4022pub fn normal_expectation_nd_adaptive_result<const D: usize, F, R, E>(
4024 ctx: &QuadratureContext,
4025 mu: [f64; D],
4026 cov: [[f64; D]; D],
4027 max_n: usize,
4028 f: F,
4029) -> Result<R, E>
4030where
4031 F: Fn([f64; D]) -> Result<R, E>,
4032 R: GhqValue,
4033{
4034 match ghq_nd_integrate_result::<D, _, R, E>(ctx, mu, cov, max_n, &f)? {
4035 Some(v) => Ok(v),
4036 None => f(mu),
4037 }
4038}
4039
4040pub fn normal_expectation_2d_adaptive_result<F, E>(
4042 ctx: &QuadratureContext,
4043 mu: [f64; 2],
4044 cov: [[f64; 2]; 2],
4045 f: F,
4046) -> Result<f64, E>
4047where
4048 F: Fn(f64, f64) -> Result<f64, E>,
4049{
4050 normal_expectation_nd_adaptive_result::<2, _, _, E>(ctx, mu, cov, 21, |x| f(x[0], x[1]))
4051}
4052
4053pub fn normal_expectation_3d_adaptive<F>(
4055 ctx: &QuadratureContext,
4056 mu: [f64; 3],
4057 cov: [[f64; 3]; 3],
4058 f: F,
4059) -> f64
4060where
4061 F: Fn(f64, f64, f64) -> f64,
4062{
4063 normal_expectation_nd_adaptive::<3, _>(ctx, mu, cov, 15, |x| f(x[0], x[1], x[2]))
4065}
4066
4067#[inline]
4086pub fn probit_posterior_mean(eta: f64, se_eta: f64) -> f64 {
4087 if se_eta < 1e-10 {
4088 return gam_math::probability::normal_cdf(eta);
4089 }
4090 let denom = (1.0 + se_eta * se_eta).sqrt();
4091 gam_math::probability::normal_cdf(eta / denom)
4092}
4093
4094#[inline]
4095pub fn logit_posterior_meanvariance(ctx: &QuadratureContext, eta: f64, se_eta: f64) -> (f64, f64) {
4096 let m1 = integrate_normal_ghq_adaptive(ctx, eta, se_eta, sigmoid);
4097 let m2 = integrate_normal_ghq_adaptive(ctx, eta, se_eta, |x| {
4098 let p = sigmoid(x);
4099 p * p
4100 })
4101 .clamp(0.0, 1.0);
4102 (m1, (m2 - m1 * m1).max(0.0))
4103}
4104
4105#[inline]
4106pub fn probit_posterior_meanvariance(ctx: &QuadratureContext, eta: f64, se_eta: f64) -> (f64, f64) {
4107 let m1 = probit_posterior_mean(eta, se_eta);
4108 let m2 = integrate_normal_ghq_adaptive(ctx, eta, se_eta, |x| {
4109 let p = gam_math::probability::normal_cdf(x);
4110 p * p
4111 })
4112 .clamp(0.0, 1.0);
4113 (m1, (m2 - m1 * m1).max(0.0))
4114}
4115
4116#[inline]
4117pub fn cloglog_posterior_meanvariance(
4118 ctx: &QuadratureContext,
4119 eta: f64,
4120 se_eta: f64,
4121) -> (f64, f64) {
4122 if !(eta.is_finite() && se_eta.is_finite()) || se_eta <= CLOGLOG_SIGMA_DEGENERATE {
4142 return (cloglog_mean_exact(eta), 0.0);
4143 }
4144 let (survival, _) = cloglog_survival_term_controlled(ctx, eta, se_eta);
4145 let (survival_sq, _) = cloglog_survivalsecond_moment_controlled(ctx, eta, se_eta);
4146 let mean = cloglog_mean_from_survival(survival);
4147 let variance = (survival_sq - survival * survival).max(0.0);
4148 (mean, variance)
4149}
4150
4151#[inline]
4185pub fn cloglog_posterior_mean(ctx: &QuadratureContext, eta: f64, se_eta: f64) -> f64 {
4186 if !(eta.is_finite() && se_eta.is_finite()) || se_eta <= CLOGLOG_SIGMA_DEGENERATE {
4190 return cloglog_mean_exact(eta);
4191 }
4192 let (survival, _) = cloglog_survival_term_controlled(ctx, eta, se_eta);
4193 cloglog_mean_from_survival(survival)
4194}
4195
4196#[inline]
4210pub fn survival_posterior_mean(ctx: &QuadratureContext, eta: f64, se_eta: f64) -> f64 {
4211 cloglog_survival_term_controlled(ctx, eta, se_eta)
4212 .0
4213 .clamp(0.0, 1.0)
4214}
4215
4216#[inline]
4217pub fn survival_posterior_meanvariance(
4218 ctx: &QuadratureContext,
4219 eta: f64,
4220 se_eta: f64,
4221) -> (f64, f64) {
4222 let (m1, _) = cloglog_survival_term_controlled(ctx, eta, se_eta);
4223 let (m2, _) = cloglog_survivalsecond_moment_controlled(ctx, eta, se_eta);
4224 (m1.clamp(0.0, 1.0), (m2 - m1 * m1).max(0.0))
4225}
4226
4227pub fn logit_posterior_mean_exact(mu: f64, sigma: f64) -> f64 {
4303 if !(mu.is_finite() && sigma.is_finite()) || sigma <= 0.0 {
4304 return sigmoid(mu);
4305 }
4306 if sigma < LOGIT_SIGMA_DEGENERATE {
4307 return sigmoid(mu);
4310 }
4311
4312 let inv_sqrt_pi = 0.5 * std::f64::consts::FRAC_2_SQRT_PI; let sqrt2_sigma = SQRT_2 * sigma;
4314 let coeff = (2.0_f64 * std::f64::consts::PI).sqrt() / sigma; let c = -mu / sqrt2_sigma; let beta = std::f64::consts::PI / sqrt2_sigma; let r2 = FADDEEVA_ASYMPTOTIC_RADIUS * FADDEEVA_ASYMPTOTIC_RADIUS;
4318
4319 let mut corr = 0.0_f64;
4325 let mut n = 1usize;
4326 let tail_start = loop {
4327 let b = (2.0 * (n as f64) - 1.0) * beta;
4328 let abs_xi2 = c * c + b * b;
4329 if abs_xi2 > r2 && n >= FADDEEVA_TAIL_MIN_INDEX {
4330 break n;
4331 }
4332 let xi = Complex { re: c, im: b };
4333 let d = if abs_xi2 > r2 {
4334 inv_sqrt_pi * faddeeva_asymptotic_a(xi).re
4336 } else {
4337 faddeeva_upper_halfplane(xi).im - inv_sqrt_pi * c / abs_xi2
4338 };
4339 corr += d;
4340 n += 1;
4341 };
4342
4343 corr += faddeeva_pole_series_em_tail(c, beta, tail_start, inv_sqrt_pi);
4344
4345 sigmoid(mu) - coeff * corr
4346}
4347
4348const FADDEEVA_TAIL_MIN_INDEX: usize = 48;
4352const FADDEEVA_ASYMPTOTIC_RADIUS: f64 = 7.0;
4355const FADDEEVA_ASYMPTOTIC_TERMS: usize = 14;
4358
4359fn faddeeva_asymptotic_a(xi: Complex) -> Complex {
4363 let inv = complex_div(Complex { re: 1.0, im: 0.0 }, xi);
4364 let inv2 = complexmul(inv, inv);
4365 let mut xp = complexmul(inv2, inv); let mut cm = 0.5_f64; let mut s = Complex::default();
4368 for m in 1..=FADDEEVA_ASYMPTOTIC_TERMS {
4369 s = complex_add(
4370 s,
4371 Complex {
4372 re: cm * xp.re,
4373 im: cm * xp.im,
4374 },
4375 );
4376 cm *= (2.0 * (m as f64) + 1.0) / 2.0; xp = complexmul(xp, inv2);
4378 }
4379 s
4380}
4381
4382fn faddeeva_pole_series_em_tail(c: f64, beta: f64, tail_start: usize, inv_sqrt_pi: f64) -> f64 {
4391 let b_a = (2.0 * (tail_start as f64) - 1.0) * beta;
4392 let xi = Complex { re: c, im: b_a };
4393 let inv = complex_div(Complex { re: 1.0, im: 0.0 }, xi);
4394 let inv2 = complexmul(inv, inv);
4395 let two_i_beta = Complex {
4397 re: 0.0,
4398 im: 2.0 * beta,
4399 };
4400
4401 let mut s = Complex::default(); let mut a_acc = Complex::default(); let mut fp_inner = Complex::default(); let mut x2m = inv2; let mut x2m1 = complexmul(inv2, inv); let mut x2m2 = complexmul(inv2, inv2); let mut cm = 0.5_f64;
4409 for m in 1..=FADDEEVA_ASYMPTOTIC_TERMS {
4410 let mf = m as f64;
4411 let inv_4ibm = Complex {
4413 re: 0.0,
4414 im: -1.0 / (4.0 * beta * mf),
4415 };
4416 s = complex_add(
4417 s,
4418 complexmul(
4419 Complex {
4420 re: cm * x2m.re,
4421 im: cm * x2m.im,
4422 },
4423 inv_4ibm,
4424 ),
4425 );
4426 a_acc = complex_add(
4427 a_acc,
4428 Complex {
4429 re: cm * x2m1.re,
4430 im: cm * x2m1.im,
4431 },
4432 );
4433 let fc = cm * (-(2.0 * mf + 1.0));
4434 fp_inner = complex_add(
4435 fp_inner,
4436 Complex {
4437 re: fc * x2m2.re,
4438 im: fc * x2m2.im,
4439 },
4440 );
4441 cm *= (2.0 * mf + 1.0) / 2.0;
4442 x2m = complexmul(x2m, inv2);
4443 x2m1 = complexmul(x2m1, inv2);
4444 x2m2 = complexmul(x2m2, inv2);
4445 }
4446
4447 s = complex_add(
4449 s,
4450 Complex {
4451 re: 0.5 * a_acc.re,
4452 im: 0.5 * a_acc.im,
4453 },
4454 );
4455 let fprime = complexmul(two_i_beta, fp_inner);
4457 s = complex_add(
4458 s,
4459 Complex {
4460 re: -fprime.re / 12.0,
4461 im: -fprime.im / 12.0,
4462 },
4463 );
4464
4465 inv_sqrt_pi * s.re
4468}
4469
4470fn faddeeva_upper_halfplane(z: Complex) -> Complex {
4483 let (l, coeffs) = faddeeva_weideman_coeffs();
4484 let iz = Complex {
4485 re: -z.im,
4486 im: z.re,
4487 }; let l_minus = Complex {
4489 re: l - iz.re,
4490 im: -iz.im,
4491 }; let l_plus = Complex {
4493 re: l + iz.re,
4494 im: iz.im,
4495 }; let zz = complex_div(l_plus, l_minus); let mut p = Complex {
4499 re: coeffs[0],
4500 im: 0.0,
4501 };
4502 for &c in &coeffs[1..] {
4503 p = complex_add(complexmul(p, zz), Complex { re: c, im: 0.0 });
4504 }
4505 let l_minus_sq = complexmul(l_minus, l_minus);
4506 let term1 = complex_div(
4507 Complex {
4508 re: 2.0 * p.re,
4509 im: 2.0 * p.im,
4510 },
4511 l_minus_sq,
4512 );
4513 let inv_sqrt_pi = 0.5 * std::f64::consts::FRAC_2_SQRT_PI;
4514 let term2 = complex_div(
4515 Complex {
4516 re: inv_sqrt_pi,
4517 im: 0.0,
4518 },
4519 l_minus,
4520 );
4521 complex_add(term1, term2)
4522}
4523
4524const FADDEEVA_WEIDEMAN_N: usize = 44;
4527
4528fn faddeeva_weideman_coeffs() -> &'static (f64, [f64; FADDEEVA_WEIDEMAN_N]) {
4534 static CACHE: OnceLock<(f64, [f64; FADDEEVA_WEIDEMAN_N])> = OnceLock::new();
4535 CACHE.get_or_init(|| {
4536 let n = FADDEEVA_WEIDEMAN_N;
4537 let l = (n as f64 / SQRT_2).sqrt();
4538 let m = 2 * n;
4539 let m2 = 2 * m; let mut f = vec![0.0_f64; m2];
4543 for (idx, fi) in f.iter_mut().enumerate().skip(1) {
4544 let k = (idx as isize - 1) - (m as isize - 1);
4545 let theta = (k as f64) * std::f64::consts::PI / (m as f64);
4546 let t = l * (0.5 * theta).tan();
4547 *fi = (-t * t).exp() * (l * l + t * t);
4548 }
4549 let half = m2 / 2;
4552 let mut coeffs = [0.0_f64; FADDEEVA_WEIDEMAN_N];
4553 for j in 1..=n {
4554 let mut acc = 0.0_f64;
4555 for (p, _) in f.iter().enumerate() {
4556 let fp = f[(p + half) % m2];
4557 if fp != 0.0 {
4558 acc += fp
4559 * (-2.0 * std::f64::consts::PI * (j as f64) * (p as f64) / (m2 as f64))
4560 .cos();
4561 }
4562 }
4563 coeffs[n - j] = acc / (m2 as f64);
4565 }
4566 (l, coeffs)
4567 })
4568}
4569
4570#[inline]
4572fn sigmoid(x: f64) -> f64 {
4573 let x_clamped = x.clamp(-QUADRATURE_EXP_LOG_MAX, QUADRATURE_EXP_LOG_MAX);
4574 1.0 / (1.0 + f64::exp(-x_clamped))
4575}
4576
4577#[derive(Clone, Copy, Debug)]
4593pub struct CLogLogConvolutionDerivatives {
4594 pub l: f64,
4596
4597 pub l_mu: f64,
4599 pub l_sigma: f64,
4600
4601 pub l_mumu: f64,
4603 pub l_musigma: f64,
4604 pub l_sigmasigma: f64,
4605
4606 pub l_mumumu: f64,
4608 pub l_mumusigma: f64,
4609 pub l_musigmasigma: f64,
4610 pub l_sigmasigmasigma: f64,
4611
4612 pub l_mumumumu: f64,
4614 pub l_mumumusigma: f64,
4615 pub l_mumusigmasigma: f64,
4616 pub l_musigmasigmasigma: f64,
4617 pub l_sigmasigmasigmasigma: f64,
4618}
4619
4620#[inline]
4621pub(crate) fn cloglog_point_jet5(t: f64) -> (f64, f64, f64, f64, f64, f64) {
4622 if t.is_nan() {
4623 return (f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN);
4624 }
4625 let et = safe_exp(t);
4626
4627 (
4628 -(-et).exp_m1(),
4629 cloglog_stable_poly_times_exp_neg(et, &[0.0, 1.0]),
4630 cloglog_stable_poly_times_exp_neg(et, &[0.0, 1.0, -1.0]),
4631 cloglog_stable_poly_times_exp_neg(et, &[0.0, 1.0, -3.0, 1.0]),
4632 cloglog_stable_poly_times_exp_neg(et, &[0.0, 1.0, -7.0, 6.0, -1.0]),
4633 cloglog_stable_poly_times_exp_neg(et, &[0.0, 1.0, -15.0, 25.0, -10.0, 1.0]),
4634 )
4635}
4636
4637#[inline]
4649fn cloglog_g_derivatives(t: f64) -> (f64, f64, f64, f64, f64) {
4650 let (g, g1, g2, g3, g4, _) = cloglog_point_jet5(t);
4651 (g, g1, g2, g3, g4)
4652}
4653
4654pub fn cloglog_ghq_value(ctx: &QuadratureContext, mu: f64, sigma: f64, n_nodes: usize) -> f64 {
4662 if sigma.abs() < 1e-14 {
4663 let (g, _, _, _, _) = cloglog_g_derivatives(mu);
4664 return g.clamp(0.0, 1.0);
4665 }
4666 let inv_sqrt_pi = 1.0 / std::f64::consts::PI.sqrt();
4667 let scale = SQRT_2 * sigma;
4668 with_gh_nodesweights(ctx, n_nodes, |nodes, weights| {
4669 let mut sum = 0.0_f64;
4670 for i in 0..nodes.len() {
4671 let t = mu + scale * nodes[i];
4672 let (g, _, _, _, _) = cloglog_g_derivatives(t);
4673 sum += weights[i] * g;
4674 }
4675 (sum * inv_sqrt_pi).clamp(0.0, 1.0)
4676 })
4677}
4678
4679pub fn cloglog_ghq_derivatives(
4690 ctx: &QuadratureContext,
4691 mu: f64,
4692 sigma: f64,
4693 n_nodes: usize,
4694) -> CLogLogConvolutionDerivatives {
4695 let inv_sqrt_pi = 1.0 / std::f64::consts::PI.sqrt();
4696
4697 if sigma.abs() < 1e-14 {
4704 let (g, g1, g2, g3, g4) = cloglog_g_derivatives(mu);
4705 return CLogLogConvolutionDerivatives {
4706 l: g,
4707 l_mu: g1,
4708 l_sigma: 0.0,
4709 l_mumu: g2,
4710 l_musigma: 0.0,
4711 l_sigmasigma: g2,
4712 l_mumumu: g3,
4713 l_mumusigma: 0.0,
4714 l_musigmasigma: g3,
4715 l_sigmasigmasigma: 0.0,
4716 l_mumumumu: g4,
4717 l_mumumusigma: 0.0,
4718 l_mumusigmasigma: g4,
4719 l_musigmasigmasigma: 0.0,
4720 l_sigmasigmasigmasigma: 3.0 * g4,
4721 };
4722 }
4723
4724 let scale = SQRT_2 * sigma;
4725 let sqrt2 = SQRT_2;
4726
4727 with_gh_nodesweights(ctx, n_nodes, |nodes, weights| {
4728 let mut s = [[0.0_f64; 5]; 5];
4740
4741 for i in 0..nodes.len() {
4742 let x = nodes[i];
4743 let t = mu + scale * x;
4744 let (g0, g1, g2, g3, g4) = cloglog_g_derivatives(t);
4745 let w = weights[i];
4746
4747 let x2 = x * x;
4749 let x3 = x2 * x;
4750 let x4 = x3 * x;
4751
4752 s[0][0] += w * g0;
4754
4755 s[1][0] += w * g1;
4757 s[1][1] += w * x * g1;
4758
4759 s[2][0] += w * g2;
4761 s[2][1] += w * x * g2;
4762 s[2][2] += w * x2 * g2;
4763
4764 s[3][0] += w * g3;
4766 s[3][1] += w * x * g3;
4767 s[3][2] += w * x2 * g3;
4768 s[3][3] += w * x3 * g3;
4769
4770 s[4][0] += w * g4;
4772 s[4][1] += w * x * g4;
4773 s[4][2] += w * x2 * g4;
4774 s[4][3] += w * x3 * g4;
4775 s[4][4] += w * x4 * g4;
4776 }
4777
4778 let sqrt2_1 = sqrt2;
4781 let sqrt2_2 = 2.0; let sqrt2_3 = 2.0 * sqrt2; let sqrt2_4 = 4.0; CLogLogConvolutionDerivatives {
4786 l: inv_sqrt_pi * s[0][0],
4788
4789 l_mu: inv_sqrt_pi * s[1][0],
4791 l_sigma: inv_sqrt_pi * sqrt2_1 * s[1][1],
4792
4793 l_mumu: inv_sqrt_pi * s[2][0],
4795 l_musigma: inv_sqrt_pi * sqrt2_1 * s[2][1],
4796 l_sigmasigma: inv_sqrt_pi * sqrt2_2 * s[2][2],
4797
4798 l_mumumu: inv_sqrt_pi * s[3][0],
4800 l_mumusigma: inv_sqrt_pi * sqrt2_1 * s[3][1],
4801 l_musigmasigma: inv_sqrt_pi * sqrt2_2 * s[3][2],
4802 l_sigmasigmasigma: inv_sqrt_pi * sqrt2_3 * s[3][3],
4803
4804 l_mumumumu: inv_sqrt_pi * s[4][0],
4806 l_mumumusigma: inv_sqrt_pi * sqrt2_1 * s[4][1],
4807 l_mumusigmasigma: inv_sqrt_pi * sqrt2_2 * s[4][2],
4808 l_musigmasigmasigma: inv_sqrt_pi * sqrt2_3 * s[4][3],
4809 l_sigmasigmasigmasigma: inv_sqrt_pi * sqrt2_4 * s[4][4],
4810 }
4811 })
4812}
4813
4814pub fn cloglog_ghq_derivatives_adaptive(
4820 ctx: &QuadratureContext,
4821 mu: f64,
4822 sigma: f64,
4823) -> CLogLogConvolutionDerivatives {
4824 let n = adaptive_point_count_from_sd(sigma.abs());
4825 cloglog_ghq_derivatives(ctx, mu, sigma, n)
4826}
4827
4828#[cfg(test)]
4829mod tests {
4830 use super::*;
4831 use approx::assert_relative_eq;
4832
4833 pub(crate) fn cloglog_posterior_meanwith_deriv_gamma_reference(
4834 mu: f64,
4835 sigma: f64,
4836 ) -> Result<IntegratedMeanDerivative, EstimationError> {
4837 let survival = cloglog_survival_gamma_reference(mu, sigma)?;
4840 let shifted_survival = cloglog_survival_gamma_reference(mu + sigma * sigma, sigma)?;
4841 let mean = cloglog_mean_from_survival(survival);
4842 let dmean = cloglog_shift_identity_derivative(mu, sigma, shifted_survival);
4843 if !(mean.is_finite() && dmean.is_finite()) {
4844 crate::bail_invalid_estim!(
4845 "Gamma cloglog reference backend produced non-finite values"
4846 );
4847 }
4848 Ok(IntegratedMeanDerivative {
4849 mean,
4850 dmean_dmu: dmean.max(0.0),
4851 mode: IntegratedExpectationMode::ExactSpecialFunction,
4852 })
4853 }
4854
4855 fn even_moment_exp_neg_x2(power: usize) -> f64 {
4856 assert!(power.is_multiple_of(2));
4857 let m = power / 2;
4858 let mut odd_double_factorial = 1.0_f64;
4859 for k in 0..m {
4860 odd_double_factorial *= (2 * k + 1) as f64;
4861 }
4862 odd_double_factorial * std::f64::consts::PI.sqrt() / 2.0_f64.powi(m as i32)
4863 }
4864
4865 fn normal_pdf(z: f64) -> f64 {
4866 (-(z * z) * 0.5).exp() / (2.0 * std::f64::consts::PI).sqrt()
4867 }
4868
4869 fn high_res_sigmoid_integral(eta: f64, se: f64) -> f64 {
4870 let a = -12.0_f64;
4872 let b = 12.0_f64;
4873 let n = 20_000usize; let h = (b - a) / n as f64;
4875
4876 let integrand = |z: f64| -> f64 { sigmoid(eta + se * z) * normal_pdf(z) };
4877
4878 let mut sum = integrand(a) + integrand(b);
4879 for i in 1..n {
4880 let x = a + (i as f64) * h;
4881 if i % 2 == 0 {
4882 sum += 2.0 * integrand(x);
4883 } else {
4884 sum += 4.0 * integrand(x);
4885 }
4886 }
4887 sum * h / 3.0
4888 }
4889
4890 #[test]
4891 fn test_computed_nodes_symmetric() {
4892 let ctx = QuadratureContext::new();
4894 let gh = ctx.gauss_hermite();
4895 for i in 0..N_POINTS / 2 {
4896 let j = N_POINTS - 1 - i;
4897 assert_relative_eq!(gh.nodes[i], -gh.nodes[j], epsilon = 1e-12);
4898 }
4899 assert_relative_eq!(gh.nodes[N_POINTS / 2], 0.0, epsilon = 1e-12);
4901 }
4902
4903 #[test]
4904 fn test_computedweights_symmetric() {
4905 let ctx = QuadratureContext::new();
4907 let gh = ctx.gauss_hermite();
4908 for i in 0..N_POINTS / 2 {
4909 let j = N_POINTS - 1 - i;
4910 assert_relative_eq!(gh.weights[i], gh.weights[j], epsilon = 1e-12);
4911 }
4912 }
4913
4914 #[test]
4915 fn testweights_sum_to_sqrt_pi() {
4916 let ctx = QuadratureContext::new();
4918 let gh = ctx.gauss_hermite();
4919 let sum: f64 = gh.weights.iter().sum();
4920 assert_relative_eq!(sum, std::f64::consts::PI.sqrt(), epsilon = 1e-10);
4921 }
4922
4923 #[test]
4924 fn test_clenshaw_curtisweights_are_symmetric_and_integrate_constants() {
4925 let rule = compute_clenshaw_curtis_n(33);
4926 let m = rule.weights.len() - 1;
4927 for j in 0..=m / 2 {
4928 assert_relative_eq!(rule.nodes[j], -rule.nodes[m - j], epsilon = 1e-14);
4929 assert_relative_eq!(rule.weights[j], rule.weights[m - j], epsilon = 1e-14);
4930 }
4931 let sum: f64 = rule.weights.iter().sum();
4932 assert_relative_eq!(sum, 2.0, epsilon = 1e-14, max_relative = 1e-14);
4933 }
4934
4935 #[test]
4936 fn test_cc_preference_prefers_moderate_central_case() {
4937 assert!(cloglog_should_prefer_cc(-0.2, 0.8, CLOGLOG_CC_TOL));
4938 }
4939
4940 #[test]
4941 fn test_cc_preference_prefers_moderately_large_case() {
4942 assert!(cloglog_should_prefer_cc(0.0, 2.0, CLOGLOG_CC_TOL));
4943 }
4944
4945 #[test]
4946 fn test_cc_preference_rejects_broad_case() {
4947 assert!(!cloglog_should_prefer_cc(0.0, 5.0, CLOGLOG_CC_TOL));
4948 }
4949
4950 #[test]
4951 fn testwilkinson_shift_finitewhen_d_iszero() {
4952 let shift = wilkinson_shift(0.0, 0.0, 1.25);
4955 assert!(shift.is_finite());
4956 assert_relative_eq!(shift, -1.25, epsilon = 1e-14);
4957 }
4958
4959 #[test]
4960 fn test_matches_abramowitz_stegun_7_point_gauss_hermite_constants() {
4961 let known_nodes = [
4965 -2.651_961_356_835_233_4,
4966 -1.673_551_628_767_471_4,
4967 -0.816_287_882_858_964_7,
4968 0.0,
4969 0.816_287_882_858_964_7,
4970 1.673_551_628_767_471_4,
4971 2.651_961_356_835_233_4,
4972 ];
4973 let knownweights = [
4974 0.000_971_781_245_099_519_1,
4975 0.054_515_582_819_127_03,
4976 0.425_607_252_610_127_8,
4977 0.810_264_617_556_807_3,
4978 0.425_607_252_610_127_8,
4979 0.054_515_582_819_127_03,
4980 0.000_971_781_245_099_519_1,
4981 ];
4982
4983 let ctx = QuadratureContext::new();
4984 let gh = ctx.gauss_hermite();
4985 for i in 0..N_POINTS {
4986 assert_relative_eq!(gh.nodes[i], known_nodes[i], epsilon = 1e-12);
4987 assert_relative_eq!(gh.weights[i], knownweights[i], epsilon = 1e-12);
4988 }
4989 }
4990
4991 #[test]
4992 fn test_gauss_hermite_weight_assembly_uses_eigenvector_rows() {
4993 let mut diag = [0.0_f64; N_POINTS];
4994 let mut off_diag = [0.0_f64; N_POINTS - 1];
4995 for (i, od) in off_diag.iter_mut().enumerate() {
4996 *od = (((i + 1) as f64) / 2.0).sqrt();
4997 }
4998 let (nodes, eigenvectors) = symmetric_tridiagonal_eigen(&mut diag, &mut off_diag);
4999 let mu0 = std::f64::consts::PI.sqrt();
5000 let mut row_pairs: Vec<(f64, f64)> = (0..N_POINTS)
5001 .map(|i| (nodes[i], mu0 * eigenvectors[i][0] * eigenvectors[i][0]))
5002 .collect();
5003 let mut column_pairs: Vec<(f64, f64)> = (0..N_POINTS)
5004 .map(|i| (nodes[i], mu0 * eigenvectors[0][i] * eigenvectors[0][i]))
5005 .collect();
5006 row_pairs.sort_by(|a, b| a.0.total_cmp(&b.0));
5007 column_pairs.sort_by(|a, b| a.0.total_cmp(&b.0));
5008
5009 let knownweights = [
5010 0.000_971_781_245_099_519_1,
5011 0.054_515_582_819_127_03,
5012 0.425_607_252_610_127_8,
5013 0.810_264_617_556_807_3,
5014 0.425_607_252_610_127_8,
5015 0.054_515_582_819_127_03,
5016 0.000_971_781_245_099_519_1,
5017 ];
5018
5019 for i in 0..N_POINTS {
5020 assert_relative_eq!(row_pairs[i].1, knownweights[i], epsilon = 1e-12);
5021 }
5022 let column_error: f64 = column_pairs
5023 .iter()
5024 .zip(knownweights.iter())
5025 .map(|(actual, expected)| (actual.1 - expected).abs())
5026 .sum();
5027 assert!(
5028 column_error > 1.0,
5029 "column-oriented eigenvector indexing unexpectedly matched A&S weights"
5030 );
5031 }
5032
5033 #[test]
5034 fn testzero_se_returns_mode() {
5035 let eta = 1.5;
5037 let se = 0.0;
5038 let ctx = QuadratureContext::new();
5039 let mean = logit_posterior_mean(&ctx, eta, se);
5040 let mode = sigmoid(eta);
5041 assert_relative_eq!(mean, mode, epsilon = 1e-10);
5042 }
5043
5044 #[test]
5045 fn test_symmetric_atzero() {
5046 let eta = 0.0;
5048 let se = 1.0;
5049 let ctx = QuadratureContext::new();
5050 let mean = logit_posterior_mean(&ctx, eta, se);
5051 assert_relative_eq!(mean, 0.5, epsilon = 0.01);
5053 }
5054
5055 #[test]
5056 fn test_shrinkage_at_extremes() {
5057 let eta = 3.0; let se = 1.0;
5060 let ctx = QuadratureContext::new();
5061 let mean = logit_posterior_mean(&ctx, eta, se);
5062 let mode = sigmoid(eta);
5063
5064 assert!(mean < mode, "Expected mean {} < mode {}", mean, mode);
5066 assert!(mean > 0.8, "Mean {} should still be high", mean);
5068 }
5069
5070 #[test]
5071 fn test_matches_monte_carlo() {
5072 let eta = 2.0;
5074 let se = 0.8;
5075
5076 let ctx = QuadratureContext::new();
5077 let quad_mean = logit_posterior_mean(&ctx, eta, se);
5078
5079 let n_samples = 100_000;
5081 let mut mc_sum = 0.0;
5082 let mut rng_state = 12345u64; for _ in 0..n_samples {
5084 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
5086 let u1 = ((rng_state as f64) / (u64::MAX as f64)).max(1e-10); rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
5088 let u2 = (rng_state as f64) / (u64::MAX as f64);
5089 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
5090 let eta_sample = eta + se * z;
5091 mc_sum += sigmoid(eta_sample);
5092 }
5093 let mc_mean = mc_sum / (n_samples as f64);
5094
5095 assert_relative_eq!(quad_mean, mc_mean, epsilon = 0.01);
5097 }
5098
5099 #[test]
5100 fn test_quadrature_integrates_x_squared() {
5101 let ctx = QuadratureContext::new();
5104 let gh = ctx.gauss_hermite();
5105 let mut sum = 0.0;
5106 for i in 0..N_POINTS {
5107 sum += gh.weights[i] * gh.nodes[i] * gh.nodes[i];
5108 }
5109 let expected = std::f64::consts::PI.sqrt() / 2.0;
5110 assert_relative_eq!(sum, expected, epsilon = 1e-10);
5111 }
5112
5113 #[test]
5114 fn test_quadrature_integrates_x_fourth() {
5115 let ctx = QuadratureContext::new();
5118 let gh = ctx.gauss_hermite();
5119 let mut sum = 0.0;
5120 for i in 0..N_POINTS {
5121 let x = gh.nodes[i];
5122 sum += gh.weights[i] * x * x * x * x;
5123 }
5124 let expected = 3.0 * std::f64::consts::PI.sqrt() / 4.0;
5125 assert_relative_eq!(sum, expected, epsilon = 1e-10);
5126 }
5127
5128 #[test]
5129 fn test_moment_exactness_up_to_degree_13() {
5130 let ctx = QuadratureContext::new();
5131 let gh = ctx.gauss_hermite();
5132
5133 for degree in 0..=13usize {
5134 let approx: f64 = (0..N_POINTS)
5135 .map(|i| gh.weights[i] * gh.nodes[i].powi(degree as i32))
5136 .sum();
5137
5138 let expected = if degree % 2 == 1 {
5139 0.0
5140 } else {
5141 even_moment_exp_neg_x2(degree)
5142 };
5143
5144 let err = (approx - expected).abs();
5145 let rel_scale = approx.abs().max(expected.abs()).max(1.0);
5146 assert!(
5147 err <= 1e-10 || err / rel_scale <= 1e-10,
5148 "degree={} approx={} expected={} abs_err={}",
5149 degree,
5150 approx,
5151 expected,
5152 err
5153 );
5154 }
5155 }
5156
5157 #[test]
5158 fn test_integrated_sigmoid_matches_high_res_integral_random_pairs() {
5159 let ctx = QuadratureContext::new();
5160 let mut rng_state = 0x4d595df4d0f33173u64;
5161
5162 for _ in 0..20 {
5163 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
5164 let u_eta = (rng_state as f64) / (u64::MAX as f64);
5165 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
5166 let u_se = (rng_state as f64) / (u64::MAX as f64);
5167
5168 let eta = -6.0 + 12.0 * u_eta;
5169 let se = 0.02 + 1.5 * u_se;
5170
5171 let ghq = logit_posterior_mean(&ctx, eta, se);
5172 let numeric = high_res_sigmoid_integral(eta, se);
5173 assert_relative_eq!(ghq, numeric, epsilon = 2e-3);
5174 }
5175 }
5176
5177 #[test]
5178 fn test_logit_posterior_derivative_remains_positive_in_positive_tail() {
5179 let eta = 20.0;
5180 let se = 0.0;
5181 let (_, dmu) = logit_posterior_meanwith_deriv(eta, se)
5182 .expect("logit posterior mean derivative should evaluate");
5183 assert!(dmu > 0.0);
5184 assert!(
5185 dmu < 1e-6,
5186 "positive-tail derivative should stay tiny but nonzero, got {dmu}"
5187 );
5188 }
5189
5190 #[test]
5191 fn test_logit_posterior_derivative_matches_central_difference() {
5192 let ctx = QuadratureContext::new();
5193 let eta = 1.7;
5194 let se = 0.9;
5195 let h = 1e-5;
5196
5197 let (_, dmu) = logit_posterior_meanwith_deriv(eta, se)
5198 .expect("logit posterior mean derivative should evaluate");
5199 let mu_plus = logit_posterior_mean(&ctx, eta + h, se);
5200 let mu_minus = logit_posterior_mean(&ctx, eta - h, se);
5201 let dmufd = (mu_plus - mu_minus) / (2.0 * h);
5202
5203 assert_eq!(dmu.signum(), dmufd.signum());
5204 assert_relative_eq!(dmu, dmufd, epsilon = 5e-6, max_relative = 2e-4);
5205 }
5206
5207 fn dense_sigmoid_normal_mean(mu: f64, sigma: f64) -> f64 {
5213 let a = -18.0_f64;
5214 let b = 18.0_f64;
5215 let n = 400_000usize; let h = (b - a) / n as f64;
5217 let integrand = |z: f64| -> f64 { sigmoid(mu + sigma * z) * normal_pdf(z) };
5218 let mut sum = integrand(a) + integrand(b);
5219 for i in 1..n {
5220 let z = a + (i as f64) * h;
5221 sum += if i % 2 == 0 { 2.0 } else { 4.0 } * integrand(z);
5222 }
5223 sum * h / 3.0
5224 }
5225
5226 #[test]
5227 fn test_logit_posterior_mean_exact_symmetry_identity() {
5228 let cases = [
5231 (-3.0, 0.5),
5232 (-1.2, 1.7),
5233 (0.0, 2.2),
5234 (2.3, 0.8),
5235 (3.0, 0.05),
5236 ];
5237 for (mu, sigma) in cases {
5238 let p = logit_posterior_mean_exact(mu, sigma);
5239 let q = logit_posterior_mean_exact(-mu, sigma);
5240 assert!(
5241 (p + q - 1.0).abs() < 1e-12,
5242 "symmetry broken at mu={mu} sigma={sigma}: p+q-1 = {:.3e}",
5243 p + q - 1.0
5244 );
5245 }
5246 }
5247
5248 #[test]
5249 fn test_logit_posterior_mean_exact_matches_high_res_integral() {
5250 let cases = [
5254 (-2.0, 0.4),
5255 (-0.7, 1.1),
5256 (0.8, 0.9),
5257 (2.4, 1.7),
5258 (3.0, 0.05),
5259 (3.0, 0.5),
5260 (-2.0, 2.0),
5261 (5.0, 3.0),
5262 ];
5263 for (mu, sigma) in cases {
5264 let exact = logit_posterior_mean_exact(mu, sigma);
5265 let numeric = dense_sigmoid_normal_mean(mu, sigma);
5266 assert!(
5267 (exact - numeric).abs() < 1e-10,
5268 "oracle ≠ dense reference at mu={mu} sigma={sigma}: \
5269 exact={exact:.13} ref={numeric:.13} err={:.3e}",
5270 (exact - numeric).abs()
5271 );
5272 }
5273 }
5274
5275 #[test]
5282 fn test_logit_posterior_mean_exact_no_truncation_bias_1459() {
5283 let table = [
5286 (1.0, 0.02),
5287 (1.0, 0.05),
5288 (1.0, 0.5),
5289 (1.0, 2.0),
5290 (3.0, 0.02),
5291 (3.0, 0.05),
5292 (3.0, 0.5),
5293 (3.0, 2.0),
5294 (-2.0, 0.02),
5295 (-2.0, 0.05),
5296 (-2.0, 0.5),
5297 (-2.0, 2.0),
5298 ];
5299 for (mu, sigma) in table {
5300 let exact = logit_posterior_mean_exact(mu, sigma);
5301 let reference = dense_sigmoid_normal_mean(mu, sigma);
5302 let err = (exact - reference).abs();
5303 assert!(
5304 err < 1e-10,
5305 "#1459 truncation bias resurfaced at mu={mu} sigma={sigma}: \
5306 err={err:.3e} (pre-fix bias here was ~{:.2e})",
5307 mu.abs() / (2.0 * std::f64::consts::PI.powi(2) * 4096.0)
5308 );
5309 }
5310
5311 let mu = 3.0;
5316 let errs: Vec<f64> = [0.05, 0.5, 2.0]
5317 .iter()
5318 .map(|&s| logit_posterior_mean_exact(mu, s) - dense_sigmoid_normal_mean(mu, s))
5319 .collect();
5320 for e in &errs {
5321 assert!(
5322 e.abs() < 1e-10,
5323 "residual {e:.3e} at mu=3 — old σ-independent plateau was 3.71e-5"
5324 );
5325 }
5326 }
5327
5328 #[test]
5336 fn test_faddeeva_weideman_matches_known_values() {
5337 let w0 = faddeeva_upper_halfplane(Complex { re: 0.0, im: 0.0 });
5339 assert!(
5340 (w0.re - 1.0).abs() < 1e-13 && w0.im.abs() < 1e-13,
5341 "w(0)={w0:?}"
5342 );
5343 let on_axis = [
5345 (0.1, 0.8964569799691268),
5346 (0.5, 0.6156903441929258),
5347 (1.0, 0.427583576155807),
5348 (2.0, 0.2553956763105058),
5349 (5.0, 0.11070463773306861),
5350 (9.0, 0.06230772403777468),
5351 ];
5352 for (y, want) in on_axis {
5353 let w = faddeeva_upper_halfplane(Complex { re: 0.0, im: y });
5354 assert!(
5355 (w.re - want).abs() < 1e-13 && w.im.abs() < 1e-13,
5356 "w(i·{y}): got {w:?}, want re={want}, err={:.2e}",
5357 (w.re - want).abs()
5358 );
5359 }
5360 let off_axis = [
5362 ((0.7, 1.3), (0.31327301971562715, 0.12443489420104513)),
5363 ((-1.5, 0.8), (0.21066359024766423, -0.27001624496296617)),
5364 ((3.0, 0.4), (0.030278754646989155, 0.1957320888774461)),
5365 ];
5366 for ((re, im), (wre, wim)) in off_axis {
5367 let w = faddeeva_upper_halfplane(Complex { re, im });
5368 assert!(
5369 (w.re - wre).abs() < 1e-13 && (w.im - wim).abs() < 1e-13,
5370 "w({re}+{im}i): got {w:?}, want ({wre},{wim})"
5371 );
5372 }
5373 let w = faddeeva_upper_halfplane(Complex { re: 3.0, im: 40.0 });
5377 assert!(
5378 (w.re - 0.01402158696172506).abs() < 1e-13
5379 && (w.im - 0.0010509664408184546).abs() < 1e-13,
5380 "tail value mismatch: w={w:?}"
5381 );
5382 }
5383
5384 #[test]
5385 fn test_integrated_logit_mean_close_to_exact_oracle() {
5386 let ctx = QuadratureContext::new();
5390 let cases = [(-3.0, 0.3), (-1.0, 0.8), (0.5, 1.2), (2.8, 1.0)];
5391 for (eta, se) in cases {
5392 let ghq = logit_posterior_mean(&ctx, eta, se);
5393 let exact = logit_posterior_mean_exact(eta, se);
5394 assert!(
5395 (ghq - exact).abs() < 1e-6,
5396 "production path drifts from oracle at eta={eta} se={se}: \
5397 ghq={ghq:.12} oracle={exact:.12} gap={:.3e}",
5398 (ghq - exact).abs()
5399 );
5400 }
5401 }
5402
5403 #[test]
5404 fn test_probit_posterior_mean_reduces_to_map_atzero_se() {
5405 let eta = 1.25;
5406 let p = probit_posterior_mean(eta, 0.0);
5407 let map = gam_math::probability::normal_cdf(eta);
5408 assert_relative_eq!(p, map, epsilon = 1e-12);
5409 }
5410
5411 #[test]
5412 fn test_probit_posterior_mean_shrinks_extremeswith_uncertainty() {
5413 let hi_eta = 3.0;
5414 let lo_eta = -3.0;
5415 let p_hi_map = probit_posterior_mean(hi_eta, 0.0);
5416 let p_hi_unc = probit_posterior_mean(hi_eta, 2.0);
5417 let p_lo_map = probit_posterior_mean(lo_eta, 0.0);
5418 let p_lo_unc = probit_posterior_mean(lo_eta, 2.0);
5419 assert!(p_hi_unc < p_hi_map);
5420 assert!(p_lo_unc > p_lo_map);
5421 }
5422
5423 #[test]
5424 fn test_survival_posterior_mean_is_bounded_and_shrinks_tail() {
5425 let ctx = QuadratureContext::new();
5426 let eta: f64 = 3.0;
5427 let map = (-(eta.exp())).exp();
5428 let pm = survival_posterior_mean(&ctx, eta, 1.5);
5429 assert!((0.0..=1.0).contains(&pm));
5430 assert!(pm > map);
5431 }
5432
5433 #[test]
5434 fn test_cloglog_and_survival_posterior_means_are_complements() {
5435 let ctx = QuadratureContext::new();
5436 let cases = [
5437 (-3.0, 0.0),
5438 (-0.2, 0.1),
5439 (0.4, 0.8),
5440 (2.0, 1.5),
5441 (10.0, 0.3),
5442 (0.0, 20.0),
5443 (10.0, 10.0),
5444 (-0.5, 100.0),
5445 ];
5446 for (eta, se) in cases {
5447 let clog = cloglog_posterior_mean(&ctx, eta, se);
5448 let surv = survival_posterior_mean(&ctx, eta, se);
5449 assert_relative_eq!(clog + surv, 1.0, epsilon = 2e-10, max_relative = 2e-10);
5450 }
5451 }
5452
5453 #[test]
5454 fn test_cloglog_and_survival_share_large_sigmaspecial_function_path() {
5455 let ctx = QuadratureContext::new();
5456 let eta = -0.2;
5457 let se = 0.8;
5458 let clog = cloglog_posterior_mean(&ctx, eta, se);
5459 let surv = survival_posterior_mean(&ctx, eta, se);
5460 let integrated =
5461 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::CLogLog, eta, se)
5462 .expect("cloglog integrated inverse-link moments should evaluate");
5463 assert_eq!(
5464 integrated.mode,
5465 IntegratedExpectationMode::ExactSpecialFunction
5466 );
5467 assert_relative_eq!(clog, integrated.mean, epsilon = 1e-12, max_relative = 1e-12);
5468 assert_relative_eq!(clog + surv, 1.0, epsilon = 1e-10, max_relative = 1e-10);
5469 }
5470
5471 #[test]
5472 fn test_cloglog_and_survival_posteriorvariances_match() {
5473 let ctx = QuadratureContext::new();
5474 let cases = [(-3.0, 0.0), (-0.2, 0.1), (0.4, 0.8), (2.0, 1.5)];
5475 for (eta, se) in cases {
5476 let (_, clogvar) = cloglog_posterior_meanvariance(&ctx, eta, se);
5477 let (_, survvar) = survival_posterior_meanvariance(&ctx, eta, se);
5478 assert_relative_eq!(clogvar, survvar, epsilon = 1e-12, max_relative = 1e-12);
5479 }
5480 }
5481
5482 #[test]
5483 fn test_survivalvariance_uses_exactsecond_moment_shift() {
5484 let ctx = QuadratureContext::new();
5485 let eta = -0.2;
5486 let se = 0.8;
5487 let (survival, _) = cloglog_survival_term_controlled(&ctx, eta, se);
5488 let (survival_sq, _) = cloglog_survivalsecond_moment_controlled(&ctx, eta, se);
5489 let (_, variance) = survival_posterior_meanvariance(&ctx, eta, se);
5490 assert_relative_eq!(
5491 variance,
5492 (survival_sq - survival * survival).max(0.0),
5493 epsilon = 1e-12,
5494 max_relative = 1e-12
5495 );
5496 }
5497
5498 #[test]
5499 fn test_lognormal_laplace_shift_matches_explicitmu_plus_logz() {
5500 let ctx = QuadratureContext::new();
5501 let mu = -0.2;
5502 let sigma = 0.8;
5503 let z = 2.0;
5504 let shifted = lognormal_laplace_term_controlled(&ctx, z, mu, sigma);
5505 let explicit = cloglog_survival_term_controlled(&ctx, mu + z.ln(), sigma);
5506 assert_eq!(shifted.1, explicit.1);
5507 assert_relative_eq!(shifted.0, explicit.0, epsilon = 1e-12, max_relative = 1e-12);
5508 }
5509
5510 #[test]
5511 fn test_integrated_dispatch_uses_closed_form_probit() {
5512 let ctx = QuadratureContext::new();
5513 let out = integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Probit, 0.7, 1.3)
5514 .expect("probit integrated inverse-link moments should evaluate");
5515 assert_eq!(out.mode, IntegratedExpectationMode::ExactClosedForm);
5516 let direct = probit_posterior_meanwith_deriv_exact(0.7, 1.3);
5517 assert_relative_eq!(out.mean, direct.mean, epsilon = 1e-12);
5518 assert_relative_eq!(out.dmean_dmu, direct.dmean_dmu, epsilon = 1e-12);
5519 }
5520
5521 #[test]
5522 fn test_integrated_probit_jet_matches_closed_form_derivatives() {
5523 let ctx = QuadratureContext::new();
5524 let mu = 0.7;
5525 let sigma = 1.3;
5526 let out = integrated_inverse_link_jet(&ctx, LinkFunction::Probit, mu, sigma)
5527 .expect("probit integrated inverse-link jet should evaluate");
5528 let s = (1.0 + sigma * sigma).sqrt();
5529 let z = mu / s;
5530 let pdf = gam_math::probability::normal_pdf(z);
5531 assert_relative_eq!(out.mean, gam_math::probability::normal_cdf(z), epsilon = 1e-12);
5532 assert_relative_eq!(out.d1, pdf / s, epsilon = 1e-12);
5533 assert_relative_eq!(out.d2, -z * pdf / (s * s), epsilon = 1e-12);
5534 assert_relative_eq!(out.d3, (z * z - 1.0) * pdf / (s * s * s), epsilon = 1e-12);
5535 }
5536
5537 #[test]
5538 fn test_integrated_logit_jet_matches_central_differences() {
5539 let ctx = QuadratureContext::new();
5552 let mu = 1.1;
5553 let sigma = 0.8;
5554 let out = integrated_inverse_link_jet(&ctx, LinkFunction::Logit, mu, sigma)
5555 .expect("logit integrated inverse-link jet should evaluate");
5556 assert!(matches!(
5557 out.mode,
5558 IntegratedExpectationMode::ExactSpecialFunction
5559 | IntegratedExpectationMode::QuadratureFallback
5560 ));
5561 let (ref_mean, ref_d1, ref_d2, ref_d3) = logit_reference_jet_highres_simpson(mu, sigma);
5562 assert_relative_eq!(out.mean, ref_mean, epsilon = 1e-11, max_relative = 1e-10);
5563 assert_relative_eq!(out.d1, ref_d1, epsilon = 1e-11, max_relative = 1e-10);
5564 assert_relative_eq!(out.d2, ref_d2, epsilon = 1e-11, max_relative = 1e-10);
5565 assert_relative_eq!(out.d3, ref_d3, epsilon = 1e-11, max_relative = 1e-10);
5566 }
5567
5568 #[test]
5569 fn test_integrated_logit_pirls_jet_matches_general_dispatch() {
5570 let ctx = QuadratureContext::new();
5580 let mu = 1.1;
5581 let sigma = 0.8;
5582
5583 let pirls =
5584 integrated_logit_inverse_link_jet_pirls(&ctx, mu, sigma).expect("PIRLS logit jet");
5585 let general = integrated_inverse_link_jet(&ctx, LinkFunction::Logit, mu, sigma)
5586 .expect("general logit jet");
5587
5588 assert!(matches!(
5589 pirls.mode,
5590 IntegratedExpectationMode::ExactSpecialFunction
5591 | IntegratedExpectationMode::QuadratureFallback
5592 ));
5593 assert_eq!(pirls.mode, general.mode);
5594 assert_relative_eq!(pirls.mean, general.mean, epsilon = 1e-12);
5595 assert_relative_eq!(pirls.d1, general.d1, epsilon = 1e-12);
5596 assert_relative_eq!(pirls.d2, general.d2, epsilon = 1e-10);
5597 assert_relative_eq!(pirls.d3, general.d3, epsilon = 1e-8);
5598 }
5599
5600 #[test]
5601 fn test_integrated_cloglog_jet_matches_central_differences() {
5602 let ctx = QuadratureContext::new();
5603 let mu = 0.4;
5604 let sigma = 0.6;
5605 let h = 1e-4;
5606 let out = integrated_inverse_link_jet(&ctx, LinkFunction::CLogLog, mu, sigma)
5607 .expect("cloglog integrated inverse-link jet should evaluate");
5608 let plus = integrated_inverse_link_jet(&ctx, LinkFunction::CLogLog, mu + h, sigma)
5609 .expect("cloglog integrated inverse-link jet should evaluate");
5610 let minus = integrated_inverse_link_jet(&ctx, LinkFunction::CLogLog, mu - h, sigma)
5611 .expect("cloglog integrated inverse-link jet should evaluate");
5612 let d1fd = (plus.mean - minus.mean) / (2.0 * h);
5613 let d2fd = (plus.d1 - minus.d1) / (2.0 * h);
5614 let d3fd = (plus.d2 - minus.d2) / (2.0 * h);
5615 assert_eq!(out.d1.signum(), d1fd.signum());
5616 assert_eq!(out.d2.signum(), d2fd.signum());
5617 assert_eq!(out.d3.signum(), d3fd.signum());
5618 assert_relative_eq!(out.d1, d1fd, epsilon = 2e-5, max_relative = 3e-4);
5619 assert_relative_eq!(out.d2, d2fd, epsilon = 4e-5, max_relative = 8e-4);
5620 assert_relative_eq!(out.d3, d3fd, epsilon = 8e-5, max_relative = 2e-3);
5621 }
5622
5623 #[test]
5624 fn test_integrated_cloglog_wide_sigma_d3_matches_simpson_and_d2_slope() {
5625 let ctx = QuadratureContext::new();
5626 let cases = [(0.0, 4.0), (-1.0, 4.0), (2.0, 3.0), (3.0, 3.0)];
5627 let h = 1e-4;
5628
5629 for (mu, sigma) in cases {
5630 let out = integrated_inverse_link_jet(&ctx, LinkFunction::CLogLog, mu, sigma)
5631 .expect("wide-sigma cloglog integrated jet should evaluate");
5632 let reference = cloglog_reference_jet_highres_simpson(mu, sigma);
5633 let plus = integrated_inverse_link_jet(&ctx, LinkFunction::CLogLog, mu + h, sigma)
5634 .expect("wide-sigma cloglog integrated jet should evaluate");
5635 let minus = integrated_inverse_link_jet(&ctx, LinkFunction::CLogLog, mu - h, sigma)
5636 .expect("wide-sigma cloglog integrated jet should evaluate");
5637 let d3fd = (plus.d2 - minus.d2) / (2.0 * h);
5638
5639 assert_eq!(out.mode, IntegratedExpectationMode::QuadratureFallback);
5640 assert_relative_eq!(out.mean, reference.0, epsilon = 4e-8, max_relative = 4e-8);
5641 assert_relative_eq!(out.d1, reference.1, epsilon = 4e-8, max_relative = 4e-8);
5642 assert_relative_eq!(out.d2, reference.2, epsilon = 2e-9, max_relative = 2e-7);
5643 assert_relative_eq!(out.d3, reference.3, epsilon = 2e-9, max_relative = 2e-7);
5644 assert_relative_eq!(out.d3, d3fd, epsilon = 2e-7, max_relative = 4e-5);
5645 }
5646 }
5647
5648 #[test]
5649 fn test_latent_cloglog_jet5_matches_higher_order_central_differences() {
5650 let ctx = QuadratureContext::new();
5651 let mu = 0.35;
5652 let sigma = 0.7;
5653 let h = 2e-4;
5654
5655 let out = latent_cloglog_inverse_link_jet5_controlled(&ctx, mu, sigma);
5656 let plus = latent_cloglog_inverse_link_jet5_controlled(&ctx, mu + h, sigma);
5657 let minus = latent_cloglog_inverse_link_jet5_controlled(&ctx, mu - h, sigma);
5658
5659 let d4fd = (plus.d3 - minus.d3) / (2.0 * h);
5660 let d5fd = (plus.d4 - minus.d4) / (2.0 * h);
5661
5662 assert_eq!(out.d4.signum(), d4fd.signum());
5663 assert_eq!(out.d5.signum(), d5fd.signum());
5664 assert_relative_eq!(out.d4, d4fd, epsilon = 2e-4, max_relative = 5e-3);
5665 assert_relative_eq!(out.d5, d5fd, epsilon = 6e-4, max_relative = 2e-2);
5666 }
5667
5668 #[test]
5669 fn test_logit_exact_derivative_matches_finite_difference() {
5670 let out = logit_posterior_meanwith_deriv_controlled(1.1, 0.8).expect("controlled logit");
5680 let (ref_mean, ref_d1, _, _) = logit_reference_jet_highres_simpson(1.1, 0.8);
5681 assert_relative_eq!(out.mean, ref_mean, epsilon = 1e-11, max_relative = 1e-10);
5682 assert!(out.dmean_dmu > 0.0);
5683 assert_relative_eq!(out.dmean_dmu, ref_d1, epsilon = 1e-11, max_relative = 1e-10);
5684 }
5685
5686 #[test]
5687 fn test_logit_exact_clamped_degenerate_branch_is_locally_flat() {
5688 let out = logit_posterior_meanwith_deriv_exact(-710.0, 0.0).expect("exact logit");
5689 let h = 1e-6;
5690 let plus = logit_posterior_meanwith_deriv_exact(-710.0 + h, 0.0)
5691 .expect("exact logit plus")
5692 .mean;
5693 let minus = logit_posterior_meanwith_deriv_exact(-710.0 - h, 0.0)
5694 .expect("exact logit minus")
5695 .mean;
5696 let fd = (plus - minus) / (2.0 * h);
5697 assert_eq!(fd, 0.0);
5698 assert_eq!(out.dmean_dmu, 0.0);
5699 }
5700
5701 fn simpson_integrate<F>(a: f64, b: f64, n_intervals: usize, f: F) -> f64
5702 where
5703 F: Fn(f64) -> f64,
5704 {
5705 assert_eq!(n_intervals % 2, 0, "Simpson integration requires an even n");
5706 let h = (b - a) / n_intervals as f64;
5707 let mut sum = f(a) + f(b);
5708 for i in 1..n_intervals {
5709 let x = a + i as f64 * h;
5710 let w = if i % 2 == 0 { 2.0 } else { 4.0 };
5711 sum += w * f(x);
5712 }
5713 sum * h / 3.0
5714 }
5715
5716 fn cloglog_reference_mean_and_derivative(mu: f64, sigma: f64) -> (f64, f64) {
5717 if sigma <= CLOGLOG_SIGMA_DEGENERATE {
5718 return (cloglog_mean_exact(mu), cloglog_mean_d1_exact(mu));
5719 }
5720
5721 let z_max = 12.0;
5725 let n_intervals = 4096;
5726 let inv_sqrt_2pi = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
5727 let mean = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5728 let eta = mu + sigma * z;
5729 inv_sqrt_2pi * (-0.5 * z * z).exp() * cloglog_mean_exact(eta)
5730 });
5731 let deriv = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5732 let eta = mu + sigma * z;
5733 inv_sqrt_2pi * (-0.5 * z * z).exp() * cloglog_mean_d1_exact(eta)
5734 });
5735 (mean, deriv)
5736 }
5737
5738 fn logit_reference_jet_highres_simpson(mu: f64, sigma: f64) -> (f64, f64, f64, f64) {
5751 let z_max = 14.0;
5752 let n_intervals = 16384;
5753 let inv_sqrt_2pi = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
5754 let phi = |z: f64| inv_sqrt_2pi * (-0.5 * z * z).exp();
5755 let mean = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5756 let eta = mu + sigma * z;
5757 let (p, _, _, _) = component_point_jet(LinkComponent::Logit, eta);
5758 phi(z) * p
5759 });
5760 let d1 = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5761 let eta = mu + sigma * z;
5762 let (_, p1, _, _) = component_point_jet(LinkComponent::Logit, eta);
5763 phi(z) * p1
5764 });
5765 let d2 = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5766 let eta = mu + sigma * z;
5767 let (_, _, p2, _) = component_point_jet(LinkComponent::Logit, eta);
5768 phi(z) * p2
5769 });
5770 let d3 = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5771 let eta = mu + sigma * z;
5772 let (_, _, _, p3) = component_point_jet(LinkComponent::Logit, eta);
5773 phi(z) * p3
5774 });
5775 (mean, d1, d2, d3)
5776 }
5777
5778 fn cloglog_reference_jet_highres_simpson(mu: f64, sigma: f64) -> (f64, f64, f64, f64) {
5779 let z_max = 14.0;
5780 let n_intervals = 16384;
5781 let inv_sqrt_2pi = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
5782 let phi = |z: f64| inv_sqrt_2pi * (-0.5 * z * z).exp();
5783 let mean = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5784 let eta = mu + sigma * z;
5785 let (g, _, _, _, _, _) = cloglog_point_jet5(eta);
5786 phi(z) * g
5787 });
5788 let d1 = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5789 let eta = mu + sigma * z;
5790 let (_, g1, _, _, _, _) = cloglog_point_jet5(eta);
5791 phi(z) * g1
5792 });
5793 let d2 = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5794 let eta = mu + sigma * z;
5795 let (_, _, g2, _, _, _) = cloglog_point_jet5(eta);
5796 phi(z) * g2
5797 });
5798 let d3 = simpson_integrate(-z_max, z_max, n_intervals, |z| {
5799 let eta = mu + sigma * z;
5800 let (_, _, _, g3, _, _) = cloglog_point_jet5(eta);
5801 phi(z) * g3
5802 });
5803 (mean, d1, d2, d3)
5804 }
5805
5806 #[test]
5807 fn test_cloglog_taylor_negative_tail_matches_mathematical_target() {
5808 let mu = -40.0;
5809 let sigma = 0.1;
5810 let out = cloglog_small_sigma_taylor(mu, sigma);
5811 let (expected_mean, expected_deriv) = cloglog_reference_mean_and_derivative(mu, sigma);
5812
5813 assert!(
5814 out.dmean_dmu > 0.0,
5815 "negative-tail derivative should remain positive"
5816 );
5817 assert_relative_eq!(
5818 out.mean,
5819 expected_mean,
5820 epsilon = 1e-30,
5821 max_relative = 1e-12
5822 );
5823 assert_relative_eq!(
5824 out.dmean_dmu,
5825 expected_deriv,
5826 epsilon = 1e-30,
5827 max_relative = 1e-12
5828 );
5829 }
5830
5831 #[test]
5832 fn test_cloglog_degenerate_negative_tail_matches_pointwise_target() {
5833 let ctx = QuadratureContext::new();
5834 let mu = -40.0;
5835 let out = cloglog_posterior_meanwith_deriv_controlled(&ctx, mu, 0.0);
5836
5837 assert!(
5838 out.dmean_dmu > 0.0,
5839 "degenerate negative-tail derivative should remain positive"
5840 );
5841 assert_relative_eq!(
5842 out.mean,
5843 cloglog_mean_exact(mu),
5844 epsilon = 1e-30,
5845 max_relative = 1e-15
5846 );
5847 assert_relative_eq!(
5848 out.dmean_dmu,
5849 cloglog_mean_d1_exact(mu),
5850 epsilon = 1e-30,
5851 max_relative = 1e-15
5852 );
5853 }
5854
5855 #[test]
5856 fn test_degenerate_probit_and_logit_jets_are_flat_on_active_clamps() {
5857 let probit = integrated_probit_jet(-40.0, 0.0);
5858 assert_eq!(probit.d1, 0.0);
5859 assert_eq!(probit.d2, 0.0);
5860 assert_eq!(probit.d3, 0.0);
5861
5862 let logit = component_point_jet(LinkComponent::Logit, -710.0);
5863 assert_eq!(logit.1, 0.0);
5864 assert_eq!(logit.2, 0.0);
5865 assert_eq!(logit.3, 0.0);
5866 }
5867
5868 #[test]
5869 fn test_degenerate_cloglog_component_jet_preserves_smooth_negative_tail() {
5870 let eta: f64 = -40.0;
5871 let t = eta.exp();
5872 let s = (-t).exp();
5873 let cloglog = component_point_jet(LinkComponent::CLogLog, eta);
5874 let expected_mean = -(-t).exp_m1();
5875 let expected_d1 = t * s;
5876 let expected_d2 = (t - t * t) * s;
5877 let expected_d3 = (t - 3.0 * t * t + t * t * t) * s;
5878
5879 assert!(cloglog.1 > 0.0, "negative-tail d1 should remain positive");
5880 assert_relative_eq!(
5881 cloglog.0,
5882 expected_mean,
5883 epsilon = 1e-30,
5884 max_relative = 1e-15
5885 );
5886 assert_relative_eq!(
5887 cloglog.1,
5888 expected_d1,
5889 epsilon = 1e-30,
5890 max_relative = 1e-15
5891 );
5892 assert_relative_eq!(
5893 cloglog.2,
5894 expected_d2,
5895 epsilon = 1e-30,
5896 max_relative = 1e-15
5897 );
5898 assert_relative_eq!(
5899 cloglog.3,
5900 expected_d3,
5901 epsilon = 1e-30,
5902 max_relative = 1e-15
5903 );
5904 }
5905
5906 #[test]
5907 fn test_zero_sigma_logit_and_cloglog_share_component_tail_jets() {
5908 let ctx = QuadratureContext::new();
5909 for (link, component, eta) in [
5910 (LinkFunction::Logit, LinkComponent::Logit, 50.0),
5911 (LinkFunction::CLogLog, LinkComponent::CLogLog, -50.0),
5912 ] {
5913 let integrated = integrated_inverse_link_jet(&ctx, link, eta, 0.0)
5914 .expect("degenerate integrated jet");
5915 let point = component_inverse_link_jet(component, eta);
5916 assert_eq!(integrated.mode, IntegratedExpectationMode::ExactClosedForm);
5917 assert_eq!(integrated.mean, point.mu);
5918 assert_eq!(integrated.d1, point.d1);
5919 assert_eq!(integrated.d2, point.d2);
5920 assert_eq!(integrated.d3, point.d3);
5921 }
5922 }
5923
5924 #[test]
5925 fn test_cloglog_controlled_matches_mathematical_target_on_small_sigma_grid() {
5926 let ctx = QuadratureContext::new();
5927 let cases = [
5931 (-30.0, 1e-10),
5932 (-30.0, 0.1),
5933 (-10.0, 0.24),
5934 (-3.0, 0.2),
5935 (0.0, 0.05),
5936 (0.4, 0.1),
5937 (3.0, 0.24),
5938 (10.0, 0.1),
5939 (30.0, 0.24),
5940 ];
5941
5942 for &(mu, sigma) in &cases {
5943 let approx = cloglog_posterior_meanwith_deriv_controlled(&ctx, mu, sigma);
5944 let (expected_mean, expected_deriv) = cloglog_reference_mean_and_derivative(mu, sigma);
5945 assert_relative_eq!(
5946 approx.mean,
5947 expected_mean,
5948 epsilon = 1e-12,
5949 max_relative = 2e-3
5950 );
5951 assert_relative_eq!(
5952 approx.dmean_dmu,
5953 expected_deriv,
5954 epsilon = 1e-12,
5955 max_relative = 4e-3
5956 );
5957 }
5958 }
5959
5960 #[test]
5961 fn test_cloglog_dispatch_uses_gamma_backend_for_large_sigma_central_regime() {
5962 let ctx = QuadratureContext::new();
5963 let out =
5964 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::CLogLog, -0.2, 0.8)
5965 .expect("cloglog integrated inverse-link moments should evaluate");
5966 assert_eq!(out.mode, IntegratedExpectationMode::ExactSpecialFunction);
5967 assert!(out.mean.is_finite());
5968 assert!(out.dmean_dmu.is_finite());
5969 assert!(out.dmean_dmu >= 0.0);
5970 }
5971
5972 #[test]
5973 fn test_cloglog_dispatch_uses_large_sigma_asymptotic_without_ghq() {
5974 let ctx = QuadratureContext::new();
5975 let out =
5976 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::CLogLog, 0.0, 20.0)
5977 .expect("cloglog integrated inverse-link moments should evaluate");
5978 assert_eq!(out.mode, IntegratedExpectationMode::ControlledAsymptotic);
5979 assert!(out.mean.is_finite());
5980 assert!(out.dmean_dmu.is_finite());
5981 assert!(out.dmean_dmu >= 0.0);
5982 }
5983
5984 #[test]
5985 fn test_cloglog_cc_matches_gamma_reference_on_central_case() {
5986 let ctx = QuadratureContext::new();
5987 let mu = -0.2;
5988 let sigma = 0.8;
5989 let cc = cloglog_survival_cc(&ctx, mu, sigma, CLOGLOG_CC_TOL).expect("cc backend");
5990 let gamma = cloglog_survival_gamma_reference(mu, sigma).expect("gamma backend");
5991 assert_relative_eq!(cc, gamma, epsilon = 5e-6, max_relative = 5e-6);
5992 }
5993
5994 #[test]
5995 fn test_cloglog_gamma_reference_matches_seeded_monte_carlo_small_case() {
5996 let mu = -0.2;
5997 let sigma = 0.8;
5998 let gamma =
5999 cloglog_posterior_meanwith_deriv_gamma_reference(mu, sigma).expect("gamma reference");
6000 let mut rng_state = 0x9e3779b97f4a7c15u64;
6001 let mut mean_mc = 0.0f64;
6002 let mut deriv_mc = 0.0f64;
6003 let n_samples = 300_000usize;
6004 for _ in 0..n_samples {
6005 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
6006 let u1 = ((rng_state as f64) / (u64::MAX as f64)).clamp(1e-12, 1.0 - 1e-12);
6007 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
6008 let u2 = ((rng_state as f64) / (u64::MAX as f64)).clamp(1e-12, 1.0 - 1e-12);
6009 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
6010 let eta = mu + sigma * z;
6011 mean_mc += cloglog_mean_exact(eta);
6012 deriv_mc += cloglog_mean_d1_exact(eta);
6013 }
6014 mean_mc /= n_samples as f64;
6015 deriv_mc /= n_samples as f64;
6016 assert_relative_eq!(gamma.mean, mean_mc, epsilon = 2e-3, max_relative = 2e-3);
6017 assert_relative_eq!(
6018 gamma.dmean_dmu,
6019 deriv_mc,
6020 epsilon = 2e-3,
6021 max_relative = 2e-3
6022 );
6023 }
6024
6025 #[test]
6026 fn test_logit_dispatch_uses_tail_asymptotic_outside_old_guard() {
6027 let ctx = QuadratureContext::new();
6028 let out = integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, 35.0, 1.0)
6029 .expect("logit integrated inverse-link moments should evaluate");
6030 assert_eq!(out.mode, IntegratedExpectationMode::ControlledAsymptotic);
6031 assert!(out.mean.is_finite());
6032 assert!(out.dmean_dmu.is_finite());
6033 assert!(out.dmean_dmu >= 0.0);
6034 }
6035
6036 #[test]
6037 fn test_logit_dispatch_prefers_erfcx_in_moderate_regime() {
6038 let ctx = QuadratureContext::new();
6049 let out = integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, 1.1, 0.8)
6050 .expect("logit integrated inverse-link moments should evaluate");
6051 assert!(matches!(
6052 out.mode,
6053 IntegratedExpectationMode::ExactSpecialFunction
6054 | IntegratedExpectationMode::QuadratureFallback
6055 ));
6056 assert!(out.mean.is_finite());
6057 assert!(out.dmean_dmu.is_finite());
6058 assert!(out.dmean_dmu >= 0.0);
6059 let (ref_mean, ref_d1, _, _) = logit_reference_jet_highres_simpson(1.1, 0.8);
6060 assert_relative_eq!(out.mean, ref_mean, epsilon = 1e-11, max_relative = 1e-10);
6061 assert_relative_eq!(out.dmean_dmu, ref_d1, epsilon = 1e-11, max_relative = 1e-10);
6062 }
6063
6064 #[test]
6065 fn test_logit_dispatch_large_sigma_uses_accurate_quadrature_not_monahan() {
6066 let ctx = QuadratureContext::new();
6075 let out = integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, 0.5, 20.0)
6076 .expect("logit integrated inverse-link moments should evaluate");
6077 assert_eq!(out.mode, IntegratedExpectationMode::QuadratureFallback);
6078 let (ref_mean, ref_d1, _, _) = logit_reference_jet_highres_simpson(0.5, 20.0);
6079 assert_relative_eq!(out.mean, ref_mean, epsilon = 1e-9, max_relative = 1e-7);
6080 assert_relative_eq!(out.dmean_dmu, ref_d1, epsilon = 1e-9, max_relative = 1e-7);
6081 let kappa = (1.0 + std::f64::consts::PI * 20.0 * 20.0 / 8.0)
6084 .sqrt()
6085 .recip();
6086 let monahan_mean = gam_math::probability::normal_cdf(0.5 * kappa);
6087 assert!(
6088 (out.mean - monahan_mean).abs() > 1e-3,
6089 "dispatcher must not return the inaccurate Monahan mean {monahan_mean}; got {}",
6090 out.mean
6091 );
6092 }
6093
6094 #[test]
6095 fn test_logit_controlled_path_keeps_exact_backend_in_moderate_regime() {
6096 let out = logit_posterior_meanwith_deriv_controlled(1.1, 0.8).expect("logit controlled");
6106 assert!(matches!(
6107 out.mode,
6108 IntegratedExpectationMode::ExactSpecialFunction
6109 | IntegratedExpectationMode::QuadratureFallback
6110 ));
6111 let (ref_mean, ref_d1, _, _) = logit_reference_jet_highres_simpson(1.1, 0.8);
6112 assert_relative_eq!(out.mean, ref_mean, epsilon = 1e-11, max_relative = 1e-10);
6113 assert_relative_eq!(out.dmean_dmu, ref_d1, epsilon = 1e-11, max_relative = 1e-10);
6114 }
6115
6116 #[test]
6117 fn test_logit_dispatch_derivative_correct_at_mu_zero_small_sigma() {
6118 let ctx = QuadratureContext::new();
6127 for &(mu, sigma) in &[(0.0, 0.3), (0.0, 0.4), (0.0, 0.5)] {
6128 let out =
6129 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, mu, sigma)
6130 .expect("logit integrated inverse-link moments should evaluate");
6131 assert_relative_eq!(out.mean, 0.5, epsilon = 1e-10);
6133 assert!(
6135 out.dmean_dmu <= 0.25 + 1e-9,
6136 "E[sigmoid'] must not exceed 0.25 at (μ={mu}, σ={sigma}); got {}",
6137 out.dmean_dmu
6138 );
6139 let (_, ref_d1, _, _) = logit_reference_jet_highres_simpson(mu, sigma);
6140 assert_relative_eq!(out.dmean_dmu, ref_d1, epsilon = 1e-9, max_relative = 1e-6);
6141 }
6142 }
6143
6144 #[test]
6145 fn test_logit_erfcx_exact_branch_is_self_certified() {
6146 for &(mu, sigma) in &[(8.0, 1.0), (10.0, 1.0), (15.0, 2.0)] {
6153 let out = logit_posterior_meanwith_deriv_exact(mu, sigma)
6154 .expect("erfcx branch should certify");
6155 assert_eq!(out.mode, IntegratedExpectationMode::ExactSpecialFunction);
6156 let (ref_mean, ref_d1, _, _) = logit_reference_jet_highres_simpson(mu, sigma);
6157 assert_relative_eq!(out.mean, ref_mean, epsilon = 1e-9, max_relative = 1e-7);
6158 assert_relative_eq!(out.dmean_dmu, ref_d1, epsilon = 1e-9, max_relative = 1e-7);
6159 }
6160 assert!(
6164 logit_posterior_meanwith_deriv_exact(0.0, 0.3).is_err(),
6165 "erfcx branch must not claim ExactSpecialFunction when it cannot certify the derivative"
6166 );
6167 }
6168
6169 #[test]
6170 fn test_logit_integrated_derivative_is_even_in_mu() {
6171 let ctx = QuadratureContext::new();
6177 for &(mu, sigma) in &[(0.3, 0.3), (1.1, 0.8), (10.0, 1.0), (3.0, 3.0), (35.0, 1.0)] {
6178 let pos =
6179 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, mu, sigma)
6180 .expect("logit moments (+μ)");
6181 let neg =
6182 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, -mu, sigma)
6183 .expect("logit moments (-μ)");
6184 assert_relative_eq!(
6185 pos.dmean_dmu,
6186 neg.dmean_dmu,
6187 epsilon = 1e-9,
6188 max_relative = 1e-7
6189 );
6190 assert_relative_eq!(
6192 neg.mean,
6193 1.0 - pos.mean,
6194 epsilon = 1e-9,
6195 max_relative = 1e-7
6196 );
6197 }
6198 }
6199
6200 #[test]
6201 fn test_logit_dmean_dmu_equals_fd_of_mean_across_regimes() {
6202 let ctx = QuadratureContext::new();
6217 let h = 1e-4;
6218 let cases = [
6219 (0.0, 0.8), (0.7, 0.8), (1.5, 1.2), (-1.1, 0.9), (8.0, 1.0), (10.0, 1.5), (-9.0, 1.0), (0.5, 0.05), (0.5, 20.0), ];
6229 for &(mu, sigma) in &cases {
6230 let at = |m: f64| {
6231 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, m, sigma)
6232 .expect("logit moments")
6233 };
6234 let out = at(mu);
6235 let fd = (at(mu + h).mean - at(mu - h).mean) / (2.0 * h);
6236 assert!(
6237 (out.dmean_dmu - fd).abs() <= 1e-5,
6238 "dmean_dmu must equal d/dμ of mean at (μ={mu}, σ={sigma}): \
6239 returned {}, FD of mean {} (mode {:?})",
6240 out.dmean_dmu,
6241 fd,
6242 out.mode
6243 );
6244 assert!(
6248 out.dmean_dmu <= 0.25 + 1e-9 && out.dmean_dmu >= 0.0,
6249 "dmean_dmu out of [0, 0.25] at (μ={mu}, σ={sigma}): {}",
6250 out.dmean_dmu
6251 );
6252 }
6253 }
6254
6255 #[test]
6256 fn test_logit_scalar_matches_jet_at_large_sigma() {
6257 let ctx = QuadratureContext::new();
6263 for &(mu, sigma) in &[(3.0, 3.0), (4.0, 4.0), (2.0, 5.0), (5.0, 5.0)] {
6264 let scalar =
6265 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, mu, sigma)
6266 .expect("scalar logit moments");
6267 let jet = integrated_inverse_link_jet(&ctx, LinkFunction::Logit, mu, sigma)
6268 .expect("jet logit moments");
6269 let (ref_mean, ref_d1, _, _) = logit_reference_jet_highres_simpson(mu, sigma);
6273 assert_relative_eq!(scalar.mean, ref_mean, epsilon = 1e-9, max_relative = 1e-8);
6274 assert_relative_eq!(
6275 scalar.dmean_dmu,
6276 ref_d1,
6277 epsilon = 1e-9,
6278 max_relative = 1e-8
6279 );
6280 assert_relative_eq!(scalar.mean, jet.mean, epsilon = 1e-12, max_relative = 1e-12);
6288 assert_relative_eq!(
6289 scalar.dmean_dmu,
6290 jet.d1,
6291 epsilon = 1e-12,
6292 max_relative = 1e-12
6293 );
6294 }
6295 }
6296
6297 #[test]
6298 fn test_logit_jet_accurate_at_wide_sigma() {
6299 let ctx = QuadratureContext::new();
6307 for &(mu, sigma) in &[(3.0, 3.0), (4.0, 4.0), (2.0, 5.0), (5.0, 5.0), (0.5, 20.0)] {
6308 let jet = integrated_inverse_link_jet(&ctx, LinkFunction::Logit, mu, sigma)
6309 .expect("wide-σ logit jet");
6310 let (rm, rd1, rd2, rd3) = logit_reference_jet_highres_simpson(mu, sigma);
6311 assert_relative_eq!(jet.mean, rm, epsilon = 1e-8, max_relative = 1e-7);
6312 assert_relative_eq!(jet.d1, rd1, epsilon = 1e-8, max_relative = 1e-6);
6313 assert_relative_eq!(jet.d2, rd2, epsilon = 1e-8, max_relative = 1e-6);
6314 assert_relative_eq!(jet.d3, rd3, epsilon = 1e-8, max_relative = 1e-6);
6315 let scalar =
6317 integrated_inverse_link_mean_and_derivative(&ctx, LinkFunction::Logit, mu, sigma)
6318 .expect("scalar logit moments");
6319 assert_relative_eq!(jet.d1, scalar.dmean_dmu, epsilon = 1e-12);
6320 assert_relative_eq!(jet.mean, scalar.mean, epsilon = 1e-12);
6321 let pirls = integrated_logit_inverse_link_jet_pirls(&ctx, mu, sigma)
6323 .expect("wide-σ PIRLS logit jet");
6324 assert_relative_eq!(pirls.mean, jet.mean, epsilon = 1e-12);
6325 assert_relative_eq!(pirls.d1, jet.d1, epsilon = 1e-12);
6326 assert_relative_eq!(pirls.d2, jet.d2, epsilon = 1e-12);
6327 assert_relative_eq!(pirls.d3, jet.d3, epsilon = 1e-12);
6328 assert_eq!(pirls.mode, jet.mode);
6329 }
6330 }
6331
6332 #[test]
6333 fn test_logit_jet_continuous_across_ghq_simpson_seam() {
6334 let ctx = QuadratureContext::new();
6342 let sigma = LOGIT_JET_GHQ_SIGMA_MAX;
6343 for mu in [-2.0, -0.5, 0.0, 0.7, 1.3, 3.0] {
6344 let ghq = integrated_inverse_link_jet(&ctx, LinkFunction::Logit, mu, sigma)
6346 .expect("jet at seam (GHQ dispatch)");
6347 let simpson = logit_wide_sigma_jet(mu, sigma).expect("jet at seam (Simpson)");
6349 assert_relative_eq!(ghq.mean, simpson.mean, epsilon = 1e-9, max_relative = 1e-8);
6352 assert_relative_eq!(ghq.d1, simpson.d1, epsilon = 1e-9, max_relative = 1e-7);
6353 assert_relative_eq!(ghq.d2, simpson.d2, epsilon = 1e-9, max_relative = 1e-7);
6354 assert_relative_eq!(ghq.d3, simpson.d3, epsilon = 1e-8, max_relative = 1e-6);
6355 }
6356 }
6357
6358 #[test]
6359 fn test_logit_batch_uses_same_dispatchvalues() {
6360 let ctx = QuadratureContext::new();
6361 let eta = ndarray::array![-2.0, 0.0, 1.25, 35.0];
6362 let se = ndarray::array![0.1, 0.5, 1.0, 1.0];
6363 let batch_mean = logit_posterior_mean_batch(&ctx, &eta, &se)
6364 .expect("logit posterior mean batch should evaluate");
6365 let (batchmu, batch_dmu) = logit_posterior_meanwith_deriv_batch(&ctx, &eta, &se)
6366 .expect("logit posterior mean derivative batch should evaluate");
6367 for i in 0..eta.len() {
6368 let direct = integrated_inverse_link_mean_and_derivative(
6369 &ctx,
6370 LinkFunction::Logit,
6371 eta[i],
6372 se[i],
6373 )
6374 .expect("logit integrated inverse-link moments should evaluate");
6375 assert_relative_eq!(batch_mean[i], direct.mean, epsilon = 1e-12);
6376 assert_relative_eq!(batchmu[i], direct.mean, epsilon = 1e-12);
6377 assert_relative_eq!(batch_dmu[i], direct.dmean_dmu, epsilon = 1e-12);
6378 }
6379 }
6380
6381 #[test]
6382 fn exact_logit_small_se_branch_loses_tail_derivative() {
6383 let eta = 50.0_f64;
6384 let stable_z = (-eta).exp();
6385 let stable_dmu = stable_z / (1.0_f64 + stable_z).powi(2);
6386 assert!(stable_dmu > 0.0);
6387 let out = logit_posterior_meanwith_deriv_exact(eta, 0.0).expect("exact branch");
6388 let dmu = out.dmean_dmu;
6389 assert!(
6390 (dmu - stable_dmu).abs() < 1e-30,
6391 "exact logit small-se branch should use the stable derivative z/(1+z)^2 at eta={eta}; got {} vs {}",
6392 dmu,
6393 stable_dmu
6394 );
6395 }
6396
6397 #[test]
6398 fn integrated_family_moments_rejects_latent_cloglog_without_concrete_handler() {
6399 let ctx = QuadratureContext::new();
6405 let latent =
6406 gam_problem::types::LatentCLogLogState::new(0.4).expect("valid latent cloglog state");
6407 let spec =
6408 LikelihoodSpec::new(ResponseFamily::Binomial, InverseLink::LatentCLogLog(latent));
6409 let err = integrated_family_moments_jet(
6410 &ctx,
6411 &spec,
6412 LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
6413 0.2,
6414 0.5,
6415 )
6416 .expect_err("latent cloglog moments should error in this dispatcher");
6417 assert!(format!("{err}").contains("LatentCLogLog"));
6418 }
6419
6420 #[test]
6421 fn integrated_family_moments_supports_stateful_sas() {
6422 let ctx = QuadratureContext::new();
6423 let sas = crate::mixture_link::state_from_sasspec(gam_problem::types::SasLinkSpec {
6424 initial_epsilon: 0.3,
6425 initial_log_delta: -0.2,
6426 })
6427 .expect("sas state should reconstruct from raw parameters");
6428 let spec = LikelihoodSpec::new(ResponseFamily::Binomial, InverseLink::Sas(sas));
6429 let out = integrated_family_moments_jet(
6430 &ctx,
6431 &spec,
6432 LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
6433 0.2,
6434 0.5,
6435 )
6436 .expect("stateful SAS integrated moments should evaluate");
6437 assert!(out.mean.is_finite());
6438 assert!(out.d1.is_finite());
6439 assert!(out.d2.is_finite());
6440 assert!(out.d3.is_finite());
6441 assert!(out.mean > 0.0 && out.mean < 1.0);
6442 }
6443
6444 #[test]
6445 fn integrated_family_moments_supports_pure_probit_mixture() {
6446 let ctx = QuadratureContext::new();
6447 let state = crate::mixture_link::state_fromspec(&gam_problem::types::MixtureLinkSpec {
6448 components: vec![gam_problem::types::LinkComponent::Probit],
6449 initial_rho: ndarray::Array1::<f64>::zeros(0),
6450 })
6451 .expect("single-component probit mixture state");
6452 let spec = LikelihoodSpec::new(ResponseFamily::Binomial, InverseLink::Mixture(state));
6453 let out = integrated_family_moments_jet(
6454 &ctx,
6455 &spec,
6456 LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
6457 0.7,
6458 1.3,
6459 )
6460 .expect("pure probit mixture integrated moments should evaluate");
6461 let exact = integrated_probit_jet(0.7, 1.3);
6462 assert_relative_eq!(out.mean, exact.mean, epsilon = 1e-12);
6463 assert_relative_eq!(out.d1, exact.d1, epsilon = 1e-12);
6464 assert_relative_eq!(out.d2, exact.d2, epsilon = 1e-12);
6465 assert_relative_eq!(out.d3, exact.d3, epsilon = 1e-12);
6466 assert_eq!(out.mode, IntegratedExpectationMode::ExactClosedForm);
6467 }
6468
6469 #[test]
6470 fn integrated_family_moments_supports_pure_logit_mixture() {
6471 let ctx = QuadratureContext::new();
6472 let state = crate::mixture_link::state_fromspec(&gam_problem::types::MixtureLinkSpec {
6473 components: vec![gam_problem::types::LinkComponent::Logit],
6474 initial_rho: ndarray::Array1::<f64>::zeros(0),
6475 })
6476 .expect("single-component logit mixture state");
6477 let spec = LikelihoodSpec::new(ResponseFamily::Binomial, InverseLink::Mixture(state));
6478 let out = integrated_family_moments_jet(
6479 &ctx,
6480 &spec,
6481 LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
6482 1.1,
6483 0.8,
6484 )
6485 .expect("pure logit mixture integrated moments should evaluate");
6486 let exact = integrated_inverse_link_jet(&ctx, LinkFunction::Logit, 1.1, 0.8)
6487 .expect("canonical integrated logit jet");
6488 assert_relative_eq!(out.mean, exact.mean, epsilon = 1e-12);
6489 assert_relative_eq!(out.d1, exact.d1, epsilon = 1e-12);
6490 assert_relative_eq!(out.d2, exact.d2, epsilon = 1e-12);
6491 assert_relative_eq!(out.d3, exact.d3, epsilon = 1e-12);
6492 assert_eq!(out.mode, exact.mode);
6493 }
6494
6495 #[test]
6496 fn integrated_family_moments_supports_stateful_mixture() {
6497 let ctx = QuadratureContext::new();
6498 let state = crate::mixture_link::state_fromspec(&gam_problem::types::MixtureLinkSpec {
6499 components: vec![
6500 gam_problem::types::LinkComponent::Logit,
6501 gam_problem::types::LinkComponent::Probit,
6502 ],
6503 initial_rho: ndarray::array![0.35],
6504 })
6505 .expect("mixture state should reconstruct from rho");
6506 let spec = LikelihoodSpec::new(
6507 ResponseFamily::Binomial,
6508 InverseLink::Mixture(state.clone()),
6509 );
6510 let out = integrated_family_moments_jet(
6511 &ctx,
6512 &spec,
6513 LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
6514 0.2,
6515 0.5,
6516 )
6517 .expect("stateful mixture integrated moments should evaluate");
6518 let direct = integrated_mixture_jet(&ctx, 0.2, 0.5, &state)
6519 .expect("direct integrated mixture jet should evaluate");
6520 assert_relative_eq!(out.mean, direct.mean, epsilon = 1e-12);
6521 assert_relative_eq!(out.d1, direct.d1, epsilon = 1e-12);
6522 assert_relative_eq!(out.d2, direct.d2, epsilon = 1e-12);
6523 assert_relative_eq!(out.d3, direct.d3, epsilon = 1e-12);
6524 assert_eq!(out.mode, direct.mode);
6525 }
6526
6527 #[test]
6528 fn integrated_family_moments_use_scale_dispersion_for_tweedie_and_gamma() {
6529 let ctx = QuadratureContext::new();
6533 let e = 0.3_f64;
6535 let se = 0.5_f64;
6536 let m = (e + 0.5 * se * se).exp();
6537
6538 let p = 1.5_f64;
6540 let phi = 2.0_f64;
6541 let tweedie = LikelihoodSpec::tweedie_log(p);
6542 let out = integrated_family_moments_jet(
6543 &ctx,
6544 &tweedie,
6545 LikelihoodScaleMetadata::EstimatedTweediePhi { phi },
6546 e,
6547 se,
6548 )
6549 .expect("tweedie integrated moments should evaluate");
6550 let expected = phi * m.powf(p);
6551 assert_relative_eq!(out.variance, expected, epsilon = 1e-12);
6552 assert_relative_eq!(out.variance / m.powf(p), phi, epsilon = 1e-12);
6554
6555 let shape = 4.0_f64;
6557 let gamma = LikelihoodSpec::gamma_log();
6558 let out = integrated_family_moments_jet(
6559 &ctx,
6560 &gamma,
6561 LikelihoodScaleMetadata::EstimatedGammaShape { shape },
6562 e,
6563 se,
6564 )
6565 .expect("gamma integrated moments should evaluate");
6566 let expected = m * m / shape;
6567 assert_relative_eq!(out.variance, expected, epsilon = 1e-12);
6568 assert_relative_eq!(out.variance / (m * m), 1.0 / shape, epsilon = 1e-12);
6570
6571 let poisson = LikelihoodSpec::poisson_log();
6573 let out = integrated_family_moments_jet(
6574 &ctx,
6575 &poisson,
6576 LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
6577 e,
6578 se,
6579 )
6580 .expect("poisson integrated moments should evaluate");
6581 assert_relative_eq!(out.variance, m, epsilon = 1e-12);
6582
6583 let theta = 3.0_f64;
6585 let nb = LikelihoodSpec::negative_binomial_log(theta);
6586 let out = integrated_family_moments_jet(
6587 &ctx,
6588 &nb,
6589 LikelihoodScaleMetadata::EstimatedNegBinTheta { theta },
6590 e,
6591 se,
6592 )
6593 .expect("negative-binomial integrated moments should evaluate");
6594 assert_relative_eq!(out.variance, m + m * m / theta, epsilon = 1e-12);
6595
6596 let err = integrated_family_moments_jet(
6598 &ctx,
6599 &gamma,
6600 LikelihoodScaleMetadata::Unspecified,
6601 e,
6602 se,
6603 )
6604 .expect_err("gamma without a shape in the scale metadata must error");
6605 assert!(
6606 format!("{err}").contains("Gamma integrated variance requires the shape"),
6607 "unexpected error message: {err}"
6608 );
6609
6610 let err = integrated_family_moments_jet(
6612 &ctx,
6613 &tweedie,
6614 LikelihoodScaleMetadata::Unspecified,
6615 e,
6616 se,
6617 )
6618 .expect_err("tweedie without a φ in the scale metadata must error");
6619 assert!(
6620 format!("{err}").contains("Tweedie integrated variance requires dispersion"),
6621 "unexpected error message: {err}"
6622 );
6623 }
6624
6625 #[test]
6628 fn cloglog_g_derivatives_at_zero() {
6629 let (g, g1, g2, g3, g4) = cloglog_g_derivatives(0.0);
6630 let expected_g = 1.0 - (-1.0_f64).exp();
6632 assert_relative_eq!(g, expected_g, epsilon = 1e-14);
6633 let e_neg1 = (-1.0_f64).exp();
6635 assert_relative_eq!(g1, e_neg1, epsilon = 1e-14);
6636 assert_relative_eq!(g2, 0.0, epsilon = 1e-14);
6638 assert_relative_eq!(g3, -e_neg1, epsilon = 1e-14);
6640 assert_relative_eq!(g4, -e_neg1, epsilon = 1e-14);
6642 }
6643
6644 #[test]
6645 fn cloglog_g_derivatives_saturation() {
6646 let (g, g1, g2, g3, g4) = cloglog_g_derivatives(50.0);
6648 assert_relative_eq!(g, 1.0, epsilon = 1e-10);
6649 assert_eq!(g1, 0.0);
6650 assert_eq!(g2, 0.0);
6651 assert_eq!(g3, 0.0);
6652 assert_eq!(g4, 0.0);
6653
6654 let (g, g1, g2, g3, g4) = cloglog_g_derivatives(-50.0);
6656 let expected = (-50.0_f64).exp();
6657 assert_relative_eq!(g, expected, max_relative = 1e-10);
6658 assert_relative_eq!(g1, expected, max_relative = 1e-10);
6659 assert_relative_eq!(g2, expected, max_relative = 1e-10);
6661 assert_relative_eq!(g3, expected, max_relative = 1e-10);
6662 assert_relative_eq!(g4, expected, max_relative = 1e-10);
6663 }
6664
6665 #[test]
6666 fn cloglog_ghq_value_sigma_zero_matches_pointwise() {
6667 let ctx = QuadratureContext::new();
6668 for &mu in &[-2.0, -1.0, 0.0, 0.5, 1.5] {
6670 let val = cloglog_ghq_value(&ctx, mu, 0.0, 21);
6671 let (g, _, _, _, _) = cloglog_g_derivatives(mu);
6672 assert_relative_eq!(val, g, epsilon = 1e-14);
6673 }
6674 }
6675
6676 #[test]
6677 fn cloglog_ghq_value_bounded_zero_one() {
6678 let ctx = QuadratureContext::new();
6679 for &mu in &[-5.0, -2.0, 0.0, 1.0, 3.0, 10.0] {
6681 for &sigma in &[0.1, 0.5, 1.0, 2.0, 5.0] {
6682 let val = cloglog_ghq_value(&ctx, mu, sigma, 31);
6683 assert!((0.0..=1.0).contains(&val), "L({mu},{sigma}) = {val}");
6684 }
6685 }
6686 }
6687
6688 #[test]
6689 fn cloglog_ghq_derivatives_sigma_zero_matches_pointwise() {
6690 let ctx = QuadratureContext::new();
6691 let mu = 0.3;
6692 let d = cloglog_ghq_derivatives(&ctx, mu, 0.0, 21);
6693 let (g, g1, g2, g3, g4) = cloglog_g_derivatives(mu);
6694 assert_relative_eq!(d.l, g, epsilon = 1e-14);
6695 assert_relative_eq!(d.l_mu, g1, epsilon = 1e-14);
6696 assert_relative_eq!(d.l_mumu, g2, epsilon = 1e-14);
6697 assert_relative_eq!(d.l_mumumu, g3, epsilon = 1e-14);
6698 assert_relative_eq!(d.l_mumumumu, g4, epsilon = 1e-14);
6699
6700 assert_eq!(d.l_sigma, 0.0);
6702 assert_eq!(d.l_musigma, 0.0);
6703 assert_eq!(d.l_mumusigma, 0.0);
6704 assert_eq!(d.l_mumumusigma, 0.0);
6705 assert_eq!(d.l_sigmasigmasigma, 0.0);
6706 assert_eq!(d.l_musigmasigmasigma, 0.0);
6707
6708 assert_relative_eq!(d.l_sigmasigma, g2, epsilon = 1e-14);
6711 assert_relative_eq!(d.l_musigmasigma, g3, epsilon = 1e-14);
6712 assert_relative_eq!(d.l_mumusigmasigma, g4, epsilon = 1e-14);
6713 assert_relative_eq!(d.l_sigmasigmasigmasigma, 3.0 * g4, epsilon = 1e-14);
6714 }
6715
6716 #[test]
6717 fn cloglog_ghq_derivatives_finite_difference_mu() {
6718 let ctx = QuadratureContext::new();
6720 let mu = 0.5;
6721 let sigma = 0.8;
6722 let h = 1e-6;
6723 let d = cloglog_ghq_derivatives(&ctx, mu, sigma, 31);
6724 let l_plus = cloglog_ghq_value(&ctx, mu + h, sigma, 31);
6725 let l_minus = cloglog_ghq_value(&ctx, mu - h, sigma, 31);
6726 let fd_mu = (l_plus - l_minus) / (2.0 * h);
6727 assert_relative_eq!(d.l_mu, fd_mu, epsilon = 1e-5);
6728
6729 let d_plus = cloglog_ghq_derivatives(&ctx, mu + h, sigma, 31);
6731 let d_minus = cloglog_ghq_derivatives(&ctx, mu - h, sigma, 31);
6732 let fd_mumu = (d_plus.l_mu - d_minus.l_mu) / (2.0 * h);
6733 assert_relative_eq!(d.l_mumu, fd_mumu, epsilon = 1e-4);
6734 }
6735
6736 #[test]
6737 fn cloglog_ghq_derivatives_finite_difference_sigma() {
6738 let ctx = QuadratureContext::new();
6740 let mu = 0.2;
6741 let sigma = 1.0;
6742 let h = 1e-6;
6743 let d = cloglog_ghq_derivatives(&ctx, mu, sigma, 31);
6744 let l_plus = cloglog_ghq_value(&ctx, mu, sigma + h, 31);
6745 let l_minus = cloglog_ghq_value(&ctx, mu, sigma - h, 31);
6746 let fd_sigma = (l_plus - l_minus) / (2.0 * h);
6747 assert_relative_eq!(d.l_sigma, fd_sigma, epsilon = 1e-5);
6748 }
6749
6750 #[test]
6751 fn cloglog_ghq_derivatives_finite_difference_cross() {
6752 let ctx = QuadratureContext::new();
6754 let mu = -0.5;
6755 let sigma = 0.6;
6756 let h = 1e-6;
6757 let d = cloglog_ghq_derivatives(&ctx, mu, sigma, 31);
6758 let d_plus = cloglog_ghq_derivatives(&ctx, mu, sigma + h, 31);
6759 let d_minus = cloglog_ghq_derivatives(&ctx, mu, sigma - h, 31);
6760 let fd_musigma = (d_plus.l_mu - d_minus.l_mu) / (2.0 * h);
6761 assert_relative_eq!(d.l_musigma, fd_musigma, epsilon = 1e-4);
6762 }
6763
6764 #[test]
6765 fn cloglog_ghq_l_mu_nonnegative() {
6766 let ctx = QuadratureContext::new();
6768 for &mu in &[-3.0, -1.0, 0.0, 1.0, 3.0] {
6769 for &sigma in &[0.1, 0.5, 1.0, 2.0] {
6770 let d = cloglog_ghq_derivatives(&ctx, mu, sigma, 21);
6771 assert!(
6772 d.l_mu >= -1e-14,
6773 "L_mu should be non-negative at mu={mu}, sigma={sigma}: got {}",
6774 d.l_mu
6775 );
6776 }
6777 }
6778 }
6779
6780 #[test]
6781 fn cloglog_ghq_adaptive_matches_explicit() {
6782 let ctx = QuadratureContext::new();
6783 let mu = 0.7;
6784 let sigma = 1.2;
6785 let adaptive = cloglog_ghq_derivatives_adaptive(&ctx, mu, sigma);
6786 let n = adaptive_point_count_from_sd(sigma);
6787 let explicit = cloglog_ghq_derivatives(&ctx, mu, sigma, n);
6788 assert_relative_eq!(adaptive.l, explicit.l, epsilon = 1e-15);
6789 assert_relative_eq!(adaptive.l_mu, explicit.l_mu, epsilon = 1e-15);
6790 assert_relative_eq!(adaptive.l_sigma, explicit.l_sigma, epsilon = 1e-15);
6791 assert_relative_eq!(adaptive.l_mumu, explicit.l_mumu, epsilon = 1e-15);
6792 }
6793
6794 #[test]
6795 fn cloglog_ghq_value_matches_mathematical_target_in_central_regime() {
6796 let ctx = QuadratureContext::new();
6797 for &mu in &[-1.0, 0.0, 0.5, 2.0] {
6798 for &sigma in &[0.1, 0.5, 1.0] {
6799 let ghq = cloglog_ghq_value(&ctx, mu, sigma, 51);
6800 let (expected_mean, _) = cloglog_reference_mean_and_derivative(mu, sigma);
6801 assert_relative_eq!(ghq, expected_mean, epsilon = 1e-12, max_relative = 2e-8);
6802 }
6803 }
6804 }
6805
6806 #[test]
6809 fn cloglog_negative_tail_mean_matches_exact_near_transition() {
6810 let eta: f64 = -30.0;
6814 let exact = {
6815 let ex = eta.exp();
6816 -(-ex).exp_m1()
6817 };
6818 let tail = cloglog_negative_tail_mean(eta);
6819 assert!(
6820 (exact - tail).abs() < 1e-26 * exact.abs().max(1e-300),
6821 "tail mean at η={eta}: exact={exact:.6e} tail={tail:.6e}"
6822 );
6823 }
6824
6825 #[inline]
6826 fn cloglog_negative_tail_derivative(eta: f64) -> f64 {
6827 if eta < -745.0 {
6829 0.0
6830 } else {
6831 let ex = safe_exp(eta);
6832 (ex * (-ex).exp()).max(0.0)
6833 }
6834 }
6835
6836 #[test]
6837 fn cloglog_negative_tail_derivative_matches_exact_near_transition() {
6838 let eta: f64 = -30.0;
6840 let ex = eta.exp();
6841 let exact = ex * (-ex).exp();
6842 let tail = cloglog_negative_tail_derivative(eta);
6843 assert!(
6844 (exact - tail).abs() < 1e-26 * exact.abs().max(1e-300),
6845 "tail derivative at η={eta}: exact={exact:.6e} tail={tail:.6e}"
6846 );
6847 }
6848
6849 #[test]
6850 fn cloglog_negative_tail_degenerate_branch_matches_target_near_transition() {
6851 let ctx = QuadratureContext::default();
6852 let sigma = 0.0;
6853 for &mu in &[-30.001, -30.0, -29.999] {
6854 let out = cloglog_posterior_meanwith_deriv_controlled(&ctx, mu, sigma);
6855 assert_relative_eq!(
6856 out.mean,
6857 cloglog_mean_exact(mu),
6858 epsilon = 1e-28,
6859 max_relative = 1e-15
6860 );
6861 assert_relative_eq!(
6862 out.dmean_dmu,
6863 cloglog_mean_d1_exact(mu),
6864 epsilon = 1e-28,
6865 max_relative = 1e-15
6866 );
6867 }
6868 }
6869
6870 #[test]
6871 fn cloglog_negative_tail_small_sigma_branch_matches_target_near_transition() {
6872 let ctx = QuadratureContext::default();
6873 let sigma = 0.1;
6874 for &mu in &[-30.001, -30.0, -29.999] {
6875 let out = cloglog_posterior_meanwith_deriv_controlled(&ctx, mu, sigma);
6876 let (expected_mean, expected_deriv) = cloglog_reference_mean_and_derivative(mu, sigma);
6877 assert_relative_eq!(
6878 out.mean,
6879 expected_mean,
6880 epsilon = 1e-24,
6881 max_relative = 1e-10
6882 );
6883 assert_relative_eq!(
6884 out.dmean_dmu,
6885 expected_deriv,
6886 epsilon = 1e-24,
6887 max_relative = 1e-10
6888 );
6889 }
6890 }
6891
6892 fn ref_cholesky_heap(cov: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
6896 let n = cov.len();
6897 if n == 0 || cov.iter().any(|r| r.len() != n) {
6898 return None;
6899 }
6900 let mut base = cov.to_vec();
6901 for retry in 0..8 {
6902 let jitter = if retry == 0 {
6903 0.0
6904 } else {
6905 1e-12 * 10f64.powi(retry - 1)
6906 };
6907 if jitter > 0.0 {
6908 for i in 0..n {
6909 base[i][i] = cov[i][i] + jitter;
6910 }
6911 }
6912 let mut l = vec![vec![0.0_f64; n]; n];
6913 let mut ok = true;
6914 for i in 0..n {
6915 for j in 0..=i {
6916 let mut sum = base[i][j];
6917 for k in 0..j {
6918 sum -= l[i][k] * l[j][k];
6919 }
6920 if i == j {
6921 if !sum.is_finite() || sum <= 0.0 {
6922 ok = false;
6923 break;
6924 }
6925 l[i][j] = sum.sqrt();
6926 } else {
6927 l[i][j] = sum / l[j][j];
6928 }
6929 }
6930 if !ok {
6931 break;
6932 }
6933 }
6934 if ok {
6935 return Some(l);
6936 }
6937 }
6938 None
6939 }
6940
6941 #[test]
6942 fn cholesky_static_matches_heap_d2() {
6943 let cases: &[[[f64; 2]; 2]] = &[
6946 [[1.0, 0.0], [0.0, 1.0]],
6947 [[2.5, 0.3], [0.3, 0.75]],
6948 [[1.0, 0.9999], [0.9999, 1.0]],
6949 [[1e-10, 0.0], [0.0, 1e-10]],
6950 [[4.0, -1.5], [-1.5, 2.25]],
6951 ];
6952 for cov in cases {
6953 let stack = cholesky_static_with_jitter::<2>(cov).expect("stack cholesky");
6954 let heap_in: Vec<Vec<f64>> = cov.iter().map(|r| r.to_vec()).collect();
6955 let heap = ref_cholesky_heap(&heap_in).expect("heap cholesky");
6956 for i in 0..2 {
6957 for j in 0..2 {
6958 assert_eq!(
6959 stack[i][j].to_bits(),
6960 heap[i][j].to_bits(),
6961 "mismatch at ({i},{j}) for cov={cov:?}"
6962 );
6963 }
6964 }
6965 }
6966 }
6967
6968 #[test]
6969 fn cholesky_static_matches_heap_d3() {
6970 let cases: &[[[f64; 3]; 3]] = &[
6971 [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
6972 [[2.0, 0.5, 0.1], [0.5, 1.5, -0.2], [0.1, -0.2, 0.8]],
6973 [[4.0, 1.0, 0.5], [1.0, 3.0, 0.25], [0.5, 0.25, 2.0]],
6974 ];
6975 for cov in cases {
6976 let stack = cholesky_static_with_jitter::<3>(cov).expect("stack cholesky");
6977 let heap_in: Vec<Vec<f64>> = cov.iter().map(|r| r.to_vec()).collect();
6978 let heap = ref_cholesky_heap(&heap_in).expect("heap cholesky");
6979 for i in 0..3 {
6980 for j in 0..3 {
6981 assert_eq!(
6982 stack[i][j].to_bits(),
6983 heap[i][j].to_bits(),
6984 "mismatch at ({i},{j}) for cov={cov:?}"
6985 );
6986 }
6987 }
6988 }
6989 }
6990
6991 #[test]
6992 fn cholesky_static_d1() {
6993 let l = cholesky_static_with_jitter::<1>(&[[2.25]]).expect("d=1");
6994 assert_eq!(l[0][0], 1.5);
6995 assert!(cholesky_static_with_jitter::<1>(&[[-1.0e-13]]).is_some());
7005 assert!(cholesky_static_with_jitter::<1>(&[[-1.0e3]]).is_none());
7008 }
7009}