1use super::family::*;
2use super::gradient_paths::*;
3use super::hessian_paths::{new_cell_moment_cache_stats, new_cell_moment_lru_cache};
4use super::install_flex::validate_spec;
5use super::*;
6use gam_linalg::faer_ndarray::{FaerEigh, fast_ab, fast_atb, fast_xt_diag_x};
7use crate::marginal_slope_orthogonal::INFLUENCE_ABSORBER_FIXED_LOG_LAMBDA;
8use faer::Side;
9
10pub(crate) const BMS_PROBIT_SEPARATION_ETA_INF: f64 = 35.0;
25
26pub(super) const GAUGE_PRIORITY_ANCHOR: u8 = 200;
41pub(super) const GAUGE_PRIORITY_MARGINAL: u8 = 150;
44pub(super) const GAUGE_PRIORITY_LOGSLOPE: u8 = 120;
46pub(super) const GAUGE_PRIORITY_CANDIDATE_FLEX: u8 = 100;
49pub(super) const GAUGE_PRIORITY_SCORE_WARP_DEV: u8 = 80;
52pub(super) const GAUGE_PRIORITY_DEVIATION_DEFAULT: u8 = 70;
56pub(super) const GAUGE_PRIORITY_LINK_DEV: u8 = 60;
58
59pub(crate) const EXACT_SPATIAL_OUTER_TOL_FLOOR: f64 = 1e-6;
65
66pub struct BmsMarginalJacobian {
105 pub marginal_dense: Arc<Array2<f64>>,
107 pub logslope_dense: Arc<Array2<f64>>,
109 pub offset_m: Array1<f64>,
110 pub offset_s: Array1<f64>,
111 pub p_marginal: usize,
113}
114
115impl BmsMarginalJacobian {
116 pub fn new(
117 marginal_dense: Arc<Array2<f64>>,
118 logslope_dense: Arc<Array2<f64>>,
119 offset_m: Array1<f64>,
120 offset_s: Array1<f64>,
121 p_marginal: usize,
122 ) -> Self {
123 Self {
124 marginal_dense,
125 logslope_dense,
126 offset_m,
127 offset_s,
128 p_marginal,
129 }
130 }
131}
132
133impl BlockEffectiveJacobian for BmsMarginalJacobian {
134 fn effective_jacobian_rows(
135 &self,
136 state: &FamilyLinearizationState<'_>,
137 rows: std::ops::Range<usize>,
138 ) -> Result<Array2<f64>, String> {
139 let beta = state.beta;
140 let s = state.probit_frailty_scale;
141 let p_m = self.p_marginal;
142 let p_s_block = self.logslope_dense.ncols();
143 let beta_s_raw = if beta.len() > p_m {
144 &beta[p_m..]
145 } else {
146 &[][..]
147 };
148 let p_s_use = p_s_block.min(beta_s_raw.len());
149 let beta_s = &beta_s_raw[..p_s_use];
150 let n = self.marginal_dense.nrows();
151 let rows = rows.start.min(n)..rows.end.min(n);
152 let p_block = self.marginal_dense.ncols();
153
154 let mut out = Array2::<f64>::zeros((rows.end - rows.start, p_block));
164 for i in rows.clone() {
165 let g_i = self.offset_s[i]
166 + self
167 .logslope_dense
168 .row(i)
169 .slice(ndarray::s![..p_s_use])
170 .dot(&ArrayView1::from(beta_s));
171 let sg = s * g_i;
172 let c_i = (1.0 + sg * sg).sqrt();
173 let m_row = self.marginal_dense.row(i);
175 out.row_mut(i - rows.start).assign(&m_row.mapv(|x| c_i * x));
176 }
177 Ok(out)
178 }
179
180 fn n_outputs(&self) -> usize {
181 1
182 }
183}
184
185pub struct BmsLogslopeJacobian {
197 pub marginal_dense: Arc<Array2<f64>>,
199 pub logslope_dense: Arc<Array2<f64>>,
201 pub offset_m: Array1<f64>,
202 pub offset_s: Array1<f64>,
203 pub z: Arc<Array1<f64>>,
204 pub p_marginal: usize,
206}
207
208impl BmsLogslopeJacobian {
209 pub fn new(
210 marginal_dense: Arc<Array2<f64>>,
211 logslope_dense: Arc<Array2<f64>>,
212 offset_m: Array1<f64>,
213 offset_s: Array1<f64>,
214 z: Arc<Array1<f64>>,
215 p_marginal: usize,
216 ) -> Self {
217 Self {
218 marginal_dense,
219 logslope_dense,
220 offset_m,
221 offset_s,
222 z,
223 p_marginal,
224 }
225 }
226}
227
228impl BlockEffectiveJacobian for BmsLogslopeJacobian {
229 fn effective_jacobian_rows(
230 &self,
231 state: &FamilyLinearizationState<'_>,
232 rows: std::ops::Range<usize>,
233 ) -> Result<Array2<f64>, String> {
234 let beta = state.beta;
235 let s = state.probit_frailty_scale;
236 let p_m = self.p_marginal;
237 let p_m_use = p_m.min(beta.len());
238 let beta_m = &beta[..p_m_use];
239 let beta_s_raw = if beta.len() > p_m {
240 &beta[p_m..]
241 } else {
242 &[][..]
243 };
244 let p_s_block = self.logslope_dense.ncols();
245 let p_s_use = p_s_block.min(beta_s_raw.len());
246 let beta_s = &beta_s_raw[..p_s_use];
247 let n = self.logslope_dense.nrows();
248 let rows = rows.start.min(n)..rows.end.min(n);
249
250 let mut out = Array2::<f64>::zeros((rows.end - rows.start, p_s_block));
263 for i in rows.clone() {
264 let q_i = self.offset_m[i]
265 + self
266 .marginal_dense
267 .row(i)
268 .slice(ndarray::s![..p_m_use])
269 .dot(&ArrayView1::from(beta_m));
270 let g_i = self.offset_s[i]
271 + self
272 .logslope_dense
273 .row(i)
274 .slice(ndarray::s![..p_s_use])
275 .dot(&ArrayView1::from(beta_s));
276 let sg = s * g_i;
277 let c_i = (1.0 + sg * sg).sqrt();
278 let z_i = self.z[i];
279 let factor = q_i * s * s * g_i / c_i + s * z_i;
281 let g_row = self.logslope_dense.row(i);
283 out.row_mut(i - rows.start)
284 .assign(&g_row.mapv(|x| factor * x));
285 }
286 Ok(out)
287 }
288
289 fn n_outputs(&self) -> usize {
290 1
291 }
292}
293
294pub(crate) fn widen_marginal_dense_with_influence(
304 marginal_dense: &Arc<Array2<f64>>,
305 influence_columns: Option<&Array2<f64>>,
306) -> Result<Arc<Array2<f64>>, String> {
307 let Some(z_infl) = influence_columns else {
308 return Ok(Arc::clone(marginal_dense));
309 };
310 let n = marginal_dense.nrows();
311 if z_infl.nrows() != n {
312 return Err(format!(
313 "influence block: residualised columns have {} rows, marginal design has {n}",
314 z_infl.nrows()
315 ));
316 }
317 let p_m = marginal_dense.ncols();
318 let p1 = z_infl.ncols();
319 let mut widened = Array2::<f64>::zeros((n, p_m + p1));
320 widened
321 .slice_mut(s![.., ..p_m])
322 .assign(marginal_dense.as_ref());
323 widened.slice_mut(s![.., p_m..]).assign(z_infl);
324 Ok(Arc::new(widened))
325}
326
327pub(crate) const LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL: f64 = 1.0e-6;
335
336#[derive(Debug, Clone)]
388pub(super) struct ReducedLogslopeReparam {
389 transform: Array2<f64>,
392}
393
394impl ReducedLogslopeReparam {
395 #[inline]
397 pub(super) fn original_cols(&self) -> usize {
398 self.transform.nrows()
399 }
400
401 #[inline]
403 pub(super) fn reduced_cols(&self) -> usize {
404 self.transform.ncols()
405 }
406
407 pub(super) fn recover_original_logslope_beta(
411 &self,
412 beta_reduced: &Array1<f64>,
413 ) -> Result<Array1<f64>, String> {
414 if beta_reduced.len() != self.reduced_cols() {
415 return Err(format!(
416 "reduced logslope reparam: β' length ({}) != reduced width ({})",
417 beta_reduced.len(),
418 self.reduced_cols()
419 ));
420 }
421 Ok(self.transform.dot(beta_reduced))
422 }
423}
424
425fn build_reduced_logslope_reparam(
433 marginal_design: &TermCollectionDesign,
434 logslope_design: &TermCollectionDesign,
435 z: &Array1<f64>,
436 row_metric: &Array1<f64>,
437 marginal_offset: &Array1<f64>,
438 logslope_offset: &Array1<f64>,
439 marginal_baseline: f64,
440 logslope_baseline: f64,
441 probit_scale: f64,
442) -> Result<Option<ReducedLogslopeReparam>, String> {
443 let marginal = marginal_design
444 .design
445 .try_to_dense_arc("build_reduced_logslope_reparam::marginal")?;
446 let logslope = logslope_design
447 .design
448 .try_to_dense_arc("build_reduced_logslope_reparam::logslope")?;
449 let n = marginal.nrows();
450 if logslope.nrows() != n
451 || z.len() != n
452 || row_metric.len() != n
453 || marginal_offset.len() != n
454 || logslope_offset.len() != n
455 {
456 return Err(format!(
457 "reduced logslope reparam row mismatch: marginal={}, logslope={}, z={}, row_metric={}, marginal_offset={}, logslope_offset={}",
458 marginal.nrows(),
459 logslope.nrows(),
460 z.len(),
461 row_metric.len(),
462 marginal_offset.len(),
463 logslope_offset.len(),
464 ));
465 }
466 let p_m = marginal.ncols();
467 let p_g = logslope.ncols();
468 if p_m == 0 || p_g == 0 {
469 return Ok(None);
470 }
471 if !marginal_baseline.is_finite()
472 || !logslope_baseline.is_finite()
473 || !probit_scale.is_finite()
474 || probit_scale <= 0.0
475 || z.iter().any(|v| !v.is_finite())
476 || row_metric.iter().any(|v| !v.is_finite() || *v < 0.0)
477 || marginal_offset.iter().any(|v| !v.is_finite())
478 || logslope_offset.iter().any(|v| !v.is_finite())
479 {
480 return Err(
481 "reduced logslope reparam requires finite pilot geometry and finite non-negative row metric"
482 .to_string(),
483 );
484 }
485
486 match reduced_logslope_transform_effective(
497 marginal.view(),
498 logslope.view(),
499 z,
500 row_metric,
501 marginal_offset,
502 logslope_offset,
503 marginal_baseline,
504 logslope_baseline,
505 probit_scale,
506 )? {
507 Some(transform) => Ok(Some(ReducedLogslopeReparam { transform })),
508 None => Ok(None),
509 }
510}
511
512pub(crate) fn reduced_logslope_transform_effective(
531 marginal: ArrayView2<'_, f64>,
532 logslope: ArrayView2<'_, f64>,
533 z: &Array1<f64>,
534 row_metric: &Array1<f64>,
535 marginal_offset: &Array1<f64>,
536 logslope_offset: &Array1<f64>,
537 marginal_baseline: f64,
538 logslope_baseline: f64,
539 probit_scale: f64,
540) -> Result<Option<Array2<f64>>, String> {
541 let n = marginal.nrows();
542 let p_m = marginal.ncols();
543 let p_g = logslope.ncols();
544 if p_m == 0 || p_g == 0 {
545 return Ok(None);
546 }
547
548 let mut m_eff = Array2::<f64>::zeros((n, p_m));
550 let mut g_eff = Array2::<f64>::zeros((n, p_g));
551 for i in 0..n {
552 let q_i = marginal_offset[i] + marginal_baseline;
553 let g_i = logslope_offset[i] + logslope_baseline;
554 let sg = probit_scale * g_i;
555 let c_i = (1.0 + sg * sg).sqrt();
556 let f_i = q_i * probit_scale * probit_scale * g_i / c_i + probit_scale * z[i];
557 for j in 0..p_m {
558 m_eff[[i, j]] = c_i * marginal[[i, j]];
559 }
560 for j in 0..p_g {
561 g_eff[[i, j]] = f_i * logslope[[i, j]];
562 }
563 }
564
565 let c_gram = fast_xt_diag_x(&g_eff, row_metric);
568 let energy_scale = (0..p_g).map(|i| c_gram[[i, i]]).fold(0.0_f64, f64::max);
569 if !energy_scale.is_finite() || energy_scale <= 0.0 {
570 return Ok(None);
571 }
572
573 let mut a_gram = fast_xt_diag_x(&m_eff, row_metric);
577 let a_scale = (0..p_m).map(|i| a_gram[[i, i]]).fold(0.0_f64, f64::max);
578 let a_ridge = (a_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL).max(f64::EPSILON);
579 for i in 0..p_m {
580 a_gram[[i, i]] += a_ridge;
581 }
582
583 let b_cross = gam_linalg::faer_ndarray::fast_xt_diag_y(&m_eff, row_metric, &g_eff);
585 let a_view = gam_linalg::faer_ndarray::FaerArrayView::new(&a_gram);
586 let a_factor =
587 gam_linalg::faer_ndarray::factorize_symmetricwith_fallback(a_view.as_ref(), Side::Lower)
588 .map_err(|e| {
589 format!(
590 "reduced logslope reparam: effective marginal Gram factorization failed: {e}"
591 )
592 })?;
593 let b_view = gam_linalg::faer_ndarray::FaerArrayView::new(&b_cross);
594 let solved = a_factor.solve(b_view.as_ref()); let a_inv_b = Array2::from_shape_fn((p_m, p_g), |(i, j)| solved[(i, j)]);
596 let schur = fast_atb(&b_cross, &a_inv_b); let mut stt = &c_gram - &schur;
598 stt = (&stt + &stt.t()) * 0.5;
599 if stt.iter().any(|v| !v.is_finite()) {
600 return Err(
601 "reduced logslope reparam: effective Schur Gram produced non-finite entries"
602 .to_string(),
603 );
604 }
605
606 let (evals, evecs) = stt
607 .eigh(Side::Lower)
608 .map_err(|e| format!("reduced logslope reparam: eigendecomposition failed: {e:?}"))?;
609 let tol = energy_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL;
613 let mut kept: Vec<usize> = (0..evals.len()).filter(|&i| evals[i] > tol).collect();
614 kept.sort_by(|&a, &b| {
615 evals[b]
616 .partial_cmp(&evals[a])
617 .unwrap_or(std::cmp::Ordering::Equal)
618 });
619 let r = kept.len();
620 if r == p_g || r == 0 {
624 return Ok(None);
625 }
626 let mut transform = Array2::<f64>::zeros((p_g, r));
627 for (out_col, &src) in kept.iter().enumerate() {
628 transform.column_mut(out_col).assign(&evecs.column(src));
629 }
630 if transform.iter().any(|v| !v.is_finite()) {
631 return Err(
632 "reduced logslope reparam: reduced transform produced non-finite entries".to_string(),
633 );
634 }
635 Ok(Some(transform))
636}
637
638fn reparameterize_logslope_design_reduced(
645 logslope_design: &TermCollectionDesign,
646 reparam: &ReducedLogslopeReparam,
647) -> Result<TermCollectionDesign, String> {
648 let g = logslope_design
649 .design
650 .try_to_dense_arc("reparameterize_logslope_design_reduced::logslope")?;
651 let p_g = g.ncols();
652 if p_g != reparam.original_cols() {
653 return Err(format!(
654 "reduced logslope reparam width mismatch: design has {p_g} cols, transform expects {}",
655 reparam.original_cols()
656 ));
657 }
658 let t = &reparam.transform;
659 let r = reparam.reduced_cols();
660 let g_reduced = fast_ab(&g, t);
662
663 let mut new_penalties: Vec<gam_terms::smooth::BlockwisePenalty> =
666 Vec::with_capacity(logslope_design.penalties.len());
667 let mut new_nullspace_dims: Vec<usize> = Vec::with_capacity(logslope_design.penalties.len());
668 for bp in &logslope_design.penalties {
669 let mut full = Array2::<f64>::zeros((p_g, p_g));
670 full.slice_mut(s![bp.col_range.clone(), bp.col_range.clone()])
671 .assign(&bp.local);
672 let st = fast_ab(&full, t); let mut s_reduced = fast_atb(t, &st); s_reduced = (&s_reduced + &s_reduced.t()) * 0.5;
676 let (evals, _) = s_reduced
678 .eigh(Side::Lower)
679 .map_err(|e| format!("reduced logslope penalty eigendecomposition failed: {e:?}"))?;
680 let max_eval = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
681 let pen_tol = (max_eval * 1.0e-12).max(f64::EPSILON);
682 let rank = evals.iter().filter(|&&v| v.abs() > pen_tol).count();
683 let nullspace_dim = r.saturating_sub(rank);
684 new_penalties.push(gam_terms::smooth::BlockwisePenalty::new(0..r, s_reduced));
685 new_nullspace_dims.push(nullspace_dim);
686 }
687
688 let new_design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(g_reduced));
689 Ok(TermCollectionDesign {
695 design: new_design,
696 penalties: new_penalties,
697 nullspace_dims: new_nullspace_dims,
698 penaltyinfo: Vec::new(),
699 dropped_penaltyinfo: Vec::new(),
700 coefficient_lower_bounds: None,
701 linear_constraints: None,
702 intercept_range: 0..0,
703 linear_ranges: Vec::new(),
704 random_effect_ranges: Vec::new(),
705 random_effect_levels: Vec::new(),
706 smooth: gam_terms::smooth::SmoothDesign {
707 term_designs: Vec::new(),
708 penalties: Vec::new(),
709 nullspace_dims: Vec::new(),
710 penaltyinfo: Vec::new(),
711 dropped_penaltyinfo: Vec::new(),
712 terms: Vec::new(),
713 coefficient_lower_bounds: None,
714 linear_constraints: None,
715 },
716 })
717}
718
719pub(crate) fn marginal_penalties_with_influence_ridge(
740 design: &TermCollectionDesign,
741 rho_marginal: &Array1<f64>,
742 influence_columns: Option<&Array2<f64>>,
743 influence_ridge_log_lambda: f64,
744) -> Result<(Vec<PenaltyMatrix>, Vec<usize>, Array1<f64>), String> {
745 let p_m = design.design.ncols();
746 let p1 = influence_columns.map(|z| z.ncols()).unwrap_or(0);
747 let total_dim = p_m + p1;
748 let mut penalties: Vec<PenaltyMatrix> = design
751 .penalties
752 .iter()
753 .map(|bp| bp.to_penalty_matrix(total_dim))
754 .collect();
755 let mut nullspace_dims = design.nullspace_dims.clone();
756 let mut log_lambdas = rho_marginal.to_vec();
757
758 if p1 > 0 {
762 penalties.push(
763 PenaltyMatrix::Blockwise {
764 local: Array2::<f64>::eye(p1),
765 col_range: p_m..total_dim,
766 total_dim,
767 }
768 .with_fixed_log_lambda(influence_ridge_log_lambda),
769 );
770 nullspace_dims.push(0);
771 log_lambdas.push(influence_ridge_log_lambda);
772 }
773
774 Ok((penalties, nullspace_dims, Array1::from_vec(log_lambdas)))
775}
776
777pub(crate) fn widen_marginal_beta_hint(
780 beta_hint: Option<Array1<f64>>,
781 p_marginal_widened: usize,
782) -> Option<Array1<f64>> {
783 beta_hint.map(|hint| {
784 if hint.len() == p_marginal_widened {
785 hint
786 } else {
787 let mut widened = Array1::<f64>::zeros(p_marginal_widened);
788 let copy = hint.len().min(p_marginal_widened);
789 widened
790 .slice_mut(s![..copy])
791 .assign(&hint.slice(s![..copy]));
792 widened
793 }
794 })
795}
796
797fn marginal_fitted_eta_sup_norm(design: &TermCollectionDesign, masked_beta: &Array1<f64>) -> f64 {
807 let x = &design.design;
808 let n = x.nrows();
809 if n == 0 || x.ncols() == 0 {
810 return 0.0;
811 }
812 let mut sup = 0.0_f64;
813 for row in 0..n {
814 let eta = x.dot_row_view(row, masked_beta.view());
815 if eta.is_finite() {
816 sup = sup.max(eta.abs());
817 }
818 }
819 sup
820}
821
822fn marginal_design_beta(
825 design: &TermCollectionDesign,
826 block_beta: ArrayView1<'_, f64>,
827) -> Array1<f64> {
828 let ncols = design.design.ncols();
829 let mut masked = Array1::<f64>::zeros(ncols);
830 let copy = ncols.min(block_beta.len());
831 masked
832 .slice_mut(s![..copy])
833 .assign(&block_beta.slice(s![..copy]));
834 masked
835}
836
837fn mask_parametric_columns(
843 design: &TermCollectionDesign,
844 spec: &TermCollectionSpec,
845 full: &Array1<f64>,
846) -> Array1<f64> {
847 let ncols = design.design.ncols();
848 let mut masked = Array1::<f64>::zeros(ncols);
849 if design.intercept_range.len() == 1 {
850 let idx = design.intercept_range.start;
851 if idx < ncols {
852 masked[idx] = full[idx];
853 }
854 }
855 for (linear, (_, range)) in spec.linear_terms.iter().zip(design.linear_ranges.iter()) {
856 if linear.double_penalty {
857 continue;
858 }
859 for col in range.clone() {
860 if col < ncols {
861 masked[col] = full[col];
862 }
863 }
864 }
865 masked
866}
867
868pub(crate) fn bernoulli_marginal_slope_runaway_error_from_beta(
879 block_beta: ArrayView1<'_, f64>,
880 design: &TermCollectionDesign,
881 spec: &TermCollectionSpec,
882 inner_converged: bool,
883 eval_label: &str,
884) -> Option<String> {
885 let full_beta = marginal_design_beta(design, block_beta);
886 let parametric_beta = mask_parametric_columns(design, spec, &full_beta);
887
888 let eta_parametric = marginal_fitted_eta_sup_norm(design, ¶metric_beta);
889 let eta_full = marginal_fitted_eta_sup_norm(design, &full_beta);
890
891 let (eta_inf, explanation) = if eta_parametric >= BMS_PROBIT_SEPARATION_ETA_INF {
892 (
893 eta_parametric,
894 "an unpenalized parametric marginal direction has no stable finite probit optimum and its fitted predictor has run to the probit underflow scale",
895 )
896 } else if eta_full >= BMS_PROBIT_SEPARATION_ETA_INF {
897 (
898 eta_full,
899 "a marginal direction is trading off against the logslope surface; this is the under-constrained marginal/logslope coupling that appears when the score is correlated with the shared surface covariates",
900 )
901 } else {
902 return None;
906 };
907
908 let inner_status = if inner_converged {
909 "the inner solve reached a KKT certificate at this separation-scale predictor"
910 } else {
911 "the inner solve failed while already carrying a separation-scale predictor"
912 };
913 let beta_abs = full_beta
915 .iter()
916 .copied()
917 .filter(|v| v.is_finite())
918 .fold(0.0_f64, |acc, v| acc.max(v.abs()));
919
920 Some(format!(
921 "bernoulli marginal-slope probit marginal/logslope runaway detected in block \
922 'marginal_surface' during {eval_label}: the fitted marginal predictor has \
923 |η|∞={eta_inf:.3e} (numerical-degeneracy threshold \
924 {BMS_PROBIT_SEPARATION_ETA_INF:.1}; raw |β|∞={beta_abs:.3e} is reported for \
925 context only and does not gate this diagnostic). The joint design is \
926 identifiable; {explanation}. {inner_status}. The robust Jeffreys curvature \
927 path is already installed for this fit, so this diagnostic means the current \
928 coupled surface still drives the linear predictor to the probit underflow \
929 scale rather than a request for an external bias-reduction prior. Reduce or \
930 reparameterize the coupled marginal/logslope surface, or use a \
931 lower-dimensional logslope interaction. This is not a \
932 Matérn/Duchon polynomial-nullspace or cross-block gauge-priority \
933 failure."
934 ))
935}
936
937pub(crate) fn bernoulli_marginal_slope_runaway_error(
938 warm_start: &CustomFamilyWarmStart,
939 design: &TermCollectionDesign,
940 spec: &TermCollectionSpec,
941 inner_converged: bool,
942 eval_label: &str,
943) -> Option<String> {
944 let block_beta = warm_start.block_beta_view(0)?;
945 bernoulli_marginal_slope_runaway_error_from_beta(
946 block_beta,
947 design,
948 spec,
949 inner_converged,
950 eval_label,
951 )
952}
953
954#[cfg(test)]
955mod runaway_tests {
956 use super::*;
957 use gam_linalg::faer_ndarray::{FaerArrayView, factorize_symmetricwith_fallback, fast_xt_diag_y};
958 use gam_terms::smooth::{LinearCoefficientGeometry, LinearTermSpec};
959
960 pub(crate) fn marginal_logslope_overlap_penalty(
966 marginal_design: &DesignMatrix,
967 logslope_design: &DesignMatrix,
968 z: &Array1<f64>,
969 row_metric: &Array1<f64>,
970 marginal_offset: &Array1<f64>,
971 logslope_offset: &Array1<f64>,
972 marginal_baseline: f64,
973 logslope_baseline: f64,
974 probit_scale: f64,
975 ) -> Result<Option<Array2<f64>>, String> {
976 let marginal =
977 marginal_design.try_to_dense_arc("marginal_logslope_overlap_penalty::marginal")?;
978 let logslope =
979 logslope_design.try_to_dense_arc("marginal_logslope_overlap_penalty::logslope")?;
980 let n = marginal.nrows();
981 if logslope.nrows() != n
982 || z.len() != n
983 || row_metric.len() != n
984 || marginal_offset.len() != n
985 || logslope_offset.len() != n
986 {
987 return Err(format!(
988 "marginal/logslope overlap penalty row mismatch: marginal={}, logslope={}, z={}, row_metric={}, marginal_offset={}, logslope_offset={}",
989 marginal.nrows(),
990 logslope.nrows(),
991 z.len(),
992 row_metric.len(),
993 marginal_offset.len(),
994 logslope_offset.len(),
995 ));
996 }
997 let p_m = marginal.ncols();
998 let p_g = logslope.ncols();
999 if p_m == 0 || p_g == 0 {
1000 return Ok(None);
1001 }
1002 if !marginal_baseline.is_finite()
1003 || !logslope_baseline.is_finite()
1004 || !probit_scale.is_finite()
1005 || probit_scale <= 0.0
1006 || z.iter().any(|v| !v.is_finite())
1007 || row_metric.iter().any(|v| !v.is_finite() || *v < 0.0)
1008 || marginal_offset.iter().any(|v| !v.is_finite())
1009 || logslope_offset.iter().any(|v| !v.is_finite())
1010 {
1011 return Err(
1012 "marginal/logslope overlap penalty requires finite pilot geometry and finite non-negative row metric"
1013 .to_string(),
1014 );
1015 }
1016
1017 let mut marginal_effective = Array2::<f64>::zeros((n, p_m));
1018 let mut effective_logslope = Array2::<f64>::zeros((n, p_g));
1019 for i in 0..n {
1020 let q_i = marginal_offset[i] + marginal_baseline;
1021 let g_i = logslope_offset[i] + logslope_baseline;
1022 let sg = probit_scale * g_i;
1023 let c_i = (1.0 + sg * sg).sqrt();
1024 let logslope_factor =
1025 q_i * probit_scale * probit_scale * g_i / c_i + probit_scale * z[i];
1026 for j in 0..p_m {
1027 marginal_effective[[i, j]] = c_i * marginal[[i, j]];
1028 }
1029 for j in 0..p_g {
1030 effective_logslope[[i, j]] = logslope_factor * logslope[[i, j]];
1031 }
1032 }
1033 if effective_logslope.iter().all(|v| v.abs() <= f64::EPSILON) {
1034 return Ok(None);
1035 }
1036
1037 let mut gram = fast_xt_diag_x(&effective_logslope, row_metric);
1038 let gram_scale = gram.diag().iter().copied().fold(0.0_f64, f64::max);
1039 if !gram_scale.is_finite() || gram_scale <= 0.0 {
1040 return Ok(None);
1041 }
1042 let projection_ridge = (gram_scale * 1.0e-10).max(f64::EPSILON);
1043 for i in 0..p_g {
1044 gram[[i, i]] += projection_ridge;
1045 }
1046 let cross = fast_xt_diag_y(&effective_logslope, row_metric, &marginal_effective);
1047 let gram_view = FaerArrayView::new(&gram);
1048 let factor = factorize_symmetricwith_fallback(gram_view.as_ref(), Side::Lower)
1049 .map_err(|e| format!("marginal/logslope overlap Gram factorization failed: {e}"))?;
1050 let rhsview = FaerArrayView::new(&cross);
1051 let coeffs_mat = factor.solve(rhsview.as_ref());
1052 let coeffs = Array2::from_shape_fn((p_g, p_m), |(i, j)| coeffs_mat[(i, j)]);
1053 let projected_marginal = fast_ab(&effective_logslope, &coeffs);
1054 let mut penalty = fast_xt_diag_y(&marginal_effective, row_metric, &projected_marginal);
1055 penalty = (&penalty + &penalty.t()) * 0.5;
1056 let max_abs = penalty.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
1057 if !max_abs.is_finite() || max_abs <= 1.0e-12 {
1058 return Ok(None);
1059 }
1060 Ok(Some(penalty))
1061 }
1062
1063 #[test]
1071 pub(crate) fn effective_reduction_drops_score_weighted_confound_raw_audit_misses() {
1072 let m = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1074 let g = Array2::<f64>::from_shape_vec((3, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 9.0]).unwrap();
1075 let z = Array1::from_vec(vec![1.0, 0.5, 1.0 / 3.0]);
1076 let w = Array1::<f64>::ones(3);
1077 let zero = Array1::<f64>::zeros(3);
1078
1079 let reparam = reduced_logslope_transform_effective(
1082 m.view(),
1083 g.view(),
1084 &z,
1085 &w,
1086 &zero,
1087 &zero,
1088 0.0,
1089 0.0,
1090 1.0,
1091 )
1092 .expect("effective reduction must succeed")
1093 .expect("effective audit must reduce the score-weighted confound (raw audit would not)");
1094 assert_eq!(
1095 reparam.ncols(),
1096 1,
1097 "exactly one effective-identifiable logslope direction should survive"
1098 );
1099
1100 let g_eff = {
1104 let mut e = Array2::<f64>::zeros((3, 2));
1105 for i in 0..3 {
1106 for j in 0..2 {
1107 e[[i, j]] = z[i] * g[[i, j]];
1108 }
1109 }
1110 e
1111 };
1112 let img = g_eff.dot(&reparam.column(0));
1113 let mean = img.iter().sum::<f64>() / 3.0;
1114 let var = img.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / 3.0;
1115 assert!(
1116 var > 1.0e-6,
1117 "kept direction must be the identifiable (non-constant) effective column, var={var}"
1118 );
1119 }
1120
1121 #[test]
1127 pub(crate) fn effective_reduction_fully_confounded_single_column_returns_none() {
1128 let m = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1129 let g = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1130 let z = Array1::from_vec(vec![1.0, 0.5, 1.0 / 3.0]);
1131 let w = Array1::<f64>::ones(3);
1132 let zero = Array1::<f64>::zeros(3);
1133 let reparam = reduced_logslope_transform_effective(
1134 m.view(),
1135 g.view(),
1136 &z,
1137 &w,
1138 &zero,
1139 &zero,
1140 0.0,
1141 0.0,
1142 1.0,
1143 )
1144 .expect("effective reduction must succeed");
1145 assert!(
1146 reparam.is_none(),
1147 "fully effective-confounded logslope must keep raw design (None), not a 0-width block"
1148 );
1149 }
1150
1151 #[test]
1154 pub(crate) fn effective_reduction_no_confound_returns_none() {
1155 let m = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1156 let g = Array2::<f64>::from_shape_vec((3, 2), vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
1158 let z = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1159 let w = Array1::<f64>::ones(3);
1160 let zero = Array1::<f64>::zeros(3);
1161 let reparam = reduced_logslope_transform_effective(
1162 m.view(),
1163 g.view(),
1164 &z,
1165 &w,
1166 &zero,
1167 &zero,
1168 0.0,
1169 0.0,
1170 1.0,
1171 )
1172 .expect("effective reduction must succeed");
1173 assert!(
1174 reparam.is_none(),
1175 "no effective confound ⇒ no reduction (raw design kept unchanged)"
1176 );
1177 }
1178
1179 #[test]
1180 pub(crate) fn spatial_joint_setup_counts_only_learned_penalties_in_rho() {
1181 let data = Array2::<f64>::zeros((3, 1));
1182 let empty_terms = TermCollectionSpec {
1183 linear_terms: Vec::new(),
1184 random_effect_terms: Vec::new(),
1185 smooth_terms: Vec::new(),
1186 };
1187 let setup = joint_setup(
1188 data.view(),
1189 &empty_terms,
1190 &empty_terms,
1191 2,
1192 3,
1193 &[0.4],
1194 &SpatialLengthScaleOptimizationOptions::default(),
1195 );
1196
1197 assert_eq!(
1198 setup.rho_dim(),
1199 6,
1200 "BMS spatial setup rho must contain only learned marginal/logslope/auxiliary penalties; fixed physical ridges are carried by PenaltyMatrix::Fixed"
1201 );
1202 }
1203
1204 #[test]
1205 pub(crate) fn overlap_penalty_targets_score_weighted_logslope_span() {
1206 let marginal = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1207 Array2::from_shape_vec((4, 1), vec![0.0, 1.0, 2.0, 3.0]).unwrap(),
1208 ));
1209 let logslope = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1210 Array2::from_shape_vec((4, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap(),
1211 ));
1212 let z = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
1213 let row_metric = Array1::ones(4);
1214 let offsets = Array1::zeros(4);
1215
1216 let penalty = marginal_logslope_overlap_penalty(
1217 &marginal,
1218 &logslope,
1219 &z,
1220 &row_metric,
1221 &offsets,
1222 &offsets,
1223 0.0,
1224 0.0,
1225 1.0,
1226 )
1227 .expect("overlap penalty should build")
1228 .expect("marginal signal lies in the pilot logslope Jacobian span");
1229
1230 assert_eq!(penalty.dim(), (1, 1));
1231 assert!((penalty[[0, 0]] - 14.0).abs() < 1.0e-6);
1232 }
1233
1234 #[test]
1235 pub(crate) fn overlap_penalty_skips_weight_orthogonal_channels() {
1236 let marginal = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1237 Array2::from_shape_vec((4, 1), vec![-1.0, 1.0, -1.0, 1.0]).unwrap(),
1238 ));
1239 let logslope = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1240 Array2::from_shape_vec((4, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap(),
1241 ));
1242 let z = Array1::ones(4);
1243 let row_metric = Array1::ones(4);
1244 let offsets = Array1::zeros(4);
1245
1246 let penalty = marginal_logslope_overlap_penalty(
1247 &marginal,
1248 &logslope,
1249 &z,
1250 &row_metric,
1251 &offsets,
1252 &offsets,
1253 0.0,
1254 0.0,
1255 1.0,
1256 )
1257 .expect("overlap penalty should build");
1258
1259 assert!(penalty.is_none());
1260 }
1261
1262 fn dense_marginal_design(
1270 x: Array2<f64>,
1271 intercept_range: std::ops::Range<usize>,
1272 linear_ranges: Vec<(String, std::ops::Range<usize>)>,
1273 ) -> TermCollectionDesign {
1274 TermCollectionDesign {
1275 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x)),
1276 penalties: Vec::new(),
1277 nullspace_dims: Vec::new(),
1278 penaltyinfo: Vec::new(),
1279 dropped_penaltyinfo: Vec::new(),
1280 coefficient_lower_bounds: None,
1281 linear_constraints: None,
1282 intercept_range,
1283 linear_ranges,
1284 random_effect_ranges: Vec::new(),
1285 random_effect_levels: Vec::new(),
1286 smooth: gam_terms::smooth::SmoothDesign {
1287 term_designs: Vec::new(),
1288 penalties: Vec::new(),
1289 nullspace_dims: Vec::new(),
1290 penaltyinfo: Vec::new(),
1291 dropped_penaltyinfo: Vec::new(),
1292 terms: Vec::new(),
1293 coefficient_lower_bounds: None,
1294 linear_constraints: None,
1295 },
1296 }
1297 }
1298
1299 fn linear_term(name: &str, feature_col: usize) -> LinearTermSpec {
1300 LinearTermSpec {
1301 name: name.to_string(),
1302 feature_col,
1303 feature_cols: vec![feature_col],
1304 categorical_levels: vec![],
1305 double_penalty: false,
1306 coefficient_geometry: LinearCoefficientGeometry::default(),
1307 coefficient_min: None,
1308 coefficient_max: None,
1309 }
1310 }
1311
1312 fn empty_spec() -> TermCollectionSpec {
1313 TermCollectionSpec {
1314 linear_terms: Vec::new(),
1315 random_effect_terms: Vec::new(),
1316 smooth_terms: Vec::new(),
1317 }
1318 }
1319
1320 #[test]
1327 pub(crate) fn runaway_guard_silent_when_huge_beta_cancels_to_bounded_eta() {
1328 let x = Array2::<f64>::from_shape_vec((4, 2), vec![1.0; 8]).unwrap();
1330 let design = dense_marginal_design(x, 0..0, Vec::new());
1331 let beta = Array1::from_vec(vec![60.0, -60.0]);
1332
1333 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1334 beta.view(),
1335 &design,
1336 &empty_spec(),
1337 true,
1338 "regression-fixture",
1339 );
1340 assert!(
1341 msg.is_none(),
1342 "huge cancelling β with bounded fitted η must NOT trip the runaway guard; got {msg:?}"
1343 );
1344 }
1345
1346 #[test]
1350 pub(crate) fn runaway_guard_fires_when_fitted_eta_exceeds_threshold() {
1351 let x = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1352 let design = dense_marginal_design(x, 0..0, Vec::new());
1353 let beta = Array1::from_vec(vec![40.0]);
1354
1355 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1356 beta.view(),
1357 &design,
1358 &empty_spec(),
1359 true,
1360 "separation-fixture",
1361 )
1362 .expect("fitted |η|∞=40 ≥ 35 must trip the runaway guard");
1363
1364 assert!(msg.contains("marginal/logslope runaway"));
1365 assert!(msg.contains("|η|∞"));
1366 assert!(msg.contains("4.000e1"));
1367 assert!(msg.contains("score is correlated with the shared surface covariates"));
1368 assert!(msg.contains("not a Matérn/Duchon polynomial-nullspace"));
1369 assert!(msg.contains("KKT certificate"));
1370 }
1371
1372 #[test]
1376 pub(crate) fn runaway_guard_names_unpenalized_parametric_direction_via_fitted_eta() {
1377 let x = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1378 let design = dense_marginal_design(x, 0..0, vec![("sex".to_string(), 0..1)]);
1379 let mut spec = empty_spec();
1380 spec.linear_terms.push(linear_term("sex", 0));
1381 let beta = Array1::from_vec(vec![41.0]);
1382
1383 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1384 beta.view(),
1385 &design,
1386 &spec,
1387 true,
1388 "parametric-fixture",
1389 )
1390 .expect("parametric fitted |η|∞=41 ≥ 35 must trip the runaway guard");
1391
1392 assert!(msg.contains("unpenalized parametric marginal direction"));
1393 assert!(msg.contains("|η|∞"));
1394 assert!(msg.contains("robust Jeffreys curvature path is already installed"));
1395 assert!(msg.contains("not a Matérn/Duchon polynomial-nullspace"));
1396 }
1397
1398 #[test]
1402 pub(crate) fn runaway_guard_silent_for_nonconverged_but_bounded_eta() {
1403 let x = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1404 let design = dense_marginal_design(x, 0..0, Vec::new());
1405 let beta = Array1::from_vec(vec![5.0]);
1406
1407 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1408 beta.view(),
1409 &design,
1410 &empty_spec(),
1411 false,
1412 "nonconverged-fixture",
1413 );
1414 assert!(
1415 msg.is_none(),
1416 "bounded fitted η must not raise the separation error even when the inner solve did not converge; got {msg:?}"
1417 );
1418 }
1419
1420 #[test]
1423 pub(crate) fn runaway_guard_fires_for_nonconverged_separating_eta() {
1424 let x = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1425 let design = dense_marginal_design(x, 0..0, Vec::new());
1426 let beta = Array1::from_vec(vec![50.0]);
1427
1428 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1429 beta.view(),
1430 &design,
1431 &empty_spec(),
1432 false,
1433 "nonconverged-separating-fixture",
1434 )
1435 .expect("separating |η|∞ at non-convergence must still trip the guard");
1436
1437 assert!(msg.contains(
1438 "the inner solve failed while already carrying a separation-scale predictor"
1439 ));
1440 }
1441}
1442
1443pub(crate) fn build_marginal_blockspec_bms(
1444 design: &TermCollectionDesign,
1445 baseline: f64,
1446 offset: &Array1<f64>,
1447 rho: Array1<f64>,
1448 beta_hint: Option<Array1<f64>>,
1449 logslope_design: &TermCollectionDesign,
1450 logslope_offset: &Array1<f64>,
1451 logslope_baseline: f64,
1452 p_marginal: usize,
1453 influence_columns: Option<&Array2<f64>>,
1454 influence_ridge_log_lambda: f64,
1455) -> Result<ParameterBlockSpec, String> {
1456 let offset_m = offset + baseline;
1457 let offset_s = logslope_offset + logslope_baseline;
1458 let raw_marginal_dense = design
1459 .design
1460 .try_to_dense_arc("build_marginal_blockspec_bms::marginal")?;
1461 let marginal_dense =
1462 widen_marginal_dense_with_influence(&raw_marginal_dense, influence_columns)?;
1463 let logslope_dense = logslope_design
1464 .design
1465 .try_to_dense_arc("build_marginal_blockspec_bms::logslope")?;
1466 let callback: Arc<dyn BlockEffectiveJacobian> = Arc::new(BmsMarginalJacobian {
1467 marginal_dense: Arc::clone(&marginal_dense),
1468 logslope_dense,
1469 offset_m: offset_m.clone(),
1470 offset_s,
1471 p_marginal,
1472 });
1473 let (penalties, nullspace_dims, initial_log_lambdas) = marginal_penalties_with_influence_ridge(
1474 design,
1475 &rho,
1476 influence_columns,
1477 influence_ridge_log_lambda,
1478 )?;
1479 Ok(ParameterBlockSpec {
1480 name: "marginal_surface".to_string(),
1481 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1482 (*marginal_dense).clone(),
1483 )),
1484 offset: offset_m,
1485 penalties,
1486 nullspace_dims,
1487 initial_log_lambdas,
1488 initial_beta: widen_marginal_beta_hint(beta_hint, p_marginal),
1489 gauge_priority: GAUGE_PRIORITY_MARGINAL,
1502 jacobian_callback: Some(callback),
1503 stacked_design: None,
1504 stacked_offset: None,
1505 })
1506}
1507
1508pub(crate) fn build_logslope_blockspec_bms(
1509 design: &TermCollectionDesign,
1510 baseline: f64,
1511 offset: &Array1<f64>,
1512 rho: Array1<f64>,
1513 beta_hint: Option<Array1<f64>>,
1514 marginal_design: &TermCollectionDesign,
1515 marginal_offset: &Array1<f64>,
1516 marginal_baseline: f64,
1517 z: Arc<Array1<f64>>,
1518 p_marginal: usize,
1519 influence_columns: Option<&Array2<f64>>,
1520) -> Result<ParameterBlockSpec, String> {
1521 let offset_s = offset + baseline;
1522 let offset_m = marginal_offset + marginal_baseline;
1523 let raw_marginal_dense = marginal_design
1524 .design
1525 .try_to_dense_arc("build_logslope_blockspec_bms::marginal")?;
1526 let marginal_dense =
1531 widen_marginal_dense_with_influence(&raw_marginal_dense, influence_columns)?;
1532 let logslope_dense = design
1533 .design
1534 .try_to_dense_arc("build_logslope_blockspec_bms::logslope")?;
1535 let callback: Arc<dyn BlockEffectiveJacobian> = Arc::new(BmsLogslopeJacobian {
1536 marginal_dense,
1537 logslope_dense: Arc::clone(&logslope_dense),
1538 offset_m,
1539 offset_s: offset_s.clone(),
1540 z,
1541 p_marginal,
1542 });
1543 Ok(ParameterBlockSpec {
1544 name: "logslope_surface".to_string(),
1545 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1546 (*logslope_dense).clone(),
1547 )),
1548 offset: offset_s,
1549 penalties: design.penalties_as_penalty_matrix(),
1550 nullspace_dims: design.nullspace_dims.clone(),
1551 initial_log_lambdas: rho,
1552 initial_beta: beta_hint,
1553 gauge_priority: GAUGE_PRIORITY_LOGSLOPE,
1561 jacobian_callback: Some(callback),
1562 stacked_design: None,
1563 stacked_offset: None,
1564 })
1565}
1566
1567pub(crate) fn build_deviation_aux_blockspec(
1568 name: &str,
1569 prepared: &DeviationPrepared,
1570 rho: Array1<f64>,
1571 beta_hint: Option<Array1<f64>>,
1572) -> Result<ParameterBlockSpec, String> {
1573 let mut block = prepared.block.clone();
1574 block.initial_log_lambdas = Some(rho);
1575 let candidate_beta = beta_hint.or_else(|| Some(Array1::<f64>::zeros(block.design.ncols())));
1576 block.initial_beta = candidate_beta
1577 .map(|beta| {
1578 let zero = Array1::<f64>::zeros(beta.len());
1579 project_monotone_feasible_beta(&prepared.runtime, &zero, &beta, name)
1580 })
1581 .transpose()?;
1582 let mut spec = block.intospec(name)?;
1583 spec.gauge_priority = match name {
1593 "link_dev" => GAUGE_PRIORITY_LINK_DEV,
1594 "score_warp_dev" => GAUGE_PRIORITY_SCORE_WARP_DEV,
1601 _ => GAUGE_PRIORITY_DEVIATION_DEFAULT,
1602 };
1603 Ok(spec)
1604}
1605
1606pub(crate) fn push_deviation_aux_blockspecs(
1607 blocks: &mut Vec<ParameterBlockSpec>,
1608 rho: &Array1<f64>,
1609 cursor: &mut usize,
1610 score_warp_prepared: Option<&DeviationPrepared>,
1611 link_dev_prepared: Option<&DeviationPrepared>,
1612 score_warp_beta_hint: Option<Array1<f64>>,
1613 link_dev_beta_hint: Option<Array1<f64>>,
1614) -> Result<(), String> {
1615 if let Some(prepared) = score_warp_prepared {
1616 let rho_h = rho
1617 .slice(s![*cursor..*cursor + prepared.block.penalties.len()])
1618 .to_owned();
1619 *cursor += prepared.block.penalties.len();
1620 blocks.push(build_deviation_aux_blockspec(
1621 "score_warp_dev",
1622 prepared,
1623 rho_h,
1624 score_warp_beta_hint,
1625 )?);
1626 }
1627 if let Some(prepared) = link_dev_prepared {
1628 let rho_w = rho
1629 .slice(s![*cursor..*cursor + prepared.block.penalties.len()])
1630 .to_owned();
1631 blocks.push(build_deviation_aux_blockspec(
1632 "link_dev",
1633 prepared,
1634 rho_w,
1635 link_dev_beta_hint,
1636 )?);
1637 }
1638 Ok(())
1639}
1640
1641fn inner_fit(
1642 family: &BernoulliMarginalSlopeFamily,
1643 blocks: &[ParameterBlockSpec],
1644 options: &BlockwiseFitOptions,
1645) -> Result<UnifiedFitResult, String> {
1646 let mut options = options.clone();
1647 options.use_outer_hessian = false;
1652 options.outer_tol = options.outer_tol.max(2.0e-5);
1653 fit_custom_family(family, blocks, &options).map_err(|e| e.to_string())
1654}
1655
1656pub fn fit_bernoulli_marginal_slope_terms(
1657 data: ArrayView2<'_, f64>,
1658 spec: BernoulliMarginalSlopeTermSpec,
1659 options: &BlockwiseFitOptions,
1660 kappa_options: &SpatialLengthScaleOptimizationOptions,
1661 policy: &gam_runtime::resource::ResourcePolicy,
1662) -> Result<BernoulliMarginalSlopeFitResult, String> {
1663 let mut spec = spec;
1664 let data_view = data;
1665 validate_spec(data_view, &spec)?;
1666 let mjs_frozen_marginal =
1675 gam_terms::smooth::freeze_measure_jet_length_scale_learning(&mut spec.marginalspec);
1676 let mjs_frozen_logslope =
1677 gam_terms::smooth::freeze_measure_jet_length_scale_learning(&mut spec.logslopespec);
1678 if mjs_frozen_marginal + mjs_frozen_logslope > 0 {
1679 log::info!(
1680 "[BMS spatial] froze measure-jet length-scale learning on {} marginal + {} log-slope \
1681 term(s): the coupled surface keeps ℓ at its conditioned auto value (#1116)",
1682 mjs_frozen_marginal,
1683 mjs_frozen_logslope
1684 );
1685 }
1686 let mut effective_kappa_options = kappa_options.clone();
1687 let kappa_locked_marginal = gam_terms::smooth::all_spatial_terms_kappa_fixed(&spec.marginalspec);
1697 let kappa_locked_logslope = gam_terms::smooth::all_spatial_terms_kappa_fixed(&spec.logslopespec);
1698 if effective_kappa_options.enabled && kappa_locked_marginal && kappa_locked_logslope {
1699 log::info!(
1700 "[BMS spatial] disabling κ/ψ optimization: every spatial term has an \
1701 explicit length_scale and no anisotropy; user-supplied kernel scale is fixed"
1702 );
1703 effective_kappa_options.enabled = false;
1704 }
1705 let flex_spatial_pilot_path = (spec.score_warp.is_some() || spec.link_dev.is_some())
1706 && spec.y.len() >= BMS_FLEX_SPATIAL_OUTER_PILOT_ROW_THRESHOLD
1707 && effective_kappa_options.enabled;
1708 if flex_spatial_pilot_path {
1709 let marginal_terms = spatial_length_scale_term_indices(&spec.marginalspec);
1710 let logslope_terms = spatial_length_scale_term_indices(&spec.logslopespec);
1711 let marginal_updates = apply_spatial_anisotropy_pilot_initializer(
1712 data_view,
1713 &mut spec.marginalspec,
1714 &marginal_terms,
1715 effective_kappa_options.pilot_subsample_threshold,
1716 &effective_kappa_options,
1717 );
1718 let logslope_updates = apply_spatial_anisotropy_pilot_initializer(
1719 data_view,
1720 &mut spec.logslopespec,
1721 &logslope_terms,
1722 effective_kappa_options.pilot_subsample_threshold,
1723 &effective_kappa_options,
1724 );
1725 effective_kappa_options.enabled = false;
1726 log::info!(
1727 "[BMS spatial] n={} flex=true pilot_geometry_updates={} iterative_spatial_outer=false reason=large-flex-spatial-pilot",
1728 spec.y.len(),
1729 marginal_updates + logslope_updates,
1730 );
1731 }
1732 let (z_standardized, z_normalization) = standardize_latent_z_with_policy(
1733 &spec.z,
1734 &spec.weights,
1735 "bernoulli-marginal-slope",
1736 &spec.latent_z_policy,
1737 )?;
1738 spec.z = z_standardized;
1739 let sigma_learnable = matches!(
1740 &spec.frailty,
1741 FrailtySpec::GaussianShift { sigma_fixed: None }
1742 );
1743 let initial_sigma = match &spec.frailty {
1744 FrailtySpec::GaussianShift {
1745 sigma_fixed: Some(s),
1746 } => Some(*s),
1747 FrailtySpec::GaussianShift { sigma_fixed: None } => Some(0.5),
1748 FrailtySpec::None => None,
1749 FrailtySpec::HazardMultiplier { .. } => {
1750 return Err(
1751 "internal: validate_spec should have rejected unsupported marginal-slope frailty"
1752 .to_string(),
1753 );
1754 }
1755 };
1756 let probit_scale = probit_frailty_scale(initial_sigma);
1757 let (_raw_joint_designs, mut joint_specs) = build_term_collection_designs_and_freeze_joint(
1758 data_view,
1759 &[spec.marginalspec.clone(), spec.logslopespec.clone()],
1760 )
1761 .map_err(|e| e.to_string())?;
1762 let marginalspec_boot = joint_specs.remove(0);
1763 let logslopespec_boot = joint_specs.remove(0);
1764 let (mut joint_designs, _) = build_term_collection_designs_and_freeze_joint(
1781 data_view,
1782 &[marginalspec_boot.clone(), logslopespec_boot.clone()],
1783 )
1784 .map_err(|e| format!("failed to rebuild frozen probe BMS joint designs: {e}"))?;
1785 let marginal_design = joint_designs.remove(0);
1786 let logslope_design = joint_designs.remove(0);
1787 let absorber_active = spec
1794 .score_influence_jacobian
1795 .as_ref()
1796 .is_some_and(|j| j.ncols() > 0);
1797 let conditioning_dense = if absorber_active {
1798 None
1799 } else {
1800 Some(
1801 marginal_design
1802 .design
1803 .try_to_dense_arc("bernoulli marginal-slope conditional latent-z gate")?,
1804 )
1805 };
1806 let (latent_measure, latent_z_calibration) = build_latent_measure_with_geometry(
1807 &spec.z,
1808 &spec.weights,
1809 &spec.latent_z_policy,
1810 conditioning_dense.as_ref().map(|d| d.view()),
1811 )?;
1812 if latent_measure.is_empirical() && sigma_learnable {
1813 return Err("empirical latent-measure marginal-slope calibration requires fixed GaussianShift sigma; learnable sigma derivatives must be fit under the standard-normal latent measure"
1814 .to_string());
1815 }
1816
1817 let y = Arc::new(spec.y.clone());
1818 let weights = Arc::new(spec.weights.clone());
1819 let z = match &latent_z_calibration {
1824 LatentMeasureCalibration::None => Arc::new(spec.z.clone()),
1825 LatentMeasureCalibration::RankInverseNormal(cal) => {
1826 Arc::new(cal.apply_to_training(&spec.z)?)
1827 }
1828 LatentMeasureCalibration::ConditionalLocationScale(cal) => {
1829 let a_block = conditioning_dense.as_ref().ok_or_else(|| {
1832 "conditional latent calibration requires the marginal conditioning block"
1833 .to_string()
1834 })?;
1835 Arc::new(cal.apply(spec.z.view(), a_block.view())?)
1836 }
1837 };
1838 let z_train = z.as_ref();
1839 let pilot_baseline = pooled_probit_baseline(&spec.y, z_train, &spec.weights)?;
1840 let baseline = (
1841 bernoulli_marginal_slope_eta_from_probability(
1842 &spec.base_link,
1843 normal_cdf(pilot_baseline.0),
1844 "bernoulli marginal-slope baseline link inversion",
1845 )?,
1846 pilot_baseline.1 / probit_scale,
1847 );
1848
1849 let rigid_pilot_eta = rigid_pooled_probit_pilot_eta(
1892 &spec.base_link,
1893 z_train,
1894 &spec.marginal_offset,
1895 &spec.logslope_offset,
1896 baseline.0,
1897 baseline.1,
1898 probit_scale,
1899 )?;
1900 let cross_block_pilot_w_score_warp =
1901 pilot_irls_hessian_row_metric_at_eta(&rigid_pilot_eta, &spec.weights);
1902
1903 let influence_columns = if let Some(jac) = spec
1914 .score_influence_jacobian
1915 .as_ref()
1916 .filter(|j| j.ncols() > 0)
1917 {
1918 let marginal_dense_for_proj = marginal_design
1919 .design
1920 .try_to_dense_arc("bernoulli marginal-slope influence-block marginal projection")?;
1921 let marginal_dense = marginal_dense_for_proj.as_ref();
1922 if jac.nrows() != marginal_dense.nrows() {
1923 return Err(format!(
1924 "influence block: Jacobian has {} rows, marginal design has {}",
1925 jac.nrows(),
1926 marginal_dense.nrows()
1927 ));
1928 }
1929 let rigid_logslope_at_rows = &spec.logslope_offset + baseline.1;
1936 let residualized =
1937 crate::marginal_slope_orthogonal::residualized_influence_block(
1938 jac,
1939 z_train,
1940 &rigid_logslope_at_rows,
1941 probit_scale,
1942 marginal_dense.view(),
1943 &cross_block_pilot_w_score_warp,
1944 )?;
1945 Some(residualized)
1946 } else {
1947 None
1948 };
1949 let mut cross_block_warnings: Vec<CrossBlockIdentifiabilityWarning> = Vec::new();
1950 let score_warp_prepared = if let Some(cfg) = spec.score_warp.as_ref() {
1951 use super::deviation_runtime::ParametricAnchorBlock;
1952 let mut prepared = build_score_warp_deviation_block_from_seed(z_train, cfg)?;
1953 let outcome = install_compiled_flex_block_into_runtime(
1958 &mut prepared,
1959 z_train,
1960 cfg,
1961 &[
1962 (&marginal_design.design, ParametricAnchorBlock::Marginal),
1963 (&logslope_design.design, ParametricAnchorBlock::Logslope),
1964 ],
1965 &[],
1966 &cross_block_pilot_w_score_warp,
1967 )?;
1968 match outcome {
1969 FlexCompileOutcome::Reparameterised => Some(prepared),
1970 FlexCompileOutcome::FullyAliased { reason } => {
1971 cross_block_warnings.push(CrossBlockIdentifiabilityWarning {
1977 candidate_label: "score_warp",
1978 anchor_summary: "marginal+logslope".to_string(),
1979 reason,
1980 });
1981 Some(prepared)
1982 }
1983 }
1984 } else {
1985 None
1986 };
1987 let link_dev_prepared = if let Some(cfg) = spec.link_dev.as_ref() {
2013 let eta_pilot = pilot_eta_for_link_dev_orthogonalisation(
2014 &spec.base_link,
2015 &spec.y,
2016 z_train,
2017 &spec.weights,
2018 &marginal_design.design,
2019 &spec.marginal_offset,
2020 &spec.logslope_offset,
2021 baseline.0,
2022 baseline.1,
2023 probit_scale,
2024 )?;
2025 let link_dev_seed = padded_deviation_seed(&eta_pilot, 1.0, 0.5);
2026 let mut prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
2027 &link_dev_seed,
2028 &eta_pilot,
2029 cfg,
2030 )?;
2031 let score_warp_anchor_design = score_warp_prepared
2068 .as_ref()
2069 .map(|sw| sw.runtime.design_at_training_with_residual(z_train))
2070 .transpose()?;
2071 use super::deviation_runtime::ParametricAnchorBlock;
2072 let parametric_anchors: [(&DesignMatrix, ParametricAnchorBlock); 2] = [
2073 (&marginal_design.design, ParametricAnchorBlock::Marginal),
2074 (&logslope_design.design, ParametricAnchorBlock::Logslope),
2075 ];
2076 let flex_anchor_slot: Option<&Array2<f64>> = score_warp_anchor_design.as_ref();
2077 let flex_anchors: Vec<&Array2<f64>> = flex_anchor_slot.into_iter().collect();
2078 let cross_block_pilot_w_link_dev =
2083 pilot_irls_hessian_row_metric_at_eta(&eta_pilot, &spec.weights);
2084 let outcome = install_compiled_flex_block_into_runtime(
2085 &mut prepared,
2086 &eta_pilot,
2087 cfg,
2088 ¶metric_anchors,
2089 &flex_anchors,
2090 &cross_block_pilot_w_link_dev,
2091 )?;
2092 match outcome {
2093 FlexCompileOutcome::Reparameterised => Some(prepared),
2094 FlexCompileOutcome::FullyAliased { reason } => {
2095 cross_block_warnings.push(CrossBlockIdentifiabilityWarning {
2101 candidate_label: "link_deviation",
2102 anchor_summary: "marginal+logslope+score_warp".to_string(),
2103 reason,
2104 });
2105 Some(prepared)
2106 }
2107 }
2108 } else {
2109 None
2110 };
2111 let extra_rho0 = {
2112 let mut out = Vec::new();
2113 if let Some(ref prepared) = score_warp_prepared {
2114 out.extend(std::iter::repeat_n(0.0, prepared.block.penalties.len()));
2115 }
2116 if let Some(ref prepared) = link_dev_prepared {
2117 out.extend(std::iter::repeat_n(0.0, prepared.block.penalties.len()));
2118 }
2119 out
2120 };
2121 let logslope_reduced_reparam: Option<ReducedLogslopeReparam> = build_reduced_logslope_reparam(
2134 &marginal_design,
2135 &logslope_design,
2136 z.as_ref(),
2137 &cross_block_pilot_w_score_warp,
2138 &spec.marginal_offset,
2139 &spec.logslope_offset,
2140 baseline.0,
2141 baseline.1,
2142 probit_scale,
2143 )?;
2144 let reduce_logslope_design =
2150 |logslope_design: &TermCollectionDesign| -> Result<TermCollectionDesign, String> {
2151 match logslope_reduced_reparam.as_ref() {
2152 Some(reparam) => reparameterize_logslope_design_reduced(logslope_design, reparam),
2153 None => Ok(logslope_design.clone()),
2154 }
2155 };
2156
2157 let marginal_penalty_count = marginal_design.penalties.len();
2158 let setup = joint_setup(
2159 data_view,
2160 &marginalspec_boot,
2161 &logslopespec_boot,
2162 marginal_penalty_count,
2163 logslope_design.penalties.len(),
2164 &extra_rho0,
2165 &effective_kappa_options,
2166 );
2167 let setup = if sigma_learnable {
2168 setup.with_auxiliary(
2169 Array1::from_vec(vec![initial_sigma.expect("learnable sigma seed").ln()]),
2170 Array1::from_vec(vec![0.01_f64.ln()]),
2171 Array1::from_vec(vec![5.0_f64.ln()]),
2172 )
2173 } else {
2174 setup
2175 };
2176 let final_sigma_cell = std::cell::Cell::new(initial_sigma);
2177 let exact_warm_start = RefCell::new(None::<CustomFamilyWarmStart>);
2178 let runaway_error = RefCell::new(None::<String>);
2179 let pending_beta_seed = RefCell::new(None::<Array1<f64>>);
2186 let hints = RefCell::new(ThetaHints::default());
2187 let score_warp_runtime = score_warp_prepared.as_ref().map(|p| p.runtime.clone());
2188 let link_dev_runtime = link_dev_prepared.as_ref().map(|p| p.runtime.clone());
2189
2190 let build_blocks = |rho: &Array1<f64>,
2191 marginal_design: &TermCollectionDesign,
2192 logslope_design: &TermCollectionDesign|
2193 -> Result<Vec<ParameterBlockSpec>, String> {
2194 let hints = hints.borrow();
2195 let mut cursor = 0usize;
2196 let logslope_design_reduced = reduce_logslope_design(logslope_design)?;
2203 let logslope_design = &logslope_design_reduced;
2204 let rho_marginal = rho
2209 .slice(s![cursor..cursor + marginal_design.penalties.len()])
2210 .to_owned();
2211 cursor += marginal_design.penalties.len();
2212 let rho_logslope = rho
2213 .slice(s![cursor..cursor + logslope_design.penalties.len()])
2214 .to_owned();
2215 cursor += logslope_design.penalties.len();
2216 let p_m = marginal_design.design.ncols()
2217 + influence_columns.as_ref().map(|z| z.ncols()).unwrap_or(0);
2218 let mut blocks = vec![
2219 build_marginal_blockspec_bms(
2220 marginal_design,
2221 baseline.0,
2222 &spec.marginal_offset,
2223 rho_marginal,
2224 hints.marginal_beta.clone(),
2225 logslope_design,
2226 &spec.logslope_offset,
2227 baseline.1,
2228 p_m,
2229 influence_columns.as_ref(),
2230 INFLUENCE_ABSORBER_FIXED_LOG_LAMBDA,
2231 )?,
2232 build_logslope_blockspec_bms(
2233 logslope_design,
2234 baseline.1,
2235 &spec.logslope_offset,
2236 rho_logslope,
2237 hints.logslope_beta.clone(),
2238 marginal_design,
2239 &spec.marginal_offset,
2240 baseline.0,
2241 Arc::clone(&z),
2242 p_m,
2243 influence_columns.as_ref(),
2244 )?,
2245 ];
2246 push_deviation_aux_blockspecs(
2247 &mut blocks,
2248 rho,
2249 &mut cursor,
2250 score_warp_prepared.as_ref(),
2251 link_dev_prepared.as_ref(),
2252 hints.score_warp_beta.clone(),
2253 hints.link_dev_beta.clone(),
2254 )?;
2255 Ok(blocks)
2256 };
2257
2258 let intercept_warm_starts = new_intercept_warm_start_cache(y.len());
2259 let cell_moment_lru = new_cell_moment_lru_cache(policy);
2260 let cell_moment_cache_stats = new_cell_moment_cache_stats();
2261 let make_family = |marginal_design: &TermCollectionDesign,
2262 logslope_design: &TermCollectionDesign,
2263 sigma: Option<f64>|
2264 -> BernoulliMarginalSlopeFamily {
2265 let kernel_marginal_design = match influence_columns.as_ref() {
2271 Some(z_infl) => {
2272 let raw = marginal_design
2273 .design
2274 .try_to_dense_arc("make_family::widened-marginal")
2275 .expect("dense marginal design for influence widening");
2276 let widened = widen_marginal_dense_with_influence(&raw, Some(z_infl))
2277 .expect("widen marginal design with influence columns");
2278 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from((*widened).clone()))
2279 }
2280 None => marginal_design.design.clone(),
2281 };
2282 let kernel_logslope_design = reduce_logslope_design(logslope_design)
2288 .expect("reduce logslope design for family construction")
2289 .design;
2290 BernoulliMarginalSlopeFamily {
2291 y: Arc::clone(&y),
2292 weights: Arc::clone(&weights),
2293 z: Arc::clone(&z),
2294 latent_measure: latent_measure.clone(),
2295 gaussian_frailty_sd: sigma,
2296 base_link: spec.base_link.clone(),
2297 marginal_design: kernel_marginal_design,
2298 logslope_design: kernel_logslope_design,
2299 score_warp: score_warp_runtime.clone(),
2300 link_dev: link_dev_runtime.clone(),
2301 policy: policy.clone(),
2302 cell_moment_lru: Arc::clone(&cell_moment_lru),
2303 cell_moment_cache_stats: Arc::clone(&cell_moment_cache_stats),
2304 intercept_warm_starts: Some(Arc::clone(&intercept_warm_starts)),
2305 auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
2306 auto_subsample_last_rho: Arc::new(Mutex::new(None)),
2307 }
2308 };
2309
2310 let marginal_terms = spatial_length_scale_term_indices(&marginalspec_boot);
2311 let logslope_terms = spatial_length_scale_term_indices(&logslopespec_boot);
2312 let marginal_has_spatial = !marginal_terms.is_empty();
2313 let logslope_has_spatial = !logslope_terms.is_empty();
2314 let analytic_joint_derivatives_available =
2315 marginal_has_spatial || logslope_has_spatial || setup.log_kappa_dim() == 0;
2316 if setup.log_kappa_dim() > 0 && !analytic_joint_derivatives_available {
2317 return Err("exact bernoulli marginal-slope spatial optimization requires analytic joint psi derivatives"
2318 .to_string());
2319 }
2320 let initial_rho = setup.theta0().slice(s![..setup.rho_dim()]).to_owned();
2321 let initial_blocks = build_blocks(&initial_rho, &marginal_design, &logslope_design)?;
2322 let initial_family = make_family(&marginal_design, &logslope_design, initial_sigma);
2323 let (joint_gradient, joint_hessian) =
2324 custom_family_outer_derivatives(&initial_family, &initial_blocks, options);
2325 let analytic_joint_gradient_available = analytic_joint_derivatives_available
2326 && matches!(joint_gradient, gam_problem::Derivative::Analytic);
2327 let analytic_joint_hessian_available =
2333 analytic_joint_derivatives_available && joint_hessian.is_analytic();
2334 let kappa_options_ref: &SpatialLengthScaleOptimizationOptions = &effective_kappa_options;
2335 let sigma_from_theta = |theta: &Array1<f64>| -> Option<f64> {
2336 if sigma_learnable {
2337 Some(theta[setup.rho_dim() + setup.log_kappa_dim()].exp())
2338 } else {
2339 initial_sigma
2340 }
2341 };
2342 let derivative_block_cache = RefCell::new(
2343 None::<(
2344 Array1<f64>,
2345 Arc<Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>>,
2346 )>,
2347 );
2348 let theta_matches = |left: &Array1<f64>, right: &Array1<f64>| -> bool {
2349 left.len() == right.len()
2350 && left
2351 .iter()
2352 .zip(right.iter())
2353 .all(|(lhs, rhs)| (*lhs - *rhs).abs() <= 1e-12 * (1.0 + lhs.abs().max(rhs.abs())))
2354 };
2355 let get_derivative_blocks = |theta: &Array1<f64>,
2356 specs: &[TermCollectionSpec],
2357 designs: &[TermCollectionDesign]|
2358 -> Result<
2359 Arc<Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>>,
2360 String,
2361 > {
2362 if let Some((cached_theta, cached_blocks)) = derivative_block_cache.borrow().as_ref()
2363 && theta_matches(cached_theta, theta)
2364 {
2365 return Ok(Arc::clone(cached_blocks));
2366 }
2367
2368 let built = |specs: &[TermCollectionSpec],
2369 designs: &[TermCollectionDesign]|
2370 -> Result<
2371 Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>,
2372 String,
2373 > {
2374 let marginal_psi_derivs = if marginal_has_spatial {
2375 build_block_spatial_psi_derivatives(data_view, &specs[0], &designs[0])?.ok_or_else(
2376 || {
2377 "bernoulli marginal-slope: marginal block has spatial terms \
2378 but spatial psi derivatives are unavailable"
2379 .to_string()
2380 },
2381 )?
2382 } else {
2383 Vec::new()
2384 };
2385 let logslope_psi_derivs = if logslope_has_spatial {
2386 build_block_spatial_psi_derivatives(data_view, &specs[1], &designs[1])?.ok_or_else(
2387 || {
2388 "bernoulli marginal-slope: logslope block has spatial terms \
2389 but spatial psi derivatives are unavailable"
2390 .to_string()
2391 },
2392 )?
2393 } else {
2394 Vec::new()
2395 };
2396 let mut derivative_blocks = vec![marginal_psi_derivs, logslope_psi_derivs];
2397 if score_warp_runtime.is_some() {
2398 derivative_blocks.push(Vec::new());
2399 }
2400 if link_dev_runtime.is_some() {
2401 derivative_blocks.push(Vec::new());
2402 }
2403 if sigma_learnable {
2404 derivative_blocks
2405 .last_mut()
2406 .expect("bernoulli derivative block list is non-empty")
2407 .push(crate::custom_family::CustomFamilyBlockPsiDerivative::new(
2408 None,
2409 Array2::zeros((0, 0)),
2410 Array2::zeros((0, 0)),
2411 None,
2412 None,
2413 None,
2414 None,
2415 ));
2416 }
2417 Ok(derivative_blocks)
2418 }(specs, designs)?;
2419 let built = Arc::new(built);
2420 derivative_block_cache.replace(Some((theta.clone(), Arc::clone(&built))));
2421 Ok(built)
2422 };
2423
2424 let outer_policy = {
2429 let psi_dim = setup.theta0().len() - setup.rho_dim();
2430 initial_family.outer_derivative_policy(&initial_blocks, psi_dim, options)
2431 };
2432 let exact_spatial_outer_tol = kappa_options_ref.rel_tol.max(EXACT_SPATIAL_OUTER_TOL_FLOOR);
2433 let solved = optimize_spatial_length_scale_exact_joint(
2434 data_view,
2435 &[marginalspec_boot.clone(), logslopespec_boot.clone()],
2436 &[marginal_terms.clone(), logslope_terms.clone()],
2437 kappa_options_ref,
2438 &setup,
2439 gam_solve::seeding::SeedRiskProfile::GeneralizedLinear,
2440 analytic_joint_gradient_available,
2441 analytic_joint_hessian_available,
2442 true,
2443 None,
2444 outer_policy,
2445 |theta, specs: &[TermCollectionSpec], designs: &[TermCollectionDesign]| {
2446 if let Some(err) = runaway_error.borrow().as_ref().cloned() {
2447 return Err(err);
2448 }
2449 assert_eq!(
2450 specs.len(),
2451 designs.len(),
2452 "spatial joint optimizer must supply one spec per design",
2453 );
2454 let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
2455 let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
2456 let sigma = sigma_from_theta(theta);
2457 final_sigma_cell.set(sigma);
2458 let family = make_family(&designs[0], &designs[1], sigma);
2459 let fit = inner_fit(&family, &blocks, options)?;
2460 if let Some(block) = fit.block_states.first()
2461 && let Some(err) = bernoulli_marginal_slope_runaway_error_from_beta(
2462 block.beta.view(),
2463 &designs[0],
2464 &specs[0],
2465 fit.outer_converged,
2466 "final fit",
2467 )
2468 {
2469 runaway_error.replace(Some(err.clone()));
2470 return Err(err);
2471 }
2472 let mut hints_mut = hints.borrow_mut();
2473 let mut bidx = 0usize;
2474 if let Some(block) = fit.block_states.get(bidx) {
2475 hints_mut.marginal_beta = Some(block.beta.clone());
2476 }
2477 bidx += 1;
2478 if let Some(block) = fit.block_states.get(bidx) {
2479 hints_mut.logslope_beta = Some(block.beta.clone());
2480 }
2481 bidx += 1;
2482 if score_warp_prepared.is_some() {
2483 if let Some(block) = fit.block_states.get(bidx) {
2484 hints_mut.score_warp_beta = Some(block.beta.clone());
2485 }
2486 bidx += 1;
2487 }
2488 if link_dev_prepared.is_some()
2489 && let Some(block) = fit.block_states.get(bidx)
2490 {
2491 hints_mut.link_dev_beta = Some(block.beta.clone());
2492 }
2493 Ok(fit)
2494 },
2495 |theta,
2496 specs: &[TermCollectionSpec],
2497 designs: &[TermCollectionDesign],
2498 eval_mode,
2499 row_set: &crate::row_kernel::RowSet| {
2500 if let Some(err) = runaway_error.borrow().as_ref().cloned() {
2501 return Err(err);
2502 }
2503 use gam_problem::EvalMode;
2504 static BMS_OUTER_EVAL_ROWSET_LOGGED: std::sync::Once = std::sync::Once::new();
2511 BMS_OUTER_EVAL_ROWSET_LOGGED.call_once(|| {
2512 let row_set_rows = match row_set {
2513 crate::row_kernel::RowSet::All => spec.y.len(),
2514 crate::row_kernel::RowSet::Subsample { rows, .. } => rows.len(),
2515 };
2516 log::debug!(
2517 "[BMS exact outer eval] mode={eval_mode:?} row_set_rows={row_set_rows}"
2518 );
2519 });
2520 let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
2521 let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
2522 if let Some(beta_seed) = pending_beta_seed.borrow_mut().take() {
2526 let widths: Vec<usize> = blocks.iter().map(|b| b.design.ncols()).collect();
2527 match CustomFamilyWarmStart::from_cached_beta(&widths, &beta_seed) {
2528 Ok(ws) => {
2529 exact_warm_start.replace(Some(ws));
2530 }
2531 Err(e) => {
2532 log::warn!(
2533 "[BMS] outer ρ-cache β-warm-start rejected: {e}; falling back to cold β"
2534 );
2535 }
2536 }
2537 }
2538 let sigma = sigma_from_theta(theta);
2539 final_sigma_cell.set(sigma);
2540 let family = make_family(&designs[0], &designs[1], sigma);
2541 let derivative_blocks = get_derivative_blocks(theta, specs, designs)?;
2542 let effective_mode = match eval_mode {
2546 EvalMode::ValueGradientHessian if !analytic_joint_hessian_available => {
2547 EvalMode::ValueAndGradient
2548 }
2549 other => other,
2550 };
2551 let mut eval_options =
2552 joint_hyper_options_for_outer_tolerance(options, exact_spatial_outer_tol);
2553 if let crate::row_kernel::RowSet::Subsample { rows, n_full } = row_set {
2554 let subsample = crate::outer_subsample::OuterScoreSubsample::from_weighted_rows(
2555 rows.as_ref().clone(),
2556 *n_full,
2557 0,
2558 );
2559 eval_options.outer_score_subsample = Some(Arc::new(subsample));
2560 eval_options.auto_outer_subsample = false;
2561 }
2562 let eval = evaluate_custom_family_joint_hyper_shared(
2563 &family,
2564 &blocks,
2565 &eval_options,
2566 &rho,
2567 derivative_blocks,
2568 exact_warm_start.borrow().as_ref(),
2569 effective_mode,
2570 )?;
2571 if let Some(err) = bernoulli_marginal_slope_runaway_error(
2572 &eval.warm_start,
2573 &designs[0],
2574 &specs[0],
2575 eval.inner_converged,
2576 "exact outer evaluation",
2577 ) {
2578 runaway_error.replace(Some(err.clone()));
2579 return Err(err);
2580 }
2581 exact_warm_start.replace(Some(eval.warm_start.clone()));
2582 if !eval.inner_converged {
2583 return Err(
2584 "exact bernoulli marginal-slope inner solve did not converge".to_string(),
2585 );
2586 }
2587 if matches!(eval_mode, EvalMode::ValueGradientHessian)
2588 && analytic_joint_hessian_available
2589 && !eval.outer_hessian.is_analytic()
2590 {
2591 return Err("exact bernoulli marginal-slope joint [rho, psi] objective did not return an outer Hessian"
2592 .to_string());
2593 }
2594 Ok((eval.objective, eval.gradient, eval.outer_hessian))
2595 },
2596 |theta, specs: &[TermCollectionSpec], designs: &[TermCollectionDesign]| {
2597 if let Some(err) = runaway_error.borrow().as_ref().cloned() {
2598 return Err(err);
2599 }
2600 let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
2601 let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
2602 if let Some(beta_seed) = pending_beta_seed.borrow_mut().take() {
2603 let widths: Vec<usize> = blocks.iter().map(|b| b.design.ncols()).collect();
2604 match CustomFamilyWarmStart::from_cached_beta(&widths, &beta_seed) {
2605 Ok(ws) => {
2606 exact_warm_start.replace(Some(ws));
2607 }
2608 Err(e) => {
2609 log::warn!(
2610 "[BMS] outer ρ-cache β-warm-start rejected (efs): {e}; falling back to cold β"
2611 );
2612 }
2613 }
2614 }
2615 let sigma = sigma_from_theta(theta);
2616 final_sigma_cell.set(sigma);
2617 let family = make_family(&designs[0], &designs[1], sigma);
2618 let derivative_blocks = get_derivative_blocks(theta, specs, designs)?;
2619 let eval = evaluate_custom_family_joint_hyper_efs_shared(
2620 &family,
2621 &blocks,
2622 &joint_hyper_options_for_outer_tolerance(options, exact_spatial_outer_tol),
2623 &rho,
2624 derivative_blocks,
2625 exact_warm_start.borrow().as_ref(),
2626 )?;
2627 if let Some(err) = bernoulli_marginal_slope_runaway_error(
2628 &eval.warm_start,
2629 &designs[0],
2630 &specs[0],
2631 eval.inner_converged,
2632 "EFS outer evaluation",
2633 ) {
2634 runaway_error.replace(Some(err.clone()));
2635 return Err(err);
2636 }
2637 exact_warm_start.replace(Some(eval.warm_start.clone()));
2638 if !eval.inner_converged {
2639 return Err(
2640 "exact bernoulli marginal-slope EFS inner solve did not converge".to_string(),
2641 );
2642 }
2643 Ok(eval.efs_eval)
2644 },
2645 crate::marginal_slope_shared::make_beta_seed_validator(&pending_beta_seed),
2646 )?;
2647
2648 let mut resolved_specs = solved.resolved_specs;
2649 let mut designs = solved.designs;
2650 let mut solved_fit = solved.fit;
2662 if let Some(reparam) = logslope_reduced_reparam.as_ref() {
2663 let r = reparam.reduced_cols();
2664 if let Some(block) = solved_fit.blocks.get_mut(1)
2665 && block.beta.len() == r
2666 {
2667 block.beta = reparam.recover_original_logslope_beta(&block.beta)?;
2668 }
2669 if let Some(state) = solved_fit.block_states.get_mut(1)
2670 && state.beta.len() == r
2671 {
2672 state.beta = reparam.recover_original_logslope_beta(&state.beta)?;
2673 }
2674 }
2675 let (latent_z_rank_int_calibration, latent_z_conditional_calibration) =
2717 match latent_z_calibration {
2718 LatentMeasureCalibration::None => (None, None),
2719 LatentMeasureCalibration::RankInverseNormal(cal) => (Some(cal), None),
2720 LatentMeasureCalibration::ConditionalLocationScale(cal) => (None, Some(cal)),
2721 };
2722 if let Some(cal) = latent_z_conditional_calibration.as_ref()
2732 && let Some(vb) = solved_fit.covariance_conditional.clone()
2733 {
2734 let p_beta = vb.nrows();
2735 let marginal_dense = marginal_design
2736 .design
2737 .try_to_dense_arc("bms generated-regressor marginal design")?;
2738 let logslope_reduced = reduce_logslope_design(&logslope_design)?;
2739 let logslope_reduced_dense = logslope_reduced
2740 .design
2741 .try_to_dense_arc("bms generated-regressor reduced logslope design")?;
2742 let p_m = marginal_dense.ncols();
2743 let r = logslope_reduced_dense.ncols();
2744 if p_beta != vb.ncols() {
2745 return Err(format!(
2746 "bms generated-regressor: covariance_conditional must be square, got {}×{}",
2747 vb.nrows(),
2748 vb.ncols()
2749 ));
2750 }
2751 if p_beta == p_m + r {
2755 let marginal_eta = &solved_fit.block_states[0].eta;
2756 let slope_eta = &solved_fit.block_states[1].eta;
2757 let probit_scale = probit_frailty_scale(final_sigma_cell.get());
2758 let s = rigid_standard_normal_score_zeta_sensitivity(
2759 &spec.base_link,
2760 marginal_eta,
2761 slope_eta,
2762 z.as_ref(),
2763 y.as_ref(),
2764 weights.as_ref(),
2765 probit_scale,
2766 marginal_dense.view(),
2767 logslope_reduced_dense.view(),
2768 p_beta,
2769 )?;
2770 let correction = cal.generated_regressor_correction(
2778 s.view(),
2779 spec.z.view(),
2780 marginal_dense.view(),
2781 vb.view(),
2782 )?;
2783 if let Some(cov) = solved_fit.covariance_conditional.as_mut() {
2784 *cov = &*cov + &correction;
2785 }
2786 if let Some(cov) = solved_fit.covariance_corrected.as_mut() {
2787 *cov = &*cov + &correction;
2788 }
2789 log::info!(
2790 "[BMS latent-z] Murphy–Topel generated-regressor SE correction applied: \
2791 p_beta={p_beta} theta1_dim={} max_diag_inflation={:.3e}",
2792 cal.theta1_dim(),
2793 (0..p_beta)
2794 .map(|i| correction[[i, i]])
2795 .fold(0.0_f64, f64::max),
2796 );
2797 } else {
2798 log::info!(
2799 "[BMS latent-z] Murphy–Topel generated-regressor SE correction skipped: \
2800 aux deviation blocks present (p_beta={p_beta} > marginal({p_m})+logslope({r})); \
2801 rigid-kernel z-channel does not yet cover score_warp/link_dev deviations"
2802 );
2803 }
2804 }
2805 Ok(BernoulliMarginalSlopeFitResult {
2818 fit: solved_fit,
2819 marginalspec_resolved: resolved_specs.remove(0),
2820 logslopespec_resolved: resolved_specs.remove(0),
2821 marginal_design: designs.remove(0),
2822 logslope_design: designs.remove(0),
2823 baseline_marginal: baseline.0,
2824 baseline_logslope: baseline.1,
2825 z_normalization,
2826 latent_measure,
2827 score_warp_runtime,
2828 link_dev_runtime,
2829 gaussian_frailty_sd: final_sigma_cell.get(),
2830 cross_block_warnings,
2831 latent_z_rank_int_calibration,
2832 latent_z_conditional_calibration,
2833 })
2834}