1use super::family::*;
2use super::gradient_paths::*;
3use super::hessian_paths::{new_cell_moment_cache_stats, new_cell_moment_lru_cache};
4use super::install_flex::validate_spec;
5use super::*;
6use gam_linalg::faer_ndarray::{FaerEigh, fast_ab, fast_atb, fast_xt_diag_x};
7use crate::marginal_slope_orthogonal::influence_absorber_log_lambda;
8use faer::Side;
9
10pub(crate) const BMS_PROBIT_SEPARATION_ETA_INF: f64 = 35.0;
25
26pub(super) const GAUGE_PRIORITY_ANCHOR: u8 = 200;
41pub(super) const GAUGE_PRIORITY_MARGINAL: u8 = 150;
44pub(super) const GAUGE_PRIORITY_LOGSLOPE: u8 = 120;
46pub(super) const GAUGE_PRIORITY_CANDIDATE_FLEX: u8 = 100;
49pub(super) const GAUGE_PRIORITY_SCORE_WARP_DEV: u8 = 80;
52pub(super) const GAUGE_PRIORITY_DEVIATION_DEFAULT: u8 = 70;
56pub(super) const GAUGE_PRIORITY_LINK_DEV: u8 = 60;
58
59pub(crate) const EXACT_SPATIAL_OUTER_TOL_FLOOR: f64 = 1e-6;
65
66pub struct BmsMarginalJacobian {
105 pub marginal_dense: Arc<Array2<f64>>,
107 pub logslope_dense: Arc<Array2<f64>>,
109 pub offset_m: Array1<f64>,
110 pub offset_s: Array1<f64>,
111 pub p_marginal: usize,
113}
114
115impl BmsMarginalJacobian {
116 pub fn new(
117 marginal_dense: Arc<Array2<f64>>,
118 logslope_dense: Arc<Array2<f64>>,
119 offset_m: Array1<f64>,
120 offset_s: Array1<f64>,
121 p_marginal: usize,
122 ) -> Self {
123 Self {
124 marginal_dense,
125 logslope_dense,
126 offset_m,
127 offset_s,
128 p_marginal,
129 }
130 }
131}
132
133impl BlockEffectiveJacobian for BmsMarginalJacobian {
134 fn effective_jacobian_rows(
135 &self,
136 state: &FamilyLinearizationState<'_>,
137 rows: std::ops::Range<usize>,
138 ) -> Result<Array2<f64>, String> {
139 let beta = state.beta;
140 let s = state.probit_frailty_scale;
141 let p_m = self.p_marginal;
142 let p_s_block = self.logslope_dense.ncols();
143 let beta_s_raw = if beta.len() > p_m {
144 &beta[p_m..]
145 } else {
146 &[][..]
147 };
148 let p_s_use = p_s_block.min(beta_s_raw.len());
149 let beta_s = &beta_s_raw[..p_s_use];
150 let n = self.marginal_dense.nrows();
151 let rows = rows.start.min(n)..rows.end.min(n);
152 let p_block = self.marginal_dense.ncols();
153
154 let mut out = Array2::<f64>::zeros((rows.end - rows.start, p_block));
164 for i in rows.clone() {
165 let g_i = self.offset_s[i]
166 + self
167 .logslope_dense
168 .row(i)
169 .slice(ndarray::s![..p_s_use])
170 .dot(&ArrayView1::from(beta_s));
171 let sg = s * g_i;
172 let c_i = (1.0 + sg * sg).sqrt();
173 let m_row = self.marginal_dense.row(i);
175 out.row_mut(i - rows.start).assign(&m_row.mapv(|x| c_i * x));
176 }
177 Ok(out)
178 }
179
180 fn n_outputs(&self) -> usize {
181 1
182 }
183
184 fn locks_raw_width_reduction(&self) -> bool {
185 true
195 }
196}
197
198pub struct BmsLogslopeJacobian {
210 pub marginal_dense: Arc<Array2<f64>>,
212 pub logslope_dense: Arc<Array2<f64>>,
214 pub offset_m: Array1<f64>,
215 pub offset_s: Array1<f64>,
216 pub z: Arc<Array1<f64>>,
217 pub p_marginal: usize,
219}
220
221impl BmsLogslopeJacobian {
222 pub fn new(
223 marginal_dense: Arc<Array2<f64>>,
224 logslope_dense: Arc<Array2<f64>>,
225 offset_m: Array1<f64>,
226 offset_s: Array1<f64>,
227 z: Arc<Array1<f64>>,
228 p_marginal: usize,
229 ) -> Self {
230 Self {
231 marginal_dense,
232 logslope_dense,
233 offset_m,
234 offset_s,
235 z,
236 p_marginal,
237 }
238 }
239}
240
241impl BlockEffectiveJacobian for BmsLogslopeJacobian {
242 fn effective_jacobian_rows(
243 &self,
244 state: &FamilyLinearizationState<'_>,
245 rows: std::ops::Range<usize>,
246 ) -> Result<Array2<f64>, String> {
247 let beta = state.beta;
248 let s = state.probit_frailty_scale;
249 let p_m = self.p_marginal;
250 let p_m_use = p_m.min(beta.len());
251 let beta_m = &beta[..p_m_use];
252 let beta_s_raw = if beta.len() > p_m {
253 &beta[p_m..]
254 } else {
255 &[][..]
256 };
257 let p_s_block = self.logslope_dense.ncols();
258 let p_s_use = p_s_block.min(beta_s_raw.len());
259 let beta_s = &beta_s_raw[..p_s_use];
260 let n = self.logslope_dense.nrows();
261 let rows = rows.start.min(n)..rows.end.min(n);
262
263 let mut out = Array2::<f64>::zeros((rows.end - rows.start, p_s_block));
276 for i in rows.clone() {
277 let q_i = self.offset_m[i]
278 + self
279 .marginal_dense
280 .row(i)
281 .slice(ndarray::s![..p_m_use])
282 .dot(&ArrayView1::from(beta_m));
283 let g_i = self.offset_s[i]
284 + self
285 .logslope_dense
286 .row(i)
287 .slice(ndarray::s![..p_s_use])
288 .dot(&ArrayView1::from(beta_s));
289 let sg = s * g_i;
290 let c_i = (1.0 + sg * sg).sqrt();
291 let z_i = self.z[i];
292 let factor = q_i * s * s * g_i / c_i + s * z_i;
294 let g_row = self.logslope_dense.row(i);
296 out.row_mut(i - rows.start)
297 .assign(&g_row.mapv(|x| factor * x));
298 }
299 Ok(out)
300 }
301
302 fn n_outputs(&self) -> usize {
303 1
304 }
305
306 fn locks_raw_width_reduction(&self) -> bool {
307 true
318 }
319}
320
321pub(crate) fn widen_marginal_dense_with_influence(
331 marginal_dense: &Arc<Array2<f64>>,
332 influence_columns: Option<&Array2<f64>>,
333) -> Result<Arc<Array2<f64>>, String> {
334 let Some(z_infl) = influence_columns else {
335 return Ok(Arc::clone(marginal_dense));
336 };
337 let n = marginal_dense.nrows();
338 if z_infl.nrows() != n {
339 return Err(format!(
340 "influence block: residualised columns have {} rows, marginal design has {n}",
341 z_infl.nrows()
342 ));
343 }
344 let p_m = marginal_dense.ncols();
345 let p1 = z_infl.ncols();
346 let mut widened = Array2::<f64>::zeros((n, p_m + p1));
347 widened
348 .slice_mut(s![.., ..p_m])
349 .assign(marginal_dense.as_ref());
350 widened.slice_mut(s![.., p_m..]).assign(z_infl);
351 Ok(Arc::new(widened))
352}
353
354pub(crate) const LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL: f64 = 1.0e-6;
362
363#[derive(Debug, Clone)]
415pub(super) struct ReducedLogslopeReparam {
416 transform: Array2<f64>,
419}
420
421impl ReducedLogslopeReparam {
422 #[inline]
424 pub(super) fn original_cols(&self) -> usize {
425 self.transform.nrows()
426 }
427
428 #[inline]
430 pub(super) fn reduced_cols(&self) -> usize {
431 self.transform.ncols()
432 }
433
434 pub(super) fn recover_original_logslope_beta(
438 &self,
439 beta_reduced: &Array1<f64>,
440 ) -> Result<Array1<f64>, String> {
441 if beta_reduced.len() != self.reduced_cols() {
442 return Err(format!(
443 "reduced logslope reparam: β' length ({}) != reduced width ({})",
444 beta_reduced.len(),
445 self.reduced_cols()
446 ));
447 }
448 Ok(self.transform.dot(beta_reduced))
449 }
450}
451
452fn build_reduced_logslope_reparam(
460 marginal_design: &TermCollectionDesign,
461 logslope_design: &TermCollectionDesign,
462 z: &Array1<f64>,
463 row_metric: &Array1<f64>,
464 marginal_offset: &Array1<f64>,
465 logslope_offset: &Array1<f64>,
466 marginal_baseline: f64,
467 logslope_baseline: f64,
468 probit_scale: f64,
469) -> Result<Option<ReducedLogslopeReparam>, String> {
470 let marginal = marginal_design
471 .design
472 .try_to_dense_arc("build_reduced_logslope_reparam::marginal")?;
473 let logslope = logslope_design
474 .design
475 .try_to_dense_arc("build_reduced_logslope_reparam::logslope")?;
476 let n = marginal.nrows();
477 if logslope.nrows() != n
478 || z.len() != n
479 || row_metric.len() != n
480 || marginal_offset.len() != n
481 || logslope_offset.len() != n
482 {
483 return Err(format!(
484 "reduced logslope reparam row mismatch: marginal={}, logslope={}, z={}, row_metric={}, marginal_offset={}, logslope_offset={}",
485 marginal.nrows(),
486 logslope.nrows(),
487 z.len(),
488 row_metric.len(),
489 marginal_offset.len(),
490 logslope_offset.len(),
491 ));
492 }
493 let p_m = marginal.ncols();
494 let p_g = logslope.ncols();
495 if p_m == 0 || p_g == 0 {
496 return Ok(None);
497 }
498 if !marginal_baseline.is_finite()
499 || !logslope_baseline.is_finite()
500 || !probit_scale.is_finite()
501 || probit_scale <= 0.0
502 || z.iter().any(|v| !v.is_finite())
503 || row_metric.iter().any(|v| !v.is_finite() || *v < 0.0)
504 || marginal_offset.iter().any(|v| !v.is_finite())
505 || logslope_offset.iter().any(|v| !v.is_finite())
506 {
507 return Err(
508 "reduced logslope reparam requires finite pilot geometry and finite non-negative row metric"
509 .to_string(),
510 );
511 }
512
513 match reduced_logslope_transform_effective(
524 marginal.view(),
525 logslope.view(),
526 z,
527 row_metric,
528 marginal_offset,
529 logslope_offset,
530 marginal_baseline,
531 logslope_baseline,
532 probit_scale,
533 )? {
534 Some(transform) => Ok(Some(ReducedLogslopeReparam { transform })),
535 None => Ok(None),
536 }
537}
538
539pub(crate) fn reduced_logslope_transform_effective(
558 marginal: ArrayView2<'_, f64>,
559 logslope: ArrayView2<'_, f64>,
560 z: &Array1<f64>,
561 row_metric: &Array1<f64>,
562 marginal_offset: &Array1<f64>,
563 logslope_offset: &Array1<f64>,
564 marginal_baseline: f64,
565 logslope_baseline: f64,
566 probit_scale: f64,
567) -> Result<Option<Array2<f64>>, String> {
568 let n = marginal.nrows();
569 let p_m = marginal.ncols();
570 let p_g = logslope.ncols();
571 if p_m == 0 || p_g == 0 {
572 return Ok(None);
573 }
574
575 let mut m_eff = Array2::<f64>::zeros((n, p_m));
577 let mut g_eff = Array2::<f64>::zeros((n, p_g));
578 for i in 0..n {
579 let q_i = marginal_offset[i] + marginal_baseline;
580 let g_i = logslope_offset[i] + logslope_baseline;
581 let sg = probit_scale * g_i;
582 let c_i = (1.0 + sg * sg).sqrt();
583 let f_i = q_i * probit_scale * probit_scale * g_i / c_i + probit_scale * z[i];
584 for j in 0..p_m {
585 m_eff[[i, j]] = c_i * marginal[[i, j]];
586 }
587 for j in 0..p_g {
588 g_eff[[i, j]] = f_i * logslope[[i, j]];
589 }
590 }
591
592 let c_gram = fast_xt_diag_x(&g_eff, row_metric);
595 let energy_scale = (0..p_g).map(|i| c_gram[[i, i]]).fold(0.0_f64, f64::max);
596 if !energy_scale.is_finite() || energy_scale <= 0.0 {
597 return Ok(None);
598 }
599
600 let mut a_gram = fast_xt_diag_x(&m_eff, row_metric);
604 let a_scale = (0..p_m).map(|i| a_gram[[i, i]]).fold(0.0_f64, f64::max);
605 let a_ridge = (a_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL).max(f64::EPSILON);
606 for i in 0..p_m {
607 a_gram[[i, i]] += a_ridge;
608 }
609
610 let b_cross = gam_linalg::faer_ndarray::fast_xt_diag_y(&m_eff, row_metric, &g_eff);
612 let a_view = gam_linalg::faer_ndarray::FaerArrayView::new(&a_gram);
613 let a_factor =
614 gam_linalg::faer_ndarray::factorize_symmetricwith_fallback(a_view.as_ref(), Side::Lower)
615 .map_err(|e| {
616 format!(
617 "reduced logslope reparam: effective marginal Gram factorization failed: {e}"
618 )
619 })?;
620 let b_view = gam_linalg::faer_ndarray::FaerArrayView::new(&b_cross);
621 let solved = a_factor.solve(b_view.as_ref()); let a_inv_b = Array2::from_shape_fn((p_m, p_g), |(i, j)| solved[(i, j)]);
623 let schur = fast_atb(&b_cross, &a_inv_b); let mut stt = &c_gram - &schur;
625 stt = (&stt + &stt.t()) * 0.5;
626 if stt.iter().any(|v| !v.is_finite()) {
627 return Err(
628 "reduced logslope reparam: effective Schur Gram produced non-finite entries"
629 .to_string(),
630 );
631 }
632
633 let (evals, evecs) = stt
634 .eigh(Side::Lower)
635 .map_err(|e| format!("reduced logslope reparam: eigendecomposition failed: {e:?}"))?;
636 let tol = energy_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL;
640 let mut kept: Vec<usize> = (0..evals.len()).filter(|&i| evals[i] > tol).collect();
641 kept.sort_by(|&a, &b| {
642 evals[b]
643 .partial_cmp(&evals[a])
644 .unwrap_or(std::cmp::Ordering::Equal)
645 });
646 let r = kept.len();
647 if r == p_g || r == 0 {
651 return Ok(None);
652 }
653 let mut transform = Array2::<f64>::zeros((p_g, r));
654 for (out_col, &src) in kept.iter().enumerate() {
655 transform.column_mut(out_col).assign(&evecs.column(src));
656 }
657 if transform.iter().any(|v| !v.is_finite()) {
658 return Err(
659 "reduced logslope reparam: reduced transform produced non-finite entries".to_string(),
660 );
661 }
662 Ok(Some(transform))
663}
664
665fn reparameterize_logslope_design_reduced(
672 logslope_design: &TermCollectionDesign,
673 reparam: &ReducedLogslopeReparam,
674) -> Result<TermCollectionDesign, String> {
675 let g = logslope_design
676 .design
677 .try_to_dense_arc("reparameterize_logslope_design_reduced::logslope")?;
678 let p_g = g.ncols();
679 if p_g != reparam.original_cols() {
680 return Err(format!(
681 "reduced logslope reparam width mismatch: design has {p_g} cols, transform expects {}",
682 reparam.original_cols()
683 ));
684 }
685 let t = &reparam.transform;
686 let r = reparam.reduced_cols();
687 let g_reduced = fast_ab(&g, t);
689
690 let mut new_penalties: Vec<gam_terms::smooth::BlockwisePenalty> =
693 Vec::with_capacity(logslope_design.penalties.len());
694 let mut new_nullspace_dims: Vec<usize> = Vec::with_capacity(logslope_design.penalties.len());
695 for bp in &logslope_design.penalties {
696 let mut full = Array2::<f64>::zeros((p_g, p_g));
697 full.slice_mut(s![bp.col_range.clone(), bp.col_range.clone()])
698 .assign(&bp.local);
699 let st = fast_ab(&full, t); let mut s_reduced = fast_atb(t, &st); s_reduced = (&s_reduced + &s_reduced.t()) * 0.5;
703 let (evals, _) = s_reduced
705 .eigh(Side::Lower)
706 .map_err(|e| format!("reduced logslope penalty eigendecomposition failed: {e:?}"))?;
707 let max_eval = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
708 let pen_tol = (max_eval * 1.0e-12).max(f64::EPSILON);
709 let rank = evals.iter().filter(|&&v| v.abs() > pen_tol).count();
710 let nullspace_dim = r.saturating_sub(rank);
711 new_penalties.push(gam_terms::smooth::BlockwisePenalty::new(0..r, s_reduced));
712 new_nullspace_dims.push(nullspace_dim);
713 }
714
715 let new_design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(g_reduced));
716 Ok(TermCollectionDesign {
722 design: new_design,
723 penalties: new_penalties,
724 nullspace_dims: new_nullspace_dims,
725 penaltyinfo: Vec::new(),
726 dropped_penaltyinfo: Vec::new(),
727 coefficient_lower_bounds: None,
728 linear_constraints: None,
729 intercept_range: 0..0,
730 linear_ranges: Vec::new(),
731 random_effect_ranges: Vec::new(),
732 random_effect_levels: Vec::new(),
733 smooth: gam_terms::smooth::SmoothDesign {
734 term_designs: Vec::new(),
735 penalties: Vec::new(),
736 nullspace_dims: Vec::new(),
737 penaltyinfo: Vec::new(),
738 dropped_penaltyinfo: Vec::new(),
739 terms: Vec::new(),
740 coefficient_lower_bounds: None,
741 linear_constraints: None,
742 },
743 })
744}
745
746pub(crate) fn marginal_penalties_with_influence_ridge(
767 design: &TermCollectionDesign,
768 rho_marginal: &Array1<f64>,
769 influence_columns: Option<&Array2<f64>>,
770 influence_ridge_log_lambda: f64,
771) -> Result<(Vec<PenaltyMatrix>, Vec<usize>, Array1<f64>), String> {
772 let p_m = design.design.ncols();
773 let p1 = influence_columns.map(|z| z.ncols()).unwrap_or(0);
774 let total_dim = p_m + p1;
775 let mut penalties: Vec<PenaltyMatrix> = design
778 .penalties
779 .iter()
780 .map(|bp| bp.to_penalty_matrix(total_dim))
781 .collect();
782 let mut nullspace_dims = design.nullspace_dims.clone();
783 let mut log_lambdas = rho_marginal.to_vec();
784
785 if p1 > 0 {
789 penalties.push(
790 PenaltyMatrix::Blockwise {
791 local: Array2::<f64>::eye(p1),
792 col_range: p_m..total_dim,
793 total_dim,
794 }
795 .with_fixed_log_lambda(influence_ridge_log_lambda),
796 );
797 nullspace_dims.push(0);
798 log_lambdas.push(influence_ridge_log_lambda);
799 }
800
801 Ok((penalties, nullspace_dims, Array1::from_vec(log_lambdas)))
802}
803
804pub(crate) fn widen_marginal_beta_hint(
807 beta_hint: Option<Array1<f64>>,
808 p_marginal_widened: usize,
809) -> Option<Array1<f64>> {
810 beta_hint.map(|hint| {
811 if hint.len() == p_marginal_widened {
812 hint
813 } else {
814 let mut widened = Array1::<f64>::zeros(p_marginal_widened);
815 let copy = hint.len().min(p_marginal_widened);
816 widened
817 .slice_mut(s![..copy])
818 .assign(&hint.slice(s![..copy]));
819 widened
820 }
821 })
822}
823
824fn marginal_fitted_eta_sup_norm(design: &TermCollectionDesign, masked_beta: &Array1<f64>) -> f64 {
834 let x = &design.design;
835 let n = x.nrows();
836 if n == 0 || x.ncols() == 0 {
837 return 0.0;
838 }
839 let mut sup = 0.0_f64;
840 for row in 0..n {
841 let eta = x.dot_row_view(row, masked_beta.view());
842 if eta.is_finite() {
843 sup = sup.max(eta.abs());
844 }
845 }
846 sup
847}
848
849fn marginal_design_beta(
852 design: &TermCollectionDesign,
853 block_beta: ArrayView1<'_, f64>,
854) -> Array1<f64> {
855 let ncols = design.design.ncols();
856 let mut masked = Array1::<f64>::zeros(ncols);
857 let copy = ncols.min(block_beta.len());
858 masked
859 .slice_mut(s![..copy])
860 .assign(&block_beta.slice(s![..copy]));
861 masked
862}
863
864fn mask_parametric_columns(
870 design: &TermCollectionDesign,
871 spec: &TermCollectionSpec,
872 full: &Array1<f64>,
873) -> Array1<f64> {
874 let ncols = design.design.ncols();
875 let mut masked = Array1::<f64>::zeros(ncols);
876 if design.intercept_range.len() == 1 {
877 let idx = design.intercept_range.start;
878 if idx < ncols {
879 masked[idx] = full[idx];
880 }
881 }
882 for (linear, (_, range)) in spec.linear_terms.iter().zip(design.linear_ranges.iter()) {
883 if linear.double_penalty {
884 continue;
885 }
886 for col in range.clone() {
887 if col < ncols {
888 masked[col] = full[col];
889 }
890 }
891 }
892 masked
893}
894
895pub(crate) fn bernoulli_marginal_slope_runaway_error_from_beta(
906 block_beta: ArrayView1<'_, f64>,
907 design: &TermCollectionDesign,
908 spec: &TermCollectionSpec,
909 inner_converged: bool,
910 eval_label: &str,
911) -> Option<String> {
912 let full_beta = marginal_design_beta(design, block_beta);
913 let parametric_beta = mask_parametric_columns(design, spec, &full_beta);
914
915 let eta_parametric = marginal_fitted_eta_sup_norm(design, ¶metric_beta);
916 let eta_full = marginal_fitted_eta_sup_norm(design, &full_beta);
917
918 let (eta_inf, explanation) = if eta_parametric >= BMS_PROBIT_SEPARATION_ETA_INF {
919 (
920 eta_parametric,
921 "an unpenalized parametric marginal direction has no stable finite probit optimum and its fitted predictor has run to the probit underflow scale",
922 )
923 } else if eta_full >= BMS_PROBIT_SEPARATION_ETA_INF {
924 (
925 eta_full,
926 "a marginal direction is trading off against the logslope surface; this is the under-constrained marginal/logslope coupling that appears when the score is correlated with the shared surface covariates",
927 )
928 } else {
929 return None;
933 };
934
935 let inner_status = if inner_converged {
936 "the inner solve reached a KKT certificate at this separation-scale predictor"
937 } else {
938 "the inner solve failed while already carrying a separation-scale predictor"
939 };
940 let beta_abs = full_beta
942 .iter()
943 .copied()
944 .filter(|v| v.is_finite())
945 .fold(0.0_f64, |acc, v| acc.max(v.abs()));
946
947 Some(format!(
948 "bernoulli marginal-slope probit marginal/logslope runaway detected in block \
949 'marginal_surface' during {eval_label}: the fitted marginal predictor has \
950 |η|∞={eta_inf:.3e} (numerical-degeneracy threshold \
951 {BMS_PROBIT_SEPARATION_ETA_INF:.1}; raw |β|∞={beta_abs:.3e} is reported for \
952 context only and does not gate this diagnostic). The joint design is \
953 identifiable; {explanation}. {inner_status}. The robust Jeffreys curvature \
954 path is already installed for this fit, so this diagnostic means the current \
955 coupled surface still drives the linear predictor to the probit underflow \
956 scale rather than a request for an external bias-reduction prior. Reduce or \
957 reparameterize the coupled marginal/logslope surface, or use a \
958 lower-dimensional logslope interaction. This is not a \
959 Matérn/Duchon polynomial-nullspace or cross-block gauge-priority \
960 failure."
961 ))
962}
963
964pub(crate) fn bernoulli_marginal_slope_runaway_error(
965 warm_start: &CustomFamilyWarmStart,
966 design: &TermCollectionDesign,
967 spec: &TermCollectionSpec,
968 inner_converged: bool,
969 eval_label: &str,
970) -> Option<String> {
971 let block_beta = warm_start.block_beta_view(0)?;
972 bernoulli_marginal_slope_runaway_error_from_beta(
973 block_beta,
974 design,
975 spec,
976 inner_converged,
977 eval_label,
978 )
979}
980
981#[cfg(test)]
982mod runaway_tests {
983 use super::*;
984 use gam_linalg::faer_ndarray::{FaerArrayView, factorize_symmetricwith_fallback, fast_xt_diag_y};
985 use gam_terms::smooth::{LinearCoefficientGeometry, LinearTermSpec};
986
987 pub(crate) fn marginal_logslope_overlap_penalty(
993 marginal_design: &DesignMatrix,
994 logslope_design: &DesignMatrix,
995 z: &Array1<f64>,
996 row_metric: &Array1<f64>,
997 marginal_offset: &Array1<f64>,
998 logslope_offset: &Array1<f64>,
999 marginal_baseline: f64,
1000 logslope_baseline: f64,
1001 probit_scale: f64,
1002 ) -> Result<Option<Array2<f64>>, String> {
1003 let marginal =
1004 marginal_design.try_to_dense_arc("marginal_logslope_overlap_penalty::marginal")?;
1005 let logslope =
1006 logslope_design.try_to_dense_arc("marginal_logslope_overlap_penalty::logslope")?;
1007 let n = marginal.nrows();
1008 if logslope.nrows() != n
1009 || z.len() != n
1010 || row_metric.len() != n
1011 || marginal_offset.len() != n
1012 || logslope_offset.len() != n
1013 {
1014 return Err(format!(
1015 "marginal/logslope overlap penalty row mismatch: marginal={}, logslope={}, z={}, row_metric={}, marginal_offset={}, logslope_offset={}",
1016 marginal.nrows(),
1017 logslope.nrows(),
1018 z.len(),
1019 row_metric.len(),
1020 marginal_offset.len(),
1021 logslope_offset.len(),
1022 ));
1023 }
1024 let p_m = marginal.ncols();
1025 let p_g = logslope.ncols();
1026 if p_m == 0 || p_g == 0 {
1027 return Ok(None);
1028 }
1029 if !marginal_baseline.is_finite()
1030 || !logslope_baseline.is_finite()
1031 || !probit_scale.is_finite()
1032 || probit_scale <= 0.0
1033 || z.iter().any(|v| !v.is_finite())
1034 || row_metric.iter().any(|v| !v.is_finite() || *v < 0.0)
1035 || marginal_offset.iter().any(|v| !v.is_finite())
1036 || logslope_offset.iter().any(|v| !v.is_finite())
1037 {
1038 return Err(
1039 "marginal/logslope overlap penalty requires finite pilot geometry and finite non-negative row metric"
1040 .to_string(),
1041 );
1042 }
1043
1044 let mut marginal_effective = Array2::<f64>::zeros((n, p_m));
1045 let mut effective_logslope = Array2::<f64>::zeros((n, p_g));
1046 for i in 0..n {
1047 let q_i = marginal_offset[i] + marginal_baseline;
1048 let g_i = logslope_offset[i] + logslope_baseline;
1049 let sg = probit_scale * g_i;
1050 let c_i = (1.0 + sg * sg).sqrt();
1051 let logslope_factor =
1052 q_i * probit_scale * probit_scale * g_i / c_i + probit_scale * z[i];
1053 for j in 0..p_m {
1054 marginal_effective[[i, j]] = c_i * marginal[[i, j]];
1055 }
1056 for j in 0..p_g {
1057 effective_logslope[[i, j]] = logslope_factor * logslope[[i, j]];
1058 }
1059 }
1060 if effective_logslope.iter().all(|v| v.abs() <= f64::EPSILON) {
1061 return Ok(None);
1062 }
1063
1064 let mut gram = fast_xt_diag_x(&effective_logslope, row_metric);
1065 let gram_scale = gram.diag().iter().copied().fold(0.0_f64, f64::max);
1066 if !gram_scale.is_finite() || gram_scale <= 0.0 {
1067 return Ok(None);
1068 }
1069 let projection_ridge = (gram_scale * 1.0e-10).max(f64::EPSILON);
1070 for i in 0..p_g {
1071 gram[[i, i]] += projection_ridge;
1072 }
1073 let cross = fast_xt_diag_y(&effective_logslope, row_metric, &marginal_effective);
1074 let gram_view = FaerArrayView::new(&gram);
1075 let factor = factorize_symmetricwith_fallback(gram_view.as_ref(), Side::Lower)
1076 .map_err(|e| format!("marginal/logslope overlap Gram factorization failed: {e}"))?;
1077 let rhsview = FaerArrayView::new(&cross);
1078 let coeffs_mat = factor.solve(rhsview.as_ref());
1079 let coeffs = Array2::from_shape_fn((p_g, p_m), |(i, j)| coeffs_mat[(i, j)]);
1080 let projected_marginal = fast_ab(&effective_logslope, &coeffs);
1081 let mut penalty = fast_xt_diag_y(&marginal_effective, row_metric, &projected_marginal);
1082 penalty = (&penalty + &penalty.t()) * 0.5;
1083 let max_abs = penalty.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
1084 if !max_abs.is_finite() || max_abs <= 1.0e-12 {
1085 return Ok(None);
1086 }
1087 Ok(Some(penalty))
1088 }
1089
1090 #[test]
1098 pub(crate) fn effective_reduction_drops_score_weighted_confound_raw_audit_misses() {
1099 let m = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1101 let g = Array2::<f64>::from_shape_vec((3, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 9.0]).unwrap();
1102 let z = Array1::from_vec(vec![1.0, 0.5, 1.0 / 3.0]);
1103 let w = Array1::<f64>::ones(3);
1104 let zero = Array1::<f64>::zeros(3);
1105
1106 let reparam = reduced_logslope_transform_effective(
1109 m.view(),
1110 g.view(),
1111 &z,
1112 &w,
1113 &zero,
1114 &zero,
1115 0.0,
1116 0.0,
1117 1.0,
1118 )
1119 .expect("effective reduction must succeed")
1120 .expect("effective audit must reduce the score-weighted confound (raw audit would not)");
1121 assert_eq!(
1122 reparam.ncols(),
1123 1,
1124 "exactly one effective-identifiable logslope direction should survive"
1125 );
1126
1127 let g_eff = {
1131 let mut e = Array2::<f64>::zeros((3, 2));
1132 for i in 0..3 {
1133 for j in 0..2 {
1134 e[[i, j]] = z[i] * g[[i, j]];
1135 }
1136 }
1137 e
1138 };
1139 let img = g_eff.dot(&reparam.column(0));
1140 let mean = img.iter().sum::<f64>() / 3.0;
1141 let var = img.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / 3.0;
1142 assert!(
1143 var > 1.0e-6,
1144 "kept direction must be the identifiable (non-constant) effective column, var={var}"
1145 );
1146 }
1147
1148 #[test]
1154 pub(crate) fn effective_reduction_fully_confounded_single_column_returns_none() {
1155 let m = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1156 let g = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1157 let z = Array1::from_vec(vec![1.0, 0.5, 1.0 / 3.0]);
1158 let w = Array1::<f64>::ones(3);
1159 let zero = Array1::<f64>::zeros(3);
1160 let reparam = reduced_logslope_transform_effective(
1161 m.view(),
1162 g.view(),
1163 &z,
1164 &w,
1165 &zero,
1166 &zero,
1167 0.0,
1168 0.0,
1169 1.0,
1170 )
1171 .expect("effective reduction must succeed");
1172 assert!(
1173 reparam.is_none(),
1174 "fully effective-confounded logslope must keep raw design (None), not a 0-width block"
1175 );
1176 }
1177
1178 #[test]
1181 pub(crate) fn effective_reduction_no_confound_returns_none() {
1182 let m = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1183 let g = Array2::<f64>::from_shape_vec((3, 2), vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
1185 let z = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1186 let w = Array1::<f64>::ones(3);
1187 let zero = Array1::<f64>::zeros(3);
1188 let reparam = reduced_logslope_transform_effective(
1189 m.view(),
1190 g.view(),
1191 &z,
1192 &w,
1193 &zero,
1194 &zero,
1195 0.0,
1196 0.0,
1197 1.0,
1198 )
1199 .expect("effective reduction must succeed");
1200 assert!(
1201 reparam.is_none(),
1202 "no effective confound ⇒ no reduction (raw design kept unchanged)"
1203 );
1204 }
1205
1206 #[test]
1207 pub(crate) fn spatial_joint_setup_counts_only_learned_penalties_in_rho() {
1208 let data = Array2::<f64>::zeros((3, 1));
1209 let empty_terms = TermCollectionSpec {
1210 linear_terms: Vec::new(),
1211 random_effect_terms: Vec::new(),
1212 smooth_terms: Vec::new(),
1213 };
1214 let setup = joint_setup(
1215 data.view(),
1216 &empty_terms,
1217 &empty_terms,
1218 2,
1219 3,
1220 &[0.4],
1221 &SpatialLengthScaleOptimizationOptions::default(),
1222 );
1223
1224 assert_eq!(
1225 setup.rho_dim(),
1226 6,
1227 "BMS spatial setup rho must contain only learned marginal/logslope/auxiliary penalties; fixed physical ridges are carried by PenaltyMatrix::Fixed"
1228 );
1229 }
1230
1231 #[test]
1232 pub(crate) fn overlap_penalty_targets_score_weighted_logslope_span() {
1233 let marginal = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1234 Array2::from_shape_vec((4, 1), vec![0.0, 1.0, 2.0, 3.0]).unwrap(),
1235 ));
1236 let logslope = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1237 Array2::from_shape_vec((4, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap(),
1238 ));
1239 let z = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
1240 let row_metric = Array1::ones(4);
1241 let offsets = Array1::zeros(4);
1242
1243 let penalty = marginal_logslope_overlap_penalty(
1244 &marginal,
1245 &logslope,
1246 &z,
1247 &row_metric,
1248 &offsets,
1249 &offsets,
1250 0.0,
1251 0.0,
1252 1.0,
1253 )
1254 .expect("overlap penalty should build")
1255 .expect("marginal signal lies in the pilot logslope Jacobian span");
1256
1257 assert_eq!(penalty.dim(), (1, 1));
1258 assert!((penalty[[0, 0]] - 14.0).abs() < 1.0e-6);
1259 }
1260
1261 #[test]
1262 pub(crate) fn overlap_penalty_skips_weight_orthogonal_channels() {
1263 let marginal = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1264 Array2::from_shape_vec((4, 1), vec![-1.0, 1.0, -1.0, 1.0]).unwrap(),
1265 ));
1266 let logslope = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1267 Array2::from_shape_vec((4, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap(),
1268 ));
1269 let z = Array1::ones(4);
1270 let row_metric = Array1::ones(4);
1271 let offsets = Array1::zeros(4);
1272
1273 let penalty = marginal_logslope_overlap_penalty(
1274 &marginal,
1275 &logslope,
1276 &z,
1277 &row_metric,
1278 &offsets,
1279 &offsets,
1280 0.0,
1281 0.0,
1282 1.0,
1283 )
1284 .expect("overlap penalty should build");
1285
1286 assert!(penalty.is_none());
1287 }
1288
1289 fn dense_marginal_design(
1297 x: Array2<f64>,
1298 intercept_range: std::ops::Range<usize>,
1299 linear_ranges: Vec<(String, std::ops::Range<usize>)>,
1300 ) -> TermCollectionDesign {
1301 TermCollectionDesign {
1302 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x)),
1303 penalties: Vec::new(),
1304 nullspace_dims: Vec::new(),
1305 penaltyinfo: Vec::new(),
1306 dropped_penaltyinfo: Vec::new(),
1307 coefficient_lower_bounds: None,
1308 linear_constraints: None,
1309 intercept_range,
1310 linear_ranges,
1311 random_effect_ranges: Vec::new(),
1312 random_effect_levels: Vec::new(),
1313 smooth: gam_terms::smooth::SmoothDesign {
1314 term_designs: Vec::new(),
1315 penalties: Vec::new(),
1316 nullspace_dims: Vec::new(),
1317 penaltyinfo: Vec::new(),
1318 dropped_penaltyinfo: Vec::new(),
1319 terms: Vec::new(),
1320 coefficient_lower_bounds: None,
1321 linear_constraints: None,
1322 },
1323 }
1324 }
1325
1326 fn linear_term(name: &str, feature_col: usize) -> LinearTermSpec {
1327 LinearTermSpec {
1328 name: name.to_string(),
1329 feature_col,
1330 feature_cols: vec![feature_col],
1331 categorical_levels: vec![],
1332 double_penalty: false,
1333 coefficient_geometry: LinearCoefficientGeometry::default(),
1334 coefficient_min: None,
1335 coefficient_max: None,
1336 }
1337 }
1338
1339 fn empty_spec() -> TermCollectionSpec {
1340 TermCollectionSpec {
1341 linear_terms: Vec::new(),
1342 random_effect_terms: Vec::new(),
1343 smooth_terms: Vec::new(),
1344 }
1345 }
1346
1347 #[test]
1354 pub(crate) fn runaway_guard_silent_when_huge_beta_cancels_to_bounded_eta() {
1355 let x = Array2::<f64>::from_shape_vec((4, 2), vec![1.0; 8]).unwrap();
1357 let design = dense_marginal_design(x, 0..0, Vec::new());
1358 let beta = Array1::from_vec(vec![60.0, -60.0]);
1359
1360 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1361 beta.view(),
1362 &design,
1363 &empty_spec(),
1364 true,
1365 "regression-fixture",
1366 );
1367 assert!(
1368 msg.is_none(),
1369 "huge cancelling β with bounded fitted η must NOT trip the runaway guard; got {msg:?}"
1370 );
1371 }
1372
1373 #[test]
1377 pub(crate) fn runaway_guard_fires_when_fitted_eta_exceeds_threshold() {
1378 let x = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1379 let design = dense_marginal_design(x, 0..0, Vec::new());
1380 let beta = Array1::from_vec(vec![40.0]);
1381
1382 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1383 beta.view(),
1384 &design,
1385 &empty_spec(),
1386 true,
1387 "separation-fixture",
1388 )
1389 .expect("fitted |η|∞=40 ≥ 35 must trip the runaway guard");
1390
1391 assert!(msg.contains("marginal/logslope runaway"));
1392 assert!(msg.contains("|η|∞"));
1393 assert!(msg.contains("4.000e1"));
1394 assert!(msg.contains("score is correlated with the shared surface covariates"));
1395 assert!(msg.contains("not a Matérn/Duchon polynomial-nullspace"));
1396 assert!(msg.contains("KKT certificate"));
1397 }
1398
1399 #[test]
1403 pub(crate) fn runaway_guard_names_unpenalized_parametric_direction_via_fitted_eta() {
1404 let x = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1405 let design = dense_marginal_design(x, 0..0, vec![("sex".to_string(), 0..1)]);
1406 let mut spec = empty_spec();
1407 spec.linear_terms.push(linear_term("sex", 0));
1408 let beta = Array1::from_vec(vec![41.0]);
1409
1410 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1411 beta.view(),
1412 &design,
1413 &spec,
1414 true,
1415 "parametric-fixture",
1416 )
1417 .expect("parametric fitted |η|∞=41 ≥ 35 must trip the runaway guard");
1418
1419 assert!(msg.contains("unpenalized parametric marginal direction"));
1420 assert!(msg.contains("|η|∞"));
1421 assert!(msg.contains("robust Jeffreys curvature path is already installed"));
1422 assert!(msg.contains("not a Matérn/Duchon polynomial-nullspace"));
1423 }
1424
1425 #[test]
1429 pub(crate) fn runaway_guard_silent_for_nonconverged_but_bounded_eta() {
1430 let x = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1431 let design = dense_marginal_design(x, 0..0, Vec::new());
1432 let beta = Array1::from_vec(vec![5.0]);
1433
1434 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1435 beta.view(),
1436 &design,
1437 &empty_spec(),
1438 false,
1439 "nonconverged-fixture",
1440 );
1441 assert!(
1442 msg.is_none(),
1443 "bounded fitted η must not raise the separation error even when the inner solve did not converge; got {msg:?}"
1444 );
1445 }
1446
1447 #[test]
1450 pub(crate) fn runaway_guard_fires_for_nonconverged_separating_eta() {
1451 let x = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1452 let design = dense_marginal_design(x, 0..0, Vec::new());
1453 let beta = Array1::from_vec(vec![50.0]);
1454
1455 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1456 beta.view(),
1457 &design,
1458 &empty_spec(),
1459 false,
1460 "nonconverged-separating-fixture",
1461 )
1462 .expect("separating |η|∞ at non-convergence must still trip the guard");
1463
1464 assert!(msg.contains(
1465 "the inner solve failed while already carrying a separation-scale predictor"
1466 ));
1467 }
1468
1469 #[test]
1482 pub(crate) fn bms_block_jacobians_self_compute_at_audit_empty_beta_nonzero_logslope_baseline() {
1483 use std::sync::Arc;
1484 let n = 4usize;
1485 let marginal =
1486 Arc::new(Array2::<f64>::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap());
1487 let logslope =
1488 Arc::new(Array2::<f64>::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap());
1489 let offset_m = Array1::<f64>::zeros(n);
1490 let g_baseline = 0.3_f64;
1493 let offset_s = Array1::<f64>::from_elem(n, g_baseline);
1494 let z = Arc::new(Array1::from_vec(vec![-0.7, 0.2, 0.9, 1.4]));
1495 let s = 1.0_f64;
1496
1497 let beta: Vec<f64> = Vec::new();
1499 let state = FamilyLinearizationState {
1500 beta: &beta,
1501 family_scalars: None,
1502 channel_hessian: None,
1503 probit_frailty_scale: s,
1504 };
1505
1506 let marginal_jac = BmsMarginalJacobian::new(
1507 Arc::clone(&marginal),
1508 Arc::clone(&logslope),
1509 offset_m.clone(),
1510 offset_s.clone(),
1511 1,
1512 );
1513 let j_m = marginal_jac
1514 .effective_jacobian_rows(&state, 0..n)
1515 .expect("BMS marginal Jacobian must self-compute at audit empty β (gam#370)");
1516 let c_expected = (1.0 + (s * g_baseline).powi(2)).sqrt();
1518 assert_eq!(j_m.dim(), (n, 1));
1519 for i in 0..n {
1520 assert!(
1521 (j_m[[i, 0]] - c_expected).abs() < 1e-12,
1522 "marginal J[{i}] = {} != closed-form c_i = {c_expected}",
1523 j_m[[i, 0]]
1524 );
1525 }
1526
1527 let logslope_jac = BmsLogslopeJacobian::new(
1528 Arc::clone(&marginal),
1529 Arc::clone(&logslope),
1530 offset_m,
1531 offset_s,
1532 Arc::clone(&z),
1533 1,
1534 );
1535 let j_s = logslope_jac
1536 .effective_jacobian_rows(&state, 0..n)
1537 .expect("BMS logslope Jacobian must self-compute at audit empty β (gam#370)");
1538 assert_eq!(j_s.dim(), (n, 1));
1541 for i in 0..n {
1542 let expected = s * z[i];
1543 assert!(
1544 (j_s[[i, 0]] - expected).abs() < 1e-12,
1545 "logslope J[{i}] = {} != closed-form factor {expected}",
1546 j_s[[i, 0]]
1547 );
1548 assert!(j_s[[i, 0]].is_finite());
1549 }
1550 }
1551}
1552
1553pub(crate) fn build_marginal_blockspec_bms(
1554 design: &TermCollectionDesign,
1555 baseline: f64,
1556 offset: &Array1<f64>,
1557 rho: Array1<f64>,
1558 beta_hint: Option<Array1<f64>>,
1559 logslope_design: &TermCollectionDesign,
1560 logslope_offset: &Array1<f64>,
1561 logslope_baseline: f64,
1562 p_marginal: usize,
1563 influence_columns: Option<&Array2<f64>>,
1564 influence_ridge_log_lambda: f64,
1565) -> Result<ParameterBlockSpec, String> {
1566 let offset_m = offset + baseline;
1567 let offset_s = logslope_offset + logslope_baseline;
1568 let raw_marginal_dense = design
1569 .design
1570 .try_to_dense_arc("build_marginal_blockspec_bms::marginal")?;
1571 let marginal_dense =
1572 widen_marginal_dense_with_influence(&raw_marginal_dense, influence_columns)?;
1573 let logslope_dense = logslope_design
1574 .design
1575 .try_to_dense_arc("build_marginal_blockspec_bms::logslope")?;
1576 let callback: Arc<dyn BlockEffectiveJacobian> = Arc::new(BmsMarginalJacobian {
1577 marginal_dense: Arc::clone(&marginal_dense),
1578 logslope_dense,
1579 offset_m: offset_m.clone(),
1580 offset_s,
1581 p_marginal,
1582 });
1583 let (penalties, nullspace_dims, initial_log_lambdas) = marginal_penalties_with_influence_ridge(
1584 design,
1585 &rho,
1586 influence_columns,
1587 influence_ridge_log_lambda,
1588 )?;
1589 Ok(ParameterBlockSpec {
1590 name: "marginal_surface".to_string(),
1591 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1592 (*marginal_dense).clone(),
1593 )),
1594 offset: offset_m,
1595 penalties,
1596 nullspace_dims,
1597 initial_log_lambdas,
1598 initial_beta: widen_marginal_beta_hint(beta_hint, p_marginal),
1599 gauge_priority: GAUGE_PRIORITY_MARGINAL,
1612 jacobian_callback: Some(callback),
1613 stacked_design: None,
1614 stacked_offset: None,
1615 })
1616}
1617
1618pub(crate) fn build_logslope_blockspec_bms(
1619 design: &TermCollectionDesign,
1620 baseline: f64,
1621 offset: &Array1<f64>,
1622 rho: Array1<f64>,
1623 beta_hint: Option<Array1<f64>>,
1624 marginal_design: &TermCollectionDesign,
1625 marginal_offset: &Array1<f64>,
1626 marginal_baseline: f64,
1627 z: Arc<Array1<f64>>,
1628 p_marginal: usize,
1629 influence_columns: Option<&Array2<f64>>,
1630) -> Result<ParameterBlockSpec, String> {
1631 let offset_s = offset + baseline;
1632 let offset_m = marginal_offset + marginal_baseline;
1633 let raw_marginal_dense = marginal_design
1634 .design
1635 .try_to_dense_arc("build_logslope_blockspec_bms::marginal")?;
1636 let marginal_dense =
1641 widen_marginal_dense_with_influence(&raw_marginal_dense, influence_columns)?;
1642 let logslope_dense = design
1643 .design
1644 .try_to_dense_arc("build_logslope_blockspec_bms::logslope")?;
1645 let callback: Arc<dyn BlockEffectiveJacobian> = Arc::new(BmsLogslopeJacobian {
1646 marginal_dense,
1647 logslope_dense: Arc::clone(&logslope_dense),
1648 offset_m,
1649 offset_s: offset_s.clone(),
1650 z,
1651 p_marginal,
1652 });
1653 Ok(ParameterBlockSpec {
1654 name: "logslope_surface".to_string(),
1655 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1656 (*logslope_dense).clone(),
1657 )),
1658 offset: offset_s,
1659 penalties: design.penalties_as_penalty_matrix(),
1660 nullspace_dims: design.nullspace_dims.clone(),
1661 initial_log_lambdas: rho,
1662 initial_beta: beta_hint,
1663 gauge_priority: GAUGE_PRIORITY_LOGSLOPE,
1671 jacobian_callback: Some(callback),
1672 stacked_design: None,
1673 stacked_offset: None,
1674 })
1675}
1676
1677pub(crate) fn build_deviation_aux_blockspec(
1678 name: &str,
1679 prepared: &DeviationPrepared,
1680 rho: Array1<f64>,
1681 beta_hint: Option<Array1<f64>>,
1682) -> Result<ParameterBlockSpec, String> {
1683 let mut block = prepared.block.clone();
1684 block.initial_log_lambdas = Some(rho);
1685 let candidate_beta = beta_hint.or_else(|| Some(Array1::<f64>::zeros(block.design.ncols())));
1686 block.initial_beta = candidate_beta
1687 .map(|beta| {
1688 let zero = Array1::<f64>::zeros(beta.len());
1689 project_monotone_feasible_beta(&prepared.runtime, &zero, &beta, name)
1690 })
1691 .transpose()?;
1692 let mut spec = block.intospec(name)?;
1693 spec.gauge_priority = match name {
1703 "link_dev" => GAUGE_PRIORITY_LINK_DEV,
1704 "score_warp_dev" => GAUGE_PRIORITY_SCORE_WARP_DEV,
1711 _ => GAUGE_PRIORITY_DEVIATION_DEFAULT,
1712 };
1713 Ok(spec)
1714}
1715
1716pub(crate) fn push_deviation_aux_blockspecs(
1717 blocks: &mut Vec<ParameterBlockSpec>,
1718 rho: &Array1<f64>,
1719 cursor: &mut usize,
1720 score_warp_prepared: Option<&DeviationPrepared>,
1721 link_dev_prepared: Option<&DeviationPrepared>,
1722 score_warp_beta_hint: Option<Array1<f64>>,
1723 link_dev_beta_hint: Option<Array1<f64>>,
1724) -> Result<(), String> {
1725 if let Some(prepared) = score_warp_prepared {
1726 let rho_h = rho
1727 .slice(s![*cursor..*cursor + prepared.block.penalties.len()])
1728 .to_owned();
1729 *cursor += prepared.block.penalties.len();
1730 blocks.push(build_deviation_aux_blockspec(
1731 "score_warp_dev",
1732 prepared,
1733 rho_h,
1734 score_warp_beta_hint,
1735 )?);
1736 }
1737 if let Some(prepared) = link_dev_prepared {
1738 let rho_w = rho
1739 .slice(s![*cursor..*cursor + prepared.block.penalties.len()])
1740 .to_owned();
1741 blocks.push(build_deviation_aux_blockspec(
1742 "link_dev",
1743 prepared,
1744 rho_w,
1745 link_dev_beta_hint,
1746 )?);
1747 }
1748 Ok(())
1749}
1750
1751fn inner_fit(
1752 family: &BernoulliMarginalSlopeFamily,
1753 blocks: &[ParameterBlockSpec],
1754 options: &BlockwiseFitOptions,
1755) -> Result<UnifiedFitResult, String> {
1756 let mut options = options.clone();
1757 options.use_outer_hessian = false;
1762 options.outer_tol = options.outer_tol.max(2.0e-5);
1763 fit_custom_family(family, blocks, &options).map_err(|e| e.to_string())
1764}
1765
1766pub fn fit_bernoulli_marginal_slope_terms(
1767 data: ArrayView2<'_, f64>,
1768 spec: BernoulliMarginalSlopeTermSpec,
1769 options: &BlockwiseFitOptions,
1770 kappa_options: &SpatialLengthScaleOptimizationOptions,
1771 policy: &gam_runtime::resource::ResourcePolicy,
1772) -> Result<BernoulliMarginalSlopeFitResult, String> {
1773 let mut spec = spec;
1774 let data_view = data;
1775 validate_spec(data_view, &spec)?;
1776 let mjs_frozen_marginal =
1785 gam_terms::smooth::freeze_measure_jet_length_scale_learning(&mut spec.marginalspec);
1786 let mjs_frozen_logslope =
1787 gam_terms::smooth::freeze_measure_jet_length_scale_learning(&mut spec.logslopespec);
1788 if mjs_frozen_marginal + mjs_frozen_logslope > 0 {
1789 log::info!(
1790 "[BMS spatial] froze measure-jet length-scale learning on {} marginal + {} log-slope \
1791 term(s): the coupled surface keeps ℓ at its conditioned auto value (#1116)",
1792 mjs_frozen_marginal,
1793 mjs_frozen_logslope
1794 );
1795 }
1796 let mut effective_kappa_options = kappa_options.clone();
1797 let kappa_locked_marginal = gam_terms::smooth::all_spatial_terms_kappa_fixed(&spec.marginalspec);
1807 let kappa_locked_logslope = gam_terms::smooth::all_spatial_terms_kappa_fixed(&spec.logslopespec);
1808 if effective_kappa_options.enabled && kappa_locked_marginal && kappa_locked_logslope {
1809 log::info!(
1810 "[BMS spatial] disabling κ/ψ optimization: every spatial term has an \
1811 explicit length_scale and no anisotropy; user-supplied kernel scale is fixed"
1812 );
1813 effective_kappa_options.enabled = false;
1814 }
1815 let flex_spatial_pilot_path = (spec.score_warp.is_some() || spec.link_dev.is_some())
1816 && spec.y.len() >= BMS_FLEX_SPATIAL_OUTER_PILOT_ROW_THRESHOLD
1817 && effective_kappa_options.enabled;
1818 if flex_spatial_pilot_path {
1819 let marginal_terms = spatial_length_scale_term_indices(&spec.marginalspec);
1820 let logslope_terms = spatial_length_scale_term_indices(&spec.logslopespec);
1821 let marginal_updates = apply_spatial_anisotropy_pilot_initializer(
1822 data_view,
1823 &mut spec.marginalspec,
1824 &marginal_terms,
1825 effective_kappa_options.pilot_subsample_threshold,
1826 &effective_kappa_options,
1827 );
1828 let logslope_updates = apply_spatial_anisotropy_pilot_initializer(
1829 data_view,
1830 &mut spec.logslopespec,
1831 &logslope_terms,
1832 effective_kappa_options.pilot_subsample_threshold,
1833 &effective_kappa_options,
1834 );
1835 effective_kappa_options.enabled = false;
1836 log::info!(
1837 "[BMS spatial] n={} flex=true pilot_geometry_updates={} iterative_spatial_outer=false reason=large-flex-spatial-pilot",
1838 spec.y.len(),
1839 marginal_updates + logslope_updates,
1840 );
1841 }
1842 let (z_standardized, z_normalization) = standardize_latent_z_with_policy(
1843 &spec.z,
1844 &spec.weights,
1845 "bernoulli-marginal-slope",
1846 &spec.latent_z_policy,
1847 )?;
1848 spec.z = z_standardized;
1849 let sigma_learnable = matches!(
1850 &spec.frailty,
1851 FrailtySpec::GaussianShift { sigma_fixed: None }
1852 );
1853 let initial_sigma = match &spec.frailty {
1854 FrailtySpec::GaussianShift {
1855 sigma_fixed: Some(s),
1856 } => Some(*s),
1857 FrailtySpec::GaussianShift { sigma_fixed: None } => Some(0.5),
1858 FrailtySpec::None => None,
1859 FrailtySpec::HazardMultiplier { .. } => {
1860 return Err(
1861 "internal: validate_spec should have rejected unsupported marginal-slope frailty"
1862 .to_string(),
1863 );
1864 }
1865 };
1866 let probit_scale = probit_frailty_scale(initial_sigma);
1867 let (_raw_joint_designs, mut joint_specs) = build_term_collection_designs_and_freeze_joint(
1868 data_view,
1869 &[spec.marginalspec.clone(), spec.logslopespec.clone()],
1870 )
1871 .map_err(|e| e.to_string())?;
1872 let marginalspec_boot = joint_specs.remove(0);
1873 let logslopespec_boot = joint_specs.remove(0);
1874 let (mut joint_designs, _) = build_term_collection_designs_and_freeze_joint(
1891 data_view,
1892 &[marginalspec_boot.clone(), logslopespec_boot.clone()],
1893 )
1894 .map_err(|e| format!("failed to rebuild frozen probe BMS joint designs: {e}"))?;
1895 let marginal_design = joint_designs.remove(0);
1896 let logslope_design = joint_designs.remove(0);
1897 let absorber_active = spec
1904 .score_influence_jacobian
1905 .as_ref()
1906 .is_some_and(|j| j.ncols() > 0);
1907 let conditioning_dense = if absorber_active {
1908 None
1909 } else {
1910 Some(
1911 marginal_design
1912 .design
1913 .try_to_dense_arc("bernoulli marginal-slope conditional latent-z gate")?,
1914 )
1915 };
1916 let (latent_measure, latent_z_calibration) = build_latent_measure_with_geometry(
1917 &spec.z,
1918 &spec.weights,
1919 &spec.latent_z_policy,
1920 conditioning_dense.as_ref().map(|d| d.view()),
1921 )?;
1922 if latent_measure.is_empirical() && sigma_learnable {
1923 return Err("empirical latent-measure marginal-slope calibration requires fixed GaussianShift sigma; learnable sigma derivatives must be fit under the standard-normal latent measure"
1924 .to_string());
1925 }
1926
1927 let y = Arc::new(spec.y.clone());
1928 let weights = Arc::new(spec.weights.clone());
1929 let z = match &latent_z_calibration {
1934 LatentMeasureCalibration::None => Arc::new(spec.z.clone()),
1935 LatentMeasureCalibration::RankInverseNormal(cal) => {
1936 Arc::new(cal.apply_to_training(&spec.z)?)
1937 }
1938 LatentMeasureCalibration::ConditionalLocationScale(cal) => {
1939 let a_block = conditioning_dense.as_ref().ok_or_else(|| {
1942 "conditional latent calibration requires the marginal conditioning block"
1943 .to_string()
1944 })?;
1945 Arc::new(cal.apply(spec.z.view(), a_block.view())?)
1946 }
1947 };
1948 let z_train = z.as_ref();
1949 let pilot_baseline = pooled_probit_baseline(&spec.y, z_train, &spec.weights)?;
1950 let baseline = (
1951 bernoulli_marginal_slope_eta_from_probability(
1952 &spec.base_link,
1953 normal_cdf(pilot_baseline.0),
1954 "bernoulli marginal-slope baseline link inversion",
1955 )?,
1956 pilot_baseline.1 / probit_scale,
1957 );
1958
1959 let rigid_pilot_eta = rigid_pooled_probit_pilot_eta(
2002 &spec.base_link,
2003 z_train,
2004 &spec.marginal_offset,
2005 &spec.logslope_offset,
2006 baseline.0,
2007 baseline.1,
2008 probit_scale,
2009 )?;
2010 let cross_block_pilot_w_score_warp =
2011 pilot_irls_hessian_row_metric_at_eta(&rigid_pilot_eta, &spec.weights);
2012
2013 let influence_columns = if let Some(jac) = spec
2025 .score_influence_jacobian
2026 .as_ref()
2027 .filter(|j| j.ncols() > 0)
2028 {
2029 let protected_design = DesignMatrix::hstack(vec![
2030 marginal_design.design.clone(),
2031 logslope_design.design.clone(),
2032 ])
2033 .map_err(|e| {
2034 format!(
2035 "bernoulli marginal-slope influence-block protected projection stack failed to concatenate marginal + logslope design: {e}"
2036 )
2037 })?;
2038 let protected_dense_for_proj = protected_design
2039 .try_to_dense_arc("bernoulli marginal-slope influence-block protected projection")?;
2040 let protected_dense = protected_dense_for_proj.as_ref();
2041 if jac.nrows() != protected_dense.nrows() {
2042 return Err(format!(
2043 "influence block: Jacobian has {} rows, protected design has {}",
2044 jac.nrows(),
2045 protected_dense.nrows()
2046 ));
2047 }
2048 let rigid_logslope_at_rows = &spec.logslope_offset + baseline.1;
2059 let residualized = crate::marginal_slope_orthogonal::residualized_influence_block(
2060 jac,
2061 z_train,
2062 &rigid_logslope_at_rows,
2063 probit_scale,
2064 protected_dense.view(),
2065 &cross_block_pilot_w_score_warp,
2066 )?;
2067 Some(residualized)
2068 } else {
2069 None
2070 };
2071 let mut cross_block_warnings: Vec<CrossBlockIdentifiabilityWarning> = Vec::new();
2072 let score_warp_prepared = if let Some(cfg) = spec.score_warp.as_ref() {
2073 use super::deviation_runtime::ParametricAnchorBlock;
2074 let mut prepared = build_score_warp_deviation_block_from_seed(z_train, cfg)?;
2075 let outcome = install_compiled_flex_block_into_runtime(
2080 &mut prepared,
2081 z_train,
2082 cfg,
2083 &[
2084 (&marginal_design.design, ParametricAnchorBlock::Marginal),
2085 (&logslope_design.design, ParametricAnchorBlock::Logslope),
2086 ],
2087 &[],
2088 &cross_block_pilot_w_score_warp,
2089 )?;
2090 match outcome {
2091 FlexCompileOutcome::Reparameterised => Some(prepared),
2092 FlexCompileOutcome::FullyAliased { reason } => {
2093 cross_block_warnings.push(CrossBlockIdentifiabilityWarning {
2099 candidate_label: "score_warp",
2100 anchor_summary: "marginal+logslope".to_string(),
2101 reason,
2102 });
2103 Some(prepared)
2104 }
2105 }
2106 } else {
2107 None
2108 };
2109 let link_dev_prepared = if let Some(cfg) = spec.link_dev.as_ref() {
2135 let eta_pilot = pilot_eta_for_link_dev_orthogonalisation(
2136 &spec.base_link,
2137 &spec.y,
2138 z_train,
2139 &spec.weights,
2140 &marginal_design.design,
2141 &spec.marginal_offset,
2142 &spec.logslope_offset,
2143 baseline.0,
2144 baseline.1,
2145 probit_scale,
2146 )?;
2147 let link_dev_seed = padded_deviation_seed(&eta_pilot, 1.0, 0.5);
2148 let mut prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
2149 &link_dev_seed,
2150 &eta_pilot,
2151 cfg,
2152 )?;
2153 let score_warp_anchor_design = score_warp_prepared
2190 .as_ref()
2191 .map(|sw| sw.runtime.design_at_training_with_residual(z_train))
2192 .transpose()?;
2193 use super::deviation_runtime::ParametricAnchorBlock;
2194 let parametric_anchors: [(&DesignMatrix, ParametricAnchorBlock); 2] = [
2195 (&marginal_design.design, ParametricAnchorBlock::Marginal),
2196 (&logslope_design.design, ParametricAnchorBlock::Logslope),
2197 ];
2198 let flex_anchor_slot: Option<&Array2<f64>> = score_warp_anchor_design.as_ref();
2199 let flex_anchors: Vec<&Array2<f64>> = flex_anchor_slot.into_iter().collect();
2200 let cross_block_pilot_w_link_dev =
2205 pilot_irls_hessian_row_metric_at_eta(&eta_pilot, &spec.weights);
2206 let outcome = install_compiled_flex_block_into_runtime(
2207 &mut prepared,
2208 &eta_pilot,
2209 cfg,
2210 ¶metric_anchors,
2211 &flex_anchors,
2212 &cross_block_pilot_w_link_dev,
2213 )?;
2214 match outcome {
2215 FlexCompileOutcome::Reparameterised => Some(prepared),
2216 FlexCompileOutcome::FullyAliased { reason } => {
2217 cross_block_warnings.push(CrossBlockIdentifiabilityWarning {
2223 candidate_label: "link_deviation",
2224 anchor_summary: "marginal+logslope+score_warp".to_string(),
2225 reason,
2226 });
2227 Some(prepared)
2228 }
2229 }
2230 } else {
2231 None
2232 };
2233 let extra_rho0 = {
2234 let mut out = Vec::new();
2235 if let Some(ref prepared) = score_warp_prepared {
2236 out.extend(std::iter::repeat_n(0.0, prepared.block.penalties.len()));
2237 }
2238 if let Some(ref prepared) = link_dev_prepared {
2239 out.extend(std::iter::repeat_n(0.0, prepared.block.penalties.len()));
2240 }
2241 out
2242 };
2243 let logslope_reduced_reparam: Option<ReducedLogslopeReparam> = build_reduced_logslope_reparam(
2256 &marginal_design,
2257 &logslope_design,
2258 z.as_ref(),
2259 &cross_block_pilot_w_score_warp,
2260 &spec.marginal_offset,
2261 &spec.logslope_offset,
2262 baseline.0,
2263 baseline.1,
2264 probit_scale,
2265 )?;
2266 let reduce_logslope_design =
2272 |logslope_design: &TermCollectionDesign| -> Result<TermCollectionDesign, String> {
2273 match logslope_reduced_reparam.as_ref() {
2274 Some(reparam) => reparameterize_logslope_design_reduced(logslope_design, reparam),
2275 None => Ok(logslope_design.clone()),
2276 }
2277 };
2278
2279 let marginal_penalty_count = marginal_design.penalties.len();
2280 let setup = joint_setup(
2281 data_view,
2282 &marginalspec_boot,
2283 &logslopespec_boot,
2284 marginal_penalty_count,
2285 logslope_design.penalties.len(),
2286 &extra_rho0,
2287 &effective_kappa_options,
2288 );
2289 let setup = if sigma_learnable {
2290 setup.with_auxiliary(
2291 Array1::from_vec(vec![initial_sigma.expect("learnable sigma seed").ln()]),
2292 Array1::from_vec(vec![0.01_f64.ln()]),
2293 Array1::from_vec(vec![5.0_f64.ln()]),
2294 )
2295 } else {
2296 setup
2297 };
2298 let final_sigma_cell = std::cell::Cell::new(initial_sigma);
2299 let exact_warm_start = RefCell::new(None::<CustomFamilyWarmStart>);
2300 let runaway_error = RefCell::new(None::<String>);
2301 let pending_beta_seed = RefCell::new(None::<Array1<f64>>);
2308 let hints = RefCell::new(ThetaHints::default());
2309 let score_warp_runtime = score_warp_prepared.as_ref().map(|p| p.runtime.clone());
2310 let link_dev_runtime = link_dev_prepared.as_ref().map(|p| p.runtime.clone());
2311
2312 let build_blocks = |rho: &Array1<f64>,
2313 marginal_design: &TermCollectionDesign,
2314 logslope_design: &TermCollectionDesign|
2315 -> Result<Vec<ParameterBlockSpec>, String> {
2316 let hints = hints.borrow();
2317 let mut cursor = 0usize;
2318 let logslope_design_reduced = reduce_logslope_design(logslope_design)?;
2325 let logslope_design = &logslope_design_reduced;
2326 let rho_marginal = rho
2331 .slice(s![cursor..cursor + marginal_design.penalties.len()])
2332 .to_owned();
2333 cursor += marginal_design.penalties.len();
2334 let rho_logslope = rho
2335 .slice(s![cursor..cursor + logslope_design.penalties.len()])
2336 .to_owned();
2337 cursor += logslope_design.penalties.len();
2338 let p_m = marginal_design.design.ncols()
2339 + influence_columns.as_ref().map(|z| z.ncols()).unwrap_or(0);
2340 let mut blocks = vec![
2341 build_marginal_blockspec_bms(
2342 marginal_design,
2343 baseline.0,
2344 &spec.marginal_offset,
2345 rho_marginal,
2346 hints.marginal_beta.clone(),
2347 logslope_design,
2348 &spec.logslope_offset,
2349 baseline.1,
2350 p_m,
2351 influence_columns.as_ref(),
2352 influence_absorber_log_lambda(spec.z.len()),
2353 )?,
2354 build_logslope_blockspec_bms(
2355 logslope_design,
2356 baseline.1,
2357 &spec.logslope_offset,
2358 rho_logslope,
2359 hints.logslope_beta.clone(),
2360 marginal_design,
2361 &spec.marginal_offset,
2362 baseline.0,
2363 Arc::clone(&z),
2364 p_m,
2365 influence_columns.as_ref(),
2366 )?,
2367 ];
2368 push_deviation_aux_blockspecs(
2369 &mut blocks,
2370 rho,
2371 &mut cursor,
2372 score_warp_prepared.as_ref(),
2373 link_dev_prepared.as_ref(),
2374 hints.score_warp_beta.clone(),
2375 hints.link_dev_beta.clone(),
2376 )?;
2377 Ok(blocks)
2378 };
2379
2380 let intercept_warm_starts = new_intercept_warm_start_cache(y.len());
2381 let cell_moment_lru = new_cell_moment_lru_cache(policy);
2382 let cell_moment_cache_stats = new_cell_moment_cache_stats();
2383 let make_family = |marginal_design: &TermCollectionDesign,
2384 logslope_design: &TermCollectionDesign,
2385 sigma: Option<f64>|
2386 -> BernoulliMarginalSlopeFamily {
2387 let kernel_marginal_design = match influence_columns.as_ref() {
2393 Some(z_infl) => {
2394 let raw = marginal_design
2395 .design
2396 .try_to_dense_arc("make_family::widened-marginal")
2397 .expect("dense marginal design for influence widening");
2398 let widened = widen_marginal_dense_with_influence(&raw, Some(z_infl))
2399 .expect("widen marginal design with influence columns");
2400 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from((*widened).clone()))
2401 }
2402 None => marginal_design.design.clone(),
2403 };
2404 let kernel_logslope_design = reduce_logslope_design(logslope_design)
2410 .expect("reduce logslope design for family construction")
2411 .design;
2412 BernoulliMarginalSlopeFamily {
2413 y: Arc::clone(&y),
2414 weights: Arc::clone(&weights),
2415 z: Arc::clone(&z),
2416 latent_measure: latent_measure.clone(),
2417 gaussian_frailty_sd: sigma,
2418 base_link: spec.base_link.clone(),
2419 marginal_design: kernel_marginal_design,
2420 logslope_design: kernel_logslope_design,
2421 score_warp: score_warp_runtime.clone(),
2422 link_dev: link_dev_runtime.clone(),
2423 policy: policy.clone(),
2424 cell_moment_lru: Arc::clone(&cell_moment_lru),
2425 cell_moment_cache_stats: Arc::clone(&cell_moment_cache_stats),
2426 intercept_warm_starts: Some(Arc::clone(&intercept_warm_starts)),
2427 auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
2428 auto_subsample_last_rho: Arc::new(Mutex::new(None)),
2429 }
2430 };
2431
2432 let marginal_terms = spatial_length_scale_term_indices(&marginalspec_boot);
2433 let logslope_terms = spatial_length_scale_term_indices(&logslopespec_boot);
2434 let marginal_has_spatial = !marginal_terms.is_empty();
2435 let logslope_has_spatial = !logslope_terms.is_empty();
2436 let analytic_joint_derivatives_available =
2437 marginal_has_spatial || logslope_has_spatial || setup.log_kappa_dim() == 0;
2438 if setup.log_kappa_dim() > 0 && !analytic_joint_derivatives_available {
2439 return Err("exact bernoulli marginal-slope spatial optimization requires analytic joint psi derivatives"
2440 .to_string());
2441 }
2442 let initial_rho = setup.theta0().slice(s![..setup.rho_dim()]).to_owned();
2443 let initial_blocks = build_blocks(&initial_rho, &marginal_design, &logslope_design)?;
2444 let initial_family = make_family(&marginal_design, &logslope_design, initial_sigma);
2445 let (joint_gradient, joint_hessian) =
2446 custom_family_outer_derivatives(&initial_family, &initial_blocks, options);
2447 let analytic_joint_gradient_available = analytic_joint_derivatives_available
2448 && matches!(joint_gradient, gam_problem::Derivative::Analytic);
2449 let analytic_joint_hessian_available =
2455 analytic_joint_derivatives_available && joint_hessian.is_analytic();
2456 let kappa_options_ref: &SpatialLengthScaleOptimizationOptions = &effective_kappa_options;
2457 let sigma_from_theta = |theta: &Array1<f64>| -> Option<f64> {
2458 if sigma_learnable {
2459 Some(theta[setup.rho_dim() + setup.log_kappa_dim()].exp())
2460 } else {
2461 initial_sigma
2462 }
2463 };
2464 let derivative_block_cache = RefCell::new(
2465 None::<(
2466 Array1<f64>,
2467 Arc<Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>>,
2468 )>,
2469 );
2470 let theta_matches = |left: &Array1<f64>, right: &Array1<f64>| -> bool {
2471 left.len() == right.len()
2472 && left
2473 .iter()
2474 .zip(right.iter())
2475 .all(|(lhs, rhs)| (*lhs - *rhs).abs() <= 1e-12 * (1.0 + lhs.abs().max(rhs.abs())))
2476 };
2477 let get_derivative_blocks = |theta: &Array1<f64>,
2478 specs: &[TermCollectionSpec],
2479 designs: &[TermCollectionDesign]|
2480 -> Result<
2481 Arc<Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>>,
2482 String,
2483 > {
2484 if let Some((cached_theta, cached_blocks)) = derivative_block_cache.borrow().as_ref()
2485 && theta_matches(cached_theta, theta)
2486 {
2487 return Ok(Arc::clone(cached_blocks));
2488 }
2489
2490 let built = |specs: &[TermCollectionSpec],
2491 designs: &[TermCollectionDesign]|
2492 -> Result<
2493 Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>,
2494 String,
2495 > {
2496 let marginal_psi_derivs = if marginal_has_spatial {
2497 build_block_spatial_psi_derivatives(data_view, &specs[0], &designs[0])?.ok_or_else(
2498 || {
2499 "bernoulli marginal-slope: marginal block has spatial terms \
2500 but spatial psi derivatives are unavailable"
2501 .to_string()
2502 },
2503 )?
2504 } else {
2505 Vec::new()
2506 };
2507 let logslope_psi_derivs = if logslope_has_spatial {
2508 build_block_spatial_psi_derivatives(data_view, &specs[1], &designs[1])?.ok_or_else(
2509 || {
2510 "bernoulli marginal-slope: logslope block has spatial terms \
2511 but spatial psi derivatives are unavailable"
2512 .to_string()
2513 },
2514 )?
2515 } else {
2516 Vec::new()
2517 };
2518 let mut derivative_blocks = vec![marginal_psi_derivs, logslope_psi_derivs];
2519 if score_warp_runtime.is_some() {
2520 derivative_blocks.push(Vec::new());
2521 }
2522 if link_dev_runtime.is_some() {
2523 derivative_blocks.push(Vec::new());
2524 }
2525 if sigma_learnable {
2526 derivative_blocks
2527 .last_mut()
2528 .expect("bernoulli derivative block list is non-empty")
2529 .push(crate::custom_family::CustomFamilyBlockPsiDerivative::new(
2530 None,
2531 Array2::zeros((0, 0)),
2532 Array2::zeros((0, 0)),
2533 None,
2534 None,
2535 None,
2536 None,
2537 ));
2538 }
2539 Ok(derivative_blocks)
2540 }(specs, designs)?;
2541 let built = Arc::new(built);
2542 derivative_block_cache.replace(Some((theta.clone(), Arc::clone(&built))));
2543 Ok(built)
2544 };
2545
2546 let outer_policy = {
2551 let psi_dim = setup.theta0().len() - setup.rho_dim();
2552 initial_family.outer_derivative_policy(&initial_blocks, psi_dim, options)
2553 };
2554 let exact_spatial_outer_tol = kappa_options_ref.rel_tol.max(EXACT_SPATIAL_OUTER_TOL_FLOOR);
2555 let solved = optimize_spatial_length_scale_exact_joint(
2556 data_view,
2557 &[marginalspec_boot.clone(), logslopespec_boot.clone()],
2558 &[marginal_terms.clone(), logslope_terms.clone()],
2559 kappa_options_ref,
2560 &setup,
2561 gam_solve::seeding::SeedRiskProfile::GeneralizedLinear,
2562 analytic_joint_gradient_available,
2563 analytic_joint_hessian_available,
2564 true,
2565 None,
2566 outer_policy,
2567 |theta, specs: &[TermCollectionSpec], designs: &[TermCollectionDesign]| {
2568 if let Some(err) = runaway_error.borrow().as_ref().cloned() {
2569 return Err(err);
2570 }
2571 assert_eq!(
2572 specs.len(),
2573 designs.len(),
2574 "spatial joint optimizer must supply one spec per design",
2575 );
2576 let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
2577 let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
2578 let sigma = sigma_from_theta(theta);
2579 final_sigma_cell.set(sigma);
2580 let family = make_family(&designs[0], &designs[1], sigma);
2581 let fit = inner_fit(&family, &blocks, options)?;
2582 if let Some(block) = fit.block_states.first()
2583 && let Some(err) = bernoulli_marginal_slope_runaway_error_from_beta(
2584 block.beta.view(),
2585 &designs[0],
2586 &specs[0],
2587 fit.outer_converged,
2588 "final fit",
2589 )
2590 {
2591 runaway_error.replace(Some(err.clone()));
2592 return Err(err);
2593 }
2594 let mut hints_mut = hints.borrow_mut();
2595 let mut bidx = 0usize;
2596 if let Some(block) = fit.block_states.get(bidx) {
2597 hints_mut.marginal_beta = Some(block.beta.clone());
2598 }
2599 bidx += 1;
2600 if let Some(block) = fit.block_states.get(bidx) {
2601 hints_mut.logslope_beta = Some(block.beta.clone());
2602 }
2603 bidx += 1;
2604 if score_warp_prepared.is_some() {
2605 if let Some(block) = fit.block_states.get(bidx) {
2606 hints_mut.score_warp_beta = Some(block.beta.clone());
2607 }
2608 bidx += 1;
2609 }
2610 if link_dev_prepared.is_some()
2611 && let Some(block) = fit.block_states.get(bidx)
2612 {
2613 hints_mut.link_dev_beta = Some(block.beta.clone());
2614 }
2615 Ok(fit)
2616 },
2617 |theta,
2618 specs: &[TermCollectionSpec],
2619 designs: &[TermCollectionDesign],
2620 eval_mode,
2621 row_set: &crate::row_kernel::RowSet| {
2622 if let Some(err) = runaway_error.borrow().as_ref().cloned() {
2623 return Err(err);
2624 }
2625 use gam_problem::EvalMode;
2626 static BMS_OUTER_EVAL_ROWSET_LOGGED: std::sync::Once = std::sync::Once::new();
2633 BMS_OUTER_EVAL_ROWSET_LOGGED.call_once(|| {
2634 let row_set_rows = match row_set {
2635 crate::row_kernel::RowSet::All => spec.y.len(),
2636 crate::row_kernel::RowSet::Subsample { rows, .. } => rows.len(),
2637 };
2638 log::debug!(
2639 "[BMS exact outer eval] mode={eval_mode:?} row_set_rows={row_set_rows}"
2640 );
2641 });
2642 let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
2643 let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
2644 if let Some(beta_seed) = pending_beta_seed.borrow_mut().take() {
2648 let widths: Vec<usize> = blocks.iter().map(|b| b.design.ncols()).collect();
2649 match CustomFamilyWarmStart::from_cached_beta(&widths, &beta_seed) {
2650 Ok(ws) => {
2651 exact_warm_start.replace(Some(ws));
2652 }
2653 Err(e) => {
2654 log::warn!(
2655 "[BMS] outer ρ-cache β-warm-start rejected: {e}; falling back to cold β"
2656 );
2657 }
2658 }
2659 }
2660 let sigma = sigma_from_theta(theta);
2661 final_sigma_cell.set(sigma);
2662 let family = make_family(&designs[0], &designs[1], sigma);
2663 let derivative_blocks = get_derivative_blocks(theta, specs, designs)?;
2664 let effective_mode = match eval_mode {
2668 EvalMode::ValueGradientHessian if !analytic_joint_hessian_available => {
2669 EvalMode::ValueAndGradient
2670 }
2671 other => other,
2672 };
2673 let mut eval_options =
2674 joint_hyper_options_for_outer_tolerance(options, exact_spatial_outer_tol);
2675 if let crate::row_kernel::RowSet::Subsample { rows, n_full } = row_set {
2676 let subsample = crate::outer_subsample::OuterScoreSubsample::from_weighted_rows(
2677 rows.as_ref().clone(),
2678 *n_full,
2679 0,
2680 );
2681 eval_options.outer_score_subsample = Some(Arc::new(subsample));
2682 eval_options.auto_outer_subsample = false;
2683 }
2684 let eval = evaluate_custom_family_joint_hyper_shared(
2685 &family,
2686 &blocks,
2687 &eval_options,
2688 &rho,
2689 derivative_blocks,
2690 exact_warm_start.borrow().as_ref(),
2691 effective_mode,
2692 )?;
2693 if let Some(err) = bernoulli_marginal_slope_runaway_error(
2694 &eval.warm_start,
2695 &designs[0],
2696 &specs[0],
2697 eval.inner_converged,
2698 "exact outer evaluation",
2699 ) {
2700 runaway_error.replace(Some(err.clone()));
2701 return Err(err);
2702 }
2703 exact_warm_start.replace(Some(eval.warm_start.clone()));
2704 if !eval.inner_converged {
2705 return Err(
2706 "exact bernoulli marginal-slope inner solve did not converge".to_string(),
2707 );
2708 }
2709 if matches!(eval_mode, EvalMode::ValueGradientHessian)
2710 && analytic_joint_hessian_available
2711 && !eval.outer_hessian.is_analytic()
2712 {
2713 return Err("exact bernoulli marginal-slope joint [rho, psi] objective did not return an outer Hessian"
2714 .to_string());
2715 }
2716 Ok((eval.objective, eval.gradient, eval.outer_hessian))
2717 },
2718 |theta, specs: &[TermCollectionSpec], designs: &[TermCollectionDesign]| {
2719 if let Some(err) = runaway_error.borrow().as_ref().cloned() {
2720 return Err(err);
2721 }
2722 let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
2723 let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
2724 if let Some(beta_seed) = pending_beta_seed.borrow_mut().take() {
2725 let widths: Vec<usize> = blocks.iter().map(|b| b.design.ncols()).collect();
2726 match CustomFamilyWarmStart::from_cached_beta(&widths, &beta_seed) {
2727 Ok(ws) => {
2728 exact_warm_start.replace(Some(ws));
2729 }
2730 Err(e) => {
2731 log::warn!(
2732 "[BMS] outer ρ-cache β-warm-start rejected (efs): {e}; falling back to cold β"
2733 );
2734 }
2735 }
2736 }
2737 let sigma = sigma_from_theta(theta);
2738 final_sigma_cell.set(sigma);
2739 let family = make_family(&designs[0], &designs[1], sigma);
2740 let derivative_blocks = get_derivative_blocks(theta, specs, designs)?;
2741 let eval = evaluate_custom_family_joint_hyper_efs_shared(
2742 &family,
2743 &blocks,
2744 &joint_hyper_options_for_outer_tolerance(options, exact_spatial_outer_tol),
2745 &rho,
2746 derivative_blocks,
2747 exact_warm_start.borrow().as_ref(),
2748 )?;
2749 if let Some(err) = bernoulli_marginal_slope_runaway_error(
2750 &eval.warm_start,
2751 &designs[0],
2752 &specs[0],
2753 eval.inner_converged,
2754 "EFS outer evaluation",
2755 ) {
2756 runaway_error.replace(Some(err.clone()));
2757 return Err(err);
2758 }
2759 exact_warm_start.replace(Some(eval.warm_start.clone()));
2760 if !eval.inner_converged {
2761 return Err(
2762 "exact bernoulli marginal-slope EFS inner solve did not converge".to_string(),
2763 );
2764 }
2765 Ok(eval.efs_eval)
2766 },
2767 crate::marginal_slope_shared::make_beta_seed_validator(&pending_beta_seed),
2768 )?;
2769
2770 let mut resolved_specs = solved.resolved_specs;
2771 let mut designs = solved.designs;
2772 let mut solved_fit = solved.fit;
2784 if let Some(reparam) = logslope_reduced_reparam.as_ref() {
2785 let r = reparam.reduced_cols();
2786 if let Some(block) = solved_fit.blocks.get_mut(1)
2787 && block.beta.len() == r
2788 {
2789 block.beta = reparam.recover_original_logslope_beta(&block.beta)?;
2790 }
2791 if let Some(state) = solved_fit.block_states.get_mut(1)
2792 && state.beta.len() == r
2793 {
2794 state.beta = reparam.recover_original_logslope_beta(&state.beta)?;
2795 }
2796 }
2797 let (latent_z_rank_int_calibration, latent_z_conditional_calibration) =
2839 match latent_z_calibration {
2840 LatentMeasureCalibration::None => (None, None),
2841 LatentMeasureCalibration::RankInverseNormal(cal) => (Some(cal), None),
2842 LatentMeasureCalibration::ConditionalLocationScale(cal) => (None, Some(cal)),
2843 };
2844 if let Some(cal) = latent_z_conditional_calibration.as_ref()
2854 && let Some(vb) = solved_fit.covariance_conditional.clone()
2855 {
2856 let p_beta = vb.nrows();
2857 let marginal_dense = marginal_design
2858 .design
2859 .try_to_dense_arc("bms generated-regressor marginal design")?;
2860 let logslope_reduced = reduce_logslope_design(&logslope_design)?;
2861 let logslope_reduced_dense = logslope_reduced
2862 .design
2863 .try_to_dense_arc("bms generated-regressor reduced logslope design")?;
2864 let p_m = marginal_dense.ncols();
2865 let r = logslope_reduced_dense.ncols();
2866 if p_beta != vb.ncols() {
2867 return Err(format!(
2868 "bms generated-regressor: covariance_conditional must be square, got {}×{}",
2869 vb.nrows(),
2870 vb.ncols()
2871 ));
2872 }
2873 if p_beta == p_m + r {
2877 let marginal_eta = &solved_fit.block_states[0].eta;
2878 let slope_eta = &solved_fit.block_states[1].eta;
2879 let probit_scale = probit_frailty_scale(final_sigma_cell.get());
2880 let s = rigid_standard_normal_score_zeta_sensitivity(
2881 &spec.base_link,
2882 marginal_eta,
2883 slope_eta,
2884 z.as_ref(),
2885 y.as_ref(),
2886 weights.as_ref(),
2887 probit_scale,
2888 marginal_dense.view(),
2889 logslope_reduced_dense.view(),
2890 p_beta,
2891 )?;
2892 let correction = cal.generated_regressor_correction(
2900 s.view(),
2901 spec.z.view(),
2902 marginal_dense.view(),
2903 vb.view(),
2904 )?;
2905 if let Some(cov) = solved_fit.covariance_conditional.as_mut() {
2906 *cov = &*cov + &correction;
2907 }
2908 if let Some(cov) = solved_fit.covariance_corrected.as_mut() {
2909 *cov = &*cov + &correction;
2910 }
2911 log::info!(
2912 "[BMS latent-z] Murphy–Topel generated-regressor SE correction applied: \
2913 p_beta={p_beta} theta1_dim={} max_diag_inflation={:.3e}",
2914 cal.theta1_dim(),
2915 (0..p_beta)
2916 .map(|i| correction[[i, i]])
2917 .fold(0.0_f64, f64::max),
2918 );
2919 } else {
2920 log::info!(
2921 "[BMS latent-z] Murphy–Topel generated-regressor SE correction skipped: \
2922 aux deviation blocks present (p_beta={p_beta} > marginal({p_m})+logslope({r})); \
2923 rigid-kernel z-channel does not yet cover score_warp/link_dev deviations"
2924 );
2925 }
2926 }
2927 Ok(BernoulliMarginalSlopeFitResult {
2940 fit: solved_fit,
2941 marginalspec_resolved: resolved_specs.remove(0),
2942 logslopespec_resolved: resolved_specs.remove(0),
2943 marginal_design: designs.remove(0),
2944 logslope_design: designs.remove(0),
2945 baseline_marginal: baseline.0,
2946 baseline_logslope: baseline.1,
2947 z_normalization,
2948 latent_measure,
2949 score_warp_runtime,
2950 link_dev_runtime,
2951 gaussian_frailty_sd: final_sigma_cell.get(),
2952 cross_block_warnings,
2953 latent_z_rank_int_calibration,
2954 latent_z_conditional_calibration,
2955 })
2956}