1use faer::Side;
27use gam_terms::construction::CanonicalPenalty;
28use gam_solve::estimate::reml::FirthDenseOperator;
29use gam_solve::estimate::reml::penalty_logdet::PenaltyPseudologdet;
30use gam_solve::estimate::{
31 EstimationError, UnifiedFitResult, validate_explicit_dense_hessian_for_whitening,
32};
33use gam_linalg::faer_ndarray::{FaerCholesky, FaerEigh, fast_ata_into, fast_atv, fast_av_into};
34use gam_models::wiggle::monotone_wiggle_basis_with_derivative_order;
35use crate::gpu_polya_gamma::{PgSeed, PolyaGammaBatchInput};
36use gam_linalg::triangular::back_substitution_lower_transpose_guarded_into;
37use gam_linalg::matrix::DesignMatrix;
38use gam_solve::mixture_link::{
39 InverseLinkKernel, LinkParamPartials, inverse_link_jet_for_inverse_link, softmax_last_fixedzero,
40};
41use gam_problem::types::{
42 InverseLink, LikelihoodSpec, ResponseFamily, RhoPrior, StandardLink, is_valid_tweedie_power,
43};
44use general_mcmc::generic_hmc::HamiltonianTarget;
45pub use general_mcmc::generic_nuts::NUTSMassMatrixConfig;
46use general_mcmc::generic_nuts::{GenericNUTS, MassMatrixAdaptation};
47use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, Axis};
48use rand::{RngExt, SeedableRng, rngs::StdRng};
49use serde::{Deserialize, Serialize};
50use std::cell::RefCell;
51use std::fmt;
52use std::sync::{Arc, Mutex};
53
54#[inline]
59fn likelihood_spec_supports_firth(spec: &LikelihoodSpec) -> bool {
60 spec.supports_firth()
61}
62
63#[inline]
66fn likelihood_spec_jeffreys_link(spec: &LikelihoodSpec) -> Option<InverseLink> {
67 if likelihood_spec_supports_firth(spec) {
68 Some(spec.link.clone())
69 } else {
70 None
71 }
72}
73
74#[derive(Debug, Clone)]
81pub enum HmcError {
82 NonFiniteState { reason: String },
85 InvalidConfig { reason: String },
88 DimensionMismatch { reason: String },
90 FirthUnsupported { reason: String },
93 LinkMismatch { reason: String },
96 UnsupportedFamily { reason: String },
98 SamplingFailed { reason: String },
101}
102
103impl fmt::Display for HmcError {
104 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105 match self {
106 HmcError::NonFiniteState { reason }
107 | HmcError::InvalidConfig { reason }
108 | HmcError::DimensionMismatch { reason }
109 | HmcError::FirthUnsupported { reason }
110 | HmcError::LinkMismatch { reason }
111 | HmcError::UnsupportedFamily { reason }
112 | HmcError::SamplingFailed { reason } => f.write_str(reason),
113 }
114 }
115}
116
117impl From<HmcError> for String {
118 fn from(err: HmcError) -> String {
119 err.to_string()
120 }
121}
122
123const MAX_AUTOCORRELATION_LAG: usize = 1000;
128
129const AUTOCOVARIANCE_FLOOR: f64 = 1e-16;
133
134fn compute_split_rhat_and_ess(samples: &Array3<f64>) -> (f64, f64) {
139 let n_chains = samples.shape()[0];
140 let n_samples = samples.shape()[1];
141 let dim = samples.shape()[2];
142
143 if n_chains < 2 || n_samples < 4 {
144 return (1.0, n_chains as f64 * n_samples as f64 * 0.5);
145 }
146
147 let half = n_samples / 2;
149 let n_split_chains = n_chains * 2;
150 let n_split_samples = half;
151
152 let mut max_rhat = 0.0f64;
153 let mut min_ess = f64::INFINITY;
154
155 #[inline]
156 fn splitvalue(
157 samples: &Array3<f64>,
158 n_chains: usize,
159 half: usize,
160 dim: usize,
161 sc: usize,
162 t: usize,
163 ) -> f64 {
164 let chain = sc % n_chains;
165 if sc < n_chains {
166 samples[[chain, t, dim]]
167 } else {
168 samples[[chain, half + t, dim]]
169 }
170 }
171
172 fn ess_from_split_dimension(
173 samples: &Array3<f64>,
174 n_chains: usize,
175 half: usize,
176 dim: usize,
177 ) -> f64 {
178 let m = n_chains * 2;
179 let n = half;
180 if m == 0 || n < 4 {
181 return (m * n).max(1) as f64;
182 }
183
184 let mut means = vec![0.0_f64; m];
185 let mut gamma0 = vec![0.0_f64; m];
186 for sc in 0..m {
187 let mut sum = 0.0;
188 for t in 0..n {
189 sum += splitvalue(samples, n_chains, half, dim, sc, t);
190 }
191 let mean = sum / n as f64;
192 means[sc] = mean;
193 let mut g0 = 0.0;
194 for t in 0..n {
195 let d = splitvalue(samples, n_chains, half, dim, sc, t) - mean;
196 g0 += d * d;
197 }
198 gamma0[sc] = (g0 / n as f64).max(AUTOCOVARIANCE_FLOOR);
199 }
200
201 let max_lag = (n - 1).min(MAX_AUTOCORRELATION_LAG);
202 let mut tau = 1.0_f64;
203 let mut lag = 1usize;
204 while lag < max_lag {
205 let mut pair = 0.0_f64;
206 for l in [lag, lag + 1] {
207 if l > max_lag {
208 continue;
209 }
210 let mut rho_l = 0.0;
211 for sc in 0..m {
212 let mu = means[sc];
213 let mut cov = 0.0;
214 let denom = (n - l) as f64;
215 for t in 0..(n - l) {
216 let x0 = splitvalue(samples, n_chains, half, dim, sc, t);
217 let x1 = splitvalue(samples, n_chains, half, dim, sc, t + l);
218 cov += (x0 - mu) * (x1 - mu);
219 }
220 cov /= denom;
221 rho_l += cov / gamma0[sc];
222 }
223 rho_l /= m as f64;
224 pair += rho_l;
225 }
226 if !pair.is_finite() || pair <= 0.0 {
227 break;
228 }
229 tau += 2.0 * pair;
230 lag += 2;
231 }
232 if !tau.is_finite() || tau <= 0.0 {
233 return 1.0;
234 }
235 let total = (m * n) as f64;
236 (total / tau).clamp(1.0, total)
237 }
238
239 let mut chain_means = vec![0.0_f64; n_split_chains];
240 let mut chainvars = vec![0.0_f64; n_split_chains];
241 for d in 0..dim {
242 for chain in 0..n_chains {
243 let mut sum1 = 0.0;
245 for i in 0..half {
246 sum1 += samples[[chain, i, d]];
247 }
248 let mean1 = sum1 / half as f64;
249 let mut var1 = 0.0;
250 for i in 0..half {
251 let diff = samples[[chain, i, d]] - mean1;
252 var1 += diff * diff;
253 }
254 var1 /= (half - 1).max(1) as f64;
255 let first_idx = chain;
256 chain_means[first_idx] = mean1;
257 chainvars[first_idx] = var1;
258
259 let mut sum2 = 0.0;
261 for i in half..(2 * half) {
262 sum2 += samples[[chain, i, d]];
263 }
264 let mean2 = sum2 / half as f64;
265 let mut var2 = 0.0;
266 for i in half..(2 * half) {
267 let diff = samples[[chain, i, d]] - mean2;
268 var2 += diff * diff;
269 }
270 var2 /= (half - 1).max(1) as f64;
271 let second_idx = n_chains + chain;
272 chain_means[second_idx] = mean2;
273 chainvars[second_idx] = var2;
274 }
275
276 let w: f64 = chainvars.iter().copied().sum::<f64>() / n_split_chains as f64;
278
279 let overall_mean: f64 = chain_means.iter().copied().sum::<f64>() / n_split_chains as f64;
281 let b: f64 = chain_means
282 .iter()
283 .map(|m| (m - overall_mean).powi(2))
284 .sum::<f64>()
285 * n_split_samples as f64
286 / (n_split_chains - 1) as f64;
287
288 let var_hat = (n_split_samples as f64 - 1.0) / n_split_samples as f64 * w
290 + b / n_split_samples as f64;
291
292 let rhat_d = if w > 1e-10 { (var_hat / w).sqrt() } else { 1.0 };
294 max_rhat = max_rhat.max(rhat_d);
295
296 let ess_d = ess_from_split_dimension(samples, n_chains, half, d);
298 min_ess = min_ess.min(ess_d);
299 }
300
301 (max_rhat, min_ess.max(1.0))
302}
303
304fn solve_upper_triangular_transpose(l: &Array2<f64>, dim: usize) -> Array2<f64> {
325 let mut result = Array2::<f64>::zeros((dim, dim));
326 if dim == 0 {
327 return result;
328 }
329
330 let l_owned;
334 let l_rows: &[f64] = if let Some(s) = l.as_slice() {
335 s
336 } else {
337 l_owned = l.to_owned();
338 l_owned
339 .as_slice()
340 .expect("owned standard-layout Array2 has contiguous storage")
341 };
342
343 let mut y = vec![0.0_f64; dim];
345
346 for col in 0..dim {
347 let d_col = l_rows[col * dim + col];
350 let inv_d_col = if d_col.abs() > 1e-15 {
351 1.0 / d_col
352 } else {
353 0.0
354 };
355 y[col] = inv_d_col;
356
357 for i in (col + 1)..dim {
360 let row_off = i * dim;
361 let l_row = &l_rows[row_off + col..row_off + i];
362 let y_seg = &y[col..i];
363 let mut sum = 0.0_f64;
367 for k in 0..l_row.len() {
368 sum += l_row[k] * y_seg[k];
369 }
370 let d = l_rows[row_off + i];
371 y[i] = if d.abs() > 1e-15 { -sum / d } else { 0.0 };
372 }
373
374 let res_row_start = col * dim + col;
378 let res_row = &mut result.as_slice_mut().expect("owned Array2 contiguous")
379 [res_row_start..res_row_start + (dim - col)];
380 for (k, slot) in res_row.iter_mut().enumerate() {
381 *slot = y[col + k];
382 }
383
384 for slot in &mut y[col..dim] {
386 *slot = 0.0;
387 }
388 }
389
390 result
391}
392
393struct WhiteningTransform {
394 chol: Array2<f64>,
395 chol_t: Array2<f64>,
396}
397
398fn hessian_whitening_transform(
399 hessian: ArrayView2<f64>,
400 dim: usize,
401 cov_scale: f64,
402 cholesky_error_prefix: &str,
403) -> Result<WhiteningTransform, String> {
404 let hessian_owned = hessian.to_owned();
405 let chol_factor = hessian_owned
406 .cholesky(Side::Lower)
407 .map_err(|e| format!("{cholesky_error_prefix}: {:?}", e))?;
408 let l_h = chol_factor.lower_triangular();
409 let mut chol = solve_upper_triangular_transpose(&l_h, dim);
410 let sqrt_cov_scale = cov_scale.max(0.0).sqrt();
411 if (sqrt_cov_scale - 1.0).abs() > 0.0 {
412 chol.mapv_inplace(|v| v * sqrt_cov_scale);
413 }
414 let chol_t = chol.t().to_owned();
415 Ok(WhiteningTransform { chol, chol_t })
416}
417
418#[derive(Clone)]
423struct SharedData {
424 x: Arc<Array2<f64>>,
426 y: Arc<Array1<f64>>,
428 weights: Arc<Array1<f64>>,
430 mode: Arc<Array1<f64>>,
432 offset: Option<Arc<Array1<f64>>>,
440 gamma_shape: f64,
442 dispersion: gam_solve::model_types::Dispersion,
453 n_samples: usize,
455 dim: usize,
457}
458
459thread_local! {
460 static NUTS_RESIDUAL_SCRATCH: RefCell<Array1<f64>> = RefCell::new(Array1::zeros(0));
461}
462
463#[derive(Debug, Clone, Copy, PartialEq, Eq)]
469pub enum NutsFamily {
470 Gaussian,
471 BinomialLogit,
472 BinomialProbit,
473 BinomialCLogLog,
474 PoissonLog,
475 TweedieLog,
476 NegativeBinomialLog,
477 GammaLog,
478}
479
480impl NutsFamily {
481 #[inline]
482 fn likelihood_spec(self) -> LikelihoodSpec {
483 match self {
484 Self::Gaussian => LikelihoodSpec {
485 response: ResponseFamily::Gaussian,
486 link: InverseLink::Standard(StandardLink::Identity),
487 },
488 Self::BinomialLogit => LikelihoodSpec {
489 response: ResponseFamily::Binomial,
490 link: InverseLink::Standard(StandardLink::Logit),
491 },
492 Self::BinomialProbit => LikelihoodSpec {
493 response: ResponseFamily::Binomial,
494 link: InverseLink::Standard(StandardLink::Probit),
495 },
496 Self::BinomialCLogLog => LikelihoodSpec {
497 response: ResponseFamily::Binomial,
498 link: InverseLink::Standard(StandardLink::CLogLog),
499 },
500 Self::PoissonLog => LikelihoodSpec {
501 response: ResponseFamily::Poisson,
502 link: InverseLink::Standard(StandardLink::Log),
503 },
504 Self::TweedieLog => LikelihoodSpec {
505 response: ResponseFamily::Tweedie { p: 1.5 },
506 link: InverseLink::Standard(StandardLink::Log),
507 },
508 Self::NegativeBinomialLog => LikelihoodSpec {
509 response: ResponseFamily::NegativeBinomial {
510 theta: 1.0,
511 theta_fixed: false,
512 },
513 link: InverseLink::Standard(StandardLink::Log),
514 },
515 Self::GammaLog => LikelihoodSpec {
516 response: ResponseFamily::Gamma,
517 link: InverseLink::Standard(StandardLink::Log),
518 },
519 }
520 }
521
522 #[inline]
552 fn coefficient_covariance_scale(self, profiled_gaussian_phi: f64) -> f64 {
553 match self {
554 NutsFamily::Gaussian => profiled_gaussian_phi,
555 _ => 1.0,
556 }
557 }
558}
559
560pub struct NutsPosterior {
573 data: SharedData,
575 chol: Array2<f64>,
578 chol_t: Array2<f64>,
580 nuts_family: NutsFamily,
582 firth_enabled: bool,
585 penalty_z_quad: Array2<f64>,
591 penalty_z_lin: Array1<f64>,
594 penalty_z_const: f64,
596 cov_scale: f64,
602}
603
604impl NutsPosterior {
605 pub fn new(
620 x: ArrayView2<f64>,
621 y: ArrayView1<f64>,
622 weights: ArrayView1<f64>,
623 penalty_matrix: ArrayView2<f64>,
624 mode: ArrayView1<f64>,
625 hessian: ArrayView2<f64>,
626 nuts_family: NutsFamily,
627 gamma_shape: f64,
628 dispersion: gam_solve::model_types::Dispersion,
629 firth_enabled: bool,
630 ) -> Result<Self, String> {
631 let n_samples = x.nrows();
632 let dim = x.ncols();
633
634 if !penalty_matrix.iter().all(|x| x.is_finite()) {
636 return Err(HmcError::NonFiniteState {
637 reason: "Penalty matrix contains NaN or Inf values".to_string(),
638 }
639 .into());
640 }
641 if !hessian.iter().all(|x| x.is_finite()) {
642 return Err(HmcError::NonFiniteState {
643 reason: "Hessian matrix contains NaN or Inf values".to_string(),
644 }
645 .into());
646 }
647 if !mode.iter().all(|x| x.is_finite()) {
648 return Err(HmcError::NonFiniteState {
649 reason: "Mode vector contains NaN or Inf values".to_string(),
650 }
651 .into());
652 }
653
654 validate_firth_support(nuts_family, firth_enabled).map_err(String::from)?;
655 if nuts_family.likelihood_spec().is_binomial() {
656 validate_binary_responses("binomial NUTS", &y, &weights).map_err(String::from)?;
657 }
658 if matches!(nuts_family, NutsFamily::NegativeBinomialLog) {
659 validate_count_responses("negative-binomial NUTS", &y, &weights)
660 .map_err(String::from)?;
661 }
662
663 let cov_scale = nuts_family.coefficient_covariance_scale(dispersion.phi());
672 let whitening = hessian_whitening_transform(
673 hessian,
674 dim,
675 cov_scale,
676 "Hessian Cholesky decomposition failed",
677 )?;
678 let chol = whitening.chol;
679 let chol_t = whitening.chol_t;
680
681 let penalty_owned = penalty_matrix.to_owned();
689 let mode_owned = mode.to_owned();
690 let s_mu = penalty_owned.dot(&mode_owned);
691 let penalty_z_const = 0.5 * mode_owned.dot(&s_mu);
692 let penalty_z_lin = chol_t.dot(&s_mu);
693 let s_chol = penalty_owned.dot(&chol);
696 let penalty_z_quad = chol_t.dot(&s_chol);
697
698 let data = SharedData {
699 x: Arc::new(x.to_owned()),
700 y: Arc::new(y.to_owned()),
701 weights: Arc::new(weights.to_owned()),
702 mode: Arc::new(mode_owned),
703 offset: None,
704 gamma_shape,
705 dispersion,
706 n_samples,
707 dim,
708 };
709
710 Ok(Self {
711 data,
712 chol,
713 chol_t,
714 nuts_family,
715 firth_enabled,
716 penalty_z_quad,
717 penalty_z_lin,
718 penalty_z_const,
719 cov_scale,
720 })
721 }
722
723 fn with_offset(mut self, offset: ArrayView1<f64>) -> Result<Self, String> {
733 if offset.len() != self.data.n_samples {
734 return Err(HmcError::DimensionMismatch {
735 reason: format!(
736 "NUTS offset length {} does not match {} observations",
737 offset.len(),
738 self.data.n_samples
739 ),
740 }
741 .into());
742 }
743 if !offset.iter().all(|v| v.is_finite()) {
744 return Err(HmcError::NonFiniteState {
745 reason: "NUTS offset contains NaN or Inf values".to_string(),
746 }
747 .into());
748 }
749 self.data.offset = Some(Arc::new(offset.to_owned()));
750 Ok(self)
751 }
752
753 fn compute_logp_and_grad_nd_into(
754 &self,
755 z: &Array1<f64>,
756 residual: &mut Array1<f64>,
757 grad: &mut Array1<f64>,
758 ) -> f64 {
759 let beta = self.data.mode.as_ref() + &self.chol.dot(z);
762
763 let mut eta = gam_linalg::faer_ndarray::fast_av(self.data.x.as_ref(), &beta);
765 if let Some(offset) = self.data.offset.as_ref() {
766 eta += offset.as_ref();
767 }
768
769 let (ll, mut grad_ll_beta) = self.family_logp_and_grad_into(&eta, residual);
771
772 let mut firth_logdet = 0.0;
773 if self.firth_enabled {
774 match firth_jeffreys_logp_and_grad(self.nuts_family, &self.data, &eta) {
775 Ok((value, grad_beta_firth)) => {
776 firth_logdet = value;
777 grad_ll_beta += &grad_beta_firth;
778 }
779 Err(err) => {
780 log::warn!(
781 "[NUTS/Firth] Jeffreys target became invalid at the current state: {}",
782 err
783 );
784 grad.fill(0.0);
785 return f64::NEG_INFINITY;
786 }
787 }
788 }
789
790 let penalty_scale = 1.0 / self.cov_scale.max(1e-300);
814 let mz = self.penalty_z_quad.dot(z);
815 let lin_term = self.penalty_z_lin.dot(z);
816 let quad_term = 0.5 * z.dot(&mz);
817 let penalty = penalty_scale * (self.penalty_z_const + lin_term + quad_term);
818
819 fast_av_into(&self.chol_t, &grad_ll_beta, grad);
822 let lin_view = self.penalty_z_lin.view();
824 ndarray::Zip::from(grad)
825 .and(&lin_view)
826 .and(&mz)
827 .par_for_each(|g, &l, &m| {
828 *g -= penalty_scale * (l + m);
829 });
830
831 ll + firth_logdet - penalty
832 }
833
834 fn family_logp_and_grad_into(
835 &self,
836 eta: &Array1<f64>,
837 residual: &mut Array1<f64>,
838 ) -> (f64, Array1<f64>) {
839 nuts_family_logp_and_grad_into(self.nuts_family, &self.data, eta, residual)
840 }
841
842 pub fn chol(&self) -> &Array2<f64> {
844 &self.chol
845 }
846
847 pub fn mode(&self) -> &Array1<f64> {
849 &self.data.mode
850 }
851
852 pub fn dim(&self) -> usize {
854 self.data.dim
855 }
856}
857
858const HALF_LOG_2PI: f64 = 0.918_938_533_204_672_7;
859
860#[inline]
861fn standard_normal_log_pdf(x: f64) -> f64 {
862 -0.5 * x * x - HALF_LOG_2PI
863}
864
865#[inline]
867fn log_ndtr(x: f64) -> f64 {
868 let arg = -x * std::f64::consts::FRAC_1_SQRT_2;
869 let erfc_val = statrs::function::erf::erfc(arg);
870 if erfc_val > 0.0 {
871 erfc_val.ln() - std::f64::consts::LN_2
872 } else {
873 -0.5 * x * x - (-x).ln() - HALF_LOG_2PI
874 }
875}
876
877#[inline]
878fn validate_firth_support(family: NutsFamily, firth_enabled: bool) -> Result<(), HmcError> {
879 let spec = family.likelihood_spec();
880 if firth_enabled && !likelihood_spec_supports_firth(&spec) {
881 return Err(HmcError::FirthUnsupported {
882 reason: format!(
883 "NUTS with Firth requires a Binomial inverse link with a Fisher-weight jet; {} does not support it",
884 spec.pretty_name()
885 ),
886 });
887 }
888 Ok::<(), _>(())
889}
890
891#[inline]
892fn validate_firth_likelihood_support(
893 likelihood: &LikelihoodSpec,
894 firth_enabled: bool,
895) -> Result<(), HmcError> {
896 if firth_enabled && !likelihood_spec_supports_firth(likelihood) {
897 return Err(HmcError::FirthUnsupported {
898 reason: format!(
899 "Joint HMC with Firth requires a Binomial inverse link with a Fisher-weight jet; {} does not support it",
900 likelihood.pretty_name()
901 ),
902 });
903 }
904 Ok::<(), _>(())
905}
906
907#[inline]
908fn valid_count_response(y: f64) -> bool {
909 y.is_finite() && y >= 0.0 && (y - y.round()).abs() <= 1e-9
910}
911
912fn validate_count_responses(
913 family: &str,
914 y: &ArrayView1<'_, f64>,
915 weights: &ArrayView1<'_, f64>,
916) -> Result<(), HmcError> {
917 for (i, (&yi, &wi)) in y.iter().zip(weights.iter()).enumerate() {
918 if wi > 0.0 && !valid_count_response(yi) {
919 return Err(HmcError::InvalidConfig {
920 reason: format!(
921 "{family} response must be a finite non-negative integer at positive-weight row {i}; got {yi}"
922 ),
923 });
924 }
925 }
926 Ok(())
927}
928
929fn validate_binary_responses(
930 family: &str,
931 y: &ArrayView1<'_, f64>,
932 weights: &ArrayView1<'_, f64>,
933) -> Result<(), HmcError> {
934 for (i, (&yi, &wi)) in y.iter().zip(weights.iter()).enumerate() {
935 if wi > 0.0 && !(yi == 0.0 || yi == 1.0) {
936 return Err(HmcError::InvalidConfig {
937 reason: format!(
938 "{family} response must be exactly 0 or 1 at positive-weight row {i}; got {yi}"
939 ),
940 });
941 }
942 }
943 Ok(())
944}
945
946fn firth_jeffreys_logp_and_grad(
953 family: NutsFamily,
954 data: &SharedData,
955 eta: &Array1<f64>,
956) -> Result<(f64, Array1<f64>), HmcError> {
957 if eta.len() != data.n_samples {
958 return Err(HmcError::DimensionMismatch {
959 reason: format!(
960 "Firth Jeffreys term eta length {} != number of samples {}",
961 eta.len(),
962 data.n_samples
963 ),
964 });
965 }
966 if data.dim == 0 || data.n_samples == 0 {
967 return Ok((0.0, Array1::zeros(data.dim)));
968 }
969 validate_firth_support(family, true)?;
970 if data.weights.iter().all(|w| *w == 0.0) {
971 return Ok((0.0, Array1::zeros(data.dim)));
972 }
973
974 let jeffreys_link =
975 likelihood_spec_jeffreys_link(&family.likelihood_spec()).ok_or_else(|| {
976 HmcError::FirthUnsupported {
977 reason: format!(
978 "Firth Jeffreys term has no Fisher-weight jet for {}",
979 family.likelihood_spec().pretty_name()
980 ),
981 }
982 })?;
983 let op = if data.weights.iter().all(|&w| w == 1.0) {
984 FirthDenseOperator::build_for_link(&jeffreys_link, data.x.as_ref(), eta)
985 } else {
986 FirthDenseOperator::build_with_observation_weights_for_link(
987 &jeffreys_link,
988 data.x.as_ref(),
989 eta,
990 data.weights.view(),
991 )
992 }
993 .map_err(|e| HmcError::SamplingFailed {
994 reason: format!("Firth Jeffreys operator failed: {e}"),
995 })?;
996 Ok(op.jeffreys_logdet_and_beta_gradient())
997}
998
999fn nuts_family_logp_and_grad_into(
1008 family: NutsFamily,
1009 data: &SharedData,
1010 eta: &Array1<f64>,
1011 residual: &mut Array1<f64>,
1012) -> (f64, Array1<f64>) {
1013 match family {
1014 NutsFamily::BinomialLogit => logit_logp_and_grad_into(data, eta, residual),
1015 NutsFamily::BinomialProbit => probit_logp_and_grad_into(data, eta, residual),
1016 NutsFamily::BinomialCLogLog => cloglog_logp_and_grad_into(data, eta, residual),
1017 NutsFamily::Gaussian => gaussian_logp_and_grad_into(data, eta, residual),
1018 NutsFamily::PoissonLog => poisson_log_logp_and_grad(data, eta),
1019 NutsFamily::TweedieLog => tweedie_log_quasilogp_and_grad(data, eta, data.gamma_shape),
1022 NutsFamily::NegativeBinomialLog => {
1023 negative_binomial_log_logp_and_grad(data, eta, data.gamma_shape)
1026 }
1027 NutsFamily::GammaLog => gamma_log_logp_and_grad(data, eta),
1028 }
1029}
1030
1031#[derive(Clone, Debug)]
1032struct BinomialLinkTerms {
1033 log_mu: f64,
1034 log1m_mu: f64,
1035 dlog_mu_deta: f64,
1036 dlog1m_mu_deta: f64,
1037 dmu_dlink: Vec<f64>,
1038}
1039
1040#[inline]
1041fn log_terms_from_mu_and_dmu(
1042 mu: f64,
1043 dmu_deta: f64,
1044 dmu_dlink: Vec<f64>,
1045) -> Result<BinomialLinkTerms, String> {
1046 if !(mu.is_finite() && (0.0..=1.0).contains(&mu) && dmu_deta.is_finite()) {
1047 return Err(format!(
1048 "binomial inverse link returned invalid mu/deta derivative: mu={mu}, dmu_deta={dmu_deta}"
1049 ));
1050 }
1051 let log_mu = if mu == 0.0 {
1052 f64::NEG_INFINITY
1053 } else {
1054 mu.ln()
1055 };
1056 let one_minus_mu = 1.0 - mu;
1057 let log1m_mu = if one_minus_mu == 0.0 {
1058 f64::NEG_INFINITY
1059 } else {
1060 one_minus_mu.ln()
1061 };
1062 let dlog_mu_deta = if mu == 0.0 {
1063 f64::INFINITY.copysign(dmu_deta)
1064 } else {
1065 dmu_deta / mu
1066 };
1067 let dlog1m_mu_deta = if one_minus_mu == 0.0 {
1068 f64::NEG_INFINITY.copysign(dmu_deta)
1069 } else {
1070 -dmu_deta / one_minus_mu
1071 };
1072 Ok(BinomialLinkTerms {
1073 log_mu,
1074 log1m_mu,
1075 dlog_mu_deta,
1076 dlog1m_mu_deta,
1077 dmu_dlink,
1078 })
1079}
1080
1081#[inline]
1082fn binomial_link_terms(
1083 inverse_link: &InverseLink,
1084 eta: f64,
1085 n_link_params: usize,
1086) -> Result<BinomialLinkTerms, String> {
1087 let jet =
1088 inverse_link_jet_for_inverse_link(inverse_link, eta).map_err(|err| err.to_string())?;
1089 let mut dmu_dlink = vec![0.0; n_link_params];
1090 if n_link_params > 0 {
1091 match inverse_link
1092 .param_partials(eta)
1093 .map_err(|err| err.to_string())?
1094 {
1095 Some(LinkParamPartials::Sas(partials)) => {
1096 if n_link_params != 2 {
1097 return Err(format!(
1098 "SAS/Beta-Logistic link parameter dimension mismatch: expected 2, got {n_link_params}"
1099 ));
1100 }
1101 dmu_dlink[0] = partials.djet_depsilon.mu;
1102 dmu_dlink[1] = partials.djet_dlog_delta.mu;
1103 }
1104 Some(LinkParamPartials::Mixture(partials)) => {
1105 if partials.djet_drho.len() != n_link_params {
1106 return Err(format!(
1107 "mixture link parameter dimension mismatch: expected {}, got {n_link_params}",
1108 partials.djet_drho.len()
1109 ));
1110 }
1111 for (slot, partial) in dmu_dlink.iter_mut().zip(partials.djet_drho.iter()) {
1112 *slot = partial.mu;
1113 }
1114 }
1115 None => {
1116 return Err(format!(
1117 "joint HMC expected {n_link_params} adaptive link parameters, but the inverse link exposes none"
1118 ));
1119 }
1120 }
1121 }
1122 log_terms_from_mu_and_dmu(jet.mu, jet.d1, dmu_dlink)
1123}
1124
1125fn joint_binomial_logp_grad_and_link_grad(
1126 inverse_link: &InverseLink,
1127 data: &SharedData,
1128 eta: &Array1<f64>,
1129 n_link_params: usize,
1130) -> Result<(f64, Array1<f64>, Array1<f64>), String> {
1131 let n = data.n_samples;
1132 use rayon::iter::{IntoParallelIterator, ParallelIterator};
1136 let per_row: Result<Vec<(f64, f64, Vec<f64>)>, String> = (0..n)
1137 .into_par_iter()
1138 .map(|i| {
1139 let y_i = data.y[i];
1140 let w_i = data.weights[i];
1141 if w_i <= 0.0 {
1142 return Ok((0.0, 0.0, vec![0.0; n_link_params]));
1143 }
1144 let terms = binomial_link_terms(inverse_link, eta[i], n_link_params)?;
1145 if y_i == 1.0 {
1146 let inv_mu = terms.log_mu.exp().recip();
1147 let log_mu = terms.log_mu;
1148 let dlog_mu_deta = terms.dlog_mu_deta;
1149 let grad_link = terms
1150 .dmu_dlink
1151 .into_iter()
1152 .map(|dmu| w_i * dmu * inv_mu)
1153 .collect();
1154 Ok((w_i * log_mu, w_i * dlog_mu_deta, grad_link))
1155 } else if y_i == 0.0 {
1156 let inv_one_minus_mu = terms.log1m_mu.exp().recip();
1157 let log1m_mu = terms.log1m_mu;
1158 let dlog1m_mu_deta = terms.dlog1m_mu_deta;
1159 let grad_link = terms
1160 .dmu_dlink
1161 .into_iter()
1162 .map(|dmu| -w_i * dmu * inv_one_minus_mu)
1163 .collect();
1164 Ok((w_i * log1m_mu, w_i * dlog1m_mu_deta, grad_link))
1165 } else {
1166 Err(format!(
1167 "binomial joint HMC response must be exactly 0 or 1 after validation; got {y_i}"
1168 ))
1169 }
1170 })
1171 .collect();
1172 let per_row = per_row?;
1173 let mut residual = Array1::<f64>::zeros(n);
1174 let mut grad_link = Array1::<f64>::zeros(n_link_params);
1175 let mut ll = 0.0;
1176 for (i, (ll_i, residual_i, grad_link_i)) in per_row.into_iter().enumerate() {
1177 ll += ll_i;
1178 residual[i] = residual_i;
1179 for (slot, value) in grad_link.iter_mut().zip(grad_link_i.iter()) {
1180 *slot += *value;
1181 }
1182 }
1183
1184 Ok((ll, fast_atv(&data.x, &residual), grad_link))
1185}
1186
1187fn joint_binomial_logp_and_grad(
1188 likelihood: &LikelihoodSpec,
1189 data: &SharedData,
1190 eta: &Array1<f64>,
1191) -> Result<(f64, Array1<f64>), String> {
1192 if !matches!(likelihood.response, ResponseFamily::Binomial) {
1193 return Err(HmcError::UnsupportedFamily {
1194 reason: format!(
1195 "{} is not a binomial joint-HMC family",
1196 likelihood.pretty_name()
1197 ),
1198 }
1199 .into());
1200 }
1201 match &likelihood.link {
1202 InverseLink::Standard(StandardLink::Logit) => Ok(logit_logp_and_grad(data, eta)),
1203 InverseLink::Standard(StandardLink::Probit) => Ok(probit_logp_and_grad(data, eta)),
1204 InverseLink::Standard(StandardLink::CLogLog) => Ok(cloglog_logp_and_grad(data, eta)),
1205 InverseLink::LatentCLogLog(_)
1206 | InverseLink::Sas(_)
1207 | InverseLink::BetaLogistic(_)
1208 | InverseLink::Mixture(_) => {
1209 let (ll, grad_beta, _) =
1210 joint_binomial_logp_grad_and_link_grad(&likelihood.link, data, eta, 0)?;
1211 Ok((ll, grad_beta))
1212 }
1213 InverseLink::Standard(_) => Err(HmcError::UnsupportedFamily {
1214 reason: format!(
1215 "{} is not a binomial joint-HMC family",
1216 likelihood.pretty_name()
1217 ),
1218 }
1219 .into()),
1220 }
1221}
1222
1223fn joint_family_logp_grad_and_link_grad(
1224 likelihood: &LikelihoodSpec,
1225 data: &SharedData,
1226 eta: &Array1<f64>,
1227 n_link_params: usize,
1228) -> Result<(f64, Array1<f64>, Array1<f64>), String> {
1229 match (&likelihood.response, &likelihood.link) {
1230 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
1231 let (ll, grad) = logit_logp_and_grad(data, eta);
1232 Ok((ll, grad, Array1::zeros(n_link_params)))
1233 }
1234 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit)) => {
1235 let (ll, grad) = probit_logp_and_grad(data, eta);
1236 Ok((ll, grad, Array1::zeros(n_link_params)))
1237 }
1238 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {
1239 let (ll, grad) = cloglog_logp_and_grad(data, eta);
1240 Ok((ll, grad, Array1::zeros(n_link_params)))
1241 }
1242 (
1243 ResponseFamily::Binomial,
1244 InverseLink::LatentCLogLog(_)
1245 | InverseLink::Sas(_)
1246 | InverseLink::BetaLogistic(_)
1247 | InverseLink::Mixture(_),
1248 ) => joint_binomial_logp_grad_and_link_grad(&likelihood.link, data, eta, n_link_params),
1249 _ => {
1250 let (ll, grad) = joint_family_logp_and_grad(likelihood, data, eta)?;
1251 Ok((ll, grad, Array1::zeros(n_link_params)))
1252 }
1253 }
1254}
1255
1256fn joint_family_logp_and_grad(
1257 likelihood: &LikelihoodSpec,
1258 data: &SharedData,
1259 eta: &Array1<f64>,
1260) -> Result<(f64, Array1<f64>), String> {
1261 match &likelihood.response {
1262 ResponseFamily::Binomial => joint_binomial_logp_and_grad(likelihood, data, eta),
1263 ResponseFamily::Gaussian => Ok(gaussian_logp_and_grad(data, eta)),
1264 ResponseFamily::Poisson => Ok(poisson_log_logp_and_grad(data, eta)),
1265 ResponseFamily::Tweedie { p } => {
1266 let p = *p;
1269 if !is_valid_tweedie_power(p) {
1270 return Err(HmcError::InvalidConfig {
1271 reason: format!(
1272 "Tweedie variance power must be finite and strictly between 1 and 2; got {p}"
1273 ),
1274 }
1275 .into());
1276 }
1277 Ok(tweedie_log_quasilogp_and_grad(data, eta, p))
1278 }
1279 ResponseFamily::NegativeBinomial { theta, .. } => {
1280 Ok(negative_binomial_log_logp_and_grad(data, eta, *theta))
1283 }
1284 ResponseFamily::Beta { .. } => Err(HmcError::UnsupportedFamily {
1285 reason: "Joint HMC fallback is not implemented for BetaLogit".to_string(),
1286 }
1287 .into()),
1288 ResponseFamily::Gamma => Ok(gamma_log_logp_and_grad(data, eta)),
1289 ResponseFamily::RoystonParmar => Err(HmcError::UnsupportedFamily {
1290 reason: "Joint HMC fallback is not implemented for RoystonParmar".to_string(),
1291 }
1292 .into()),
1293 }
1294}
1295
1296fn logit_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1300 let mut residual = Array1::<f64>::zeros(data.n_samples);
1301 logit_logp_and_grad_into(data, eta, &mut residual)
1302}
1303
1304fn logit_logp_and_grad_into(
1305 data: &SharedData,
1306 eta: &Array1<f64>,
1307 residual: &mut Array1<f64>,
1308) -> (f64, Array1<f64>) {
1309 use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1310 let n = data.n_samples;
1311 assert_eq!(residual.len(), n);
1312 let ll: f64 = residual
1316 .as_slice_mut()
1317 .unwrap()
1318 .par_iter_mut()
1319 .enumerate()
1320 .map(|(i, slot)| {
1321 let eta_i = eta[i];
1322 let y_i = data.y[i];
1323 let w_i = data.weights[i];
1324 let mu = gam_linalg::utils::stable_logistic(eta_i);
1325 *slot = w_i * (y_i - mu);
1326 w_i * (y_i * eta_i - gam_linalg::utils::stable_softplus(eta_i))
1327 })
1328 .sum();
1329
1330 let grad_ll = fast_atv(data.x.as_ref(), &*residual);
1331 (ll, grad_ll)
1332}
1333
1334fn probit_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1341 let mut residual = Array1::<f64>::zeros(data.n_samples);
1342 probit_logp_and_grad_into(data, eta, &mut residual)
1343}
1344
1345fn probit_logp_and_grad_into(
1346 data: &SharedData,
1347 eta: &Array1<f64>,
1348 residual: &mut Array1<f64>,
1349) -> (f64, Array1<f64>) {
1350 use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1351 let n = data.n_samples;
1352 assert_eq!(residual.len(), n);
1353 let ll: f64 = residual
1354 .as_slice_mut()
1355 .unwrap()
1356 .par_iter_mut()
1357 .enumerate()
1358 .map(|(i, slot)| {
1359 let eta_i = eta[i];
1360 let y_i = data.y[i];
1361 let w_i = data.weights[i];
1362 let log_phi_pos = log_ndtr(eta_i);
1363 let log_phi_neg = log_ndtr(-eta_i);
1364 let log_phi_val = standard_normal_log_pdf(eta_i);
1365 let ratio_pos = (log_phi_val - log_phi_pos).exp();
1366 let ratio_neg = (log_phi_val - log_phi_neg).exp();
1367 let grad_i = y_i * ratio_pos - (1.0 - y_i) * ratio_neg;
1368 *slot = w_i * grad_i;
1369 w_i * (y_i * log_phi_pos + (1.0 - y_i) * log_phi_neg)
1370 })
1371 .sum();
1372
1373 let grad_ll = fast_atv(data.x.as_ref(), &*residual);
1374 (ll, grad_ll)
1375}
1376
1377#[inline]
1383fn cloglog_bernoulli_logp_and_residual(eta: f64, y: f64) -> Result<(f64, f64), EstimationError> {
1384 if !(eta.is_finite() && (-700.0..=700.0).contains(&eta)) {
1385 gam_problem::bail_invalid_estim!("cloglog eta must be finite and within [-700, 700]; got {eta}");
1386 }
1387 let exp_eta = eta.exp();
1388 let log_mu = crate::probability::log1mexp_positive(exp_eta);
1391 let log_one_minus_mu = -exp_eta;
1392 let grad_log_mu = (eta - exp_eta - log_mu).exp();
1393 let ll_i = y * log_mu + (1.0 - y) * log_one_minus_mu;
1394 let residual_i = y * grad_log_mu - (1.0 - y) * exp_eta;
1395 Ok((ll_i, residual_i))
1396}
1397
1398fn cloglog_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1399 let mut residual = Array1::<f64>::zeros(data.n_samples);
1400 cloglog_logp_and_grad_into(data, eta, &mut residual)
1401}
1402
1403fn cloglog_logp_and_grad_into(
1404 data: &SharedData,
1405 eta: &Array1<f64>,
1406 residual: &mut Array1<f64>,
1407) -> (f64, Array1<f64>) {
1408 use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1409 let n = data.n_samples;
1410 assert_eq!(residual.len(), n);
1411 if eta
1412 .iter()
1413 .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1414 {
1415 return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1416 }
1417 let ll: f64 = residual
1418 .as_slice_mut()
1419 .unwrap()
1420 .par_iter_mut()
1421 .enumerate()
1422 .map(|(i, slot)| {
1423 let y_i = data.y[i];
1424 let w_i = data.weights[i];
1425 let (ll_i, residual_i) =
1426 cloglog_bernoulli_logp_and_residual(eta[i], y_i).expect("validated cloglog eta");
1427 *slot = w_i * residual_i;
1428 w_i * ll_i
1429 })
1430 .sum();
1431
1432 let grad_ll = fast_atv(data.x.as_ref(), &*residual);
1433 (ll, grad_ll)
1434}
1435
1436fn gaussian_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1448 let mut weighted_residual = Array1::<f64>::zeros(data.n_samples);
1449 gaussian_logp_and_grad_into(data, eta, &mut weighted_residual)
1450}
1451
1452fn gaussian_logp_and_grad_into(
1453 data: &SharedData,
1454 eta: &Array1<f64>,
1455 weighted_residual: &mut Array1<f64>,
1456) -> (f64, Array1<f64>) {
1457 use gam_problem::dispersion_cov::DispersionExt as _;
1458 use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1459 let n = data.n_samples;
1460 let inv_phi = data.dispersion.inv_phi();
1461 assert_eq!(weighted_residual.len(), n);
1462 let ll: f64 = weighted_residual
1465 .as_slice_mut()
1466 .unwrap()
1467 .par_iter_mut()
1468 .enumerate()
1469 .map(|(i, slot)| {
1470 let residual = data.y[i] - eta[i];
1471 let w_i = data.weights[i];
1472 let scaled = w_i * inv_phi;
1473 *slot = scaled * residual;
1474 -0.5 * scaled * residual * residual
1475 })
1476 .sum();
1477
1478 let grad_ll = fast_atv(data.x.as_ref(), &*weighted_residual);
1479 (ll, grad_ll)
1480}
1481
1482fn poisson_log_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1486 use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1487 let n = data.n_samples;
1488 if eta
1489 .iter()
1490 .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1491 {
1492 return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1493 }
1494 let mut residual = Array1::<f64>::zeros(n);
1495 let ll: f64 = residual
1496 .as_slice_mut()
1497 .unwrap()
1498 .par_iter_mut()
1499 .enumerate()
1500 .map(|(i, slot)| {
1501 let eta_i = eta[i];
1502 let mu_i = eta_i.exp();
1503 let y_i = data.y[i];
1504 let w_i = data.weights[i];
1505 *slot = w_i * (y_i - mu_i);
1506 w_i * (y_i * eta_i - mu_i)
1507 })
1508 .sum();
1509
1510 let grad_ll = fast_atv(&data.x, &residual);
1511 (ll, grad_ll)
1512}
1513
1514fn tweedie_log_quasilogp_and_grad(
1515 data: &SharedData,
1516 eta: &Array1<f64>,
1517 p: f64,
1518) -> (f64, Array1<f64>) {
1519 use gam_problem::dispersion_cov::DispersionExt as _;
1520 use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1521 let n = data.n_samples;
1522 if !is_valid_tweedie_power(p) {
1525 return (f64::NAN, Array1::from_elem(data.dim, f64::NAN));
1526 }
1527 if eta
1528 .iter()
1529 .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1530 {
1531 return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1532 }
1533 let inv_phi = data.dispersion.inv_phi();
1534 let mut residual = Array1::<f64>::zeros(n);
1535 let ll: f64 = residual
1536 .as_slice_mut()
1537 .unwrap()
1538 .par_iter_mut()
1539 .enumerate()
1540 .map(|(i, slot)| {
1541 let eta_i = eta[i];
1542 let mu_i = eta_i.exp().max(1e-300);
1543 let y_i = data.y[i];
1544 let w_i = data.weights[i] * inv_phi;
1545 *slot = w_i * (y_i - mu_i) * mu_i.powf(1.0 - p);
1546 let qll = y_i * mu_i.powf(1.0 - p) / (1.0 - p) - mu_i.powf(2.0 - p) / (2.0 - p);
1547 w_i * qll
1548 })
1549 .sum();
1550
1551 let grad_ll = fast_atv(&data.x, &residual);
1552 (ll, grad_ll)
1553}
1554
1555fn negative_binomial_log_logp_and_grad(
1556 data: &SharedData,
1557 eta: &Array1<f64>,
1558 theta: f64,
1559) -> (f64, Array1<f64>) {
1560 use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1561 let n = data.n_samples;
1562 if !(theta.is_finite() && theta > 0.0)
1563 || eta
1564 .iter()
1565 .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1566 || data
1567 .y
1568 .iter()
1569 .zip(data.weights.iter())
1570 .any(|(&y_i, &w_i)| w_i > 0.0 && !valid_count_response(y_i))
1571 {
1572 return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1573 }
1574 let mut residual = Array1::<f64>::zeros(n);
1575 let ll: f64 = residual
1576 .as_slice_mut()
1577 .unwrap()
1578 .par_iter_mut()
1579 .enumerate()
1580 .map(|(i, slot)| {
1581 let eta_i = eta[i];
1582 let mu_i = eta_i.exp().max(1e-12);
1583 let y_i = data.y[i];
1584 let w_i = data.weights[i];
1585 if w_i <= 0.0 {
1586 *slot = 0.0;
1587 return 0.0;
1588 }
1589 let log_mu_term = if y_i > 0.0 { y_i * mu_i.ln() } else { 0.0 };
1590 *slot = w_i * theta * (y_i - mu_i) / (theta + mu_i);
1591 w_i * (statrs::function::gamma::ln_gamma(y_i + theta)
1592 - statrs::function::gamma::ln_gamma(theta)
1593 - statrs::function::gamma::ln_gamma(y_i + 1.0)
1594 + theta * (theta.ln() - (theta + mu_i).ln())
1595 + log_mu_term
1596 - y_i * (theta + mu_i).ln())
1597 })
1598 .sum();
1599
1600 let grad_ll = fast_atv(&data.x, &residual);
1601 (ll, grad_ll)
1602}
1603
1604fn gamma_log_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1605 use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1606 let n = data.n_samples;
1607 if eta
1608 .iter()
1609 .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1610 {
1611 return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1612 }
1613 let shape = data.gamma_shape.max(1e-10);
1614 let shape_ln_shape = shape * shape.ln();
1619 let log_gamma_shape = statrs::function::gamma::ln_gamma(shape);
1620 let shape_minus_one = shape - 1.0;
1621 let mut residual = Array1::<f64>::zeros(n);
1622 let ll: f64 = residual
1623 .as_slice_mut()
1624 .unwrap()
1625 .par_iter_mut()
1626 .enumerate()
1627 .map(|(i, slot)| {
1628 let eta_i = eta[i];
1629 let mu_i = eta_i.exp();
1630 let y_i = data.y[i];
1631 let w_i = data.weights[i];
1632 let ll_i = w_i
1633 * (shape_ln_shape - log_gamma_shape - shape * eta_i
1634 + shape_minus_one * y_i.max(1e-12).ln()
1635 - shape * y_i / mu_i);
1636 *slot = w_i * shape * (y_i / mu_i - 1.0);
1637 ll_i
1638 })
1639 .sum();
1640
1641 let grad_ll = fast_atv(&data.x, &residual);
1642 (ll, grad_ll)
1643}
1644
1645#[cfg(test)]
1646mod tests {
1647 use super::{
1648 FamilyNutsInputs, GlmFlatInputs, JointBetaRhoInputs, JointBetaRhoPosterior,
1649 LinkWigglePosterior, LinkWiggleSplineArtifacts, NutsConfig, NutsFamily, NutsPosterior,
1650 SharedData, cloglog_bernoulli_logp_and_residual, firth_jeffreys_logp_and_grad,
1651 joint_family_logp_and_grad, laplace_directional_cubic_diagnostic,
1652 laplace_skewness_threshold, laplace_trustworthiness_from_skewness,
1653 run_joint_beta_rho_sampling, run_logit_polya_gamma_gibbs,
1654 run_nuts_sampling_flattened_family,
1655 };
1656 use gam_terms::construction::CanonicalPenalty;
1657 use gam_solve::estimate::{
1658 BlockRole, FitGeometry, FitInference, FittedBlock, FittedLinkState, UnifiedFitResult,
1659 UnifiedFitResultParts,
1660 };
1661 use gam_models::survival::{PenaltyBlocks, SurvivalMonotonicityPenalty, SurvivalSpec};
1662 use gam_linalg::matrix::DesignMatrix;
1663 use gam_problem::types::{
1664 InverseLink, LikelihoodScaleMetadata, LikelihoodSpec, LogLikelihoodNormalization,
1665 ResponseFamily, RhoPrior, StandardLink,
1666 };
1667 use general_mcmc::generic_hmc::HamiltonianTarget;
1668 use ndarray::{Array1, Array2, array};
1669 use std::sync::Arc;
1670
1671 impl NutsPosterior {
1672 pub(super) fn compute_logp_and_grad_nd(&self, z: &Array1<f64>) -> (f64, Array1<f64>) {
1674 let mut residual = Array1::<f64>::zeros(self.data.n_samples);
1675 let mut grad = Array1::<f64>::zeros(z.len());
1676 let logp = self.compute_logp_and_grad_nd_into(z, &mut residual, &mut grad);
1677 (logp, grad)
1678 }
1679 }
1680
1681 impl LinkWigglePosterior {
1682 pub(super) fn compute_logp_and_grad(&self, z: &Array1<f64>) -> (f64, Array1<f64>) {
1684 let dim = self.p_base + self.p_link;
1685 let mut grad = Array1::<f64>::zeros(dim);
1686 let logp = self.compute_logp_and_grad_into(z, &mut grad);
1687 (logp, grad)
1688 }
1689 }
1690
1691 impl JointBetaRhoPosterior {
1692 pub(super) fn compute_joint_logp_and_grad(
1694 &self,
1695 params: &Array1<f64>,
1696 ) -> (f64, Array1<f64>) {
1697 let total_dim = self.n_beta + self.n_rho + self.n_link_params;
1698 let mut grad = Array1::<f64>::zeros(total_dim);
1699 let logp = self.compute_joint_logp_and_grad_into(params, &mut grad);
1700 (logp, grad)
1701 }
1702 }
1703
1704 fn hmc_test_fit(
1705 blocks: Vec<FittedBlock>,
1706 inference: Option<FitInference>,
1707 geometry: Option<FitGeometry>,
1708 ) -> UnifiedFitResult {
1709 let lambdas = Array1::zeros(0);
1710 UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
1711 blocks,
1712 log_lambdas: lambdas.clone(),
1713 lambdas,
1714 likelihood_family: Some(LikelihoodSpec::new(
1715 ResponseFamily::Gaussian,
1716 InverseLink::Standard(StandardLink::Identity),
1717 )),
1718 likelihood_scale: LikelihoodScaleMetadata::ProfiledGaussian,
1719 log_likelihood_normalization: LogLikelihoodNormalization::Full,
1720 log_likelihood: -1.0,
1721 deviance: 2.0,
1722 reml_score: 0.0,
1723 stable_penalty_term: 0.0,
1724 penalized_objective: 0.0,
1725 used_device: false,
1726 outer_iterations: 1,
1727 outer_converged: true,
1728 outer_gradient_norm: None,
1729 standard_deviation: 1.0,
1730 covariance_conditional: None,
1731 covariance_corrected: None,
1732 inference,
1733 fitted_link: FittedLinkState::Standard(None),
1734 geometry,
1735 block_states: Vec::new(),
1736 pirls_status: gam_solve::pirls::PirlsStatus::Converged,
1737 max_abs_eta: 0.0,
1738 constraint_kkt: None,
1739 artifacts: Default::default(),
1740 inner_cycles: 0,
1741 })
1742 .expect("valid HMC handoff test fit")
1743 }
1744
1745 #[test]
1746 fn hmc_whitening_consumes_standard_fit_inference_hessian() {
1747 let hessian = array![[2.0, 0.1], [0.1, 1.6]];
1748 let fit = hmc_test_fit(
1749 vec![FittedBlock {
1750 beta: array![0.05, -0.1],
1751 role: BlockRole::Mean,
1752 edf: 2.0,
1753 lambdas: Array1::zeros(0),
1754 }],
1755 Some(FitInference {
1756 edf_by_block: vec![],
1757 penalty_block_trace: vec![],
1758 edf_total: 2.0,
1759 smoothing_correction: None,
1760 penalized_hessian: hessian.clone().into(),
1761 working_weights: array![1.0, 1.0, 1.0],
1762 working_response: array![0.0, 0.1, -0.2],
1763 reparam_qs: None,
1764 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
1765 beta_covariance: None,
1766 beta_standard_errors: None,
1767 beta_covariance_corrected: None,
1768 beta_standard_errors_corrected: None,
1769 beta_covariance_frequentist: None,
1770 coefficient_influence: None,
1771 weighted_gram: None,
1772 bias_correction_beta: None,
1773 }),
1774 None,
1775 );
1776
1777 let explicit = super::explicit_fit_hessian_for_whitening(&fit, 2, "standard fit")
1778 .expect("standard fit exports explicit Hessian");
1779 assert_eq!(explicit, &hessian);
1780
1781 let x = array![[1.0, 0.0], [1.0, 0.5], [1.0, -0.5]];
1782 let y = array![0.0, 0.2, -0.1];
1783 let weights = Array1::ones(3);
1784 let penalty = Array2::eye(2);
1785 NutsPosterior::new(
1786 x.view(),
1787 y.view(),
1788 weights.view(),
1789 penalty.view(),
1790 fit.beta.view(),
1791 explicit.view(),
1792 NutsFamily::Gaussian,
1793 1.0,
1794 gam_solve::estimate::Dispersion::Known(1.0),
1795 false,
1796 )
1797 .expect("HMC target whitens with upstream Hessian");
1798 }
1799
1800 #[test]
1801 fn hmc_whitening_consumes_blockwise_geometry_hessian() {
1802 let hessian = array![[3.0, 0.2], [0.2, 2.0]];
1803 let fit = hmc_test_fit(
1804 vec![
1805 FittedBlock {
1806 beta: array![0.1],
1807 role: BlockRole::Location,
1808 edf: 1.0,
1809 lambdas: Array1::zeros(0),
1810 },
1811 FittedBlock {
1812 beta: array![-0.2],
1813 role: BlockRole::Scale,
1814 edf: 1.0,
1815 lambdas: Array1::zeros(0),
1816 },
1817 ],
1818 None,
1819 Some(FitGeometry {
1820 penalized_hessian: hessian.clone().into(),
1821 working_weights: array![1.0, 0.8],
1822 working_response: array![0.0, 0.1],
1823 }),
1824 );
1825
1826 let explicit = super::explicit_fit_hessian_for_whitening(&fit, 2, "blockwise fit")
1827 .expect("blockwise fit exports materialized Hessian");
1828 assert_eq!(explicit, &hessian);
1829 }
1830
1831 #[test]
1832 fn hmc_whitening_rejects_covariance_only_fit_without_synthesizing_hessian() {
1833 let fit = UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
1834 blocks: vec![FittedBlock {
1835 beta: array![0.0],
1836 role: BlockRole::Mean,
1837 edf: 1.0,
1838 lambdas: Array1::zeros(0),
1839 }],
1840 log_lambdas: Array1::zeros(0),
1841 lambdas: Array1::zeros(0),
1842 likelihood_family: Some(LikelihoodSpec::new(
1843 ResponseFamily::Gaussian,
1844 InverseLink::Standard(StandardLink::Identity),
1845 )),
1846 likelihood_scale: LikelihoodScaleMetadata::ProfiledGaussian,
1847 log_likelihood_normalization: LogLikelihoodNormalization::Full,
1848 log_likelihood: -1.0,
1849 deviance: 2.0,
1850 reml_score: 0.0,
1851 stable_penalty_term: 0.0,
1852 penalized_objective: 0.0,
1853 used_device: false,
1854 outer_iterations: 1,
1855 outer_converged: true,
1856 outer_gradient_norm: None,
1857 standard_deviation: 1.0,
1858 covariance_conditional: Some(array![[0.5]]),
1859 covariance_corrected: None,
1860 inference: None,
1861 fitted_link: FittedLinkState::Standard(None),
1862 geometry: None,
1863 block_states: Vec::new(),
1864 pirls_status: gam_solve::pirls::PirlsStatus::Converged,
1865 max_abs_eta: 0.0,
1866 constraint_kkt: None,
1867 artifacts: Default::default(),
1868 inner_cycles: 0,
1869 })
1870 .expect("covariance-only fit can exist for prediction");
1871
1872 let err = super::explicit_fit_hessian_for_whitening(&fit, 1, "covariance-only fit")
1873 .expect_err("HMC must not invert covariance as a Hessian fallback");
1874 assert!(
1875 err.contains("missing an explicit penalized Hessian"),
1876 "unexpected error: {err}"
1877 );
1878 }
1879
1880 #[test]
1881 fn log1pexp_is_finite_for_extreme_eta() {
1882 assert!(gam_linalg::utils::stable_softplus(1000.0).is_finite());
1883 assert!(gam_linalg::utils::stable_softplus(-1000.0).is_finite());
1884 assert!((gam_linalg::utils::stable_softplus(-1000.0) - 0.0).abs() < 1e-12);
1885 }
1886
1887 #[test]
1888 fn sigmoid_stable_behaves_at_extremes() {
1889 let hi = gam_linalg::utils::stable_logistic(1000.0);
1890 let lo = gam_linalg::utils::stable_logistic(-1000.0);
1891 assert!((1.0 - 1e-12..=1.0).contains(&hi));
1892 assert!((0.0..=1e-12).contains(&lo));
1893 }
1894
1895 #[test]
1896 fn cloglog_log_mu_uses_complementary_loglog_inverse_link() {
1897 let eta = -1.0_f64;
1898 let (ll_y1, residual_y1) =
1899 cloglog_bernoulli_logp_and_residual(eta, 1.0).expect("valid eta");
1900 let expected = (1.0 - (-eta.exp()).exp()).ln();
1901 let wrong_log_one_minus_exp_eta = (1.0 - eta.exp()).ln();
1902
1903 assert!((ll_y1 - expected).abs() < 1e-14);
1904 assert!((ll_y1 - wrong_log_one_minus_exp_eta).abs() > 0.5);
1905
1906 let eps = 1e-6;
1907 let (lp, _) = cloglog_bernoulli_logp_and_residual(eta + eps, 1.0).expect("valid eta");
1908 let (lm, _) = cloglog_bernoulli_logp_and_residual(eta - eps, 1.0).expect("valid eta");
1909 let fd = (lp - lm) / (2.0 * eps);
1910 assert!(
1911 (residual_y1 - fd).abs() < 1e-9,
1912 "cloglog residual is not the derivative of log μ: analytic={residual_y1}, fd={fd}"
1913 );
1914 }
1915
1916 #[test]
1917 fn link_wiggle_posterior_whitening_uses_supplied_explicit_joint_hessian() {
1918 let x = array![[1.0], [1.0], [1.0]];
1919 let y = array![0.0, 1.0, 1.0];
1920 let weights = Array1::ones(3);
1921 let penalty_base = Array2::zeros((1, 1));
1922 let penalty_link = Array2::zeros((1, 1));
1923 let mode_beta = array![0.2];
1924 let mode_theta = array![0.05];
1925 let hessian = array![[4.0, 1.0], [1.0, 3.0]];
1926 let spline = LinkWiggleSplineArtifacts {
1927 knot_range: (-1.0, 1.0),
1928 knot_vector: Array1::from_vec(vec![-1.0, -1.0, -1.0, 1.0, 1.0, 1.0]),
1929 degree: 2,
1930 };
1931
1932 let posterior = LinkWigglePosterior::new(
1933 x.view(),
1934 y.view(),
1935 weights.view(),
1936 penalty_base.view(),
1937 penalty_link.view(),
1938 mode_beta.view(),
1939 mode_theta.view(),
1940 hessian.view(),
1941 spline,
1942 NutsFamily::BinomialLogit,
1943 1.0,
1944 )
1945 .expect("link-wiggle posterior should accept explicit SPD joint Hessian");
1946
1947 let reconstructed_cov = posterior.chol().dot(&posterior.chol().t());
1948 let eye_from_hessian = hessian.dot(&reconstructed_cov);
1949 for r in 0..2 {
1950 for c in 0..2 {
1951 let expected = if r == c { 1.0 } else { 0.0 };
1952 assert!(
1953 (eye_from_hessian[[r, c]] - expected).abs() < 1e-10,
1954 "whitening did not use the supplied explicit joint Hessian at ({r},{c}): got {} expected {}",
1955 eye_from_hessian[[r, c]],
1956 expected
1957 );
1958 }
1959 }
1960 }
1961
1962 #[test]
1963 fn link_wiggle_cloglog_gradient_matches_its_log_likelihood() {
1964 let x = array![[1.0], [1.0], [1.0], [1.0]];
1965 let y = array![1.0, 0.0, 1.0, 0.0];
1966 let weights = array![1.0, 1.2, 0.8, 1.4];
1967 let penalty_base = Array2::zeros((1, 1));
1968 let penalty_link = Array2::zeros((1, 1));
1969 let mode_beta = array![-0.8];
1970 let mode_theta = array![0.04];
1971 let hessian = Array2::eye(2);
1972 let spline = LinkWiggleSplineArtifacts {
1973 knot_range: (-1.5, 0.5),
1974 knot_vector: Array1::from_vec(vec![-1.5, -1.5, -1.5, 0.5, 0.5, 0.5]),
1975 degree: 2,
1976 };
1977
1978 let posterior = LinkWigglePosterior::new(
1979 x.view(),
1980 y.view(),
1981 weights.view(),
1982 penalty_base.view(),
1983 penalty_link.view(),
1984 mode_beta.view(),
1985 mode_theta.view(),
1986 hessian.view(),
1987 spline,
1988 NutsFamily::BinomialCLogLog,
1989 1.0,
1990 )
1991 .expect("cloglog link-wiggle posterior");
1992
1993 let z = array![0.2, -0.03];
1994 let (_, grad) = posterior.compute_logp_and_grad(&z);
1995 let eps = 1e-6;
1996 for j in 0..z.len() {
1997 let mut z_plus = z.clone();
1998 let mut z_minus = z.clone();
1999 z_plus[j] += eps;
2000 z_minus[j] -= eps;
2001 let (lp, _) = posterior.compute_logp_and_grad(&z_plus);
2002 let (lm, _) = posterior.compute_logp_and_grad(&z_minus);
2003 let fd = (lp - lm) / (2.0 * eps);
2004 assert!(
2005 (grad[j] - fd).abs() < 1e-6,
2006 "link-wiggle cloglog gradient mismatch at {j}: analytic={}, fd={}",
2007 grad[j],
2008 fd
2009 );
2010 }
2011 }
2012
2013 #[test]
2014 fn nuts_logitgradient_matches_finite_difference() {
2015 let x = array![[1.0, -0.5], [0.2, 0.7], [-1.0, 0.3], [0.5, -1.2]];
2016 let y = array![1.0, 0.0, 1.0, 0.0];
2017 let w = array![1.0, 1.5, 0.8, 1.2];
2018 let penalty = array![[0.4, 0.0], [0.0, 0.6]];
2019 let mode = array![0.1, -0.2];
2020 let hessian = array![[2.0, 0.2], [0.2, 1.7]]; let posterior = NutsPosterior::new(
2023 x.view(),
2024 y.view(),
2025 w.view(),
2026 penalty.view(),
2027 mode.view(),
2028 hessian.view(),
2029 NutsFamily::BinomialLogit,
2030 1.0,
2031 gam_solve::estimate::Dispersion::Known(1.0),
2032 true,
2033 )
2034 .expect("posterior");
2035
2036 let z = array![0.15, -0.35];
2037 let (_, grad) = posterior.compute_logp_and_grad_nd(&z);
2038
2039 let eps = 1e-6;
2040 for j in 0..z.len() {
2041 let mut z_plus = z.clone();
2042 let mut z_minus = z.clone();
2043 z_plus[j] += eps;
2044 z_minus[j] -= eps;
2045 let (lp, _) = posterior.compute_logp_and_grad_nd(&z_plus);
2046 let (lm, _) = posterior.compute_logp_and_grad_nd(&z_minus);
2047 let fd = (lp - lm) / (2.0 * eps);
2048 assert_eq!(
2049 grad[j].signum(),
2050 fd.signum(),
2051 "gradient sign mismatch at {}: analytic={}, fd={}",
2052 j,
2053 grad[j],
2054 fd
2055 );
2056 assert!(
2057 (grad[j] - fd).abs() < 1e-5,
2058 "gradient mismatch at {}: analytic={}, fd={}",
2059 j,
2060 grad[j],
2061 fd
2062 );
2063 }
2064 }
2065
2066 #[test]
2067 fn gamma_log_logp_and_grad_uses_fitted_shape() {
2068 let x = array![[1.0_f64], [1.0_f64]];
2069 let y = array![1.5_f64, 2.5_f64];
2070 let weights = array![1.0_f64, 2.0_f64];
2071 let eta = array![0.2_f64, 0.4_f64];
2072 let shape = 3.5_f64;
2073 let data = SharedData {
2074 x: Arc::new(x.clone()),
2075 y: Arc::new(y.clone()),
2076 weights: Arc::new(weights.clone()),
2077 mode: Arc::new(Array1::zeros(1)),
2078 offset: None,
2079 gamma_shape: shape,
2080 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2081 n_samples: x.nrows(),
2082 dim: x.ncols(),
2083 };
2084
2085 let (ll, grad) = super::gamma_log_logp_and_grad(&data, &eta);
2086
2087 let mut expected_ll = 0.0;
2088 let mut expected_score = 0.0;
2089 for i in 0..eta.len() {
2090 let mu = eta[i].exp();
2091 expected_ll += weights[i]
2092 * (shape * shape.ln() - statrs::function::gamma::ln_gamma(shape) - shape * eta[i]
2093 + (shape - 1.0) * y[i].ln()
2094 - shape * y[i] / mu);
2095 expected_score += weights[i] * shape * (y[i] / mu - 1.0);
2096 }
2097
2098 assert!((ll - expected_ll).abs() < 1e-12);
2099 assert_eq!(grad.len(), 1);
2100 assert!((grad[0] - expected_score).abs() < 1e-12);
2101 }
2102
2103 fn gamma_log_observed_information(
2107 x: &Array2<f64>,
2108 mode: &Array1<f64>,
2109 y: &Array1<f64>,
2110 weights: &Array1<f64>,
2111 shape: f64,
2112 ) -> Array2<f64> {
2113 let p = x.ncols();
2114 let eta = x.dot(mode);
2115 let mut h = Array2::<f64>::zeros((p, p));
2116 for i in 0..x.nrows() {
2117 let mu = eta[i].exp();
2118 let wt = weights[i] * shape * y[i] / mu;
2119 for a in 0..p {
2120 for b in 0..p {
2121 h[[a, b]] += wt * x[[i, a]] * x[[i, b]];
2122 }
2123 }
2124 }
2125 h
2126 }
2127
2128 #[test]
2141 fn gamma_log_nuts_target_curvature_matches_unscaled_hessian_issue_680() {
2142 let x = array![[1.0, -0.7], [1.0, 0.3], [1.0, 1.1], [1.0, -0.2], [1.0, 0.8],];
2143 let mode = array![0.4_f64, -0.6_f64];
2144 let y = array![1.2_f64, 0.7, 2.3, 0.9, 1.6];
2145 let weights = array![1.0_f64, 1.5, 0.8, 1.2, 1.0];
2146 let shape = 4.0_f64;
2148 let p = x.ncols();
2149
2150 let h_data = gamma_log_observed_information(&x, &mode, &y, &weights, shape);
2151 let s = array![[0.5_f64, 0.1], [0.1, 0.9]];
2153 let hessian = &h_data + &s;
2154
2155 let target = NutsPosterior::new(
2156 x.view(),
2157 y.view(),
2158 weights.view(),
2159 s.view(),
2160 mode.view(),
2161 hessian.view(),
2162 NutsFamily::GammaLog,
2163 shape,
2164 gam_solve::estimate::Dispersion::Estimated(1.0 / shape),
2165 false,
2166 )
2167 .expect("GammaLog NUTS target builds");
2168
2169 let eps = 1e-6;
2172 let z0 = Array1::<f64>::zeros(p);
2173 let mut hz = Array2::<f64>::zeros((p, p));
2174 for j in 0..p {
2175 let mut zp = z0.clone();
2176 let mut zm = z0.clone();
2177 zp[j] += eps;
2178 zm[j] -= eps;
2179 let (_, gp) = target.compute_logp_and_grad_nd(&zp);
2180 let (_, gm) = target.compute_logp_and_grad_nd(&zm);
2181 for a in 0..p {
2182 hz[[a, j]] = -(gp[a] - gm[a]) / (2.0 * eps);
2183 }
2184 }
2185
2186 for a in 0..p {
2187 for b in 0..p {
2188 let expected = if a == b { 1.0 } else { 0.0 };
2189 assert!(
2190 (hz[[a, b]] - expected).abs() < 1e-4,
2191 "z-curvature[{a},{b}] = {} (expected {expected}); a non-identity \
2192 value means the GammaLog target re-introduced the #680 dispersion \
2193 double-count (penalty ×ν and/or whitening ×√φ)",
2194 hz[[a, b]]
2195 );
2196 }
2197 }
2198 let trace: f64 = (0..p).map(|i| hz[[i, i]]).sum();
2200 assert!(
2201 (trace - p as f64).abs() < 1e-3,
2202 "z-curvature trace {trace} ≠ {p}: dispersion double-count signature"
2203 );
2204 }
2205
2206 #[test]
2212 fn gamma_log_nuts_whitening_targets_unscaled_inverse_hessian_issue_680() {
2213 let x = array![[1.0, -0.4], [1.0, 0.6], [1.0, 0.1], [1.0, 1.3]];
2214 let mode = array![0.2_f64, 0.3_f64];
2215 let y = array![0.8_f64, 1.7, 1.1, 2.2];
2216 let weights = array![1.0_f64, 1.0, 1.5, 0.7];
2217 let shape = 6.25_f64; let p = x.ncols();
2219 let s = array![[0.3_f64, 0.0], [0.0, 0.7]];
2220 let hessian = &gamma_log_observed_information(&x, &mode, &y, &weights, shape) + &s;
2221
2222 let target = NutsPosterior::new(
2223 x.view(),
2224 y.view(),
2225 weights.view(),
2226 s.view(),
2227 mode.view(),
2228 hessian.view(),
2229 NutsFamily::GammaLog,
2230 shape,
2231 gam_solve::estimate::Dispersion::Estimated(1.0 / shape),
2232 false,
2233 )
2234 .expect("GammaLog NUTS target builds");
2235
2236 let l = target.chol();
2238 let llt = l.dot(&l.t());
2239 let prod = llt.dot(&hessian);
2240 for a in 0..p {
2241 for b in 0..p {
2242 let expected = if a == b { 1.0 } else { 0.0 };
2243 assert!(
2244 (prod[[a, b]] - expected).abs() < 1e-8,
2245 "L Lᵀ H[{a},{b}] = {} (expected {expected}); a φ·I result means \
2246 the Gamma whitening still scales by √φ (#680)",
2247 prod[[a, b]]
2248 );
2249 }
2250 }
2251 }
2252
2253 #[test]
2254 fn firth_jeffreys_logit_is_finite_for_rank_deficient_design() {
2255 let x = array![
2256 [1.0, -0.5, 1.0],
2257 [1.0, 0.3, 1.0],
2258 [1.0, 0.8, 1.0],
2259 [1.0, -1.2, 1.0],
2260 ];
2261 let y = array![1.0, 0.0, 1.0, 0.0];
2262 let weights = array![1.0, 2.0, 0.5, 1.5];
2263 let eta = array![0.2, -0.1, 0.4, -0.3];
2264
2265 let data = SharedData {
2266 x: Arc::new(x.clone()),
2267 y: Arc::new(y),
2268 weights: Arc::new(weights.clone()),
2269 mode: Arc::new(Array1::zeros(x.ncols())),
2270 offset: None,
2271 gamma_shape: 1.0,
2272 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2273 n_samples: x.nrows(),
2274 dim: x.ncols(),
2275 };
2276
2277 let (value, grad) =
2278 firth_jeffreys_logp_and_grad(NutsFamily::BinomialLogit, &data, &eta).expect("firth");
2279
2280 assert!(value.is_finite());
2281 assert_eq!(grad.len(), x.ncols());
2282 assert!(grad.iter().all(|v| v.is_finite()));
2283 }
2284
2285 #[test]
2286 fn logit_pg_gibbs_returns_finite_samples() {
2287 let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2288 let y = array![1.0, 0.0, 1.0, 0.0];
2289 let w = array![1.0, 1.0, 1.0, 1.0];
2290 let penalty = array![[0.2, 0.0], [0.0, 0.4]];
2291 let mode = array![0.0, 0.0];
2292 let cfg = NutsConfig {
2293 n_samples: 30,
2294 nwarmup: 30,
2295 n_chains: 2,
2296 target_accept: 0.8,
2297 seed: 123,
2298 };
2299 let out = run_logit_polya_gamma_gibbs(
2300 x.view(),
2301 y.view(),
2302 w.view(),
2303 penalty.view(),
2304 mode.view(),
2305 &cfg,
2306 )
2307 .expect("pg gibbs should run");
2308 assert_eq!(out.samples.ncols(), 2);
2309 assert_eq!(out.samples.nrows(), cfg.n_samples * cfg.n_chains);
2310 assert!(out.samples.iter().all(|v| v.is_finite()));
2311 assert!(out.posterior_mean.iter().all(|v| v.is_finite()));
2312 assert!(out.posterior_std.iter().all(|v| v.is_finite()));
2313 }
2314
2315 #[test]
2316 fn family_pg_dispatch_rejects_non_bernoulli_response() {
2317 let x = array![[1.0], [1.0]];
2318 let y = array![2.0, 0.0];
2319 let w = array![1.0, 1.0];
2320 let penalty = array![[0.1]];
2321 let mode = array![0.0];
2322 let non_spd_hessian = array![[0.0]];
2323 let cfg = NutsConfig {
2324 n_samples: 1,
2325 nwarmup: 1,
2326 n_chains: 1,
2327 target_accept: 0.8,
2328 seed: 321,
2329 };
2330
2331 let result = run_nuts_sampling_flattened_family(
2332 LikelihoodSpec::binomial_logit(),
2333 FamilyNutsInputs::Glm(GlmFlatInputs {
2334 x: x.view(),
2335 y: y.view(),
2336 weights: w.view(),
2337 penalty_matrix: penalty.view(),
2338 mode: mode.view(),
2339 hessian: non_spd_hessian.view(),
2340 gamma_shape: None,
2341 dispersion: gam_solve::model_types::Dispersion::Known(1.0),
2342 firth_bias_reduction: false,
2343 offset: None,
2344 }),
2345 &cfg,
2346 );
2347
2348 let err = result.err().expect("PG dispatch should reject count rows");
2349 assert!(
2350 err.contains("response must be exactly 0 or 1"),
2351 "unexpected error: {err}"
2352 );
2353 }
2354
2355 #[test]
2356 fn family_dispatch_uses_pg_gibbs_for_standard_logit() {
2357 let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2358 let y = array![1.0, 0.0, 1.0, 0.0];
2359 let w = array![1.0, 1.0, 1.0, 1.0];
2360 let penalty = array![[0.2, 0.0], [0.0, 0.4]];
2361 let mode = array![0.0, 0.0];
2362 let non_spdhessian = array![[0.0, 0.0], [0.0, 0.0]];
2363 let cfg = NutsConfig {
2364 n_samples: 20,
2365 nwarmup: 20,
2366 n_chains: 2,
2367 target_accept: 0.8,
2368 seed: 456,
2369 };
2370 let out = run_nuts_sampling_flattened_family(
2371 LikelihoodSpec {
2372 response: ResponseFamily::Binomial,
2373 link: InverseLink::Standard(StandardLink::Logit),
2374 },
2375 FamilyNutsInputs::Glm(GlmFlatInputs {
2376 x: x.view(),
2377 y: y.view(),
2378 weights: w.view(),
2379 penalty_matrix: penalty.view(),
2380 mode: mode.view(),
2381 hessian: non_spdhessian.view(),
2382 gamma_shape: None,
2383 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2384 firth_bias_reduction: false,
2385 offset: None,
2386 }),
2387 &cfg,
2388 )
2389 .expect("dispatch should use PG Gibbs and not require Hessian factorization");
2390 assert_eq!(out.samples.nrows(), cfg.n_samples * cfg.n_chains);
2391 assert!(out.samples.iter().all(|v| v.is_finite()));
2392 }
2393
2394 #[test]
2395 fn family_dispatch_routes_probit_to_nuts_path() {
2396 let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2397 let y = array![1.0, 0.0, 1.0, 0.0];
2398 let w = array![1.0, 1.0, 1.0, 1.0];
2399 let penalty = array![[0.2, 0.0], [0.0, 0.4]];
2400 let mode = array![0.0, 0.0];
2401 let non_spdhessian = array![[0.0, 0.0], [0.0, 0.0]];
2402 let cfg = NutsConfig {
2403 n_samples: 20,
2404 nwarmup: 20,
2405 n_chains: 2,
2406 target_accept: 0.8,
2407 seed: 654,
2408 };
2409
2410 let err = match run_nuts_sampling_flattened_family(
2411 LikelihoodSpec {
2412 response: ResponseFamily::Binomial,
2413 link: InverseLink::Standard(StandardLink::Probit),
2414 },
2415 FamilyNutsInputs::Glm(GlmFlatInputs {
2416 x: x.view(),
2417 y: y.view(),
2418 weights: w.view(),
2419 penalty_matrix: penalty.view(),
2420 mode: mode.view(),
2421 hessian: non_spdhessian.view(),
2422 gamma_shape: None,
2423 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2424 firth_bias_reduction: false,
2425 offset: None,
2426 }),
2427 &cfg,
2428 ) {
2429 Ok(_) => panic!("non-SPD Hessian should fail after probit routes to the NUTS path"),
2430 Err(err) => err,
2431 };
2432
2433 assert!(
2434 err.contains("Hessian Cholesky decomposition failed"),
2435 "unexpected error: {err}"
2436 );
2437 }
2438
2439 #[test]
2440 fn family_dispatch_rejects_nonbinomial_firth_family() {
2441 let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2442 let y = array![1.0, 2.0, 0.0, 3.0];
2443 let w = array![1.0, 1.0, 1.0, 1.0];
2444 let penalty = array![[0.2, 0.0], [0.0, 0.4]];
2445 let mode = array![0.0, 0.0];
2446 let hessian = array![[1.5, 0.1], [0.1, 1.2]];
2447 let cfg = NutsConfig {
2448 n_samples: 20,
2449 nwarmup: 20,
2450 n_chains: 2,
2451 target_accept: 0.8,
2452 seed: 111,
2453 };
2454
2455 let err = match run_nuts_sampling_flattened_family(
2456 LikelihoodSpec {
2457 response: ResponseFamily::Poisson,
2458 link: InverseLink::Standard(StandardLink::Log),
2459 },
2460 FamilyNutsInputs::Glm(GlmFlatInputs {
2461 x: x.view(),
2462 y: y.view(),
2463 weights: w.view(),
2464 penalty_matrix: penalty.view(),
2465 mode: mode.view(),
2466 hessian: hessian.view(),
2467 gamma_shape: None,
2468 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2469 firth_bias_reduction: true,
2470 offset: None,
2471 }),
2472 &cfg,
2473 ) {
2474 Ok(_) => panic!("Poisson Firth should be rejected explicitly"),
2475 Err(err) => err,
2476 };
2477
2478 assert!(
2479 err.contains(
2480 "NUTS with Firth requires a Binomial inverse link with a Fisher-weight jet"
2481 ),
2482 "unexpected error: {err}"
2483 );
2484 }
2485
2486 #[test]
2487 fn run_nuts_sampling_rejects_invalid_target_accept() {
2488 let x = array![[1.0], [1.0], [1.0]];
2489 let y = array![0.5, -0.5, 1.0];
2490 let weights = array![1.0, 1.0, 1.0];
2491 let penalty = array![[0.25]];
2492 let mode = array![0.0];
2493 let hessian = array![[1.25]];
2494 let cfg = NutsConfig {
2495 n_samples: 10,
2496 nwarmup: 10,
2497 n_chains: 1,
2498 target_accept: 1.0,
2499 seed: 222,
2500 };
2501
2502 let err = super::run_nuts_sampling(
2503 x.view(),
2504 y.view(),
2505 weights.view(),
2506 penalty.view(),
2507 mode.view(),
2508 hessian.view(),
2509 NutsFamily::Gaussian,
2510 1.0,
2511 gam_solve::estimate::Dispersion::Known(1.0),
2512 false,
2513 None,
2514 &cfg,
2515 )
2516 .expect_err("invalid target_accept should be rejected before sampling");
2517
2518 assert!(
2519 err.contains("target_accept must be finite and lie in (0, 1)"),
2520 "unexpected error: {err}"
2521 );
2522 }
2523
2524 #[test]
2525 fn run_nuts_sampling_rejects_zero_or_too_few_samples() {
2526 let x = array![[1.0], [1.0], [1.0]];
2533 let y = array![0.5, -0.5, 1.0];
2534 let weights = array![1.0, 1.0, 1.0];
2535 let penalty = array![[0.25]];
2536 let mode = array![0.0];
2537 let hessian = array![[1.25]];
2538
2539 for bad_samples in [0usize, 1, 2, 3] {
2540 let cfg = NutsConfig {
2541 n_samples: bad_samples,
2542 nwarmup: 10,
2543 n_chains: 2,
2544 target_accept: 0.8,
2545 seed: 222,
2546 };
2547
2548 let err = super::run_nuts_sampling(
2549 x.view(),
2550 y.view(),
2551 weights.view(),
2552 penalty.view(),
2553 mode.view(),
2554 hessian.view(),
2555 NutsFamily::Gaussian,
2556 1.0,
2557 gam_solve::estimate::Dispersion::Known(1.0),
2558 false,
2559 None,
2560 &cfg,
2561 )
2562 .expect_err("too-few samples must be rejected before sampling");
2563
2564 assert!(
2565 err.contains("n_samples must be >= 4"),
2566 "n_samples={bad_samples} gave unexpected error: {err}"
2567 );
2568 }
2569 }
2570
2571 #[test]
2572 fn polya_gamma_gibbs_rejects_degenerate_counts_but_accepts_single_chain() {
2573 let x = array![[1.0], [1.0], [1.0], [1.0]];
2582 let y = array![1.0, 0.0, 1.0, 0.0];
2583 let weights = array![1.0, 1.0, 1.0, 1.0];
2584 let penalty = array![[0.25]];
2585 let mode = array![0.0];
2586
2587 let zero_chain_cfg = NutsConfig {
2588 n_samples: 20,
2589 nwarmup: 10,
2590 n_chains: 0,
2591 target_accept: 0.8,
2592 seed: 7,
2593 };
2594 let err = super::run_logit_polya_gamma_gibbs(
2595 x.view(),
2596 y.view(),
2597 weights.view(),
2598 penalty.view(),
2599 mode.view(),
2600 &zero_chain_cfg,
2601 )
2602 .expect_err("PG Gibbs must reject zero chains up front, not return an empty posterior");
2603 assert!(
2604 err.contains("n_chains must be >= 1"),
2605 "PG n_chains=0 gave unexpected error: {err}"
2606 );
2607
2608 let zero_sample_cfg = NutsConfig {
2609 n_samples: 0,
2610 nwarmup: 10,
2611 n_chains: 2,
2612 target_accept: 0.8,
2613 seed: 7,
2614 };
2615 let err = super::run_logit_polya_gamma_gibbs(
2616 x.view(),
2617 y.view(),
2618 weights.view(),
2619 penalty.view(),
2620 mode.view(),
2621 &zero_sample_cfg,
2622 )
2623 .expect_err("PG Gibbs must reject zero samples up front, not return an empty posterior");
2624 assert!(
2625 err.contains("n_samples must be >= 4"),
2626 "PG n_samples=0 gave unexpected error: {err}"
2627 );
2628
2629 let single_chain_cfg = NutsConfig {
2630 n_samples: 20,
2631 nwarmup: 10,
2632 n_chains: 1,
2633 target_accept: 0.8,
2634 seed: 7,
2635 };
2636 let result = super::run_logit_polya_gamma_gibbs(
2637 x.view(),
2638 y.view(),
2639 weights.view(),
2640 penalty.view(),
2641 mode.view(),
2642 &single_chain_cfg,
2643 )
2644 .expect("PG Gibbs must accept a single chain and return draws");
2645 assert_eq!(
2646 result.samples.nrows(),
2647 20,
2648 "single-chain PG run should return all 20 requested draws"
2649 );
2650 }
2651
2652 #[test]
2653 fn run_nuts_sampling_rejects_zero_chains_but_accepts_single_chain() {
2654 let x = array![[1.0], [1.0], [1.0]];
2668 let y = array![0.5, -0.5, 1.0];
2669 let weights = array![1.0, 1.0, 1.0];
2670 let penalty = array![[0.25]];
2671 let mode = array![0.0];
2672 let hessian = array![[1.25]];
2673
2674 let zero_chain_cfg = NutsConfig {
2675 n_samples: 50,
2676 nwarmup: 10,
2677 n_chains: 0,
2678 target_accept: 0.8,
2679 seed: 222,
2680 };
2681 let err = super::run_nuts_sampling(
2682 x.view(),
2683 y.view(),
2684 weights.view(),
2685 penalty.view(),
2686 mode.view(),
2687 hessian.view(),
2688 NutsFamily::Gaussian,
2689 1.0,
2690 gam_solve::estimate::Dispersion::Known(1.0),
2691 false,
2692 None,
2693 &zero_chain_cfg,
2694 )
2695 .expect_err("zero chains must be rejected before sampling");
2696 assert!(
2697 err.contains("n_chains must be >= 1"),
2698 "n_chains=0 gave unexpected error: {err}"
2699 );
2700
2701 let single_chain_cfg = NutsConfig {
2702 n_samples: 50,
2703 nwarmup: 10,
2704 n_chains: 1,
2705 target_accept: 0.8,
2706 seed: 222,
2707 };
2708 let result = super::run_nuts_sampling(
2709 x.view(),
2710 y.view(),
2711 weights.view(),
2712 penalty.view(),
2713 mode.view(),
2714 hessian.view(),
2715 NutsFamily::Gaussian,
2716 1.0,
2717 gam_solve::estimate::Dispersion::Known(1.0),
2718 false,
2719 None,
2720 &single_chain_cfg,
2721 )
2722 .expect("a single chain is a supported configuration and must return draws");
2723 assert_eq!(
2724 result.samples.nrows(),
2725 50,
2726 "single-chain run should return all 50 requested draws"
2727 );
2728 }
2729
2730 #[test]
2731 fn joint_hmc_boundary_rejects_nonbinomial_firth_family() {
2732 let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2733 let y = array![1.0, 2.0, 0.0, 3.0];
2734 let w = array![1.0, 1.0, 1.0, 1.0];
2735 let hessian = array![[1.5, 0.1], [0.1, 1.2]];
2736 let penalty_root = array![[0.4, 0.0], [0.0, 0.6]];
2737 let mode = array![0.0, 0.0];
2738 let rho_mode = array![0.0];
2739 let cfg = NutsConfig {
2740 n_samples: 20,
2741 nwarmup: 20,
2742 n_chains: 2,
2743 target_accept: 0.8,
2744 seed: 111,
2745 };
2746
2747 let inputs = JointBetaRhoInputs {
2748 x: x.view(),
2749 y: y.view(),
2750 weights: w.view(),
2751 likelihood: LikelihoodSpec {
2752 response: ResponseFamily::Poisson,
2753 link: InverseLink::Standard(StandardLink::Log),
2754 },
2755 gamma_shape: None,
2756 mode: mode.view(),
2757 hessian: hessian.view(),
2758 penalty_roots: vec![CanonicalPenalty::from_dense_root(
2759 penalty_root.clone(),
2760 penalty_root.ncols(),
2761 )],
2762 rho_mode: rho_mode.view(),
2763 rho_prior: RhoPrior::default(),
2764 firth_bias_reduction: true,
2765 trigger_skewness: 0.75,
2766 };
2767
2768 let err = match run_joint_beta_rho_sampling(&inputs, &cfg) {
2769 Ok(_) => panic!("Poisson joint HMC Firth should be rejected explicitly"),
2770 Err(err) => err,
2771 };
2772
2773 assert!(
2774 err.contains(
2775 "Joint HMC with Firth requires a Binomial inverse link with a Fisher-weight jet"
2776 ),
2777 "unexpected error: {err}"
2778 );
2779 }
2780
2781 #[test]
2782 fn joint_hmc_uses_combined_penalty_logdet_for_overlapping_penalties() {
2783 let x = array![[0.0, 0.0]];
2784 let y = array![0.0];
2785 let w = array![0.0];
2786 let mode = array![0.0, 0.0];
2787 let hessian = array![[1.0, 0.0], [0.0, 1.0]];
2788 let rho_mode = array![0.0, 0.0];
2789 let penalty_1 = array![[1.0, 0.0], [0.0, 1.0]];
2790 let penalty_2 = array![[2.0_f64.sqrt(), 0.0], [0.0, 1.0]];
2791 let target = JointBetaRhoPosterior::new(
2792 x.view(),
2793 y.view(),
2794 w.view(),
2795 mode.view(),
2796 hessian.view(),
2797 vec![
2798 CanonicalPenalty::from_dense_root(penalty_1, 2),
2799 CanonicalPenalty::from_dense_root(penalty_2, 2),
2800 ],
2801 rho_mode.view(),
2802 LikelihoodSpec {
2803 response: ResponseFamily::Gaussian,
2804 link: InverseLink::Standard(StandardLink::Identity),
2805 },
2806 None,
2807 RhoPrior::Flat,
2808 false,
2809 )
2810 .expect("joint target");
2811
2812 let params = array![0.0, 0.0, 0.0, 0.0];
2813 let (_, grad) = target.compute_joint_logp_and_grad(¶ms);
2814 assert!(
2815 (grad[2] - 5.0 / 12.0).abs() < 1.0e-10,
2816 "expected overlapping-penalty gradient 5/12, got {}",
2817 grad[2]
2818 );
2819 assert!(
2820 (grad[3] - 7.0 / 12.0).abs() < 1.0e-10,
2821 "expected overlapping-penalty gradient 7/12, got {}",
2822 grad[3]
2823 );
2824 }
2825
2826 #[test]
2827 fn joint_hmc_target_does_not_depend_on_rho_mode_when_prior_is_fixed() {
2828 let x = array![[0.0]];
2829 let y = array![0.0];
2830 let w = array![0.0];
2831 let mode = array![0.0];
2832 let hessian = array![[1.0]];
2833 let penalty = CanonicalPenalty::from_dense_root(array![[1.0]], 1);
2834 let prior = RhoPrior::Normal {
2835 mean: 0.25,
2836 sd: 1.7,
2837 };
2838
2839 let target_a = JointBetaRhoPosterior::new(
2840 x.view(),
2841 y.view(),
2842 w.view(),
2843 mode.view(),
2844 hessian.view(),
2845 vec![penalty.clone()],
2846 array![0.0].view(),
2847 LikelihoodSpec {
2848 response: ResponseFamily::Gaussian,
2849 link: InverseLink::Standard(StandardLink::Identity),
2850 },
2851 None,
2852 prior.clone(),
2853 false,
2854 )
2855 .expect("target a");
2856 let target_b = JointBetaRhoPosterior::new(
2857 x.view(),
2858 y.view(),
2859 w.view(),
2860 mode.view(),
2861 hessian.view(),
2862 vec![penalty],
2863 array![2.5].view(),
2864 LikelihoodSpec {
2865 response: ResponseFamily::Gaussian,
2866 link: InverseLink::Standard(StandardLink::Identity),
2867 },
2868 None,
2869 prior,
2870 false,
2871 )
2872 .expect("target b");
2873
2874 let params = array![0.0, -0.4];
2875 let (lp_a, grad_a) = target_a.compute_joint_logp_and_grad(¶ms);
2876 let (lp_b, grad_b) = target_b.compute_joint_logp_and_grad(¶ms);
2877 assert!((lp_a - lp_b).abs() < 1.0e-12);
2878 for i in 0..grad_a.len() {
2879 assert!(
2880 (grad_a[i] - grad_b[i]).abs() < 1.0e-12,
2881 "rho_mode leaked into target gradient at {}: {} vs {}",
2882 i,
2883 grad_a[i],
2884 grad_b[i]
2885 );
2886 }
2887 }
2888
2889 #[test]
2890 fn joint_hmc_binomial_sas_uses_runtime_link_state() {
2891 let x = array![[1.0], [1.0]];
2892 let y = array![1.0, 0.0];
2893 let weights = array![1.0, 1.0];
2894 let eta = array![0.3, -0.2];
2895 let sas_state = gam_solve::mixture_link::state_from_sasspec(gam_problem::types::SasLinkSpec {
2896 initial_epsilon: 0.4,
2897 initial_log_delta: -0.2,
2898 })
2899 .expect("sas state");
2900 let data = SharedData {
2901 x: Arc::new(x),
2902 y: Arc::new(y),
2903 weights: Arc::new(weights),
2904 mode: Arc::new(Array1::zeros(1)),
2905 offset: None,
2906 gamma_shape: 1.0,
2907 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2908 n_samples: 2,
2909 dim: 1,
2910 };
2911
2912 let (ll_sas, _) = joint_family_logp_and_grad(
2913 &LikelihoodSpec {
2914 response: ResponseFamily::Binomial,
2915 link: InverseLink::Sas(sas_state),
2916 },
2917 &data,
2918 &eta,
2919 )
2920 .expect("sas joint logp");
2921 let (ll_logit, _) = joint_family_logp_and_grad(
2922 &LikelihoodSpec {
2923 response: ResponseFamily::Binomial,
2924 link: InverseLink::Standard(StandardLink::Logit),
2925 },
2926 &data,
2927 &eta,
2928 )
2929 .expect("logit joint logp");
2930
2931 assert!(
2932 (ll_sas - ll_logit).abs() > 1.0e-6,
2933 "adaptive SAS link should not collapse to the logit likelihood"
2934 );
2935 }
2936
2937 #[test]
2938 fn directional_cubic_diagnostic_is_rotation_invariant_for_hessian_eigenvectors() {
2939 let x = array![[1.0, 0.5], [-0.3, 1.4], [0.8, -1.1]];
2940 let c = array![0.7, -0.5, 0.2];
2941 let h = array![[4.0, 0.0], [0.0, 1.0]];
2942 let theta = std::f64::consts::FRAC_PI_4;
2943 let q = array![[theta.cos(), -theta.sin()], [theta.sin(), theta.cos()],];
2944 let x_rot = x.dot(&q);
2945 let h_rot = q.t().dot(&h).dot(&q);
2946
2947 let (base_max, base_vals) = laplace_directional_cubic_diagnostic(
2948 &h,
2949 &DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x)),
2950 &c,
2951 true,
2952 )
2953 .expect("base diagnostic");
2954 let (rot_max, rot_vals) = laplace_directional_cubic_diagnostic(
2955 &h_rot,
2956 &DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x_rot)),
2957 &c,
2958 true,
2959 )
2960 .expect("rotated diagnostic");
2961
2962 let mut base_abs: Vec<f64> = base_vals.iter().map(|v| v.abs()).collect();
2963 let mut rot_abs: Vec<f64> = rot_vals.iter().map(|v| v.abs()).collect();
2964 base_abs.sort_by(|a, b| a.partial_cmp(b).expect("finite compare"));
2965 rot_abs.sort_by(|a, b| a.partial_cmp(b).expect("finite compare"));
2966
2967 assert!((base_max - rot_max).abs() < 1.0e-10);
2968 for i in 0..base_abs.len() {
2969 assert!(
2970 (base_abs[i] - rot_abs[i]).abs() < 1.0e-10,
2971 "directional diagnostic changed under rotation at {}: {} vs {}",
2972 i,
2973 base_abs[i],
2974 rot_abs[i]
2975 );
2976 }
2977 }
2978
2979 #[test]
2983 fn joint_hmc_penalty_logdet_agrees_with_reml_path() {
2984 use gam_solve::estimate::reml::penalty_logdet::PenaltyPseudologdet;
2985
2986 let root_1 = array![[1.0, 0.5, 0.0], [0.0, 0.8, 0.3]];
2988 let root_2 = array![[0.0, 0.7, 0.0], [0.0, 0.0, 1.2]];
2989 let cp1 = CanonicalPenalty::from_dense_root(root_1, 3);
2990 let cp2 = CanonicalPenalty::from_dense_root(root_2, 3);
2991 let lambdas = [2.5_f64, 0.8];
2992 let penalties = [cp1.clone(), cp2.clone()];
2993
2994 let pld =
2996 PenaltyPseudologdet::from_penalties(&penalties, &lambdas, 0.0, 3).expect("reml pld");
2997 let reml_value = pld.value();
2998 let (reml_d1, reml_d2) = pld.rho_derivatives_from_penalties(&penalties, &lambdas);
2999
3000 let x = Array2::<f64>::zeros((1, 3));
3004 let y = array![0.0];
3005 let w = array![0.0];
3006 let mode = Array1::<f64>::zeros(3);
3007 let hessian = Array2::<f64>::eye(3);
3008 let rho = Array1::from_vec(lambdas.iter().map(|l| l.ln()).collect());
3009 let target = JointBetaRhoPosterior::new(
3010 x.view(),
3011 y.view(),
3012 w.view(),
3013 mode.view(),
3014 hessian.view(),
3015 vec![cp1, cp2],
3016 rho.view(),
3017 LikelihoodSpec {
3018 response: ResponseFamily::Gaussian,
3019 link: InverseLink::Standard(StandardLink::Identity),
3020 },
3021 None,
3022 RhoPrior::Flat,
3023 false,
3024 )
3025 .expect("joint target");
3026
3027 let mut params = Array1::<f64>::zeros(3 + 2);
3029 params[3] = rho[0];
3030 params[4] = rho[1];
3031 let (logp, grad) = target.compute_joint_logp_and_grad(¶ms);
3032
3033 assert!(
3035 (logp - 0.5 * reml_value).abs() < 1.0e-8,
3036 "joint HMC logdet value {} vs REML 0.5*{} = {}",
3037 logp,
3038 reml_value,
3039 0.5 * reml_value,
3040 );
3041
3042 for k in 0..2 {
3044 assert!(
3045 (grad[3 + k] - 0.5 * reml_d1[k]).abs() < 1.0e-8,
3046 "joint HMC logdet gradient[{}] = {} vs REML 0.5*{} = {}",
3047 k,
3048 grad[3 + k],
3049 reml_d1[k],
3050 0.5 * reml_d1[k],
3051 );
3052 }
3053
3054 assert!(
3057 (reml_d2[[0, 1]] - reml_d2[[1, 0]]).abs() < 1.0e-12,
3058 "REML penalty logdet Hessian not symmetric"
3059 );
3060 }
3061
3062 #[test]
3067 fn joint_hmc_family_gating_never_remaps() {
3068 let data = SharedData {
3069 x: Arc::new(array![[1.0], [1.0]]),
3070 y: Arc::new(array![1.0, 0.0]),
3071 weights: Arc::new(array![1.0, 1.0]),
3072 mode: Arc::new(Array1::zeros(1)),
3073 offset: None,
3074 gamma_shape: 1.0,
3075 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
3076 n_samples: 2,
3077 dim: 1,
3078 };
3079 let eta = array![0.1, -0.1];
3080
3081 let accepted = [
3083 LikelihoodSpec {
3084 response: ResponseFamily::Binomial,
3085 link: InverseLink::Standard(StandardLink::Logit),
3086 },
3087 LikelihoodSpec {
3088 response: ResponseFamily::Binomial,
3089 link: InverseLink::Standard(StandardLink::Probit),
3090 },
3091 LikelihoodSpec {
3092 response: ResponseFamily::Binomial,
3093 link: InverseLink::Standard(StandardLink::CLogLog),
3094 },
3095 LikelihoodSpec {
3096 response: ResponseFamily::Gaussian,
3097 link: InverseLink::Standard(StandardLink::Identity),
3098 },
3099 LikelihoodSpec {
3100 response: ResponseFamily::Poisson,
3101 link: InverseLink::Standard(StandardLink::Log),
3102 },
3103 LikelihoodSpec {
3104 response: ResponseFamily::Gamma,
3105 link: InverseLink::Standard(StandardLink::Log),
3106 },
3107 ];
3108 for spec in &accepted {
3109 let result = joint_family_logp_and_grad(spec, &data, &eta);
3110 assert!(
3111 result.is_ok(),
3112 "spec {:?} should be accepted but got error: {:?}",
3113 spec,
3114 result.err(),
3115 );
3116 }
3117
3118 let sas_state = gam_solve::mixture_link::state_from_sasspec(gam_problem::types::SasLinkSpec {
3121 initial_epsilon: 0.0,
3122 initial_log_delta: 0.0,
3123 })
3124 .expect("sas state");
3125 let adaptive = [
3126 LikelihoodSpec {
3127 response: ResponseFamily::Binomial,
3128 link: InverseLink::Sas(sas_state),
3129 },
3130 LikelihoodSpec {
3131 response: ResponseFamily::Binomial,
3132 link: InverseLink::BetaLogistic(
3133 gam_solve::mixture_link::state_from_sasspec(gam_problem::types::SasLinkSpec {
3134 initial_epsilon: 0.0,
3135 initial_log_delta: 0.0,
3136 })
3137 .expect("bl state"),
3138 ),
3139 },
3140 ];
3141 for spec in &adaptive {
3142 let result = joint_family_logp_and_grad(spec, &data, &eta);
3143 assert!(
3144 result.is_ok(),
3145 "adaptive spec {:?} should be accepted with its real link",
3146 spec,
3147 );
3148 }
3149
3150 let rp_result = joint_family_logp_and_grad(
3152 &LikelihoodSpec {
3153 response: ResponseFamily::RoystonParmar,
3154 link: InverseLink::Standard(StandardLink::Logit),
3155 },
3156 &data,
3157 &eta,
3158 );
3159 assert!(
3160 rp_result.is_err(),
3161 "RoystonParmar should be rejected, not silently accepted"
3162 );
3163 }
3164
3165 #[test]
3168 fn directional_cubic_power_iteration_finds_larger_or_equal_skewness() {
3169 let x = array![
3173 [2.0, 1.0],
3174 [-1.0, 2.0],
3175 [0.5, -0.5],
3176 [1.5, 0.3],
3177 [-0.8, 1.7],
3178 ];
3179 let c = array![1.0, -0.5, 0.3, -0.7, 0.4];
3180 let h = array![[3.0, 1.0], [1.0, 2.0]];
3181
3182 let (max_val, eigenvector_vals) = laplace_directional_cubic_diagnostic(
3183 &h,
3184 &DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x)),
3185 &c,
3186 true,
3187 )
3188 .expect("diagnostic");
3189
3190 let eig_max = eigenvector_vals
3192 .iter()
3193 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
3194 assert!(
3195 max_val >= eig_max - 1.0e-12,
3196 "power iteration result {} should be >= eigenvector max {}",
3197 max_val,
3198 eig_max,
3199 );
3200 }
3201
3202 #[test]
3203 fn laplace_trustworthiness_is_block_local_and_threshold_shrinks_with_n() {
3204 let skew = array![0.01, 0.9];
3210
3211 let verdict = laplace_trustworthiness_from_skewness(&skew, 100.0);
3214 assert_eq!(
3215 verdict.untrustworthy_directions,
3216 vec![1],
3217 "only the strongly-skewed direction should be flagged (block-local)",
3218 );
3219 assert!(verdict.fallback_required());
3220 assert!((verdict.max_abs_skewness - 0.9).abs() < 1e-12);
3221
3222 let t_small = laplace_skewness_threshold(25.0);
3226 let t_large = laplace_skewness_threshold(10_000.0);
3227 assert!(
3228 t_large < t_small,
3229 "validity threshold must tighten with sample size: {t_large} !< {t_small}",
3230 );
3231
3232 let none = laplace_trustworthiness_from_skewness(&skew, 0.0);
3235 assert!(!none.fallback_required());
3236 assert!(none.threshold.is_infinite());
3237 }
3238
3239 struct AnharmonicBlock {
3245 lambdas: Array1<f64>,
3246 a: f64,
3247 }
3248 impl super::BlockExcessTarget for AnharmonicBlock {
3249 fn block_dim(&self) -> usize {
3250 self.lambdas.len()
3251 }
3252 fn rho_dim(&self) -> usize {
3253 self.lambdas.len()
3254 }
3255 fn block_curvatures(&self) -> &Array1<f64> {
3256 &self.lambdas
3257 }
3258 fn excess(&self, t: &Array1<f64>) -> f64 {
3259 self.a * t.iter().map(|&x| x.powi(4)).sum::<f64>()
3260 }
3261 fn excess_rho_gradient(&self, t: &Array1<f64>) -> Array1<f64> {
3262 t.mapv(|x| self.a * x.powi(4))
3263 }
3264 fn displaced_neg_score(&self, t: &Array1<f64>) -> Array1<f64> {
3265 assert_eq!(t.len(), self.block_dim(), "displacement dim mismatch");
3269 Array1::zeros(0)
3270 }
3271 fn base_neg_score(&self) -> Array1<f64> {
3272 Array1::zeros(0)
3273 }
3274 }
3275
3276 #[test]
3277 fn block_sampled_marginal_is_zero_for_gaussian_block() {
3278 let target = AnharmonicBlock {
3283 lambdas: array![2.0, 0.5],
3284 a: 0.0,
3285 };
3286 let out = super::block_sampled_marginal_correction(&target).expect("correction");
3287 assert!(
3288 out.value.abs() < 1e-12,
3289 "Gaussian block value {}",
3290 out.value
3291 );
3292 assert!(out.rho_gradient.iter().all(|&g| g.abs() < 1e-12));
3293 assert!(out.n_draws > 0);
3294 }
3295
3296 #[test]
3297 fn block_sampled_marginal_recovers_analytic_quartic_correction() {
3298 let lambda = 3.0_f64;
3305 let a = 0.05_f64;
3306 let target = AnharmonicBlock {
3307 lambdas: array![lambda],
3308 a,
3309 };
3310 let out = super::block_sampled_marginal_correction(&target).expect("correction");
3311
3312 let sigma = (1.0 / lambda).sqrt();
3315 let steps = 20_001;
3316 let lo = -8.0 * sigma;
3317 let hi = 8.0 * sigma;
3318 let h = (hi - lo) / (steps as f64 - 1.0);
3319 let mut integral = 0.0_f64;
3320 for i in 0..steps {
3321 let tt = lo + h * i as f64;
3322 let gauss = (-(tt * tt) / (2.0 * sigma * sigma)).exp()
3323 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
3324 let w = if i == 0 || i == steps - 1 { 0.5 } else { 1.0 };
3325 integral += w * gauss * (-a * tt.powi(4)).exp() * h;
3326 }
3327 let reference = integral.ln();
3328 assert!(
3329 (out.value - reference).abs() < 5e-3,
3330 "sampled Δ_b {} vs reference {}",
3331 out.value,
3332 reference,
3333 );
3334 assert!(out.value < 0.0, "quartic penalty must shrink block mass");
3335 }
3336
3337 struct MatvecBlock {
3345 lambdas: Array1<f64>,
3346 x: Array2<f64>,
3347 v_b: Array2<f64>,
3348 y: Array1<f64>,
3349 batched: bool,
3350 }
3351 impl MatvecBlock {
3352 fn s_of(&self, t: &Array1<f64>) -> Array1<f64> {
3353 let delta = self.v_b.dot(t);
3354 gam_linalg::faer_ndarray::fast_av(&self.x, &delta)
3355 }
3356 fn excess_and_ngs(&self, s: &Array1<f64>) -> (f64, Array1<f64>) {
3358 let mut excess = 0.0;
3359 let mut ngs = Array1::<f64>::zeros(s.len());
3360 for i in 0..s.len() {
3361 let mu = (self.y[i] + s[i]).tanh();
3362 excess += 0.5 * s[i] * s[i] - 0.1 * mu;
3363 ngs[i] = mu - self.y[i];
3364 }
3365 (excess, ngs)
3366 }
3367 }
3368 impl super::BlockExcessTarget for MatvecBlock {
3369 fn block_dim(&self) -> usize {
3370 self.lambdas.len()
3371 }
3372 fn rho_dim(&self) -> usize {
3373 self.lambdas.len()
3374 }
3375 fn block_curvatures(&self) -> &Array1<f64> {
3376 &self.lambdas
3377 }
3378 fn excess(&self, t: &Array1<f64>) -> f64 {
3379 self.excess_and_ngs(&self.s_of(t)).0
3380 }
3381 fn excess_rho_gradient(&self, t: &Array1<f64>) -> Array1<f64> {
3382 t.mapv(|x| 0.01 * x)
3383 }
3384 fn displaced_neg_score(&self, t: &Array1<f64>) -> Array1<f64> {
3385 self.excess_and_ngs(&self.s_of(t)).1
3386 }
3387 fn base_neg_score(&self) -> Array1<f64> {
3388 self.excess_and_ngs(&self.s_of(&Array1::zeros(self.block_dim())))
3389 .1
3390 }
3391 fn excess_with_displaced_neg_score_batch(
3392 &self,
3393 draws: &Array2<f64>,
3394 ) -> Vec<(f64, Option<Array1<f64>>)> {
3395 if !self.batched {
3396 let mut out = Vec::with_capacity(draws.ncols());
3398 let mut t = Array1::<f64>::zeros(draws.nrows());
3399 for s in 0..draws.ncols() {
3400 t.assign(&draws.column(s));
3401 out.push(self.excess_with_displaced_neg_score(&t));
3402 }
3403 return out;
3404 }
3405 let delta_all = gam_linalg::faer_ndarray::fast_ab(&self.v_b, draws);
3407 let s_all = gam_linalg::faer_ndarray::fast_ab(&self.x, &delta_all);
3408 (0..draws.ncols())
3409 .map(|c| {
3410 let (e, ngs) = self.excess_and_ngs(&s_all.column(c).to_owned());
3411 if e.is_finite() {
3412 (e, Some(ngs))
3413 } else {
3414 (e, None)
3415 }
3416 })
3417 .collect()
3418 }
3419 }
3420
3421 #[test]
3422 fn block_sampled_marginal_batched_matches_serial_matvec() {
3423 let n = 80usize;
3427 let p = 40usize;
3428 let m = 3usize;
3429 let mut x = Array2::<f64>::zeros((n, p));
3430 for i in 0..n {
3431 for j in 0..p {
3432 x[(i, j)] = ((i * 7 + j * 13) % 11) as f64 * 0.05 - 0.25;
3433 }
3434 }
3435 let mut v_b = Array2::<f64>::zeros((p, m));
3436 for i in 0..p {
3437 for r in 0..m {
3438 v_b[(i, r)] = ((i * 3 + r * 5) % 7) as f64 * 0.1 - 0.3;
3439 }
3440 }
3441 let y: Array1<f64> = (0..n).map(|i| ((i % 5) as f64) * 0.2).collect();
3442 let lambdas = array![2.0, 1.0, 0.5];
3443
3444 let serial = super::block_sampled_marginal_correction(&MatvecBlock {
3445 lambdas: lambdas.clone(),
3446 x: x.clone(),
3447 v_b: v_b.clone(),
3448 y: y.clone(),
3449 batched: false,
3450 })
3451 .expect("serial");
3452 let batched = super::block_sampled_marginal_correction(&MatvecBlock {
3453 lambdas,
3454 x,
3455 v_b,
3456 y,
3457 batched: true,
3458 })
3459 .expect("batched");
3460
3461 assert_eq!(serial.n_draws, batched.n_draws);
3462 assert!(
3463 (serial.value - batched.value).abs() <= 1e-10 * (1.0 + serial.value.abs()),
3464 "value serial {} vs batched {}",
3465 serial.value,
3466 batched.value
3467 );
3468 for k in 0..serial.rho_gradient.len() {
3469 assert!(
3470 (serial.rho_gradient[k] - batched.rho_gradient[k]).abs()
3471 <= 1e-10 * (1.0 + serial.rho_gradient[k].abs()),
3472 "rho_gradient[{k}] serial {} vs batched {}",
3473 serial.rho_gradient[k],
3474 batched.rho_gradient[k]
3475 );
3476 }
3477 let ms = serial.moments.expect("serial moments");
3478 let mb = batched.moments.expect("batched moments");
3479 for (a, b) in ms.e_t.iter().zip(mb.e_t.iter()) {
3480 assert!((a - b).abs() <= 1e-10 * (1.0 + a.abs()), "e_t {a} vs {b}");
3481 }
3482 for (a, b) in ms.e_neg_score.iter().zip(mb.e_neg_score.iter()) {
3483 assert!(
3484 (a - b).abs() <= 1e-10 * (1.0 + a.abs()),
3485 "e_neg_score {a} vs {b}"
3486 );
3487 }
3488 for (a, b) in ms.e_t_neg_score.iter().zip(mb.e_t_neg_score.iter()) {
3489 assert!(
3490 (a - b).abs() <= 1e-10 * (1.0 + a.abs()),
3491 "e_t_neg_score {a} vs {b}"
3492 );
3493 }
3494 }
3495
3496 #[test]
3497 fn logit_pg_rao_blackwell_returns_finite_terms() {
3498 let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
3499 let y = array![1.0, 0.0, 1.0, 0.0];
3500 let w = array![1.0, 1.0, 1.0, 1.0];
3501 let penalty = array![[0.2, 0.0], [0.0, 0.4]];
3502 let mode = array![0.0, 0.0];
3503 let roots = vec![array![[0.2_f64.sqrt(), 0.0], [0.0, 0.4_f64.sqrt()]]];
3504 let cfg = NutsConfig {
3505 n_samples: 30,
3506 nwarmup: 30,
3507 n_chains: 2,
3508 target_accept: 0.8,
3509 seed: 789,
3510 };
3511
3512 let rb = super::estimate_logit_pg_rao_blackwell_terms(
3513 x.view(),
3514 y.view(),
3515 w.view(),
3516 penalty.view(),
3517 mode.view(),
3518 &roots,
3519 &cfg,
3520 )
3521 .expect("rao-blackwell PG should run");
3522
3523 assert_eq!(rb.len(), 1);
3524 assert!(rb[0].is_finite());
3525 assert!(rb[0] >= 0.0);
3526 }
3527
3528 #[test]
3529 fn logit_pg_rao_blackwell_rejects_non_bernoulli_response() {
3530 let x = array![[1.0], [1.0]];
3531 let y = array![0.25, 1.0];
3532 let w = array![1.0, 1.0];
3533 let penalty = array![[0.1]];
3534 let mode = array![0.0];
3535 let roots = vec![array![[0.1_f64.sqrt()]]];
3536 let cfg = NutsConfig {
3537 n_samples: 1,
3538 nwarmup: 1,
3539 n_chains: 1,
3540 target_accept: 0.8,
3541 seed: 654,
3542 };
3543
3544 let result = super::estimate_logit_pg_rao_blackwell_terms(
3545 x.view(),
3546 y.view(),
3547 w.view(),
3548 penalty.view(),
3549 mode.view(),
3550 &roots,
3551 &cfg,
3552 );
3553
3554 let err = result
3555 .err()
3556 .expect("PG Rao-Blackwell should reject proportion rows");
3557 assert!(
3558 err.contains("response must be exactly 0 or 1"),
3559 "unexpected error: {err}"
3560 );
3561 }
3562
3563 #[test]
3564 fn logit_pg_rao_blackwell_matches_beta_quadratic_moment_sanity() {
3565 let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
3566 let y = array![1.0, 0.0, 1.0, 0.0];
3567 let w = array![1.0, 1.0, 1.0, 1.0];
3568 let penalty = array![[0.2, 0.0], [0.0, 0.4]];
3569 let mode = array![0.0, 0.0];
3570 let roots = vec![array![[0.2_f64.sqrt(), 0.0], [0.0, 0.4_f64.sqrt()]]];
3571 let cfg = NutsConfig {
3572 n_samples: 120,
3573 nwarmup: 80,
3574 n_chains: 2,
3575 target_accept: 0.8,
3576 seed: 901,
3577 };
3578
3579 let gibbs = run_logit_polya_gamma_gibbs(
3580 x.view(),
3581 y.view(),
3582 w.view(),
3583 penalty.view(),
3584 mode.view(),
3585 &cfg,
3586 )
3587 .expect("pg gibbs should run");
3588 let mc_quad = gibbs
3589 .samples
3590 .rows()
3591 .into_iter()
3592 .map(|beta| {
3593 let sb = penalty.dot(&beta.to_owned());
3594 beta.dot(&sb)
3595 })
3596 .sum::<f64>()
3597 / (gibbs.samples.nrows() as f64);
3598
3599 let rb = super::estimate_logit_pg_rao_blackwell_terms(
3600 x.view(),
3601 y.view(),
3602 w.view(),
3603 penalty.view(),
3604 mode.view(),
3605 &roots,
3606 &cfg,
3607 )
3608 .expect("rao-blackwell PG should run");
3609
3610 let diff = (rb[0] - mc_quad).abs();
3611 assert!(
3612 diff < 0.35,
3613 "Rao-Blackwell vs beta-moment mismatch too large: rb={}, mc={}, diff={}",
3614 rb[0],
3615 mc_quad,
3616 diff
3617 );
3618 }
3619
3620 #[test]
3621 fn survival_hmc_structural_monotonic_returns_finitevalues() {
3622 let age_entry = array![1.0];
3623 let age_exit = array![2.0];
3624 let event_target = array![1u8];
3625 let event_competing = array![0u8];
3626 let sampleweight = array![1.0];
3627 let x_entry = array![[1.0, 0.2]];
3628 let x_exit = array![[1.0, 0.6]];
3629 let x_derivative = array![[0.0, 1.0]];
3630 let penalties = PenaltyBlocks::new(Vec::new());
3631 let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3632 let mode = array![0.0, 0.0];
3633 let hessian = Array2::<f64>::eye(2);
3634
3635 let posterior = super::survival_hmc::SurvivalPosterior::new(
3636 age_entry.view(),
3637 age_exit.view(),
3638 event_target.view(),
3639 event_competing.view(),
3640 sampleweight.view(),
3641 x_entry.view(),
3642 x_exit.view(),
3643 x_derivative.view(),
3644 None,
3645 None,
3646 None,
3647 penalties,
3648 monotonicity,
3649 SurvivalSpec::Net,
3650 true,
3651 2,
3652 mode.view(),
3653 hessian.view(),
3654 )
3655 .expect("construct survival posterior");
3656
3657 let position = array![0.0, 0.0];
3658 let mut grad = Array1::<f64>::zeros(2);
3659 let logp = HamiltonianTarget::logp_and_grad(&posterior, &position, &mut grad);
3660 assert!(logp.is_finite());
3661 assert!(grad.iter().all(|v| v.is_finite()));
3662 }
3663
3664 #[test]
3665 fn survival_hmc_structural_monotonic_differs_from_linear_geometry() {
3666 let age_entry = array![1.0];
3667 let age_exit = array![2.0];
3668 let event_target = array![1u8];
3669 let event_competing = array![0u8];
3670 let sampleweight = array![1.0];
3671 let x_entry = array![[0.2, 0.1]];
3672 let x_exit = array![[0.6, 0.3]];
3673 let x_derivative = array![[1.0, 0.0]];
3674 let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3675 let mode = array![0.0, 0.0];
3676 let hessian = Array2::<f64>::eye(2);
3677 let z = array![std::f64::consts::LN_2, 0.0];
3678
3679 let posterior_linear = super::survival_hmc::SurvivalPosterior::new(
3680 age_entry.view(),
3681 age_exit.view(),
3682 event_target.view(),
3683 event_competing.view(),
3684 sampleweight.view(),
3685 x_entry.view(),
3686 x_exit.view(),
3687 x_derivative.view(),
3688 None,
3689 None,
3690 None,
3691 PenaltyBlocks::new(Vec::new()),
3692 monotonicity,
3693 SurvivalSpec::Net,
3694 false,
3695 0,
3696 mode.view(),
3697 hessian.view(),
3698 )
3699 .expect("construct linear posterior");
3700 let mut grad_linear = Array1::<f64>::zeros(2);
3701 HamiltonianTarget::logp_and_grad(&posterior_linear, &z, &mut grad_linear);
3702
3703 let posterior_struct = super::survival_hmc::SurvivalPosterior::new(
3704 age_entry.view(),
3705 age_exit.view(),
3706 event_target.view(),
3707 event_competing.view(),
3708 sampleweight.view(),
3709 x_entry.view(),
3710 x_exit.view(),
3711 x_derivative.view(),
3712 None,
3713 None,
3714 None,
3715 PenaltyBlocks::new(Vec::new()),
3716 monotonicity,
3717 SurvivalSpec::Net,
3718 true,
3719 2,
3720 mode.view(),
3721 hessian.view(),
3722 )
3723 .expect("construct structural posterior");
3724 let mut grad_struct = Array1::<f64>::zeros(2);
3725 HamiltonianTarget::logp_and_grad(&posterior_struct, &z, &mut grad_struct);
3726
3727 assert!(
3728 (grad_struct[0] - grad_linear[0]).abs() > 1e-6,
3729 "expected structural and linear fallback gradients to differ"
3730 );
3731 assert!(grad_struct[0].is_finite());
3732 assert!(grad_linear[0].is_finite());
3733 }
3734
3735 #[test]
3736 fn survival_hmc_fallback_barrier_rejects_offsets_below_monotonicity_threshold() {
3737 let age_entry = array![1.0];
3738 let age_exit = array![2.0];
3739 let event_target = array![1u8];
3740 let event_competing = array![0u8];
3741 let sampleweight = array![1.0];
3742 let x_entry = array![[1.0, 0.0]];
3743 let x_exit = array![[1.0, 0.0]];
3744 let x_derivative = array![[0.0, 0.0]];
3746 let penalties = PenaltyBlocks::new(Vec::new());
3747 let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3748 let mode = array![0.0, 0.0];
3749 let hessian = Array2::<f64>::eye(2);
3750 let z = array![0.0, 0.0];
3751
3752 let posterior_no_offset = super::survival_hmc::SurvivalPosterior::new(
3753 age_entry.view(),
3754 age_exit.view(),
3755 event_target.view(),
3756 event_competing.view(),
3757 sampleweight.view(),
3758 x_entry.view(),
3759 x_exit.view(),
3760 x_derivative.view(),
3761 None,
3762 None,
3763 Some(array![0.0].view()),
3764 penalties.clone(),
3765 monotonicity,
3766 SurvivalSpec::Net,
3767 false,
3768 0,
3769 mode.view(),
3770 hessian.view(),
3771 )
3772 .expect("construct posterior without derivative offset");
3773 let mut grad_no_offset = Array1::<f64>::zeros(2);
3774 let logp_no_offset =
3775 HamiltonianTarget::logp_and_grad(&posterior_no_offset, &z, &mut grad_no_offset);
3776
3777 let posteriorwith_offset = super::survival_hmc::SurvivalPosterior::new(
3778 age_entry.view(),
3779 age_exit.view(),
3780 event_target.view(),
3781 event_competing.view(),
3782 sampleweight.view(),
3783 x_entry.view(),
3784 x_exit.view(),
3785 x_derivative.view(),
3786 None,
3787 None,
3788 Some(array![2.0].view()),
3789 penalties,
3790 monotonicity,
3791 SurvivalSpec::Net,
3792 false,
3793 0,
3794 mode.view(),
3795 hessian.view(),
3796 )
3797 .expect("construct posterior with derivative offset");
3798 let mut gradwith_offset = Array1::<f64>::zeros(2);
3799 let logpwith_offset =
3800 HamiltonianTarget::logp_and_grad(&posteriorwith_offset, &z, &mut gradwith_offset);
3801
3802 assert!(!logp_no_offset.is_finite());
3803 assert!(!logpwith_offset.is_finite());
3804 assert!(grad_no_offset.iter().all(|v| *v == 0.0));
3805 assert!(gradwith_offset.iter().all(|v| *v == 0.0));
3806 }
3807
3808 #[test]
3809 fn survival_hmc_fallback_barrier_becomes_finite_once_offset_clears_guard() {
3810 let age_entry = array![1.0];
3811 let age_exit = array![2.0];
3812 let event_target = array![1u8];
3813 let event_competing = array![0u8];
3814 let sampleweight = array![1.0];
3815 let x_entry = array![[1.0, 0.0]];
3816 let x_exit = array![[1.0, 0.0]];
3817 let x_derivative = array![[0.0, 0.0]];
3818 let penalties = PenaltyBlocks::new(Vec::new());
3819 let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3820 let mode = array![0.0, 0.0];
3821 let hessian = Array2::<f64>::eye(2);
3822 let z = array![0.0, 0.0];
3823
3824 let posterior_below_guard = super::survival_hmc::SurvivalPosterior::new(
3825 age_entry.view(),
3826 age_exit.view(),
3827 event_target.view(),
3828 event_competing.view(),
3829 sampleweight.view(),
3830 x_entry.view(),
3831 x_exit.view(),
3832 x_derivative.view(),
3833 None,
3834 None,
3835 Some(array![2.0].view()),
3836 penalties.clone(),
3837 monotonicity,
3838 SurvivalSpec::Net,
3839 false,
3840 0,
3841 mode.view(),
3842 hessian.view(),
3843 )
3844 .expect("construct posterior below derivative guard");
3845 let mut grad_below_guard = Array1::<f64>::zeros(2);
3846 let logp_below_guard =
3847 HamiltonianTarget::logp_and_grad(&posterior_below_guard, &z, &mut grad_below_guard);
3848
3849 let posterior_above_guard = super::survival_hmc::SurvivalPosterior::new(
3850 age_entry.view(),
3851 age_exit.view(),
3852 event_target.view(),
3853 event_competing.view(),
3854 sampleweight.view(),
3855 x_entry.view(),
3856 x_exit.view(),
3857 x_derivative.view(),
3858 None,
3859 None,
3860 Some(array![3.1].view()),
3861 penalties,
3862 monotonicity,
3863 SurvivalSpec::Net,
3864 false,
3865 0,
3866 mode.view(),
3867 hessian.view(),
3868 )
3869 .expect("construct posterior above derivative guard");
3870 let mut grad_above_guard = Array1::<f64>::zeros(2);
3871 let logp_above_guard =
3872 HamiltonianTarget::logp_and_grad(&posterior_above_guard, &z, &mut grad_above_guard);
3873
3874 assert!(!logp_below_guard.is_finite());
3875 assert!(logp_above_guard.is_finite());
3876 assert!(grad_below_guard.iter().all(|v| *v == 0.0));
3877 assert!(grad_above_guard.iter().all(|v| v.is_finite()));
3878 }
3879
3880 #[test]
3881 fn survival_hmc_structural_monotonic_handles_sparse_multirow_geometry() {
3882 let age_entry = array![1.0, 1.2];
3883 let age_exit = array![2.0, 2.4];
3884 let event_target = array![1u8, 1u8];
3885 let event_competing = array![0u8, 0u8];
3886 let sampleweight = array![1.0, 1.0];
3887 let x_entry = array![[0.1, 0.0, 0.2], [0.2, 0.1, 0.2]];
3888 let x_exit = array![[0.4, 0.2, 0.3], [0.6, 0.1, 0.3]];
3889 let x_derivative = array![[1.0, 0.0, 0.0], [0.5, 1.0, 0.0]];
3891 let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3892 let mode = array![4.0, 2.0, 0.0];
3893 let hessian = Array2::<f64>::eye(3);
3894 let z = array![0.05, -0.1, 0.15];
3895
3896 let posterior = super::survival_hmc::SurvivalPosterior::new(
3897 age_entry.view(),
3898 age_exit.view(),
3899 event_target.view(),
3900 event_competing.view(),
3901 sampleweight.view(),
3902 x_entry.view(),
3903 x_exit.view(),
3904 x_derivative.view(),
3905 None,
3906 None,
3907 None,
3908 PenaltyBlocks::new(Vec::new()),
3909 monotonicity,
3910 SurvivalSpec::Net,
3911 true,
3912 2,
3913 mode.view(),
3914 hessian.view(),
3915 )
3916 .expect("construct structural posterior");
3917
3918 let mut grad = Array1::<f64>::zeros(3);
3919 let logp = HamiltonianTarget::logp_and_grad(&posterior, &z, &mut grad);
3920 assert!(logp.is_finite());
3921 assert!(grad.iter().all(|v| v.is_finite()));
3922 }
3923}
3924
3925impl HamiltonianTarget<Array1<f64>> for NutsPosterior {
3927 fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
3928 NUTS_RESIDUAL_SCRATCH.with(|scratch| {
3929 let mut residual = scratch.borrow_mut();
3930 if residual.len() != self.data.n_samples {
3931 *residual = Array1::<f64>::zeros(self.data.n_samples);
3932 }
3933 self.compute_logp_and_grad_nd_into(position, &mut residual, grad)
3934 })
3935 }
3936}
3937
3938#[derive(Clone, Debug, Serialize, Deserialize)]
3940pub struct NutsConfig {
3941 pub n_samples: usize,
3943 pub nwarmup: usize,
3945 pub n_chains: usize,
3947 pub target_accept: f64,
3949 #[serde(default = "default_nuts_seed")]
3951 pub seed: u64,
3952}
3953
3954fn default_nuts_seed() -> u64 {
3955 42
3956}
3957
3958fn validate_nuts_target_accept(target_accept: f64) -> Result<(), HmcError> {
3959 if target_accept.is_finite() && target_accept > 0.0 && target_accept < 1.0 {
3960 Ok(())
3961 } else {
3962 Err(HmcError::InvalidConfig {
3963 reason: format!(
3964 "NUTS target_accept must be finite and lie in (0, 1), got {target_accept}"
3965 ),
3966 })
3967 }
3968}
3969
3970const MIN_NUTS_SAMPLES: usize = 4;
3978
3979const MIN_NUTS_CHAINS: usize = 1;
3988
3989fn validate_nuts_draws(config: &NutsConfig) -> Result<(), HmcError> {
3994 if config.n_chains < MIN_NUTS_CHAINS {
3995 return Err(HmcError::InvalidConfig {
3996 reason: format!(
3997 "NUTS n_chains must be >= {MIN_NUTS_CHAINS}; with zero chains the \
3998 sampler has no initial positions to run, got {}",
3999 config.n_chains
4000 ),
4001 });
4002 }
4003 if config.n_samples < MIN_NUTS_SAMPLES {
4004 return Err(HmcError::InvalidConfig {
4005 reason: format!(
4006 "NUTS n_samples must be >= {MIN_NUTS_SAMPLES} so split-R-hat / ESS \
4007 diagnostics are defined, got {}",
4008 config.n_samples
4009 ),
4010 });
4011 }
4012 Ok(())
4013}
4014
4015pub(crate) fn validate_nuts_config(config: &NutsConfig) -> Result<(), HmcError> {
4019 validate_nuts_target_accept(config.target_accept)?;
4020 validate_nuts_draws(config)?;
4021 Ok(())
4022}
4023
4024#[inline]
4025fn splitmix64(x: u64) -> u64 {
4026 gam_linalg::utils::splitmix64_hash(x)
4027}
4028
4029#[inline]
4030fn chain_stream_seed(seed: u64, chain: usize, stream: u64) -> u64 {
4031 splitmix64(seed ^ stream ^ ((chain as u64).wrapping_mul(0xD1B5_4A32_D192_ED03)))
4032}
4033
4034#[inline]
4035fn nuts_transition_seed(seed: u64, stream: u64) -> u64 {
4036 splitmix64(seed ^ stream ^ 0xA24B_AED4_963E_E407)
4037}
4038
4039#[inline]
4040fn gibbs_pg_seed(seed: u64, chain: usize, stream: u64, iter: usize) -> u64 {
4041 chain_stream_seed(
4042 seed,
4043 chain,
4044 stream ^ ((iter as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15)),
4045 )
4046}
4047
4048fn draw_logit_pg1_omega(
4049 shapes: ArrayView1<'_, u32>,
4050 tilts: ArrayView1<'_, f64>,
4051 seed: u64,
4052 out: &mut Array1<f64>,
4053) -> Result<(), String> {
4054 if out.len() != tilts.len() {
4055 return Err(HmcError::DimensionMismatch {
4056 reason: "draw_logit_pg1_omega: output length mismatch".to_string(),
4057 }
4058 .into());
4059 }
4060 let draws = crate::gpu_polya_gamma::draw_batch(PolyaGammaBatchInput {
4061 shapes,
4062 tilts,
4063 seed: PgSeed(seed),
4064 })?;
4065 out.assign(&draws);
4066 out.mapv_inplace(|v| v.max(1.0e-12));
4067 Ok(())
4068}
4069
4070const HIGH_DIM_THRESHOLD: usize = 50;
4076
4077const HIGH_DIM_TARGET_ACCEPT_FLOOR: f64 = 0.92;
4081const LOW_DIM_TARGET_ACCEPT_FLOOR: f64 = 0.90;
4083const MAX_TARGET_ACCEPT: f64 = 0.95;
4086
4087const MIN_WARMUP_FOR_MASS_ADAPT: usize = 80;
4092
4093const DENSE_MASS_MATRIX_MAX_DIM: usize = 75;
4097
4098const MASS_REGULARIZE_HIGH_DIM: f64 = 0.14;
4102const MASS_REGULARIZE_LOW_DIM: f64 = 0.10;
4103const SURVIVAL_MASS_REGULARIZE_HIGH_DIM: f64 = 0.18;
4106const SURVIVAL_MASS_REGULARIZE_LOW_DIM: f64 = 0.12;
4107
4108const MASS_MATRIX_JITTER: f64 = 1e-5;
4111
4112#[inline]
4113fn robust_target_accept(requested: f64, dim: usize) -> f64 {
4114 let floor = if dim > HIGH_DIM_THRESHOLD {
4115 HIGH_DIM_TARGET_ACCEPT_FLOOR
4116 } else {
4117 LOW_DIM_TARGET_ACCEPT_FLOOR
4118 };
4119 requested.max(floor).min(MAX_TARGET_ACCEPT)
4120}
4121
4122fn jittered_initial_positions(
4123 config: &NutsConfig,
4124 dim: usize,
4125 scale: f64,
4126 stream: u64,
4127) -> Vec<Array1<f64>> {
4128 (0..config.n_chains)
4129 .map(|chain| {
4130 let mut rng = StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, stream));
4131 Array1::from_shape_fn(dim, |_| sample_standard_normal(&mut rng) * scale)
4132 })
4133 .collect()
4134}
4135
4136fn robust_mass_matrix_config(dim: usize, nwarmup: usize) -> NUTSMassMatrixConfig {
4137 if nwarmup < MIN_WARMUP_FOR_MASS_ADAPT {
4138 return NUTSMassMatrixConfig::disabled();
4139 }
4140 let start_buffer = (nwarmup / 8).clamp(35, 180);
4141 let end_buffer = (nwarmup / 5).clamp(50, 250);
4142 let initial_window = (nwarmup / 20).clamp(10, 60);
4143 NUTSMassMatrixConfig {
4144 adaptation: MassMatrixAdaptation::Diagonal,
4145 start_buffer,
4146 end_buffer,
4147 initial_window,
4148 regularize: if dim > HIGH_DIM_THRESHOLD {
4149 MASS_REGULARIZE_HIGH_DIM
4150 } else {
4151 MASS_REGULARIZE_LOW_DIM
4152 },
4153 jitter: MASS_MATRIX_JITTER,
4154 dense_max_dim: DENSE_MASS_MATRIX_MAX_DIM,
4155 }
4156}
4157
4158fn robust_survival_mass_matrix_config(dim: usize, nwarmup: usize) -> NUTSMassMatrixConfig {
4159 if nwarmup < MIN_WARMUP_FOR_MASS_ADAPT {
4160 return NUTSMassMatrixConfig::disabled();
4161 }
4162 let start_buffer = (nwarmup / 7).clamp(40, 200);
4165 let end_buffer = (nwarmup / 4).clamp(60, 280);
4166 let initial_window = (nwarmup / 20).clamp(10, 60);
4167 NUTSMassMatrixConfig {
4168 adaptation: MassMatrixAdaptation::Diagonal,
4169 start_buffer,
4170 end_buffer,
4171 initial_window,
4172 regularize: if dim > HIGH_DIM_THRESHOLD {
4173 SURVIVAL_MASS_REGULARIZE_HIGH_DIM
4174 } else {
4175 SURVIVAL_MASS_REGULARIZE_LOW_DIM
4176 },
4177 jitter: MASS_MATRIX_JITTER,
4178 dense_max_dim: DENSE_MASS_MATRIX_MAX_DIM,
4179 }
4180}
4181
4182impl Default for NutsConfig {
4183 fn default() -> Self {
4184 Self {
4185 n_samples: 1000,
4186 nwarmup: 500,
4187 n_chains: 4,
4188 target_accept: 0.9,
4189 seed: 42,
4190 }
4191 }
4192}
4193
4194impl NutsConfig {
4195 pub fn for_dimension(n_params: usize) -> Self {
4203 let effective_autocorr = (n_params as f64).sqrt().max(1.0);
4205
4206 let target_ess = 100 * n_params;
4208
4209 let raw_samples = (target_ess as f64 * (1.0 + 2.0 * effective_autocorr) * 1.5) as usize;
4211
4212 let n_samples = raw_samples.clamp(500, 10_000);
4214
4215 let nwarmup = n_samples;
4217
4218 let n_chains = if n_params > 50 { 4 } else { 2 };
4220
4221 Self {
4222 n_samples,
4223 nwarmup,
4224 n_chains,
4225 target_accept: 0.9,
4226 seed: 42,
4227 }
4228 }
4229}
4230
4231#[derive(Clone, Debug)]
4233pub struct NutsResult {
4234 pub samples: Array2<f64>,
4236 pub posterior_mean: Array1<f64>,
4238 pub posterior_std: Array1<f64>,
4240 pub rhat: f64,
4242 pub ess: f64,
4244 pub converged: bool,
4246}
4247
4248#[derive(Clone, Copy)]
4249struct NutsConvergenceThresholds {
4250 max_rhat: f64,
4251 min_ess: Option<f64>,
4252}
4253
4254impl NutsConvergenceThresholds {
4255 #[inline]
4256 fn converged(self, rhat: f64, ess: f64) -> bool {
4257 let rhat_ok = rhat < self.max_rhat;
4258 match self.min_ess {
4259 Some(min_ess) => rhat_ok && ess > min_ess,
4260 None => rhat_ok,
4261 }
4262 }
4263}
4264
4265fn run_whitened_nuts_samples<Target>(
4266 target: Target,
4267 initial_positions: Vec<Array1<f64>>,
4268 config: &NutsConfig,
4269 dim: usize,
4270 mass_cfg: NUTSMassMatrixConfig,
4271 transition_seed_stream: u64,
4272 sampling_error_label: &str,
4273) -> Result<(Array3<f64>, String), String>
4274where
4275 Target: HamiltonianTarget<Array1<f64>> + Sync + Send,
4276{
4277 let mut sampler = GenericNUTS::new_with_mass_matrix(
4278 target,
4279 initial_positions,
4280 robust_target_accept(config.target_accept, dim),
4281 mass_cfg,
4282 )
4283 .set_seed(nuts_transition_seed(config.seed, transition_seed_stream));
4284
4285 let (samples_array, run_stats) = sampler
4286 .run_progress(config.n_samples, config.nwarmup)
4287 .map_err(|e| format!("{sampling_error_label}: {e}"))?;
4288 Ok((samples_array, run_stats.to_string()))
4289}
4290
4291fn unwhiten_samples(
4292 samples_array: &Array3<f64>,
4293 mode: &Array1<f64>,
4294 chol: &Array2<f64>,
4295 dim: usize,
4296 z_start: usize,
4297) -> Array2<f64> {
4298 let shape = samples_array.shape();
4299 let n_chains = shape[0];
4300 let n_samples_out = shape[1];
4301 let total_samples = n_chains * n_samples_out;
4302
4303 let mut samples = Array2::<f64>::zeros((total_samples, dim));
4304 let mut z_buffer = Array1::<f64>::zeros(dim);
4305 for chain in 0..n_chains {
4306 for sample_i in 0..n_samples_out {
4307 let zview = samples_array.slice(ndarray::s![chain, sample_i, z_start..z_start + dim]);
4308 z_buffer.assign(&zview);
4309 let beta = mode + &chol.dot(&z_buffer);
4310 let sample_idx = chain * n_samples_out + sample_i;
4311 samples.row_mut(sample_idx).assign(&beta);
4312 }
4313 }
4314
4315 samples
4316}
4317
4318fn summarize_unwhitened_nuts_samples(
4319 samples: Array2<f64>,
4320 samples_array: &Array3<f64>,
4321 empty_mean: Array1<f64>,
4322 convergence: NutsConvergenceThresholds,
4323) -> NutsResult {
4324 let posterior_mean = samples.mean_axis(Axis(0)).unwrap_or(empty_mean);
4325 let posterior_std = samples.std_axis(Axis(0), 0.0);
4326 let (rhat, ess) = compute_split_rhat_and_ess(samples_array);
4327 let converged = convergence.converged(rhat, ess);
4328
4329 NutsResult {
4330 samples,
4331 posterior_mean,
4332 posterior_std,
4333 rhat,
4334 ess,
4335 converged,
4336 }
4337}
4338
4339fn run_whitened_nuts_result<Target>(
4340 target: Target,
4341 mode: &Array1<f64>,
4342 chol: &Array2<f64>,
4343 initial_positions: Vec<Array1<f64>>,
4344 config: &NutsConfig,
4345 dim: usize,
4346 mass_cfg: NUTSMassMatrixConfig,
4347 transition_seed_stream: u64,
4348 sampling_error_label: &str,
4349 empty_mean: Array1<f64>,
4350 convergence: NutsConvergenceThresholds,
4351) -> Result<(NutsResult, String), String>
4352where
4353 Target: HamiltonianTarget<Array1<f64>> + Sync + Send,
4354{
4355 let (samples_array, run_stats) = run_whitened_nuts_samples(
4356 target,
4357 initial_positions,
4358 config,
4359 dim,
4360 mass_cfg,
4361 transition_seed_stream,
4362 sampling_error_label,
4363 )?;
4364 let samples = unwhiten_samples(&samples_array, mode, chol, dim, 0);
4365 let result =
4366 summarize_unwhitened_nuts_samples(samples, &samples_array, empty_mean, convergence);
4367 Ok((result, run_stats))
4368}
4369
4370impl NutsResult {
4371 pub fn posterior_mean_of<F>(&self, f: F) -> f64
4374 where
4375 F: Fn(ArrayView1<f64>) -> f64 + Sync,
4376 {
4377 let n = self.samples.nrows();
4378 if n == 0 {
4379 return 0.0;
4380 }
4381 use rayon::iter::{IntoParallelIterator, ParallelIterator};
4384 let sum: f64 = (0..n).into_par_iter().map(|i| f(self.samples.row(i))).sum();
4385 sum / n as f64
4386 }
4387
4388 pub fn posterior_interval_of<F>(&self, f: F, lower_pct: f64, upper_pct: f64) -> (f64, f64)
4390 where
4391 F: Fn(ArrayView1<f64>) -> f64,
4392 {
4393 let n = self.samples.nrows();
4394 if n == 0 {
4395 return (0.0, 0.0);
4396 }
4397 let mut values: Vec<f64> = (0..n).map(|i| f(self.samples.row(i))).collect();
4398 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
4399
4400 let lower_idx = ((lower_pct / 100.0) * n as f64).floor() as usize;
4401 let upper_idx = ((upper_pct / 100.0) * n as f64).ceil() as usize;
4402
4403 (
4404 values[lower_idx.min(n.saturating_sub(1))],
4405 values[upper_idx.min(n.saturating_sub(1))],
4406 )
4407 }
4408}
4409
4410#[inline]
4411fn sample_standard_normal<R: rand::Rng + ?Sized>(rng: &mut R) -> f64 {
4412 let u1 = rng.random::<f64>().max(1e-16);
4413 let u2 = rng.random::<f64>();
4414 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
4415}
4416
4417pub fn run_logit_polya_gamma_gibbs(
4426 x: ArrayView2<f64>,
4427 y: ArrayView1<f64>,
4428 weights: ArrayView1<f64>,
4429 penalty_matrix: ArrayView2<f64>,
4430 mode: ArrayView1<f64>,
4431 config: &NutsConfig,
4432) -> Result<NutsResult, String> {
4433 let n = x.nrows();
4434 let p = x.ncols();
4435 if y.len() != n || weights.len() != n {
4436 return Err(HmcError::DimensionMismatch {
4437 reason: "run_logit_polya_gamma_gibbs: input length mismatch".to_string(),
4438 }
4439 .into());
4440 }
4441 if mode.len() != p || penalty_matrix.nrows() != p || penalty_matrix.ncols() != p {
4442 return Err(HmcError::DimensionMismatch {
4443 reason: "run_logit_polya_gamma_gibbs: coefficient/penalty dimension mismatch"
4444 .to_string(),
4445 }
4446 .into());
4447 }
4448 if !weights.iter().all(|w| (*w - 1.0).abs() <= 1e-10) {
4449 return Err(HmcError::InvalidConfig {
4450 reason: "run_logit_polya_gamma_gibbs requires unit weights (PG(1,·)); use NUTS for non-unit weights".to_string(),
4451 }
4452 .into());
4453 }
4454 validate_binary_responses("run_logit_polya_gamma_gibbs", &y, &weights).map_err(String::from)?;
4455 validate_nuts_config(config).map_err(String::from)?;
4462
4463 let n_iter = config.nwarmup + config.n_samples;
4464
4465 let kappa = y.mapv(|v| v - 0.5);
4467 let rhs_b = fast_atv(&x, &kappa);
4468
4469 let mut samples_array = Array3::<f64>::zeros((config.n_chains, config.n_samples, p));
4470 let mut eta = Array1::<f64>::zeros(n);
4471 let mut omega = Array1::<f64>::ones(n);
4472 let pg_shapes = Array1::<u32>::from_elem(n, 1);
4473 let mut xw = x.to_owned();
4474 let mut xt_omega_x = Array2::<f64>::zeros((p, p));
4475 let penalty = penalty_matrix.to_owned();
4476 let mut q = Array2::<f64>::zeros((p, p));
4477 let mut mean = Array1::<f64>::zeros(p);
4478 let mut z = Array1::<f64>::zeros(p);
4479 let mut noise = Array1::<f64>::zeros(p);
4480
4481 for chain in 0..config.n_chains {
4482 let mut init_rng =
4483 StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0xB3C4_5A1F_8E9D_7632));
4484 let mut draw_rng =
4485 StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0x17A9_26D5_4C1B_E083));
4486 let mut beta = mode.to_owned();
4487 for j in 0..p {
4489 beta[j] += 0.05 * sample_standard_normal(&mut init_rng);
4490 }
4491
4492 for iter in 0..n_iter {
4493 eta.assign(&gam_linalg::faer_ndarray::fast_av(&x, &beta));
4494 draw_logit_pg1_omega(
4495 pg_shapes.view(),
4496 eta.view(),
4497 gibbs_pg_seed(config.seed, chain, 0x4D94_DF4E_5D72_81AB, iter),
4498 &mut omega,
4499 )?;
4500
4501 ndarray::Zip::indexed(xw.rows_mut())
4504 .and(x.rows())
4505 .and(&omega)
4506 .par_for_each(|_idx, mut xw_row, x_row, omega_i| {
4507 let s = omega_i.sqrt();
4508 for j in 0..p {
4509 xw_row[j] = x_row[j] * s;
4510 }
4511 });
4512 fast_ata_into(&xw, &mut xt_omega_x);
4513
4514 q.assign(&penalty);
4515 q += &xt_omega_x;
4516
4517 let factor = q
4519 .cholesky(Side::Lower)
4520 .map_err(|e| format!("PG Gibbs failed to factor Q: {:?}", e))?;
4521 mean.assign(&factor.solvevec(&rhs_b));
4522
4523 for j in 0..p {
4524 z[j] = sample_standard_normal(&mut draw_rng);
4525 }
4526 let l = factor.lower_triangular();
4527 back_substitution_lower_transpose_guarded_into(&l, &z, &mut noise);
4528 beta.assign(&(&mean + &noise));
4529
4530 if iter >= config.nwarmup {
4531 let keep_idx = iter - config.nwarmup;
4532 samples_array
4533 .slice_mut(ndarray::s![chain, keep_idx, ..])
4534 .assign(&beta);
4535 }
4536 }
4537 }
4538
4539 let total_samples = config.n_chains * config.n_samples;
4540 let mut samples = Array2::<f64>::zeros((total_samples, p));
4541 for chain in 0..config.n_chains {
4542 for s in 0..config.n_samples {
4543 let idx = chain * config.n_samples + s;
4544 samples
4545 .row_mut(idx)
4546 .assign(&samples_array.slice(ndarray::s![chain, s, ..]));
4547 }
4548 }
4549
4550 let posterior_mean = samples
4551 .mean_axis(Axis(0))
4552 .unwrap_or_else(|| Array1::zeros(p));
4553 let posterior_std = samples.std_axis(Axis(0), 0.0);
4554 let (rhat, ess) = if config.n_chains >= 2 && config.n_samples >= 4 {
4555 compute_split_rhat_and_ess(&samples_array)
4556 } else {
4557 (1.0, (total_samples as f64) * 0.5)
4558 };
4559 let converged = rhat < 1.1 && ess > 100.0;
4560
4561 Ok(NutsResult {
4562 samples,
4563 posterior_mean,
4564 posterior_std,
4565 rhat,
4566 ess,
4567 converged,
4568 })
4569}
4570
4571pub fn estimate_logit_pg_rao_blackwell_terms(
4581 x: ArrayView2<f64>,
4582 y: ArrayView1<f64>,
4583 weights: ArrayView1<f64>,
4584 penalty_matrix: ArrayView2<f64>,
4585 mode: ArrayView1<f64>,
4586 penalty_roots: &[Array2<f64>],
4587 config: &NutsConfig,
4588) -> Result<Array1<f64>, String> {
4589 let n = x.nrows();
4590 let p = x.ncols();
4591 if y.len() != n || weights.len() != n {
4592 return Err(HmcError::DimensionMismatch {
4593 reason: "estimate_logit_pg_rao_blackwell_terms: input length mismatch".to_string(),
4594 }
4595 .into());
4596 }
4597 if mode.len() != p || penalty_matrix.nrows() != p || penalty_matrix.ncols() != p {
4598 return Err(HmcError::DimensionMismatch {
4599 reason: "estimate_logit_pg_rao_blackwell_terms: coefficient/penalty dimension mismatch"
4600 .to_string(),
4601 }
4602 .into());
4603 }
4604 if !weights.iter().all(|w| (*w - 1.0).abs() <= 1e-10) {
4605 return Err(HmcError::InvalidConfig {
4606 reason: "estimate_logit_pg_rao_blackwell_terms requires unit weights (PG(1,·))"
4607 .to_string(),
4608 }
4609 .into());
4610 }
4611 validate_binary_responses("estimate_logit_pg_rao_blackwell_terms", &y, &weights)
4612 .map_err(String::from)?;
4613 if penalty_roots.iter().any(|r| r.ncols() != p) {
4614 return Err(HmcError::DimensionMismatch {
4615 reason: "estimate_logit_pg_rao_blackwell_terms: root width mismatch".to_string(),
4616 }
4617 .into());
4618 }
4619 let penalty_roots_t: Vec<Array2<f64>> =
4622 penalty_roots.iter().map(|r| r.t().to_owned()).collect();
4623
4624 let n_iter = config.nwarmup + config.n_samples;
4625
4626 let kappa = y.mapv(|v| v - 0.5);
4629 let rhs_b = fast_atv(&x, &kappa);
4630
4631 let penalty = penalty_matrix.to_owned();
4632 let mut eta = Array1::<f64>::zeros(n);
4633 let mut omega = Array1::<f64>::ones(n);
4634 let pg_shapes = Array1::<u32>::from_elem(n, 1);
4635 let mut xw = x.to_owned();
4636 let mut xt_omega_x = Array2::<f64>::zeros((p, p));
4637 let mut q = Array2::<f64>::zeros((p, p));
4638 let mut mean = Array1::<f64>::zeros(p);
4639 let mut rb_sum = Array1::<f64>::zeros(penalty_roots.len());
4640 let mut z = Array1::<f64>::zeros(p);
4641 let mut noise = Array1::<f64>::zeros(p);
4642
4643 let mut kept = 0usize;
4644 for chain in 0..config.n_chains {
4645 let mut init_rng =
4646 StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0x28F0_7B65_1A4D_C93E));
4647 let mut draw_rng =
4648 StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0xC642_6E35_B5A9_1D80));
4649 let mut beta = mode.to_owned();
4650 for j in 0..p {
4651 beta[j] += 0.05 * sample_standard_normal(&mut init_rng);
4652 }
4653
4654 for iter in 0..n_iter {
4655 eta.assign(&gam_linalg::faer_ndarray::fast_av(&x, &beta));
4656 draw_logit_pg1_omega(
4657 pg_shapes.view(),
4658 eta.view(),
4659 gibbs_pg_seed(config.seed, chain, 0x83F1_56C9_A7E0_2D4B, iter),
4660 &mut omega,
4661 )?;
4662
4663 ndarray::Zip::from(xw.rows_mut())
4664 .and(x.rows())
4665 .and(&omega)
4666 .par_for_each(|mut xw_row, x_row, &omega_i| {
4667 let s = omega_i.sqrt();
4668 for j in 0..p {
4669 xw_row[j] = x_row[j] * s;
4670 }
4671 });
4672 fast_ata_into(&xw, &mut xt_omega_x);
4673
4674 q.assign(&penalty);
4677 q += &xt_omega_x;
4678
4679 let factor = q
4680 .cholesky(Side::Lower)
4681 .map_err(|e| format!("PG Rao-Blackwell failed to factor Q: {:?}", e))?;
4682 mean.assign(&factor.solvevec(&rhs_b));
4685
4686 for j in 0..p {
4688 z[j] = sample_standard_normal(&mut draw_rng);
4689 }
4690 let l = factor.lower_triangular();
4691 back_substitution_lower_transpose_guarded_into(&l, &z, &mut noise);
4692 beta.assign(&(&mean + &noise));
4693
4694 if iter < config.nwarmup {
4695 continue;
4696 }
4697 kept += 1;
4698
4699 for (k, r_k) in penalty_roots.iter().enumerate() {
4700 if r_k.nrows() == 0 {
4701 continue;
4702 }
4703
4704 let rmu = r_k.dot(&mean);
4706 let mu_quad = rmu.dot(&rmu);
4707
4708 let solved_mat = factor.solve_mat(&penalty_roots_t[k]); let solved_t = solved_mat.t();
4713 let mut trace_term = 0.0_f64;
4714 for (&a, &b) in r_k.iter().zip(solved_t.iter()) {
4715 trace_term += a * b;
4716 }
4717
4718 rb_sum[k] += trace_term + mu_quad;
4719 }
4720 }
4721 }
4722
4723 if kept == 0 {
4724 return Err(HmcError::SamplingFailed {
4725 reason: "estimate_logit_pg_rao_blackwell_terms: no retained samples".to_string(),
4726 }
4727 .into());
4728 }
4729 let out = rb_sum.mapv(|v| v / (kept as f64));
4730 if !out.iter().all(|v| v.is_finite()) {
4731 return Err(HmcError::NonFiniteState {
4732 reason: "estimate_logit_pg_rao_blackwell_terms: non-finite expectation".to_string(),
4733 }
4734 .into());
4735 }
4736 Ok(out)
4737}
4738
4739pub(crate) fn run_nuts_sampling(
4752 x: ArrayView2<f64>,
4753 y: ArrayView1<f64>,
4754 weights: ArrayView1<f64>,
4755 penalty_matrix: ArrayView2<f64>,
4756 mode: ArrayView1<f64>,
4757 hessian: ArrayView2<f64>,
4758 nuts_family: NutsFamily,
4759 gamma_shape: f64,
4760 dispersion: gam_solve::model_types::Dispersion,
4761 firth_bias_reduction: bool,
4762 offset: Option<ArrayView1<f64>>,
4763 config: &NutsConfig,
4764) -> Result<NutsResult, String> {
4765 validate_firth_support(nuts_family, firth_bias_reduction).map_err(String::from)?;
4766 validate_nuts_config(config).map_err(String::from)?;
4767 if nuts_family == NutsFamily::TweedieLog && !is_valid_tweedie_power(gamma_shape) {
4768 return Err(format!(
4769 "Tweedie variance power must be finite and strictly between 1 and 2; got {gamma_shape}"
4770 ));
4771 }
4772 let dim = mode.len();
4773
4774 let target = NutsPosterior::new(
4777 x,
4778 y,
4779 weights,
4780 penalty_matrix,
4781 mode,
4782 hessian,
4783 nuts_family,
4784 gamma_shape,
4785 dispersion,
4786 firth_bias_reduction,
4787 )?;
4788 let target = match offset {
4789 Some(offset) => target.with_offset(offset)?,
4790 None => target,
4791 };
4792
4793 let chol = target.chol().clone();
4795 let mode_arr = target.mode().clone();
4796
4797 let initial_positions = jittered_initial_positions(config, dim, 0.1, 0x0F65_83B2_BC71_4D9E);
4798 let mass_cfg = robust_mass_matrix_config(dim, config.nwarmup);
4799 let (result, run_stats) = run_whitened_nuts_result(
4800 target,
4801 &mode_arr,
4802 &chol,
4803 initial_positions,
4804 config,
4805 dim,
4806 mass_cfg,
4807 0xF1D3_C2B5_A697_804E,
4808 "NUTS sampling failed",
4809 Array1::zeros(dim),
4810 NutsConvergenceThresholds {
4811 max_rhat: 1.1,
4812 min_ess: Some(100.0),
4813 },
4814 )?;
4815 log::info!("NUTS sampling complete: {}", run_stats);
4816
4817 Ok(result)
4818}
4819
4820struct GaussianModeTarget;
4842
4843impl HamiltonianTarget<Array1<f64>> for GaussianModeTarget {
4844 #[inline]
4845 fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
4846 let mut quad = 0.0;
4851 for (g, &zi) in grad.iter_mut().zip(position.iter()) {
4852 *g = -zi;
4853 quad += zi * zi;
4854 }
4855 -0.5 * quad
4856 }
4857}
4858
4859pub fn sample_gaussian_mode_posterior(
4874 mode: ArrayView1<f64>,
4875 hessian: ArrayView2<f64>,
4876 config: &NutsConfig,
4877) -> Result<GaussianModePosterior, String> {
4878 validate_nuts_config(config).map_err(String::from)?;
4879 let dim = mode.len();
4880 if hessian.nrows() != dim || hessian.ncols() != dim {
4881 return Err(format!(
4882 "Gaussian-posterior fallback: hessian shape {:?} does not match mode dim {dim}",
4883 hessian.dim()
4884 ));
4885 }
4886 if dim == 0 {
4887 return Err("Gaussian-posterior fallback: zero-dimensional posterior".to_string());
4888 }
4889
4890 let mut h = hessian.to_owned();
4896 for i in 0..dim {
4897 for j in (i + 1)..dim {
4898 let avg = 0.5 * (h[[i, j]] + h[[j, i]]);
4899 h[[i, j]] = avg;
4900 h[[j, i]] = avg;
4901 }
4902 }
4903 let diag_scale = (0..dim).map(|i| h[[i, i]].abs()).fold(0.0_f64, f64::max);
4904 let jitter = (diag_scale * 1e-10).max(1e-12);
4905 for i in 0..dim {
4906 h[[i, i]] += jitter;
4907 }
4908
4909 let mode_owned = mode.to_owned();
4910 let whitening = hessian_whitening_transform(
4911 h.view(),
4912 dim,
4913 1.0,
4914 "Gaussian-posterior fallback Cholesky failed",
4915 )?;
4916 let chol = whitening.chol;
4917 let target = GaussianModeTarget;
4918 let initial_positions = jittered_initial_positions(config, dim, 0.1, 0x51A6_2C73_90E4_1DBF);
4919 let mass_cfg = robust_mass_matrix_config(dim, config.nwarmup);
4920 let (result, run_stats) = run_whitened_nuts_result(
4921 target,
4922 &mode_owned,
4923 &chol,
4924 initial_positions,
4925 config,
4926 dim,
4927 mass_cfg,
4928 0x7C19_5A3E_82D6_44B1,
4929 "Gaussian-posterior fallback NUTS sampling failed",
4930 mode_owned.clone(),
4931 NutsConvergenceThresholds {
4932 max_rhat: 1.1,
4933 min_ess: None,
4934 },
4935 )?;
4936 log::info!(
4937 "never-fail Gaussian-posterior fallback: sampling complete dim={dim} {}",
4938 run_stats
4939 );
4940
4941 Ok(GaussianModePosterior {
4942 samples: result.samples,
4943 posterior_mean: result.posterior_mean,
4944 posterior_std: result.posterior_std,
4945 rhat: result.rhat,
4946 ess: result.ess,
4947 })
4948}
4949
4950const RHO_NUTS_INFEASIBLE_LOGP_PENALTY: f64 = 1.0e8;
4956
4957struct WhitenedRhoCriterionTarget<F> {
4973 criterion_and_grad: Mutex<F>,
4975 mode: Array1<f64>,
4977 chol: Array2<f64>,
4979 chol_t: Array2<f64>,
4981 cost_hat: f64,
4983}
4984
4985impl<F> HamiltonianTarget<Array1<f64>> for WhitenedRhoCriterionTarget<F>
4986where
4987 F: FnMut(&Array1<f64>) -> Option<(f64, Array1<f64>)> + Send,
4988{
4989 fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
4990 let rho = &self.mode + &self.chol.dot(position);
4991 let eval = {
4992 let mut criterion = self
4993 .criterion_and_grad
4994 .lock()
4995 .expect("rho-criterion mutex poisoned");
4996 (*criterion)(&rho)
4997 };
4998 match eval {
4999 Some((cost, g))
5000 if cost.is_finite()
5001 && g.len() == position.len()
5002 && g.iter().all(|v| v.is_finite()) =>
5003 {
5004 let grad_z = self.chol_t.dot(&g);
5005 for (gi, &v) in grad.iter_mut().zip(grad_z.iter()) {
5006 *gi = -v;
5007 }
5008 -(cost - self.cost_hat)
5009 }
5010 _ => {
5011 let mut quad = 0.0;
5013 for (gi, &zi) in grad.iter_mut().zip(position.iter()) {
5014 *gi = -zi;
5015 quad += zi * zi;
5016 }
5017 -0.5 * quad - RHO_NUTS_INFEASIBLE_LOGP_PENALTY
5018 }
5019 }
5020 }
5021}
5022
5023pub fn run_rho_criterion_nuts<F>(
5038 rho_hat: ArrayView1<f64>,
5039 outer_hessian: ArrayView2<f64>,
5040 mut criterion_and_grad: F,
5041 config: &NutsConfig,
5042) -> Result<NutsResult, String>
5043where
5044 F: FnMut(&Array1<f64>) -> Option<(f64, Array1<f64>)> + Send,
5045{
5046 validate_nuts_config(config).map_err(String::from)?;
5047 let dim = rho_hat.len();
5048 if dim == 0 {
5049 return Err("rho-posterior NUTS: zero-dimensional rho".to_string());
5050 }
5051 if outer_hessian.nrows() != dim || outer_hessian.ncols() != dim {
5052 return Err(format!(
5053 "rho-posterior NUTS: outer Hessian shape {:?} does not match rho dim {dim}",
5054 outer_hessian.dim()
5055 ));
5056 }
5057
5058 let mut h = outer_hessian.to_owned();
5061 for i in 0..dim {
5062 for j in (i + 1)..dim {
5063 let avg = 0.5 * (h[[i, j]] + h[[j, i]]);
5064 h[[i, j]] = avg;
5065 h[[j, i]] = avg;
5066 }
5067 }
5068 let diag_scale = (0..dim).map(|i| h[[i, i]].abs()).fold(0.0_f64, f64::max);
5069 let jitter = (diag_scale * 1e-10).max(1e-12);
5070 for i in 0..dim {
5071 h[[i, i]] += jitter;
5072 }
5073
5074 let mode = rho_hat.to_owned();
5075 let whitening = hessian_whitening_transform(
5076 h.view(),
5077 dim,
5078 1.0,
5079 "rho-posterior NUTS: outer-Hessian Cholesky failed",
5080 )?;
5081
5082 let cost_hat = match criterion_and_grad(&mode) {
5083 Some((cost, _)) if cost.is_finite() => cost,
5084 _ => {
5085 return Err(
5086 "rho-posterior NUTS: criterion is infeasible at rho_hat itself".to_string(),
5087 );
5088 }
5089 };
5090
5091 let chol = whitening.chol;
5092 let target = WhitenedRhoCriterionTarget {
5093 criterion_and_grad: Mutex::new(criterion_and_grad),
5094 mode: mode.clone(),
5095 chol: chol.clone(),
5096 chol_t: whitening.chol_t,
5097 cost_hat,
5098 };
5099 let initial_positions = jittered_initial_positions(config, dim, 0.1, 0x3D8A_91C4_E27B_5F60);
5100 let mass_cfg = NUTSMassMatrixConfig::disabled();
5105 let (result, run_stats) = run_whitened_nuts_result(
5106 target,
5107 &mode,
5108 &chol,
5109 initial_positions,
5110 config,
5111 dim,
5112 mass_cfg,
5113 0x6B42_E9A1_05D7_C83F,
5114 "rho-posterior NUTS sampling failed",
5115 mode.clone(),
5116 NutsConvergenceThresholds {
5117 max_rhat: 1.1,
5118 min_ess: None,
5119 },
5120 )?;
5121 log::info!("rho-posterior NUTS (#938 tier 2): sampling complete dim={dim} {run_stats}");
5122 Ok(result)
5123}
5124
5125pub struct GlmFlatInputs<'a> {
5127 pub x: ArrayView2<'a, f64>,
5128 pub y: ArrayView1<'a, f64>,
5129 pub weights: ArrayView1<'a, f64>,
5130 pub penalty_matrix: ArrayView2<'a, f64>,
5131 pub mode: ArrayView1<'a, f64>,
5132 pub hessian: ArrayView2<'a, f64>,
5133 pub gamma_shape: Option<f64>,
5134 pub dispersion: gam_solve::model_types::Dispersion,
5141 pub firth_bias_reduction: bool,
5142 pub offset: Option<ArrayView1<'a, f64>>,
5147}
5148
5149pub struct SurvivalFlatInputs<'a> {
5151 pub age_entry: ArrayView1<'a, f64>,
5152 pub age_exit: ArrayView1<'a, f64>,
5153 pub event_target: ArrayView1<'a, u8>,
5154 pub event_competing: ArrayView1<'a, u8>,
5155 pub weights: ArrayView1<'a, f64>,
5156 pub x_entry: ArrayView2<'a, f64>,
5157 pub x_exit: ArrayView2<'a, f64>,
5158 pub x_derivative: ArrayView2<'a, f64>,
5159 pub eta_offset_entry: Option<ArrayView1<'a, f64>>,
5160 pub eta_offset_exit: Option<ArrayView1<'a, f64>>,
5161 pub derivative_offset_exit: Option<ArrayView1<'a, f64>>,
5162}
5163
5164pub struct SurvivalNutsInputs<'a> {
5166 pub flat: SurvivalFlatInputs<'a>,
5167 pub penalties: gam_models::survival::PenaltyBlocks,
5168 pub monotonicity: gam_models::survival::SurvivalMonotonicityPenalty,
5169 pub spec: gam_models::survival::SurvivalSpec,
5170 pub structurally_monotonic: bool,
5171 pub structural_time_columns: usize,
5172 pub mode: ArrayView1<'a, f64>,
5173 pub hessian: ArrayView2<'a, f64>,
5174}
5175
5176pub enum FamilyNutsInputs<'a> {
5178 Glm(GlmFlatInputs<'a>),
5179 Survival(Box<SurvivalNutsInputs<'a>>),
5180}
5181
5182pub fn explicit_fit_hessian_for_whitening<'a>(
5191 fit: &'a UnifiedFitResult,
5192 expected_dim: usize,
5193 label: &str,
5194) -> Result<&'a Array2<f64>, String> {
5195 let hessian = fit.penalized_hessian().ok_or_else(|| {
5196 format!(
5197 "{label}: fit result is missing an explicit penalized Hessian for HMC/NUTS whitening"
5198 )
5199 })?;
5200 validate_explicit_dense_hessian_for_whitening(
5201 &format!("{label} penalized Hessian"),
5202 hessian,
5203 expected_dim,
5204 )
5205 .map_err(|err| err.to_string())?;
5206 Ok(hessian)
5207}
5208
5209pub fn run_nuts_sampling_flattened_family(
5211 likelihood: LikelihoodSpec,
5212 inputs: FamilyNutsInputs<'_>,
5213 config: &NutsConfig,
5214) -> Result<NutsResult, String> {
5215 if let FamilyNutsInputs::Glm(glm) = &inputs
5216 && glm.firth_bias_reduction
5217 && !likelihood_spec_supports_firth(&likelihood)
5218 {
5219 return Err(HmcError::FirthUnsupported {
5220 reason: format!(
5221 "NUTS with Firth requires a Binomial inverse link with a Fisher-weight jet; {} does not support it",
5222 likelihood.pretty_name()
5223 ),
5224 }
5225 .into());
5226 }
5227
5228 match (likelihood.response.clone(), likelihood.link.clone(), inputs) {
5229 (
5230 ResponseFamily::Gaussian,
5231 InverseLink::Standard(StandardLink::Identity),
5232 FamilyNutsInputs::Glm(glm),
5233 ) => run_nuts_sampling(
5234 glm.x,
5235 glm.y,
5236 glm.weights,
5237 glm.penalty_matrix,
5238 glm.mode,
5239 glm.hessian,
5240 NutsFamily::Gaussian,
5241 1.0,
5242 glm.dispersion,
5243 glm.firth_bias_reduction,
5244 glm.offset,
5245 config,
5246 ),
5247 (
5248 ResponseFamily::Binomial,
5249 InverseLink::Standard(StandardLink::Logit),
5250 FamilyNutsInputs::Glm(glm),
5251 ) => {
5252 if !glm.firth_bias_reduction
5259 && glm.offset.is_none()
5260 && glm.weights.iter().all(|w| (*w - 1.0).abs() <= 1e-10)
5261 {
5262 run_logit_polya_gamma_gibbs(
5263 glm.x,
5264 glm.y,
5265 glm.weights,
5266 glm.penalty_matrix,
5267 glm.mode,
5268 config,
5269 )
5270 } else {
5271 run_nuts_sampling(
5272 glm.x,
5273 glm.y,
5274 glm.weights,
5275 glm.penalty_matrix,
5276 glm.mode,
5277 glm.hessian,
5278 NutsFamily::BinomialLogit,
5279 1.0,
5280 glm.dispersion,
5281 glm.firth_bias_reduction,
5282 glm.offset,
5283 config,
5284 )
5285 }
5286 }
5287 (
5288 ResponseFamily::Binomial,
5289 InverseLink::Standard(StandardLink::Probit),
5290 FamilyNutsInputs::Glm(glm),
5291 ) => run_nuts_sampling(
5292 glm.x,
5293 glm.y,
5294 glm.weights,
5295 glm.penalty_matrix,
5296 glm.mode,
5297 glm.hessian,
5298 NutsFamily::BinomialProbit,
5299 1.0,
5300 glm.dispersion,
5301 glm.firth_bias_reduction,
5302 glm.offset,
5303 config,
5304 ),
5305 (
5306 ResponseFamily::Binomial,
5307 InverseLink::Standard(StandardLink::CLogLog),
5308 FamilyNutsInputs::Glm(glm),
5309 ) => run_nuts_sampling(
5310 glm.x,
5311 glm.y,
5312 glm.weights,
5313 glm.penalty_matrix,
5314 glm.mode,
5315 glm.hessian,
5316 NutsFamily::BinomialCLogLog,
5317 1.0,
5318 glm.dispersion,
5319 glm.firth_bias_reduction,
5320 glm.offset,
5321 config,
5322 ),
5323 (
5324 ResponseFamily::Binomial,
5325 InverseLink::LatentCLogLog(_),
5326 FamilyNutsInputs::Glm(glm),
5327 ) => run_nuts_sampling(
5328 glm.x,
5329 glm.y,
5330 glm.weights,
5331 glm.penalty_matrix,
5332 glm.mode,
5333 glm.hessian,
5334 NutsFamily::BinomialCLogLog,
5335 1.0,
5336 glm.dispersion,
5337 glm.firth_bias_reduction,
5338 glm.offset,
5339 config,
5340 ),
5341 (ResponseFamily::Binomial, InverseLink::Mixture(_), FamilyNutsInputs::Glm(_)) => Err(
5342 "BinomialMixture NUTS is not implemented yet; use fit_gam/predict_gam for blended inverse-link models"
5343 .to_string(),
5344 ),
5345 (ResponseFamily::Binomial, InverseLink::Sas(_), FamilyNutsInputs::Glm(_)) => Err(
5346 "BinomialSas NUTS is not implemented yet; use fit_gam/predict_gam for SAS-link models"
5347 .to_string(),
5348 ),
5349 (ResponseFamily::Binomial, InverseLink::BetaLogistic(_), FamilyNutsInputs::Glm(_)) => Err(
5350 "BinomialBetaLogistic NUTS is not implemented yet; use fit_gam/predict_gam for beta-logistic-link models"
5351 .to_string(),
5352 ),
5353 (ResponseFamily::Binomial, InverseLink::Standard(_), FamilyNutsInputs::Glm(_)) => Err(
5354 "NUTS sampling is not implemented for this binomial inverse link".to_string(),
5355 ),
5356 (ResponseFamily::RoystonParmar, _, FamilyNutsInputs::Survival(survival)) => {
5357 survival_hmc::run_survival_nuts_sampling(
5358 survival.flat.age_entry,
5359 survival.flat.age_exit,
5360 survival.flat.event_target,
5361 survival.flat.event_competing,
5362 survival.flat.weights,
5363 survival.flat.x_entry,
5364 survival.flat.x_exit,
5365 survival.flat.x_derivative,
5366 survival.flat.eta_offset_entry,
5367 survival.flat.eta_offset_exit,
5368 survival.flat.derivative_offset_exit,
5369 survival.penalties,
5370 survival.monotonicity,
5371 survival.spec,
5372 survival.structurally_monotonic,
5373 survival.structural_time_columns,
5374 survival.mode,
5375 survival.hessian,
5376 config,
5377 )
5378 }
5379 (ResponseFamily::RoystonParmar, _, FamilyNutsInputs::Glm(_)) => Err(
5380 "RoystonParmar family requires FamilyNutsInputs::Survival flattened inputs".to_string(),
5381 ),
5382 (_, _, FamilyNutsInputs::Survival(_)) => Err(
5383 "Survival flattened inputs are only valid for the Royston-Parmar response family"
5384 .to_string(),
5385 ),
5386 (ResponseFamily::Poisson, _, FamilyNutsInputs::Glm(glm)) => run_nuts_sampling(
5387 glm.x,
5388 glm.y,
5389 glm.weights,
5390 glm.penalty_matrix,
5391 glm.mode,
5392 glm.hessian,
5393 NutsFamily::PoissonLog,
5394 1.0,
5395 glm.dispersion,
5396 glm.firth_bias_reduction,
5397 glm.offset,
5398 config,
5399 ),
5400 (ResponseFamily::Tweedie { p }, _, FamilyNutsInputs::Glm(glm)) => {
5401 if !is_valid_tweedie_power(p) {
5404 return Err(format!(
5405 "Tweedie variance power must be finite and strictly between 1 and 2; got {p}"
5406 ));
5407 }
5408 run_nuts_sampling(
5409 glm.x,
5410 glm.y,
5411 glm.weights,
5412 glm.penalty_matrix,
5413 glm.mode,
5414 glm.hessian,
5415 NutsFamily::TweedieLog,
5416 p,
5417 glm.dispersion,
5418 glm.firth_bias_reduction,
5419 glm.offset,
5420 config,
5421 )
5422 }
5423 (ResponseFamily::NegativeBinomial { theta, .. }, _, FamilyNutsInputs::Glm(glm)) => {
5424 run_nuts_sampling(
5427 glm.x,
5428 glm.y,
5429 glm.weights,
5430 glm.penalty_matrix,
5431 glm.mode,
5432 glm.hessian,
5433 NutsFamily::NegativeBinomialLog,
5434 theta,
5435 glm.dispersion,
5436 glm.firth_bias_reduction,
5437 glm.offset,
5438 config,
5439 )
5440 }
5441 (ResponseFamily::Beta { .. }, _, FamilyNutsInputs::Glm(_)) => Err(
5442 "NUTS sampling is not implemented for beta-regression logit".to_string(),
5443 ),
5444 (ResponseFamily::Gamma, _, FamilyNutsInputs::Glm(glm)) => run_nuts_sampling(
5445 glm.x,
5446 glm.y,
5447 glm.weights,
5448 glm.penalty_matrix,
5449 glm.mode,
5450 glm.hessian,
5451 NutsFamily::GammaLog,
5452 glm.gamma_shape.unwrap_or(1.0),
5453 glm.dispersion,
5454 glm.firth_bias_reduction,
5455 glm.offset,
5456 config,
5457 ),
5458 (ResponseFamily::Gaussian, _, FamilyNutsInputs::Glm(_)) => Err(
5459 "NUTS sampling is only implemented for Gaussian with identity link".to_string(),
5460 ),
5461 }
5462}
5463
5464#[derive(Clone)]
5485pub struct LinkWiggleSplineArtifacts {
5486 pub knot_range: (f64, f64),
5488 pub knot_vector: Array1<f64>,
5490 pub degree: usize,
5492}
5493
5494#[derive(Clone)]
5496pub struct LinkWigglePosterior {
5497 x: Arc<Array2<f64>>,
5499 y: Arc<Array1<f64>>,
5500 weights: Arc<Array1<f64>>,
5501 penalty_base: Arc<Array2<f64>>,
5503 penalty_link: Arc<Array2<f64>>,
5505 mode_beta: Arc<Array1<f64>>,
5506 mode_theta: Arc<Array1<f64>>,
5507 spline: LinkWiggleSplineArtifacts,
5508 chol: Array2<f64>,
5510 chol_t: Array2<f64>,
5512 p_base: usize,
5513 p_link: usize,
5514 n_samples: usize,
5515 nuts_family: NutsFamily,
5516 scale: f64,
5518 cov_scale: f64,
5524}
5525
5526impl LinkWigglePosterior {
5527 #[inline]
5529 fn standardized_z(&self, u: &Array1<f64>) -> (Array1<f64>, Array1<f64>, f64) {
5530 let (min_u, max_u) = self.spline.knot_range;
5531 let rw = (max_u - min_u).max(1e-6);
5532 let z_raw: Array1<f64> = u.mapv(|v| (v - min_u) / rw);
5533 let z_c: Array1<f64> = z_raw.mapv(|z| z.clamp(0.0, 1.0));
5534 (z_raw, z_c, rw)
5535 }
5536
5537 pub fn new(
5539 x: ArrayView2<f64>,
5540 y: ArrayView1<f64>,
5541 weights: ArrayView1<f64>,
5542 penalty_base: ArrayView2<f64>,
5543 penalty_link: ArrayView2<f64>,
5544 mode_beta: ArrayView1<f64>,
5545 mode_theta: ArrayView1<f64>,
5546 hessian: ArrayView2<f64>,
5547 spline: LinkWiggleSplineArtifacts,
5548 nuts_family: NutsFamily,
5549 scale: f64,
5550 ) -> Result<Self, String> {
5551 let n_samples = x.nrows();
5552 let p_base = x.ncols();
5553 let p_link = mode_theta.len();
5554 let dim = p_base + p_link;
5555 if hessian.nrows() != dim || hessian.ncols() != dim {
5556 return Err(HmcError::DimensionMismatch {
5557 reason: format!(
5558 "LinkWigglePosterior: Hessian dim mismatch: {}x{} vs expected {}x{}",
5559 hessian.nrows(),
5560 hessian.ncols(),
5561 dim,
5562 dim,
5563 ),
5564 }
5565 .into());
5566 }
5567 if nuts_family.likelihood_spec().is_binomial() {
5568 validate_binary_responses("binomial link-wiggle NUTS", &y, &weights)
5569 .map_err(String::from)?;
5570 }
5571 if matches!(nuts_family, NutsFamily::NegativeBinomialLog) {
5572 validate_count_responses("negative-binomial link-wiggle NUTS", &y, &weights)
5573 .map_err(String::from)?;
5574 }
5575 let cov_scale = match nuts_family {
5584 NutsFamily::Gaussian => scale * scale,
5585 _ => 1.0,
5586 };
5587 let whitening = hessian_whitening_transform(
5588 hessian,
5589 dim,
5590 cov_scale,
5591 "LinkWigglePosterior Cholesky failed",
5592 )?;
5593 let chol = whitening.chol;
5594 let chol_t = whitening.chol_t;
5595 Ok(Self {
5596 x: Arc::new(x.to_owned()),
5597 y: Arc::new(y.to_owned()),
5598 weights: Arc::new(weights.to_owned()),
5599 penalty_base: Arc::new(penalty_base.to_owned()),
5600 penalty_link: Arc::new(penalty_link.to_owned()),
5601 mode_beta: Arc::new(mode_beta.to_owned()),
5602 mode_theta: Arc::new(mode_theta.to_owned()),
5603 spline,
5604 chol,
5605 chol_t,
5606 p_base,
5607 p_link,
5608 n_samples,
5609 nuts_family,
5610 scale,
5611 cov_scale,
5612 })
5613 }
5614
5615 fn evaluate_link(&self, u: &Array1<f64>, theta: &Array1<f64>) -> (Array2<f64>, Array1<f64>) {
5617 let n = u.len();
5618 if theta.is_empty() {
5619 return (Array2::zeros((n, 0)), u.clone());
5620 }
5621
5622 let (z_raw, z_c, _) = self.standardized_z(u);
5623 let Ok(mut basis) = monotone_wiggle_basis_with_derivative_order(
5624 z_c.view(),
5625 &self.spline.knot_vector,
5626 self.spline.degree,
5627 0,
5628 ) else {
5629 return (Array2::zeros((n, theta.len())), u.clone());
5630 };
5631 if basis.ncols() != theta.len() {
5632 return (Array2::zeros((n, theta.len())), u.clone());
5633 }
5634
5635 let mut needs_ext = false;
5638 for i in 0..n {
5639 if (z_raw[i] - z_c[i]).abs() > 1e-12 {
5640 needs_ext = true;
5641 break;
5642 }
5643 }
5644 if needs_ext
5645 && let Ok(b_prime) = monotone_wiggle_basis_with_derivative_order(
5646 z_c.view(),
5647 &self.spline.knot_vector,
5648 self.spline.degree,
5649 1,
5650 )
5651 {
5652 for i in 0..n {
5653 let dz = z_raw[i] - z_c[i];
5654 if dz.abs() <= 1e-12 {
5655 continue;
5656 }
5657 for j in 0..basis.ncols().min(b_prime.ncols()) {
5658 basis[[i, j]] += dz * b_prime[[i, j]];
5659 }
5660 }
5661 }
5662 (
5663 basis.clone(),
5664 u + &gam_linalg::faer_ndarray::fast_av(&basis, theta),
5665 )
5666 }
5667
5668 fn compute_g_prime(&self, u: &Array1<f64>, theta: &Array1<f64>) -> Array1<f64> {
5670 let n = u.len();
5671 let mut g = Array1::<f64>::ones(n);
5672 let (_, z_c, rw) = self.standardized_z(u);
5673 if theta.is_empty() {
5674 return g;
5675 }
5676
5677 let Ok(b_prime_constrained) = monotone_wiggle_basis_with_derivative_order(
5678 z_c.view(),
5679 &self.spline.knot_vector,
5680 self.spline.degree,
5681 1,
5682 ) else {
5683 return g;
5684 };
5685 if b_prime_constrained.ncols() != theta.len() {
5686 return g;
5687 }
5688 let dwiggle_dz = gam_linalg::faer_ndarray::fast_av(&b_prime_constrained, theta);
5689 ndarray::Zip::from(&mut g)
5690 .and(&dwiggle_dz)
5691 .par_for_each(|gi, &dw| *gi = 1.0 + dw / rw);
5692 g
5693 }
5694
5695 fn compute_logp_and_grad_into(&self, z: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
5696 let dim = self.p_base + self.p_link;
5697
5698 let mut mode = Array1::<f64>::zeros(dim);
5700 mode.slice_mut(ndarray::s![0..self.p_base])
5701 .assign(&self.mode_beta);
5702 mode.slice_mut(ndarray::s![self.p_base..])
5703 .assign(&self.mode_theta);
5704 let q = &mode + &self.chol.dot(z);
5705 let beta = q.slice(ndarray::s![0..self.p_base]).to_owned();
5706 let theta = q.slice(ndarray::s![self.p_base..]).to_owned();
5707
5708 let u = gam_linalg::faer_ndarray::fast_av(self.x.as_ref(), &beta);
5710 let (bwiggle, eta) = self.evaluate_link(&u, &theta);
5711
5712 let ll;
5714 let mut residual = Array1::<f64>::zeros(self.n_samples);
5715 match self.nuts_family {
5716 NutsFamily::Gaussian => {
5717 let inv_scale_sq = 1.0 / (self.scale * self.scale).max(1e-10);
5718 let mut ll_acc = 0.0;
5719 for i in 0..self.n_samples {
5720 let r = self.y[i] - eta[i];
5721 let w = self.weights[i];
5722 ll_acc -= 0.5 * w * r * r * inv_scale_sq;
5723 residual[i] = w * r * inv_scale_sq;
5724 }
5725 ll = ll_acc;
5726 }
5727 NutsFamily::BinomialLogit => {
5728 let mut ll_acc = 0.0;
5729 for i in 0..self.n_samples {
5730 let eta_i = eta[i];
5731 let (y_i, w_i) = (self.y[i], self.weights[i]);
5732 ll_acc += w_i * (y_i * eta_i - gam_linalg::utils::stable_softplus(eta_i));
5733 let mu = gam_linalg::utils::stable_logistic(eta_i);
5734 residual[i] = w_i * (y_i - mu);
5735 }
5736 ll = ll_acc;
5737 }
5738 NutsFamily::BinomialProbit => {
5739 let mut ll_acc = 0.0;
5740 for i in 0..self.n_samples {
5741 let eta_i = eta[i];
5742 let (y_i, w_i) = (self.y[i], self.weights[i]);
5743 let log_phi_pos = log_ndtr(eta_i);
5744 let log_phi_neg = log_ndtr(-eta_i);
5745 ll_acc += w_i * (y_i * log_phi_pos + (1.0 - y_i) * log_phi_neg);
5746 let log_phi = standard_normal_log_pdf(eta_i);
5747 let ratio_pos = (log_phi - log_phi_pos).exp();
5748 let ratio_neg = (log_phi - log_phi_neg).exp();
5749 residual[i] = w_i * (y_i * ratio_pos - (1.0 - y_i) * ratio_neg);
5750 }
5751 ll = ll_acc;
5752 }
5753 NutsFamily::BinomialCLogLog => {
5754 let mut ll_acc = 0.0;
5755 for i in 0..self.n_samples {
5756 let eta_i = eta[i];
5757 if !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)) {
5758 grad.fill(0.0);
5759 return f64::NEG_INFINITY;
5760 }
5761 let (y_i, w_i) = (self.y[i], self.weights[i]);
5762 let (ll_i, residual_i) = match cloglog_bernoulli_logp_and_residual(eta_i, y_i) {
5763 Ok(values) => values,
5764 Err(_) => {
5765 grad.fill(0.0);
5766 return f64::NEG_INFINITY;
5767 }
5768 };
5769 ll_acc += w_i * ll_i;
5770 residual[i] = w_i * residual_i;
5771 }
5772 ll = ll_acc;
5773 }
5774 NutsFamily::PoissonLog => {
5775 let mut ll_acc = 0.0;
5776 for i in 0..self.n_samples {
5777 let eta_i = eta[i];
5778 if !(eta_i.is_finite() && (-30.0..=30.0).contains(&eta_i)) {
5779 grad.fill(0.0);
5780 return f64::NEG_INFINITY;
5781 }
5782 let (y_i, w_i) = (self.y[i], self.weights[i]);
5783 let mu = eta_i.exp();
5784 ll_acc += w_i * (y_i * eta_i - mu);
5785 residual[i] = w_i * (y_i - mu);
5786 }
5787 ll = ll_acc;
5788 }
5789 NutsFamily::TweedieLog => {
5790 let mut ll_acc = 0.0;
5791 if !is_valid_tweedie_power(self.scale) {
5794 grad.fill(0.0);
5795 return f64::NEG_INFINITY;
5796 }
5797 let p = self.scale;
5798 for i in 0..self.n_samples {
5799 let eta_i = eta[i];
5800 if !(eta_i.is_finite() && (-30.0..=30.0).contains(&eta_i)) {
5801 grad.fill(0.0);
5802 return f64::NEG_INFINITY;
5803 }
5804 let (y_i, w_i) = (self.y[i], self.weights[i]);
5805 let mu = eta_i.exp().max(1e-300);
5806 ll_acc +=
5807 w_i * (y_i * mu.powf(1.0 - p) / (1.0 - p) - mu.powf(2.0 - p) / (2.0 - p));
5808 residual[i] = w_i * (y_i - mu) * mu.powf(1.0 - p);
5809 }
5810 ll = ll_acc;
5811 }
5812 NutsFamily::NegativeBinomialLog => {
5813 let mut ll_acc = 0.0;
5814 if !(self.scale.is_finite() && self.scale > 0.0) {
5817 grad.fill(0.0);
5818 return f64::NEG_INFINITY;
5819 }
5820 let theta = self.scale;
5821 for i in 0..self.n_samples {
5822 let eta_i = eta[i];
5823 if !(eta_i.is_finite() && (-30.0..=30.0).contains(&eta_i)) {
5824 grad.fill(0.0);
5825 return f64::NEG_INFINITY;
5826 }
5827 let (y_i, w_i) = (self.y[i], self.weights[i]);
5828 if w_i <= 0.0 {
5829 residual[i] = 0.0;
5830 continue;
5831 }
5832 let mu = eta_i.exp().max(1e-12);
5833 let log_mu_term = if y_i > 0.0 { y_i * mu.ln() } else { 0.0 };
5834 ll_acc += w_i
5835 * (statrs::function::gamma::ln_gamma(y_i + theta)
5836 - statrs::function::gamma::ln_gamma(theta)
5837 - statrs::function::gamma::ln_gamma(y_i + 1.0)
5838 + theta * (theta.ln() - (theta + mu).ln())
5839 + log_mu_term
5840 - y_i * (theta + mu).ln());
5841 residual[i] = w_i * theta * (y_i - mu) / (theta + mu);
5842 }
5843 ll = ll_acc;
5844 }
5845 NutsFamily::GammaLog => {
5846 let mut ll_acc = 0.0;
5847 let shape = self.scale.max(1e-10);
5848 for i in 0..self.n_samples {
5849 let eta_i = eta[i];
5850 if !(eta_i.is_finite() && (-30.0..=30.0).contains(&eta_i)) {
5851 grad.fill(0.0);
5852 return f64::NEG_INFINITY;
5853 }
5854 let (y_i, w_i) = (self.y[i], self.weights[i]);
5855 let mu = eta_i.exp();
5856 ll_acc += w_i * shape * (-y_i / mu - eta_i);
5857 residual[i] = w_i * shape * (y_i / mu - 1.0);
5858 }
5859 ll = ll_acc;
5860 }
5861 }
5862
5863 let penalty_scale = 1.0 / self.cov_scale.max(1e-300);
5876
5877 let s_link_theta = self.penalty_link.dot(&theta);
5879 let grad_theta = &fast_atv(&bwiggle, &residual) - &(&s_link_theta * penalty_scale);
5880
5881 let g_prime = self.compute_g_prime(&u, &theta);
5884 let r_scaled: Array1<f64> = residual
5885 .iter()
5886 .zip(g_prime.iter())
5887 .map(|(&r, &g)| r * g)
5888 .collect();
5889 let s_base_beta = self.penalty_base.dot(&beta);
5890 let grad_beta = &fast_atv(&self.x, &r_scaled) - &(&s_base_beta * penalty_scale);
5891
5892 let penalty =
5894 penalty_scale * (0.5 * beta.dot(&s_base_beta) + 0.5 * theta.dot(&s_link_theta));
5895
5896 let mut grad_q = Array1::<f64>::zeros(dim);
5898 grad_q
5899 .slice_mut(ndarray::s![0..self.p_base])
5900 .assign(&grad_beta);
5901 grad_q
5902 .slice_mut(ndarray::s![self.p_base..])
5903 .assign(&grad_theta);
5904 fast_av_into(&self.chol_t, &grad_q, grad);
5905 ll - penalty
5906 }
5907
5908 pub fn chol(&self) -> &Array2<f64> {
5910 &self.chol
5911 }
5912
5913 pub fn mode_joint(&self) -> Array1<f64> {
5915 let dim = self.p_base + self.p_link;
5916 let mut mode = Array1::<f64>::zeros(dim);
5917 mode.slice_mut(ndarray::s![0..self.p_base])
5918 .assign(&self.mode_beta);
5919 mode.slice_mut(ndarray::s![self.p_base..])
5920 .assign(&self.mode_theta);
5921 mode
5922 }
5923}
5924
5925impl HamiltonianTarget<Array1<f64>> for LinkWigglePosterior {
5926 fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
5927 self.compute_logp_and_grad_into(position, grad)
5928 }
5929}
5930
5931pub fn run_link_wiggle_nuts_sampling(
5933 x: ArrayView2<f64>,
5934 y: ArrayView1<f64>,
5935 weights: ArrayView1<f64>,
5936 penalty_base: ArrayView2<f64>,
5937 penalty_link: ArrayView2<f64>,
5938 mode_beta: ArrayView1<f64>,
5939 mode_theta: ArrayView1<f64>,
5940 hessian: ArrayView2<f64>,
5941 spline: LinkWiggleSplineArtifacts,
5942 nuts_family: NutsFamily,
5943 scale: f64,
5944 config: &NutsConfig,
5945) -> Result<NutsResult, String> {
5946 validate_nuts_config(config).map_err(String::from)?;
5947 let dim = mode_beta.len() + mode_theta.len();
5948 let target = LinkWigglePosterior::new(
5949 x,
5950 y,
5951 weights,
5952 penalty_base,
5953 penalty_link,
5954 mode_beta,
5955 mode_theta,
5956 hessian,
5957 spline,
5958 nuts_family,
5959 scale,
5960 )?;
5961 let chol = target.chol().clone();
5962 let mode_arr = target.mode_joint();
5963
5964 let initial_positions = jittered_initial_positions(config, dim, 0.1, 0x8C48_0F65_3A2B_D917);
5965
5966 let mass_cfg = robust_mass_matrix_config(dim, config.nwarmup);
5967 let (result, run_stats) = run_whitened_nuts_result(
5968 target,
5969 &mode_arr,
5970 &chol,
5971 initial_positions,
5972 config,
5973 dim,
5974 mass_cfg,
5975 0x2E31_A4B6_C908_F57D,
5976 "Link-wiggle NUTS sampling failed",
5977 Array1::zeros(dim),
5978 NutsConvergenceThresholds {
5979 max_rhat: 1.1,
5980 min_ess: Some(100.0),
5981 },
5982 )?;
5983 log::info!("Link-wiggle NUTS sampling complete: {}", run_stats);
5984
5985 Ok(result)
5986}
5987
5988pub fn laplace_directional_cubic_diagnostic(
6031 hessian: &Array2<f64>,
6032 design: &DesignMatrix,
6033 c_weights: &Array1<f64>,
6034 refine_supremum: bool,
6035) -> Result<(f64, Array1<f64>), String> {
6036 let p = hessian.nrows();
6037 if p == 0 || hessian.ncols() != p {
6038 return Ok((0.0, Array1::zeros(0)));
6039 }
6040
6041 let sym_h = (hessian + &hessian.t()) * 0.5;
6042 let (evals, evecs) = sym_h
6043 .eigh(Side::Lower)
6044 .map_err(|e| format!("directional cubic diagnostic eigendecomposition failed: {e}"))?;
6045 let max_eval = evals.iter().fold(0.0_f64, |acc, &ev| acc.max(ev.abs()));
6046 let tol = (max_eval * 1.0e-12).max(1.0e-14);
6047 let mut directional = Array1::<f64>::zeros(p);
6048 let mut max_abs = 0.0_f64;
6049
6050 for r in 0..p {
6058 let lambda = evals[r];
6059 if lambda <= tol {
6060 continue;
6061 }
6062 let v = evecs.column(r);
6063 let gamma = directional_cubic_contraction(design, c_weights, &v) / lambda.powf(1.5);
6064 directional[r] = if gamma.is_finite() { gamma } else { 0.0 };
6065 max_abs = max_abs.max(directional[r].abs());
6066 }
6067
6068 if refine_supremum && p >= 2 {
6080 let positive_mask: Vec<bool> = evals.iter().map(|&ev| ev > tol).collect();
6084 let n_pos = positive_mask.iter().filter(|&&m| m).count();
6085 if n_pos >= 2 {
6086 let max_abs_from_probes = cubic_power_iteration_refinement(
6087 design,
6088 c_weights,
6089 &evals,
6090 &evecs,
6091 &positive_mask,
6092 n_pos,
6093 );
6094 if max_abs_from_probes > max_abs {
6095 max_abs = max_abs_from_probes;
6096 }
6097 }
6098 }
6099
6100 Ok((max_abs, directional))
6101}
6102
6103fn directional_cubic_contraction(
6105 design: &DesignMatrix,
6106 c_weights: &Array1<f64>,
6107 v: &ArrayView1<f64>,
6108) -> f64 {
6109 match design.as_sparse() {
6110 Some(x_sparse) => {
6111 let (symbolic, values) = x_sparse.as_ref().parts();
6112 let col_ptr = symbolic.col_ptr();
6113 let row_idx = symbolic.row_idx();
6114 let mut row_scores = vec![0.0_f64; x_sparse.nrows()];
6115 for col in 0..x_sparse.ncols() {
6116 let coeff = v[col];
6117 for ptr in col_ptr[col]..col_ptr[col + 1] {
6118 row_scores[row_idx[ptr]] += values[ptr] * coeff;
6119 }
6120 }
6121 let mut cubic = 0.0_f64;
6122 for i in 0..row_scores.len().min(c_weights.len()) {
6123 cubic += c_weights[i] * row_scores[i].powi(3);
6124 }
6125 cubic
6126 }
6127 None => {
6128 let x_dense = design.to_dense_cow();
6129 let x_dense = x_dense.as_ref();
6130 let mut cubic = 0.0_f64;
6131 for i in 0..x_dense.nrows().min(c_weights.len()) {
6132 let proj = x_dense.row(i).dot(v);
6133 cubic += c_weights[i] * proj.powi(3);
6134 }
6135 cubic
6136 }
6137 }
6138}
6139
6140fn directional_cubic_gradient(
6143 design: &DesignMatrix,
6144 c_weights: &Array1<f64>,
6145 v: &Array1<f64>,
6146) -> Array1<f64> {
6147 let p = v.len();
6148 match design.as_sparse() {
6149 Some(x_sparse) => {
6150 let (symbolic, values) = x_sparse.as_ref().parts();
6151 let col_ptr = symbolic.col_ptr();
6152 let row_idx = symbolic.row_idx();
6153 let n = x_sparse.nrows();
6154 let mut row_scores = vec![0.0_f64; n];
6155 for col in 0..x_sparse.ncols() {
6156 let coeff = v[col];
6157 for ptr in col_ptr[col]..col_ptr[col + 1] {
6158 row_scores[row_idx[ptr]] += values[ptr] * coeff;
6159 }
6160 }
6161 let mut quad_weights = vec![0.0_f64; n];
6163 for i in 0..n.min(c_weights.len()) {
6164 quad_weights[i] = 3.0 * c_weights[i] * row_scores[i] * row_scores[i];
6165 }
6166 let mut grad = Array1::<f64>::zeros(p);
6168 for col in 0..x_sparse.ncols() {
6169 let mut acc = 0.0_f64;
6170 for ptr in col_ptr[col]..col_ptr[col + 1] {
6171 acc += values[ptr] * quad_weights[row_idx[ptr]];
6172 }
6173 grad[col] = acc;
6174 }
6175 grad
6176 }
6177 None => {
6178 let x_dense = design.to_dense_cow();
6179 let x_dense = x_dense.as_ref();
6180 let n = x_dense.nrows();
6181 let mut grad = Array1::<f64>::zeros(p);
6182 for i in 0..n.min(c_weights.len()) {
6183 let proj = x_dense.row(i).dot(v);
6184 let w = 3.0 * c_weights[i] * proj * proj;
6185 let row = x_dense.row(i);
6187 for j in 0..p {
6188 grad[j] += w * row[j];
6189 }
6190 }
6191 grad
6192 }
6193 }
6194}
6195
6196fn cubic_power_iteration_refinement(
6202 design: &DesignMatrix,
6203 c_weights: &Array1<f64>,
6204 evals: &Array1<f64>,
6205 evecs: &Array2<f64>,
6206 positive_mask: &[bool],
6207 n_pos: usize,
6208) -> f64 {
6209 let p = evals.len();
6210 let max_probes = 8;
6211 let max_iters = 5;
6212
6213 let to_original = |u: &Array1<f64>| -> Array1<f64> {
6216 let mut v = Array1::<f64>::zeros(p);
6217 let mut idx = 0;
6218 for r in 0..p {
6219 if positive_mask[r] {
6220 let scale = u[idx] / evals[r].sqrt();
6221 let col = evecs.column(r);
6222 for j in 0..p {
6223 v[j] += scale * col[j];
6224 }
6225 idx += 1;
6226 }
6227 }
6228 v
6229 };
6230
6231 let to_whitened = |g: &Array1<f64>| -> Array1<f64> {
6233 let mut u = Array1::<f64>::zeros(n_pos);
6234 let mut idx = 0;
6235 for r in 0..p {
6236 if positive_mask[r] {
6237 u[idx] = evals[r].sqrt() * evecs.column(r).dot(g);
6238 idx += 1;
6239 }
6240 }
6241 u
6242 };
6243
6244 let eval_gamma = |u: &Array1<f64>| -> f64 {
6246 let norm = u.dot(u).sqrt();
6247 if norm < 1e-30 {
6248 return 0.0;
6249 }
6250 let u_normed: Array1<f64> = u / norm;
6251 let v = to_original(&u_normed);
6252 let cubic = directional_cubic_contraction(design, c_weights, &v.view());
6254 if cubic.is_finite() { cubic.abs() } else { 0.0 }
6255 };
6256
6257 let refine_step = |u: &Array1<f64>| -> Array1<f64> {
6259 let norm = u.dot(u).sqrt();
6260 if norm < 1e-30 {
6261 return u.clone();
6262 }
6263 let u_normed: Array1<f64> = u / norm;
6264 let v = to_original(&u_normed);
6265 let grad_v = directional_cubic_gradient(design, c_weights, &v);
6267 let mut grad_u = to_whitened(&grad_v);
6269 let dot = grad_u.dot(&u_normed);
6271 grad_u.scaled_add(-dot, &u_normed);
6272 let cubic_val = directional_cubic_contraction(design, c_weights, &v.view());
6274 let sign = if cubic_val >= 0.0 { 1.0 } else { -1.0 };
6275 let step_size = 0.3;
6276 let mut u_new = &u_normed + &(&grad_u * (sign * step_size));
6277 let new_norm = u_new.dot(&u_new).sqrt();
6278 if new_norm > 1e-30 {
6279 u_new /= new_norm;
6280 }
6281 u_new
6282 };
6283
6284 let mut best = 0.0_f64;
6285
6286 let mut seeds: Vec<Array1<f64>> = Vec::with_capacity(max_probes);
6292
6293 let mut best_eig_idx = 0;
6296 let mut best_eig_gamma = 0.0_f64;
6297 for j in 0..n_pos {
6298 let mut u = Array1::<f64>::zeros(n_pos);
6299 u[j] = 1.0;
6300 let g = eval_gamma(&u);
6301 if g > best_eig_gamma {
6302 best_eig_gamma = g;
6303 best_eig_idx = j;
6304 }
6305 }
6306 best = best.max(best_eig_gamma);
6307 let mut u_best = Array1::<f64>::zeros(n_pos);
6308 u_best[best_eig_idx] = 1.0;
6309 seeds.push(u_best);
6310
6311 let n_top = n_pos.min(4);
6313 for i in 0..n_top {
6314 for j in (i + 1)..n_top {
6315 if seeds.len() >= max_probes {
6316 break;
6317 }
6318 let inv_sqrt2 = std::f64::consts::FRAC_1_SQRT_2;
6319 let mut u_plus = Array1::<f64>::zeros(n_pos);
6320 u_plus[i] = inv_sqrt2;
6321 u_plus[j] = inv_sqrt2;
6322 seeds.push(u_plus);
6323 if seeds.len() < max_probes {
6324 let mut u_minus = Array1::<f64>::zeros(n_pos);
6325 u_minus[i] = inv_sqrt2;
6326 u_minus[j] = -inv_sqrt2;
6327 seeds.push(u_minus);
6328 }
6329 }
6330 }
6331
6332 for seed in &seeds {
6334 let mut u = seed.clone();
6335 for _ in 0..max_iters {
6336 u = refine_step(&u);
6337 }
6338 let g = eval_gamma(&u);
6339 best = best.max(g);
6340 }
6341
6342 best
6343}
6344
6345pub use gam_problem::laplace_sampler_contract::{
6355 BlockExcessTarget, BlockSampledMarginal, BlockSampledMoments, GaussianModePosterior,
6356 LaplaceTrustworthiness, laplace_skewness_threshold, laplace_trustworthiness_from_skewness,
6357};
6358
6359pub struct HmcIoLaplaceMarginalSampler;
6365
6366impl gam_problem::laplace_sampler_contract::LaplaceMarginalSampler for HmcIoLaplaceMarginalSampler {
6367 fn directional_cubic_diagnostic(
6368 &self,
6369 hessian: &Array2<f64>,
6370 design: &DesignMatrix,
6371 c_weights: &Array1<f64>,
6372 refine_supremum: bool,
6373 ) -> Result<(f64, Array1<f64>), String> {
6374 laplace_directional_cubic_diagnostic(hessian, design, c_weights, refine_supremum)
6375 }
6376
6377 fn block_sampled_marginal_correction(
6378 &self,
6379 target: &dyn BlockExcessTarget,
6380 ) -> Result<BlockSampledMarginal, String> {
6381 block_sampled_marginal_correction(target)
6382 }
6383}
6384
6385pub struct HmcIoGaussianModePosteriorSampler;
6392
6393impl gam_problem::laplace_sampler_contract::GaussianModePosteriorSampler
6394 for HmcIoGaussianModePosteriorSampler
6395{
6396 fn sample_gaussian_mode_posterior(
6397 &self,
6398 mode: ArrayView1<f64>,
6399 precision: ArrayView2<f64>,
6400 ) -> Result<GaussianModePosterior, String> {
6401 let config = NutsConfig::for_dimension(mode.len());
6402 sample_gaussian_mode_posterior(mode, precision, &config)
6403 }
6404}
6405
6406fn block_sampling_draws(block_dim: usize) -> usize {
6412 const BASE: usize = 256;
6415 const PER_DIM: usize = 256;
6416 const CAP: usize = 4096;
6417 (BASE + PER_DIM * block_dim).min(CAP)
6418}
6419
6420pub fn block_sampled_marginal_correction<T: BlockExcessTarget + ?Sized>(
6446 target: &T,
6447) -> Result<BlockSampledMarginal, String> {
6448 use rand::SeedableRng;
6449 use rand::rngs::StdRng;
6450
6451 let m = target.block_dim();
6452 let k = target.rho_dim();
6453 if m == 0 {
6454 return Ok(BlockSampledMarginal {
6455 value: 0.0,
6456 rho_gradient: Array1::zeros(k),
6457 importance_ess: 0.0,
6458 n_draws: 0,
6459 moments: None,
6460 });
6461 }
6462 let lambdas = target.block_curvatures();
6463 if lambdas.len() != m {
6464 return Err(format!(
6465 "block_sampled_marginal_correction: block_curvatures len {} != block_dim {m}",
6466 lambdas.len()
6467 ));
6468 }
6469 let inv_sqrt_lambda: Array1<f64> = lambdas.mapv(|l| {
6470 if l > 0.0 {
6471 1.0 / l.sqrt()
6472 } else {
6473 f64::NAN
6477 }
6478 });
6479 if inv_sqrt_lambda.iter().any(|v| !v.is_finite()) {
6480 return Err(
6481 "block_sampled_marginal_correction: non-positive block curvature (mode is not a \
6482 strict local minimum in a sampled direction)"
6483 .to_string(),
6484 );
6485 }
6486
6487 let n_draws = block_sampling_draws(m);
6488 let mut seed_bits: u64 = 0x9E37_79B9_7F4A_7C15;
6505 seed_bits ^= (m as u64).rotate_left(17);
6506 seed_bits = seed_bits.wrapping_mul(0x1000_0000_01B3);
6507 seed_bits ^= (k as u64).rotate_left(31);
6508 seed_bits = seed_bits.wrapping_mul(0x1000_0000_01B3);
6509 let mut rng = StdRng::seed_from_u64(seed_bits);
6510
6511 let n_obs = target.base_neg_score().len();
6521 let mut max_lw = f64::NEG_INFINITY;
6522 let mut sum_w = 0.0_f64;
6523 let mut sum_w2 = 0.0_f64;
6524 let mut grad_acc = Array1::<f64>::zeros(k);
6525 let mut e_t_acc = Array1::<f64>::zeros(m);
6526 let mut e_tt_acc = Array2::<f64>::zeros((m, m));
6527 let mut e_ngs_acc = Array1::<f64>::zeros(n_obs);
6528 let mut e_t_ngs_acc = Array2::<f64>::zeros((n_obs, m));
6529
6530 let mut draws = Array2::<f64>::zeros((m, n_draws));
6538 for s in 0..n_draws {
6539 let mut col = draws.column_mut(s);
6540 for r in 0..m {
6541 let z = sample_standard_normal(&mut rng);
6542 col[r] = z * inv_sqrt_lambda[r];
6543 }
6544 }
6545 let batched = target.excess_with_displaced_neg_score_batch(&draws);
6546
6547 let mut t = Array1::<f64>::zeros(m);
6548 for (sidx, (excess, displaced_ngs)) in batched.into_iter().enumerate() {
6549 t.assign(&draws.column(sidx));
6550 if !excess.is_finite() {
6551 continue;
6552 }
6553 let Some(ngs) = displaced_ngs else {
6554 continue;
6556 };
6557 let lw = -excess;
6558 if lw > max_lw {
6559 let rescale = (max_lw - lw).exp();
6562 sum_w *= rescale;
6563 sum_w2 *= rescale * rescale;
6564 grad_acc *= rescale;
6565 e_t_acc *= rescale;
6566 e_tt_acc *= rescale;
6567 e_ngs_acc *= rescale;
6568 e_t_ngs_acc *= rescale;
6569 max_lw = lw;
6570 }
6571 let w = (lw - max_lw).exp();
6572 sum_w += w;
6573 sum_w2 += w * w;
6574 grad_acc.scaled_add(-w, &target.excess_rho_gradient(&t));
6576 if ngs.len() != n_obs {
6578 return Err(format!(
6579 "block_sampled_marginal_correction: displaced_neg_score len {} != {n_obs}",
6580 ngs.len()
6581 ));
6582 }
6583 e_t_acc.scaled_add(w, &t);
6584 e_ngs_acc.scaled_add(w, &ngs);
6585 for r in 0..m {
6586 let wt_r = w * t[r];
6587 for q in 0..m {
6588 e_tt_acc[(q, r)] += wt_r * t[q];
6589 }
6590 e_t_ngs_acc.column_mut(r).scaled_add(wt_r, &ngs);
6591 }
6592 }
6593 if !max_lw.is_finite() {
6594 return Err(
6595 "block_sampled_marginal_correction: all importance draws were infeasible".to_string(),
6596 );
6597 }
6598 let value = max_lw + (sum_w / n_draws as f64).ln();
6599 let (rho_gradient, moments) = if sum_w > 0.0 {
6601 (
6602 grad_acc / sum_w,
6603 Some(BlockSampledMoments {
6604 e_t: e_t_acc / sum_w,
6605 e_tt: e_tt_acc / sum_w,
6606 e_neg_score: e_ngs_acc / sum_w,
6607 e_t_neg_score: e_t_ngs_acc / sum_w,
6608 }),
6609 )
6610 } else {
6611 (Array1::zeros(k), None)
6612 };
6613 let importance_ess = if sum_w2 > 0.0 {
6615 (sum_w * sum_w) / sum_w2
6616 } else {
6617 0.0
6618 };
6619
6620 if !value.is_finite() || rho_gradient.iter().any(|v| !v.is_finite()) {
6621 return Err(
6622 "block_sampled_marginal_correction: produced a non-finite correction or gradient"
6623 .to_string(),
6624 );
6625 }
6626 if let Some(mo) = moments.as_ref()
6627 && (mo.e_t.iter().any(|v| !v.is_finite())
6628 || mo.e_tt.iter().any(|v| !v.is_finite())
6629 || mo.e_neg_score.iter().any(|v| !v.is_finite())
6630 || mo.e_t_neg_score.iter().any(|v| !v.is_finite()))
6631 {
6632 return Err(
6633 "block_sampled_marginal_correction: produced non-finite gradient-channel moments"
6634 .to_string(),
6635 );
6636 }
6637
6638 Ok(BlockSampledMarginal {
6639 value,
6640 rho_gradient,
6641 importance_ess,
6642 n_draws,
6643 moments,
6644 })
6645}
6646
6647#[derive(Clone, Debug)]
6649pub struct JointBetaRhoResult {
6650 pub beta_samples: Array2<f64>,
6652 pub rho_samples: Array2<f64>,
6654 pub beta_mean: Array1<f64>,
6656 pub link_param_samples: Array2<f64>,
6658 pub link_param_mean: Array1<f64>,
6660 pub rho_mean: Array1<f64>,
6662 pub rhat: f64,
6664 pub ess: f64,
6666 pub converged: bool,
6668 pub trigger_skewness: f64,
6670}
6671
6672struct JointBetaRhoPosterior {
6681 data: SharedData,
6682 chol: Array2<f64>,
6684 chol_t: Array2<f64>,
6686 likelihood: LikelihoodSpec,
6688 n_beta: usize,
6690 n_rho: usize,
6692 n_link_params: usize,
6694 link_param_mode: Array1<f64>,
6696 penalty_canonical: Vec<gam_terms::construction::CanonicalPenalty>,
6698 rho_prior: RhoPrior,
6700 rho_mode: Array1<f64>,
6702 firth_enabled: bool,
6705 penalty_logdet_cache: Mutex<Option<(u64, f64, Array1<f64>)>>,
6713}
6714
6715impl JointBetaRhoPosterior {
6716 fn new(
6717 x: ArrayView2<f64>,
6718 y: ArrayView1<f64>,
6719 weights: ArrayView1<f64>,
6720 mode: ArrayView1<f64>,
6721 hessian: ArrayView2<f64>,
6722 penalty_canonical: Vec<gam_terms::construction::CanonicalPenalty>,
6723 rho_mode: ArrayView1<f64>,
6724 likelihood: LikelihoodSpec,
6725 gamma_shape: Option<f64>,
6726 rho_prior: RhoPrior,
6727 firth_enabled: bool,
6728 ) -> Result<Self, String> {
6729 let n_samples = x.nrows();
6730 let n_beta = x.ncols();
6731 let n_rho = penalty_canonical.len();
6732
6733 if rho_mode.len() != n_rho {
6734 return Err(HmcError::DimensionMismatch {
6735 reason: format!(
6736 "rho_mode length {} != penalty count {}",
6737 rho_mode.len(),
6738 n_rho
6739 ),
6740 }
6741 .into());
6742 }
6743
6744 match (&likelihood.response, &likelihood.link) {
6745 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {}
6746 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit)) => {}
6747 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {}
6748 (ResponseFamily::Binomial, InverseLink::LatentCLogLog(_)) => {}
6749 (ResponseFamily::Binomial, InverseLink::Sas(_)) => {}
6750 (ResponseFamily::Binomial, InverseLink::BetaLogistic(_)) => {}
6751 (ResponseFamily::Binomial, InverseLink::Mixture(_)) => {}
6752 (ResponseFamily::Binomial, InverseLink::Standard(other)) => {
6753 return Err(HmcError::LinkMismatch {
6754 reason: format!(
6755 "Joint HMC binomial response requires a binomial-compatible inverse link; got {:?}",
6756 other
6757 ),
6758 }
6759 .into());
6760 }
6761 (ResponseFamily::Gaussian, InverseLink::Standard(StandardLink::Identity)) => {}
6762 (ResponseFamily::Gaussian, _) => {
6763 return Err(HmcError::LinkMismatch {
6764 reason: "Joint HMC Gaussian requires an identity inverse link".to_string(),
6765 }
6766 .into());
6767 }
6768 (
6769 ResponseFamily::Poisson
6770 | ResponseFamily::Tweedie { .. }
6771 | ResponseFamily::NegativeBinomial { .. }
6772 | ResponseFamily::Gamma,
6773 InverseLink::Standard(StandardLink::Log),
6774 ) => {}
6775 (
6776 ResponseFamily::Poisson
6777 | ResponseFamily::Tweedie { .. }
6778 | ResponseFamily::NegativeBinomial { .. }
6779 | ResponseFamily::Gamma,
6780 _,
6781 ) => {
6782 return Err(HmcError::LinkMismatch {
6783 reason: "Joint HMC log-link family requires a log inverse link".to_string(),
6784 }
6785 .into());
6786 }
6787 (ResponseFamily::Beta { .. }, InverseLink::Standard(StandardLink::Logit)) => {}
6788 (ResponseFamily::Beta { .. }, _) => {
6789 return Err(HmcError::LinkMismatch {
6790 reason: "Joint HMC Beta requires a logit inverse link".to_string(),
6791 }
6792 .into());
6793 }
6794 (ResponseFamily::RoystonParmar, _) => {
6795 return Err(HmcError::UnsupportedFamily {
6796 reason: "Joint HMC fallback is not implemented for RoystonParmar".to_string(),
6797 }
6798 .into());
6799 }
6800 }
6801
6802 validate_firth_likelihood_support(&likelihood, firth_enabled).map_err(String::from)?;
6803 if matches!(likelihood.response, ResponseFamily::NegativeBinomial { .. }) {
6804 validate_count_responses("negative-binomial joint HMC", &y, &weights)
6805 .map_err(String::from)?;
6806 }
6807 if likelihood.is_binomial() {
6808 validate_binary_responses("binomial joint HMC", &y, &weights).map_err(String::from)?;
6809 }
6810
6811 let whitening = hessian_whitening_transform(
6812 hessian,
6813 n_beta,
6814 1.0,
6815 "Joint HMC: Hessian Cholesky failed",
6816 )?;
6817 let chol = whitening.chol;
6818 let chol_t = whitening.chol_t;
6819
6820 let data = SharedData {
6821 x: Arc::new(x.to_owned()),
6822 y: Arc::new(y.to_owned()),
6823 weights: Arc::new(weights.to_owned()),
6824 mode: Arc::new(mode.to_owned()),
6825 offset: None,
6826 gamma_shape: gamma_shape.unwrap_or(1.0),
6827 dispersion: gam_solve::model_types::Dispersion::Known(1.0),
6832 n_samples,
6833 dim: n_beta,
6834 };
6835 let link_param_mode = Self::link_param_mode(&likelihood.link);
6836
6837 Ok(Self {
6838 data,
6839 chol,
6840 chol_t,
6841 likelihood,
6842 n_beta,
6843 n_rho,
6844 n_link_params: link_param_mode.len(),
6845 link_param_mode,
6846 penalty_canonical,
6847 rho_prior,
6848 rho_mode: rho_mode.to_owned(),
6849 firth_enabled,
6850 penalty_logdet_cache: Mutex::new(None),
6851 })
6852 }
6853
6854 #[inline]
6860 fn hash_rho(rho: ndarray::ArrayView1<f64>) -> u64 {
6861 let mut h: u64 = 0xcbf2_9ce4_8422_2325;
6862 for &x in rho.iter() {
6863 h ^= x.to_bits();
6864 h = h.wrapping_mul(0x0000_0100_0000_01b3);
6865 }
6866 h
6867 }
6868
6869 fn link_param_mode(inverse_link: &InverseLink) -> Array1<f64> {
6870 match inverse_link {
6871 InverseLink::Sas(state) | InverseLink::BetaLogistic(state) => {
6872 Array1::from_vec(vec![state.epsilon, state.log_delta])
6873 }
6874 InverseLink::Mixture(state) => state.rho.clone(),
6875 InverseLink::Standard(_) | InverseLink::LatentCLogLog(_) => Array1::zeros(0),
6876 }
6877 }
6878
6879 fn inverse_link_with_params(
6880 &self,
6881 link_params: ndarray::ArrayView1<'_, f64>,
6882 ) -> Result<InverseLink, String> {
6883 match &self.likelihood.link {
6884 InverseLink::Sas(_) => {
6885 if link_params.len() != 2 {
6886 return Err(format!(
6887 "SAS link parameter length must be 2, got {}",
6888 link_params.len()
6889 ));
6890 }
6891 Ok(InverseLink::Sas(
6892 gam_solve::mixture_link::sas_link_state_from_raw(
6893 link_params[0],
6894 link_params[1],
6895 )?,
6896 ))
6897 }
6898 InverseLink::BetaLogistic(_) => {
6899 if link_params.len() != 2 || !link_params.iter().all(|v| v.is_finite()) {
6900 return Err(
6901 "Beta-Logistic link parameters must be finite with length 2".to_string()
6902 );
6903 }
6904 Ok(InverseLink::BetaLogistic(gam_problem::types::SasLinkState {
6905 epsilon: link_params[0],
6906 log_delta: link_params[1],
6907 delta: link_params[1].exp(),
6908 }))
6909 }
6910 InverseLink::Mixture(state) => {
6911 let rho = link_params.to_owned();
6912 Ok(InverseLink::Mixture(gam_problem::types::MixtureLinkState {
6913 components: state.components.clone(),
6914 pi: softmax_last_fixedzero(&rho),
6915 rho,
6916 }))
6917 }
6918 InverseLink::Standard(_) | InverseLink::LatentCLogLog(_) => {
6919 Ok(self.likelihood.link.clone())
6920 }
6921 }
6922 }
6923
6924 fn compute_joint_logp_and_grad_into(
6937 &self,
6938 params: &Array1<f64>,
6939 out_grad: &mut Array1<f64>,
6940 ) -> f64 {
6941 let n_beta = self.n_beta;
6942 let n_rho = self.n_rho;
6943 let n_link_params = self.n_link_params;
6944
6945 let z = params.slice(ndarray::s![..n_beta]);
6948 let rho = params.slice(ndarray::s![n_beta..n_beta + n_rho]);
6949 let link_params = params.slice(ndarray::s![n_beta + n_rho..]);
6950 let lambdas: Array1<f64> = rho.mapv(f64::exp);
6951
6952 let inverse_link = match self.inverse_link_with_params(link_params) {
6953 Ok(link) => link,
6954 Err(err) => {
6955 log::warn!(
6956 "[Joint HMC] adaptive inverse-link parameters are invalid: {}",
6957 err
6958 );
6959 out_grad.fill(0.0);
6960 return f64::NEG_INFINITY;
6961 }
6962 };
6963
6964 let beta = self.data.mode.as_ref() + &self.chol.dot(&z);
6966
6967 let eta = gam_linalg::faer_ndarray::fast_av(self.data.x.as_ref(), &beta);
6969
6970 let step_likelihood = LikelihoodSpec {
6972 response: self.likelihood.response.clone(),
6973 link: inverse_link,
6974 };
6975 let (ll, mut grad_ll_beta, grad_link) = match joint_family_logp_grad_and_link_grad(
6976 &step_likelihood,
6977 &self.data,
6978 &eta,
6979 n_link_params,
6980 ) {
6981 Ok(value) => value,
6982 Err(err) => {
6983 log::warn!(
6984 "[Joint HMC] likelihood target became invalid at the current state: {}",
6985 err
6986 );
6987 out_grad.fill(0.0);
6988 return f64::NEG_INFINITY;
6989 }
6990 };
6991
6992 let mut firth_logdet = 0.0;
6993 if self.firth_enabled {
6994 match firth_jeffreys_logp_and_grad(NutsFamily::BinomialLogit, &self.data, &eta) {
6995 Ok((value, grad_beta_firth)) => {
6996 firth_logdet = value;
6997 grad_ll_beta += &grad_beta_firth;
6998 }
6999 Err(err) => {
7000 log::warn!(
7001 "[Joint HMC/Firth] Jeffreys target became invalid at the current state: {}",
7002 err
7003 );
7004 out_grad.fill(0.0);
7005 return f64::NEG_INFINITY;
7006 }
7007 }
7008 }
7009
7010 let mut penalty_val = 0.0;
7014 let mut s_beta = Array1::<f64>::zeros(n_beta);
7015 let mut grad_rho = Array1::<f64>::zeros(n_rho);
7016
7017 let max_rank = self
7021 .penalty_canonical
7022 .iter()
7023 .map(|cp| cp.rank())
7024 .max()
7025 .unwrap_or(0);
7026 let mut r_beta_scratch = Array1::<f64>::zeros(max_rank);
7027
7028 for (k, cp) in self.penalty_canonical.iter().enumerate() {
7029 let r = &cp.col_range;
7031 let beta_block = beta.slice(ndarray::s![r.start..r.end]);
7032 let rank_k = cp.rank();
7033 gam_linalg::faer_ndarray::fast_av_view_into(
7034 &cp.root,
7035 &beta_block,
7036 r_beta_scratch.slice_mut(ndarray::s![..rank_k]),
7037 );
7038 let r_beta = r_beta_scratch.slice(ndarray::s![..rank_k]);
7039 let quad_k = r_beta.dot(&r_beta);
7040 penalty_val += 0.5 * lambdas[k] * quad_k;
7041
7042 for a in 0..cp.block_dim() {
7044 let val: f64 = (0..rank_k).map(|row| cp.root[[row, a]] * r_beta[row]).sum();
7045 s_beta[r.start + a] += lambdas[k] * val;
7046 }
7047
7048 grad_rho[k] = -0.5 * lambdas[k] * quad_k;
7050 }
7051
7052 let log_det_s = if self.penalty_canonical.is_empty() {
7060 0.0
7061 } else {
7062 let rho_hash = Self::hash_rho(rho);
7063 let cached = self.penalty_logdet_cache.lock().ok().and_then(|guard| {
7064 guard.as_ref().and_then(|(h, v, g)| {
7065 if *h == rho_hash && g.len() == n_rho {
7066 for k in 0..n_rho {
7067 grad_rho[k] += 0.5 * g[k];
7068 }
7069 Some(*v)
7070 } else {
7071 None
7072 }
7073 })
7074 });
7075 if let Some(hit) = cached {
7076 hit
7077 } else {
7078 match PenaltyPseudologdet::from_penalties(
7079 &self.penalty_canonical,
7080 lambdas.as_slice().unwrap_or(&[]),
7081 0.0,
7082 n_beta,
7083 ) {
7084 Ok(pld) => {
7085 let (det1, _) = pld.rho_derivatives_from_penalties(
7086 &self.penalty_canonical,
7087 lambdas.as_slice().unwrap_or(&[]),
7088 );
7089 let value = pld.value();
7090 if let Ok(mut guard) = self.penalty_logdet_cache.lock() {
7091 *guard = Some((rho_hash, value, det1.clone()));
7092 }
7093 for k in 0..n_rho {
7094 grad_rho[k] += 0.5 * det1[k];
7095 }
7096 value
7097 }
7098 Err(err) => {
7099 log::warn!(
7100 "[Joint HMC] structural penalty logdet became invalid at the current state: {}",
7101 err
7102 );
7103 out_grad.fill(0.0);
7104 return f64::NEG_INFINITY;
7105 }
7106 }
7107 }
7108 };
7109
7110 let mut rho_prior = 0.0;
7112 match &self.rho_prior {
7113 RhoPrior::Flat => {}
7114 RhoPrior::Normal { mean, sd } => {
7115 let inv_var = 1.0 / (*sd * *sd);
7116 for k in 0..n_rho {
7117 let d = rho[k] - *mean;
7118 rho_prior -= 0.5 * inv_var * d * d;
7119 grad_rho[k] -= inv_var * d;
7120 }
7121 }
7122 RhoPrior::GammaPrecision { shape, rate } => {
7123 for k in 0..n_rho {
7124 let lambda = rho[k].exp();
7125 rho_prior += *shape * rho[k] - *rate * lambda;
7127 grad_rho[k] += *shape - *rate * lambda;
7128 }
7129 }
7130 RhoPrior::PenalizedComplexity { upper, tail_prob } => {
7131 if !pc_prior_params_valid(*upper, *tail_prob) {
7132 out_grad.fill(0.0);
7133 return f64::NEG_INFINITY;
7134 }
7135 let theta = -tail_prob.ln() / *upper;
7136 for k in 0..n_rho {
7137 let e = (-0.5 * rho[k]).exp();
7139 rho_prior += -0.5 * rho[k] - theta * e;
7140 grad_rho[k] += -0.5 + 0.5 * theta * e;
7141 }
7142 }
7143 RhoPrior::Independent(priors) => {
7144 if priors.len() != n_rho {
7145 out_grad.fill(0.0);
7146 return f64::NEG_INFINITY;
7147 }
7148 for k in 0..n_rho {
7149 match &priors[k] {
7150 RhoPrior::Flat => {}
7151 RhoPrior::Normal { mean, sd } => {
7152 let inv_var = 1.0 / (*sd * *sd);
7153 let d = rho[k] - *mean;
7154 rho_prior -= 0.5 * inv_var * d * d;
7155 grad_rho[k] -= inv_var * d;
7156 }
7157 RhoPrior::GammaPrecision { shape, rate } => {
7158 let lambda = rho[k].exp();
7159 rho_prior += *shape * rho[k] - *rate * lambda;
7161 grad_rho[k] += *shape - *rate * lambda;
7162 }
7163 RhoPrior::PenalizedComplexity { upper, tail_prob } => {
7164 if !pc_prior_params_valid(*upper, *tail_prob) {
7165 out_grad.fill(0.0);
7166 return f64::NEG_INFINITY;
7167 }
7168 let theta = -tail_prob.ln() / *upper;
7169 let e = (-0.5 * rho[k]).exp();
7170 rho_prior += -0.5 * rho[k] - theta * e;
7171 grad_rho[k] += -0.5 + 0.5 * theta * e;
7172 }
7173 RhoPrior::Independent(_) => {
7174 out_grad.fill(0.0);
7175 return f64::NEG_INFINITY;
7176 }
7177 }
7178 }
7179 }
7180 }
7181
7182 let logp = ll + firth_logdet - penalty_val + 0.5 * log_det_s + rho_prior;
7184
7185 let grad_beta = &grad_ll_beta - &s_beta;
7187
7188 gam_linalg::faer_ndarray::fast_av_view_into(
7190 &self.chol_t,
7191 &grad_beta,
7192 out_grad.slice_mut(ndarray::s![..n_beta]),
7193 );
7194 out_grad
7195 .slice_mut(ndarray::s![n_beta..n_beta + n_rho])
7196 .assign(&grad_rho);
7197 out_grad
7198 .slice_mut(ndarray::s![n_beta + n_rho..])
7199 .assign(&grad_link);
7200
7201 logp
7202 }
7203}
7204
7205fn pc_prior_params_valid(upper: f64, tail_prob: f64) -> bool {
7211 upper.is_finite() && upper > 0.0 && tail_prob.is_finite() && tail_prob > 0.0 && tail_prob < 1.0
7212}
7213
7214impl HamiltonianTarget<Array1<f64>> for JointBetaRhoPosterior {
7215 fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
7216 self.compute_joint_logp_and_grad_into(position, grad)
7217 }
7218}
7219
7220pub struct JointBetaRhoInputs<'a> {
7222 pub x: ArrayView2<'a, f64>,
7223 pub y: ArrayView1<'a, f64>,
7224 pub weights: ArrayView1<'a, f64>,
7225 pub likelihood: LikelihoodSpec,
7226 pub gamma_shape: Option<f64>,
7227 pub mode: ArrayView1<'a, f64>,
7228 pub hessian: ArrayView2<'a, f64>,
7229 pub penalty_roots: Vec<CanonicalPenalty>,
7230 pub rho_mode: ArrayView1<'a, f64>,
7231 pub rho_prior: RhoPrior,
7232 pub firth_bias_reduction: bool,
7233 pub trigger_skewness: f64,
7235}
7236
7237pub fn run_joint_beta_rho_sampling(
7243 inputs: &JointBetaRhoInputs<'_>,
7244 config: &NutsConfig,
7245) -> Result<JointBetaRhoResult, String> {
7246 validate_firth_likelihood_support(&inputs.likelihood, inputs.firth_bias_reduction)
7247 .map_err(String::from)?;
7248 validate_nuts_config(config).map_err(String::from)?;
7249 let n_beta = inputs.mode.len();
7250 let n_rho = inputs.penalty_roots.len();
7251 let n_link_params = JointBetaRhoPosterior::link_param_mode(&inputs.likelihood.link).len();
7252 let total_dim = n_beta + n_rho + n_link_params;
7253
7254 log::info!(
7255 "[Joint HMC] Sampling (β, ρ, link) jointly: {} β-params + {} ρ-params + {} link-params = {} total (triggered by skewness {:.3})",
7256 n_beta,
7257 n_rho,
7258 n_link_params,
7259 total_dim,
7260 inputs.trigger_skewness,
7261 );
7262
7263 let target = JointBetaRhoPosterior::new(
7264 inputs.x,
7265 inputs.y,
7266 inputs.weights,
7267 inputs.mode,
7268 inputs.hessian,
7269 inputs.penalty_roots.clone(),
7270 inputs.rho_mode,
7271 inputs.likelihood.clone(),
7272 inputs.gamma_shape,
7273 inputs.rho_prior.clone(),
7274 inputs.firth_bias_reduction,
7275 )?;
7276
7277 let chol = target.chol.clone();
7278 let mode_arr = target.data.mode.clone();
7279 let rho_mode = target.rho_mode.clone();
7280 let link_param_mode = target.link_param_mode.clone();
7281
7282 let initial_positions: Vec<Array1<f64>> = (0..config.n_chains)
7284 .map(|chain| {
7285 let mut rng =
7286 StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0x9B51_6E37_F2D0_A48C));
7287 let mut pos = Array1::<f64>::zeros(total_dim);
7288 for j in 0..n_beta {
7290 pos[j] = sample_standard_normal(&mut rng) * 0.1;
7291 }
7292 for k in 0..n_rho {
7294 pos[n_beta + k] = rho_mode[k] + sample_standard_normal(&mut rng) * 0.2;
7295 }
7296 for k in 0..n_link_params {
7298 pos[n_beta + n_rho + k] =
7299 link_param_mode[k] + sample_standard_normal(&mut rng) * 0.05;
7300 }
7301 pos
7302 })
7303 .collect();
7304
7305 let mass_cfg = robust_mass_matrix_config(total_dim, config.nwarmup);
7308
7309 let (samples_array, run_stats) = run_whitened_nuts_samples(
7310 target,
7311 initial_positions,
7312 config,
7313 total_dim,
7314 mass_cfg,
7315 0x63AF_175B_D820_C94E,
7316 "Joint (β,ρ) NUTS sampling failed",
7317 )?;
7318 log::info!("[Joint HMC] Sampling complete: {}", run_stats);
7319
7320 let shape = samples_array.shape();
7322 let n_chains = shape[0];
7323 let n_samples_out = shape[1];
7324 let total_samples = n_chains * n_samples_out;
7325
7326 let beta_samples = unwhiten_samples(&samples_array, mode_arr.as_ref(), &chol, n_beta, 0);
7327 let mut rho_samples = Array2::<f64>::zeros((total_samples, n_rho));
7328 let mut link_param_samples = Array2::<f64>::zeros((total_samples, n_link_params));
7329
7330 for chain in 0..n_chains {
7331 for sample_i in 0..n_samples_out {
7332 let sample_idx = chain * n_samples_out + sample_i;
7333 let zview = samples_array.slice(ndarray::s![chain, sample_i, ..]);
7334
7335 let rho_slice = zview.slice(ndarray::s![n_beta..n_beta + n_rho]);
7337 rho_samples.row_mut(sample_idx).assign(&rho_slice);
7338 let link_slice = zview.slice(ndarray::s![n_beta + n_rho..]);
7339 link_param_samples.row_mut(sample_idx).assign(&link_slice);
7340 }
7341 }
7342
7343 let beta_mean = beta_samples
7344 .mean_axis(Axis(0))
7345 .unwrap_or_else(|| Array1::zeros(n_beta));
7346 let rho_mean = rho_samples
7347 .mean_axis(Axis(0))
7348 .unwrap_or_else(|| Array1::zeros(n_rho));
7349 let link_param_mean = link_param_samples
7350 .mean_axis(Axis(0))
7351 .unwrap_or_else(|| Array1::zeros(n_link_params));
7352
7353 let (rhat, ess) = compute_split_rhat_and_ess(&samples_array);
7354
7355 let converged = NutsConvergenceThresholds {
7356 max_rhat: 1.1,
7357 min_ess: Some(50.0),
7358 }
7359 .converged(rhat, ess);
7360 if !converged {
7361 log::warn!(
7362 "[Joint HMC] Convergence warning: R-hat={:.3}, ESS={:.1}",
7363 rhat,
7364 ess,
7365 );
7366 }
7367
7368 Ok(JointBetaRhoResult {
7369 beta_samples,
7370 rho_samples,
7371 beta_mean,
7372 link_param_samples,
7373 link_param_mean,
7374 rho_mean,
7375 rhat,
7376 ess,
7377 converged,
7378 trigger_skewness: inputs.trigger_skewness,
7379 })
7380}
7381
7382mod survival_hmc {
7387 use super::*;
7388 use gam_models::survival::{
7389 PenaltyBlocks, SurvivalEngineInputs, SurvivalMonotonicityPenalty, SurvivalSpec,
7390 WorkingModelSurvival,
7391 };
7392
7393 #[derive(Clone)]
7395 struct SharedSurvivalData {
7396 base_model: Arc<WorkingModelSurvival>,
7398 mode: Arc<Array1<f64>>,
7400 }
7401
7402 #[derive(Clone)]
7404 pub struct SurvivalPosterior {
7405 data: SharedSurvivalData,
7407 chol: Array2<f64>,
7409 chol_t: Array2<f64>,
7411 }
7412
7413 impl SurvivalPosterior {
7414 pub fn new(
7416 age_entry: ArrayView1<'_, f64>,
7417 age_exit: ArrayView1<'_, f64>,
7418 event_target: ArrayView1<'_, u8>,
7419 event_competing: ArrayView1<'_, u8>,
7420 sampleweight: ArrayView1<'_, f64>,
7421 x_entry: ArrayView2<'_, f64>,
7422 x_exit: ArrayView2<'_, f64>,
7423 x_derivative: ArrayView2<'_, f64>,
7424 offset_eta_entry: Option<ArrayView1<'_, f64>>,
7425 offset_eta_exit: Option<ArrayView1<'_, f64>>,
7426 offset_derivative_exit: Option<ArrayView1<'_, f64>>,
7427 penalties: PenaltyBlocks,
7428 monotonicity: SurvivalMonotonicityPenalty,
7429 spec: SurvivalSpec,
7430 structurally_monotonic: bool,
7431 structural_time_columns: usize,
7432 mode: ArrayView1<f64>,
7433 hessian: ArrayView2<f64>,
7434 ) -> Result<Self, String> {
7435 let n = age_entry.len();
7436 let off_eta_entry = offset_eta_entry
7437 .map(|v| v.to_owned())
7438 .unwrap_or_else(|| Array1::zeros(n));
7439 let off_eta_exit = offset_eta_exit
7440 .map(|v| v.to_owned())
7441 .unwrap_or_else(|| Array1::zeros(n));
7442 let off_deriv_exit = offset_derivative_exit
7443 .map(|v| v.to_owned())
7444 .unwrap_or_else(|| Array1::zeros(n));
7445
7446 let mut base_model = WorkingModelSurvival::from_engine_inputswith_offsets(
7447 SurvivalEngineInputs {
7448 age_entry,
7449 age_exit,
7450 event_target,
7451 event_competing,
7452 sampleweight,
7453 x_entry,
7454 x_exit,
7455 x_derivative,
7456 monotonicity_constraint_rows: None,
7457 monotonicity_constraint_offsets: None,
7458 },
7459 Some(gam_models::survival::SurvivalBaselineOffsets {
7460 eta_entry: off_eta_entry.view(),
7461 eta_exit: off_eta_exit.view(),
7462 derivative_exit: off_deriv_exit.view(),
7463 }),
7464 penalties,
7465 monotonicity,
7466 spec,
7467 )
7468 .map_err(|e| format!("Survival state construction failed: {:?}", e))?;
7469 if structurally_monotonic {
7470 base_model
7471 .set_structural_monotonicity(true, structural_time_columns)
7472 .map_err(|e| {
7473 format!("Failed to enable structural monotonicity in survival HMC: {e}")
7474 })?;
7475 }
7476
7477 let sampler_mode = mode.to_owned();
7478 let dim = sampler_mode.len();
7479
7480 let whitening = hessian_whitening_transform(
7481 hessian,
7482 dim,
7483 1.0,
7484 "Hessian Cholesky decomposition failed",
7485 )?;
7486 let chol = whitening.chol;
7487 let chol_t = whitening.chol_t;
7488
7489 let data = SharedSurvivalData {
7490 base_model: Arc::new(base_model),
7491 mode: Arc::new(sampler_mode),
7492 };
7493
7494 Ok(Self { data, chol, chol_t })
7495 }
7496
7497 fn compute_logp_and_grad_into(
7498 &self,
7499 z: &Array1<f64>,
7500 grad: &mut Array1<f64>,
7501 ) -> Result<f64, String> {
7502 let sampler_position = self.data.mode.as_ref() + &self.chol.dot(z);
7503 let state = self
7504 .data
7505 .base_model
7506 .update_state(&sampler_position)
7507 .map_err(|e| format!("Survival state update failed: {:?}", e))?;
7508 let logp = state.log_likelihood - state.penalty_term;
7509 let grad_beta = state.gradient.mapv(|g| -g);
7510 fast_av_into(&self.chol_t, &grad_beta, grad);
7511 Ok(logp)
7512 }
7513
7514 pub fn chol(&self) -> &Array2<f64> {
7516 &self.chol
7517 }
7518
7519 pub fn mode(&self) -> &Array1<f64> {
7521 &self.data.mode
7522 }
7523 }
7524
7525 impl HamiltonianTarget<Array1<f64>> for SurvivalPosterior {
7526 fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
7527 match self.compute_logp_and_grad_into(position, grad) {
7528 Ok(logp) => logp,
7529 Err(e) => {
7530 log::warn!("Survival posterior evaluation failed: {}", e);
7531 grad.fill(0.0);
7532 f64::NEG_INFINITY
7533 }
7534 }
7535 }
7536 }
7537
7538 pub(crate) fn run_survival_nuts_sampling(
7540 age_entry: ArrayView1<'_, f64>,
7541 age_exit: ArrayView1<'_, f64>,
7542 event_target: ArrayView1<'_, u8>,
7543 event_competing: ArrayView1<'_, u8>,
7544 sampleweight: ArrayView1<'_, f64>,
7545 x_entry: ArrayView2<'_, f64>,
7546 x_exit: ArrayView2<'_, f64>,
7547 x_derivative: ArrayView2<'_, f64>,
7548 eta_offset_entry: Option<ArrayView1<'_, f64>>,
7549 eta_offset_exit: Option<ArrayView1<'_, f64>>,
7550 derivative_offset_exit: Option<ArrayView1<'_, f64>>,
7551 penalties: PenaltyBlocks,
7552 monotonicity: SurvivalMonotonicityPenalty,
7553 spec: SurvivalSpec,
7554 structurally_monotonic: bool,
7555 structural_time_columns: usize,
7556 mode: ArrayView1<f64>,
7557 hessian: ArrayView2<f64>,
7558 config: &NutsConfig,
7559 ) -> Result<NutsResult, String> {
7560 validate_nuts_config(config).map_err(String::from)?;
7561 let target = SurvivalPosterior::new(
7563 age_entry,
7564 age_exit,
7565 event_target,
7566 event_competing,
7567 sampleweight,
7568 x_entry,
7569 x_exit,
7570 x_derivative,
7571 eta_offset_entry,
7572 eta_offset_exit,
7573 derivative_offset_exit,
7574 penalties,
7575 monotonicity,
7576 spec,
7577 structurally_monotonic,
7578 structural_time_columns,
7579 mode,
7580 hessian,
7581 )?;
7582
7583 let chol = target.chol().clone();
7585 let mode_arr = target.mode().clone();
7586 let dim = mode_arr.len();
7587
7588 let initial_positions = jittered_initial_positions(config, dim, 0.1, 0xEC2D_7A9B_4051_F638);
7589
7590 let mass_cfg = robust_survival_mass_matrix_config(dim, config.nwarmup);
7591 let (result, run_stats) = run_whitened_nuts_result(
7592 target,
7593 &mode_arr,
7594 &chol,
7595 initial_positions,
7596 config,
7597 dim,
7598 mass_cfg,
7599 0x731B_60D4_AE52_9C8F,
7600 "NUTS sampling failed",
7601 Array1::zeros(dim),
7602 NutsConvergenceThresholds {
7603 max_rhat: 1.1,
7604 min_ess: None,
7605 },
7606 )?;
7607
7608 log::info!("Survival NUTS sampling complete: {}", run_stats);
7609
7610 Ok(result)
7611 }
7612}
7613
7614pub fn run_survival_nuts_sampling_flattened<'a>(
7616 flat: SurvivalFlatInputs<'a>,
7617 penalties: gam_models::survival::PenaltyBlocks,
7618 monotonicity: gam_models::survival::SurvivalMonotonicityPenalty,
7619 spec: gam_models::survival::SurvivalSpec,
7620 structurally_monotonic: bool,
7621 structural_time_columns: usize,
7622 mode: ArrayView1<'a, f64>,
7623 hessian: ArrayView2<'a, f64>,
7624 config: &NutsConfig,
7625) -> Result<NutsResult, String> {
7626 run_nuts_sampling_flattened_family(
7627 LikelihoodSpec {
7628 response: ResponseFamily::RoystonParmar,
7629 link: InverseLink::Standard(StandardLink::Identity),
7630 },
7631 FamilyNutsInputs::Survival(Box::new(SurvivalNutsInputs {
7632 flat,
7633 penalties,
7634 monotonicity,
7635 spec,
7636 structurally_monotonic,
7637 structural_time_columns,
7638 mode,
7639 hessian,
7640 })),
7641 config,
7642 )
7643}