1use super::family::*;
2use super::gradient_paths::*;
3use super::hessian_paths::{new_cell_moment_cache_stats, new_cell_moment_lru_cache};
4use super::install_flex::validate_spec;
5use super::*;
6use gam_linalg::faer_ndarray::{FaerEigh, fast_ab, fast_atb, fast_xt_diag_x};
7use crate::marginal_slope_orthogonal::influence_absorber_log_lambda;
8use faer::Side;
9
10pub(crate) const BMS_PROBIT_SEPARATION_ETA_INF: f64 = 35.0;
25
26pub(super) const GAUGE_PRIORITY_ANCHOR: u8 = 200;
41pub(super) const GAUGE_PRIORITY_MARGINAL: u8 = 150;
44pub(super) const GAUGE_PRIORITY_LOGSLOPE: u8 = 120;
46pub(super) const GAUGE_PRIORITY_CANDIDATE_FLEX: u8 = 100;
49pub(super) const GAUGE_PRIORITY_SCORE_WARP_DEV: u8 = 80;
52pub(super) const GAUGE_PRIORITY_DEVIATION_DEFAULT: u8 = 70;
56pub(super) const GAUGE_PRIORITY_LINK_DEV: u8 = 60;
58
59pub(crate) const EXACT_SPATIAL_OUTER_TOL_FLOOR: f64 = 1e-6;
65
66pub struct BmsMarginalJacobian {
105 pub marginal_dense: Arc<Array2<f64>>,
107 pub logslope_dense: Arc<Array2<f64>>,
109 pub offset_m: Array1<f64>,
110 pub offset_s: Array1<f64>,
111 pub p_marginal: usize,
113}
114
115impl BmsMarginalJacobian {
116 pub fn new(
117 marginal_dense: Arc<Array2<f64>>,
118 logslope_dense: Arc<Array2<f64>>,
119 offset_m: Array1<f64>,
120 offset_s: Array1<f64>,
121 p_marginal: usize,
122 ) -> Self {
123 Self {
124 marginal_dense,
125 logslope_dense,
126 offset_m,
127 offset_s,
128 p_marginal,
129 }
130 }
131}
132
133impl BlockEffectiveJacobian for BmsMarginalJacobian {
134 fn effective_jacobian_rows(
135 &self,
136 state: &FamilyLinearizationState<'_>,
137 rows: std::ops::Range<usize>,
138 ) -> Result<Array2<f64>, String> {
139 let beta = state.beta;
140 let s = state.probit_frailty_scale;
141 let p_m = self.p_marginal;
142 let p_s_block = self.logslope_dense.ncols();
143 let beta_s_raw = if beta.len() > p_m {
144 &beta[p_m..]
145 } else {
146 &[][..]
147 };
148 let p_s_use = p_s_block.min(beta_s_raw.len());
149 let beta_s = &beta_s_raw[..p_s_use];
150 let n = self.marginal_dense.nrows();
151 let rows = rows.start.min(n)..rows.end.min(n);
152 let p_block = self.marginal_dense.ncols();
153
154 let mut out = Array2::<f64>::zeros((rows.end - rows.start, p_block));
164 for i in rows.clone() {
165 let g_i = self.offset_s[i]
166 + self
167 .logslope_dense
168 .row(i)
169 .slice(ndarray::s![..p_s_use])
170 .dot(&ArrayView1::from(beta_s));
171 let sg = s * g_i;
172 let c_i = (1.0 + sg * sg).sqrt();
173 let m_row = self.marginal_dense.row(i);
175 out.row_mut(i - rows.start).assign(&m_row.mapv(|x| c_i * x));
176 }
177 Ok(out)
178 }
179
180 fn n_outputs(&self) -> usize {
181 1
182 }
183}
184
185pub struct BmsLogslopeJacobian {
197 pub marginal_dense: Arc<Array2<f64>>,
199 pub logslope_dense: Arc<Array2<f64>>,
201 pub offset_m: Array1<f64>,
202 pub offset_s: Array1<f64>,
203 pub z: Arc<Array1<f64>>,
204 pub p_marginal: usize,
206}
207
208impl BmsLogslopeJacobian {
209 pub fn new(
210 marginal_dense: Arc<Array2<f64>>,
211 logslope_dense: Arc<Array2<f64>>,
212 offset_m: Array1<f64>,
213 offset_s: Array1<f64>,
214 z: Arc<Array1<f64>>,
215 p_marginal: usize,
216 ) -> Self {
217 Self {
218 marginal_dense,
219 logslope_dense,
220 offset_m,
221 offset_s,
222 z,
223 p_marginal,
224 }
225 }
226}
227
228impl BlockEffectiveJacobian for BmsLogslopeJacobian {
229 fn effective_jacobian_rows(
230 &self,
231 state: &FamilyLinearizationState<'_>,
232 rows: std::ops::Range<usize>,
233 ) -> Result<Array2<f64>, String> {
234 let beta = state.beta;
235 let s = state.probit_frailty_scale;
236 let p_m = self.p_marginal;
237 let p_m_use = p_m.min(beta.len());
238 let beta_m = &beta[..p_m_use];
239 let beta_s_raw = if beta.len() > p_m {
240 &beta[p_m..]
241 } else {
242 &[][..]
243 };
244 let p_s_block = self.logslope_dense.ncols();
245 let p_s_use = p_s_block.min(beta_s_raw.len());
246 let beta_s = &beta_s_raw[..p_s_use];
247 let n = self.logslope_dense.nrows();
248 let rows = rows.start.min(n)..rows.end.min(n);
249
250 let mut out = Array2::<f64>::zeros((rows.end - rows.start, p_s_block));
263 for i in rows.clone() {
264 let q_i = self.offset_m[i]
265 + self
266 .marginal_dense
267 .row(i)
268 .slice(ndarray::s![..p_m_use])
269 .dot(&ArrayView1::from(beta_m));
270 let g_i = self.offset_s[i]
271 + self
272 .logslope_dense
273 .row(i)
274 .slice(ndarray::s![..p_s_use])
275 .dot(&ArrayView1::from(beta_s));
276 let sg = s * g_i;
277 let c_i = (1.0 + sg * sg).sqrt();
278 let z_i = self.z[i];
279 let factor = q_i * s * s * g_i / c_i + s * z_i;
281 let g_row = self.logslope_dense.row(i);
283 out.row_mut(i - rows.start)
284 .assign(&g_row.mapv(|x| factor * x));
285 }
286 Ok(out)
287 }
288
289 fn n_outputs(&self) -> usize {
290 1
291 }
292}
293
294pub(crate) fn widen_marginal_dense_with_influence(
304 marginal_dense: &Arc<Array2<f64>>,
305 influence_columns: Option<&Array2<f64>>,
306) -> Result<Arc<Array2<f64>>, String> {
307 let Some(z_infl) = influence_columns else {
308 return Ok(Arc::clone(marginal_dense));
309 };
310 let n = marginal_dense.nrows();
311 if z_infl.nrows() != n {
312 return Err(format!(
313 "influence block: residualised columns have {} rows, marginal design has {n}",
314 z_infl.nrows()
315 ));
316 }
317 let p_m = marginal_dense.ncols();
318 let p1 = z_infl.ncols();
319 let mut widened = Array2::<f64>::zeros((n, p_m + p1));
320 widened
321 .slice_mut(s![.., ..p_m])
322 .assign(marginal_dense.as_ref());
323 widened.slice_mut(s![.., p_m..]).assign(z_infl);
324 Ok(Arc::new(widened))
325}
326
327pub(crate) const LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL: f64 = 1.0e-6;
335
336#[derive(Debug, Clone)]
388pub(super) struct ReducedLogslopeReparam {
389 transform: Array2<f64>,
392}
393
394impl ReducedLogslopeReparam {
395 #[inline]
397 pub(super) fn original_cols(&self) -> usize {
398 self.transform.nrows()
399 }
400
401 #[inline]
403 pub(super) fn reduced_cols(&self) -> usize {
404 self.transform.ncols()
405 }
406
407 pub(super) fn recover_original_logslope_beta(
411 &self,
412 beta_reduced: &Array1<f64>,
413 ) -> Result<Array1<f64>, String> {
414 if beta_reduced.len() != self.reduced_cols() {
415 return Err(format!(
416 "reduced logslope reparam: β' length ({}) != reduced width ({})",
417 beta_reduced.len(),
418 self.reduced_cols()
419 ));
420 }
421 Ok(self.transform.dot(beta_reduced))
422 }
423}
424
425fn build_reduced_logslope_reparam(
433 marginal_design: &TermCollectionDesign,
434 logslope_design: &TermCollectionDesign,
435 z: &Array1<f64>,
436 row_metric: &Array1<f64>,
437 marginal_offset: &Array1<f64>,
438 logslope_offset: &Array1<f64>,
439 marginal_baseline: f64,
440 logslope_baseline: f64,
441 probit_scale: f64,
442) -> Result<Option<ReducedLogslopeReparam>, String> {
443 let marginal = marginal_design
444 .design
445 .try_to_dense_arc("build_reduced_logslope_reparam::marginal")?;
446 let logslope = logslope_design
447 .design
448 .try_to_dense_arc("build_reduced_logslope_reparam::logslope")?;
449 let n = marginal.nrows();
450 if logslope.nrows() != n
451 || z.len() != n
452 || row_metric.len() != n
453 || marginal_offset.len() != n
454 || logslope_offset.len() != n
455 {
456 return Err(format!(
457 "reduced logslope reparam row mismatch: marginal={}, logslope={}, z={}, row_metric={}, marginal_offset={}, logslope_offset={}",
458 marginal.nrows(),
459 logslope.nrows(),
460 z.len(),
461 row_metric.len(),
462 marginal_offset.len(),
463 logslope_offset.len(),
464 ));
465 }
466 let p_m = marginal.ncols();
467 let p_g = logslope.ncols();
468 if p_m == 0 || p_g == 0 {
469 return Ok(None);
470 }
471 if !marginal_baseline.is_finite()
472 || !logslope_baseline.is_finite()
473 || !probit_scale.is_finite()
474 || probit_scale <= 0.0
475 || z.iter().any(|v| !v.is_finite())
476 || row_metric.iter().any(|v| !v.is_finite() || *v < 0.0)
477 || marginal_offset.iter().any(|v| !v.is_finite())
478 || logslope_offset.iter().any(|v| !v.is_finite())
479 {
480 return Err(
481 "reduced logslope reparam requires finite pilot geometry and finite non-negative row metric"
482 .to_string(),
483 );
484 }
485
486 match reduced_logslope_transform_effective(
497 marginal.view(),
498 logslope.view(),
499 z,
500 row_metric,
501 marginal_offset,
502 logslope_offset,
503 marginal_baseline,
504 logslope_baseline,
505 probit_scale,
506 )? {
507 Some(transform) => Ok(Some(ReducedLogslopeReparam { transform })),
508 None => Ok(None),
509 }
510}
511
512pub(crate) fn reduced_logslope_transform_effective(
531 marginal: ArrayView2<'_, f64>,
532 logslope: ArrayView2<'_, f64>,
533 z: &Array1<f64>,
534 row_metric: &Array1<f64>,
535 marginal_offset: &Array1<f64>,
536 logslope_offset: &Array1<f64>,
537 marginal_baseline: f64,
538 logslope_baseline: f64,
539 probit_scale: f64,
540) -> Result<Option<Array2<f64>>, String> {
541 let n = marginal.nrows();
542 let p_m = marginal.ncols();
543 let p_g = logslope.ncols();
544 if p_m == 0 || p_g == 0 {
545 return Ok(None);
546 }
547
548 let mut m_eff = Array2::<f64>::zeros((n, p_m));
550 let mut g_eff = Array2::<f64>::zeros((n, p_g));
551 for i in 0..n {
552 let q_i = marginal_offset[i] + marginal_baseline;
553 let g_i = logslope_offset[i] + logslope_baseline;
554 let sg = probit_scale * g_i;
555 let c_i = (1.0 + sg * sg).sqrt();
556 let f_i = q_i * probit_scale * probit_scale * g_i / c_i + probit_scale * z[i];
557 for j in 0..p_m {
558 m_eff[[i, j]] = c_i * marginal[[i, j]];
559 }
560 for j in 0..p_g {
561 g_eff[[i, j]] = f_i * logslope[[i, j]];
562 }
563 }
564
565 let c_gram = fast_xt_diag_x(&g_eff, row_metric);
568 let energy_scale = (0..p_g).map(|i| c_gram[[i, i]]).fold(0.0_f64, f64::max);
569 if !energy_scale.is_finite() || energy_scale <= 0.0 {
570 return Ok(None);
571 }
572
573 let mut a_gram = fast_xt_diag_x(&m_eff, row_metric);
577 let a_scale = (0..p_m).map(|i| a_gram[[i, i]]).fold(0.0_f64, f64::max);
578 let a_ridge = (a_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL).max(f64::EPSILON);
579 for i in 0..p_m {
580 a_gram[[i, i]] += a_ridge;
581 }
582
583 let b_cross = gam_linalg::faer_ndarray::fast_xt_diag_y(&m_eff, row_metric, &g_eff);
585 let a_view = gam_linalg::faer_ndarray::FaerArrayView::new(&a_gram);
586 let a_factor =
587 gam_linalg::faer_ndarray::factorize_symmetricwith_fallback(a_view.as_ref(), Side::Lower)
588 .map_err(|e| {
589 format!(
590 "reduced logslope reparam: effective marginal Gram factorization failed: {e}"
591 )
592 })?;
593 let b_view = gam_linalg::faer_ndarray::FaerArrayView::new(&b_cross);
594 let solved = a_factor.solve(b_view.as_ref()); let a_inv_b = Array2::from_shape_fn((p_m, p_g), |(i, j)| solved[(i, j)]);
596 let schur = fast_atb(&b_cross, &a_inv_b); let mut stt = &c_gram - &schur;
598 stt = (&stt + &stt.t()) * 0.5;
599 if stt.iter().any(|v| !v.is_finite()) {
600 return Err(
601 "reduced logslope reparam: effective Schur Gram produced non-finite entries"
602 .to_string(),
603 );
604 }
605
606 let (evals, evecs) = stt
607 .eigh(Side::Lower)
608 .map_err(|e| format!("reduced logslope reparam: eigendecomposition failed: {e:?}"))?;
609 let tol = energy_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL;
613 let mut kept: Vec<usize> = (0..evals.len()).filter(|&i| evals[i] > tol).collect();
614 kept.sort_by(|&a, &b| {
615 evals[b]
616 .partial_cmp(&evals[a])
617 .unwrap_or(std::cmp::Ordering::Equal)
618 });
619 let r = kept.len();
620 if r == p_g || r == 0 {
624 return Ok(None);
625 }
626 let mut transform = Array2::<f64>::zeros((p_g, r));
627 for (out_col, &src) in kept.iter().enumerate() {
628 transform.column_mut(out_col).assign(&evecs.column(src));
629 }
630 if transform.iter().any(|v| !v.is_finite()) {
631 return Err(
632 "reduced logslope reparam: reduced transform produced non-finite entries".to_string(),
633 );
634 }
635 Ok(Some(transform))
636}
637
638fn reparameterize_logslope_design_reduced(
645 logslope_design: &TermCollectionDesign,
646 reparam: &ReducedLogslopeReparam,
647) -> Result<TermCollectionDesign, String> {
648 let g = logslope_design
649 .design
650 .try_to_dense_arc("reparameterize_logslope_design_reduced::logslope")?;
651 let p_g = g.ncols();
652 if p_g != reparam.original_cols() {
653 return Err(format!(
654 "reduced logslope reparam width mismatch: design has {p_g} cols, transform expects {}",
655 reparam.original_cols()
656 ));
657 }
658 let t = &reparam.transform;
659 let r = reparam.reduced_cols();
660 let g_reduced = fast_ab(&g, t);
662
663 let mut new_penalties: Vec<gam_terms::smooth::BlockwisePenalty> =
666 Vec::with_capacity(logslope_design.penalties.len());
667 let mut new_nullspace_dims: Vec<usize> = Vec::with_capacity(logslope_design.penalties.len());
668 for bp in &logslope_design.penalties {
669 let mut full = Array2::<f64>::zeros((p_g, p_g));
670 full.slice_mut(s![bp.col_range.clone(), bp.col_range.clone()])
671 .assign(&bp.local);
672 let st = fast_ab(&full, t); let mut s_reduced = fast_atb(t, &st); s_reduced = (&s_reduced + &s_reduced.t()) * 0.5;
676 let (evals, _) = s_reduced
678 .eigh(Side::Lower)
679 .map_err(|e| format!("reduced logslope penalty eigendecomposition failed: {e:?}"))?;
680 let max_eval = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
681 let pen_tol = (max_eval * 1.0e-12).max(f64::EPSILON);
682 let rank = evals.iter().filter(|&&v| v.abs() > pen_tol).count();
683 let nullspace_dim = r.saturating_sub(rank);
684 new_penalties.push(gam_terms::smooth::BlockwisePenalty::new(0..r, s_reduced));
685 new_nullspace_dims.push(nullspace_dim);
686 }
687
688 let new_design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(g_reduced));
689 Ok(TermCollectionDesign {
695 design: new_design,
696 penalties: new_penalties,
697 nullspace_dims: new_nullspace_dims,
698 penaltyinfo: Vec::new(),
699 dropped_penaltyinfo: Vec::new(),
700 coefficient_lower_bounds: None,
701 linear_constraints: None,
702 intercept_range: 0..0,
703 linear_ranges: Vec::new(),
704 random_effect_ranges: Vec::new(),
705 random_effect_levels: Vec::new(),
706 smooth: gam_terms::smooth::SmoothDesign {
707 term_designs: Vec::new(),
708 penalties: Vec::new(),
709 nullspace_dims: Vec::new(),
710 penaltyinfo: Vec::new(),
711 dropped_penaltyinfo: Vec::new(),
712 terms: Vec::new(),
713 coefficient_lower_bounds: None,
714 linear_constraints: None,
715 },
716 })
717}
718
719pub(crate) fn marginal_penalties_with_influence_ridge(
740 design: &TermCollectionDesign,
741 rho_marginal: &Array1<f64>,
742 influence_columns: Option<&Array2<f64>>,
743 influence_ridge_log_lambda: f64,
744) -> Result<(Vec<PenaltyMatrix>, Vec<usize>, Array1<f64>), String> {
745 let p_m = design.design.ncols();
746 let p1 = influence_columns.map(|z| z.ncols()).unwrap_or(0);
747 let total_dim = p_m + p1;
748 let mut penalties: Vec<PenaltyMatrix> = design
751 .penalties
752 .iter()
753 .map(|bp| bp.to_penalty_matrix(total_dim))
754 .collect();
755 let mut nullspace_dims = design.nullspace_dims.clone();
756 let mut log_lambdas = rho_marginal.to_vec();
757
758 if p1 > 0 {
762 penalties.push(
763 PenaltyMatrix::Blockwise {
764 local: Array2::<f64>::eye(p1),
765 col_range: p_m..total_dim,
766 total_dim,
767 }
768 .with_fixed_log_lambda(influence_ridge_log_lambda),
769 );
770 nullspace_dims.push(0);
771 log_lambdas.push(influence_ridge_log_lambda);
772 }
773
774 Ok((penalties, nullspace_dims, Array1::from_vec(log_lambdas)))
775}
776
777pub(crate) fn widen_marginal_beta_hint(
780 beta_hint: Option<Array1<f64>>,
781 p_marginal_widened: usize,
782) -> Option<Array1<f64>> {
783 beta_hint.map(|hint| {
784 if hint.len() == p_marginal_widened {
785 hint
786 } else {
787 let mut widened = Array1::<f64>::zeros(p_marginal_widened);
788 let copy = hint.len().min(p_marginal_widened);
789 widened
790 .slice_mut(s![..copy])
791 .assign(&hint.slice(s![..copy]));
792 widened
793 }
794 })
795}
796
797fn marginal_fitted_eta_sup_norm(design: &TermCollectionDesign, masked_beta: &Array1<f64>) -> f64 {
807 let x = &design.design;
808 let n = x.nrows();
809 if n == 0 || x.ncols() == 0 {
810 return 0.0;
811 }
812 let mut sup = 0.0_f64;
813 for row in 0..n {
814 let eta = x.dot_row_view(row, masked_beta.view());
815 if eta.is_finite() {
816 sup = sup.max(eta.abs());
817 }
818 }
819 sup
820}
821
822fn marginal_design_beta(
825 design: &TermCollectionDesign,
826 block_beta: ArrayView1<'_, f64>,
827) -> Array1<f64> {
828 let ncols = design.design.ncols();
829 let mut masked = Array1::<f64>::zeros(ncols);
830 let copy = ncols.min(block_beta.len());
831 masked
832 .slice_mut(s![..copy])
833 .assign(&block_beta.slice(s![..copy]));
834 masked
835}
836
837fn mask_parametric_columns(
843 design: &TermCollectionDesign,
844 spec: &TermCollectionSpec,
845 full: &Array1<f64>,
846) -> Array1<f64> {
847 let ncols = design.design.ncols();
848 let mut masked = Array1::<f64>::zeros(ncols);
849 if design.intercept_range.len() == 1 {
850 let idx = design.intercept_range.start;
851 if idx < ncols {
852 masked[idx] = full[idx];
853 }
854 }
855 for (linear, (_, range)) in spec.linear_terms.iter().zip(design.linear_ranges.iter()) {
856 if linear.double_penalty {
857 continue;
858 }
859 for col in range.clone() {
860 if col < ncols {
861 masked[col] = full[col];
862 }
863 }
864 }
865 masked
866}
867
868pub(crate) fn bernoulli_marginal_slope_runaway_error_from_beta(
879 block_beta: ArrayView1<'_, f64>,
880 design: &TermCollectionDesign,
881 spec: &TermCollectionSpec,
882 inner_converged: bool,
883 eval_label: &str,
884) -> Option<String> {
885 let full_beta = marginal_design_beta(design, block_beta);
886 let parametric_beta = mask_parametric_columns(design, spec, &full_beta);
887
888 let eta_parametric = marginal_fitted_eta_sup_norm(design, ¶metric_beta);
889 let eta_full = marginal_fitted_eta_sup_norm(design, &full_beta);
890
891 let (eta_inf, explanation) = if eta_parametric >= BMS_PROBIT_SEPARATION_ETA_INF {
892 (
893 eta_parametric,
894 "an unpenalized parametric marginal direction has no stable finite probit optimum and its fitted predictor has run to the probit underflow scale",
895 )
896 } else if eta_full >= BMS_PROBIT_SEPARATION_ETA_INF {
897 (
898 eta_full,
899 "a marginal direction is trading off against the logslope surface; this is the under-constrained marginal/logslope coupling that appears when the score is correlated with the shared surface covariates",
900 )
901 } else {
902 return None;
906 };
907
908 let inner_status = if inner_converged {
909 "the inner solve reached a KKT certificate at this separation-scale predictor"
910 } else {
911 "the inner solve failed while already carrying a separation-scale predictor"
912 };
913 let beta_abs = full_beta
915 .iter()
916 .copied()
917 .filter(|v| v.is_finite())
918 .fold(0.0_f64, |acc, v| acc.max(v.abs()));
919
920 Some(format!(
921 "bernoulli marginal-slope probit marginal/logslope runaway detected in block \
922 'marginal_surface' during {eval_label}: the fitted marginal predictor has \
923 |η|∞={eta_inf:.3e} (numerical-degeneracy threshold \
924 {BMS_PROBIT_SEPARATION_ETA_INF:.1}; raw |β|∞={beta_abs:.3e} is reported for \
925 context only and does not gate this diagnostic). The joint design is \
926 identifiable; {explanation}. {inner_status}. The robust Jeffreys curvature \
927 path is already installed for this fit, so this diagnostic means the current \
928 coupled surface still drives the linear predictor to the probit underflow \
929 scale rather than a request for an external bias-reduction prior. Reduce or \
930 reparameterize the coupled marginal/logslope surface, or use a \
931 lower-dimensional logslope interaction. This is not a \
932 Matérn/Duchon polynomial-nullspace or cross-block gauge-priority \
933 failure."
934 ))
935}
936
937pub(crate) fn bernoulli_marginal_slope_runaway_error(
938 warm_start: &CustomFamilyWarmStart,
939 design: &TermCollectionDesign,
940 spec: &TermCollectionSpec,
941 inner_converged: bool,
942 eval_label: &str,
943) -> Option<String> {
944 let block_beta = warm_start.block_beta_view(0)?;
945 bernoulli_marginal_slope_runaway_error_from_beta(
946 block_beta,
947 design,
948 spec,
949 inner_converged,
950 eval_label,
951 )
952}
953
954#[cfg(test)]
955mod runaway_tests {
956 use super::*;
957 use gam_linalg::faer_ndarray::{FaerArrayView, factorize_symmetricwith_fallback, fast_xt_diag_y};
958 use gam_terms::smooth::{LinearCoefficientGeometry, LinearTermSpec};
959
960 pub(crate) fn marginal_logslope_overlap_penalty(
966 marginal_design: &DesignMatrix,
967 logslope_design: &DesignMatrix,
968 z: &Array1<f64>,
969 row_metric: &Array1<f64>,
970 marginal_offset: &Array1<f64>,
971 logslope_offset: &Array1<f64>,
972 marginal_baseline: f64,
973 logslope_baseline: f64,
974 probit_scale: f64,
975 ) -> Result<Option<Array2<f64>>, String> {
976 let marginal =
977 marginal_design.try_to_dense_arc("marginal_logslope_overlap_penalty::marginal")?;
978 let logslope =
979 logslope_design.try_to_dense_arc("marginal_logslope_overlap_penalty::logslope")?;
980 let n = marginal.nrows();
981 if logslope.nrows() != n
982 || z.len() != n
983 || row_metric.len() != n
984 || marginal_offset.len() != n
985 || logslope_offset.len() != n
986 {
987 return Err(format!(
988 "marginal/logslope overlap penalty row mismatch: marginal={}, logslope={}, z={}, row_metric={}, marginal_offset={}, logslope_offset={}",
989 marginal.nrows(),
990 logslope.nrows(),
991 z.len(),
992 row_metric.len(),
993 marginal_offset.len(),
994 logslope_offset.len(),
995 ));
996 }
997 let p_m = marginal.ncols();
998 let p_g = logslope.ncols();
999 if p_m == 0 || p_g == 0 {
1000 return Ok(None);
1001 }
1002 if !marginal_baseline.is_finite()
1003 || !logslope_baseline.is_finite()
1004 || !probit_scale.is_finite()
1005 || probit_scale <= 0.0
1006 || z.iter().any(|v| !v.is_finite())
1007 || row_metric.iter().any(|v| !v.is_finite() || *v < 0.0)
1008 || marginal_offset.iter().any(|v| !v.is_finite())
1009 || logslope_offset.iter().any(|v| !v.is_finite())
1010 {
1011 return Err(
1012 "marginal/logslope overlap penalty requires finite pilot geometry and finite non-negative row metric"
1013 .to_string(),
1014 );
1015 }
1016
1017 let mut marginal_effective = Array2::<f64>::zeros((n, p_m));
1018 let mut effective_logslope = Array2::<f64>::zeros((n, p_g));
1019 for i in 0..n {
1020 let q_i = marginal_offset[i] + marginal_baseline;
1021 let g_i = logslope_offset[i] + logslope_baseline;
1022 let sg = probit_scale * g_i;
1023 let c_i = (1.0 + sg * sg).sqrt();
1024 let logslope_factor =
1025 q_i * probit_scale * probit_scale * g_i / c_i + probit_scale * z[i];
1026 for j in 0..p_m {
1027 marginal_effective[[i, j]] = c_i * marginal[[i, j]];
1028 }
1029 for j in 0..p_g {
1030 effective_logslope[[i, j]] = logslope_factor * logslope[[i, j]];
1031 }
1032 }
1033 if effective_logslope.iter().all(|v| v.abs() <= f64::EPSILON) {
1034 return Ok(None);
1035 }
1036
1037 let mut gram = fast_xt_diag_x(&effective_logslope, row_metric);
1038 let gram_scale = gram.diag().iter().copied().fold(0.0_f64, f64::max);
1039 if !gram_scale.is_finite() || gram_scale <= 0.0 {
1040 return Ok(None);
1041 }
1042 let projection_ridge = (gram_scale * 1.0e-10).max(f64::EPSILON);
1043 for i in 0..p_g {
1044 gram[[i, i]] += projection_ridge;
1045 }
1046 let cross = fast_xt_diag_y(&effective_logslope, row_metric, &marginal_effective);
1047 let gram_view = FaerArrayView::new(&gram);
1048 let factor = factorize_symmetricwith_fallback(gram_view.as_ref(), Side::Lower)
1049 .map_err(|e| format!("marginal/logslope overlap Gram factorization failed: {e}"))?;
1050 let rhsview = FaerArrayView::new(&cross);
1051 let coeffs_mat = factor.solve(rhsview.as_ref());
1052 let coeffs = Array2::from_shape_fn((p_g, p_m), |(i, j)| coeffs_mat[(i, j)]);
1053 let projected_marginal = fast_ab(&effective_logslope, &coeffs);
1054 let mut penalty = fast_xt_diag_y(&marginal_effective, row_metric, &projected_marginal);
1055 penalty = (&penalty + &penalty.t()) * 0.5;
1056 let max_abs = penalty.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
1057 if !max_abs.is_finite() || max_abs <= 1.0e-12 {
1058 return Ok(None);
1059 }
1060 Ok(Some(penalty))
1061 }
1062
1063 #[test]
1071 pub(crate) fn effective_reduction_drops_score_weighted_confound_raw_audit_misses() {
1072 let m = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1074 let g = Array2::<f64>::from_shape_vec((3, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 9.0]).unwrap();
1075 let z = Array1::from_vec(vec![1.0, 0.5, 1.0 / 3.0]);
1076 let w = Array1::<f64>::ones(3);
1077 let zero = Array1::<f64>::zeros(3);
1078
1079 let reparam = reduced_logslope_transform_effective(
1082 m.view(),
1083 g.view(),
1084 &z,
1085 &w,
1086 &zero,
1087 &zero,
1088 0.0,
1089 0.0,
1090 1.0,
1091 )
1092 .expect("effective reduction must succeed")
1093 .expect("effective audit must reduce the score-weighted confound (raw audit would not)");
1094 assert_eq!(
1095 reparam.ncols(),
1096 1,
1097 "exactly one effective-identifiable logslope direction should survive"
1098 );
1099
1100 let g_eff = {
1104 let mut e = Array2::<f64>::zeros((3, 2));
1105 for i in 0..3 {
1106 for j in 0..2 {
1107 e[[i, j]] = z[i] * g[[i, j]];
1108 }
1109 }
1110 e
1111 };
1112 let img = g_eff.dot(&reparam.column(0));
1113 let mean = img.iter().sum::<f64>() / 3.0;
1114 let var = img.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / 3.0;
1115 assert!(
1116 var > 1.0e-6,
1117 "kept direction must be the identifiable (non-constant) effective column, var={var}"
1118 );
1119 }
1120
1121 #[test]
1127 pub(crate) fn effective_reduction_fully_confounded_single_column_returns_none() {
1128 let m = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1129 let g = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1130 let z = Array1::from_vec(vec![1.0, 0.5, 1.0 / 3.0]);
1131 let w = Array1::<f64>::ones(3);
1132 let zero = Array1::<f64>::zeros(3);
1133 let reparam = reduced_logslope_transform_effective(
1134 m.view(),
1135 g.view(),
1136 &z,
1137 &w,
1138 &zero,
1139 &zero,
1140 0.0,
1141 0.0,
1142 1.0,
1143 )
1144 .expect("effective reduction must succeed");
1145 assert!(
1146 reparam.is_none(),
1147 "fully effective-confounded logslope must keep raw design (None), not a 0-width block"
1148 );
1149 }
1150
1151 #[test]
1154 pub(crate) fn effective_reduction_no_confound_returns_none() {
1155 let m = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1156 let g = Array2::<f64>::from_shape_vec((3, 2), vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
1158 let z = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1159 let w = Array1::<f64>::ones(3);
1160 let zero = Array1::<f64>::zeros(3);
1161 let reparam = reduced_logslope_transform_effective(
1162 m.view(),
1163 g.view(),
1164 &z,
1165 &w,
1166 &zero,
1167 &zero,
1168 0.0,
1169 0.0,
1170 1.0,
1171 )
1172 .expect("effective reduction must succeed");
1173 assert!(
1174 reparam.is_none(),
1175 "no effective confound ⇒ no reduction (raw design kept unchanged)"
1176 );
1177 }
1178
1179 #[test]
1180 pub(crate) fn spatial_joint_setup_counts_only_learned_penalties_in_rho() {
1181 let data = Array2::<f64>::zeros((3, 1));
1182 let empty_terms = TermCollectionSpec {
1183 linear_terms: Vec::new(),
1184 random_effect_terms: Vec::new(),
1185 smooth_terms: Vec::new(),
1186 };
1187 let setup = joint_setup(
1188 data.view(),
1189 &empty_terms,
1190 &empty_terms,
1191 2,
1192 3,
1193 &[0.4],
1194 &SpatialLengthScaleOptimizationOptions::default(),
1195 );
1196
1197 assert_eq!(
1198 setup.rho_dim(),
1199 6,
1200 "BMS spatial setup rho must contain only learned marginal/logslope/auxiliary penalties; fixed physical ridges are carried by PenaltyMatrix::Fixed"
1201 );
1202 }
1203
1204 #[test]
1205 pub(crate) fn overlap_penalty_targets_score_weighted_logslope_span() {
1206 let marginal = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1207 Array2::from_shape_vec((4, 1), vec![0.0, 1.0, 2.0, 3.0]).unwrap(),
1208 ));
1209 let logslope = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1210 Array2::from_shape_vec((4, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap(),
1211 ));
1212 let z = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
1213 let row_metric = Array1::ones(4);
1214 let offsets = Array1::zeros(4);
1215
1216 let penalty = marginal_logslope_overlap_penalty(
1217 &marginal,
1218 &logslope,
1219 &z,
1220 &row_metric,
1221 &offsets,
1222 &offsets,
1223 0.0,
1224 0.0,
1225 1.0,
1226 )
1227 .expect("overlap penalty should build")
1228 .expect("marginal signal lies in the pilot logslope Jacobian span");
1229
1230 assert_eq!(penalty.dim(), (1, 1));
1231 assert!((penalty[[0, 0]] - 14.0).abs() < 1.0e-6);
1232 }
1233
1234 #[test]
1235 pub(crate) fn overlap_penalty_skips_weight_orthogonal_channels() {
1236 let marginal = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1237 Array2::from_shape_vec((4, 1), vec![-1.0, 1.0, -1.0, 1.0]).unwrap(),
1238 ));
1239 let logslope = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1240 Array2::from_shape_vec((4, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap(),
1241 ));
1242 let z = Array1::ones(4);
1243 let row_metric = Array1::ones(4);
1244 let offsets = Array1::zeros(4);
1245
1246 let penalty = marginal_logslope_overlap_penalty(
1247 &marginal,
1248 &logslope,
1249 &z,
1250 &row_metric,
1251 &offsets,
1252 &offsets,
1253 0.0,
1254 0.0,
1255 1.0,
1256 )
1257 .expect("overlap penalty should build");
1258
1259 assert!(penalty.is_none());
1260 }
1261
1262 fn dense_marginal_design(
1270 x: Array2<f64>,
1271 intercept_range: std::ops::Range<usize>,
1272 linear_ranges: Vec<(String, std::ops::Range<usize>)>,
1273 ) -> TermCollectionDesign {
1274 TermCollectionDesign {
1275 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x)),
1276 penalties: Vec::new(),
1277 nullspace_dims: Vec::new(),
1278 penaltyinfo: Vec::new(),
1279 dropped_penaltyinfo: Vec::new(),
1280 coefficient_lower_bounds: None,
1281 linear_constraints: None,
1282 intercept_range,
1283 linear_ranges,
1284 random_effect_ranges: Vec::new(),
1285 random_effect_levels: Vec::new(),
1286 smooth: gam_terms::smooth::SmoothDesign {
1287 term_designs: Vec::new(),
1288 penalties: Vec::new(),
1289 nullspace_dims: Vec::new(),
1290 penaltyinfo: Vec::new(),
1291 dropped_penaltyinfo: Vec::new(),
1292 terms: Vec::new(),
1293 coefficient_lower_bounds: None,
1294 linear_constraints: None,
1295 },
1296 }
1297 }
1298
1299 fn linear_term(name: &str, feature_col: usize) -> LinearTermSpec {
1300 LinearTermSpec {
1301 name: name.to_string(),
1302 feature_col,
1303 feature_cols: vec![feature_col],
1304 categorical_levels: vec![],
1305 double_penalty: false,
1306 coefficient_geometry: LinearCoefficientGeometry::default(),
1307 coefficient_min: None,
1308 coefficient_max: None,
1309 }
1310 }
1311
1312 fn empty_spec() -> TermCollectionSpec {
1313 TermCollectionSpec {
1314 linear_terms: Vec::new(),
1315 random_effect_terms: Vec::new(),
1316 smooth_terms: Vec::new(),
1317 }
1318 }
1319
1320 #[test]
1327 pub(crate) fn runaway_guard_silent_when_huge_beta_cancels_to_bounded_eta() {
1328 let x = Array2::<f64>::from_shape_vec((4, 2), vec![1.0; 8]).unwrap();
1330 let design = dense_marginal_design(x, 0..0, Vec::new());
1331 let beta = Array1::from_vec(vec![60.0, -60.0]);
1332
1333 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1334 beta.view(),
1335 &design,
1336 &empty_spec(),
1337 true,
1338 "regression-fixture",
1339 );
1340 assert!(
1341 msg.is_none(),
1342 "huge cancelling β with bounded fitted η must NOT trip the runaway guard; got {msg:?}"
1343 );
1344 }
1345
1346 #[test]
1350 pub(crate) fn runaway_guard_fires_when_fitted_eta_exceeds_threshold() {
1351 let x = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1352 let design = dense_marginal_design(x, 0..0, Vec::new());
1353 let beta = Array1::from_vec(vec![40.0]);
1354
1355 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1356 beta.view(),
1357 &design,
1358 &empty_spec(),
1359 true,
1360 "separation-fixture",
1361 )
1362 .expect("fitted |η|∞=40 ≥ 35 must trip the runaway guard");
1363
1364 assert!(msg.contains("marginal/logslope runaway"));
1365 assert!(msg.contains("|η|∞"));
1366 assert!(msg.contains("4.000e1"));
1367 assert!(msg.contains("score is correlated with the shared surface covariates"));
1368 assert!(msg.contains("not a Matérn/Duchon polynomial-nullspace"));
1369 assert!(msg.contains("KKT certificate"));
1370 }
1371
1372 #[test]
1376 pub(crate) fn runaway_guard_names_unpenalized_parametric_direction_via_fitted_eta() {
1377 let x = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1378 let design = dense_marginal_design(x, 0..0, vec![("sex".to_string(), 0..1)]);
1379 let mut spec = empty_spec();
1380 spec.linear_terms.push(linear_term("sex", 0));
1381 let beta = Array1::from_vec(vec![41.0]);
1382
1383 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1384 beta.view(),
1385 &design,
1386 &spec,
1387 true,
1388 "parametric-fixture",
1389 )
1390 .expect("parametric fitted |η|∞=41 ≥ 35 must trip the runaway guard");
1391
1392 assert!(msg.contains("unpenalized parametric marginal direction"));
1393 assert!(msg.contains("|η|∞"));
1394 assert!(msg.contains("robust Jeffreys curvature path is already installed"));
1395 assert!(msg.contains("not a Matérn/Duchon polynomial-nullspace"));
1396 }
1397
1398 #[test]
1402 pub(crate) fn runaway_guard_silent_for_nonconverged_but_bounded_eta() {
1403 let x = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1404 let design = dense_marginal_design(x, 0..0, Vec::new());
1405 let beta = Array1::from_vec(vec![5.0]);
1406
1407 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1408 beta.view(),
1409 &design,
1410 &empty_spec(),
1411 false,
1412 "nonconverged-fixture",
1413 );
1414 assert!(
1415 msg.is_none(),
1416 "bounded fitted η must not raise the separation error even when the inner solve did not converge; got {msg:?}"
1417 );
1418 }
1419
1420 #[test]
1423 pub(crate) fn runaway_guard_fires_for_nonconverged_separating_eta() {
1424 let x = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1425 let design = dense_marginal_design(x, 0..0, Vec::new());
1426 let beta = Array1::from_vec(vec![50.0]);
1427
1428 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1429 beta.view(),
1430 &design,
1431 &empty_spec(),
1432 false,
1433 "nonconverged-separating-fixture",
1434 )
1435 .expect("separating |η|∞ at non-convergence must still trip the guard");
1436
1437 assert!(msg.contains(
1438 "the inner solve failed while already carrying a separation-scale predictor"
1439 ));
1440 }
1441
1442 #[test]
1455 pub(crate) fn bms_block_jacobians_self_compute_at_audit_empty_beta_nonzero_logslope_baseline() {
1456 use std::sync::Arc;
1457 let n = 4usize;
1458 let marginal =
1459 Arc::new(Array2::<f64>::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap());
1460 let logslope =
1461 Arc::new(Array2::<f64>::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap());
1462 let offset_m = Array1::<f64>::zeros(n);
1463 let g_baseline = 0.3_f64;
1466 let offset_s = Array1::<f64>::from_elem(n, g_baseline);
1467 let z = Arc::new(Array1::from_vec(vec![-0.7, 0.2, 0.9, 1.4]));
1468 let s = 1.0_f64;
1469
1470 let beta: Vec<f64> = Vec::new();
1472 let state = FamilyLinearizationState {
1473 beta: &beta,
1474 family_scalars: None,
1475 channel_hessian: None,
1476 probit_frailty_scale: s,
1477 };
1478
1479 let marginal_jac = BmsMarginalJacobian::new(
1480 Arc::clone(&marginal),
1481 Arc::clone(&logslope),
1482 offset_m.clone(),
1483 offset_s.clone(),
1484 1,
1485 );
1486 let j_m = marginal_jac
1487 .effective_jacobian_rows(&state, 0..n)
1488 .expect("BMS marginal Jacobian must self-compute at audit empty β (gam#370)");
1489 let c_expected = (1.0 + (s * g_baseline).powi(2)).sqrt();
1491 assert_eq!(j_m.dim(), (n, 1));
1492 for i in 0..n {
1493 assert!(
1494 (j_m[[i, 0]] - c_expected).abs() < 1e-12,
1495 "marginal J[{i}] = {} != closed-form c_i = {c_expected}",
1496 j_m[[i, 0]]
1497 );
1498 }
1499
1500 let logslope_jac = BmsLogslopeJacobian::new(
1501 Arc::clone(&marginal),
1502 Arc::clone(&logslope),
1503 offset_m,
1504 offset_s,
1505 Arc::clone(&z),
1506 1,
1507 );
1508 let j_s = logslope_jac
1509 .effective_jacobian_rows(&state, 0..n)
1510 .expect("BMS logslope Jacobian must self-compute at audit empty β (gam#370)");
1511 assert_eq!(j_s.dim(), (n, 1));
1514 for i in 0..n {
1515 let expected = s * z[i];
1516 assert!(
1517 (j_s[[i, 0]] - expected).abs() < 1e-12,
1518 "logslope J[{i}] = {} != closed-form factor {expected}",
1519 j_s[[i, 0]]
1520 );
1521 assert!(j_s[[i, 0]].is_finite());
1522 }
1523 }
1524}
1525
1526pub(crate) fn build_marginal_blockspec_bms(
1527 design: &TermCollectionDesign,
1528 baseline: f64,
1529 offset: &Array1<f64>,
1530 rho: Array1<f64>,
1531 beta_hint: Option<Array1<f64>>,
1532 logslope_design: &TermCollectionDesign,
1533 logslope_offset: &Array1<f64>,
1534 logslope_baseline: f64,
1535 p_marginal: usize,
1536 influence_columns: Option<&Array2<f64>>,
1537 influence_ridge_log_lambda: f64,
1538) -> Result<ParameterBlockSpec, String> {
1539 let offset_m = offset + baseline;
1540 let offset_s = logslope_offset + logslope_baseline;
1541 let raw_marginal_dense = design
1542 .design
1543 .try_to_dense_arc("build_marginal_blockspec_bms::marginal")?;
1544 let marginal_dense =
1545 widen_marginal_dense_with_influence(&raw_marginal_dense, influence_columns)?;
1546 let logslope_dense = logslope_design
1547 .design
1548 .try_to_dense_arc("build_marginal_blockspec_bms::logslope")?;
1549 let callback: Arc<dyn BlockEffectiveJacobian> = Arc::new(BmsMarginalJacobian {
1550 marginal_dense: Arc::clone(&marginal_dense),
1551 logslope_dense,
1552 offset_m: offset_m.clone(),
1553 offset_s,
1554 p_marginal,
1555 });
1556 let (penalties, nullspace_dims, initial_log_lambdas) = marginal_penalties_with_influence_ridge(
1557 design,
1558 &rho,
1559 influence_columns,
1560 influence_ridge_log_lambda,
1561 )?;
1562 Ok(ParameterBlockSpec {
1563 name: "marginal_surface".to_string(),
1564 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1565 (*marginal_dense).clone(),
1566 )),
1567 offset: offset_m,
1568 penalties,
1569 nullspace_dims,
1570 initial_log_lambdas,
1571 initial_beta: widen_marginal_beta_hint(beta_hint, p_marginal),
1572 gauge_priority: GAUGE_PRIORITY_MARGINAL,
1585 jacobian_callback: Some(callback),
1586 stacked_design: None,
1587 stacked_offset: None,
1588 })
1589}
1590
1591pub(crate) fn build_logslope_blockspec_bms(
1592 design: &TermCollectionDesign,
1593 baseline: f64,
1594 offset: &Array1<f64>,
1595 rho: Array1<f64>,
1596 beta_hint: Option<Array1<f64>>,
1597 marginal_design: &TermCollectionDesign,
1598 marginal_offset: &Array1<f64>,
1599 marginal_baseline: f64,
1600 z: Arc<Array1<f64>>,
1601 p_marginal: usize,
1602 influence_columns: Option<&Array2<f64>>,
1603) -> Result<ParameterBlockSpec, String> {
1604 let offset_s = offset + baseline;
1605 let offset_m = marginal_offset + marginal_baseline;
1606 let raw_marginal_dense = marginal_design
1607 .design
1608 .try_to_dense_arc("build_logslope_blockspec_bms::marginal")?;
1609 let marginal_dense =
1614 widen_marginal_dense_with_influence(&raw_marginal_dense, influence_columns)?;
1615 let logslope_dense = design
1616 .design
1617 .try_to_dense_arc("build_logslope_blockspec_bms::logslope")?;
1618 let callback: Arc<dyn BlockEffectiveJacobian> = Arc::new(BmsLogslopeJacobian {
1619 marginal_dense,
1620 logslope_dense: Arc::clone(&logslope_dense),
1621 offset_m,
1622 offset_s: offset_s.clone(),
1623 z,
1624 p_marginal,
1625 });
1626 Ok(ParameterBlockSpec {
1627 name: "logslope_surface".to_string(),
1628 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1629 (*logslope_dense).clone(),
1630 )),
1631 offset: offset_s,
1632 penalties: design.penalties_as_penalty_matrix(),
1633 nullspace_dims: design.nullspace_dims.clone(),
1634 initial_log_lambdas: rho,
1635 initial_beta: beta_hint,
1636 gauge_priority: GAUGE_PRIORITY_LOGSLOPE,
1644 jacobian_callback: Some(callback),
1645 stacked_design: None,
1646 stacked_offset: None,
1647 })
1648}
1649
1650pub(crate) fn build_deviation_aux_blockspec(
1651 name: &str,
1652 prepared: &DeviationPrepared,
1653 rho: Array1<f64>,
1654 beta_hint: Option<Array1<f64>>,
1655) -> Result<ParameterBlockSpec, String> {
1656 let mut block = prepared.block.clone();
1657 block.initial_log_lambdas = Some(rho);
1658 let candidate_beta = beta_hint.or_else(|| Some(Array1::<f64>::zeros(block.design.ncols())));
1659 block.initial_beta = candidate_beta
1660 .map(|beta| {
1661 let zero = Array1::<f64>::zeros(beta.len());
1662 project_monotone_feasible_beta(&prepared.runtime, &zero, &beta, name)
1663 })
1664 .transpose()?;
1665 let mut spec = block.intospec(name)?;
1666 spec.gauge_priority = match name {
1676 "link_dev" => GAUGE_PRIORITY_LINK_DEV,
1677 "score_warp_dev" => GAUGE_PRIORITY_SCORE_WARP_DEV,
1684 _ => GAUGE_PRIORITY_DEVIATION_DEFAULT,
1685 };
1686 Ok(spec)
1687}
1688
1689pub(crate) fn push_deviation_aux_blockspecs(
1690 blocks: &mut Vec<ParameterBlockSpec>,
1691 rho: &Array1<f64>,
1692 cursor: &mut usize,
1693 score_warp_prepared: Option<&DeviationPrepared>,
1694 link_dev_prepared: Option<&DeviationPrepared>,
1695 score_warp_beta_hint: Option<Array1<f64>>,
1696 link_dev_beta_hint: Option<Array1<f64>>,
1697) -> Result<(), String> {
1698 if let Some(prepared) = score_warp_prepared {
1699 let rho_h = rho
1700 .slice(s![*cursor..*cursor + prepared.block.penalties.len()])
1701 .to_owned();
1702 *cursor += prepared.block.penalties.len();
1703 blocks.push(build_deviation_aux_blockspec(
1704 "score_warp_dev",
1705 prepared,
1706 rho_h,
1707 score_warp_beta_hint,
1708 )?);
1709 }
1710 if let Some(prepared) = link_dev_prepared {
1711 let rho_w = rho
1712 .slice(s![*cursor..*cursor + prepared.block.penalties.len()])
1713 .to_owned();
1714 blocks.push(build_deviation_aux_blockspec(
1715 "link_dev",
1716 prepared,
1717 rho_w,
1718 link_dev_beta_hint,
1719 )?);
1720 }
1721 Ok(())
1722}
1723
1724fn inner_fit(
1725 family: &BernoulliMarginalSlopeFamily,
1726 blocks: &[ParameterBlockSpec],
1727 options: &BlockwiseFitOptions,
1728) -> Result<UnifiedFitResult, String> {
1729 let mut options = options.clone();
1730 options.use_outer_hessian = false;
1735 options.outer_tol = options.outer_tol.max(2.0e-5);
1736 fit_custom_family(family, blocks, &options).map_err(|e| e.to_string())
1737}
1738
1739pub fn fit_bernoulli_marginal_slope_terms(
1740 data: ArrayView2<'_, f64>,
1741 spec: BernoulliMarginalSlopeTermSpec,
1742 options: &BlockwiseFitOptions,
1743 kappa_options: &SpatialLengthScaleOptimizationOptions,
1744 policy: &gam_runtime::resource::ResourcePolicy,
1745) -> Result<BernoulliMarginalSlopeFitResult, String> {
1746 let mut spec = spec;
1747 let data_view = data;
1748 validate_spec(data_view, &spec)?;
1749 let mjs_frozen_marginal =
1758 gam_terms::smooth::freeze_measure_jet_length_scale_learning(&mut spec.marginalspec);
1759 let mjs_frozen_logslope =
1760 gam_terms::smooth::freeze_measure_jet_length_scale_learning(&mut spec.logslopespec);
1761 if mjs_frozen_marginal + mjs_frozen_logslope > 0 {
1762 log::info!(
1763 "[BMS spatial] froze measure-jet length-scale learning on {} marginal + {} log-slope \
1764 term(s): the coupled surface keeps ℓ at its conditioned auto value (#1116)",
1765 mjs_frozen_marginal,
1766 mjs_frozen_logslope
1767 );
1768 }
1769 let mut effective_kappa_options = kappa_options.clone();
1770 let kappa_locked_marginal = gam_terms::smooth::all_spatial_terms_kappa_fixed(&spec.marginalspec);
1780 let kappa_locked_logslope = gam_terms::smooth::all_spatial_terms_kappa_fixed(&spec.logslopespec);
1781 if effective_kappa_options.enabled && kappa_locked_marginal && kappa_locked_logslope {
1782 log::info!(
1783 "[BMS spatial] disabling κ/ψ optimization: every spatial term has an \
1784 explicit length_scale and no anisotropy; user-supplied kernel scale is fixed"
1785 );
1786 effective_kappa_options.enabled = false;
1787 }
1788 let flex_spatial_pilot_path = (spec.score_warp.is_some() || spec.link_dev.is_some())
1789 && spec.y.len() >= BMS_FLEX_SPATIAL_OUTER_PILOT_ROW_THRESHOLD
1790 && effective_kappa_options.enabled;
1791 if flex_spatial_pilot_path {
1792 let marginal_terms = spatial_length_scale_term_indices(&spec.marginalspec);
1793 let logslope_terms = spatial_length_scale_term_indices(&spec.logslopespec);
1794 let marginal_updates = apply_spatial_anisotropy_pilot_initializer(
1795 data_view,
1796 &mut spec.marginalspec,
1797 &marginal_terms,
1798 effective_kappa_options.pilot_subsample_threshold,
1799 &effective_kappa_options,
1800 );
1801 let logslope_updates = apply_spatial_anisotropy_pilot_initializer(
1802 data_view,
1803 &mut spec.logslopespec,
1804 &logslope_terms,
1805 effective_kappa_options.pilot_subsample_threshold,
1806 &effective_kappa_options,
1807 );
1808 effective_kappa_options.enabled = false;
1809 log::info!(
1810 "[BMS spatial] n={} flex=true pilot_geometry_updates={} iterative_spatial_outer=false reason=large-flex-spatial-pilot",
1811 spec.y.len(),
1812 marginal_updates + logslope_updates,
1813 );
1814 }
1815 let (z_standardized, z_normalization) = standardize_latent_z_with_policy(
1816 &spec.z,
1817 &spec.weights,
1818 "bernoulli-marginal-slope",
1819 &spec.latent_z_policy,
1820 )?;
1821 spec.z = z_standardized;
1822 let sigma_learnable = matches!(
1823 &spec.frailty,
1824 FrailtySpec::GaussianShift { sigma_fixed: None }
1825 );
1826 let initial_sigma = match &spec.frailty {
1827 FrailtySpec::GaussianShift {
1828 sigma_fixed: Some(s),
1829 } => Some(*s),
1830 FrailtySpec::GaussianShift { sigma_fixed: None } => Some(0.5),
1831 FrailtySpec::None => None,
1832 FrailtySpec::HazardMultiplier { .. } => {
1833 return Err(
1834 "internal: validate_spec should have rejected unsupported marginal-slope frailty"
1835 .to_string(),
1836 );
1837 }
1838 };
1839 let probit_scale = probit_frailty_scale(initial_sigma);
1840 let (_raw_joint_designs, mut joint_specs) = build_term_collection_designs_and_freeze_joint(
1841 data_view,
1842 &[spec.marginalspec.clone(), spec.logslopespec.clone()],
1843 )
1844 .map_err(|e| e.to_string())?;
1845 let marginalspec_boot = joint_specs.remove(0);
1846 let logslopespec_boot = joint_specs.remove(0);
1847 let (mut joint_designs, _) = build_term_collection_designs_and_freeze_joint(
1864 data_view,
1865 &[marginalspec_boot.clone(), logslopespec_boot.clone()],
1866 )
1867 .map_err(|e| format!("failed to rebuild frozen probe BMS joint designs: {e}"))?;
1868 let marginal_design = joint_designs.remove(0);
1869 let logslope_design = joint_designs.remove(0);
1870 let absorber_active = spec
1877 .score_influence_jacobian
1878 .as_ref()
1879 .is_some_and(|j| j.ncols() > 0);
1880 let conditioning_dense = if absorber_active {
1881 None
1882 } else {
1883 Some(
1884 marginal_design
1885 .design
1886 .try_to_dense_arc("bernoulli marginal-slope conditional latent-z gate")?,
1887 )
1888 };
1889 let (latent_measure, latent_z_calibration) = build_latent_measure_with_geometry(
1890 &spec.z,
1891 &spec.weights,
1892 &spec.latent_z_policy,
1893 conditioning_dense.as_ref().map(|d| d.view()),
1894 )?;
1895 if latent_measure.is_empirical() && sigma_learnable {
1896 return Err("empirical latent-measure marginal-slope calibration requires fixed GaussianShift sigma; learnable sigma derivatives must be fit under the standard-normal latent measure"
1897 .to_string());
1898 }
1899
1900 let y = Arc::new(spec.y.clone());
1901 let weights = Arc::new(spec.weights.clone());
1902 let z = match &latent_z_calibration {
1907 LatentMeasureCalibration::None => Arc::new(spec.z.clone()),
1908 LatentMeasureCalibration::RankInverseNormal(cal) => {
1909 Arc::new(cal.apply_to_training(&spec.z)?)
1910 }
1911 LatentMeasureCalibration::ConditionalLocationScale(cal) => {
1912 let a_block = conditioning_dense.as_ref().ok_or_else(|| {
1915 "conditional latent calibration requires the marginal conditioning block"
1916 .to_string()
1917 })?;
1918 Arc::new(cal.apply(spec.z.view(), a_block.view())?)
1919 }
1920 };
1921 let z_train = z.as_ref();
1922 let pilot_baseline = pooled_probit_baseline(&spec.y, z_train, &spec.weights)?;
1923 let baseline = (
1924 bernoulli_marginal_slope_eta_from_probability(
1925 &spec.base_link,
1926 normal_cdf(pilot_baseline.0),
1927 "bernoulli marginal-slope baseline link inversion",
1928 )?,
1929 pilot_baseline.1 / probit_scale,
1930 );
1931
1932 let rigid_pilot_eta = rigid_pooled_probit_pilot_eta(
1975 &spec.base_link,
1976 z_train,
1977 &spec.marginal_offset,
1978 &spec.logslope_offset,
1979 baseline.0,
1980 baseline.1,
1981 probit_scale,
1982 )?;
1983 let cross_block_pilot_w_score_warp =
1984 pilot_irls_hessian_row_metric_at_eta(&rigid_pilot_eta, &spec.weights);
1985
1986 let influence_columns = if let Some(jac) = spec
1998 .score_influence_jacobian
1999 .as_ref()
2000 .filter(|j| j.ncols() > 0)
2001 {
2002 let protected_design = DesignMatrix::hstack(vec![
2003 marginal_design.design.clone(),
2004 logslope_design.design.clone(),
2005 ])
2006 .map_err(|e| {
2007 format!(
2008 "bernoulli marginal-slope influence-block protected projection stack failed to concatenate marginal + logslope design: {e}"
2009 )
2010 })?;
2011 let protected_dense_for_proj = protected_design
2012 .try_to_dense_arc("bernoulli marginal-slope influence-block protected projection")?;
2013 let protected_dense = protected_dense_for_proj.as_ref();
2014 if jac.nrows() != protected_dense.nrows() {
2015 return Err(format!(
2016 "influence block: Jacobian has {} rows, protected design has {}",
2017 jac.nrows(),
2018 protected_dense.nrows()
2019 ));
2020 }
2021 let rigid_logslope_at_rows = &spec.logslope_offset + baseline.1;
2032 let residualized = crate::marginal_slope_orthogonal::residualized_influence_block(
2033 jac,
2034 z_train,
2035 &rigid_logslope_at_rows,
2036 probit_scale,
2037 protected_dense.view(),
2038 &cross_block_pilot_w_score_warp,
2039 )?;
2040 Some(residualized)
2041 } else {
2042 None
2043 };
2044 let mut cross_block_warnings: Vec<CrossBlockIdentifiabilityWarning> = Vec::new();
2045 let score_warp_prepared = if let Some(cfg) = spec.score_warp.as_ref() {
2046 use super::deviation_runtime::ParametricAnchorBlock;
2047 let mut prepared = build_score_warp_deviation_block_from_seed(z_train, cfg)?;
2048 let outcome = install_compiled_flex_block_into_runtime(
2053 &mut prepared,
2054 z_train,
2055 cfg,
2056 &[
2057 (&marginal_design.design, ParametricAnchorBlock::Marginal),
2058 (&logslope_design.design, ParametricAnchorBlock::Logslope),
2059 ],
2060 &[],
2061 &cross_block_pilot_w_score_warp,
2062 )?;
2063 match outcome {
2064 FlexCompileOutcome::Reparameterised => Some(prepared),
2065 FlexCompileOutcome::FullyAliased { reason } => {
2066 cross_block_warnings.push(CrossBlockIdentifiabilityWarning {
2072 candidate_label: "score_warp",
2073 anchor_summary: "marginal+logslope".to_string(),
2074 reason,
2075 });
2076 Some(prepared)
2077 }
2078 }
2079 } else {
2080 None
2081 };
2082 let link_dev_prepared = if let Some(cfg) = spec.link_dev.as_ref() {
2108 let eta_pilot = pilot_eta_for_link_dev_orthogonalisation(
2109 &spec.base_link,
2110 &spec.y,
2111 z_train,
2112 &spec.weights,
2113 &marginal_design.design,
2114 &spec.marginal_offset,
2115 &spec.logslope_offset,
2116 baseline.0,
2117 baseline.1,
2118 probit_scale,
2119 )?;
2120 let link_dev_seed = padded_deviation_seed(&eta_pilot, 1.0, 0.5);
2121 let mut prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
2122 &link_dev_seed,
2123 &eta_pilot,
2124 cfg,
2125 )?;
2126 let score_warp_anchor_design = score_warp_prepared
2163 .as_ref()
2164 .map(|sw| sw.runtime.design_at_training_with_residual(z_train))
2165 .transpose()?;
2166 use super::deviation_runtime::ParametricAnchorBlock;
2167 let parametric_anchors: [(&DesignMatrix, ParametricAnchorBlock); 2] = [
2168 (&marginal_design.design, ParametricAnchorBlock::Marginal),
2169 (&logslope_design.design, ParametricAnchorBlock::Logslope),
2170 ];
2171 let flex_anchor_slot: Option<&Array2<f64>> = score_warp_anchor_design.as_ref();
2172 let flex_anchors: Vec<&Array2<f64>> = flex_anchor_slot.into_iter().collect();
2173 let cross_block_pilot_w_link_dev =
2178 pilot_irls_hessian_row_metric_at_eta(&eta_pilot, &spec.weights);
2179 let outcome = install_compiled_flex_block_into_runtime(
2180 &mut prepared,
2181 &eta_pilot,
2182 cfg,
2183 ¶metric_anchors,
2184 &flex_anchors,
2185 &cross_block_pilot_w_link_dev,
2186 )?;
2187 match outcome {
2188 FlexCompileOutcome::Reparameterised => Some(prepared),
2189 FlexCompileOutcome::FullyAliased { reason } => {
2190 cross_block_warnings.push(CrossBlockIdentifiabilityWarning {
2196 candidate_label: "link_deviation",
2197 anchor_summary: "marginal+logslope+score_warp".to_string(),
2198 reason,
2199 });
2200 Some(prepared)
2201 }
2202 }
2203 } else {
2204 None
2205 };
2206 let extra_rho0 = {
2207 let mut out = Vec::new();
2208 if let Some(ref prepared) = score_warp_prepared {
2209 out.extend(std::iter::repeat_n(0.0, prepared.block.penalties.len()));
2210 }
2211 if let Some(ref prepared) = link_dev_prepared {
2212 out.extend(std::iter::repeat_n(0.0, prepared.block.penalties.len()));
2213 }
2214 out
2215 };
2216 let logslope_reduced_reparam: Option<ReducedLogslopeReparam> = build_reduced_logslope_reparam(
2229 &marginal_design,
2230 &logslope_design,
2231 z.as_ref(),
2232 &cross_block_pilot_w_score_warp,
2233 &spec.marginal_offset,
2234 &spec.logslope_offset,
2235 baseline.0,
2236 baseline.1,
2237 probit_scale,
2238 )?;
2239 let reduce_logslope_design =
2245 |logslope_design: &TermCollectionDesign| -> Result<TermCollectionDesign, String> {
2246 match logslope_reduced_reparam.as_ref() {
2247 Some(reparam) => reparameterize_logslope_design_reduced(logslope_design, reparam),
2248 None => Ok(logslope_design.clone()),
2249 }
2250 };
2251
2252 let marginal_penalty_count = marginal_design.penalties.len();
2253 let setup = joint_setup(
2254 data_view,
2255 &marginalspec_boot,
2256 &logslopespec_boot,
2257 marginal_penalty_count,
2258 logslope_design.penalties.len(),
2259 &extra_rho0,
2260 &effective_kappa_options,
2261 );
2262 let setup = if sigma_learnable {
2263 setup.with_auxiliary(
2264 Array1::from_vec(vec![initial_sigma.expect("learnable sigma seed").ln()]),
2265 Array1::from_vec(vec![0.01_f64.ln()]),
2266 Array1::from_vec(vec![5.0_f64.ln()]),
2267 )
2268 } else {
2269 setup
2270 };
2271 let final_sigma_cell = std::cell::Cell::new(initial_sigma);
2272 let exact_warm_start = RefCell::new(None::<CustomFamilyWarmStart>);
2273 let runaway_error = RefCell::new(None::<String>);
2274 let pending_beta_seed = RefCell::new(None::<Array1<f64>>);
2281 let hints = RefCell::new(ThetaHints::default());
2282 let score_warp_runtime = score_warp_prepared.as_ref().map(|p| p.runtime.clone());
2283 let link_dev_runtime = link_dev_prepared.as_ref().map(|p| p.runtime.clone());
2284
2285 let build_blocks = |rho: &Array1<f64>,
2286 marginal_design: &TermCollectionDesign,
2287 logslope_design: &TermCollectionDesign|
2288 -> Result<Vec<ParameterBlockSpec>, String> {
2289 let hints = hints.borrow();
2290 let mut cursor = 0usize;
2291 let logslope_design_reduced = reduce_logslope_design(logslope_design)?;
2298 let logslope_design = &logslope_design_reduced;
2299 let rho_marginal = rho
2304 .slice(s![cursor..cursor + marginal_design.penalties.len()])
2305 .to_owned();
2306 cursor += marginal_design.penalties.len();
2307 let rho_logslope = rho
2308 .slice(s![cursor..cursor + logslope_design.penalties.len()])
2309 .to_owned();
2310 cursor += logslope_design.penalties.len();
2311 let p_m = marginal_design.design.ncols()
2312 + influence_columns.as_ref().map(|z| z.ncols()).unwrap_or(0);
2313 let mut blocks = vec![
2314 build_marginal_blockspec_bms(
2315 marginal_design,
2316 baseline.0,
2317 &spec.marginal_offset,
2318 rho_marginal,
2319 hints.marginal_beta.clone(),
2320 logslope_design,
2321 &spec.logslope_offset,
2322 baseline.1,
2323 p_m,
2324 influence_columns.as_ref(),
2325 influence_absorber_log_lambda(spec.z.len()),
2326 )?,
2327 build_logslope_blockspec_bms(
2328 logslope_design,
2329 baseline.1,
2330 &spec.logslope_offset,
2331 rho_logslope,
2332 hints.logslope_beta.clone(),
2333 marginal_design,
2334 &spec.marginal_offset,
2335 baseline.0,
2336 Arc::clone(&z),
2337 p_m,
2338 influence_columns.as_ref(),
2339 )?,
2340 ];
2341 push_deviation_aux_blockspecs(
2342 &mut blocks,
2343 rho,
2344 &mut cursor,
2345 score_warp_prepared.as_ref(),
2346 link_dev_prepared.as_ref(),
2347 hints.score_warp_beta.clone(),
2348 hints.link_dev_beta.clone(),
2349 )?;
2350 Ok(blocks)
2351 };
2352
2353 let intercept_warm_starts = new_intercept_warm_start_cache(y.len());
2354 let cell_moment_lru = new_cell_moment_lru_cache(policy);
2355 let cell_moment_cache_stats = new_cell_moment_cache_stats();
2356 let make_family = |marginal_design: &TermCollectionDesign,
2357 logslope_design: &TermCollectionDesign,
2358 sigma: Option<f64>|
2359 -> BernoulliMarginalSlopeFamily {
2360 let kernel_marginal_design = match influence_columns.as_ref() {
2366 Some(z_infl) => {
2367 let raw = marginal_design
2368 .design
2369 .try_to_dense_arc("make_family::widened-marginal")
2370 .expect("dense marginal design for influence widening");
2371 let widened = widen_marginal_dense_with_influence(&raw, Some(z_infl))
2372 .expect("widen marginal design with influence columns");
2373 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from((*widened).clone()))
2374 }
2375 None => marginal_design.design.clone(),
2376 };
2377 let kernel_logslope_design = reduce_logslope_design(logslope_design)
2383 .expect("reduce logslope design for family construction")
2384 .design;
2385 BernoulliMarginalSlopeFamily {
2386 y: Arc::clone(&y),
2387 weights: Arc::clone(&weights),
2388 z: Arc::clone(&z),
2389 latent_measure: latent_measure.clone(),
2390 gaussian_frailty_sd: sigma,
2391 base_link: spec.base_link.clone(),
2392 marginal_design: kernel_marginal_design,
2393 logslope_design: kernel_logslope_design,
2394 score_warp: score_warp_runtime.clone(),
2395 link_dev: link_dev_runtime.clone(),
2396 policy: policy.clone(),
2397 cell_moment_lru: Arc::clone(&cell_moment_lru),
2398 cell_moment_cache_stats: Arc::clone(&cell_moment_cache_stats),
2399 intercept_warm_starts: Some(Arc::clone(&intercept_warm_starts)),
2400 auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
2401 auto_subsample_last_rho: Arc::new(Mutex::new(None)),
2402 }
2403 };
2404
2405 let marginal_terms = spatial_length_scale_term_indices(&marginalspec_boot);
2406 let logslope_terms = spatial_length_scale_term_indices(&logslopespec_boot);
2407 let marginal_has_spatial = !marginal_terms.is_empty();
2408 let logslope_has_spatial = !logslope_terms.is_empty();
2409 let analytic_joint_derivatives_available =
2410 marginal_has_spatial || logslope_has_spatial || setup.log_kappa_dim() == 0;
2411 if setup.log_kappa_dim() > 0 && !analytic_joint_derivatives_available {
2412 return Err("exact bernoulli marginal-slope spatial optimization requires analytic joint psi derivatives"
2413 .to_string());
2414 }
2415 let initial_rho = setup.theta0().slice(s![..setup.rho_dim()]).to_owned();
2416 let initial_blocks = build_blocks(&initial_rho, &marginal_design, &logslope_design)?;
2417 let initial_family = make_family(&marginal_design, &logslope_design, initial_sigma);
2418 let (joint_gradient, joint_hessian) =
2419 custom_family_outer_derivatives(&initial_family, &initial_blocks, options);
2420 let analytic_joint_gradient_available = analytic_joint_derivatives_available
2421 && matches!(joint_gradient, gam_problem::Derivative::Analytic);
2422 let analytic_joint_hessian_available =
2428 analytic_joint_derivatives_available && joint_hessian.is_analytic();
2429 let kappa_options_ref: &SpatialLengthScaleOptimizationOptions = &effective_kappa_options;
2430 let sigma_from_theta = |theta: &Array1<f64>| -> Option<f64> {
2431 if sigma_learnable {
2432 Some(theta[setup.rho_dim() + setup.log_kappa_dim()].exp())
2433 } else {
2434 initial_sigma
2435 }
2436 };
2437 let derivative_block_cache = RefCell::new(
2438 None::<(
2439 Array1<f64>,
2440 Arc<Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>>,
2441 )>,
2442 );
2443 let theta_matches = |left: &Array1<f64>, right: &Array1<f64>| -> bool {
2444 left.len() == right.len()
2445 && left
2446 .iter()
2447 .zip(right.iter())
2448 .all(|(lhs, rhs)| (*lhs - *rhs).abs() <= 1e-12 * (1.0 + lhs.abs().max(rhs.abs())))
2449 };
2450 let get_derivative_blocks = |theta: &Array1<f64>,
2451 specs: &[TermCollectionSpec],
2452 designs: &[TermCollectionDesign]|
2453 -> Result<
2454 Arc<Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>>,
2455 String,
2456 > {
2457 if let Some((cached_theta, cached_blocks)) = derivative_block_cache.borrow().as_ref()
2458 && theta_matches(cached_theta, theta)
2459 {
2460 return Ok(Arc::clone(cached_blocks));
2461 }
2462
2463 let built = |specs: &[TermCollectionSpec],
2464 designs: &[TermCollectionDesign]|
2465 -> Result<
2466 Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>,
2467 String,
2468 > {
2469 let marginal_psi_derivs = if marginal_has_spatial {
2470 build_block_spatial_psi_derivatives(data_view, &specs[0], &designs[0])?.ok_or_else(
2471 || {
2472 "bernoulli marginal-slope: marginal block has spatial terms \
2473 but spatial psi derivatives are unavailable"
2474 .to_string()
2475 },
2476 )?
2477 } else {
2478 Vec::new()
2479 };
2480 let logslope_psi_derivs = if logslope_has_spatial {
2481 build_block_spatial_psi_derivatives(data_view, &specs[1], &designs[1])?.ok_or_else(
2482 || {
2483 "bernoulli marginal-slope: logslope block has spatial terms \
2484 but spatial psi derivatives are unavailable"
2485 .to_string()
2486 },
2487 )?
2488 } else {
2489 Vec::new()
2490 };
2491 let mut derivative_blocks = vec![marginal_psi_derivs, logslope_psi_derivs];
2492 if score_warp_runtime.is_some() {
2493 derivative_blocks.push(Vec::new());
2494 }
2495 if link_dev_runtime.is_some() {
2496 derivative_blocks.push(Vec::new());
2497 }
2498 if sigma_learnable {
2499 derivative_blocks
2500 .last_mut()
2501 .expect("bernoulli derivative block list is non-empty")
2502 .push(crate::custom_family::CustomFamilyBlockPsiDerivative::new(
2503 None,
2504 Array2::zeros((0, 0)),
2505 Array2::zeros((0, 0)),
2506 None,
2507 None,
2508 None,
2509 None,
2510 ));
2511 }
2512 Ok(derivative_blocks)
2513 }(specs, designs)?;
2514 let built = Arc::new(built);
2515 derivative_block_cache.replace(Some((theta.clone(), Arc::clone(&built))));
2516 Ok(built)
2517 };
2518
2519 let outer_policy = {
2524 let psi_dim = setup.theta0().len() - setup.rho_dim();
2525 initial_family.outer_derivative_policy(&initial_blocks, psi_dim, options)
2526 };
2527 let exact_spatial_outer_tol = kappa_options_ref.rel_tol.max(EXACT_SPATIAL_OUTER_TOL_FLOOR);
2528 let solved = optimize_spatial_length_scale_exact_joint(
2529 data_view,
2530 &[marginalspec_boot.clone(), logslopespec_boot.clone()],
2531 &[marginal_terms.clone(), logslope_terms.clone()],
2532 kappa_options_ref,
2533 &setup,
2534 gam_solve::seeding::SeedRiskProfile::GeneralizedLinear,
2535 analytic_joint_gradient_available,
2536 analytic_joint_hessian_available,
2537 true,
2538 None,
2539 outer_policy,
2540 |theta, specs: &[TermCollectionSpec], designs: &[TermCollectionDesign]| {
2541 if let Some(err) = runaway_error.borrow().as_ref().cloned() {
2542 return Err(err);
2543 }
2544 assert_eq!(
2545 specs.len(),
2546 designs.len(),
2547 "spatial joint optimizer must supply one spec per design",
2548 );
2549 let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
2550 let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
2551 let sigma = sigma_from_theta(theta);
2552 final_sigma_cell.set(sigma);
2553 let family = make_family(&designs[0], &designs[1], sigma);
2554 let fit = inner_fit(&family, &blocks, options)?;
2555 if let Some(block) = fit.block_states.first()
2556 && let Some(err) = bernoulli_marginal_slope_runaway_error_from_beta(
2557 block.beta.view(),
2558 &designs[0],
2559 &specs[0],
2560 fit.outer_converged,
2561 "final fit",
2562 )
2563 {
2564 runaway_error.replace(Some(err.clone()));
2565 return Err(err);
2566 }
2567 let mut hints_mut = hints.borrow_mut();
2568 let mut bidx = 0usize;
2569 if let Some(block) = fit.block_states.get(bidx) {
2570 hints_mut.marginal_beta = Some(block.beta.clone());
2571 }
2572 bidx += 1;
2573 if let Some(block) = fit.block_states.get(bidx) {
2574 hints_mut.logslope_beta = Some(block.beta.clone());
2575 }
2576 bidx += 1;
2577 if score_warp_prepared.is_some() {
2578 if let Some(block) = fit.block_states.get(bidx) {
2579 hints_mut.score_warp_beta = Some(block.beta.clone());
2580 }
2581 bidx += 1;
2582 }
2583 if link_dev_prepared.is_some()
2584 && let Some(block) = fit.block_states.get(bidx)
2585 {
2586 hints_mut.link_dev_beta = Some(block.beta.clone());
2587 }
2588 Ok(fit)
2589 },
2590 |theta,
2591 specs: &[TermCollectionSpec],
2592 designs: &[TermCollectionDesign],
2593 eval_mode,
2594 row_set: &crate::row_kernel::RowSet| {
2595 if let Some(err) = runaway_error.borrow().as_ref().cloned() {
2596 return Err(err);
2597 }
2598 use gam_problem::EvalMode;
2599 static BMS_OUTER_EVAL_ROWSET_LOGGED: std::sync::Once = std::sync::Once::new();
2606 BMS_OUTER_EVAL_ROWSET_LOGGED.call_once(|| {
2607 let row_set_rows = match row_set {
2608 crate::row_kernel::RowSet::All => spec.y.len(),
2609 crate::row_kernel::RowSet::Subsample { rows, .. } => rows.len(),
2610 };
2611 log::debug!(
2612 "[BMS exact outer eval] mode={eval_mode:?} row_set_rows={row_set_rows}"
2613 );
2614 });
2615 let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
2616 let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
2617 if let Some(beta_seed) = pending_beta_seed.borrow_mut().take() {
2621 let widths: Vec<usize> = blocks.iter().map(|b| b.design.ncols()).collect();
2622 match CustomFamilyWarmStart::from_cached_beta(&widths, &beta_seed) {
2623 Ok(ws) => {
2624 exact_warm_start.replace(Some(ws));
2625 }
2626 Err(e) => {
2627 log::warn!(
2628 "[BMS] outer ρ-cache β-warm-start rejected: {e}; falling back to cold β"
2629 );
2630 }
2631 }
2632 }
2633 let sigma = sigma_from_theta(theta);
2634 final_sigma_cell.set(sigma);
2635 let family = make_family(&designs[0], &designs[1], sigma);
2636 let derivative_blocks = get_derivative_blocks(theta, specs, designs)?;
2637 let effective_mode = match eval_mode {
2641 EvalMode::ValueGradientHessian if !analytic_joint_hessian_available => {
2642 EvalMode::ValueAndGradient
2643 }
2644 other => other,
2645 };
2646 let mut eval_options =
2647 joint_hyper_options_for_outer_tolerance(options, exact_spatial_outer_tol);
2648 if let crate::row_kernel::RowSet::Subsample { rows, n_full } = row_set {
2649 let subsample = crate::outer_subsample::OuterScoreSubsample::from_weighted_rows(
2650 rows.as_ref().clone(),
2651 *n_full,
2652 0,
2653 );
2654 eval_options.outer_score_subsample = Some(Arc::new(subsample));
2655 eval_options.auto_outer_subsample = false;
2656 }
2657 let eval = evaluate_custom_family_joint_hyper_shared(
2658 &family,
2659 &blocks,
2660 &eval_options,
2661 &rho,
2662 derivative_blocks,
2663 exact_warm_start.borrow().as_ref(),
2664 effective_mode,
2665 )?;
2666 if let Some(err) = bernoulli_marginal_slope_runaway_error(
2667 &eval.warm_start,
2668 &designs[0],
2669 &specs[0],
2670 eval.inner_converged,
2671 "exact outer evaluation",
2672 ) {
2673 runaway_error.replace(Some(err.clone()));
2674 return Err(err);
2675 }
2676 exact_warm_start.replace(Some(eval.warm_start.clone()));
2677 if !eval.inner_converged {
2678 return Err(
2679 "exact bernoulli marginal-slope inner solve did not converge".to_string(),
2680 );
2681 }
2682 if matches!(eval_mode, EvalMode::ValueGradientHessian)
2683 && analytic_joint_hessian_available
2684 && !eval.outer_hessian.is_analytic()
2685 {
2686 return Err("exact bernoulli marginal-slope joint [rho, psi] objective did not return an outer Hessian"
2687 .to_string());
2688 }
2689 Ok((eval.objective, eval.gradient, eval.outer_hessian))
2690 },
2691 |theta, specs: &[TermCollectionSpec], designs: &[TermCollectionDesign]| {
2692 if let Some(err) = runaway_error.borrow().as_ref().cloned() {
2693 return Err(err);
2694 }
2695 let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
2696 let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
2697 if let Some(beta_seed) = pending_beta_seed.borrow_mut().take() {
2698 let widths: Vec<usize> = blocks.iter().map(|b| b.design.ncols()).collect();
2699 match CustomFamilyWarmStart::from_cached_beta(&widths, &beta_seed) {
2700 Ok(ws) => {
2701 exact_warm_start.replace(Some(ws));
2702 }
2703 Err(e) => {
2704 log::warn!(
2705 "[BMS] outer ρ-cache β-warm-start rejected (efs): {e}; falling back to cold β"
2706 );
2707 }
2708 }
2709 }
2710 let sigma = sigma_from_theta(theta);
2711 final_sigma_cell.set(sigma);
2712 let family = make_family(&designs[0], &designs[1], sigma);
2713 let derivative_blocks = get_derivative_blocks(theta, specs, designs)?;
2714 let eval = evaluate_custom_family_joint_hyper_efs_shared(
2715 &family,
2716 &blocks,
2717 &joint_hyper_options_for_outer_tolerance(options, exact_spatial_outer_tol),
2718 &rho,
2719 derivative_blocks,
2720 exact_warm_start.borrow().as_ref(),
2721 )?;
2722 if let Some(err) = bernoulli_marginal_slope_runaway_error(
2723 &eval.warm_start,
2724 &designs[0],
2725 &specs[0],
2726 eval.inner_converged,
2727 "EFS outer evaluation",
2728 ) {
2729 runaway_error.replace(Some(err.clone()));
2730 return Err(err);
2731 }
2732 exact_warm_start.replace(Some(eval.warm_start.clone()));
2733 if !eval.inner_converged {
2734 return Err(
2735 "exact bernoulli marginal-slope EFS inner solve did not converge".to_string(),
2736 );
2737 }
2738 Ok(eval.efs_eval)
2739 },
2740 crate::marginal_slope_shared::make_beta_seed_validator(&pending_beta_seed),
2741 )?;
2742
2743 let mut resolved_specs = solved.resolved_specs;
2744 let mut designs = solved.designs;
2745 let mut solved_fit = solved.fit;
2757 if let Some(reparam) = logslope_reduced_reparam.as_ref() {
2758 let r = reparam.reduced_cols();
2759 if let Some(block) = solved_fit.blocks.get_mut(1)
2760 && block.beta.len() == r
2761 {
2762 block.beta = reparam.recover_original_logslope_beta(&block.beta)?;
2763 }
2764 if let Some(state) = solved_fit.block_states.get_mut(1)
2765 && state.beta.len() == r
2766 {
2767 state.beta = reparam.recover_original_logslope_beta(&state.beta)?;
2768 }
2769 }
2770 let (latent_z_rank_int_calibration, latent_z_conditional_calibration) =
2812 match latent_z_calibration {
2813 LatentMeasureCalibration::None => (None, None),
2814 LatentMeasureCalibration::RankInverseNormal(cal) => (Some(cal), None),
2815 LatentMeasureCalibration::ConditionalLocationScale(cal) => (None, Some(cal)),
2816 };
2817 if let Some(cal) = latent_z_conditional_calibration.as_ref()
2827 && let Some(vb) = solved_fit.covariance_conditional.clone()
2828 {
2829 let p_beta = vb.nrows();
2830 let marginal_dense = marginal_design
2831 .design
2832 .try_to_dense_arc("bms generated-regressor marginal design")?;
2833 let logslope_reduced = reduce_logslope_design(&logslope_design)?;
2834 let logslope_reduced_dense = logslope_reduced
2835 .design
2836 .try_to_dense_arc("bms generated-regressor reduced logslope design")?;
2837 let p_m = marginal_dense.ncols();
2838 let r = logslope_reduced_dense.ncols();
2839 if p_beta != vb.ncols() {
2840 return Err(format!(
2841 "bms generated-regressor: covariance_conditional must be square, got {}×{}",
2842 vb.nrows(),
2843 vb.ncols()
2844 ));
2845 }
2846 if p_beta == p_m + r {
2850 let marginal_eta = &solved_fit.block_states[0].eta;
2851 let slope_eta = &solved_fit.block_states[1].eta;
2852 let probit_scale = probit_frailty_scale(final_sigma_cell.get());
2853 let s = rigid_standard_normal_score_zeta_sensitivity(
2854 &spec.base_link,
2855 marginal_eta,
2856 slope_eta,
2857 z.as_ref(),
2858 y.as_ref(),
2859 weights.as_ref(),
2860 probit_scale,
2861 marginal_dense.view(),
2862 logslope_reduced_dense.view(),
2863 p_beta,
2864 )?;
2865 let correction = cal.generated_regressor_correction(
2873 s.view(),
2874 spec.z.view(),
2875 marginal_dense.view(),
2876 vb.view(),
2877 )?;
2878 if let Some(cov) = solved_fit.covariance_conditional.as_mut() {
2879 *cov = &*cov + &correction;
2880 }
2881 if let Some(cov) = solved_fit.covariance_corrected.as_mut() {
2882 *cov = &*cov + &correction;
2883 }
2884 log::info!(
2885 "[BMS latent-z] Murphy–Topel generated-regressor SE correction applied: \
2886 p_beta={p_beta} theta1_dim={} max_diag_inflation={:.3e}",
2887 cal.theta1_dim(),
2888 (0..p_beta)
2889 .map(|i| correction[[i, i]])
2890 .fold(0.0_f64, f64::max),
2891 );
2892 } else {
2893 log::info!(
2894 "[BMS latent-z] Murphy–Topel generated-regressor SE correction skipped: \
2895 aux deviation blocks present (p_beta={p_beta} > marginal({p_m})+logslope({r})); \
2896 rigid-kernel z-channel does not yet cover score_warp/link_dev deviations"
2897 );
2898 }
2899 }
2900 Ok(BernoulliMarginalSlopeFitResult {
2913 fit: solved_fit,
2914 marginalspec_resolved: resolved_specs.remove(0),
2915 logslopespec_resolved: resolved_specs.remove(0),
2916 marginal_design: designs.remove(0),
2917 logslope_design: designs.remove(0),
2918 baseline_marginal: baseline.0,
2919 baseline_logslope: baseline.1,
2920 z_normalization,
2921 latent_measure,
2922 score_warp_runtime,
2923 link_dev_runtime,
2924 gaussian_frailty_sd: final_sigma_cell.get(),
2925 cross_block_warnings,
2926 latent_z_rank_int_calibration,
2927 latent_z_conditional_calibration,
2928 })
2929}