1use crate::gpu_polya_gamma::{PgSeed, PolyaGammaBatchInput};
27use faer::Side;
28use gam_linalg::faer_ndarray::{FaerCholesky, FaerEigh, fast_ata_into, fast_atv, fast_av_into};
29use gam_linalg::matrix::DesignMatrix;
30use gam_linalg::triangular::back_substitution_lower_transpose_guarded_into;
31use gam_models::wiggle::monotone_wiggle_basis_with_derivative_order;
32use gam_problem::types::{
33 InverseLink, LikelihoodSpec, ResponseFamily, RhoPrior, StandardLink, is_valid_tweedie_power,
34};
35use gam_solve::estimate::reml::FirthDenseOperator;
36use gam_solve::estimate::reml::penalty_logdet::PenaltyPseudologdet;
37use gam_solve::estimate::{
38 EstimationError, UnifiedFitResult, validate_explicit_dense_hessian_for_whitening,
39};
40use gam_solve::mixture_link::{
41 InverseLinkKernel, LinkParamPartials, inverse_link_jet_for_inverse_link, softmax_last_fixedzero,
42};
43use gam_terms::construction::CanonicalPenalty;
44use general_mcmc::generic_hmc::HamiltonianTarget;
45pub use general_mcmc::generic_nuts::NUTSMassMatrixConfig;
46use general_mcmc::generic_nuts::{GenericNUTS, MassMatrixAdaptation};
47use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, Axis};
48use rand::{RngExt, SeedableRng, rngs::StdRng};
49use serde::{Deserialize, Serialize};
50use std::cell::RefCell;
51use std::fmt;
52use std::sync::{Arc, Mutex};
53
54#[inline]
59fn likelihood_spec_supports_firth(spec: &LikelihoodSpec) -> bool {
60 spec.supports_firth()
61}
62
63#[inline]
66fn likelihood_spec_jeffreys_link(spec: &LikelihoodSpec) -> Option<InverseLink> {
67 if likelihood_spec_supports_firth(spec) {
68 Some(spec.link.clone())
69 } else {
70 None
71 }
72}
73
74#[derive(Debug, Clone)]
81pub enum HmcError {
82 NonFiniteState { reason: String },
85 InvalidConfig { reason: String },
88 DimensionMismatch { reason: String },
90 FirthUnsupported { reason: String },
93 LinkMismatch { reason: String },
96 UnsupportedFamily { reason: String },
98 SamplingFailed { reason: String },
101}
102
103impl fmt::Display for HmcError {
104 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105 match self {
106 HmcError::NonFiniteState { reason }
107 | HmcError::InvalidConfig { reason }
108 | HmcError::DimensionMismatch { reason }
109 | HmcError::FirthUnsupported { reason }
110 | HmcError::LinkMismatch { reason }
111 | HmcError::UnsupportedFamily { reason }
112 | HmcError::SamplingFailed { reason } => f.write_str(reason),
113 }
114 }
115}
116
117impl From<HmcError> for String {
118 fn from(err: HmcError) -> String {
119 err.to_string()
120 }
121}
122
123const MAX_AUTOCORRELATION_LAG: usize = 1000;
128
129const AUTOCOVARIANCE_FLOOR: f64 = 1e-16;
133
134fn compute_split_rhat_and_ess(samples: &Array3<f64>) -> (f64, f64) {
139 let n_chains = samples.shape()[0];
140 let n_samples = samples.shape()[1];
141 let dim = samples.shape()[2];
142
143 if n_chains < 2 || n_samples < 4 {
144 return (1.0, n_chains as f64 * n_samples as f64 * 0.5);
145 }
146
147 let half = n_samples / 2;
149 let n_split_chains = n_chains * 2;
150 let n_split_samples = half;
151
152 let mut max_rhat = 0.0f64;
153 let mut min_ess = f64::INFINITY;
154
155 #[inline]
156 fn splitvalue(
157 samples: &Array3<f64>,
158 n_chains: usize,
159 half: usize,
160 dim: usize,
161 sc: usize,
162 t: usize,
163 ) -> f64 {
164 let chain = sc % n_chains;
165 if sc < n_chains {
166 samples[[chain, t, dim]]
167 } else {
168 samples[[chain, half + t, dim]]
169 }
170 }
171
172 fn ess_from_split_dimension(
173 samples: &Array3<f64>,
174 n_chains: usize,
175 half: usize,
176 dim: usize,
177 ) -> f64 {
178 let m = n_chains * 2;
179 let n = half;
180 if m == 0 || n < 4 {
181 return (m * n).max(1) as f64;
182 }
183
184 let mut means = vec![0.0_f64; m];
185 let mut gamma0 = vec![0.0_f64; m];
186 for sc in 0..m {
187 let mut sum = 0.0;
188 for t in 0..n {
189 sum += splitvalue(samples, n_chains, half, dim, sc, t);
190 }
191 let mean = sum / n as f64;
192 means[sc] = mean;
193 let mut g0 = 0.0;
194 for t in 0..n {
195 let d = splitvalue(samples, n_chains, half, dim, sc, t) - mean;
196 g0 += d * d;
197 }
198 gamma0[sc] = (g0 / n as f64).max(AUTOCOVARIANCE_FLOOR);
199 }
200
201 let max_lag = (n - 1).min(MAX_AUTOCORRELATION_LAG);
202 let mut tau = 1.0_f64;
203 let mut lag = 1usize;
204 while lag < max_lag {
205 let mut pair = 0.0_f64;
206 for l in [lag, lag + 1] {
207 if l > max_lag {
208 continue;
209 }
210 let mut rho_l = 0.0;
211 for sc in 0..m {
212 let mu = means[sc];
213 let mut cov = 0.0;
214 let denom = (n - l) as f64;
215 for t in 0..(n - l) {
216 let x0 = splitvalue(samples, n_chains, half, dim, sc, t);
217 let x1 = splitvalue(samples, n_chains, half, dim, sc, t + l);
218 cov += (x0 - mu) * (x1 - mu);
219 }
220 cov /= denom;
221 rho_l += cov / gamma0[sc];
222 }
223 rho_l /= m as f64;
224 pair += rho_l;
225 }
226 if !pair.is_finite() || pair <= 0.0 {
227 break;
228 }
229 tau += 2.0 * pair;
230 lag += 2;
231 }
232 if !tau.is_finite() || tau <= 0.0 {
233 return 1.0;
234 }
235 let total = (m * n) as f64;
236 (total / tau).clamp(1.0, total)
237 }
238
239 let mut chain_means = vec![0.0_f64; n_split_chains];
240 let mut chainvars = vec![0.0_f64; n_split_chains];
241 for d in 0..dim {
242 for chain in 0..n_chains {
243 let mut sum1 = 0.0;
245 for i in 0..half {
246 sum1 += samples[[chain, i, d]];
247 }
248 let mean1 = sum1 / half as f64;
249 let mut var1 = 0.0;
250 for i in 0..half {
251 let diff = samples[[chain, i, d]] - mean1;
252 var1 += diff * diff;
253 }
254 var1 /= (half - 1).max(1) as f64;
255 let first_idx = chain;
256 chain_means[first_idx] = mean1;
257 chainvars[first_idx] = var1;
258
259 let mut sum2 = 0.0;
261 for i in half..(2 * half) {
262 sum2 += samples[[chain, i, d]];
263 }
264 let mean2 = sum2 / half as f64;
265 let mut var2 = 0.0;
266 for i in half..(2 * half) {
267 let diff = samples[[chain, i, d]] - mean2;
268 var2 += diff * diff;
269 }
270 var2 /= (half - 1).max(1) as f64;
271 let second_idx = n_chains + chain;
272 chain_means[second_idx] = mean2;
273 chainvars[second_idx] = var2;
274 }
275
276 let w: f64 = chainvars.iter().copied().sum::<f64>() / n_split_chains as f64;
278
279 let overall_mean: f64 = chain_means.iter().copied().sum::<f64>() / n_split_chains as f64;
281 let b: f64 = chain_means
282 .iter()
283 .map(|m| (m - overall_mean).powi(2))
284 .sum::<f64>()
285 * n_split_samples as f64
286 / (n_split_chains - 1) as f64;
287
288 let var_hat = (n_split_samples as f64 - 1.0) / n_split_samples as f64 * w
290 + b / n_split_samples as f64;
291
292 let rhat_d = if w > 1e-10 { (var_hat / w).sqrt() } else { 1.0 };
294 max_rhat = max_rhat.max(rhat_d);
295
296 let ess_d = ess_from_split_dimension(samples, n_chains, half, d);
298 min_ess = min_ess.min(ess_d);
299 }
300
301 (max_rhat, min_ess.max(1.0))
302}
303
304fn solve_upper_triangular_transpose(l: &Array2<f64>, dim: usize) -> Array2<f64> {
325 let mut result = Array2::<f64>::zeros((dim, dim));
326 if dim == 0 {
327 return result;
328 }
329
330 let l_owned;
334 let l_rows: &[f64] = if let Some(s) = l.as_slice() {
335 s
336 } else {
337 l_owned = l.to_owned();
338 l_owned
339 .as_slice()
340 .expect("owned standard-layout Array2 has contiguous storage")
341 };
342
343 let mut y = vec![0.0_f64; dim];
345
346 for col in 0..dim {
347 let d_col = l_rows[col * dim + col];
350 let inv_d_col = if d_col.abs() > 1e-15 {
351 1.0 / d_col
352 } else {
353 0.0
354 };
355 y[col] = inv_d_col;
356
357 for i in (col + 1)..dim {
360 let row_off = i * dim;
361 let l_row = &l_rows[row_off + col..row_off + i];
362 let y_seg = &y[col..i];
363 let mut sum = 0.0_f64;
367 for k in 0..l_row.len() {
368 sum += l_row[k] * y_seg[k];
369 }
370 let d = l_rows[row_off + i];
371 y[i] = if d.abs() > 1e-15 { -sum / d } else { 0.0 };
372 }
373
374 let res_row_start = col * dim + col;
378 let res_row = &mut result.as_slice_mut().expect("owned Array2 contiguous")
379 [res_row_start..res_row_start + (dim - col)];
380 for (k, slot) in res_row.iter_mut().enumerate() {
381 *slot = y[col + k];
382 }
383
384 for slot in &mut y[col..dim] {
386 *slot = 0.0;
387 }
388 }
389
390 result
391}
392
393struct WhiteningTransform {
394 chol: Array2<f64>,
395 chol_t: Array2<f64>,
396}
397
398fn hessian_whitening_transform(
399 hessian: ArrayView2<f64>,
400 dim: usize,
401 cov_scale: f64,
402 cholesky_error_prefix: &str,
403) -> Result<WhiteningTransform, String> {
404 let hessian_owned = hessian.to_owned();
405 let chol_factor = hessian_owned
406 .cholesky(Side::Lower)
407 .map_err(|e| format!("{cholesky_error_prefix}: {:?}", e))?;
408 let l_h = chol_factor.lower_triangular();
409 let mut chol = solve_upper_triangular_transpose(&l_h, dim);
410 let sqrt_cov_scale = cov_scale.max(0.0).sqrt();
411 if (sqrt_cov_scale - 1.0).abs() > 0.0 {
412 chol.mapv_inplace(|v| v * sqrt_cov_scale);
413 }
414 let chol_t = chol.t().to_owned();
415 Ok(WhiteningTransform { chol, chol_t })
416}
417
418#[derive(Clone)]
423struct SharedData {
424 x: Arc<Array2<f64>>,
426 y: Arc<Array1<f64>>,
428 weights: Arc<Array1<f64>>,
430 mode: Arc<Array1<f64>>,
432 offset: Option<Arc<Array1<f64>>>,
440 gamma_shape: f64,
442 dispersion: gam_solve::model_types::Dispersion,
453 n_samples: usize,
455 dim: usize,
457}
458
459thread_local! {
460 static NUTS_RESIDUAL_SCRATCH: RefCell<Array1<f64>> = RefCell::new(Array1::zeros(0));
461}
462
463#[derive(Debug, Clone, Copy, PartialEq, Eq)]
469pub enum NutsFamily {
470 Gaussian,
471 BinomialLogit,
472 BinomialProbit,
473 BinomialCLogLog,
474 PoissonLog,
475 TweedieLog,
476 NegativeBinomialLog,
477 GammaLog,
478}
479
480impl NutsFamily {
481 #[inline]
482 fn likelihood_spec(self) -> LikelihoodSpec {
483 match self {
484 Self::Gaussian => LikelihoodSpec {
485 response: ResponseFamily::Gaussian,
486 link: InverseLink::Standard(StandardLink::Identity),
487 },
488 Self::BinomialLogit => LikelihoodSpec {
489 response: ResponseFamily::Binomial,
490 link: InverseLink::Standard(StandardLink::Logit),
491 },
492 Self::BinomialProbit => LikelihoodSpec {
493 response: ResponseFamily::Binomial,
494 link: InverseLink::Standard(StandardLink::Probit),
495 },
496 Self::BinomialCLogLog => LikelihoodSpec {
497 response: ResponseFamily::Binomial,
498 link: InverseLink::Standard(StandardLink::CLogLog),
499 },
500 Self::PoissonLog => LikelihoodSpec {
501 response: ResponseFamily::Poisson,
502 link: InverseLink::Standard(StandardLink::Log),
503 },
504 Self::TweedieLog => LikelihoodSpec {
505 response: ResponseFamily::Tweedie { p: 1.5 },
506 link: InverseLink::Standard(StandardLink::Log),
507 },
508 Self::NegativeBinomialLog => LikelihoodSpec {
509 response: ResponseFamily::NegativeBinomial {
510 theta: 1.0,
511 theta_fixed: false,
512 },
513 link: InverseLink::Standard(StandardLink::Log),
514 },
515 Self::GammaLog => LikelihoodSpec {
516 response: ResponseFamily::Gamma,
517 link: InverseLink::Standard(StandardLink::Log),
518 },
519 }
520 }
521
522 #[inline]
552 fn coefficient_covariance_scale(self, profiled_gaussian_phi: f64) -> f64 {
553 match self {
554 NutsFamily::Gaussian => profiled_gaussian_phi,
555 _ => 1.0,
556 }
557 }
558}
559
560pub struct NutsPosterior {
573 data: SharedData,
575 chol: Array2<f64>,
578 chol_t: Array2<f64>,
580 nuts_family: NutsFamily,
582 firth_enabled: bool,
585 penalty_z_quad: Array2<f64>,
591 penalty_z_lin: Array1<f64>,
594 penalty_z_const: f64,
596 cov_scale: f64,
602}
603
604impl NutsPosterior {
605 pub fn new(
620 x: ArrayView2<f64>,
621 y: ArrayView1<f64>,
622 weights: ArrayView1<f64>,
623 penalty_matrix: ArrayView2<f64>,
624 mode: ArrayView1<f64>,
625 hessian: ArrayView2<f64>,
626 nuts_family: NutsFamily,
627 gamma_shape: f64,
628 dispersion: gam_solve::model_types::Dispersion,
629 firth_enabled: bool,
630 ) -> Result<Self, String> {
631 let n_samples = x.nrows();
632 let dim = x.ncols();
633
634 if !penalty_matrix.iter().all(|x| x.is_finite()) {
636 return Err(HmcError::NonFiniteState {
637 reason: "Penalty matrix contains NaN or Inf values".to_string(),
638 }
639 .into());
640 }
641 if !hessian.iter().all(|x| x.is_finite()) {
642 return Err(HmcError::NonFiniteState {
643 reason: "Hessian matrix contains NaN or Inf values".to_string(),
644 }
645 .into());
646 }
647 if !mode.iter().all(|x| x.is_finite()) {
648 return Err(HmcError::NonFiniteState {
649 reason: "Mode vector contains NaN or Inf values".to_string(),
650 }
651 .into());
652 }
653
654 validate_firth_support(nuts_family, firth_enabled).map_err(String::from)?;
655 if nuts_family.likelihood_spec().is_binomial() {
656 validate_binary_responses("binomial NUTS", &y, &weights).map_err(String::from)?;
657 }
658 if matches!(nuts_family, NutsFamily::NegativeBinomialLog) {
659 validate_count_responses("negative-binomial NUTS", &y, &weights)
660 .map_err(String::from)?;
661 }
662
663 let cov_scale = nuts_family.coefficient_covariance_scale(dispersion.phi());
672 let whitening = hessian_whitening_transform(
673 hessian,
674 dim,
675 cov_scale,
676 "Hessian Cholesky decomposition failed",
677 )?;
678 let chol = whitening.chol;
679 let chol_t = whitening.chol_t;
680
681 let penalty_owned = penalty_matrix.to_owned();
689 let mode_owned = mode.to_owned();
690 let s_mu = penalty_owned.dot(&mode_owned);
691 let penalty_z_const = 0.5 * mode_owned.dot(&s_mu);
692 let penalty_z_lin = chol_t.dot(&s_mu);
693 let s_chol = penalty_owned.dot(&chol);
696 let penalty_z_quad = chol_t.dot(&s_chol);
697
698 let data = SharedData {
699 x: Arc::new(x.to_owned()),
700 y: Arc::new(y.to_owned()),
701 weights: Arc::new(weights.to_owned()),
702 mode: Arc::new(mode_owned),
703 offset: None,
704 gamma_shape,
705 dispersion,
706 n_samples,
707 dim,
708 };
709
710 Ok(Self {
711 data,
712 chol,
713 chol_t,
714 nuts_family,
715 firth_enabled,
716 penalty_z_quad,
717 penalty_z_lin,
718 penalty_z_const,
719 cov_scale,
720 })
721 }
722
723 fn with_offset(mut self, offset: ArrayView1<f64>) -> Result<Self, String> {
733 if offset.len() != self.data.n_samples {
734 return Err(HmcError::DimensionMismatch {
735 reason: format!(
736 "NUTS offset length {} does not match {} observations",
737 offset.len(),
738 self.data.n_samples
739 ),
740 }
741 .into());
742 }
743 if !offset.iter().all(|v| v.is_finite()) {
744 return Err(HmcError::NonFiniteState {
745 reason: "NUTS offset contains NaN or Inf values".to_string(),
746 }
747 .into());
748 }
749 self.data.offset = Some(Arc::new(offset.to_owned()));
750 Ok(self)
751 }
752
753 fn compute_logp_and_grad_nd_into(
754 &self,
755 z: &Array1<f64>,
756 residual: &mut Array1<f64>,
757 grad: &mut Array1<f64>,
758 ) -> f64 {
759 let beta = self.data.mode.as_ref() + &self.chol.dot(z);
762
763 let mut eta = gam_linalg::faer_ndarray::fast_av(self.data.x.as_ref(), &beta);
765 if let Some(offset) = self.data.offset.as_ref() {
766 eta += offset.as_ref();
767 }
768
769 let (ll, mut grad_ll_beta) = self.family_logp_and_grad_into(&eta, residual);
771
772 let mut firth_logdet = 0.0;
773 if self.firth_enabled {
774 match firth_jeffreys_logp_and_grad(self.nuts_family, &self.data, &eta) {
775 Ok((value, grad_beta_firth)) => {
776 firth_logdet = value;
777 grad_ll_beta += &grad_beta_firth;
778 }
779 Err(err) => {
780 log::warn!(
781 "[NUTS/Firth] Jeffreys target became invalid at the current state: {}",
782 err
783 );
784 grad.fill(0.0);
785 return f64::NEG_INFINITY;
786 }
787 }
788 }
789
790 let penalty_scale = 1.0 / self.cov_scale.max(1e-300);
814 let mz = self.penalty_z_quad.dot(z);
815 let lin_term = self.penalty_z_lin.dot(z);
816 let quad_term = 0.5 * z.dot(&mz);
817 let penalty = penalty_scale * (self.penalty_z_const + lin_term + quad_term);
818
819 fast_av_into(&self.chol_t, &grad_ll_beta, grad);
822 let lin_view = self.penalty_z_lin.view();
824 ndarray::Zip::from(grad)
825 .and(&lin_view)
826 .and(&mz)
827 .par_for_each(|g, &l, &m| {
828 *g -= penalty_scale * (l + m);
829 });
830
831 ll + firth_logdet - penalty
832 }
833
834 fn family_logp_and_grad_into(
835 &self,
836 eta: &Array1<f64>,
837 residual: &mut Array1<f64>,
838 ) -> (f64, Array1<f64>) {
839 nuts_family_logp_and_grad_into(self.nuts_family, &self.data, eta, residual)
840 }
841
842 pub fn chol(&self) -> &Array2<f64> {
844 &self.chol
845 }
846
847 pub fn mode(&self) -> &Array1<f64> {
849 &self.data.mode
850 }
851
852 pub fn dim(&self) -> usize {
854 self.data.dim
855 }
856}
857
858const HALF_LOG_2PI: f64 = 0.918_938_533_204_672_7;
859
860#[inline]
861fn standard_normal_log_pdf(x: f64) -> f64 {
862 -0.5 * x * x - HALF_LOG_2PI
863}
864
865#[inline]
867fn log_ndtr(x: f64) -> f64 {
868 let arg = -x * std::f64::consts::FRAC_1_SQRT_2;
869 let erfc_val = statrs::function::erf::erfc(arg);
870 if erfc_val > 0.0 {
871 erfc_val.ln() - std::f64::consts::LN_2
872 } else {
873 -0.5 * x * x - (-x).ln() - HALF_LOG_2PI
874 }
875}
876
877#[inline]
878fn validate_firth_support(family: NutsFamily, firth_enabled: bool) -> Result<(), HmcError> {
879 let spec = family.likelihood_spec();
880 if firth_enabled && !likelihood_spec_supports_firth(&spec) {
881 return Err(HmcError::FirthUnsupported {
882 reason: format!(
883 "NUTS with Firth requires a Binomial inverse link with a Fisher-weight jet; {} does not support it",
884 spec.pretty_name()
885 ),
886 });
887 }
888 Ok::<(), _>(())
889}
890
891#[inline]
892fn validate_firth_likelihood_support(
893 likelihood: &LikelihoodSpec,
894 firth_enabled: bool,
895) -> Result<(), HmcError> {
896 if firth_enabled && !likelihood_spec_supports_firth(likelihood) {
897 return Err(HmcError::FirthUnsupported {
898 reason: format!(
899 "Joint HMC with Firth requires a Binomial inverse link with a Fisher-weight jet; {} does not support it",
900 likelihood.pretty_name()
901 ),
902 });
903 }
904 Ok::<(), _>(())
905}
906
907#[inline]
908fn valid_count_response(y: f64) -> bool {
909 y.is_finite() && y >= 0.0 && (y - y.round()).abs() <= 1e-9
910}
911
912fn validate_count_responses(
913 family: &str,
914 y: &ArrayView1<'_, f64>,
915 weights: &ArrayView1<'_, f64>,
916) -> Result<(), HmcError> {
917 for (i, (&yi, &wi)) in y.iter().zip(weights.iter()).enumerate() {
918 if wi > 0.0 && !valid_count_response(yi) {
919 return Err(HmcError::InvalidConfig {
920 reason: format!(
921 "{family} response must be a finite non-negative integer at positive-weight row {i}; got {yi}"
922 ),
923 });
924 }
925 }
926 Ok(())
927}
928
929fn validate_binary_responses(
930 family: &str,
931 y: &ArrayView1<'_, f64>,
932 weights: &ArrayView1<'_, f64>,
933) -> Result<(), HmcError> {
934 for (i, (&yi, &wi)) in y.iter().zip(weights.iter()).enumerate() {
935 if wi > 0.0 && !(yi == 0.0 || yi == 1.0) {
936 return Err(HmcError::InvalidConfig {
937 reason: format!(
938 "{family} response must be exactly 0 or 1 at positive-weight row {i}; got {yi}"
939 ),
940 });
941 }
942 }
943 Ok(())
944}
945
946fn firth_jeffreys_logp_and_grad(
953 family: NutsFamily,
954 data: &SharedData,
955 eta: &Array1<f64>,
956) -> Result<(f64, Array1<f64>), HmcError> {
957 if eta.len() != data.n_samples {
958 return Err(HmcError::DimensionMismatch {
959 reason: format!(
960 "Firth Jeffreys term eta length {} != number of samples {}",
961 eta.len(),
962 data.n_samples
963 ),
964 });
965 }
966 if data.dim == 0 || data.n_samples == 0 {
967 return Ok((0.0, Array1::zeros(data.dim)));
968 }
969 validate_firth_support(family, true)?;
970 if data.weights.iter().all(|w| *w == 0.0) {
971 return Ok((0.0, Array1::zeros(data.dim)));
972 }
973
974 let jeffreys_link =
975 likelihood_spec_jeffreys_link(&family.likelihood_spec()).ok_or_else(|| {
976 HmcError::FirthUnsupported {
977 reason: format!(
978 "Firth Jeffreys term has no Fisher-weight jet for {}",
979 family.likelihood_spec().pretty_name()
980 ),
981 }
982 })?;
983 let op = if data.weights.iter().all(|&w| w == 1.0) {
984 FirthDenseOperator::build_for_link(&jeffreys_link, data.x.as_ref(), eta)
985 } else {
986 FirthDenseOperator::build_with_observation_weights_for_link(
987 &jeffreys_link,
988 data.x.as_ref(),
989 eta,
990 data.weights.view(),
991 )
992 }
993 .map_err(|e| HmcError::SamplingFailed {
994 reason: format!("Firth Jeffreys operator failed: {e}"),
995 })?;
996 Ok(op.jeffreys_logdet_and_beta_gradient())
997}
998
999fn nuts_family_logp_and_grad_into(
1008 family: NutsFamily,
1009 data: &SharedData,
1010 eta: &Array1<f64>,
1011 residual: &mut Array1<f64>,
1012) -> (f64, Array1<f64>) {
1013 match family {
1014 NutsFamily::BinomialLogit => logit_logp_and_grad_into(data, eta, residual),
1015 NutsFamily::BinomialProbit => probit_logp_and_grad_into(data, eta, residual),
1016 NutsFamily::BinomialCLogLog => cloglog_logp_and_grad_into(data, eta, residual),
1017 NutsFamily::Gaussian => gaussian_logp_and_grad_into(data, eta, residual),
1018 NutsFamily::PoissonLog => poisson_log_logp_and_grad(data, eta),
1019 NutsFamily::TweedieLog => tweedie_log_quasilogp_and_grad(data, eta, data.gamma_shape),
1022 NutsFamily::NegativeBinomialLog => {
1023 negative_binomial_log_logp_and_grad(data, eta, data.gamma_shape)
1026 }
1027 NutsFamily::GammaLog => gamma_log_logp_and_grad(data, eta),
1028 }
1029}
1030
1031#[derive(Clone, Debug)]
1032struct BinomialLinkTerms {
1033 log_mu: f64,
1034 log1m_mu: f64,
1035 dlog_mu_deta: f64,
1036 dlog1m_mu_deta: f64,
1037 dmu_dlink: Vec<f64>,
1038}
1039
1040#[inline]
1041fn log_terms_from_mu_and_dmu(
1042 mu: f64,
1043 dmu_deta: f64,
1044 dmu_dlink: Vec<f64>,
1045) -> Result<BinomialLinkTerms, String> {
1046 if !(mu.is_finite() && (0.0..=1.0).contains(&mu) && dmu_deta.is_finite()) {
1047 return Err(format!(
1048 "binomial inverse link returned invalid mu/deta derivative: mu={mu}, dmu_deta={dmu_deta}"
1049 ));
1050 }
1051 let log_mu = if mu == 0.0 {
1052 f64::NEG_INFINITY
1053 } else {
1054 mu.ln()
1055 };
1056 let one_minus_mu = 1.0 - mu;
1057 let log1m_mu = if one_minus_mu == 0.0 {
1058 f64::NEG_INFINITY
1059 } else {
1060 one_minus_mu.ln()
1061 };
1062 let dlog_mu_deta = if mu == 0.0 {
1063 f64::INFINITY.copysign(dmu_deta)
1064 } else {
1065 dmu_deta / mu
1066 };
1067 let dlog1m_mu_deta = if one_minus_mu == 0.0 {
1068 f64::NEG_INFINITY.copysign(dmu_deta)
1069 } else {
1070 -dmu_deta / one_minus_mu
1071 };
1072 Ok(BinomialLinkTerms {
1073 log_mu,
1074 log1m_mu,
1075 dlog_mu_deta,
1076 dlog1m_mu_deta,
1077 dmu_dlink,
1078 })
1079}
1080
1081#[inline]
1082fn binomial_link_terms(
1083 inverse_link: &InverseLink,
1084 eta: f64,
1085 n_link_params: usize,
1086) -> Result<BinomialLinkTerms, String> {
1087 let jet =
1088 inverse_link_jet_for_inverse_link(inverse_link, eta).map_err(|err| err.to_string())?;
1089 let mut dmu_dlink = vec![0.0; n_link_params];
1090 if n_link_params > 0 {
1091 match inverse_link
1092 .param_partials(eta)
1093 .map_err(|err| err.to_string())?
1094 {
1095 Some(LinkParamPartials::Sas(partials)) => {
1096 if n_link_params != 2 {
1097 return Err(format!(
1098 "SAS/Beta-Logistic link parameter dimension mismatch: expected 2, got {n_link_params}"
1099 ));
1100 }
1101 dmu_dlink[0] = partials.djet_depsilon.mu;
1102 dmu_dlink[1] = partials.djet_dlog_delta.mu;
1103 }
1104 Some(LinkParamPartials::Mixture(partials)) => {
1105 if partials.djet_drho.len() != n_link_params {
1106 return Err(format!(
1107 "mixture link parameter dimension mismatch: expected {}, got {n_link_params}",
1108 partials.djet_drho.len()
1109 ));
1110 }
1111 for (slot, partial) in dmu_dlink.iter_mut().zip(partials.djet_drho.iter()) {
1112 *slot = partial.mu;
1113 }
1114 }
1115 None => {
1116 return Err(format!(
1117 "joint HMC expected {n_link_params} adaptive link parameters, but the inverse link exposes none"
1118 ));
1119 }
1120 }
1121 }
1122 log_terms_from_mu_and_dmu(jet.mu, jet.d1, dmu_dlink)
1123}
1124
1125fn joint_binomial_logp_grad_and_link_grad(
1126 inverse_link: &InverseLink,
1127 data: &SharedData,
1128 eta: &Array1<f64>,
1129 n_link_params: usize,
1130) -> Result<(f64, Array1<f64>, Array1<f64>), String> {
1131 let n = data.n_samples;
1132 use rayon::iter::{IntoParallelIterator, ParallelIterator};
1136 let per_row: Result<Vec<(f64, f64, Vec<f64>)>, String> = (0..n)
1137 .into_par_iter()
1138 .map(|i| {
1139 let y_i = data.y[i];
1140 let w_i = data.weights[i];
1141 if w_i <= 0.0 {
1142 return Ok((0.0, 0.0, vec![0.0; n_link_params]));
1143 }
1144 let terms = binomial_link_terms(inverse_link, eta[i], n_link_params)?;
1145 if y_i == 1.0 {
1146 let inv_mu = terms.log_mu.exp().recip();
1147 let log_mu = terms.log_mu;
1148 let dlog_mu_deta = terms.dlog_mu_deta;
1149 let grad_link = terms
1150 .dmu_dlink
1151 .into_iter()
1152 .map(|dmu| w_i * dmu * inv_mu)
1153 .collect();
1154 Ok((w_i * log_mu, w_i * dlog_mu_deta, grad_link))
1155 } else if y_i == 0.0 {
1156 let inv_one_minus_mu = terms.log1m_mu.exp().recip();
1157 let log1m_mu = terms.log1m_mu;
1158 let dlog1m_mu_deta = terms.dlog1m_mu_deta;
1159 let grad_link = terms
1160 .dmu_dlink
1161 .into_iter()
1162 .map(|dmu| -w_i * dmu * inv_one_minus_mu)
1163 .collect();
1164 Ok((w_i * log1m_mu, w_i * dlog1m_mu_deta, grad_link))
1165 } else {
1166 Err(format!(
1167 "binomial joint HMC response must be exactly 0 or 1 after validation; got {y_i}"
1168 ))
1169 }
1170 })
1171 .collect();
1172 let per_row = per_row?;
1173 let mut residual = Array1::<f64>::zeros(n);
1174 let mut grad_link = Array1::<f64>::zeros(n_link_params);
1175 let mut ll = 0.0;
1176 for (i, (ll_i, residual_i, grad_link_i)) in per_row.into_iter().enumerate() {
1177 ll += ll_i;
1178 residual[i] = residual_i;
1179 for (slot, value) in grad_link.iter_mut().zip(grad_link_i.iter()) {
1180 *slot += *value;
1181 }
1182 }
1183
1184 Ok((ll, fast_atv(&data.x, &residual), grad_link))
1185}
1186
1187fn joint_binomial_logp_and_grad(
1188 likelihood: &LikelihoodSpec,
1189 data: &SharedData,
1190 eta: &Array1<f64>,
1191) -> Result<(f64, Array1<f64>), String> {
1192 if !matches!(likelihood.response, ResponseFamily::Binomial) {
1193 return Err(HmcError::UnsupportedFamily {
1194 reason: format!(
1195 "{} is not a binomial joint-HMC family",
1196 likelihood.pretty_name()
1197 ),
1198 }
1199 .into());
1200 }
1201 match &likelihood.link {
1202 InverseLink::Standard(StandardLink::Logit) => Ok(logit_logp_and_grad(data, eta)),
1203 InverseLink::Standard(StandardLink::Probit) => Ok(probit_logp_and_grad(data, eta)),
1204 InverseLink::Standard(StandardLink::CLogLog) => Ok(cloglog_logp_and_grad(data, eta)),
1205 InverseLink::LatentCLogLog(_)
1206 | InverseLink::Sas(_)
1207 | InverseLink::BetaLogistic(_)
1208 | InverseLink::Mixture(_) => {
1209 let (ll, grad_beta, _) =
1210 joint_binomial_logp_grad_and_link_grad(&likelihood.link, data, eta, 0)?;
1211 Ok((ll, grad_beta))
1212 }
1213 InverseLink::Standard(_) => Err(HmcError::UnsupportedFamily {
1214 reason: format!(
1215 "{} is not a binomial joint-HMC family",
1216 likelihood.pretty_name()
1217 ),
1218 }
1219 .into()),
1220 }
1221}
1222
1223fn joint_family_logp_grad_and_link_grad(
1224 likelihood: &LikelihoodSpec,
1225 data: &SharedData,
1226 eta: &Array1<f64>,
1227 n_link_params: usize,
1228) -> Result<(f64, Array1<f64>, Array1<f64>), String> {
1229 match (&likelihood.response, &likelihood.link) {
1230 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
1231 let (ll, grad) = logit_logp_and_grad(data, eta);
1232 Ok((ll, grad, Array1::zeros(n_link_params)))
1233 }
1234 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit)) => {
1235 let (ll, grad) = probit_logp_and_grad(data, eta);
1236 Ok((ll, grad, Array1::zeros(n_link_params)))
1237 }
1238 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {
1239 let (ll, grad) = cloglog_logp_and_grad(data, eta);
1240 Ok((ll, grad, Array1::zeros(n_link_params)))
1241 }
1242 (
1243 ResponseFamily::Binomial,
1244 InverseLink::LatentCLogLog(_)
1245 | InverseLink::Sas(_)
1246 | InverseLink::BetaLogistic(_)
1247 | InverseLink::Mixture(_),
1248 ) => joint_binomial_logp_grad_and_link_grad(&likelihood.link, data, eta, n_link_params),
1249 _ => {
1250 let (ll, grad) = joint_family_logp_and_grad(likelihood, data, eta)?;
1251 Ok((ll, grad, Array1::zeros(n_link_params)))
1252 }
1253 }
1254}
1255
1256fn joint_family_logp_and_grad(
1257 likelihood: &LikelihoodSpec,
1258 data: &SharedData,
1259 eta: &Array1<f64>,
1260) -> Result<(f64, Array1<f64>), String> {
1261 match &likelihood.response {
1262 ResponseFamily::Binomial => joint_binomial_logp_and_grad(likelihood, data, eta),
1263 ResponseFamily::Gaussian => Ok(gaussian_logp_and_grad(data, eta)),
1264 ResponseFamily::Poisson => Ok(poisson_log_logp_and_grad(data, eta)),
1265 ResponseFamily::Tweedie { p } => {
1266 let p = *p;
1269 if !is_valid_tweedie_power(p) {
1270 return Err(HmcError::InvalidConfig {
1271 reason: format!(
1272 "Tweedie variance power must be finite and strictly between 1 and 2; got {p}"
1273 ),
1274 }
1275 .into());
1276 }
1277 Ok(tweedie_log_quasilogp_and_grad(data, eta, p))
1278 }
1279 ResponseFamily::NegativeBinomial { theta, .. } => {
1280 Ok(negative_binomial_log_logp_and_grad(data, eta, *theta))
1283 }
1284 ResponseFamily::Beta { .. } => Err(HmcError::UnsupportedFamily {
1285 reason: "Joint HMC fallback is not implemented for BetaLogit".to_string(),
1286 }
1287 .into()),
1288 ResponseFamily::Gamma => Ok(gamma_log_logp_and_grad(data, eta)),
1289 ResponseFamily::RoystonParmar => Err(HmcError::UnsupportedFamily {
1290 reason: "Joint HMC fallback is not implemented for RoystonParmar".to_string(),
1291 }
1292 .into()),
1293 }
1294}
1295
1296fn logit_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1300 let mut residual = Array1::<f64>::zeros(data.n_samples);
1301 logit_logp_and_grad_into(data, eta, &mut residual)
1302}
1303
1304fn logit_logp_and_grad_into(
1305 data: &SharedData,
1306 eta: &Array1<f64>,
1307 residual: &mut Array1<f64>,
1308) -> (f64, Array1<f64>) {
1309 use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1310 let n = data.n_samples;
1311 assert_eq!(residual.len(), n);
1312 let ll: f64 = residual
1316 .as_slice_mut()
1317 .unwrap()
1318 .par_iter_mut()
1319 .enumerate()
1320 .map(|(i, slot)| {
1321 let eta_i = eta[i];
1322 let y_i = data.y[i];
1323 let w_i = data.weights[i];
1324 let mu = gam_linalg::utils::stable_logistic(eta_i);
1325 *slot = w_i * (y_i - mu);
1326 w_i * (y_i * eta_i - gam_linalg::utils::stable_softplus(eta_i))
1327 })
1328 .sum();
1329
1330 let grad_ll = fast_atv(data.x.as_ref(), &*residual);
1331 (ll, grad_ll)
1332}
1333
1334fn probit_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1341 let mut residual = Array1::<f64>::zeros(data.n_samples);
1342 probit_logp_and_grad_into(data, eta, &mut residual)
1343}
1344
1345fn probit_logp_and_grad_into(
1346 data: &SharedData,
1347 eta: &Array1<f64>,
1348 residual: &mut Array1<f64>,
1349) -> (f64, Array1<f64>) {
1350 use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1351 let n = data.n_samples;
1352 assert_eq!(residual.len(), n);
1353 let ll: f64 = residual
1354 .as_slice_mut()
1355 .unwrap()
1356 .par_iter_mut()
1357 .enumerate()
1358 .map(|(i, slot)| {
1359 let eta_i = eta[i];
1360 let y_i = data.y[i];
1361 let w_i = data.weights[i];
1362 let log_phi_pos = log_ndtr(eta_i);
1363 let log_phi_neg = log_ndtr(-eta_i);
1364 let log_phi_val = standard_normal_log_pdf(eta_i);
1365 let ratio_pos = (log_phi_val - log_phi_pos).exp();
1366 let ratio_neg = (log_phi_val - log_phi_neg).exp();
1367 let grad_i = y_i * ratio_pos - (1.0 - y_i) * ratio_neg;
1368 *slot = w_i * grad_i;
1369 w_i * (y_i * log_phi_pos + (1.0 - y_i) * log_phi_neg)
1370 })
1371 .sum();
1372
1373 let grad_ll = fast_atv(data.x.as_ref(), &*residual);
1374 (ll, grad_ll)
1375}
1376
1377#[inline]
1383fn cloglog_bernoulli_logp_and_residual(eta: f64, y: f64) -> Result<(f64, f64), EstimationError> {
1384 if !(eta.is_finite() && (-700.0..=700.0).contains(&eta)) {
1385 gam_problem::bail_invalid_estim!(
1386 "cloglog eta must be finite and within [-700, 700]; got {eta}"
1387 );
1388 }
1389 let exp_eta = eta.exp();
1390 let log_mu = crate::probability::log1mexp_positive(exp_eta);
1393 let log_one_minus_mu = -exp_eta;
1394 let grad_log_mu = (eta - exp_eta - log_mu).exp();
1395 let ll_i = y * log_mu + (1.0 - y) * log_one_minus_mu;
1396 let residual_i = y * grad_log_mu - (1.0 - y) * exp_eta;
1397 Ok((ll_i, residual_i))
1398}
1399
1400fn cloglog_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1401 let mut residual = Array1::<f64>::zeros(data.n_samples);
1402 cloglog_logp_and_grad_into(data, eta, &mut residual)
1403}
1404
1405fn cloglog_logp_and_grad_into(
1406 data: &SharedData,
1407 eta: &Array1<f64>,
1408 residual: &mut Array1<f64>,
1409) -> (f64, Array1<f64>) {
1410 use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1411 let n = data.n_samples;
1412 assert_eq!(residual.len(), n);
1413 if eta
1414 .iter()
1415 .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1416 {
1417 return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1418 }
1419 let ll: f64 = residual
1420 .as_slice_mut()
1421 .unwrap()
1422 .par_iter_mut()
1423 .enumerate()
1424 .map(|(i, slot)| {
1425 let y_i = data.y[i];
1426 let w_i = data.weights[i];
1427 let (ll_i, residual_i) =
1428 cloglog_bernoulli_logp_and_residual(eta[i], y_i).expect("validated cloglog eta");
1429 *slot = w_i * residual_i;
1430 w_i * ll_i
1431 })
1432 .sum();
1433
1434 let grad_ll = fast_atv(data.x.as_ref(), &*residual);
1435 (ll, grad_ll)
1436}
1437
1438fn gaussian_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1450 let mut weighted_residual = Array1::<f64>::zeros(data.n_samples);
1451 gaussian_logp_and_grad_into(data, eta, &mut weighted_residual)
1452}
1453
1454fn gaussian_logp_and_grad_into(
1455 data: &SharedData,
1456 eta: &Array1<f64>,
1457 weighted_residual: &mut Array1<f64>,
1458) -> (f64, Array1<f64>) {
1459 use gam_problem::dispersion_cov::DispersionExt as _;
1460 use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1461 let n = data.n_samples;
1462 let inv_phi = data.dispersion.inv_phi();
1463 assert_eq!(weighted_residual.len(), n);
1464 let ll: f64 = weighted_residual
1467 .as_slice_mut()
1468 .unwrap()
1469 .par_iter_mut()
1470 .enumerate()
1471 .map(|(i, slot)| {
1472 let residual = data.y[i] - eta[i];
1473 let w_i = data.weights[i];
1474 let scaled = w_i * inv_phi;
1475 *slot = scaled * residual;
1476 -0.5 * scaled * residual * residual
1477 })
1478 .sum();
1479
1480 let grad_ll = fast_atv(data.x.as_ref(), &*weighted_residual);
1481 (ll, grad_ll)
1482}
1483
1484fn poisson_log_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1488 use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1489 let n = data.n_samples;
1490 if eta
1491 .iter()
1492 .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1493 {
1494 return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1495 }
1496 let mut residual = Array1::<f64>::zeros(n);
1497 let ll: f64 = residual
1498 .as_slice_mut()
1499 .unwrap()
1500 .par_iter_mut()
1501 .enumerate()
1502 .map(|(i, slot)| {
1503 let eta_i = eta[i];
1504 let mu_i = eta_i.exp();
1505 let y_i = data.y[i];
1506 let w_i = data.weights[i];
1507 *slot = w_i * (y_i - mu_i);
1508 w_i * (y_i * eta_i - mu_i)
1509 })
1510 .sum();
1511
1512 let grad_ll = fast_atv(&data.x, &residual);
1513 (ll, grad_ll)
1514}
1515
1516fn tweedie_log_quasilogp_and_grad(
1517 data: &SharedData,
1518 eta: &Array1<f64>,
1519 p: f64,
1520) -> (f64, Array1<f64>) {
1521 use gam_problem::dispersion_cov::DispersionExt as _;
1522 use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1523 let n = data.n_samples;
1524 if !is_valid_tweedie_power(p) {
1527 return (f64::NAN, Array1::from_elem(data.dim, f64::NAN));
1528 }
1529 if eta
1530 .iter()
1531 .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1532 {
1533 return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1534 }
1535 let inv_phi = data.dispersion.inv_phi();
1536 let mut residual = Array1::<f64>::zeros(n);
1537 let ll: f64 = residual
1538 .as_slice_mut()
1539 .unwrap()
1540 .par_iter_mut()
1541 .enumerate()
1542 .map(|(i, slot)| {
1543 let eta_i = eta[i];
1544 let mu_i = eta_i.exp().max(1e-300);
1545 let y_i = data.y[i];
1546 let w_i = data.weights[i] * inv_phi;
1547 *slot = w_i * (y_i - mu_i) * mu_i.powf(1.0 - p);
1548 let qll = y_i * mu_i.powf(1.0 - p) / (1.0 - p) - mu_i.powf(2.0 - p) / (2.0 - p);
1549 w_i * qll
1550 })
1551 .sum();
1552
1553 let grad_ll = fast_atv(&data.x, &residual);
1554 (ll, grad_ll)
1555}
1556
1557fn negative_binomial_log_logp_and_grad(
1558 data: &SharedData,
1559 eta: &Array1<f64>,
1560 theta: f64,
1561) -> (f64, Array1<f64>) {
1562 use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1563 let n = data.n_samples;
1564 if !(theta.is_finite() && theta > 0.0)
1565 || eta
1566 .iter()
1567 .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1568 || data
1569 .y
1570 .iter()
1571 .zip(data.weights.iter())
1572 .any(|(&y_i, &w_i)| w_i > 0.0 && !valid_count_response(y_i))
1573 {
1574 return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1575 }
1576 let mut residual = Array1::<f64>::zeros(n);
1577 let ll: f64 = residual
1578 .as_slice_mut()
1579 .unwrap()
1580 .par_iter_mut()
1581 .enumerate()
1582 .map(|(i, slot)| {
1583 let eta_i = eta[i];
1584 let mu_i = eta_i.exp().max(1e-12);
1585 let y_i = data.y[i];
1586 let w_i = data.weights[i];
1587 if w_i <= 0.0 {
1588 *slot = 0.0;
1589 return 0.0;
1590 }
1591 let log_mu_term = if y_i > 0.0 { y_i * mu_i.ln() } else { 0.0 };
1592 *slot = w_i * theta * (y_i - mu_i) / (theta + mu_i);
1593 w_i * (statrs::function::gamma::ln_gamma(y_i + theta)
1594 - statrs::function::gamma::ln_gamma(theta)
1595 - statrs::function::gamma::ln_gamma(y_i + 1.0)
1596 + theta * (theta.ln() - (theta + mu_i).ln())
1597 + log_mu_term
1598 - y_i * (theta + mu_i).ln())
1599 })
1600 .sum();
1601
1602 let grad_ll = fast_atv(&data.x, &residual);
1603 (ll, grad_ll)
1604}
1605
1606fn gamma_log_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1607 use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1608 let n = data.n_samples;
1609 if eta
1610 .iter()
1611 .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1612 {
1613 return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1614 }
1615 let shape = data.gamma_shape.max(1e-10);
1616 let shape_ln_shape = shape * shape.ln();
1621 let log_gamma_shape = statrs::function::gamma::ln_gamma(shape);
1622 let shape_minus_one = shape - 1.0;
1623 let mut residual = Array1::<f64>::zeros(n);
1624 let ll: f64 = residual
1625 .as_slice_mut()
1626 .unwrap()
1627 .par_iter_mut()
1628 .enumerate()
1629 .map(|(i, slot)| {
1630 let eta_i = eta[i];
1631 let mu_i = eta_i.exp();
1632 let y_i = data.y[i];
1633 let w_i = data.weights[i];
1634 let ll_i = w_i
1635 * (shape_ln_shape - log_gamma_shape - shape * eta_i
1636 + shape_minus_one * y_i.max(1e-12).ln()
1637 - shape * y_i / mu_i);
1638 *slot = w_i * shape * (y_i / mu_i - 1.0);
1639 ll_i
1640 })
1641 .sum();
1642
1643 let grad_ll = fast_atv(&data.x, &residual);
1644 (ll, grad_ll)
1645}
1646
1647#[cfg(test)]
1648mod tests {
1649 use super::{
1650 FamilyNutsInputs, GlmFlatInputs, JointBetaRhoInputs, JointBetaRhoPosterior,
1651 LinkWigglePosterior, LinkWiggleSplineArtifacts, NutsConfig, NutsFamily, NutsPosterior,
1652 SharedData, cloglog_bernoulli_logp_and_residual, firth_jeffreys_logp_and_grad,
1653 joint_family_logp_and_grad, laplace_directional_cubic_diagnostic,
1654 laplace_skewness_threshold, laplace_trustworthiness_from_skewness,
1655 run_joint_beta_rho_sampling, run_logit_polya_gamma_gibbs,
1656 run_nuts_sampling_flattened_family,
1657 };
1658 use gam_linalg::matrix::DesignMatrix;
1659 use gam_models::survival::{PenaltyBlocks, SurvivalMonotonicityPenalty, SurvivalSpec};
1660 use gam_problem::types::{
1661 InverseLink, LikelihoodScaleMetadata, LikelihoodSpec, LogLikelihoodNormalization,
1662 ResponseFamily, RhoPrior, StandardLink,
1663 };
1664 use gam_solve::estimate::{
1665 BlockRole, FitGeometry, FitInference, FittedBlock, FittedLinkState, UnifiedFitResult,
1666 UnifiedFitResultParts,
1667 };
1668 use gam_terms::construction::CanonicalPenalty;
1669 use general_mcmc::generic_hmc::HamiltonianTarget;
1670 use ndarray::{Array1, Array2, array};
1671 use std::sync::Arc;
1672
1673 impl NutsPosterior {
1674 pub(super) fn compute_logp_and_grad_nd(&self, z: &Array1<f64>) -> (f64, Array1<f64>) {
1676 let mut residual = Array1::<f64>::zeros(self.data.n_samples);
1677 let mut grad = Array1::<f64>::zeros(z.len());
1678 let logp = self.compute_logp_and_grad_nd_into(z, &mut residual, &mut grad);
1679 (logp, grad)
1680 }
1681 }
1682
1683 impl LinkWigglePosterior {
1684 pub(super) fn compute_logp_and_grad(&self, z: &Array1<f64>) -> (f64, Array1<f64>) {
1686 let dim = self.p_base + self.p_link;
1687 let mut grad = Array1::<f64>::zeros(dim);
1688 let logp = self.compute_logp_and_grad_into(z, &mut grad);
1689 (logp, grad)
1690 }
1691 }
1692
1693 impl JointBetaRhoPosterior {
1694 pub(super) fn compute_joint_logp_and_grad(
1696 &self,
1697 params: &Array1<f64>,
1698 ) -> (f64, Array1<f64>) {
1699 let total_dim = self.n_beta + self.n_rho + self.n_link_params;
1700 let mut grad = Array1::<f64>::zeros(total_dim);
1701 let logp = self.compute_joint_logp_and_grad_into(params, &mut grad);
1702 (logp, grad)
1703 }
1704 }
1705
1706 fn hmc_test_fit(
1707 blocks: Vec<FittedBlock>,
1708 inference: Option<FitInference>,
1709 geometry: Option<FitGeometry>,
1710 ) -> UnifiedFitResult {
1711 let lambdas = Array1::zeros(0);
1712 UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
1713 blocks,
1714 log_lambdas: lambdas.clone(),
1715 lambdas,
1716 likelihood_family: Some(LikelihoodSpec::new(
1717 ResponseFamily::Gaussian,
1718 InverseLink::Standard(StandardLink::Identity),
1719 )),
1720 likelihood_scale: LikelihoodScaleMetadata::ProfiledGaussian,
1721 log_likelihood_normalization: LogLikelihoodNormalization::Full,
1722 log_likelihood: -1.0,
1723 deviance: 2.0,
1724 reml_score: 0.0,
1725 stable_penalty_term: 0.0,
1726 penalized_objective: 0.0,
1727 used_device: false,
1728 outer_iterations: 1,
1729 outer_converged: true,
1730 outer_gradient_norm: None,
1731 standard_deviation: 1.0,
1732 covariance_conditional: None,
1733 covariance_corrected: None,
1734 inference,
1735 fitted_link: FittedLinkState::Standard(None),
1736 geometry,
1737 block_states: Vec::new(),
1738 pirls_status: gam_solve::pirls::PirlsStatus::Converged,
1739 max_abs_eta: 0.0,
1740 constraint_kkt: None,
1741 artifacts: Default::default(),
1742 inner_cycles: 0,
1743 })
1744 .expect("valid HMC handoff test fit")
1745 }
1746
1747 #[test]
1748 fn hmc_whitening_consumes_standard_fit_inference_hessian() {
1749 let hessian = array![[2.0, 0.1], [0.1, 1.6]];
1750 let fit = hmc_test_fit(
1751 vec![FittedBlock {
1752 beta: array![0.05, -0.1],
1753 role: BlockRole::Mean,
1754 edf: 2.0,
1755 lambdas: Array1::zeros(0),
1756 }],
1757 Some(FitInference {
1758 edf_by_block: vec![],
1759 penalty_block_trace: vec![],
1760 edf_total: 2.0,
1761 smoothing_correction: None,
1762 penalized_hessian: hessian.clone().into(),
1763 working_weights: array![1.0, 1.0, 1.0],
1764 working_response: array![0.0, 0.1, -0.2],
1765 reparam_qs: None,
1766 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
1767 beta_covariance: None,
1768 beta_standard_errors: None,
1769 beta_covariance_corrected: None,
1770 beta_standard_errors_corrected: None,
1771 beta_covariance_frequentist: None,
1772 coefficient_influence: None,
1773 weighted_gram: None,
1774 bias_correction_beta: None,
1775 }),
1776 None,
1777 );
1778
1779 let explicit = super::explicit_fit_hessian_for_whitening(&fit, 2, "standard fit")
1780 .expect("standard fit exports explicit Hessian");
1781 assert_eq!(explicit, &hessian);
1782
1783 let x = array![[1.0, 0.0], [1.0, 0.5], [1.0, -0.5]];
1784 let y = array![0.0, 0.2, -0.1];
1785 let weights = Array1::ones(3);
1786 let penalty = Array2::eye(2);
1787 NutsPosterior::new(
1788 x.view(),
1789 y.view(),
1790 weights.view(),
1791 penalty.view(),
1792 fit.beta.view(),
1793 explicit.view(),
1794 NutsFamily::Gaussian,
1795 1.0,
1796 gam_solve::estimate::Dispersion::Known(1.0),
1797 false,
1798 )
1799 .expect("HMC target whitens with upstream Hessian");
1800 }
1801
1802 #[test]
1803 fn hmc_whitening_consumes_blockwise_geometry_hessian() {
1804 let hessian = array![[3.0, 0.2], [0.2, 2.0]];
1805 let fit = hmc_test_fit(
1806 vec![
1807 FittedBlock {
1808 beta: array![0.1],
1809 role: BlockRole::Location,
1810 edf: 1.0,
1811 lambdas: Array1::zeros(0),
1812 },
1813 FittedBlock {
1814 beta: array![-0.2],
1815 role: BlockRole::Scale,
1816 edf: 1.0,
1817 lambdas: Array1::zeros(0),
1818 },
1819 ],
1820 None,
1821 Some(FitGeometry {
1822 penalized_hessian: hessian.clone().into(),
1823 working_weights: array![1.0, 0.8],
1824 working_response: array![0.0, 0.1],
1825 }),
1826 );
1827
1828 let explicit = super::explicit_fit_hessian_for_whitening(&fit, 2, "blockwise fit")
1829 .expect("blockwise fit exports materialized Hessian");
1830 assert_eq!(explicit, &hessian);
1831 }
1832
1833 #[test]
1834 fn hmc_whitening_rejects_covariance_only_fit_without_synthesizing_hessian() {
1835 let fit = UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
1836 blocks: vec![FittedBlock {
1837 beta: array![0.0],
1838 role: BlockRole::Mean,
1839 edf: 1.0,
1840 lambdas: Array1::zeros(0),
1841 }],
1842 log_lambdas: Array1::zeros(0),
1843 lambdas: Array1::zeros(0),
1844 likelihood_family: Some(LikelihoodSpec::new(
1845 ResponseFamily::Gaussian,
1846 InverseLink::Standard(StandardLink::Identity),
1847 )),
1848 likelihood_scale: LikelihoodScaleMetadata::ProfiledGaussian,
1849 log_likelihood_normalization: LogLikelihoodNormalization::Full,
1850 log_likelihood: -1.0,
1851 deviance: 2.0,
1852 reml_score: 0.0,
1853 stable_penalty_term: 0.0,
1854 penalized_objective: 0.0,
1855 used_device: false,
1856 outer_iterations: 1,
1857 outer_converged: true,
1858 outer_gradient_norm: None,
1859 standard_deviation: 1.0,
1860 covariance_conditional: Some(array![[0.5]]),
1861 covariance_corrected: None,
1862 inference: None,
1863 fitted_link: FittedLinkState::Standard(None),
1864 geometry: None,
1865 block_states: Vec::new(),
1866 pirls_status: gam_solve::pirls::PirlsStatus::Converged,
1867 max_abs_eta: 0.0,
1868 constraint_kkt: None,
1869 artifacts: Default::default(),
1870 inner_cycles: 0,
1871 })
1872 .expect("covariance-only fit can exist for prediction");
1873
1874 let err = super::explicit_fit_hessian_for_whitening(&fit, 1, "covariance-only fit")
1875 .expect_err("HMC must not invert covariance as a Hessian fallback");
1876 assert!(
1877 err.contains("missing an explicit penalized Hessian"),
1878 "unexpected error: {err}"
1879 );
1880 }
1881
1882 #[test]
1883 fn log1pexp_is_finite_for_extreme_eta() {
1884 assert!(gam_linalg::utils::stable_softplus(1000.0).is_finite());
1885 assert!(gam_linalg::utils::stable_softplus(-1000.0).is_finite());
1886 assert!((gam_linalg::utils::stable_softplus(-1000.0) - 0.0).abs() < 1e-12);
1887 }
1888
1889 #[test]
1890 fn sigmoid_stable_behaves_at_extremes() {
1891 let hi = gam_linalg::utils::stable_logistic(1000.0);
1892 let lo = gam_linalg::utils::stable_logistic(-1000.0);
1893 assert!((1.0 - 1e-12..=1.0).contains(&hi));
1894 assert!((0.0..=1e-12).contains(&lo));
1895 }
1896
1897 #[test]
1898 fn cloglog_log_mu_uses_complementary_loglog_inverse_link() {
1899 let eta = -1.0_f64;
1900 let (ll_y1, residual_y1) =
1901 cloglog_bernoulli_logp_and_residual(eta, 1.0).expect("valid eta");
1902 let expected = (1.0 - (-eta.exp()).exp()).ln();
1903 let wrong_log_one_minus_exp_eta = (1.0 - eta.exp()).ln();
1904
1905 assert!((ll_y1 - expected).abs() < 1e-14);
1906 assert!((ll_y1 - wrong_log_one_minus_exp_eta).abs() > 0.5);
1907
1908 let eps = 1e-6;
1909 let (lp, _) = cloglog_bernoulli_logp_and_residual(eta + eps, 1.0).expect("valid eta");
1910 let (lm, _) = cloglog_bernoulli_logp_and_residual(eta - eps, 1.0).expect("valid eta");
1911 let fd = (lp - lm) / (2.0 * eps);
1912 assert!(
1913 (residual_y1 - fd).abs() < 1e-9,
1914 "cloglog residual is not the derivative of log μ: analytic={residual_y1}, fd={fd}"
1915 );
1916 }
1917
1918 #[test]
1919 fn link_wiggle_posterior_whitening_uses_supplied_explicit_joint_hessian() {
1920 let x = array![[1.0], [1.0], [1.0]];
1921 let y = array![0.0, 1.0, 1.0];
1922 let weights = Array1::ones(3);
1923 let penalty_base = Array2::zeros((1, 1));
1924 let penalty_link = Array2::zeros((1, 1));
1925 let mode_beta = array![0.2];
1926 let mode_theta = array![0.05];
1927 let hessian = array![[4.0, 1.0], [1.0, 3.0]];
1928 let spline = LinkWiggleSplineArtifacts {
1929 knot_range: (-1.0, 1.0),
1930 knot_vector: Array1::from_vec(vec![-1.0, -1.0, -1.0, 1.0, 1.0, 1.0]),
1931 degree: 2,
1932 };
1933
1934 let posterior = LinkWigglePosterior::new(
1935 x.view(),
1936 y.view(),
1937 weights.view(),
1938 penalty_base.view(),
1939 penalty_link.view(),
1940 mode_beta.view(),
1941 mode_theta.view(),
1942 hessian.view(),
1943 spline,
1944 NutsFamily::BinomialLogit,
1945 1.0,
1946 )
1947 .expect("link-wiggle posterior should accept explicit SPD joint Hessian");
1948
1949 let reconstructed_cov = posterior.chol().dot(&posterior.chol().t());
1950 let eye_from_hessian = hessian.dot(&reconstructed_cov);
1951 for r in 0..2 {
1952 for c in 0..2 {
1953 let expected = if r == c { 1.0 } else { 0.0 };
1954 assert!(
1955 (eye_from_hessian[[r, c]] - expected).abs() < 1e-10,
1956 "whitening did not use the supplied explicit joint Hessian at ({r},{c}): got {} expected {}",
1957 eye_from_hessian[[r, c]],
1958 expected
1959 );
1960 }
1961 }
1962 }
1963
1964 #[test]
1965 fn link_wiggle_cloglog_gradient_matches_its_log_likelihood() {
1966 let x = array![[1.0], [1.0], [1.0], [1.0]];
1967 let y = array![1.0, 0.0, 1.0, 0.0];
1968 let weights = array![1.0, 1.2, 0.8, 1.4];
1969 let penalty_base = Array2::zeros((1, 1));
1970 let penalty_link = Array2::zeros((1, 1));
1971 let mode_beta = array![-0.8];
1972 let mode_theta = array![0.04];
1973 let hessian = Array2::eye(2);
1974 let spline = LinkWiggleSplineArtifacts {
1975 knot_range: (-1.5, 0.5),
1976 knot_vector: Array1::from_vec(vec![-1.5, -1.5, -1.5, 0.5, 0.5, 0.5]),
1977 degree: 2,
1978 };
1979
1980 let posterior = LinkWigglePosterior::new(
1981 x.view(),
1982 y.view(),
1983 weights.view(),
1984 penalty_base.view(),
1985 penalty_link.view(),
1986 mode_beta.view(),
1987 mode_theta.view(),
1988 hessian.view(),
1989 spline,
1990 NutsFamily::BinomialCLogLog,
1991 1.0,
1992 )
1993 .expect("cloglog link-wiggle posterior");
1994
1995 let z = array![0.2, -0.03];
1996 let (_, grad) = posterior.compute_logp_and_grad(&z);
1997 let eps = 1e-6;
1998 for j in 0..z.len() {
1999 let mut z_plus = z.clone();
2000 let mut z_minus = z.clone();
2001 z_plus[j] += eps;
2002 z_minus[j] -= eps;
2003 let (lp, _) = posterior.compute_logp_and_grad(&z_plus);
2004 let (lm, _) = posterior.compute_logp_and_grad(&z_minus);
2005 let fd = (lp - lm) / (2.0 * eps);
2006 assert!(
2007 (grad[j] - fd).abs() < 1e-6,
2008 "link-wiggle cloglog gradient mismatch at {j}: analytic={}, fd={}",
2009 grad[j],
2010 fd
2011 );
2012 }
2013 }
2014
2015 #[test]
2016 fn nuts_logitgradient_matches_finite_difference() {
2017 let x = array![[1.0, -0.5], [0.2, 0.7], [-1.0, 0.3], [0.5, -1.2]];
2018 let y = array![1.0, 0.0, 1.0, 0.0];
2019 let w = array![1.0, 1.5, 0.8, 1.2];
2020 let penalty = array![[0.4, 0.0], [0.0, 0.6]];
2021 let mode = array![0.1, -0.2];
2022 let hessian = array![[2.0, 0.2], [0.2, 1.7]]; let posterior = NutsPosterior::new(
2025 x.view(),
2026 y.view(),
2027 w.view(),
2028 penalty.view(),
2029 mode.view(),
2030 hessian.view(),
2031 NutsFamily::BinomialLogit,
2032 1.0,
2033 gam_solve::estimate::Dispersion::Known(1.0),
2034 true,
2035 )
2036 .expect("posterior");
2037
2038 let z = array![0.15, -0.35];
2039 let (_, grad) = posterior.compute_logp_and_grad_nd(&z);
2040
2041 let eps = 1e-6;
2042 for j in 0..z.len() {
2043 let mut z_plus = z.clone();
2044 let mut z_minus = z.clone();
2045 z_plus[j] += eps;
2046 z_minus[j] -= eps;
2047 let (lp, _) = posterior.compute_logp_and_grad_nd(&z_plus);
2048 let (lm, _) = posterior.compute_logp_and_grad_nd(&z_minus);
2049 let fd = (lp - lm) / (2.0 * eps);
2050 assert_eq!(
2051 grad[j].signum(),
2052 fd.signum(),
2053 "gradient sign mismatch at {}: analytic={}, fd={}",
2054 j,
2055 grad[j],
2056 fd
2057 );
2058 assert!(
2059 (grad[j] - fd).abs() < 1e-5,
2060 "gradient mismatch at {}: analytic={}, fd={}",
2061 j,
2062 grad[j],
2063 fd
2064 );
2065 }
2066 }
2067
2068 #[test]
2069 fn gamma_log_logp_and_grad_uses_fitted_shape() {
2070 let x = array![[1.0_f64], [1.0_f64]];
2071 let y = array![1.5_f64, 2.5_f64];
2072 let weights = array![1.0_f64, 2.0_f64];
2073 let eta = array![0.2_f64, 0.4_f64];
2074 let shape = 3.5_f64;
2075 let data = SharedData {
2076 x: Arc::new(x.clone()),
2077 y: Arc::new(y.clone()),
2078 weights: Arc::new(weights.clone()),
2079 mode: Arc::new(Array1::zeros(1)),
2080 offset: None,
2081 gamma_shape: shape,
2082 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2083 n_samples: x.nrows(),
2084 dim: x.ncols(),
2085 };
2086
2087 let (ll, grad) = super::gamma_log_logp_and_grad(&data, &eta);
2088
2089 let mut expected_ll = 0.0;
2090 let mut expected_score = 0.0;
2091 for i in 0..eta.len() {
2092 let mu = eta[i].exp();
2093 expected_ll += weights[i]
2094 * (shape * shape.ln() - statrs::function::gamma::ln_gamma(shape) - shape * eta[i]
2095 + (shape - 1.0) * y[i].ln()
2096 - shape * y[i] / mu);
2097 expected_score += weights[i] * shape * (y[i] / mu - 1.0);
2098 }
2099
2100 assert!((ll - expected_ll).abs() < 1e-12);
2101 assert_eq!(grad.len(), 1);
2102 assert!((grad[0] - expected_score).abs() < 1e-12);
2103 }
2104
2105 fn gamma_log_observed_information(
2109 x: &Array2<f64>,
2110 mode: &Array1<f64>,
2111 y: &Array1<f64>,
2112 weights: &Array1<f64>,
2113 shape: f64,
2114 ) -> Array2<f64> {
2115 let p = x.ncols();
2116 let eta = x.dot(mode);
2117 let mut h = Array2::<f64>::zeros((p, p));
2118 for i in 0..x.nrows() {
2119 let mu = eta[i].exp();
2120 let wt = weights[i] * shape * y[i] / mu;
2121 for a in 0..p {
2122 for b in 0..p {
2123 h[[a, b]] += wt * x[[i, a]] * x[[i, b]];
2124 }
2125 }
2126 }
2127 h
2128 }
2129
2130 #[test]
2143 fn gamma_log_nuts_target_curvature_matches_unscaled_hessian_issue_680() {
2144 let x = array![[1.0, -0.7], [1.0, 0.3], [1.0, 1.1], [1.0, -0.2], [1.0, 0.8],];
2145 let mode = array![0.4_f64, -0.6_f64];
2146 let y = array![1.2_f64, 0.7, 2.3, 0.9, 1.6];
2147 let weights = array![1.0_f64, 1.5, 0.8, 1.2, 1.0];
2148 let shape = 4.0_f64;
2150 let p = x.ncols();
2151
2152 let h_data = gamma_log_observed_information(&x, &mode, &y, &weights, shape);
2153 let s = array![[0.5_f64, 0.1], [0.1, 0.9]];
2155 let hessian = &h_data + &s;
2156
2157 let target = NutsPosterior::new(
2158 x.view(),
2159 y.view(),
2160 weights.view(),
2161 s.view(),
2162 mode.view(),
2163 hessian.view(),
2164 NutsFamily::GammaLog,
2165 shape,
2166 gam_solve::estimate::Dispersion::Estimated(1.0 / shape),
2167 false,
2168 )
2169 .expect("GammaLog NUTS target builds");
2170
2171 let eps = 1e-6;
2174 let z0 = Array1::<f64>::zeros(p);
2175 let mut hz = Array2::<f64>::zeros((p, p));
2176 for j in 0..p {
2177 let mut zp = z0.clone();
2178 let mut zm = z0.clone();
2179 zp[j] += eps;
2180 zm[j] -= eps;
2181 let (_, gp) = target.compute_logp_and_grad_nd(&zp);
2182 let (_, gm) = target.compute_logp_and_grad_nd(&zm);
2183 for a in 0..p {
2184 hz[[a, j]] = -(gp[a] - gm[a]) / (2.0 * eps);
2185 }
2186 }
2187
2188 for a in 0..p {
2189 for b in 0..p {
2190 let expected = if a == b { 1.0 } else { 0.0 };
2191 assert!(
2192 (hz[[a, b]] - expected).abs() < 1e-4,
2193 "z-curvature[{a},{b}] = {} (expected {expected}); a non-identity \
2194 value means the GammaLog target re-introduced the #680 dispersion \
2195 double-count (penalty ×ν and/or whitening ×√φ)",
2196 hz[[a, b]]
2197 );
2198 }
2199 }
2200 let trace: f64 = (0..p).map(|i| hz[[i, i]]).sum();
2202 assert!(
2203 (trace - p as f64).abs() < 1e-3,
2204 "z-curvature trace {trace} ≠ {p}: dispersion double-count signature"
2205 );
2206 }
2207
2208 #[test]
2214 fn gamma_log_nuts_whitening_targets_unscaled_inverse_hessian_issue_680() {
2215 let x = array![[1.0, -0.4], [1.0, 0.6], [1.0, 0.1], [1.0, 1.3]];
2216 let mode = array![0.2_f64, 0.3_f64];
2217 let y = array![0.8_f64, 1.7, 1.1, 2.2];
2218 let weights = array![1.0_f64, 1.0, 1.5, 0.7];
2219 let shape = 6.25_f64; let p = x.ncols();
2221 let s = array![[0.3_f64, 0.0], [0.0, 0.7]];
2222 let hessian = &gamma_log_observed_information(&x, &mode, &y, &weights, shape) + &s;
2223
2224 let target = NutsPosterior::new(
2225 x.view(),
2226 y.view(),
2227 weights.view(),
2228 s.view(),
2229 mode.view(),
2230 hessian.view(),
2231 NutsFamily::GammaLog,
2232 shape,
2233 gam_solve::estimate::Dispersion::Estimated(1.0 / shape),
2234 false,
2235 )
2236 .expect("GammaLog NUTS target builds");
2237
2238 let l = target.chol();
2240 let llt = l.dot(&l.t());
2241 let prod = llt.dot(&hessian);
2242 for a in 0..p {
2243 for b in 0..p {
2244 let expected = if a == b { 1.0 } else { 0.0 };
2245 assert!(
2246 (prod[[a, b]] - expected).abs() < 1e-8,
2247 "L Lᵀ H[{a},{b}] = {} (expected {expected}); a φ·I result means \
2248 the Gamma whitening still scales by √φ (#680)",
2249 prod[[a, b]]
2250 );
2251 }
2252 }
2253 }
2254
2255 #[test]
2256 fn firth_jeffreys_logit_is_finite_for_rank_deficient_design() {
2257 let x = array![
2258 [1.0, -0.5, 1.0],
2259 [1.0, 0.3, 1.0],
2260 [1.0, 0.8, 1.0],
2261 [1.0, -1.2, 1.0],
2262 ];
2263 let y = array![1.0, 0.0, 1.0, 0.0];
2264 let weights = array![1.0, 2.0, 0.5, 1.5];
2265 let eta = array![0.2, -0.1, 0.4, -0.3];
2266
2267 let data = SharedData {
2268 x: Arc::new(x.clone()),
2269 y: Arc::new(y),
2270 weights: Arc::new(weights.clone()),
2271 mode: Arc::new(Array1::zeros(x.ncols())),
2272 offset: None,
2273 gamma_shape: 1.0,
2274 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2275 n_samples: x.nrows(),
2276 dim: x.ncols(),
2277 };
2278
2279 let (value, grad) =
2280 firth_jeffreys_logp_and_grad(NutsFamily::BinomialLogit, &data, &eta).expect("firth");
2281
2282 assert!(value.is_finite());
2283 assert_eq!(grad.len(), x.ncols());
2284 assert!(grad.iter().all(|v| v.is_finite()));
2285 }
2286
2287 #[test]
2288 fn logit_pg_gibbs_returns_finite_samples() {
2289 let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2290 let y = array![1.0, 0.0, 1.0, 0.0];
2291 let w = array![1.0, 1.0, 1.0, 1.0];
2292 let penalty = array![[0.2, 0.0], [0.0, 0.4]];
2293 let mode = array![0.0, 0.0];
2294 let cfg = NutsConfig {
2295 n_samples: 30,
2296 nwarmup: 30,
2297 n_chains: 2,
2298 target_accept: 0.8,
2299 seed: 123,
2300 };
2301 let out = run_logit_polya_gamma_gibbs(
2302 x.view(),
2303 y.view(),
2304 w.view(),
2305 penalty.view(),
2306 mode.view(),
2307 &cfg,
2308 )
2309 .expect("pg gibbs should run");
2310 assert_eq!(out.samples.ncols(), 2);
2311 assert_eq!(out.samples.nrows(), cfg.n_samples * cfg.n_chains);
2312 assert!(out.samples.iter().all(|v| v.is_finite()));
2313 assert!(out.posterior_mean.iter().all(|v| v.is_finite()));
2314 assert!(out.posterior_std.iter().all(|v| v.is_finite()));
2315 }
2316
2317 #[test]
2318 fn family_pg_dispatch_rejects_non_bernoulli_response() {
2319 let x = array![[1.0], [1.0]];
2320 let y = array![2.0, 0.0];
2321 let w = array![1.0, 1.0];
2322 let penalty = array![[0.1]];
2323 let mode = array![0.0];
2324 let non_spd_hessian = array![[0.0]];
2325 let cfg = NutsConfig {
2326 n_samples: 1,
2327 nwarmup: 1,
2328 n_chains: 1,
2329 target_accept: 0.8,
2330 seed: 321,
2331 };
2332
2333 let result = run_nuts_sampling_flattened_family(
2334 LikelihoodSpec::binomial_logit(),
2335 FamilyNutsInputs::Glm(GlmFlatInputs {
2336 x: x.view(),
2337 y: y.view(),
2338 weights: w.view(),
2339 penalty_matrix: penalty.view(),
2340 mode: mode.view(),
2341 hessian: non_spd_hessian.view(),
2342 gamma_shape: None,
2343 dispersion: gam_solve::model_types::Dispersion::Known(1.0),
2344 firth_bias_reduction: false,
2345 offset: None,
2346 }),
2347 &cfg,
2348 );
2349
2350 let err = result.err().expect("PG dispatch should reject count rows");
2351 assert!(
2352 err.contains("response must be exactly 0 or 1"),
2353 "unexpected error: {err}"
2354 );
2355 }
2356
2357 #[test]
2358 fn family_dispatch_uses_pg_gibbs_for_standard_logit() {
2359 let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2360 let y = array![1.0, 0.0, 1.0, 0.0];
2361 let w = array![1.0, 1.0, 1.0, 1.0];
2362 let penalty = array![[0.2, 0.0], [0.0, 0.4]];
2363 let mode = array![0.0, 0.0];
2364 let non_spdhessian = array![[0.0, 0.0], [0.0, 0.0]];
2365 let cfg = NutsConfig {
2366 n_samples: 20,
2367 nwarmup: 20,
2368 n_chains: 2,
2369 target_accept: 0.8,
2370 seed: 456,
2371 };
2372 let out = run_nuts_sampling_flattened_family(
2373 LikelihoodSpec {
2374 response: ResponseFamily::Binomial,
2375 link: InverseLink::Standard(StandardLink::Logit),
2376 },
2377 FamilyNutsInputs::Glm(GlmFlatInputs {
2378 x: x.view(),
2379 y: y.view(),
2380 weights: w.view(),
2381 penalty_matrix: penalty.view(),
2382 mode: mode.view(),
2383 hessian: non_spdhessian.view(),
2384 gamma_shape: None,
2385 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2386 firth_bias_reduction: false,
2387 offset: None,
2388 }),
2389 &cfg,
2390 )
2391 .expect("dispatch should use PG Gibbs and not require Hessian factorization");
2392 assert_eq!(out.samples.nrows(), cfg.n_samples * cfg.n_chains);
2393 assert!(out.samples.iter().all(|v| v.is_finite()));
2394 }
2395
2396 #[test]
2397 fn family_dispatch_routes_probit_to_nuts_path() {
2398 let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2399 let y = array![1.0, 0.0, 1.0, 0.0];
2400 let w = array![1.0, 1.0, 1.0, 1.0];
2401 let penalty = array![[0.2, 0.0], [0.0, 0.4]];
2402 let mode = array![0.0, 0.0];
2403 let non_spdhessian = array![[0.0, 0.0], [0.0, 0.0]];
2404 let cfg = NutsConfig {
2405 n_samples: 20,
2406 nwarmup: 20,
2407 n_chains: 2,
2408 target_accept: 0.8,
2409 seed: 654,
2410 };
2411
2412 let err = match run_nuts_sampling_flattened_family(
2413 LikelihoodSpec {
2414 response: ResponseFamily::Binomial,
2415 link: InverseLink::Standard(StandardLink::Probit),
2416 },
2417 FamilyNutsInputs::Glm(GlmFlatInputs {
2418 x: x.view(),
2419 y: y.view(),
2420 weights: w.view(),
2421 penalty_matrix: penalty.view(),
2422 mode: mode.view(),
2423 hessian: non_spdhessian.view(),
2424 gamma_shape: None,
2425 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2426 firth_bias_reduction: false,
2427 offset: None,
2428 }),
2429 &cfg,
2430 ) {
2431 Ok(_) => panic!("non-SPD Hessian should fail after probit routes to the NUTS path"),
2432 Err(err) => err,
2433 };
2434
2435 assert!(
2436 err.contains("Hessian Cholesky decomposition failed"),
2437 "unexpected error: {err}"
2438 );
2439 }
2440
2441 #[test]
2442 fn family_dispatch_rejects_nonbinomial_firth_family() {
2443 let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2444 let y = array![1.0, 2.0, 0.0, 3.0];
2445 let w = array![1.0, 1.0, 1.0, 1.0];
2446 let penalty = array![[0.2, 0.0], [0.0, 0.4]];
2447 let mode = array![0.0, 0.0];
2448 let hessian = array![[1.5, 0.1], [0.1, 1.2]];
2449 let cfg = NutsConfig {
2450 n_samples: 20,
2451 nwarmup: 20,
2452 n_chains: 2,
2453 target_accept: 0.8,
2454 seed: 111,
2455 };
2456
2457 let err = match run_nuts_sampling_flattened_family(
2458 LikelihoodSpec {
2459 response: ResponseFamily::Poisson,
2460 link: InverseLink::Standard(StandardLink::Log),
2461 },
2462 FamilyNutsInputs::Glm(GlmFlatInputs {
2463 x: x.view(),
2464 y: y.view(),
2465 weights: w.view(),
2466 penalty_matrix: penalty.view(),
2467 mode: mode.view(),
2468 hessian: hessian.view(),
2469 gamma_shape: None,
2470 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2471 firth_bias_reduction: true,
2472 offset: None,
2473 }),
2474 &cfg,
2475 ) {
2476 Ok(_) => panic!("Poisson Firth should be rejected explicitly"),
2477 Err(err) => err,
2478 };
2479
2480 assert!(
2481 err.contains(
2482 "NUTS with Firth requires a Binomial inverse link with a Fisher-weight jet"
2483 ),
2484 "unexpected error: {err}"
2485 );
2486 }
2487
2488 #[test]
2489 fn run_nuts_sampling_rejects_invalid_target_accept() {
2490 let x = array![[1.0], [1.0], [1.0]];
2491 let y = array![0.5, -0.5, 1.0];
2492 let weights = array![1.0, 1.0, 1.0];
2493 let penalty = array![[0.25]];
2494 let mode = array![0.0];
2495 let hessian = array![[1.25]];
2496 let cfg = NutsConfig {
2497 n_samples: 10,
2498 nwarmup: 10,
2499 n_chains: 1,
2500 target_accept: 1.0,
2501 seed: 222,
2502 };
2503
2504 let err = super::run_nuts_sampling(
2505 x.view(),
2506 y.view(),
2507 weights.view(),
2508 penalty.view(),
2509 mode.view(),
2510 hessian.view(),
2511 NutsFamily::Gaussian,
2512 1.0,
2513 gam_solve::estimate::Dispersion::Known(1.0),
2514 false,
2515 None,
2516 &cfg,
2517 )
2518 .expect_err("invalid target_accept should be rejected before sampling");
2519
2520 assert!(
2521 err.contains("target_accept must be finite and lie in (0, 1)"),
2522 "unexpected error: {err}"
2523 );
2524 }
2525
2526 #[test]
2527 fn run_nuts_sampling_rejects_zero_or_too_few_samples() {
2528 let x = array![[1.0], [1.0], [1.0]];
2535 let y = array![0.5, -0.5, 1.0];
2536 let weights = array![1.0, 1.0, 1.0];
2537 let penalty = array![[0.25]];
2538 let mode = array![0.0];
2539 let hessian = array![[1.25]];
2540
2541 for bad_samples in [0usize, 1, 2, 3] {
2542 let cfg = NutsConfig {
2543 n_samples: bad_samples,
2544 nwarmup: 10,
2545 n_chains: 2,
2546 target_accept: 0.8,
2547 seed: 222,
2548 };
2549
2550 let err = super::run_nuts_sampling(
2551 x.view(),
2552 y.view(),
2553 weights.view(),
2554 penalty.view(),
2555 mode.view(),
2556 hessian.view(),
2557 NutsFamily::Gaussian,
2558 1.0,
2559 gam_solve::estimate::Dispersion::Known(1.0),
2560 false,
2561 None,
2562 &cfg,
2563 )
2564 .expect_err("too-few samples must be rejected before sampling");
2565
2566 assert!(
2567 err.contains("n_samples must be >= 4"),
2568 "n_samples={bad_samples} gave unexpected error: {err}"
2569 );
2570 }
2571 }
2572
2573 #[test]
2574 fn polya_gamma_gibbs_rejects_degenerate_counts_but_accepts_single_chain() {
2575 let x = array![[1.0], [1.0], [1.0], [1.0]];
2584 let y = array![1.0, 0.0, 1.0, 0.0];
2585 let weights = array![1.0, 1.0, 1.0, 1.0];
2586 let penalty = array![[0.25]];
2587 let mode = array![0.0];
2588
2589 let zero_chain_cfg = NutsConfig {
2590 n_samples: 20,
2591 nwarmup: 10,
2592 n_chains: 0,
2593 target_accept: 0.8,
2594 seed: 7,
2595 };
2596 let err = super::run_logit_polya_gamma_gibbs(
2597 x.view(),
2598 y.view(),
2599 weights.view(),
2600 penalty.view(),
2601 mode.view(),
2602 &zero_chain_cfg,
2603 )
2604 .expect_err("PG Gibbs must reject zero chains up front, not return an empty posterior");
2605 assert!(
2606 err.contains("n_chains must be >= 1"),
2607 "PG n_chains=0 gave unexpected error: {err}"
2608 );
2609
2610 let zero_sample_cfg = NutsConfig {
2611 n_samples: 0,
2612 nwarmup: 10,
2613 n_chains: 2,
2614 target_accept: 0.8,
2615 seed: 7,
2616 };
2617 let err = super::run_logit_polya_gamma_gibbs(
2618 x.view(),
2619 y.view(),
2620 weights.view(),
2621 penalty.view(),
2622 mode.view(),
2623 &zero_sample_cfg,
2624 )
2625 .expect_err("PG Gibbs must reject zero samples up front, not return an empty posterior");
2626 assert!(
2627 err.contains("n_samples must be >= 4"),
2628 "PG n_samples=0 gave unexpected error: {err}"
2629 );
2630
2631 let single_chain_cfg = NutsConfig {
2632 n_samples: 20,
2633 nwarmup: 10,
2634 n_chains: 1,
2635 target_accept: 0.8,
2636 seed: 7,
2637 };
2638 let result = super::run_logit_polya_gamma_gibbs(
2639 x.view(),
2640 y.view(),
2641 weights.view(),
2642 penalty.view(),
2643 mode.view(),
2644 &single_chain_cfg,
2645 )
2646 .expect("PG Gibbs must accept a single chain and return draws");
2647 assert_eq!(
2648 result.samples.nrows(),
2649 20,
2650 "single-chain PG run should return all 20 requested draws"
2651 );
2652 }
2653
2654 #[test]
2655 fn run_nuts_sampling_rejects_zero_chains_but_accepts_single_chain() {
2656 let x = array![[1.0], [1.0], [1.0]];
2670 let y = array![0.5, -0.5, 1.0];
2671 let weights = array![1.0, 1.0, 1.0];
2672 let penalty = array![[0.25]];
2673 let mode = array![0.0];
2674 let hessian = array![[1.25]];
2675
2676 let zero_chain_cfg = NutsConfig {
2677 n_samples: 50,
2678 nwarmup: 10,
2679 n_chains: 0,
2680 target_accept: 0.8,
2681 seed: 222,
2682 };
2683 let err = super::run_nuts_sampling(
2684 x.view(),
2685 y.view(),
2686 weights.view(),
2687 penalty.view(),
2688 mode.view(),
2689 hessian.view(),
2690 NutsFamily::Gaussian,
2691 1.0,
2692 gam_solve::estimate::Dispersion::Known(1.0),
2693 false,
2694 None,
2695 &zero_chain_cfg,
2696 )
2697 .expect_err("zero chains must be rejected before sampling");
2698 assert!(
2699 err.contains("n_chains must be >= 1"),
2700 "n_chains=0 gave unexpected error: {err}"
2701 );
2702
2703 let single_chain_cfg = NutsConfig {
2704 n_samples: 50,
2705 nwarmup: 10,
2706 n_chains: 1,
2707 target_accept: 0.8,
2708 seed: 222,
2709 };
2710 let result = super::run_nuts_sampling(
2711 x.view(),
2712 y.view(),
2713 weights.view(),
2714 penalty.view(),
2715 mode.view(),
2716 hessian.view(),
2717 NutsFamily::Gaussian,
2718 1.0,
2719 gam_solve::estimate::Dispersion::Known(1.0),
2720 false,
2721 None,
2722 &single_chain_cfg,
2723 )
2724 .expect("a single chain is a supported configuration and must return draws");
2725 assert_eq!(
2726 result.samples.nrows(),
2727 50,
2728 "single-chain run should return all 50 requested draws"
2729 );
2730 }
2731
2732 #[test]
2733 fn joint_hmc_boundary_rejects_nonbinomial_firth_family() {
2734 let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2735 let y = array![1.0, 2.0, 0.0, 3.0];
2736 let w = array![1.0, 1.0, 1.0, 1.0];
2737 let hessian = array![[1.5, 0.1], [0.1, 1.2]];
2738 let penalty_root = array![[0.4, 0.0], [0.0, 0.6]];
2739 let mode = array![0.0, 0.0];
2740 let rho_mode = array![0.0];
2741 let cfg = NutsConfig {
2742 n_samples: 20,
2743 nwarmup: 20,
2744 n_chains: 2,
2745 target_accept: 0.8,
2746 seed: 111,
2747 };
2748
2749 let inputs = JointBetaRhoInputs {
2750 x: x.view(),
2751 y: y.view(),
2752 weights: w.view(),
2753 likelihood: LikelihoodSpec {
2754 response: ResponseFamily::Poisson,
2755 link: InverseLink::Standard(StandardLink::Log),
2756 },
2757 gamma_shape: None,
2758 mode: mode.view(),
2759 hessian: hessian.view(),
2760 penalty_roots: vec![CanonicalPenalty::from_dense_root(
2761 penalty_root.clone(),
2762 penalty_root.ncols(),
2763 )],
2764 rho_mode: rho_mode.view(),
2765 rho_prior: RhoPrior::default(),
2766 firth_bias_reduction: true,
2767 trigger_skewness: 0.75,
2768 };
2769
2770 let err = match run_joint_beta_rho_sampling(&inputs, &cfg) {
2771 Ok(_) => panic!("Poisson joint HMC Firth should be rejected explicitly"),
2772 Err(err) => err,
2773 };
2774
2775 assert!(
2776 err.contains(
2777 "Joint HMC with Firth requires a Binomial inverse link with a Fisher-weight jet"
2778 ),
2779 "unexpected error: {err}"
2780 );
2781 }
2782
2783 #[test]
2784 fn joint_hmc_uses_combined_penalty_logdet_for_overlapping_penalties() {
2785 let x = array![[0.0, 0.0]];
2786 let y = array![0.0];
2787 let w = array![0.0];
2788 let mode = array![0.0, 0.0];
2789 let hessian = array![[1.0, 0.0], [0.0, 1.0]];
2790 let rho_mode = array![0.0, 0.0];
2791 let penalty_1 = array![[1.0, 0.0], [0.0, 1.0]];
2792 let penalty_2 = array![[2.0_f64.sqrt(), 0.0], [0.0, 1.0]];
2793 let target = JointBetaRhoPosterior::new(
2794 x.view(),
2795 y.view(),
2796 w.view(),
2797 mode.view(),
2798 hessian.view(),
2799 vec![
2800 CanonicalPenalty::from_dense_root(penalty_1, 2),
2801 CanonicalPenalty::from_dense_root(penalty_2, 2),
2802 ],
2803 rho_mode.view(),
2804 LikelihoodSpec {
2805 response: ResponseFamily::Gaussian,
2806 link: InverseLink::Standard(StandardLink::Identity),
2807 },
2808 None,
2809 RhoPrior::Flat,
2810 false,
2811 )
2812 .expect("joint target");
2813
2814 let params = array![0.0, 0.0, 0.0, 0.0];
2815 let (_, grad) = target.compute_joint_logp_and_grad(¶ms);
2816 assert!(
2817 (grad[2] - 5.0 / 12.0).abs() < 1.0e-10,
2818 "expected overlapping-penalty gradient 5/12, got {}",
2819 grad[2]
2820 );
2821 assert!(
2822 (grad[3] - 7.0 / 12.0).abs() < 1.0e-10,
2823 "expected overlapping-penalty gradient 7/12, got {}",
2824 grad[3]
2825 );
2826 }
2827
2828 #[test]
2829 fn joint_hmc_target_does_not_depend_on_rho_mode_when_prior_is_fixed() {
2830 let x = array![[0.0]];
2831 let y = array![0.0];
2832 let w = array![0.0];
2833 let mode = array![0.0];
2834 let hessian = array![[1.0]];
2835 let penalty = CanonicalPenalty::from_dense_root(array![[1.0]], 1);
2836 let prior = RhoPrior::Normal {
2837 mean: 0.25,
2838 sd: 1.7,
2839 };
2840
2841 let target_a = JointBetaRhoPosterior::new(
2842 x.view(),
2843 y.view(),
2844 w.view(),
2845 mode.view(),
2846 hessian.view(),
2847 vec![penalty.clone()],
2848 array![0.0].view(),
2849 LikelihoodSpec {
2850 response: ResponseFamily::Gaussian,
2851 link: InverseLink::Standard(StandardLink::Identity),
2852 },
2853 None,
2854 prior.clone(),
2855 false,
2856 )
2857 .expect("target a");
2858 let target_b = JointBetaRhoPosterior::new(
2859 x.view(),
2860 y.view(),
2861 w.view(),
2862 mode.view(),
2863 hessian.view(),
2864 vec![penalty],
2865 array![2.5].view(),
2866 LikelihoodSpec {
2867 response: ResponseFamily::Gaussian,
2868 link: InverseLink::Standard(StandardLink::Identity),
2869 },
2870 None,
2871 prior,
2872 false,
2873 )
2874 .expect("target b");
2875
2876 let params = array![0.0, -0.4];
2877 let (lp_a, grad_a) = target_a.compute_joint_logp_and_grad(¶ms);
2878 let (lp_b, grad_b) = target_b.compute_joint_logp_and_grad(¶ms);
2879 assert!((lp_a - lp_b).abs() < 1.0e-12);
2880 for i in 0..grad_a.len() {
2881 assert!(
2882 (grad_a[i] - grad_b[i]).abs() < 1.0e-12,
2883 "rho_mode leaked into target gradient at {}: {} vs {}",
2884 i,
2885 grad_a[i],
2886 grad_b[i]
2887 );
2888 }
2889 }
2890
2891 #[test]
2892 fn joint_hmc_binomial_sas_uses_runtime_link_state() {
2893 let x = array![[1.0], [1.0]];
2894 let y = array![1.0, 0.0];
2895 let weights = array![1.0, 1.0];
2896 let eta = array![0.3, -0.2];
2897 let sas_state =
2898 gam_solve::mixture_link::state_from_sasspec(gam_problem::types::SasLinkSpec {
2899 initial_epsilon: 0.4,
2900 initial_log_delta: -0.2,
2901 })
2902 .expect("sas state");
2903 let data = SharedData {
2904 x: Arc::new(x),
2905 y: Arc::new(y),
2906 weights: Arc::new(weights),
2907 mode: Arc::new(Array1::zeros(1)),
2908 offset: None,
2909 gamma_shape: 1.0,
2910 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2911 n_samples: 2,
2912 dim: 1,
2913 };
2914
2915 let (ll_sas, _) = joint_family_logp_and_grad(
2916 &LikelihoodSpec {
2917 response: ResponseFamily::Binomial,
2918 link: InverseLink::Sas(sas_state),
2919 },
2920 &data,
2921 &eta,
2922 )
2923 .expect("sas joint logp");
2924 let (ll_logit, _) = joint_family_logp_and_grad(
2925 &LikelihoodSpec {
2926 response: ResponseFamily::Binomial,
2927 link: InverseLink::Standard(StandardLink::Logit),
2928 },
2929 &data,
2930 &eta,
2931 )
2932 .expect("logit joint logp");
2933
2934 assert!(
2935 (ll_sas - ll_logit).abs() > 1.0e-6,
2936 "adaptive SAS link should not collapse to the logit likelihood"
2937 );
2938 }
2939
2940 #[test]
2941 fn directional_cubic_diagnostic_is_rotation_invariant_for_hessian_eigenvectors() {
2942 let x = array![[1.0, 0.5], [-0.3, 1.4], [0.8, -1.1]];
2943 let c = array![0.7, -0.5, 0.2];
2944 let h = array![[4.0, 0.0], [0.0, 1.0]];
2945 let theta = std::f64::consts::FRAC_PI_4;
2946 let q = array![[theta.cos(), -theta.sin()], [theta.sin(), theta.cos()],];
2947 let x_rot = x.dot(&q);
2948 let h_rot = q.t().dot(&h).dot(&q);
2949
2950 let (base_max, base_vals) = laplace_directional_cubic_diagnostic(
2951 &h,
2952 &DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x)),
2953 &c,
2954 true,
2955 )
2956 .expect("base diagnostic");
2957 let (rot_max, rot_vals) = laplace_directional_cubic_diagnostic(
2958 &h_rot,
2959 &DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x_rot)),
2960 &c,
2961 true,
2962 )
2963 .expect("rotated diagnostic");
2964
2965 let mut base_abs: Vec<f64> = base_vals.iter().map(|v| v.abs()).collect();
2966 let mut rot_abs: Vec<f64> = rot_vals.iter().map(|v| v.abs()).collect();
2967 base_abs.sort_by(|a, b| a.partial_cmp(b).expect("finite compare"));
2968 rot_abs.sort_by(|a, b| a.partial_cmp(b).expect("finite compare"));
2969
2970 assert!((base_max - rot_max).abs() < 1.0e-10);
2971 for i in 0..base_abs.len() {
2972 assert!(
2973 (base_abs[i] - rot_abs[i]).abs() < 1.0e-10,
2974 "directional diagnostic changed under rotation at {}: {} vs {}",
2975 i,
2976 base_abs[i],
2977 rot_abs[i]
2978 );
2979 }
2980 }
2981
2982 #[test]
2986 fn joint_hmc_penalty_logdet_agrees_with_reml_path() {
2987 use gam_solve::estimate::reml::penalty_logdet::PenaltyPseudologdet;
2988
2989 let root_1 = array![[1.0, 0.5, 0.0], [0.0, 0.8, 0.3]];
2991 let root_2 = array![[0.0, 0.7, 0.0], [0.0, 0.0, 1.2]];
2992 let cp1 = CanonicalPenalty::from_dense_root(root_1, 3);
2993 let cp2 = CanonicalPenalty::from_dense_root(root_2, 3);
2994 let lambdas = [2.5_f64, 0.8];
2995 let penalties = [cp1.clone(), cp2.clone()];
2996
2997 let pld =
2999 PenaltyPseudologdet::from_penalties(&penalties, &lambdas, 0.0, 3).expect("reml pld");
3000 let reml_value = pld.value();
3001 let (reml_d1, reml_d2) = pld.rho_derivatives_from_penalties(&penalties, &lambdas);
3002
3003 let x = Array2::<f64>::zeros((1, 3));
3007 let y = array![0.0];
3008 let w = array![0.0];
3009 let mode = Array1::<f64>::zeros(3);
3010 let hessian = Array2::<f64>::eye(3);
3011 let rho = Array1::from_vec(lambdas.iter().map(|l| l.ln()).collect());
3012 let target = JointBetaRhoPosterior::new(
3013 x.view(),
3014 y.view(),
3015 w.view(),
3016 mode.view(),
3017 hessian.view(),
3018 vec![cp1, cp2],
3019 rho.view(),
3020 LikelihoodSpec {
3021 response: ResponseFamily::Gaussian,
3022 link: InverseLink::Standard(StandardLink::Identity),
3023 },
3024 None,
3025 RhoPrior::Flat,
3026 false,
3027 )
3028 .expect("joint target");
3029
3030 let mut params = Array1::<f64>::zeros(3 + 2);
3032 params[3] = rho[0];
3033 params[4] = rho[1];
3034 let (logp, grad) = target.compute_joint_logp_and_grad(¶ms);
3035
3036 assert!(
3038 (logp - 0.5 * reml_value).abs() < 1.0e-8,
3039 "joint HMC logdet value {} vs REML 0.5*{} = {}",
3040 logp,
3041 reml_value,
3042 0.5 * reml_value,
3043 );
3044
3045 for k in 0..2 {
3047 assert!(
3048 (grad[3 + k] - 0.5 * reml_d1[k]).abs() < 1.0e-8,
3049 "joint HMC logdet gradient[{}] = {} vs REML 0.5*{} = {}",
3050 k,
3051 grad[3 + k],
3052 reml_d1[k],
3053 0.5 * reml_d1[k],
3054 );
3055 }
3056
3057 assert!(
3060 (reml_d2[[0, 1]] - reml_d2[[1, 0]]).abs() < 1.0e-12,
3061 "REML penalty logdet Hessian not symmetric"
3062 );
3063 }
3064
3065 #[test]
3070 fn joint_hmc_family_gating_never_remaps() {
3071 let data = SharedData {
3072 x: Arc::new(array![[1.0], [1.0]]),
3073 y: Arc::new(array![1.0, 0.0]),
3074 weights: Arc::new(array![1.0, 1.0]),
3075 mode: Arc::new(Array1::zeros(1)),
3076 offset: None,
3077 gamma_shape: 1.0,
3078 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
3079 n_samples: 2,
3080 dim: 1,
3081 };
3082 let eta = array![0.1, -0.1];
3083
3084 let accepted = [
3086 LikelihoodSpec {
3087 response: ResponseFamily::Binomial,
3088 link: InverseLink::Standard(StandardLink::Logit),
3089 },
3090 LikelihoodSpec {
3091 response: ResponseFamily::Binomial,
3092 link: InverseLink::Standard(StandardLink::Probit),
3093 },
3094 LikelihoodSpec {
3095 response: ResponseFamily::Binomial,
3096 link: InverseLink::Standard(StandardLink::CLogLog),
3097 },
3098 LikelihoodSpec {
3099 response: ResponseFamily::Gaussian,
3100 link: InverseLink::Standard(StandardLink::Identity),
3101 },
3102 LikelihoodSpec {
3103 response: ResponseFamily::Poisson,
3104 link: InverseLink::Standard(StandardLink::Log),
3105 },
3106 LikelihoodSpec {
3107 response: ResponseFamily::Gamma,
3108 link: InverseLink::Standard(StandardLink::Log),
3109 },
3110 ];
3111 for spec in &accepted {
3112 let result = joint_family_logp_and_grad(spec, &data, &eta);
3113 assert!(
3114 result.is_ok(),
3115 "spec {:?} should be accepted but got error: {:?}",
3116 spec,
3117 result.err(),
3118 );
3119 }
3120
3121 let sas_state =
3124 gam_solve::mixture_link::state_from_sasspec(gam_problem::types::SasLinkSpec {
3125 initial_epsilon: 0.0,
3126 initial_log_delta: 0.0,
3127 })
3128 .expect("sas state");
3129 let adaptive = [
3130 LikelihoodSpec {
3131 response: ResponseFamily::Binomial,
3132 link: InverseLink::Sas(sas_state),
3133 },
3134 LikelihoodSpec {
3135 response: ResponseFamily::Binomial,
3136 link: InverseLink::BetaLogistic(
3137 gam_solve::mixture_link::state_from_sasspec(gam_problem::types::SasLinkSpec {
3138 initial_epsilon: 0.0,
3139 initial_log_delta: 0.0,
3140 })
3141 .expect("bl state"),
3142 ),
3143 },
3144 ];
3145 for spec in &adaptive {
3146 let result = joint_family_logp_and_grad(spec, &data, &eta);
3147 assert!(
3148 result.is_ok(),
3149 "adaptive spec {:?} should be accepted with its real link",
3150 spec,
3151 );
3152 }
3153
3154 let rp_result = joint_family_logp_and_grad(
3156 &LikelihoodSpec {
3157 response: ResponseFamily::RoystonParmar,
3158 link: InverseLink::Standard(StandardLink::Logit),
3159 },
3160 &data,
3161 &eta,
3162 );
3163 assert!(
3164 rp_result.is_err(),
3165 "RoystonParmar should be rejected, not silently accepted"
3166 );
3167 }
3168
3169 #[test]
3172 fn directional_cubic_power_iteration_finds_larger_or_equal_skewness() {
3173 let x = array![
3177 [2.0, 1.0],
3178 [-1.0, 2.0],
3179 [0.5, -0.5],
3180 [1.5, 0.3],
3181 [-0.8, 1.7],
3182 ];
3183 let c = array![1.0, -0.5, 0.3, -0.7, 0.4];
3184 let h = array![[3.0, 1.0], [1.0, 2.0]];
3185
3186 let (max_val, eigenvector_vals) = laplace_directional_cubic_diagnostic(
3187 &h,
3188 &DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x)),
3189 &c,
3190 true,
3191 )
3192 .expect("diagnostic");
3193
3194 let eig_max = eigenvector_vals
3196 .iter()
3197 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
3198 assert!(
3199 max_val >= eig_max - 1.0e-12,
3200 "power iteration result {} should be >= eigenvector max {}",
3201 max_val,
3202 eig_max,
3203 );
3204 }
3205
3206 #[test]
3207 fn laplace_trustworthiness_is_block_local_and_threshold_shrinks_with_n() {
3208 let skew = array![0.01, 0.9];
3214
3215 let verdict = laplace_trustworthiness_from_skewness(&skew, 100.0);
3218 assert_eq!(
3219 verdict.untrustworthy_directions,
3220 vec![1],
3221 "only the strongly-skewed direction should be flagged (block-local)",
3222 );
3223 assert!(verdict.fallback_required());
3224 assert!((verdict.max_abs_skewness - 0.9).abs() < 1e-12);
3225
3226 let t_small = laplace_skewness_threshold(25.0);
3230 let t_large = laplace_skewness_threshold(10_000.0);
3231 assert!(
3232 t_large < t_small,
3233 "validity threshold must tighten with sample size: {t_large} !< {t_small}",
3234 );
3235
3236 let none = laplace_trustworthiness_from_skewness(&skew, 0.0);
3239 assert!(!none.fallback_required());
3240 assert!(none.threshold.is_infinite());
3241 }
3242
3243 struct AnharmonicBlock {
3249 lambdas: Array1<f64>,
3250 a: f64,
3251 }
3252 impl super::BlockExcessTarget for AnharmonicBlock {
3253 fn block_dim(&self) -> usize {
3254 self.lambdas.len()
3255 }
3256 fn rho_dim(&self) -> usize {
3257 self.lambdas.len()
3258 }
3259 fn block_curvatures(&self) -> &Array1<f64> {
3260 &self.lambdas
3261 }
3262 fn excess(&self, t: &Array1<f64>) -> f64 {
3263 self.a * t.iter().map(|&x| x.powi(4)).sum::<f64>()
3264 }
3265 fn excess_rho_gradient(&self, t: &Array1<f64>) -> Array1<f64> {
3266 t.mapv(|x| self.a * x.powi(4))
3267 }
3268 fn displaced_neg_score(&self, t: &Array1<f64>) -> Array1<f64> {
3269 assert_eq!(t.len(), self.block_dim(), "displacement dim mismatch");
3273 Array1::zeros(0)
3274 }
3275 fn base_neg_score(&self) -> Array1<f64> {
3276 Array1::zeros(0)
3277 }
3278 }
3279
3280 #[test]
3281 fn block_sampled_marginal_is_zero_for_gaussian_block() {
3282 let target = AnharmonicBlock {
3287 lambdas: array![2.0, 0.5],
3288 a: 0.0,
3289 };
3290 let out = super::block_sampled_marginal_correction(&target).expect("correction");
3291 assert!(
3292 out.value.abs() < 1e-12,
3293 "Gaussian block value {}",
3294 out.value
3295 );
3296 assert!(out.rho_gradient.iter().all(|&g| g.abs() < 1e-12));
3297 assert!(out.n_draws > 0);
3298 }
3299
3300 #[test]
3301 fn block_sampled_marginal_recovers_analytic_quartic_correction() {
3302 let lambda = 3.0_f64;
3309 let a = 0.05_f64;
3310 let target = AnharmonicBlock {
3311 lambdas: array![lambda],
3312 a,
3313 };
3314 let out = super::block_sampled_marginal_correction(&target).expect("correction");
3315
3316 let sigma = (1.0 / lambda).sqrt();
3319 let steps = 20_001;
3320 let lo = -8.0 * sigma;
3321 let hi = 8.0 * sigma;
3322 let h = (hi - lo) / (steps as f64 - 1.0);
3323 let mut integral = 0.0_f64;
3324 for i in 0..steps {
3325 let tt = lo + h * i as f64;
3326 let gauss = (-(tt * tt) / (2.0 * sigma * sigma)).exp()
3327 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
3328 let w = if i == 0 || i == steps - 1 { 0.5 } else { 1.0 };
3329 integral += w * gauss * (-a * tt.powi(4)).exp() * h;
3330 }
3331 let reference = integral.ln();
3332 assert!(
3333 (out.value - reference).abs() < 5e-3,
3334 "sampled Δ_b {} vs reference {}",
3335 out.value,
3336 reference,
3337 );
3338 assert!(out.value < 0.0, "quartic penalty must shrink block mass");
3339 }
3340
3341 struct MatvecBlock {
3349 lambdas: Array1<f64>,
3350 x: Array2<f64>,
3351 v_b: Array2<f64>,
3352 y: Array1<f64>,
3353 batched: bool,
3354 }
3355 impl MatvecBlock {
3356 fn s_of(&self, t: &Array1<f64>) -> Array1<f64> {
3357 let delta = self.v_b.dot(t);
3358 gam_linalg::faer_ndarray::fast_av(&self.x, &delta)
3359 }
3360 fn excess_and_ngs(&self, s: &Array1<f64>) -> (f64, Array1<f64>) {
3362 let mut excess = 0.0;
3363 let mut ngs = Array1::<f64>::zeros(s.len());
3364 for i in 0..s.len() {
3365 let mu = (self.y[i] + s[i]).tanh();
3366 excess += 0.5 * s[i] * s[i] - 0.1 * mu;
3367 ngs[i] = mu - self.y[i];
3368 }
3369 (excess, ngs)
3370 }
3371 }
3372 impl super::BlockExcessTarget for MatvecBlock {
3373 fn block_dim(&self) -> usize {
3374 self.lambdas.len()
3375 }
3376 fn rho_dim(&self) -> usize {
3377 self.lambdas.len()
3378 }
3379 fn block_curvatures(&self) -> &Array1<f64> {
3380 &self.lambdas
3381 }
3382 fn excess(&self, t: &Array1<f64>) -> f64 {
3383 self.excess_and_ngs(&self.s_of(t)).0
3384 }
3385 fn excess_rho_gradient(&self, t: &Array1<f64>) -> Array1<f64> {
3386 t.mapv(|x| 0.01 * x)
3387 }
3388 fn displaced_neg_score(&self, t: &Array1<f64>) -> Array1<f64> {
3389 self.excess_and_ngs(&self.s_of(t)).1
3390 }
3391 fn base_neg_score(&self) -> Array1<f64> {
3392 self.excess_and_ngs(&self.s_of(&Array1::zeros(self.block_dim())))
3393 .1
3394 }
3395 fn excess_with_displaced_neg_score_batch(
3396 &self,
3397 draws: &Array2<f64>,
3398 ) -> Vec<(f64, Option<Array1<f64>>)> {
3399 if !self.batched {
3400 let mut out = Vec::with_capacity(draws.ncols());
3402 let mut t = Array1::<f64>::zeros(draws.nrows());
3403 for s in 0..draws.ncols() {
3404 t.assign(&draws.column(s));
3405 out.push(self.excess_with_displaced_neg_score(&t));
3406 }
3407 return out;
3408 }
3409 let delta_all = gam_linalg::faer_ndarray::fast_ab(&self.v_b, draws);
3411 let s_all = gam_linalg::faer_ndarray::fast_ab(&self.x, &delta_all);
3412 (0..draws.ncols())
3413 .map(|c| {
3414 let (e, ngs) = self.excess_and_ngs(&s_all.column(c).to_owned());
3415 if e.is_finite() {
3416 (e, Some(ngs))
3417 } else {
3418 (e, None)
3419 }
3420 })
3421 .collect()
3422 }
3423 }
3424
3425 #[test]
3426 fn block_sampled_marginal_batched_matches_serial_matvec() {
3427 let n = 80usize;
3431 let p = 40usize;
3432 let m = 3usize;
3433 let mut x = Array2::<f64>::zeros((n, p));
3434 for i in 0..n {
3435 for j in 0..p {
3436 x[(i, j)] = ((i * 7 + j * 13) % 11) as f64 * 0.05 - 0.25;
3437 }
3438 }
3439 let mut v_b = Array2::<f64>::zeros((p, m));
3440 for i in 0..p {
3441 for r in 0..m {
3442 v_b[(i, r)] = ((i * 3 + r * 5) % 7) as f64 * 0.1 - 0.3;
3443 }
3444 }
3445 let y: Array1<f64> = (0..n).map(|i| ((i % 5) as f64) * 0.2).collect();
3446 let lambdas = array![2.0, 1.0, 0.5];
3447
3448 let serial = super::block_sampled_marginal_correction(&MatvecBlock {
3449 lambdas: lambdas.clone(),
3450 x: x.clone(),
3451 v_b: v_b.clone(),
3452 y: y.clone(),
3453 batched: false,
3454 })
3455 .expect("serial");
3456 let batched = super::block_sampled_marginal_correction(&MatvecBlock {
3457 lambdas,
3458 x,
3459 v_b,
3460 y,
3461 batched: true,
3462 })
3463 .expect("batched");
3464
3465 assert_eq!(serial.n_draws, batched.n_draws);
3466 assert!(
3467 (serial.value - batched.value).abs() <= 1e-10 * (1.0 + serial.value.abs()),
3468 "value serial {} vs batched {}",
3469 serial.value,
3470 batched.value
3471 );
3472 for k in 0..serial.rho_gradient.len() {
3473 assert!(
3474 (serial.rho_gradient[k] - batched.rho_gradient[k]).abs()
3475 <= 1e-10 * (1.0 + serial.rho_gradient[k].abs()),
3476 "rho_gradient[{k}] serial {} vs batched {}",
3477 serial.rho_gradient[k],
3478 batched.rho_gradient[k]
3479 );
3480 }
3481 let ms = serial.moments.expect("serial moments");
3482 let mb = batched.moments.expect("batched moments");
3483 for (a, b) in ms.e_t.iter().zip(mb.e_t.iter()) {
3484 assert!((a - b).abs() <= 1e-10 * (1.0 + a.abs()), "e_t {a} vs {b}");
3485 }
3486 for (a, b) in ms.e_neg_score.iter().zip(mb.e_neg_score.iter()) {
3487 assert!(
3488 (a - b).abs() <= 1e-10 * (1.0 + a.abs()),
3489 "e_neg_score {a} vs {b}"
3490 );
3491 }
3492 for (a, b) in ms.e_t_neg_score.iter().zip(mb.e_t_neg_score.iter()) {
3493 assert!(
3494 (a - b).abs() <= 1e-10 * (1.0 + a.abs()),
3495 "e_t_neg_score {a} vs {b}"
3496 );
3497 }
3498 }
3499
3500 #[test]
3501 fn logit_pg_rao_blackwell_returns_finite_terms() {
3502 let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
3503 let y = array![1.0, 0.0, 1.0, 0.0];
3504 let w = array![1.0, 1.0, 1.0, 1.0];
3505 let penalty = array![[0.2, 0.0], [0.0, 0.4]];
3506 let mode = array![0.0, 0.0];
3507 let roots = vec![array![[0.2_f64.sqrt(), 0.0], [0.0, 0.4_f64.sqrt()]]];
3508 let cfg = NutsConfig {
3509 n_samples: 30,
3510 nwarmup: 30,
3511 n_chains: 2,
3512 target_accept: 0.8,
3513 seed: 789,
3514 };
3515
3516 let rb = super::estimate_logit_pg_rao_blackwell_terms(
3517 x.view(),
3518 y.view(),
3519 w.view(),
3520 penalty.view(),
3521 mode.view(),
3522 &roots,
3523 &cfg,
3524 )
3525 .expect("rao-blackwell PG should run");
3526
3527 assert_eq!(rb.len(), 1);
3528 assert!(rb[0].is_finite());
3529 assert!(rb[0] >= 0.0);
3530 }
3531
3532 #[test]
3533 fn logit_pg_rao_blackwell_rejects_non_bernoulli_response() {
3534 let x = array![[1.0], [1.0]];
3535 let y = array![0.25, 1.0];
3536 let w = array![1.0, 1.0];
3537 let penalty = array![[0.1]];
3538 let mode = array![0.0];
3539 let roots = vec![array![[0.1_f64.sqrt()]]];
3540 let cfg = NutsConfig {
3541 n_samples: 1,
3542 nwarmup: 1,
3543 n_chains: 1,
3544 target_accept: 0.8,
3545 seed: 654,
3546 };
3547
3548 let result = super::estimate_logit_pg_rao_blackwell_terms(
3549 x.view(),
3550 y.view(),
3551 w.view(),
3552 penalty.view(),
3553 mode.view(),
3554 &roots,
3555 &cfg,
3556 );
3557
3558 let err = result
3559 .err()
3560 .expect("PG Rao-Blackwell should reject proportion rows");
3561 assert!(
3562 err.contains("response must be exactly 0 or 1"),
3563 "unexpected error: {err}"
3564 );
3565 }
3566
3567 #[test]
3568 fn logit_pg_rao_blackwell_matches_beta_quadratic_moment_sanity() {
3569 let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
3570 let y = array![1.0, 0.0, 1.0, 0.0];
3571 let w = array![1.0, 1.0, 1.0, 1.0];
3572 let penalty = array![[0.2, 0.0], [0.0, 0.4]];
3573 let mode = array![0.0, 0.0];
3574 let roots = vec![array![[0.2_f64.sqrt(), 0.0], [0.0, 0.4_f64.sqrt()]]];
3575 let cfg = NutsConfig {
3576 n_samples: 120,
3577 nwarmup: 80,
3578 n_chains: 2,
3579 target_accept: 0.8,
3580 seed: 901,
3581 };
3582
3583 let gibbs = run_logit_polya_gamma_gibbs(
3584 x.view(),
3585 y.view(),
3586 w.view(),
3587 penalty.view(),
3588 mode.view(),
3589 &cfg,
3590 )
3591 .expect("pg gibbs should run");
3592 let mc_quad = gibbs
3593 .samples
3594 .rows()
3595 .into_iter()
3596 .map(|beta| {
3597 let sb = penalty.dot(&beta.to_owned());
3598 beta.dot(&sb)
3599 })
3600 .sum::<f64>()
3601 / (gibbs.samples.nrows() as f64);
3602
3603 let rb = super::estimate_logit_pg_rao_blackwell_terms(
3604 x.view(),
3605 y.view(),
3606 w.view(),
3607 penalty.view(),
3608 mode.view(),
3609 &roots,
3610 &cfg,
3611 )
3612 .expect("rao-blackwell PG should run");
3613
3614 let diff = (rb[0] - mc_quad).abs();
3615 assert!(
3616 diff < 0.35,
3617 "Rao-Blackwell vs beta-moment mismatch too large: rb={}, mc={}, diff={}",
3618 rb[0],
3619 mc_quad,
3620 diff
3621 );
3622 }
3623
3624 #[test]
3625 fn survival_hmc_structural_monotonic_returns_finitevalues() {
3626 let age_entry = array![1.0];
3627 let age_exit = array![2.0];
3628 let event_target = array![1u8];
3629 let event_competing = array![0u8];
3630 let sampleweight = array![1.0];
3631 let x_entry = array![[1.0, 0.2]];
3632 let x_exit = array![[1.0, 0.6]];
3633 let x_derivative = array![[0.0, 1.0]];
3634 let penalties = PenaltyBlocks::new(Vec::new());
3635 let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3636 let mode = array![0.0, 0.0];
3637 let hessian = Array2::<f64>::eye(2);
3638
3639 let posterior = super::survival_hmc::SurvivalPosterior::new(
3640 age_entry.view(),
3641 age_exit.view(),
3642 event_target.view(),
3643 event_competing.view(),
3644 sampleweight.view(),
3645 x_entry.view(),
3646 x_exit.view(),
3647 x_derivative.view(),
3648 None,
3649 None,
3650 None,
3651 penalties,
3652 monotonicity,
3653 SurvivalSpec::Net,
3654 true,
3655 2,
3656 mode.view(),
3657 hessian.view(),
3658 )
3659 .expect("construct survival posterior");
3660
3661 let position = array![0.0, 0.0];
3662 let mut grad = Array1::<f64>::zeros(2);
3663 let logp = HamiltonianTarget::logp_and_grad(&posterior, &position, &mut grad);
3664 assert!(logp.is_finite());
3665 assert!(grad.iter().all(|v| v.is_finite()));
3666 }
3667
3668 #[test]
3669 fn survival_hmc_structural_monotonic_differs_from_linear_geometry() {
3670 let age_entry = array![1.0];
3671 let age_exit = array![2.0];
3672 let event_target = array![1u8];
3673 let event_competing = array![0u8];
3674 let sampleweight = array![1.0];
3675 let x_entry = array![[0.2, 0.1]];
3676 let x_exit = array![[0.6, 0.3]];
3677 let x_derivative = array![[1.0, 0.0]];
3678 let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3679 let mode = array![0.0, 0.0];
3680 let hessian = Array2::<f64>::eye(2);
3681 let z = array![std::f64::consts::LN_2, 0.0];
3682
3683 let posterior_linear = super::survival_hmc::SurvivalPosterior::new(
3684 age_entry.view(),
3685 age_exit.view(),
3686 event_target.view(),
3687 event_competing.view(),
3688 sampleweight.view(),
3689 x_entry.view(),
3690 x_exit.view(),
3691 x_derivative.view(),
3692 None,
3693 None,
3694 None,
3695 PenaltyBlocks::new(Vec::new()),
3696 monotonicity,
3697 SurvivalSpec::Net,
3698 false,
3699 0,
3700 mode.view(),
3701 hessian.view(),
3702 )
3703 .expect("construct linear posterior");
3704 let mut grad_linear = Array1::<f64>::zeros(2);
3705 HamiltonianTarget::logp_and_grad(&posterior_linear, &z, &mut grad_linear);
3706
3707 let posterior_struct = super::survival_hmc::SurvivalPosterior::new(
3708 age_entry.view(),
3709 age_exit.view(),
3710 event_target.view(),
3711 event_competing.view(),
3712 sampleweight.view(),
3713 x_entry.view(),
3714 x_exit.view(),
3715 x_derivative.view(),
3716 None,
3717 None,
3718 None,
3719 PenaltyBlocks::new(Vec::new()),
3720 monotonicity,
3721 SurvivalSpec::Net,
3722 true,
3723 2,
3724 mode.view(),
3725 hessian.view(),
3726 )
3727 .expect("construct structural posterior");
3728 let mut grad_struct = Array1::<f64>::zeros(2);
3729 HamiltonianTarget::logp_and_grad(&posterior_struct, &z, &mut grad_struct);
3730
3731 assert!(
3732 (grad_struct[0] - grad_linear[0]).abs() > 1e-6,
3733 "expected structural and linear fallback gradients to differ"
3734 );
3735 assert!(grad_struct[0].is_finite());
3736 assert!(grad_linear[0].is_finite());
3737 }
3738
3739 #[test]
3740 fn survival_hmc_fallback_barrier_rejects_offsets_below_monotonicity_threshold() {
3741 let age_entry = array![1.0];
3742 let age_exit = array![2.0];
3743 let event_target = array![1u8];
3744 let event_competing = array![0u8];
3745 let sampleweight = array![1.0];
3746 let x_entry = array![[1.0, 0.0]];
3747 let x_exit = array![[1.0, 0.0]];
3748 let x_derivative = array![[0.0, 0.0]];
3750 let penalties = PenaltyBlocks::new(Vec::new());
3751 let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3752 let mode = array![0.0, 0.0];
3753 let hessian = Array2::<f64>::eye(2);
3754 let z = array![0.0, 0.0];
3755
3756 let posterior_no_offset = super::survival_hmc::SurvivalPosterior::new(
3757 age_entry.view(),
3758 age_exit.view(),
3759 event_target.view(),
3760 event_competing.view(),
3761 sampleweight.view(),
3762 x_entry.view(),
3763 x_exit.view(),
3764 x_derivative.view(),
3765 None,
3766 None,
3767 Some(array![0.0].view()),
3768 penalties.clone(),
3769 monotonicity,
3770 SurvivalSpec::Net,
3771 false,
3772 0,
3773 mode.view(),
3774 hessian.view(),
3775 )
3776 .expect("construct posterior without derivative offset");
3777 let mut grad_no_offset = Array1::<f64>::zeros(2);
3778 let logp_no_offset =
3779 HamiltonianTarget::logp_and_grad(&posterior_no_offset, &z, &mut grad_no_offset);
3780
3781 let posteriorwith_offset = super::survival_hmc::SurvivalPosterior::new(
3782 age_entry.view(),
3783 age_exit.view(),
3784 event_target.view(),
3785 event_competing.view(),
3786 sampleweight.view(),
3787 x_entry.view(),
3788 x_exit.view(),
3789 x_derivative.view(),
3790 None,
3791 None,
3792 Some(array![2.0].view()),
3793 penalties,
3794 monotonicity,
3795 SurvivalSpec::Net,
3796 false,
3797 0,
3798 mode.view(),
3799 hessian.view(),
3800 )
3801 .expect("construct posterior with derivative offset");
3802 let mut gradwith_offset = Array1::<f64>::zeros(2);
3803 let logpwith_offset =
3804 HamiltonianTarget::logp_and_grad(&posteriorwith_offset, &z, &mut gradwith_offset);
3805
3806 assert!(!logp_no_offset.is_finite());
3807 assert!(!logpwith_offset.is_finite());
3808 assert!(grad_no_offset.iter().all(|v| *v == 0.0));
3809 assert!(gradwith_offset.iter().all(|v| *v == 0.0));
3810 }
3811
3812 #[test]
3813 fn survival_hmc_fallback_barrier_becomes_finite_once_offset_clears_guard() {
3814 let age_entry = array![1.0];
3815 let age_exit = array![2.0];
3816 let event_target = array![1u8];
3817 let event_competing = array![0u8];
3818 let sampleweight = array![1.0];
3819 let x_entry = array![[1.0, 0.0]];
3820 let x_exit = array![[1.0, 0.0]];
3821 let x_derivative = array![[0.0, 0.0]];
3822 let penalties = PenaltyBlocks::new(Vec::new());
3823 let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3824 let mode = array![0.0, 0.0];
3825 let hessian = Array2::<f64>::eye(2);
3826 let z = array![0.0, 0.0];
3827
3828 let posterior_below_guard = super::survival_hmc::SurvivalPosterior::new(
3829 age_entry.view(),
3830 age_exit.view(),
3831 event_target.view(),
3832 event_competing.view(),
3833 sampleweight.view(),
3834 x_entry.view(),
3835 x_exit.view(),
3836 x_derivative.view(),
3837 None,
3838 None,
3839 Some(array![2.0].view()),
3840 penalties.clone(),
3841 monotonicity,
3842 SurvivalSpec::Net,
3843 false,
3844 0,
3845 mode.view(),
3846 hessian.view(),
3847 )
3848 .expect("construct posterior below derivative guard");
3849 let mut grad_below_guard = Array1::<f64>::zeros(2);
3850 let logp_below_guard =
3851 HamiltonianTarget::logp_and_grad(&posterior_below_guard, &z, &mut grad_below_guard);
3852
3853 let posterior_above_guard = super::survival_hmc::SurvivalPosterior::new(
3854 age_entry.view(),
3855 age_exit.view(),
3856 event_target.view(),
3857 event_competing.view(),
3858 sampleweight.view(),
3859 x_entry.view(),
3860 x_exit.view(),
3861 x_derivative.view(),
3862 None,
3863 None,
3864 Some(array![3.1].view()),
3865 penalties,
3866 monotonicity,
3867 SurvivalSpec::Net,
3868 false,
3869 0,
3870 mode.view(),
3871 hessian.view(),
3872 )
3873 .expect("construct posterior above derivative guard");
3874 let mut grad_above_guard = Array1::<f64>::zeros(2);
3875 let logp_above_guard =
3876 HamiltonianTarget::logp_and_grad(&posterior_above_guard, &z, &mut grad_above_guard);
3877
3878 assert!(!logp_below_guard.is_finite());
3879 assert!(logp_above_guard.is_finite());
3880 assert!(grad_below_guard.iter().all(|v| *v == 0.0));
3881 assert!(grad_above_guard.iter().all(|v| v.is_finite()));
3882 }
3883
3884 #[test]
3885 fn survival_hmc_structural_monotonic_handles_sparse_multirow_geometry() {
3886 let age_entry = array![1.0, 1.2];
3887 let age_exit = array![2.0, 2.4];
3888 let event_target = array![1u8, 1u8];
3889 let event_competing = array![0u8, 0u8];
3890 let sampleweight = array![1.0, 1.0];
3891 let x_entry = array![[0.1, 0.0, 0.2], [0.2, 0.1, 0.2]];
3892 let x_exit = array![[0.4, 0.2, 0.3], [0.6, 0.1, 0.3]];
3893 let x_derivative = array![[1.0, 0.0, 0.0], [0.5, 1.0, 0.0]];
3895 let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3896 let mode = array![4.0, 2.0, 0.0];
3897 let hessian = Array2::<f64>::eye(3);
3898 let z = array![0.05, -0.1, 0.15];
3899
3900 let posterior = super::survival_hmc::SurvivalPosterior::new(
3901 age_entry.view(),
3902 age_exit.view(),
3903 event_target.view(),
3904 event_competing.view(),
3905 sampleweight.view(),
3906 x_entry.view(),
3907 x_exit.view(),
3908 x_derivative.view(),
3909 None,
3910 None,
3911 None,
3912 PenaltyBlocks::new(Vec::new()),
3913 monotonicity,
3914 SurvivalSpec::Net,
3915 true,
3916 2,
3917 mode.view(),
3918 hessian.view(),
3919 )
3920 .expect("construct structural posterior");
3921
3922 let mut grad = Array1::<f64>::zeros(3);
3923 let logp = HamiltonianTarget::logp_and_grad(&posterior, &z, &mut grad);
3924 assert!(logp.is_finite());
3925 assert!(grad.iter().all(|v| v.is_finite()));
3926 }
3927}
3928
3929impl HamiltonianTarget<Array1<f64>> for NutsPosterior {
3931 fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
3932 NUTS_RESIDUAL_SCRATCH.with(|scratch| {
3933 let mut residual = scratch.borrow_mut();
3934 if residual.len() != self.data.n_samples {
3935 *residual = Array1::<f64>::zeros(self.data.n_samples);
3936 }
3937 self.compute_logp_and_grad_nd_into(position, &mut residual, grad)
3938 })
3939 }
3940}
3941
3942#[derive(Clone, Debug, Serialize, Deserialize)]
3944pub struct NutsConfig {
3945 pub n_samples: usize,
3947 pub nwarmup: usize,
3949 pub n_chains: usize,
3951 pub target_accept: f64,
3953 #[serde(default = "default_nuts_seed")]
3955 pub seed: u64,
3956}
3957
3958fn default_nuts_seed() -> u64 {
3959 42
3960}
3961
3962fn validate_nuts_target_accept(target_accept: f64) -> Result<(), HmcError> {
3963 if target_accept.is_finite() && target_accept > 0.0 && target_accept < 1.0 {
3964 Ok(())
3965 } else {
3966 Err(HmcError::InvalidConfig {
3967 reason: format!(
3968 "NUTS target_accept must be finite and lie in (0, 1), got {target_accept}"
3969 ),
3970 })
3971 }
3972}
3973
3974const MIN_NUTS_SAMPLES: usize = 4;
3982
3983const MIN_NUTS_CHAINS: usize = 1;
3992
3993fn validate_nuts_draws(config: &NutsConfig) -> Result<(), HmcError> {
3998 if config.n_chains < MIN_NUTS_CHAINS {
3999 return Err(HmcError::InvalidConfig {
4000 reason: format!(
4001 "NUTS n_chains must be >= {MIN_NUTS_CHAINS}; with zero chains the \
4002 sampler has no initial positions to run, got {}",
4003 config.n_chains
4004 ),
4005 });
4006 }
4007 if config.n_samples < MIN_NUTS_SAMPLES {
4008 return Err(HmcError::InvalidConfig {
4009 reason: format!(
4010 "NUTS n_samples must be >= {MIN_NUTS_SAMPLES} so split-R-hat / ESS \
4011 diagnostics are defined, got {}",
4012 config.n_samples
4013 ),
4014 });
4015 }
4016 Ok(())
4017}
4018
4019pub(crate) fn validate_nuts_config(config: &NutsConfig) -> Result<(), HmcError> {
4023 validate_nuts_target_accept(config.target_accept)?;
4024 validate_nuts_draws(config)?;
4025 Ok(())
4026}
4027
4028#[inline]
4029fn splitmix64(x: u64) -> u64 {
4030 gam_linalg::utils::splitmix64_hash(x)
4031}
4032
4033#[inline]
4034fn chain_stream_seed(seed: u64, chain: usize, stream: u64) -> u64 {
4035 splitmix64(seed ^ stream ^ ((chain as u64).wrapping_mul(0xD1B5_4A32_D192_ED03)))
4036}
4037
4038#[inline]
4039fn nuts_transition_seed(seed: u64, stream: u64) -> u64 {
4040 splitmix64(seed ^ stream ^ 0xA24B_AED4_963E_E407)
4041}
4042
4043#[inline]
4044fn gibbs_pg_seed(seed: u64, chain: usize, stream: u64, iter: usize) -> u64 {
4045 chain_stream_seed(
4046 seed,
4047 chain,
4048 stream ^ ((iter as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15)),
4049 )
4050}
4051
4052fn draw_logit_pg1_omega(
4053 shapes: ArrayView1<'_, u32>,
4054 tilts: ArrayView1<'_, f64>,
4055 seed: u64,
4056 out: &mut Array1<f64>,
4057) -> Result<(), String> {
4058 if out.len() != tilts.len() {
4059 return Err(HmcError::DimensionMismatch {
4060 reason: "draw_logit_pg1_omega: output length mismatch".to_string(),
4061 }
4062 .into());
4063 }
4064 let draws = crate::gpu_polya_gamma::draw_batch(PolyaGammaBatchInput {
4065 shapes,
4066 tilts,
4067 seed: PgSeed(seed),
4068 })?;
4069 out.assign(&draws);
4070 out.mapv_inplace(|v| v.max(1.0e-12));
4071 Ok(())
4072}
4073
4074const HIGH_DIM_THRESHOLD: usize = 50;
4080
4081const HIGH_DIM_TARGET_ACCEPT_FLOOR: f64 = 0.92;
4085const LOW_DIM_TARGET_ACCEPT_FLOOR: f64 = 0.90;
4087const MAX_TARGET_ACCEPT: f64 = 0.95;
4090
4091const MIN_WARMUP_FOR_MASS_ADAPT: usize = 80;
4096
4097const DENSE_MASS_MATRIX_MAX_DIM: usize = 75;
4101
4102const MASS_REGULARIZE_HIGH_DIM: f64 = 0.14;
4106const MASS_REGULARIZE_LOW_DIM: f64 = 0.10;
4107const SURVIVAL_MASS_REGULARIZE_HIGH_DIM: f64 = 0.18;
4110const SURVIVAL_MASS_REGULARIZE_LOW_DIM: f64 = 0.12;
4111
4112const MASS_MATRIX_JITTER: f64 = 1e-5;
4115
4116#[inline]
4117fn robust_target_accept(requested: f64, dim: usize) -> f64 {
4118 let floor = if dim > HIGH_DIM_THRESHOLD {
4119 HIGH_DIM_TARGET_ACCEPT_FLOOR
4120 } else {
4121 LOW_DIM_TARGET_ACCEPT_FLOOR
4122 };
4123 requested.max(floor).min(MAX_TARGET_ACCEPT)
4124}
4125
4126fn jittered_initial_positions(
4127 config: &NutsConfig,
4128 dim: usize,
4129 scale: f64,
4130 stream: u64,
4131) -> Vec<Array1<f64>> {
4132 (0..config.n_chains)
4133 .map(|chain| {
4134 let mut rng = StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, stream));
4135 Array1::from_shape_fn(dim, |_| sample_standard_normal(&mut rng) * scale)
4136 })
4137 .collect()
4138}
4139
4140fn robust_mass_matrix_config(dim: usize, nwarmup: usize) -> NUTSMassMatrixConfig {
4141 if nwarmup < MIN_WARMUP_FOR_MASS_ADAPT {
4142 return NUTSMassMatrixConfig::disabled();
4143 }
4144 let start_buffer = (nwarmup / 8).clamp(35, 180);
4145 let end_buffer = (nwarmup / 5).clamp(50, 250);
4146 let initial_window = (nwarmup / 20).clamp(10, 60);
4147 NUTSMassMatrixConfig {
4148 adaptation: MassMatrixAdaptation::Diagonal,
4149 start_buffer,
4150 end_buffer,
4151 initial_window,
4152 regularize: if dim > HIGH_DIM_THRESHOLD {
4153 MASS_REGULARIZE_HIGH_DIM
4154 } else {
4155 MASS_REGULARIZE_LOW_DIM
4156 },
4157 jitter: MASS_MATRIX_JITTER,
4158 dense_max_dim: DENSE_MASS_MATRIX_MAX_DIM,
4159 }
4160}
4161
4162fn robust_survival_mass_matrix_config(dim: usize, nwarmup: usize) -> NUTSMassMatrixConfig {
4163 if nwarmup < MIN_WARMUP_FOR_MASS_ADAPT {
4164 return NUTSMassMatrixConfig::disabled();
4165 }
4166 let start_buffer = (nwarmup / 7).clamp(40, 200);
4169 let end_buffer = (nwarmup / 4).clamp(60, 280);
4170 let initial_window = (nwarmup / 20).clamp(10, 60);
4171 NUTSMassMatrixConfig {
4172 adaptation: MassMatrixAdaptation::Diagonal,
4173 start_buffer,
4174 end_buffer,
4175 initial_window,
4176 regularize: if dim > HIGH_DIM_THRESHOLD {
4177 SURVIVAL_MASS_REGULARIZE_HIGH_DIM
4178 } else {
4179 SURVIVAL_MASS_REGULARIZE_LOW_DIM
4180 },
4181 jitter: MASS_MATRIX_JITTER,
4182 dense_max_dim: DENSE_MASS_MATRIX_MAX_DIM,
4183 }
4184}
4185
4186impl Default for NutsConfig {
4187 fn default() -> Self {
4188 Self {
4189 n_samples: 1000,
4190 nwarmup: 500,
4191 n_chains: 4,
4192 target_accept: 0.9,
4193 seed: 42,
4194 }
4195 }
4196}
4197
4198impl NutsConfig {
4199 pub fn for_dimension(n_params: usize) -> Self {
4207 let effective_autocorr = (n_params as f64).sqrt().max(1.0);
4209
4210 let target_ess = 100 * n_params;
4212
4213 let raw_samples = (target_ess as f64 * (1.0 + 2.0 * effective_autocorr) * 1.5) as usize;
4215
4216 let n_samples = raw_samples.clamp(500, 10_000);
4218
4219 let nwarmup = n_samples;
4221
4222 let n_chains = if n_params > 50 { 4 } else { 2 };
4224
4225 Self {
4226 n_samples,
4227 nwarmup,
4228 n_chains,
4229 target_accept: 0.9,
4230 seed: 42,
4231 }
4232 }
4233}
4234
4235#[derive(Clone, Debug)]
4237pub struct NutsResult {
4238 pub samples: Array2<f64>,
4240 pub posterior_mean: Array1<f64>,
4242 pub posterior_std: Array1<f64>,
4244 pub rhat: f64,
4246 pub ess: f64,
4248 pub converged: bool,
4250}
4251
4252#[derive(Clone, Copy)]
4253struct NutsConvergenceThresholds {
4254 max_rhat: f64,
4255 min_ess: Option<f64>,
4256}
4257
4258impl NutsConvergenceThresholds {
4259 #[inline]
4260 fn converged(self, rhat: f64, ess: f64) -> bool {
4261 let rhat_ok = rhat < self.max_rhat;
4262 match self.min_ess {
4263 Some(min_ess) => rhat_ok && ess > min_ess,
4264 None => rhat_ok,
4265 }
4266 }
4267}
4268
4269fn run_whitened_nuts_samples<Target>(
4270 target: Target,
4271 initial_positions: Vec<Array1<f64>>,
4272 config: &NutsConfig,
4273 dim: usize,
4274 mass_cfg: NUTSMassMatrixConfig,
4275 transition_seed_stream: u64,
4276 sampling_error_label: &str,
4277) -> Result<(Array3<f64>, String), String>
4278where
4279 Target: HamiltonianTarget<Array1<f64>> + Sync + Send,
4280{
4281 let mut sampler = GenericNUTS::new_with_mass_matrix(
4282 target,
4283 initial_positions,
4284 robust_target_accept(config.target_accept, dim),
4285 mass_cfg,
4286 )
4287 .set_seed(nuts_transition_seed(config.seed, transition_seed_stream));
4288
4289 let (samples_array, run_stats) = sampler
4290 .run_progress(config.n_samples, config.nwarmup)
4291 .map_err(|e| format!("{sampling_error_label}: {e}"))?;
4292 Ok((samples_array, run_stats.to_string()))
4293}
4294
4295fn unwhiten_samples(
4296 samples_array: &Array3<f64>,
4297 mode: &Array1<f64>,
4298 chol: &Array2<f64>,
4299 dim: usize,
4300 z_start: usize,
4301) -> Array2<f64> {
4302 let shape = samples_array.shape();
4303 let n_chains = shape[0];
4304 let n_samples_out = shape[1];
4305 let total_samples = n_chains * n_samples_out;
4306
4307 let mut samples = Array2::<f64>::zeros((total_samples, dim));
4308 let mut z_buffer = Array1::<f64>::zeros(dim);
4309 for chain in 0..n_chains {
4310 for sample_i in 0..n_samples_out {
4311 let zview = samples_array.slice(ndarray::s![chain, sample_i, z_start..z_start + dim]);
4312 z_buffer.assign(&zview);
4313 let beta = mode + &chol.dot(&z_buffer);
4314 let sample_idx = chain * n_samples_out + sample_i;
4315 samples.row_mut(sample_idx).assign(&beta);
4316 }
4317 }
4318
4319 samples
4320}
4321
4322fn summarize_unwhitened_nuts_samples(
4323 samples: Array2<f64>,
4324 samples_array: &Array3<f64>,
4325 empty_mean: Array1<f64>,
4326 convergence: NutsConvergenceThresholds,
4327) -> NutsResult {
4328 let posterior_mean = samples.mean_axis(Axis(0)).unwrap_or(empty_mean);
4329 let posterior_std = samples.std_axis(Axis(0), 0.0);
4330 let (rhat, ess) = compute_split_rhat_and_ess(samples_array);
4331 let converged = convergence.converged(rhat, ess);
4332
4333 NutsResult {
4334 samples,
4335 posterior_mean,
4336 posterior_std,
4337 rhat,
4338 ess,
4339 converged,
4340 }
4341}
4342
4343fn run_whitened_nuts_result<Target>(
4344 target: Target,
4345 mode: &Array1<f64>,
4346 chol: &Array2<f64>,
4347 initial_positions: Vec<Array1<f64>>,
4348 config: &NutsConfig,
4349 dim: usize,
4350 mass_cfg: NUTSMassMatrixConfig,
4351 transition_seed_stream: u64,
4352 sampling_error_label: &str,
4353 empty_mean: Array1<f64>,
4354 convergence: NutsConvergenceThresholds,
4355) -> Result<(NutsResult, String), String>
4356where
4357 Target: HamiltonianTarget<Array1<f64>> + Sync + Send,
4358{
4359 let (samples_array, run_stats) = run_whitened_nuts_samples(
4360 target,
4361 initial_positions,
4362 config,
4363 dim,
4364 mass_cfg,
4365 transition_seed_stream,
4366 sampling_error_label,
4367 )?;
4368 let samples = unwhiten_samples(&samples_array, mode, chol, dim, 0);
4369 let result =
4370 summarize_unwhitened_nuts_samples(samples, &samples_array, empty_mean, convergence);
4371 Ok((result, run_stats))
4372}
4373
4374impl NutsResult {
4375 pub fn posterior_mean_of<F>(&self, f: F) -> f64
4378 where
4379 F: Fn(ArrayView1<f64>) -> f64 + Sync,
4380 {
4381 let n = self.samples.nrows();
4382 if n == 0 {
4383 return 0.0;
4384 }
4385 use rayon::iter::{IntoParallelIterator, ParallelIterator};
4388 let sum: f64 = (0..n).into_par_iter().map(|i| f(self.samples.row(i))).sum();
4389 sum / n as f64
4390 }
4391
4392 pub fn posterior_interval_of<F>(&self, f: F, lower_pct: f64, upper_pct: f64) -> (f64, f64)
4394 where
4395 F: Fn(ArrayView1<f64>) -> f64,
4396 {
4397 let n = self.samples.nrows();
4398 if n == 0 {
4399 return (0.0, 0.0);
4400 }
4401 let mut values: Vec<f64> = (0..n).map(|i| f(self.samples.row(i))).collect();
4402 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
4403
4404 let lower_idx = ((lower_pct / 100.0) * n as f64).floor() as usize;
4405 let upper_idx = ((upper_pct / 100.0) * n as f64).ceil() as usize;
4406
4407 (
4408 values[lower_idx.min(n.saturating_sub(1))],
4409 values[upper_idx.min(n.saturating_sub(1))],
4410 )
4411 }
4412}
4413
4414#[inline]
4415fn sample_standard_normal<R: rand::Rng + ?Sized>(rng: &mut R) -> f64 {
4416 let u1 = rng.random::<f64>().max(1e-16);
4417 let u2 = rng.random::<f64>();
4418 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
4419}
4420
4421pub fn run_logit_polya_gamma_gibbs(
4430 x: ArrayView2<f64>,
4431 y: ArrayView1<f64>,
4432 weights: ArrayView1<f64>,
4433 penalty_matrix: ArrayView2<f64>,
4434 mode: ArrayView1<f64>,
4435 config: &NutsConfig,
4436) -> Result<NutsResult, String> {
4437 let n = x.nrows();
4438 let p = x.ncols();
4439 if y.len() != n || weights.len() != n {
4440 return Err(HmcError::DimensionMismatch {
4441 reason: "run_logit_polya_gamma_gibbs: input length mismatch".to_string(),
4442 }
4443 .into());
4444 }
4445 if mode.len() != p || penalty_matrix.nrows() != p || penalty_matrix.ncols() != p {
4446 return Err(HmcError::DimensionMismatch {
4447 reason: "run_logit_polya_gamma_gibbs: coefficient/penalty dimension mismatch"
4448 .to_string(),
4449 }
4450 .into());
4451 }
4452 if !weights.iter().all(|w| (*w - 1.0).abs() <= 1e-10) {
4453 return Err(HmcError::InvalidConfig {
4454 reason: "run_logit_polya_gamma_gibbs requires unit weights (PG(1,·)); use NUTS for non-unit weights".to_string(),
4455 }
4456 .into());
4457 }
4458 validate_binary_responses("run_logit_polya_gamma_gibbs", &y, &weights).map_err(String::from)?;
4459 validate_nuts_config(config).map_err(String::from)?;
4466
4467 let n_iter = config.nwarmup + config.n_samples;
4468
4469 let kappa = y.mapv(|v| v - 0.5);
4471 let rhs_b = fast_atv(&x, &kappa);
4472
4473 let mut samples_array = Array3::<f64>::zeros((config.n_chains, config.n_samples, p));
4474 let mut eta = Array1::<f64>::zeros(n);
4475 let mut omega = Array1::<f64>::ones(n);
4476 let pg_shapes = Array1::<u32>::from_elem(n, 1);
4477 let mut xw = x.to_owned();
4478 let mut xt_omega_x = Array2::<f64>::zeros((p, p));
4479 let penalty = penalty_matrix.to_owned();
4480 let mut q = Array2::<f64>::zeros((p, p));
4481 let mut mean = Array1::<f64>::zeros(p);
4482 let mut z = Array1::<f64>::zeros(p);
4483 let mut noise = Array1::<f64>::zeros(p);
4484
4485 for chain in 0..config.n_chains {
4486 let mut init_rng =
4487 StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0xB3C4_5A1F_8E9D_7632));
4488 let mut draw_rng =
4489 StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0x17A9_26D5_4C1B_E083));
4490 let mut beta = mode.to_owned();
4491 for j in 0..p {
4493 beta[j] += 0.05 * sample_standard_normal(&mut init_rng);
4494 }
4495
4496 for iter in 0..n_iter {
4497 eta.assign(&gam_linalg::faer_ndarray::fast_av(&x, &beta));
4498 draw_logit_pg1_omega(
4499 pg_shapes.view(),
4500 eta.view(),
4501 gibbs_pg_seed(config.seed, chain, 0x4D94_DF4E_5D72_81AB, iter),
4502 &mut omega,
4503 )?;
4504
4505 ndarray::Zip::indexed(xw.rows_mut())
4508 .and(x.rows())
4509 .and(&omega)
4510 .par_for_each(|_idx, mut xw_row, x_row, omega_i| {
4511 let s = omega_i.sqrt();
4512 for j in 0..p {
4513 xw_row[j] = x_row[j] * s;
4514 }
4515 });
4516 fast_ata_into(&xw, &mut xt_omega_x);
4517
4518 q.assign(&penalty);
4519 q += &xt_omega_x;
4520
4521 let factor = q
4523 .cholesky(Side::Lower)
4524 .map_err(|e| format!("PG Gibbs failed to factor Q: {:?}", e))?;
4525 mean.assign(&factor.solvevec(&rhs_b));
4526
4527 for j in 0..p {
4528 z[j] = sample_standard_normal(&mut draw_rng);
4529 }
4530 let l = factor.lower_triangular();
4531 back_substitution_lower_transpose_guarded_into(&l, &z, &mut noise);
4532 beta.assign(&(&mean + &noise));
4533
4534 if iter >= config.nwarmup {
4535 let keep_idx = iter - config.nwarmup;
4536 samples_array
4537 .slice_mut(ndarray::s![chain, keep_idx, ..])
4538 .assign(&beta);
4539 }
4540 }
4541 }
4542
4543 let total_samples = config.n_chains * config.n_samples;
4544 let mut samples = Array2::<f64>::zeros((total_samples, p));
4545 for chain in 0..config.n_chains {
4546 for s in 0..config.n_samples {
4547 let idx = chain * config.n_samples + s;
4548 samples
4549 .row_mut(idx)
4550 .assign(&samples_array.slice(ndarray::s![chain, s, ..]));
4551 }
4552 }
4553
4554 let posterior_mean = samples
4555 .mean_axis(Axis(0))
4556 .unwrap_or_else(|| Array1::zeros(p));
4557 let posterior_std = samples.std_axis(Axis(0), 0.0);
4558 let (rhat, ess) = if config.n_chains >= 2 && config.n_samples >= 4 {
4559 compute_split_rhat_and_ess(&samples_array)
4560 } else {
4561 (1.0, (total_samples as f64) * 0.5)
4562 };
4563 let converged = rhat < 1.1 && ess > 100.0;
4564
4565 Ok(NutsResult {
4566 samples,
4567 posterior_mean,
4568 posterior_std,
4569 rhat,
4570 ess,
4571 converged,
4572 })
4573}
4574
4575pub fn estimate_logit_pg_rao_blackwell_terms(
4585 x: ArrayView2<f64>,
4586 y: ArrayView1<f64>,
4587 weights: ArrayView1<f64>,
4588 penalty_matrix: ArrayView2<f64>,
4589 mode: ArrayView1<f64>,
4590 penalty_roots: &[Array2<f64>],
4591 config: &NutsConfig,
4592) -> Result<Array1<f64>, String> {
4593 let n = x.nrows();
4594 let p = x.ncols();
4595 if y.len() != n || weights.len() != n {
4596 return Err(HmcError::DimensionMismatch {
4597 reason: "estimate_logit_pg_rao_blackwell_terms: input length mismatch".to_string(),
4598 }
4599 .into());
4600 }
4601 if mode.len() != p || penalty_matrix.nrows() != p || penalty_matrix.ncols() != p {
4602 return Err(HmcError::DimensionMismatch {
4603 reason: "estimate_logit_pg_rao_blackwell_terms: coefficient/penalty dimension mismatch"
4604 .to_string(),
4605 }
4606 .into());
4607 }
4608 if !weights.iter().all(|w| (*w - 1.0).abs() <= 1e-10) {
4609 return Err(HmcError::InvalidConfig {
4610 reason: "estimate_logit_pg_rao_blackwell_terms requires unit weights (PG(1,·))"
4611 .to_string(),
4612 }
4613 .into());
4614 }
4615 validate_binary_responses("estimate_logit_pg_rao_blackwell_terms", &y, &weights)
4616 .map_err(String::from)?;
4617 if penalty_roots.iter().any(|r| r.ncols() != p) {
4618 return Err(HmcError::DimensionMismatch {
4619 reason: "estimate_logit_pg_rao_blackwell_terms: root width mismatch".to_string(),
4620 }
4621 .into());
4622 }
4623 let penalty_roots_t: Vec<Array2<f64>> =
4626 penalty_roots.iter().map(|r| r.t().to_owned()).collect();
4627
4628 let n_iter = config.nwarmup + config.n_samples;
4629
4630 let kappa = y.mapv(|v| v - 0.5);
4633 let rhs_b = fast_atv(&x, &kappa);
4634
4635 let penalty = penalty_matrix.to_owned();
4636 let mut eta = Array1::<f64>::zeros(n);
4637 let mut omega = Array1::<f64>::ones(n);
4638 let pg_shapes = Array1::<u32>::from_elem(n, 1);
4639 let mut xw = x.to_owned();
4640 let mut xt_omega_x = Array2::<f64>::zeros((p, p));
4641 let mut q = Array2::<f64>::zeros((p, p));
4642 let mut mean = Array1::<f64>::zeros(p);
4643 let mut rb_sum = Array1::<f64>::zeros(penalty_roots.len());
4644 let mut z = Array1::<f64>::zeros(p);
4645 let mut noise = Array1::<f64>::zeros(p);
4646
4647 let mut kept = 0usize;
4648 for chain in 0..config.n_chains {
4649 let mut init_rng =
4650 StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0x28F0_7B65_1A4D_C93E));
4651 let mut draw_rng =
4652 StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0xC642_6E35_B5A9_1D80));
4653 let mut beta = mode.to_owned();
4654 for j in 0..p {
4655 beta[j] += 0.05 * sample_standard_normal(&mut init_rng);
4656 }
4657
4658 for iter in 0..n_iter {
4659 eta.assign(&gam_linalg::faer_ndarray::fast_av(&x, &beta));
4660 draw_logit_pg1_omega(
4661 pg_shapes.view(),
4662 eta.view(),
4663 gibbs_pg_seed(config.seed, chain, 0x83F1_56C9_A7E0_2D4B, iter),
4664 &mut omega,
4665 )?;
4666
4667 ndarray::Zip::from(xw.rows_mut())
4668 .and(x.rows())
4669 .and(&omega)
4670 .par_for_each(|mut xw_row, x_row, &omega_i| {
4671 let s = omega_i.sqrt();
4672 for j in 0..p {
4673 xw_row[j] = x_row[j] * s;
4674 }
4675 });
4676 fast_ata_into(&xw, &mut xt_omega_x);
4677
4678 q.assign(&penalty);
4681 q += &xt_omega_x;
4682
4683 let factor = q
4684 .cholesky(Side::Lower)
4685 .map_err(|e| format!("PG Rao-Blackwell failed to factor Q: {:?}", e))?;
4686 mean.assign(&factor.solvevec(&rhs_b));
4689
4690 for j in 0..p {
4692 z[j] = sample_standard_normal(&mut draw_rng);
4693 }
4694 let l = factor.lower_triangular();
4695 back_substitution_lower_transpose_guarded_into(&l, &z, &mut noise);
4696 beta.assign(&(&mean + &noise));
4697
4698 if iter < config.nwarmup {
4699 continue;
4700 }
4701 kept += 1;
4702
4703 for (k, r_k) in penalty_roots.iter().enumerate() {
4704 if r_k.nrows() == 0 {
4705 continue;
4706 }
4707
4708 let rmu = r_k.dot(&mean);
4710 let mu_quad = rmu.dot(&rmu);
4711
4712 let solved_mat = factor.solve_mat(&penalty_roots_t[k]); let solved_t = solved_mat.t();
4717 let mut trace_term = 0.0_f64;
4718 for (&a, &b) in r_k.iter().zip(solved_t.iter()) {
4719 trace_term += a * b;
4720 }
4721
4722 rb_sum[k] += trace_term + mu_quad;
4723 }
4724 }
4725 }
4726
4727 if kept == 0 {
4728 return Err(HmcError::SamplingFailed {
4729 reason: "estimate_logit_pg_rao_blackwell_terms: no retained samples".to_string(),
4730 }
4731 .into());
4732 }
4733 let out = rb_sum.mapv(|v| v / (kept as f64));
4734 if !out.iter().all(|v| v.is_finite()) {
4735 return Err(HmcError::NonFiniteState {
4736 reason: "estimate_logit_pg_rao_blackwell_terms: non-finite expectation".to_string(),
4737 }
4738 .into());
4739 }
4740 Ok(out)
4741}
4742
4743pub(crate) fn run_nuts_sampling(
4756 x: ArrayView2<f64>,
4757 y: ArrayView1<f64>,
4758 weights: ArrayView1<f64>,
4759 penalty_matrix: ArrayView2<f64>,
4760 mode: ArrayView1<f64>,
4761 hessian: ArrayView2<f64>,
4762 nuts_family: NutsFamily,
4763 gamma_shape: f64,
4764 dispersion: gam_solve::model_types::Dispersion,
4765 firth_bias_reduction: bool,
4766 offset: Option<ArrayView1<f64>>,
4767 config: &NutsConfig,
4768) -> Result<NutsResult, String> {
4769 validate_firth_support(nuts_family, firth_bias_reduction).map_err(String::from)?;
4770 validate_nuts_config(config).map_err(String::from)?;
4771 if nuts_family == NutsFamily::TweedieLog && !is_valid_tweedie_power(gamma_shape) {
4772 return Err(format!(
4773 "Tweedie variance power must be finite and strictly between 1 and 2; got {gamma_shape}"
4774 ));
4775 }
4776 let dim = mode.len();
4777
4778 let target = NutsPosterior::new(
4781 x,
4782 y,
4783 weights,
4784 penalty_matrix,
4785 mode,
4786 hessian,
4787 nuts_family,
4788 gamma_shape,
4789 dispersion,
4790 firth_bias_reduction,
4791 )?;
4792 let target = match offset {
4793 Some(offset) => target.with_offset(offset)?,
4794 None => target,
4795 };
4796
4797 let chol = target.chol().clone();
4799 let mode_arr = target.mode().clone();
4800
4801 let initial_positions = jittered_initial_positions(config, dim, 0.1, 0x0F65_83B2_BC71_4D9E);
4802 let mass_cfg = robust_mass_matrix_config(dim, config.nwarmup);
4803 let (result, run_stats) = run_whitened_nuts_result(
4804 target,
4805 &mode_arr,
4806 &chol,
4807 initial_positions,
4808 config,
4809 dim,
4810 mass_cfg,
4811 0xF1D3_C2B5_A697_804E,
4812 "NUTS sampling failed",
4813 Array1::zeros(dim),
4814 NutsConvergenceThresholds {
4815 max_rhat: 1.1,
4816 min_ess: Some(100.0),
4817 },
4818 )?;
4819 log::info!("NUTS sampling complete: {}", run_stats);
4820
4821 Ok(result)
4822}
4823
4824struct GaussianModeTarget;
4846
4847impl HamiltonianTarget<Array1<f64>> for GaussianModeTarget {
4848 #[inline]
4849 fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
4850 let mut quad = 0.0;
4855 for (g, &zi) in grad.iter_mut().zip(position.iter()) {
4856 *g = -zi;
4857 quad += zi * zi;
4858 }
4859 -0.5 * quad
4860 }
4861}
4862
4863pub fn sample_gaussian_mode_posterior(
4878 mode: ArrayView1<f64>,
4879 hessian: ArrayView2<f64>,
4880 config: &NutsConfig,
4881) -> Result<GaussianModePosterior, String> {
4882 validate_nuts_config(config).map_err(String::from)?;
4883 let dim = mode.len();
4884 if hessian.nrows() != dim || hessian.ncols() != dim {
4885 return Err(format!(
4886 "Gaussian-posterior fallback: hessian shape {:?} does not match mode dim {dim}",
4887 hessian.dim()
4888 ));
4889 }
4890 if dim == 0 {
4891 return Err("Gaussian-posterior fallback: zero-dimensional posterior".to_string());
4892 }
4893
4894 let mut h = hessian.to_owned();
4900 for i in 0..dim {
4901 for j in (i + 1)..dim {
4902 let avg = 0.5 * (h[[i, j]] + h[[j, i]]);
4903 h[[i, j]] = avg;
4904 h[[j, i]] = avg;
4905 }
4906 }
4907 let diag_scale = (0..dim).map(|i| h[[i, i]].abs()).fold(0.0_f64, f64::max);
4908 let jitter = (diag_scale * 1e-10).max(1e-12);
4909 for i in 0..dim {
4910 h[[i, i]] += jitter;
4911 }
4912
4913 let mode_owned = mode.to_owned();
4914 let whitening = hessian_whitening_transform(
4915 h.view(),
4916 dim,
4917 1.0,
4918 "Gaussian-posterior fallback Cholesky failed",
4919 )?;
4920 let chol = whitening.chol;
4921 let target = GaussianModeTarget;
4922 let initial_positions = jittered_initial_positions(config, dim, 0.1, 0x51A6_2C73_90E4_1DBF);
4923 let mass_cfg = robust_mass_matrix_config(dim, config.nwarmup);
4924 let (result, run_stats) = run_whitened_nuts_result(
4925 target,
4926 &mode_owned,
4927 &chol,
4928 initial_positions,
4929 config,
4930 dim,
4931 mass_cfg,
4932 0x7C19_5A3E_82D6_44B1,
4933 "Gaussian-posterior fallback NUTS sampling failed",
4934 mode_owned.clone(),
4935 NutsConvergenceThresholds {
4936 max_rhat: 1.1,
4937 min_ess: None,
4938 },
4939 )?;
4940 log::info!(
4941 "never-fail Gaussian-posterior fallback: sampling complete dim={dim} {}",
4942 run_stats
4943 );
4944
4945 Ok(GaussianModePosterior {
4946 samples: result.samples,
4947 posterior_mean: result.posterior_mean,
4948 posterior_std: result.posterior_std,
4949 rhat: result.rhat,
4950 ess: result.ess,
4951 })
4952}
4953
4954const RHO_NUTS_INFEASIBLE_LOGP_PENALTY: f64 = 1.0e8;
4960
4961struct WhitenedRhoCriterionTarget<F> {
4977 criterion_and_grad: Mutex<F>,
4979 mode: Array1<f64>,
4981 chol: Array2<f64>,
4983 chol_t: Array2<f64>,
4985 cost_hat: f64,
4987}
4988
4989impl<F> HamiltonianTarget<Array1<f64>> for WhitenedRhoCriterionTarget<F>
4990where
4991 F: FnMut(&Array1<f64>) -> Option<(f64, Array1<f64>)> + Send,
4992{
4993 fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
4994 let rho = &self.mode + &self.chol.dot(position);
4995 let eval = {
4996 let mut criterion = self
4997 .criterion_and_grad
4998 .lock()
4999 .expect("rho-criterion mutex poisoned");
5000 (*criterion)(&rho)
5001 };
5002 match eval {
5003 Some((cost, g))
5004 if cost.is_finite()
5005 && g.len() == position.len()
5006 && g.iter().all(|v| v.is_finite()) =>
5007 {
5008 let grad_z = self.chol_t.dot(&g);
5009 for (gi, &v) in grad.iter_mut().zip(grad_z.iter()) {
5010 *gi = -v;
5011 }
5012 -(cost - self.cost_hat)
5013 }
5014 _ => {
5015 let mut quad = 0.0;
5017 for (gi, &zi) in grad.iter_mut().zip(position.iter()) {
5018 *gi = -zi;
5019 quad += zi * zi;
5020 }
5021 -0.5 * quad - RHO_NUTS_INFEASIBLE_LOGP_PENALTY
5022 }
5023 }
5024 }
5025}
5026
5027pub fn run_rho_criterion_nuts<F>(
5042 rho_hat: ArrayView1<f64>,
5043 outer_hessian: ArrayView2<f64>,
5044 mut criterion_and_grad: F,
5045 config: &NutsConfig,
5046) -> Result<NutsResult, String>
5047where
5048 F: FnMut(&Array1<f64>) -> Option<(f64, Array1<f64>)> + Send,
5049{
5050 validate_nuts_config(config).map_err(String::from)?;
5051 let dim = rho_hat.len();
5052 if dim == 0 {
5053 return Err("rho-posterior NUTS: zero-dimensional rho".to_string());
5054 }
5055 if outer_hessian.nrows() != dim || outer_hessian.ncols() != dim {
5056 return Err(format!(
5057 "rho-posterior NUTS: outer Hessian shape {:?} does not match rho dim {dim}",
5058 outer_hessian.dim()
5059 ));
5060 }
5061
5062 let mut h = outer_hessian.to_owned();
5065 for i in 0..dim {
5066 for j in (i + 1)..dim {
5067 let avg = 0.5 * (h[[i, j]] + h[[j, i]]);
5068 h[[i, j]] = avg;
5069 h[[j, i]] = avg;
5070 }
5071 }
5072 let diag_scale = (0..dim).map(|i| h[[i, i]].abs()).fold(0.0_f64, f64::max);
5073 let jitter = (diag_scale * 1e-10).max(1e-12);
5074 for i in 0..dim {
5075 h[[i, i]] += jitter;
5076 }
5077
5078 let mode = rho_hat.to_owned();
5079 let whitening = hessian_whitening_transform(
5080 h.view(),
5081 dim,
5082 1.0,
5083 "rho-posterior NUTS: outer-Hessian Cholesky failed",
5084 )?;
5085
5086 let cost_hat = match criterion_and_grad(&mode) {
5087 Some((cost, _)) if cost.is_finite() => cost,
5088 _ => {
5089 return Err(
5090 "rho-posterior NUTS: criterion is infeasible at rho_hat itself".to_string(),
5091 );
5092 }
5093 };
5094
5095 let chol = whitening.chol;
5096 let target = WhitenedRhoCriterionTarget {
5097 criterion_and_grad: Mutex::new(criterion_and_grad),
5098 mode: mode.clone(),
5099 chol: chol.clone(),
5100 chol_t: whitening.chol_t,
5101 cost_hat,
5102 };
5103 let initial_positions = jittered_initial_positions(config, dim, 0.1, 0x3D8A_91C4_E27B_5F60);
5104 let mass_cfg = NUTSMassMatrixConfig::disabled();
5109 let (result, run_stats) = run_whitened_nuts_result(
5110 target,
5111 &mode,
5112 &chol,
5113 initial_positions,
5114 config,
5115 dim,
5116 mass_cfg,
5117 0x6B42_E9A1_05D7_C83F,
5118 "rho-posterior NUTS sampling failed",
5119 mode.clone(),
5120 NutsConvergenceThresholds {
5121 max_rhat: 1.1,
5122 min_ess: None,
5123 },
5124 )?;
5125 log::info!("rho-posterior NUTS (#938 tier 2): sampling complete dim={dim} {run_stats}");
5126 Ok(result)
5127}
5128
5129pub struct GlmFlatInputs<'a> {
5131 pub x: ArrayView2<'a, f64>,
5132 pub y: ArrayView1<'a, f64>,
5133 pub weights: ArrayView1<'a, f64>,
5134 pub penalty_matrix: ArrayView2<'a, f64>,
5135 pub mode: ArrayView1<'a, f64>,
5136 pub hessian: ArrayView2<'a, f64>,
5137 pub gamma_shape: Option<f64>,
5138 pub dispersion: gam_solve::model_types::Dispersion,
5145 pub firth_bias_reduction: bool,
5146 pub offset: Option<ArrayView1<'a, f64>>,
5151}
5152
5153pub struct SurvivalFlatInputs<'a> {
5155 pub age_entry: ArrayView1<'a, f64>,
5156 pub age_exit: ArrayView1<'a, f64>,
5157 pub event_target: ArrayView1<'a, u8>,
5158 pub event_competing: ArrayView1<'a, u8>,
5159 pub weights: ArrayView1<'a, f64>,
5160 pub x_entry: ArrayView2<'a, f64>,
5161 pub x_exit: ArrayView2<'a, f64>,
5162 pub x_derivative: ArrayView2<'a, f64>,
5163 pub eta_offset_entry: Option<ArrayView1<'a, f64>>,
5164 pub eta_offset_exit: Option<ArrayView1<'a, f64>>,
5165 pub derivative_offset_exit: Option<ArrayView1<'a, f64>>,
5166}
5167
5168pub struct SurvivalNutsInputs<'a> {
5170 pub flat: SurvivalFlatInputs<'a>,
5171 pub penalties: gam_models::survival::PenaltyBlocks,
5172 pub monotonicity: gam_models::survival::SurvivalMonotonicityPenalty,
5173 pub spec: gam_models::survival::SurvivalSpec,
5174 pub structurally_monotonic: bool,
5175 pub structural_time_columns: usize,
5176 pub mode: ArrayView1<'a, f64>,
5177 pub hessian: ArrayView2<'a, f64>,
5178}
5179
5180pub enum FamilyNutsInputs<'a> {
5182 Glm(GlmFlatInputs<'a>),
5183 Survival(Box<SurvivalNutsInputs<'a>>),
5184}
5185
5186pub fn explicit_fit_hessian_for_whitening<'a>(
5195 fit: &'a UnifiedFitResult,
5196 expected_dim: usize,
5197 label: &str,
5198) -> Result<&'a Array2<f64>, String> {
5199 let hessian = fit.penalized_hessian().ok_or_else(|| {
5200 format!(
5201 "{label}: fit result is missing an explicit penalized Hessian for HMC/NUTS whitening"
5202 )
5203 })?;
5204 validate_explicit_dense_hessian_for_whitening(
5205 &format!("{label} penalized Hessian"),
5206 hessian,
5207 expected_dim,
5208 )
5209 .map_err(|err| err.to_string())?;
5210 Ok(hessian)
5211}
5212
5213pub fn run_nuts_sampling_flattened_family(
5215 likelihood: LikelihoodSpec,
5216 inputs: FamilyNutsInputs<'_>,
5217 config: &NutsConfig,
5218) -> Result<NutsResult, String> {
5219 if let FamilyNutsInputs::Glm(glm) = &inputs
5220 && glm.firth_bias_reduction
5221 && !likelihood_spec_supports_firth(&likelihood)
5222 {
5223 return Err(HmcError::FirthUnsupported {
5224 reason: format!(
5225 "NUTS with Firth requires a Binomial inverse link with a Fisher-weight jet; {} does not support it",
5226 likelihood.pretty_name()
5227 ),
5228 }
5229 .into());
5230 }
5231
5232 match (likelihood.response.clone(), likelihood.link.clone(), inputs) {
5233 (
5234 ResponseFamily::Gaussian,
5235 InverseLink::Standard(StandardLink::Identity),
5236 FamilyNutsInputs::Glm(glm),
5237 ) => run_nuts_sampling(
5238 glm.x,
5239 glm.y,
5240 glm.weights,
5241 glm.penalty_matrix,
5242 glm.mode,
5243 glm.hessian,
5244 NutsFamily::Gaussian,
5245 1.0,
5246 glm.dispersion,
5247 glm.firth_bias_reduction,
5248 glm.offset,
5249 config,
5250 ),
5251 (
5252 ResponseFamily::Binomial,
5253 InverseLink::Standard(StandardLink::Logit),
5254 FamilyNutsInputs::Glm(glm),
5255 ) => {
5256 if !glm.firth_bias_reduction
5263 && glm.offset.is_none()
5264 && glm.weights.iter().all(|w| (*w - 1.0).abs() <= 1e-10)
5265 {
5266 run_logit_polya_gamma_gibbs(
5267 glm.x,
5268 glm.y,
5269 glm.weights,
5270 glm.penalty_matrix,
5271 glm.mode,
5272 config,
5273 )
5274 } else {
5275 run_nuts_sampling(
5276 glm.x,
5277 glm.y,
5278 glm.weights,
5279 glm.penalty_matrix,
5280 glm.mode,
5281 glm.hessian,
5282 NutsFamily::BinomialLogit,
5283 1.0,
5284 glm.dispersion,
5285 glm.firth_bias_reduction,
5286 glm.offset,
5287 config,
5288 )
5289 }
5290 }
5291 (
5292 ResponseFamily::Binomial,
5293 InverseLink::Standard(StandardLink::Probit),
5294 FamilyNutsInputs::Glm(glm),
5295 ) => run_nuts_sampling(
5296 glm.x,
5297 glm.y,
5298 glm.weights,
5299 glm.penalty_matrix,
5300 glm.mode,
5301 glm.hessian,
5302 NutsFamily::BinomialProbit,
5303 1.0,
5304 glm.dispersion,
5305 glm.firth_bias_reduction,
5306 glm.offset,
5307 config,
5308 ),
5309 (
5310 ResponseFamily::Binomial,
5311 InverseLink::Standard(StandardLink::CLogLog),
5312 FamilyNutsInputs::Glm(glm),
5313 ) => run_nuts_sampling(
5314 glm.x,
5315 glm.y,
5316 glm.weights,
5317 glm.penalty_matrix,
5318 glm.mode,
5319 glm.hessian,
5320 NutsFamily::BinomialCLogLog,
5321 1.0,
5322 glm.dispersion,
5323 glm.firth_bias_reduction,
5324 glm.offset,
5325 config,
5326 ),
5327 (
5328 ResponseFamily::Binomial,
5329 InverseLink::LatentCLogLog(_),
5330 FamilyNutsInputs::Glm(glm),
5331 ) => run_nuts_sampling(
5332 glm.x,
5333 glm.y,
5334 glm.weights,
5335 glm.penalty_matrix,
5336 glm.mode,
5337 glm.hessian,
5338 NutsFamily::BinomialCLogLog,
5339 1.0,
5340 glm.dispersion,
5341 glm.firth_bias_reduction,
5342 glm.offset,
5343 config,
5344 ),
5345 (ResponseFamily::Binomial, InverseLink::Mixture(_), FamilyNutsInputs::Glm(_)) => Err(
5346 "BinomialMixture NUTS is not implemented yet; use fit_gam/predict_gam for blended inverse-link models"
5347 .to_string(),
5348 ),
5349 (ResponseFamily::Binomial, InverseLink::Sas(_), FamilyNutsInputs::Glm(_)) => Err(
5350 "BinomialSas NUTS is not implemented yet; use fit_gam/predict_gam for SAS-link models"
5351 .to_string(),
5352 ),
5353 (ResponseFamily::Binomial, InverseLink::BetaLogistic(_), FamilyNutsInputs::Glm(_)) => Err(
5354 "BinomialBetaLogistic NUTS is not implemented yet; use fit_gam/predict_gam for beta-logistic-link models"
5355 .to_string(),
5356 ),
5357 (ResponseFamily::Binomial, InverseLink::Standard(_), FamilyNutsInputs::Glm(_)) => Err(
5358 "NUTS sampling is not implemented for this binomial inverse link".to_string(),
5359 ),
5360 (ResponseFamily::RoystonParmar, _, FamilyNutsInputs::Survival(survival)) => {
5361 survival_hmc::run_survival_nuts_sampling(
5362 survival.flat.age_entry,
5363 survival.flat.age_exit,
5364 survival.flat.event_target,
5365 survival.flat.event_competing,
5366 survival.flat.weights,
5367 survival.flat.x_entry,
5368 survival.flat.x_exit,
5369 survival.flat.x_derivative,
5370 survival.flat.eta_offset_entry,
5371 survival.flat.eta_offset_exit,
5372 survival.flat.derivative_offset_exit,
5373 survival.penalties,
5374 survival.monotonicity,
5375 survival.spec,
5376 survival.structurally_monotonic,
5377 survival.structural_time_columns,
5378 survival.mode,
5379 survival.hessian,
5380 config,
5381 )
5382 }
5383 (ResponseFamily::RoystonParmar, _, FamilyNutsInputs::Glm(_)) => Err(
5384 "RoystonParmar family requires FamilyNutsInputs::Survival flattened inputs".to_string(),
5385 ),
5386 (_, _, FamilyNutsInputs::Survival(_)) => Err(
5387 "Survival flattened inputs are only valid for the Royston-Parmar response family"
5388 .to_string(),
5389 ),
5390 (ResponseFamily::Poisson, _, FamilyNutsInputs::Glm(glm)) => run_nuts_sampling(
5391 glm.x,
5392 glm.y,
5393 glm.weights,
5394 glm.penalty_matrix,
5395 glm.mode,
5396 glm.hessian,
5397 NutsFamily::PoissonLog,
5398 1.0,
5399 glm.dispersion,
5400 glm.firth_bias_reduction,
5401 glm.offset,
5402 config,
5403 ),
5404 (ResponseFamily::Tweedie { p }, _, FamilyNutsInputs::Glm(glm)) => {
5405 if !is_valid_tweedie_power(p) {
5408 return Err(format!(
5409 "Tweedie variance power must be finite and strictly between 1 and 2; got {p}"
5410 ));
5411 }
5412 run_nuts_sampling(
5413 glm.x,
5414 glm.y,
5415 glm.weights,
5416 glm.penalty_matrix,
5417 glm.mode,
5418 glm.hessian,
5419 NutsFamily::TweedieLog,
5420 p,
5421 glm.dispersion,
5422 glm.firth_bias_reduction,
5423 glm.offset,
5424 config,
5425 )
5426 }
5427 (ResponseFamily::NegativeBinomial { theta, .. }, _, FamilyNutsInputs::Glm(glm)) => {
5428 run_nuts_sampling(
5431 glm.x,
5432 glm.y,
5433 glm.weights,
5434 glm.penalty_matrix,
5435 glm.mode,
5436 glm.hessian,
5437 NutsFamily::NegativeBinomialLog,
5438 theta,
5439 glm.dispersion,
5440 glm.firth_bias_reduction,
5441 glm.offset,
5442 config,
5443 )
5444 }
5445 (ResponseFamily::Beta { .. }, _, FamilyNutsInputs::Glm(_)) => Err(
5446 "NUTS sampling is not implemented for beta-regression logit".to_string(),
5447 ),
5448 (ResponseFamily::Gamma, _, FamilyNutsInputs::Glm(glm)) => run_nuts_sampling(
5449 glm.x,
5450 glm.y,
5451 glm.weights,
5452 glm.penalty_matrix,
5453 glm.mode,
5454 glm.hessian,
5455 NutsFamily::GammaLog,
5456 glm.gamma_shape.unwrap_or(1.0),
5457 glm.dispersion,
5458 glm.firth_bias_reduction,
5459 glm.offset,
5460 config,
5461 ),
5462 (ResponseFamily::Gaussian, _, FamilyNutsInputs::Glm(_)) => Err(
5463 "NUTS sampling is only implemented for Gaussian with identity link".to_string(),
5464 ),
5465 }
5466}
5467
5468#[derive(Clone)]
5489pub struct LinkWiggleSplineArtifacts {
5490 pub knot_range: (f64, f64),
5492 pub knot_vector: Array1<f64>,
5494 pub degree: usize,
5496}
5497
5498#[derive(Clone)]
5500pub struct LinkWigglePosterior {
5501 x: Arc<Array2<f64>>,
5503 y: Arc<Array1<f64>>,
5504 weights: Arc<Array1<f64>>,
5505 penalty_base: Arc<Array2<f64>>,
5507 penalty_link: Arc<Array2<f64>>,
5509 mode_beta: Arc<Array1<f64>>,
5510 mode_theta: Arc<Array1<f64>>,
5511 spline: LinkWiggleSplineArtifacts,
5512 chol: Array2<f64>,
5514 chol_t: Array2<f64>,
5516 p_base: usize,
5517 p_link: usize,
5518 n_samples: usize,
5519 nuts_family: NutsFamily,
5520 scale: f64,
5522 cov_scale: f64,
5528}
5529
5530impl LinkWigglePosterior {
5531 #[inline]
5533 fn standardized_z(&self, u: &Array1<f64>) -> (Array1<f64>, Array1<f64>, f64) {
5534 let (min_u, max_u) = self.spline.knot_range;
5535 let rw = (max_u - min_u).max(1e-6);
5536 let z_raw: Array1<f64> = u.mapv(|v| (v - min_u) / rw);
5537 let z_c: Array1<f64> = z_raw.mapv(|z| z.clamp(0.0, 1.0));
5538 (z_raw, z_c, rw)
5539 }
5540
5541 pub fn new(
5543 x: ArrayView2<f64>,
5544 y: ArrayView1<f64>,
5545 weights: ArrayView1<f64>,
5546 penalty_base: ArrayView2<f64>,
5547 penalty_link: ArrayView2<f64>,
5548 mode_beta: ArrayView1<f64>,
5549 mode_theta: ArrayView1<f64>,
5550 hessian: ArrayView2<f64>,
5551 spline: LinkWiggleSplineArtifacts,
5552 nuts_family: NutsFamily,
5553 scale: f64,
5554 ) -> Result<Self, String> {
5555 let n_samples = x.nrows();
5556 let p_base = x.ncols();
5557 let p_link = mode_theta.len();
5558 let dim = p_base + p_link;
5559 if hessian.nrows() != dim || hessian.ncols() != dim {
5560 return Err(HmcError::DimensionMismatch {
5561 reason: format!(
5562 "LinkWigglePosterior: Hessian dim mismatch: {}x{} vs expected {}x{}",
5563 hessian.nrows(),
5564 hessian.ncols(),
5565 dim,
5566 dim,
5567 ),
5568 }
5569 .into());
5570 }
5571 if nuts_family.likelihood_spec().is_binomial() {
5572 validate_binary_responses("binomial link-wiggle NUTS", &y, &weights)
5573 .map_err(String::from)?;
5574 }
5575 if matches!(nuts_family, NutsFamily::NegativeBinomialLog) {
5576 validate_count_responses("negative-binomial link-wiggle NUTS", &y, &weights)
5577 .map_err(String::from)?;
5578 }
5579 let cov_scale = match nuts_family {
5588 NutsFamily::Gaussian => scale * scale,
5589 _ => 1.0,
5590 };
5591 let whitening = hessian_whitening_transform(
5592 hessian,
5593 dim,
5594 cov_scale,
5595 "LinkWigglePosterior Cholesky failed",
5596 )?;
5597 let chol = whitening.chol;
5598 let chol_t = whitening.chol_t;
5599 Ok(Self {
5600 x: Arc::new(x.to_owned()),
5601 y: Arc::new(y.to_owned()),
5602 weights: Arc::new(weights.to_owned()),
5603 penalty_base: Arc::new(penalty_base.to_owned()),
5604 penalty_link: Arc::new(penalty_link.to_owned()),
5605 mode_beta: Arc::new(mode_beta.to_owned()),
5606 mode_theta: Arc::new(mode_theta.to_owned()),
5607 spline,
5608 chol,
5609 chol_t,
5610 p_base,
5611 p_link,
5612 n_samples,
5613 nuts_family,
5614 scale,
5615 cov_scale,
5616 })
5617 }
5618
5619 fn evaluate_link(&self, u: &Array1<f64>, theta: &Array1<f64>) -> (Array2<f64>, Array1<f64>) {
5621 let n = u.len();
5622 if theta.is_empty() {
5623 return (Array2::zeros((n, 0)), u.clone());
5624 }
5625
5626 let (z_raw, z_c, _) = self.standardized_z(u);
5627 let Ok(mut basis) = monotone_wiggle_basis_with_derivative_order(
5628 z_c.view(),
5629 &self.spline.knot_vector,
5630 self.spline.degree,
5631 0,
5632 ) else {
5633 return (Array2::zeros((n, theta.len())), u.clone());
5634 };
5635 if basis.ncols() != theta.len() {
5636 return (Array2::zeros((n, theta.len())), u.clone());
5637 }
5638
5639 let mut needs_ext = false;
5642 for i in 0..n {
5643 if (z_raw[i] - z_c[i]).abs() > 1e-12 {
5644 needs_ext = true;
5645 break;
5646 }
5647 }
5648 if needs_ext
5649 && let Ok(b_prime) = monotone_wiggle_basis_with_derivative_order(
5650 z_c.view(),
5651 &self.spline.knot_vector,
5652 self.spline.degree,
5653 1,
5654 )
5655 {
5656 for i in 0..n {
5657 let dz = z_raw[i] - z_c[i];
5658 if dz.abs() <= 1e-12 {
5659 continue;
5660 }
5661 for j in 0..basis.ncols().min(b_prime.ncols()) {
5662 basis[[i, j]] += dz * b_prime[[i, j]];
5663 }
5664 }
5665 }
5666 (
5667 basis.clone(),
5668 u + &gam_linalg::faer_ndarray::fast_av(&basis, theta),
5669 )
5670 }
5671
5672 fn compute_g_prime(&self, u: &Array1<f64>, theta: &Array1<f64>) -> Array1<f64> {
5674 let n = u.len();
5675 let mut g = Array1::<f64>::ones(n);
5676 let (_, z_c, rw) = self.standardized_z(u);
5677 if theta.is_empty() {
5678 return g;
5679 }
5680
5681 let Ok(b_prime_constrained) = monotone_wiggle_basis_with_derivative_order(
5682 z_c.view(),
5683 &self.spline.knot_vector,
5684 self.spline.degree,
5685 1,
5686 ) else {
5687 return g;
5688 };
5689 if b_prime_constrained.ncols() != theta.len() {
5690 return g;
5691 }
5692 let dwiggle_dz = gam_linalg::faer_ndarray::fast_av(&b_prime_constrained, theta);
5693 ndarray::Zip::from(&mut g)
5694 .and(&dwiggle_dz)
5695 .par_for_each(|gi, &dw| *gi = 1.0 + dw / rw);
5696 g
5697 }
5698
5699 fn compute_logp_and_grad_into(&self, z: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
5700 let dim = self.p_base + self.p_link;
5701
5702 let mut mode = Array1::<f64>::zeros(dim);
5704 mode.slice_mut(ndarray::s![0..self.p_base])
5705 .assign(&self.mode_beta);
5706 mode.slice_mut(ndarray::s![self.p_base..])
5707 .assign(&self.mode_theta);
5708 let q = &mode + &self.chol.dot(z);
5709 let beta = q.slice(ndarray::s![0..self.p_base]).to_owned();
5710 let theta = q.slice(ndarray::s![self.p_base..]).to_owned();
5711
5712 let u = gam_linalg::faer_ndarray::fast_av(self.x.as_ref(), &beta);
5714 let (bwiggle, eta) = self.evaluate_link(&u, &theta);
5715
5716 let ll;
5718 let mut residual = Array1::<f64>::zeros(self.n_samples);
5719 match self.nuts_family {
5720 NutsFamily::Gaussian => {
5721 let inv_scale_sq = 1.0 / (self.scale * self.scale).max(1e-10);
5722 let mut ll_acc = 0.0;
5723 for i in 0..self.n_samples {
5724 let r = self.y[i] - eta[i];
5725 let w = self.weights[i];
5726 ll_acc -= 0.5 * w * r * r * inv_scale_sq;
5727 residual[i] = w * r * inv_scale_sq;
5728 }
5729 ll = ll_acc;
5730 }
5731 NutsFamily::BinomialLogit => {
5732 let mut ll_acc = 0.0;
5733 for i in 0..self.n_samples {
5734 let eta_i = eta[i];
5735 let (y_i, w_i) = (self.y[i], self.weights[i]);
5736 ll_acc += w_i * (y_i * eta_i - gam_linalg::utils::stable_softplus(eta_i));
5737 let mu = gam_linalg::utils::stable_logistic(eta_i);
5738 residual[i] = w_i * (y_i - mu);
5739 }
5740 ll = ll_acc;
5741 }
5742 NutsFamily::BinomialProbit => {
5743 let mut ll_acc = 0.0;
5744 for i in 0..self.n_samples {
5745 let eta_i = eta[i];
5746 let (y_i, w_i) = (self.y[i], self.weights[i]);
5747 let log_phi_pos = log_ndtr(eta_i);
5748 let log_phi_neg = log_ndtr(-eta_i);
5749 ll_acc += w_i * (y_i * log_phi_pos + (1.0 - y_i) * log_phi_neg);
5750 let log_phi = standard_normal_log_pdf(eta_i);
5751 let ratio_pos = (log_phi - log_phi_pos).exp();
5752 let ratio_neg = (log_phi - log_phi_neg).exp();
5753 residual[i] = w_i * (y_i * ratio_pos - (1.0 - y_i) * ratio_neg);
5754 }
5755 ll = ll_acc;
5756 }
5757 NutsFamily::BinomialCLogLog => {
5758 let mut ll_acc = 0.0;
5759 for i in 0..self.n_samples {
5760 let eta_i = eta[i];
5761 if !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)) {
5762 grad.fill(0.0);
5763 return f64::NEG_INFINITY;
5764 }
5765 let (y_i, w_i) = (self.y[i], self.weights[i]);
5766 let (ll_i, residual_i) = match cloglog_bernoulli_logp_and_residual(eta_i, y_i) {
5767 Ok(values) => values,
5768 Err(_) => {
5769 grad.fill(0.0);
5770 return f64::NEG_INFINITY;
5771 }
5772 };
5773 ll_acc += w_i * ll_i;
5774 residual[i] = w_i * residual_i;
5775 }
5776 ll = ll_acc;
5777 }
5778 NutsFamily::PoissonLog => {
5779 let mut ll_acc = 0.0;
5780 for i in 0..self.n_samples {
5781 let eta_i = eta[i];
5782 if !(eta_i.is_finite() && (-30.0..=30.0).contains(&eta_i)) {
5783 grad.fill(0.0);
5784 return f64::NEG_INFINITY;
5785 }
5786 let (y_i, w_i) = (self.y[i], self.weights[i]);
5787 let mu = eta_i.exp();
5788 ll_acc += w_i * (y_i * eta_i - mu);
5789 residual[i] = w_i * (y_i - mu);
5790 }
5791 ll = ll_acc;
5792 }
5793 NutsFamily::TweedieLog => {
5794 let mut ll_acc = 0.0;
5795 if !is_valid_tweedie_power(self.scale) {
5798 grad.fill(0.0);
5799 return f64::NEG_INFINITY;
5800 }
5801 let p = self.scale;
5802 for i in 0..self.n_samples {
5803 let eta_i = eta[i];
5804 if !(eta_i.is_finite() && (-30.0..=30.0).contains(&eta_i)) {
5805 grad.fill(0.0);
5806 return f64::NEG_INFINITY;
5807 }
5808 let (y_i, w_i) = (self.y[i], self.weights[i]);
5809 let mu = eta_i.exp().max(1e-300);
5810 ll_acc +=
5811 w_i * (y_i * mu.powf(1.0 - p) / (1.0 - p) - mu.powf(2.0 - p) / (2.0 - p));
5812 residual[i] = w_i * (y_i - mu) * mu.powf(1.0 - p);
5813 }
5814 ll = ll_acc;
5815 }
5816 NutsFamily::NegativeBinomialLog => {
5817 let mut ll_acc = 0.0;
5818 if !(self.scale.is_finite() && self.scale > 0.0) {
5821 grad.fill(0.0);
5822 return f64::NEG_INFINITY;
5823 }
5824 let theta = self.scale;
5825 for i in 0..self.n_samples {
5826 let eta_i = eta[i];
5827 if !(eta_i.is_finite() && (-30.0..=30.0).contains(&eta_i)) {
5828 grad.fill(0.0);
5829 return f64::NEG_INFINITY;
5830 }
5831 let (y_i, w_i) = (self.y[i], self.weights[i]);
5832 if w_i <= 0.0 {
5833 residual[i] = 0.0;
5834 continue;
5835 }
5836 let mu = eta_i.exp().max(1e-12);
5837 let log_mu_term = if y_i > 0.0 { y_i * mu.ln() } else { 0.0 };
5838 ll_acc += w_i
5839 * (statrs::function::gamma::ln_gamma(y_i + theta)
5840 - statrs::function::gamma::ln_gamma(theta)
5841 - statrs::function::gamma::ln_gamma(y_i + 1.0)
5842 + theta * (theta.ln() - (theta + mu).ln())
5843 + log_mu_term
5844 - y_i * (theta + mu).ln());
5845 residual[i] = w_i * theta * (y_i - mu) / (theta + mu);
5846 }
5847 ll = ll_acc;
5848 }
5849 NutsFamily::GammaLog => {
5850 let mut ll_acc = 0.0;
5851 let shape = self.scale.max(1e-10);
5852 for i in 0..self.n_samples {
5853 let eta_i = eta[i];
5854 if !(eta_i.is_finite() && (-30.0..=30.0).contains(&eta_i)) {
5855 grad.fill(0.0);
5856 return f64::NEG_INFINITY;
5857 }
5858 let (y_i, w_i) = (self.y[i], self.weights[i]);
5859 let mu = eta_i.exp();
5860 ll_acc += w_i * shape * (-y_i / mu - eta_i);
5861 residual[i] = w_i * shape * (y_i / mu - 1.0);
5862 }
5863 ll = ll_acc;
5864 }
5865 }
5866
5867 let penalty_scale = 1.0 / self.cov_scale.max(1e-300);
5880
5881 let s_link_theta = self.penalty_link.dot(&theta);
5883 let grad_theta = &fast_atv(&bwiggle, &residual) - &(&s_link_theta * penalty_scale);
5884
5885 let g_prime = self.compute_g_prime(&u, &theta);
5888 let r_scaled: Array1<f64> = residual
5889 .iter()
5890 .zip(g_prime.iter())
5891 .map(|(&r, &g)| r * g)
5892 .collect();
5893 let s_base_beta = self.penalty_base.dot(&beta);
5894 let grad_beta = &fast_atv(&self.x, &r_scaled) - &(&s_base_beta * penalty_scale);
5895
5896 let penalty =
5898 penalty_scale * (0.5 * beta.dot(&s_base_beta) + 0.5 * theta.dot(&s_link_theta));
5899
5900 let mut grad_q = Array1::<f64>::zeros(dim);
5902 grad_q
5903 .slice_mut(ndarray::s![0..self.p_base])
5904 .assign(&grad_beta);
5905 grad_q
5906 .slice_mut(ndarray::s![self.p_base..])
5907 .assign(&grad_theta);
5908 fast_av_into(&self.chol_t, &grad_q, grad);
5909 ll - penalty
5910 }
5911
5912 pub fn chol(&self) -> &Array2<f64> {
5914 &self.chol
5915 }
5916
5917 pub fn mode_joint(&self) -> Array1<f64> {
5919 let dim = self.p_base + self.p_link;
5920 let mut mode = Array1::<f64>::zeros(dim);
5921 mode.slice_mut(ndarray::s![0..self.p_base])
5922 .assign(&self.mode_beta);
5923 mode.slice_mut(ndarray::s![self.p_base..])
5924 .assign(&self.mode_theta);
5925 mode
5926 }
5927}
5928
5929impl HamiltonianTarget<Array1<f64>> for LinkWigglePosterior {
5930 fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
5931 self.compute_logp_and_grad_into(position, grad)
5932 }
5933}
5934
5935pub fn run_link_wiggle_nuts_sampling(
5937 x: ArrayView2<f64>,
5938 y: ArrayView1<f64>,
5939 weights: ArrayView1<f64>,
5940 penalty_base: ArrayView2<f64>,
5941 penalty_link: ArrayView2<f64>,
5942 mode_beta: ArrayView1<f64>,
5943 mode_theta: ArrayView1<f64>,
5944 hessian: ArrayView2<f64>,
5945 spline: LinkWiggleSplineArtifacts,
5946 nuts_family: NutsFamily,
5947 scale: f64,
5948 config: &NutsConfig,
5949) -> Result<NutsResult, String> {
5950 validate_nuts_config(config).map_err(String::from)?;
5951 let dim = mode_beta.len() + mode_theta.len();
5952 let target = LinkWigglePosterior::new(
5953 x,
5954 y,
5955 weights,
5956 penalty_base,
5957 penalty_link,
5958 mode_beta,
5959 mode_theta,
5960 hessian,
5961 spline,
5962 nuts_family,
5963 scale,
5964 )?;
5965 let chol = target.chol().clone();
5966 let mode_arr = target.mode_joint();
5967
5968 let initial_positions = jittered_initial_positions(config, dim, 0.1, 0x8C48_0F65_3A2B_D917);
5969
5970 let mass_cfg = robust_mass_matrix_config(dim, config.nwarmup);
5971 let (result, run_stats) = run_whitened_nuts_result(
5972 target,
5973 &mode_arr,
5974 &chol,
5975 initial_positions,
5976 config,
5977 dim,
5978 mass_cfg,
5979 0x2E31_A4B6_C908_F57D,
5980 "Link-wiggle NUTS sampling failed",
5981 Array1::zeros(dim),
5982 NutsConvergenceThresholds {
5983 max_rhat: 1.1,
5984 min_ess: Some(100.0),
5985 },
5986 )?;
5987 log::info!("Link-wiggle NUTS sampling complete: {}", run_stats);
5988
5989 Ok(result)
5990}
5991
5992pub fn laplace_directional_cubic_diagnostic(
6035 hessian: &Array2<f64>,
6036 design: &DesignMatrix,
6037 c_weights: &Array1<f64>,
6038 refine_supremum: bool,
6039) -> Result<(f64, Array1<f64>), String> {
6040 let p = hessian.nrows();
6041 if p == 0 || hessian.ncols() != p {
6042 return Ok((0.0, Array1::zeros(0)));
6043 }
6044
6045 let sym_h = (hessian + &hessian.t()) * 0.5;
6046 let (evals, evecs) = sym_h
6047 .eigh(Side::Lower)
6048 .map_err(|e| format!("directional cubic diagnostic eigendecomposition failed: {e}"))?;
6049 let max_eval = evals.iter().fold(0.0_f64, |acc, &ev| acc.max(ev.abs()));
6050 let tol = (max_eval * 1.0e-12).max(1.0e-14);
6051 let mut directional = Array1::<f64>::zeros(p);
6052 let mut max_abs = 0.0_f64;
6053
6054 for r in 0..p {
6062 let lambda = evals[r];
6063 if lambda <= tol {
6064 continue;
6065 }
6066 let v = evecs.column(r);
6067 let gamma = directional_cubic_contraction(design, c_weights, &v) / lambda.powf(1.5);
6068 directional[r] = if gamma.is_finite() { gamma } else { 0.0 };
6069 max_abs = max_abs.max(directional[r].abs());
6070 }
6071
6072 if refine_supremum && p >= 2 {
6084 let positive_mask: Vec<bool> = evals.iter().map(|&ev| ev > tol).collect();
6088 let n_pos = positive_mask.iter().filter(|&&m| m).count();
6089 if n_pos >= 2 {
6090 let max_abs_from_probes = cubic_power_iteration_refinement(
6091 design,
6092 c_weights,
6093 &evals,
6094 &evecs,
6095 &positive_mask,
6096 n_pos,
6097 );
6098 if max_abs_from_probes > max_abs {
6099 max_abs = max_abs_from_probes;
6100 }
6101 }
6102 }
6103
6104 Ok((max_abs, directional))
6105}
6106
6107fn directional_cubic_contraction(
6109 design: &DesignMatrix,
6110 c_weights: &Array1<f64>,
6111 v: &ArrayView1<f64>,
6112) -> f64 {
6113 match design.as_sparse() {
6114 Some(x_sparse) => {
6115 let (symbolic, values) = x_sparse.as_ref().parts();
6116 let col_ptr = symbolic.col_ptr();
6117 let row_idx = symbolic.row_idx();
6118 let mut row_scores = vec![0.0_f64; x_sparse.nrows()];
6119 for col in 0..x_sparse.ncols() {
6120 let coeff = v[col];
6121 for ptr in col_ptr[col]..col_ptr[col + 1] {
6122 row_scores[row_idx[ptr]] += values[ptr] * coeff;
6123 }
6124 }
6125 let mut cubic = 0.0_f64;
6126 for i in 0..row_scores.len().min(c_weights.len()) {
6127 cubic += c_weights[i] * row_scores[i].powi(3);
6128 }
6129 cubic
6130 }
6131 None => {
6132 let x_dense = design.to_dense_cow();
6133 let x_dense = x_dense.as_ref();
6134 let mut cubic = 0.0_f64;
6135 for i in 0..x_dense.nrows().min(c_weights.len()) {
6136 let proj = x_dense.row(i).dot(v);
6137 cubic += c_weights[i] * proj.powi(3);
6138 }
6139 cubic
6140 }
6141 }
6142}
6143
6144fn directional_cubic_gradient(
6147 design: &DesignMatrix,
6148 c_weights: &Array1<f64>,
6149 v: &Array1<f64>,
6150) -> Array1<f64> {
6151 let p = v.len();
6152 match design.as_sparse() {
6153 Some(x_sparse) => {
6154 let (symbolic, values) = x_sparse.as_ref().parts();
6155 let col_ptr = symbolic.col_ptr();
6156 let row_idx = symbolic.row_idx();
6157 let n = x_sparse.nrows();
6158 let mut row_scores = vec![0.0_f64; n];
6159 for col in 0..x_sparse.ncols() {
6160 let coeff = v[col];
6161 for ptr in col_ptr[col]..col_ptr[col + 1] {
6162 row_scores[row_idx[ptr]] += values[ptr] * coeff;
6163 }
6164 }
6165 let mut quad_weights = vec![0.0_f64; n];
6167 for i in 0..n.min(c_weights.len()) {
6168 quad_weights[i] = 3.0 * c_weights[i] * row_scores[i] * row_scores[i];
6169 }
6170 let mut grad = Array1::<f64>::zeros(p);
6172 for col in 0..x_sparse.ncols() {
6173 let mut acc = 0.0_f64;
6174 for ptr in col_ptr[col]..col_ptr[col + 1] {
6175 acc += values[ptr] * quad_weights[row_idx[ptr]];
6176 }
6177 grad[col] = acc;
6178 }
6179 grad
6180 }
6181 None => {
6182 let x_dense = design.to_dense_cow();
6183 let x_dense = x_dense.as_ref();
6184 let n = x_dense.nrows();
6185 let mut grad = Array1::<f64>::zeros(p);
6186 for i in 0..n.min(c_weights.len()) {
6187 let proj = x_dense.row(i).dot(v);
6188 let w = 3.0 * c_weights[i] * proj * proj;
6189 let row = x_dense.row(i);
6191 for j in 0..p {
6192 grad[j] += w * row[j];
6193 }
6194 }
6195 grad
6196 }
6197 }
6198}
6199
6200fn cubic_power_iteration_refinement(
6206 design: &DesignMatrix,
6207 c_weights: &Array1<f64>,
6208 evals: &Array1<f64>,
6209 evecs: &Array2<f64>,
6210 positive_mask: &[bool],
6211 n_pos: usize,
6212) -> f64 {
6213 let p = evals.len();
6214 let max_probes = 8;
6215 let max_iters = 5;
6216
6217 let to_original = |u: &Array1<f64>| -> Array1<f64> {
6220 let mut v = Array1::<f64>::zeros(p);
6221 let mut idx = 0;
6222 for r in 0..p {
6223 if positive_mask[r] {
6224 let scale = u[idx] / evals[r].sqrt();
6225 let col = evecs.column(r);
6226 for j in 0..p {
6227 v[j] += scale * col[j];
6228 }
6229 idx += 1;
6230 }
6231 }
6232 v
6233 };
6234
6235 let to_whitened = |g: &Array1<f64>| -> Array1<f64> {
6237 let mut u = Array1::<f64>::zeros(n_pos);
6238 let mut idx = 0;
6239 for r in 0..p {
6240 if positive_mask[r] {
6241 u[idx] = evals[r].sqrt() * evecs.column(r).dot(g);
6242 idx += 1;
6243 }
6244 }
6245 u
6246 };
6247
6248 let eval_gamma = |u: &Array1<f64>| -> f64 {
6250 let norm = u.dot(u).sqrt();
6251 if norm < 1e-30 {
6252 return 0.0;
6253 }
6254 let u_normed: Array1<f64> = u / norm;
6255 let v = to_original(&u_normed);
6256 let cubic = directional_cubic_contraction(design, c_weights, &v.view());
6258 if cubic.is_finite() { cubic.abs() } else { 0.0 }
6259 };
6260
6261 let refine_step = |u: &Array1<f64>| -> Array1<f64> {
6263 let norm = u.dot(u).sqrt();
6264 if norm < 1e-30 {
6265 return u.clone();
6266 }
6267 let u_normed: Array1<f64> = u / norm;
6268 let v = to_original(&u_normed);
6269 let grad_v = directional_cubic_gradient(design, c_weights, &v);
6271 let mut grad_u = to_whitened(&grad_v);
6273 let dot = grad_u.dot(&u_normed);
6275 grad_u.scaled_add(-dot, &u_normed);
6276 let cubic_val = directional_cubic_contraction(design, c_weights, &v.view());
6278 let sign = if cubic_val >= 0.0 { 1.0 } else { -1.0 };
6279 let step_size = 0.3;
6280 let mut u_new = &u_normed + &(&grad_u * (sign * step_size));
6281 let new_norm = u_new.dot(&u_new).sqrt();
6282 if new_norm > 1e-30 {
6283 u_new /= new_norm;
6284 }
6285 u_new
6286 };
6287
6288 let mut best = 0.0_f64;
6289
6290 let mut seeds: Vec<Array1<f64>> = Vec::with_capacity(max_probes);
6296
6297 let mut best_eig_idx = 0;
6300 let mut best_eig_gamma = 0.0_f64;
6301 for j in 0..n_pos {
6302 let mut u = Array1::<f64>::zeros(n_pos);
6303 u[j] = 1.0;
6304 let g = eval_gamma(&u);
6305 if g > best_eig_gamma {
6306 best_eig_gamma = g;
6307 best_eig_idx = j;
6308 }
6309 }
6310 best = best.max(best_eig_gamma);
6311 let mut u_best = Array1::<f64>::zeros(n_pos);
6312 u_best[best_eig_idx] = 1.0;
6313 seeds.push(u_best);
6314
6315 let n_top = n_pos.min(4);
6317 for i in 0..n_top {
6318 for j in (i + 1)..n_top {
6319 if seeds.len() >= max_probes {
6320 break;
6321 }
6322 let inv_sqrt2 = std::f64::consts::FRAC_1_SQRT_2;
6323 let mut u_plus = Array1::<f64>::zeros(n_pos);
6324 u_plus[i] = inv_sqrt2;
6325 u_plus[j] = inv_sqrt2;
6326 seeds.push(u_plus);
6327 if seeds.len() < max_probes {
6328 let mut u_minus = Array1::<f64>::zeros(n_pos);
6329 u_minus[i] = inv_sqrt2;
6330 u_minus[j] = -inv_sqrt2;
6331 seeds.push(u_minus);
6332 }
6333 }
6334 }
6335
6336 for seed in &seeds {
6338 let mut u = seed.clone();
6339 for _ in 0..max_iters {
6340 u = refine_step(&u);
6341 }
6342 let g = eval_gamma(&u);
6343 best = best.max(g);
6344 }
6345
6346 best
6347}
6348
6349pub use gam_problem::laplace_sampler_contract::{
6359 BlockExcessTarget, BlockSampledMarginal, BlockSampledMoments, GaussianModePosterior,
6360 LaplaceTrustworthiness, laplace_skewness_threshold, laplace_trustworthiness_from_skewness,
6361};
6362
6363pub struct HmcIoLaplaceMarginalSampler;
6369
6370impl gam_problem::laplace_sampler_contract::LaplaceMarginalSampler for HmcIoLaplaceMarginalSampler {
6371 fn directional_cubic_diagnostic(
6372 &self,
6373 hessian: &Array2<f64>,
6374 design: &DesignMatrix,
6375 c_weights: &Array1<f64>,
6376 refine_supremum: bool,
6377 ) -> Result<(f64, Array1<f64>), String> {
6378 laplace_directional_cubic_diagnostic(hessian, design, c_weights, refine_supremum)
6379 }
6380
6381 fn block_sampled_marginal_correction(
6382 &self,
6383 target: &dyn BlockExcessTarget,
6384 ) -> Result<BlockSampledMarginal, String> {
6385 block_sampled_marginal_correction(target)
6386 }
6387}
6388
6389pub struct HmcIoGaussianModePosteriorSampler;
6396
6397impl gam_problem::laplace_sampler_contract::GaussianModePosteriorSampler
6398 for HmcIoGaussianModePosteriorSampler
6399{
6400 fn sample_gaussian_mode_posterior(
6401 &self,
6402 mode: ArrayView1<f64>,
6403 precision: ArrayView2<f64>,
6404 ) -> Result<GaussianModePosterior, String> {
6405 let config = NutsConfig::for_dimension(mode.len());
6406 sample_gaussian_mode_posterior(mode, precision, &config)
6407 }
6408}
6409
6410fn block_sampling_draws(block_dim: usize) -> usize {
6416 const BASE: usize = 256;
6419 const PER_DIM: usize = 256;
6420 const CAP: usize = 4096;
6421 (BASE + PER_DIM * block_dim).min(CAP)
6422}
6423
6424pub fn block_sampled_marginal_correction<T: BlockExcessTarget + ?Sized>(
6450 target: &T,
6451) -> Result<BlockSampledMarginal, String> {
6452 use rand::SeedableRng;
6453 use rand::rngs::StdRng;
6454
6455 let m = target.block_dim();
6456 let k = target.rho_dim();
6457 if m == 0 {
6458 return Ok(BlockSampledMarginal {
6459 value: 0.0,
6460 rho_gradient: Array1::zeros(k),
6461 importance_ess: 0.0,
6462 n_draws: 0,
6463 moments: None,
6464 });
6465 }
6466 let lambdas = target.block_curvatures();
6467 if lambdas.len() != m {
6468 return Err(format!(
6469 "block_sampled_marginal_correction: block_curvatures len {} != block_dim {m}",
6470 lambdas.len()
6471 ));
6472 }
6473 let inv_sqrt_lambda: Array1<f64> = lambdas.mapv(|l| {
6474 if l > 0.0 {
6475 1.0 / l.sqrt()
6476 } else {
6477 f64::NAN
6481 }
6482 });
6483 if inv_sqrt_lambda.iter().any(|v| !v.is_finite()) {
6484 return Err(
6485 "block_sampled_marginal_correction: non-positive block curvature (mode is not a \
6486 strict local minimum in a sampled direction)"
6487 .to_string(),
6488 );
6489 }
6490
6491 let n_draws = block_sampling_draws(m);
6492 let mut seed_bits: u64 = 0x9E37_79B9_7F4A_7C15;
6509 seed_bits ^= (m as u64).rotate_left(17);
6510 seed_bits = seed_bits.wrapping_mul(0x1000_0000_01B3);
6511 seed_bits ^= (k as u64).rotate_left(31);
6512 seed_bits = seed_bits.wrapping_mul(0x1000_0000_01B3);
6513 let mut rng = StdRng::seed_from_u64(seed_bits);
6514
6515 let n_obs = target.base_neg_score().len();
6525 let mut max_lw = f64::NEG_INFINITY;
6526 let mut sum_w = 0.0_f64;
6527 let mut sum_w2 = 0.0_f64;
6528 let mut grad_acc = Array1::<f64>::zeros(k);
6529 let mut e_t_acc = Array1::<f64>::zeros(m);
6530 let mut e_tt_acc = Array2::<f64>::zeros((m, m));
6531 let mut e_ngs_acc = Array1::<f64>::zeros(n_obs);
6532 let mut e_t_ngs_acc = Array2::<f64>::zeros((n_obs, m));
6533
6534 let mut draws = Array2::<f64>::zeros((m, n_draws));
6542 for s in 0..n_draws {
6543 let mut col = draws.column_mut(s);
6544 for r in 0..m {
6545 let z = sample_standard_normal(&mut rng);
6546 col[r] = z * inv_sqrt_lambda[r];
6547 }
6548 }
6549 let batched = target.excess_with_displaced_neg_score_batch(&draws);
6550
6551 let mut t = Array1::<f64>::zeros(m);
6552 for (sidx, (excess, displaced_ngs)) in batched.into_iter().enumerate() {
6553 t.assign(&draws.column(sidx));
6554 if !excess.is_finite() {
6555 continue;
6556 }
6557 let Some(ngs) = displaced_ngs else {
6558 continue;
6560 };
6561 let lw = -excess;
6562 if lw > max_lw {
6563 let rescale = (max_lw - lw).exp();
6566 sum_w *= rescale;
6567 sum_w2 *= rescale * rescale;
6568 grad_acc *= rescale;
6569 e_t_acc *= rescale;
6570 e_tt_acc *= rescale;
6571 e_ngs_acc *= rescale;
6572 e_t_ngs_acc *= rescale;
6573 max_lw = lw;
6574 }
6575 let w = (lw - max_lw).exp();
6576 sum_w += w;
6577 sum_w2 += w * w;
6578 grad_acc.scaled_add(-w, &target.excess_rho_gradient(&t));
6580 if ngs.len() != n_obs {
6582 return Err(format!(
6583 "block_sampled_marginal_correction: displaced_neg_score len {} != {n_obs}",
6584 ngs.len()
6585 ));
6586 }
6587 e_t_acc.scaled_add(w, &t);
6588 e_ngs_acc.scaled_add(w, &ngs);
6589 for r in 0..m {
6590 let wt_r = w * t[r];
6591 for q in 0..m {
6592 e_tt_acc[(q, r)] += wt_r * t[q];
6593 }
6594 e_t_ngs_acc.column_mut(r).scaled_add(wt_r, &ngs);
6595 }
6596 }
6597 if !max_lw.is_finite() {
6598 return Err(
6599 "block_sampled_marginal_correction: all importance draws were infeasible".to_string(),
6600 );
6601 }
6602 let value = max_lw + (sum_w / n_draws as f64).ln();
6603 let (rho_gradient, moments) = if sum_w > 0.0 {
6605 (
6606 grad_acc / sum_w,
6607 Some(BlockSampledMoments {
6608 e_t: e_t_acc / sum_w,
6609 e_tt: e_tt_acc / sum_w,
6610 e_neg_score: e_ngs_acc / sum_w,
6611 e_t_neg_score: e_t_ngs_acc / sum_w,
6612 }),
6613 )
6614 } else {
6615 (Array1::zeros(k), None)
6616 };
6617 let importance_ess = if sum_w2 > 0.0 {
6619 (sum_w * sum_w) / sum_w2
6620 } else {
6621 0.0
6622 };
6623
6624 if !value.is_finite() || rho_gradient.iter().any(|v| !v.is_finite()) {
6625 return Err(
6626 "block_sampled_marginal_correction: produced a non-finite correction or gradient"
6627 .to_string(),
6628 );
6629 }
6630 if let Some(mo) = moments.as_ref()
6631 && (mo.e_t.iter().any(|v| !v.is_finite())
6632 || mo.e_tt.iter().any(|v| !v.is_finite())
6633 || mo.e_neg_score.iter().any(|v| !v.is_finite())
6634 || mo.e_t_neg_score.iter().any(|v| !v.is_finite()))
6635 {
6636 return Err(
6637 "block_sampled_marginal_correction: produced non-finite gradient-channel moments"
6638 .to_string(),
6639 );
6640 }
6641
6642 Ok(BlockSampledMarginal {
6643 value,
6644 rho_gradient,
6645 importance_ess,
6646 n_draws,
6647 moments,
6648 })
6649}
6650
6651#[derive(Clone, Debug)]
6653pub struct JointBetaRhoResult {
6654 pub beta_samples: Array2<f64>,
6656 pub rho_samples: Array2<f64>,
6658 pub beta_mean: Array1<f64>,
6660 pub link_param_samples: Array2<f64>,
6662 pub link_param_mean: Array1<f64>,
6664 pub rho_mean: Array1<f64>,
6666 pub rhat: f64,
6668 pub ess: f64,
6670 pub converged: bool,
6672 pub trigger_skewness: f64,
6674}
6675
6676struct JointBetaRhoPosterior {
6685 data: SharedData,
6686 chol: Array2<f64>,
6688 chol_t: Array2<f64>,
6690 likelihood: LikelihoodSpec,
6692 n_beta: usize,
6694 n_rho: usize,
6696 n_link_params: usize,
6698 link_param_mode: Array1<f64>,
6700 penalty_canonical: Vec<gam_terms::construction::CanonicalPenalty>,
6702 rho_prior: RhoPrior,
6704 rho_mode: Array1<f64>,
6706 firth_enabled: bool,
6709 penalty_logdet_cache: Mutex<Option<(u64, f64, Array1<f64>)>>,
6717}
6718
6719impl JointBetaRhoPosterior {
6720 fn new(
6721 x: ArrayView2<f64>,
6722 y: ArrayView1<f64>,
6723 weights: ArrayView1<f64>,
6724 mode: ArrayView1<f64>,
6725 hessian: ArrayView2<f64>,
6726 penalty_canonical: Vec<gam_terms::construction::CanonicalPenalty>,
6727 rho_mode: ArrayView1<f64>,
6728 likelihood: LikelihoodSpec,
6729 gamma_shape: Option<f64>,
6730 rho_prior: RhoPrior,
6731 firth_enabled: bool,
6732 ) -> Result<Self, String> {
6733 let n_samples = x.nrows();
6734 let n_beta = x.ncols();
6735 let n_rho = penalty_canonical.len();
6736
6737 if rho_mode.len() != n_rho {
6738 return Err(HmcError::DimensionMismatch {
6739 reason: format!(
6740 "rho_mode length {} != penalty count {}",
6741 rho_mode.len(),
6742 n_rho
6743 ),
6744 }
6745 .into());
6746 }
6747
6748 match (&likelihood.response, &likelihood.link) {
6749 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {}
6750 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit)) => {}
6751 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {}
6752 (ResponseFamily::Binomial, InverseLink::LatentCLogLog(_)) => {}
6753 (ResponseFamily::Binomial, InverseLink::Sas(_)) => {}
6754 (ResponseFamily::Binomial, InverseLink::BetaLogistic(_)) => {}
6755 (ResponseFamily::Binomial, InverseLink::Mixture(_)) => {}
6756 (ResponseFamily::Binomial, InverseLink::Standard(other)) => {
6757 return Err(HmcError::LinkMismatch {
6758 reason: format!(
6759 "Joint HMC binomial response requires a binomial-compatible inverse link; got {:?}",
6760 other
6761 ),
6762 }
6763 .into());
6764 }
6765 (ResponseFamily::Gaussian, InverseLink::Standard(StandardLink::Identity)) => {}
6766 (ResponseFamily::Gaussian, _) => {
6767 return Err(HmcError::LinkMismatch {
6768 reason: "Joint HMC Gaussian requires an identity inverse link".to_string(),
6769 }
6770 .into());
6771 }
6772 (
6773 ResponseFamily::Poisson
6774 | ResponseFamily::Tweedie { .. }
6775 | ResponseFamily::NegativeBinomial { .. }
6776 | ResponseFamily::Gamma,
6777 InverseLink::Standard(StandardLink::Log),
6778 ) => {}
6779 (
6780 ResponseFamily::Poisson
6781 | ResponseFamily::Tweedie { .. }
6782 | ResponseFamily::NegativeBinomial { .. }
6783 | ResponseFamily::Gamma,
6784 _,
6785 ) => {
6786 return Err(HmcError::LinkMismatch {
6787 reason: "Joint HMC log-link family requires a log inverse link".to_string(),
6788 }
6789 .into());
6790 }
6791 (ResponseFamily::Beta { .. }, InverseLink::Standard(StandardLink::Logit)) => {}
6792 (ResponseFamily::Beta { .. }, _) => {
6793 return Err(HmcError::LinkMismatch {
6794 reason: "Joint HMC Beta requires a logit inverse link".to_string(),
6795 }
6796 .into());
6797 }
6798 (ResponseFamily::RoystonParmar, _) => {
6799 return Err(HmcError::UnsupportedFamily {
6800 reason: "Joint HMC fallback is not implemented for RoystonParmar".to_string(),
6801 }
6802 .into());
6803 }
6804 }
6805
6806 validate_firth_likelihood_support(&likelihood, firth_enabled).map_err(String::from)?;
6807 if matches!(likelihood.response, ResponseFamily::NegativeBinomial { .. }) {
6808 validate_count_responses("negative-binomial joint HMC", &y, &weights)
6809 .map_err(String::from)?;
6810 }
6811 if likelihood.is_binomial() {
6812 validate_binary_responses("binomial joint HMC", &y, &weights).map_err(String::from)?;
6813 }
6814
6815 let whitening = hessian_whitening_transform(
6816 hessian,
6817 n_beta,
6818 1.0,
6819 "Joint HMC: Hessian Cholesky failed",
6820 )?;
6821 let chol = whitening.chol;
6822 let chol_t = whitening.chol_t;
6823
6824 let data = SharedData {
6825 x: Arc::new(x.to_owned()),
6826 y: Arc::new(y.to_owned()),
6827 weights: Arc::new(weights.to_owned()),
6828 mode: Arc::new(mode.to_owned()),
6829 offset: None,
6830 gamma_shape: gamma_shape.unwrap_or(1.0),
6831 dispersion: gam_solve::model_types::Dispersion::Known(1.0),
6836 n_samples,
6837 dim: n_beta,
6838 };
6839 let link_param_mode = Self::link_param_mode(&likelihood.link);
6840
6841 Ok(Self {
6842 data,
6843 chol,
6844 chol_t,
6845 likelihood,
6846 n_beta,
6847 n_rho,
6848 n_link_params: link_param_mode.len(),
6849 link_param_mode,
6850 penalty_canonical,
6851 rho_prior,
6852 rho_mode: rho_mode.to_owned(),
6853 firth_enabled,
6854 penalty_logdet_cache: Mutex::new(None),
6855 })
6856 }
6857
6858 #[inline]
6864 fn hash_rho(rho: ndarray::ArrayView1<f64>) -> u64 {
6865 let mut h: u64 = 0xcbf2_9ce4_8422_2325;
6866 for &x in rho.iter() {
6867 h ^= x.to_bits();
6868 h = h.wrapping_mul(0x0000_0100_0000_01b3);
6869 }
6870 h
6871 }
6872
6873 fn link_param_mode(inverse_link: &InverseLink) -> Array1<f64> {
6874 match inverse_link {
6875 InverseLink::Sas(state) | InverseLink::BetaLogistic(state) => {
6876 Array1::from_vec(vec![state.epsilon, state.log_delta])
6877 }
6878 InverseLink::Mixture(state) => state.rho.clone(),
6879 InverseLink::Standard(_) | InverseLink::LatentCLogLog(_) => Array1::zeros(0),
6880 }
6881 }
6882
6883 fn inverse_link_with_params(
6884 &self,
6885 link_params: ndarray::ArrayView1<'_, f64>,
6886 ) -> Result<InverseLink, String> {
6887 match &self.likelihood.link {
6888 InverseLink::Sas(_) => {
6889 if link_params.len() != 2 {
6890 return Err(format!(
6891 "SAS link parameter length must be 2, got {}",
6892 link_params.len()
6893 ));
6894 }
6895 Ok(InverseLink::Sas(
6896 gam_solve::mixture_link::sas_link_state_from_raw(
6897 link_params[0],
6898 link_params[1],
6899 )?,
6900 ))
6901 }
6902 InverseLink::BetaLogistic(_) => {
6903 if link_params.len() != 2 || !link_params.iter().all(|v| v.is_finite()) {
6904 return Err(
6905 "Beta-Logistic link parameters must be finite with length 2".to_string()
6906 );
6907 }
6908 Ok(InverseLink::BetaLogistic(
6909 gam_problem::types::SasLinkState {
6910 epsilon: link_params[0],
6911 log_delta: link_params[1],
6912 delta: link_params[1].exp(),
6913 },
6914 ))
6915 }
6916 InverseLink::Mixture(state) => {
6917 let rho = link_params.to_owned();
6918 Ok(InverseLink::Mixture(gam_problem::types::MixtureLinkState {
6919 components: state.components.clone(),
6920 pi: softmax_last_fixedzero(&rho),
6921 rho,
6922 }))
6923 }
6924 InverseLink::Standard(_) | InverseLink::LatentCLogLog(_) => {
6925 Ok(self.likelihood.link.clone())
6926 }
6927 }
6928 }
6929
6930 fn compute_joint_logp_and_grad_into(
6943 &self,
6944 params: &Array1<f64>,
6945 out_grad: &mut Array1<f64>,
6946 ) -> f64 {
6947 let n_beta = self.n_beta;
6948 let n_rho = self.n_rho;
6949 let n_link_params = self.n_link_params;
6950
6951 let z = params.slice(ndarray::s![..n_beta]);
6954 let rho = params.slice(ndarray::s![n_beta..n_beta + n_rho]);
6955 let link_params = params.slice(ndarray::s![n_beta + n_rho..]);
6956 let lambdas: Array1<f64> = rho.mapv(f64::exp);
6957
6958 let inverse_link = match self.inverse_link_with_params(link_params) {
6959 Ok(link) => link,
6960 Err(err) => {
6961 log::warn!(
6962 "[Joint HMC] adaptive inverse-link parameters are invalid: {}",
6963 err
6964 );
6965 out_grad.fill(0.0);
6966 return f64::NEG_INFINITY;
6967 }
6968 };
6969
6970 let beta = self.data.mode.as_ref() + &self.chol.dot(&z);
6972
6973 let eta = gam_linalg::faer_ndarray::fast_av(self.data.x.as_ref(), &beta);
6975
6976 let step_likelihood = LikelihoodSpec {
6978 response: self.likelihood.response.clone(),
6979 link: inverse_link,
6980 };
6981 let (ll, mut grad_ll_beta, grad_link) = match joint_family_logp_grad_and_link_grad(
6982 &step_likelihood,
6983 &self.data,
6984 &eta,
6985 n_link_params,
6986 ) {
6987 Ok(value) => value,
6988 Err(err) => {
6989 log::warn!(
6990 "[Joint HMC] likelihood target became invalid at the current state: {}",
6991 err
6992 );
6993 out_grad.fill(0.0);
6994 return f64::NEG_INFINITY;
6995 }
6996 };
6997
6998 let mut firth_logdet = 0.0;
6999 if self.firth_enabled {
7000 match firth_jeffreys_logp_and_grad(NutsFamily::BinomialLogit, &self.data, &eta) {
7001 Ok((value, grad_beta_firth)) => {
7002 firth_logdet = value;
7003 grad_ll_beta += &grad_beta_firth;
7004 }
7005 Err(err) => {
7006 log::warn!(
7007 "[Joint HMC/Firth] Jeffreys target became invalid at the current state: {}",
7008 err
7009 );
7010 out_grad.fill(0.0);
7011 return f64::NEG_INFINITY;
7012 }
7013 }
7014 }
7015
7016 let mut penalty_val = 0.0;
7020 let mut s_beta = Array1::<f64>::zeros(n_beta);
7021 let mut grad_rho = Array1::<f64>::zeros(n_rho);
7022
7023 let max_rank = self
7027 .penalty_canonical
7028 .iter()
7029 .map(|cp| cp.rank())
7030 .max()
7031 .unwrap_or(0);
7032 let mut r_beta_scratch = Array1::<f64>::zeros(max_rank);
7033
7034 for (k, cp) in self.penalty_canonical.iter().enumerate() {
7035 let r = &cp.col_range;
7037 let beta_block = beta.slice(ndarray::s![r.start..r.end]);
7038 let rank_k = cp.rank();
7039 gam_linalg::faer_ndarray::fast_av_view_into(
7040 &cp.root,
7041 &beta_block,
7042 r_beta_scratch.slice_mut(ndarray::s![..rank_k]),
7043 );
7044 let r_beta = r_beta_scratch.slice(ndarray::s![..rank_k]);
7045 let quad_k = r_beta.dot(&r_beta);
7046 penalty_val += 0.5 * lambdas[k] * quad_k;
7047
7048 for a in 0..cp.block_dim() {
7050 let val: f64 = (0..rank_k).map(|row| cp.root[[row, a]] * r_beta[row]).sum();
7051 s_beta[r.start + a] += lambdas[k] * val;
7052 }
7053
7054 grad_rho[k] = -0.5 * lambdas[k] * quad_k;
7056 }
7057
7058 let log_det_s = if self.penalty_canonical.is_empty() {
7066 0.0
7067 } else {
7068 let rho_hash = Self::hash_rho(rho);
7069 let cached = self.penalty_logdet_cache.lock().ok().and_then(|guard| {
7070 guard.as_ref().and_then(|(h, v, g)| {
7071 if *h == rho_hash && g.len() == n_rho {
7072 for k in 0..n_rho {
7073 grad_rho[k] += 0.5 * g[k];
7074 }
7075 Some(*v)
7076 } else {
7077 None
7078 }
7079 })
7080 });
7081 if let Some(hit) = cached {
7082 hit
7083 } else {
7084 match PenaltyPseudologdet::from_penalties(
7085 &self.penalty_canonical,
7086 lambdas.as_slice().unwrap_or(&[]),
7087 0.0,
7088 n_beta,
7089 ) {
7090 Ok(pld) => {
7091 let (det1, _) = pld.rho_derivatives_from_penalties(
7092 &self.penalty_canonical,
7093 lambdas.as_slice().unwrap_or(&[]),
7094 );
7095 let value = pld.value();
7096 if let Ok(mut guard) = self.penalty_logdet_cache.lock() {
7097 *guard = Some((rho_hash, value, det1.clone()));
7098 }
7099 for k in 0..n_rho {
7100 grad_rho[k] += 0.5 * det1[k];
7101 }
7102 value
7103 }
7104 Err(err) => {
7105 log::warn!(
7106 "[Joint HMC] structural penalty logdet became invalid at the current state: {}",
7107 err
7108 );
7109 out_grad.fill(0.0);
7110 return f64::NEG_INFINITY;
7111 }
7112 }
7113 }
7114 };
7115
7116 let mut rho_prior = 0.0;
7118 match &self.rho_prior {
7119 RhoPrior::Flat => {}
7120 RhoPrior::Normal { mean, sd } => {
7121 let inv_var = 1.0 / (*sd * *sd);
7122 for k in 0..n_rho {
7123 let d = rho[k] - *mean;
7124 rho_prior -= 0.5 * inv_var * d * d;
7125 grad_rho[k] -= inv_var * d;
7126 }
7127 }
7128 RhoPrior::GammaPrecision { shape, rate } => {
7129 for k in 0..n_rho {
7130 let lambda = rho[k].exp();
7131 rho_prior += *shape * rho[k] - *rate * lambda;
7133 grad_rho[k] += *shape - *rate * lambda;
7134 }
7135 }
7136 RhoPrior::PenalizedComplexity { upper, tail_prob } => {
7137 if !pc_prior_params_valid(*upper, *tail_prob) {
7138 out_grad.fill(0.0);
7139 return f64::NEG_INFINITY;
7140 }
7141 let theta = -tail_prob.ln() / *upper;
7142 for k in 0..n_rho {
7143 let e = (-0.5 * rho[k]).exp();
7145 rho_prior += -0.5 * rho[k] - theta * e;
7146 grad_rho[k] += -0.5 + 0.5 * theta * e;
7147 }
7148 }
7149 RhoPrior::Independent(priors) => {
7150 if priors.len() != n_rho {
7151 out_grad.fill(0.0);
7152 return f64::NEG_INFINITY;
7153 }
7154 for k in 0..n_rho {
7155 match &priors[k] {
7156 RhoPrior::Flat => {}
7157 RhoPrior::Normal { mean, sd } => {
7158 let inv_var = 1.0 / (*sd * *sd);
7159 let d = rho[k] - *mean;
7160 rho_prior -= 0.5 * inv_var * d * d;
7161 grad_rho[k] -= inv_var * d;
7162 }
7163 RhoPrior::GammaPrecision { shape, rate } => {
7164 let lambda = rho[k].exp();
7165 rho_prior += *shape * rho[k] - *rate * lambda;
7167 grad_rho[k] += *shape - *rate * lambda;
7168 }
7169 RhoPrior::PenalizedComplexity { upper, tail_prob } => {
7170 if !pc_prior_params_valid(*upper, *tail_prob) {
7171 out_grad.fill(0.0);
7172 return f64::NEG_INFINITY;
7173 }
7174 let theta = -tail_prob.ln() / *upper;
7175 let e = (-0.5 * rho[k]).exp();
7176 rho_prior += -0.5 * rho[k] - theta * e;
7177 grad_rho[k] += -0.5 + 0.5 * theta * e;
7178 }
7179 RhoPrior::Independent(_) => {
7180 out_grad.fill(0.0);
7181 return f64::NEG_INFINITY;
7182 }
7183 }
7184 }
7185 }
7186 }
7187
7188 let logp = ll + firth_logdet - penalty_val + 0.5 * log_det_s + rho_prior;
7190
7191 let grad_beta = &grad_ll_beta - &s_beta;
7193
7194 gam_linalg::faer_ndarray::fast_av_view_into(
7196 &self.chol_t,
7197 &grad_beta,
7198 out_grad.slice_mut(ndarray::s![..n_beta]),
7199 );
7200 out_grad
7201 .slice_mut(ndarray::s![n_beta..n_beta + n_rho])
7202 .assign(&grad_rho);
7203 out_grad
7204 .slice_mut(ndarray::s![n_beta + n_rho..])
7205 .assign(&grad_link);
7206
7207 logp
7208 }
7209}
7210
7211fn pc_prior_params_valid(upper: f64, tail_prob: f64) -> bool {
7217 upper.is_finite() && upper > 0.0 && tail_prob.is_finite() && tail_prob > 0.0 && tail_prob < 1.0
7218}
7219
7220impl HamiltonianTarget<Array1<f64>> for JointBetaRhoPosterior {
7221 fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
7222 self.compute_joint_logp_and_grad_into(position, grad)
7223 }
7224}
7225
7226pub struct JointBetaRhoInputs<'a> {
7228 pub x: ArrayView2<'a, f64>,
7229 pub y: ArrayView1<'a, f64>,
7230 pub weights: ArrayView1<'a, f64>,
7231 pub likelihood: LikelihoodSpec,
7232 pub gamma_shape: Option<f64>,
7233 pub mode: ArrayView1<'a, f64>,
7234 pub hessian: ArrayView2<'a, f64>,
7235 pub penalty_roots: Vec<CanonicalPenalty>,
7236 pub rho_mode: ArrayView1<'a, f64>,
7237 pub rho_prior: RhoPrior,
7238 pub firth_bias_reduction: bool,
7239 pub trigger_skewness: f64,
7241}
7242
7243pub fn run_joint_beta_rho_sampling(
7249 inputs: &JointBetaRhoInputs<'_>,
7250 config: &NutsConfig,
7251) -> Result<JointBetaRhoResult, String> {
7252 validate_firth_likelihood_support(&inputs.likelihood, inputs.firth_bias_reduction)
7253 .map_err(String::from)?;
7254 validate_nuts_config(config).map_err(String::from)?;
7255 let n_beta = inputs.mode.len();
7256 let n_rho = inputs.penalty_roots.len();
7257 let n_link_params = JointBetaRhoPosterior::link_param_mode(&inputs.likelihood.link).len();
7258 let total_dim = n_beta + n_rho + n_link_params;
7259
7260 log::info!(
7261 "[Joint HMC] Sampling (β, ρ, link) jointly: {} β-params + {} ρ-params + {} link-params = {} total (triggered by skewness {:.3})",
7262 n_beta,
7263 n_rho,
7264 n_link_params,
7265 total_dim,
7266 inputs.trigger_skewness,
7267 );
7268
7269 let target = JointBetaRhoPosterior::new(
7270 inputs.x,
7271 inputs.y,
7272 inputs.weights,
7273 inputs.mode,
7274 inputs.hessian,
7275 inputs.penalty_roots.clone(),
7276 inputs.rho_mode,
7277 inputs.likelihood.clone(),
7278 inputs.gamma_shape,
7279 inputs.rho_prior.clone(),
7280 inputs.firth_bias_reduction,
7281 )?;
7282
7283 let chol = target.chol.clone();
7284 let mode_arr = target.data.mode.clone();
7285 let rho_mode = target.rho_mode.clone();
7286 let link_param_mode = target.link_param_mode.clone();
7287
7288 let initial_positions: Vec<Array1<f64>> = (0..config.n_chains)
7290 .map(|chain| {
7291 let mut rng =
7292 StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0x9B51_6E37_F2D0_A48C));
7293 let mut pos = Array1::<f64>::zeros(total_dim);
7294 for j in 0..n_beta {
7296 pos[j] = sample_standard_normal(&mut rng) * 0.1;
7297 }
7298 for k in 0..n_rho {
7300 pos[n_beta + k] = rho_mode[k] + sample_standard_normal(&mut rng) * 0.2;
7301 }
7302 for k in 0..n_link_params {
7304 pos[n_beta + n_rho + k] =
7305 link_param_mode[k] + sample_standard_normal(&mut rng) * 0.05;
7306 }
7307 pos
7308 })
7309 .collect();
7310
7311 let mass_cfg = robust_mass_matrix_config(total_dim, config.nwarmup);
7314
7315 let (samples_array, run_stats) = run_whitened_nuts_samples(
7316 target,
7317 initial_positions,
7318 config,
7319 total_dim,
7320 mass_cfg,
7321 0x63AF_175B_D820_C94E,
7322 "Joint (β,ρ) NUTS sampling failed",
7323 )?;
7324 log::info!("[Joint HMC] Sampling complete: {}", run_stats);
7325
7326 let shape = samples_array.shape();
7328 let n_chains = shape[0];
7329 let n_samples_out = shape[1];
7330 let total_samples = n_chains * n_samples_out;
7331
7332 let beta_samples = unwhiten_samples(&samples_array, mode_arr.as_ref(), &chol, n_beta, 0);
7333 let mut rho_samples = Array2::<f64>::zeros((total_samples, n_rho));
7334 let mut link_param_samples = Array2::<f64>::zeros((total_samples, n_link_params));
7335
7336 for chain in 0..n_chains {
7337 for sample_i in 0..n_samples_out {
7338 let sample_idx = chain * n_samples_out + sample_i;
7339 let zview = samples_array.slice(ndarray::s![chain, sample_i, ..]);
7340
7341 let rho_slice = zview.slice(ndarray::s![n_beta..n_beta + n_rho]);
7343 rho_samples.row_mut(sample_idx).assign(&rho_slice);
7344 let link_slice = zview.slice(ndarray::s![n_beta + n_rho..]);
7345 link_param_samples.row_mut(sample_idx).assign(&link_slice);
7346 }
7347 }
7348
7349 let beta_mean = beta_samples
7350 .mean_axis(Axis(0))
7351 .unwrap_or_else(|| Array1::zeros(n_beta));
7352 let rho_mean = rho_samples
7353 .mean_axis(Axis(0))
7354 .unwrap_or_else(|| Array1::zeros(n_rho));
7355 let link_param_mean = link_param_samples
7356 .mean_axis(Axis(0))
7357 .unwrap_or_else(|| Array1::zeros(n_link_params));
7358
7359 let (rhat, ess) = compute_split_rhat_and_ess(&samples_array);
7360
7361 let converged = NutsConvergenceThresholds {
7362 max_rhat: 1.1,
7363 min_ess: Some(50.0),
7364 }
7365 .converged(rhat, ess);
7366 if !converged {
7367 log::warn!(
7368 "[Joint HMC] Convergence warning: R-hat={:.3}, ESS={:.1}",
7369 rhat,
7370 ess,
7371 );
7372 }
7373
7374 Ok(JointBetaRhoResult {
7375 beta_samples,
7376 rho_samples,
7377 beta_mean,
7378 link_param_samples,
7379 link_param_mean,
7380 rho_mean,
7381 rhat,
7382 ess,
7383 converged,
7384 trigger_skewness: inputs.trigger_skewness,
7385 })
7386}
7387
7388mod survival_hmc {
7393 use super::*;
7394 use gam_models::survival::{
7395 PenaltyBlocks, SurvivalEngineInputs, SurvivalMonotonicityPenalty, SurvivalSpec,
7396 WorkingModelSurvival,
7397 };
7398
7399 #[derive(Clone)]
7401 struct SharedSurvivalData {
7402 base_model: Arc<WorkingModelSurvival>,
7404 mode: Arc<Array1<f64>>,
7406 }
7407
7408 #[derive(Clone)]
7410 pub struct SurvivalPosterior {
7411 data: SharedSurvivalData,
7413 chol: Array2<f64>,
7415 chol_t: Array2<f64>,
7417 }
7418
7419 impl SurvivalPosterior {
7420 pub fn new(
7422 age_entry: ArrayView1<'_, f64>,
7423 age_exit: ArrayView1<'_, f64>,
7424 event_target: ArrayView1<'_, u8>,
7425 event_competing: ArrayView1<'_, u8>,
7426 sampleweight: ArrayView1<'_, f64>,
7427 x_entry: ArrayView2<'_, f64>,
7428 x_exit: ArrayView2<'_, f64>,
7429 x_derivative: ArrayView2<'_, f64>,
7430 offset_eta_entry: Option<ArrayView1<'_, f64>>,
7431 offset_eta_exit: Option<ArrayView1<'_, f64>>,
7432 offset_derivative_exit: Option<ArrayView1<'_, f64>>,
7433 penalties: PenaltyBlocks,
7434 monotonicity: SurvivalMonotonicityPenalty,
7435 spec: SurvivalSpec,
7436 structurally_monotonic: bool,
7437 structural_time_columns: usize,
7438 mode: ArrayView1<f64>,
7439 hessian: ArrayView2<f64>,
7440 ) -> Result<Self, String> {
7441 let n = age_entry.len();
7442 let off_eta_entry = offset_eta_entry
7443 .map(|v| v.to_owned())
7444 .unwrap_or_else(|| Array1::zeros(n));
7445 let off_eta_exit = offset_eta_exit
7446 .map(|v| v.to_owned())
7447 .unwrap_or_else(|| Array1::zeros(n));
7448 let off_deriv_exit = offset_derivative_exit
7449 .map(|v| v.to_owned())
7450 .unwrap_or_else(|| Array1::zeros(n));
7451
7452 let mut base_model = WorkingModelSurvival::from_engine_inputswith_offsets(
7453 SurvivalEngineInputs {
7454 age_entry,
7455 age_exit,
7456 event_target,
7457 event_competing,
7458 sampleweight,
7459 x_entry,
7460 x_exit,
7461 x_derivative,
7462 monotonicity_constraint_rows: None,
7463 monotonicity_constraint_offsets: None,
7464 },
7465 Some(gam_models::survival::SurvivalBaselineOffsets {
7466 eta_entry: off_eta_entry.view(),
7467 eta_exit: off_eta_exit.view(),
7468 derivative_exit: off_deriv_exit.view(),
7469 }),
7470 penalties,
7471 monotonicity,
7472 spec,
7473 )
7474 .map_err(|e| format!("Survival state construction failed: {:?}", e))?;
7475 if structurally_monotonic {
7476 base_model
7477 .set_structural_monotonicity(true, structural_time_columns)
7478 .map_err(|e| {
7479 format!("Failed to enable structural monotonicity in survival HMC: {e}")
7480 })?;
7481 }
7482
7483 let sampler_mode = mode.to_owned();
7484 let dim = sampler_mode.len();
7485
7486 let whitening = hessian_whitening_transform(
7487 hessian,
7488 dim,
7489 1.0,
7490 "Hessian Cholesky decomposition failed",
7491 )?;
7492 let chol = whitening.chol;
7493 let chol_t = whitening.chol_t;
7494
7495 let data = SharedSurvivalData {
7496 base_model: Arc::new(base_model),
7497 mode: Arc::new(sampler_mode),
7498 };
7499
7500 Ok(Self { data, chol, chol_t })
7501 }
7502
7503 fn compute_logp_and_grad_into(
7504 &self,
7505 z: &Array1<f64>,
7506 grad: &mut Array1<f64>,
7507 ) -> Result<f64, String> {
7508 let sampler_position = self.data.mode.as_ref() + &self.chol.dot(z);
7509 let state = self
7510 .data
7511 .base_model
7512 .update_state(&sampler_position)
7513 .map_err(|e| format!("Survival state update failed: {:?}", e))?;
7514 let logp = state.log_likelihood - state.penalty_term;
7515 let grad_beta = state.gradient.mapv(|g| -g);
7516 fast_av_into(&self.chol_t, &grad_beta, grad);
7517 Ok(logp)
7518 }
7519
7520 pub fn chol(&self) -> &Array2<f64> {
7522 &self.chol
7523 }
7524
7525 pub fn mode(&self) -> &Array1<f64> {
7527 &self.data.mode
7528 }
7529 }
7530
7531 impl HamiltonianTarget<Array1<f64>> for SurvivalPosterior {
7532 fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
7533 match self.compute_logp_and_grad_into(position, grad) {
7534 Ok(logp) => logp,
7535 Err(e) => {
7536 log::warn!("Survival posterior evaluation failed: {}", e);
7537 grad.fill(0.0);
7538 f64::NEG_INFINITY
7539 }
7540 }
7541 }
7542 }
7543
7544 pub(crate) fn run_survival_nuts_sampling(
7546 age_entry: ArrayView1<'_, f64>,
7547 age_exit: ArrayView1<'_, f64>,
7548 event_target: ArrayView1<'_, u8>,
7549 event_competing: ArrayView1<'_, u8>,
7550 sampleweight: ArrayView1<'_, f64>,
7551 x_entry: ArrayView2<'_, f64>,
7552 x_exit: ArrayView2<'_, f64>,
7553 x_derivative: ArrayView2<'_, f64>,
7554 eta_offset_entry: Option<ArrayView1<'_, f64>>,
7555 eta_offset_exit: Option<ArrayView1<'_, f64>>,
7556 derivative_offset_exit: Option<ArrayView1<'_, f64>>,
7557 penalties: PenaltyBlocks,
7558 monotonicity: SurvivalMonotonicityPenalty,
7559 spec: SurvivalSpec,
7560 structurally_monotonic: bool,
7561 structural_time_columns: usize,
7562 mode: ArrayView1<f64>,
7563 hessian: ArrayView2<f64>,
7564 config: &NutsConfig,
7565 ) -> Result<NutsResult, String> {
7566 validate_nuts_config(config).map_err(String::from)?;
7567 let target = SurvivalPosterior::new(
7569 age_entry,
7570 age_exit,
7571 event_target,
7572 event_competing,
7573 sampleweight,
7574 x_entry,
7575 x_exit,
7576 x_derivative,
7577 eta_offset_entry,
7578 eta_offset_exit,
7579 derivative_offset_exit,
7580 penalties,
7581 monotonicity,
7582 spec,
7583 structurally_monotonic,
7584 structural_time_columns,
7585 mode,
7586 hessian,
7587 )?;
7588
7589 let chol = target.chol().clone();
7591 let mode_arr = target.mode().clone();
7592 let dim = mode_arr.len();
7593
7594 let initial_positions = jittered_initial_positions(config, dim, 0.1, 0xEC2D_7A9B_4051_F638);
7595
7596 let mass_cfg = robust_survival_mass_matrix_config(dim, config.nwarmup);
7597 let (result, run_stats) = run_whitened_nuts_result(
7598 target,
7599 &mode_arr,
7600 &chol,
7601 initial_positions,
7602 config,
7603 dim,
7604 mass_cfg,
7605 0x731B_60D4_AE52_9C8F,
7606 "NUTS sampling failed",
7607 Array1::zeros(dim),
7608 NutsConvergenceThresholds {
7609 max_rhat: 1.1,
7610 min_ess: None,
7611 },
7612 )?;
7613
7614 log::info!("Survival NUTS sampling complete: {}", run_stats);
7615
7616 Ok(result)
7617 }
7618}
7619
7620pub fn run_survival_nuts_sampling_flattened<'a>(
7622 flat: SurvivalFlatInputs<'a>,
7623 penalties: gam_models::survival::PenaltyBlocks,
7624 monotonicity: gam_models::survival::SurvivalMonotonicityPenalty,
7625 spec: gam_models::survival::SurvivalSpec,
7626 structurally_monotonic: bool,
7627 structural_time_columns: usize,
7628 mode: ArrayView1<'a, f64>,
7629 hessian: ArrayView2<'a, f64>,
7630 config: &NutsConfig,
7631) -> Result<NutsResult, String> {
7632 run_nuts_sampling_flattened_family(
7633 LikelihoodSpec {
7634 response: ResponseFamily::RoystonParmar,
7635 link: InverseLink::Standard(StandardLink::Identity),
7636 },
7637 FamilyNutsInputs::Survival(Box::new(SurvivalNutsInputs {
7638 flat,
7639 penalties,
7640 monotonicity,
7641 spec,
7642 structurally_monotonic,
7643 structural_time_columns,
7644 mode,
7645 hessian,
7646 })),
7647 config,
7648 )
7649}