1use super::family::*;
2use super::gradient_paths::*;
3use super::hessian_paths::{new_cell_moment_cache_stats, new_cell_moment_lru_cache};
4use super::install_flex::validate_spec;
5use super::*;
6use gam_linalg::faer_ndarray::{FaerEigh, fast_ab, fast_atb, fast_xt_diag_x};
7use crate::marginal_slope_orthogonal::INFLUENCE_ABSORBER_FIXED_LOG_LAMBDA;
8use faer::Side;
9
10pub(crate) const BMS_PROBIT_SEPARATION_ETA_INF: f64 = 35.0;
25
26pub(super) const GAUGE_PRIORITY_ANCHOR: u8 = 200;
41pub(super) const GAUGE_PRIORITY_MARGINAL: u8 = 150;
44pub(super) const GAUGE_PRIORITY_LOGSLOPE: u8 = 120;
46pub(super) const GAUGE_PRIORITY_CANDIDATE_FLEX: u8 = 100;
49pub(super) const GAUGE_PRIORITY_SCORE_WARP_DEV: u8 = 80;
52pub(super) const GAUGE_PRIORITY_DEVIATION_DEFAULT: u8 = 70;
56pub(super) const GAUGE_PRIORITY_LINK_DEV: u8 = 60;
58
59pub(crate) const EXACT_SPATIAL_OUTER_TOL_FLOOR: f64 = 1e-6;
65
66pub struct BmsMarginalJacobian {
105 pub marginal_dense: Arc<Array2<f64>>,
107 pub logslope_dense: Arc<Array2<f64>>,
109 pub offset_m: Array1<f64>,
110 pub offset_s: Array1<f64>,
111 pub p_marginal: usize,
113}
114
115impl BmsMarginalJacobian {
116 pub fn new(
117 marginal_dense: Arc<Array2<f64>>,
118 logslope_dense: Arc<Array2<f64>>,
119 offset_m: Array1<f64>,
120 offset_s: Array1<f64>,
121 p_marginal: usize,
122 ) -> Self {
123 Self {
124 marginal_dense,
125 logslope_dense,
126 offset_m,
127 offset_s,
128 p_marginal,
129 }
130 }
131}
132
133impl BlockEffectiveJacobian for BmsMarginalJacobian {
134 fn effective_jacobian_rows(
135 &self,
136 state: &FamilyLinearizationState<'_>,
137 rows: std::ops::Range<usize>,
138 ) -> Result<Array2<f64>, String> {
139 let beta = state.beta;
140 let s = state.probit_frailty_scale;
141 let p_m = self.p_marginal;
142 let p_s_block = self.logslope_dense.ncols();
143 let beta_s_raw = if beta.len() > p_m {
144 &beta[p_m..]
145 } else {
146 &[][..]
147 };
148 let p_s_use = p_s_block.min(beta_s_raw.len());
149 let beta_s = &beta_s_raw[..p_s_use];
150 let n = self.marginal_dense.nrows();
151 let rows = rows.start.min(n)..rows.end.min(n);
152 let p_block = self.marginal_dense.ncols();
153
154 let mut out = Array2::<f64>::zeros((rows.end - rows.start, p_block));
164 for i in rows.clone() {
165 let g_i = self.offset_s[i]
166 + self
167 .logslope_dense
168 .row(i)
169 .slice(ndarray::s![..p_s_use])
170 .dot(&ArrayView1::from(beta_s));
171 let sg = s * g_i;
172 let c_i = (1.0 + sg * sg).sqrt();
173 let m_row = self.marginal_dense.row(i);
175 out.row_mut(i - rows.start).assign(&m_row.mapv(|x| c_i * x));
176 }
177 Ok(out)
178 }
179
180 fn n_outputs(&self) -> usize {
181 1
182 }
183}
184
185pub struct BmsLogslopeJacobian {
197 pub marginal_dense: Arc<Array2<f64>>,
199 pub logslope_dense: Arc<Array2<f64>>,
201 pub offset_m: Array1<f64>,
202 pub offset_s: Array1<f64>,
203 pub z: Arc<Array1<f64>>,
204 pub p_marginal: usize,
206}
207
208impl BmsLogslopeJacobian {
209 pub fn new(
210 marginal_dense: Arc<Array2<f64>>,
211 logslope_dense: Arc<Array2<f64>>,
212 offset_m: Array1<f64>,
213 offset_s: Array1<f64>,
214 z: Arc<Array1<f64>>,
215 p_marginal: usize,
216 ) -> Self {
217 Self {
218 marginal_dense,
219 logslope_dense,
220 offset_m,
221 offset_s,
222 z,
223 p_marginal,
224 }
225 }
226}
227
228impl BlockEffectiveJacobian for BmsLogslopeJacobian {
229 fn effective_jacobian_rows(
230 &self,
231 state: &FamilyLinearizationState<'_>,
232 rows: std::ops::Range<usize>,
233 ) -> Result<Array2<f64>, String> {
234 let beta = state.beta;
235 let s = state.probit_frailty_scale;
236 let p_m = self.p_marginal;
237 let p_m_use = p_m.min(beta.len());
238 let beta_m = &beta[..p_m_use];
239 let beta_s_raw = if beta.len() > p_m {
240 &beta[p_m..]
241 } else {
242 &[][..]
243 };
244 let p_s_block = self.logslope_dense.ncols();
245 let p_s_use = p_s_block.min(beta_s_raw.len());
246 let beta_s = &beta_s_raw[..p_s_use];
247 let n = self.logslope_dense.nrows();
248 let rows = rows.start.min(n)..rows.end.min(n);
249
250 let mut out = Array2::<f64>::zeros((rows.end - rows.start, p_s_block));
263 for i in rows.clone() {
264 let q_i = self.offset_m[i]
265 + self
266 .marginal_dense
267 .row(i)
268 .slice(ndarray::s![..p_m_use])
269 .dot(&ArrayView1::from(beta_m));
270 let g_i = self.offset_s[i]
271 + self
272 .logslope_dense
273 .row(i)
274 .slice(ndarray::s![..p_s_use])
275 .dot(&ArrayView1::from(beta_s));
276 let sg = s * g_i;
277 let c_i = (1.0 + sg * sg).sqrt();
278 let z_i = self.z[i];
279 let factor = q_i * s * s * g_i / c_i + s * z_i;
281 let g_row = self.logslope_dense.row(i);
283 out.row_mut(i - rows.start)
284 .assign(&g_row.mapv(|x| factor * x));
285 }
286 Ok(out)
287 }
288
289 fn n_outputs(&self) -> usize {
290 1
291 }
292}
293
294pub(crate) fn widen_marginal_dense_with_influence(
304 marginal_dense: &Arc<Array2<f64>>,
305 influence_columns: Option<&Array2<f64>>,
306) -> Result<Arc<Array2<f64>>, String> {
307 let Some(z_infl) = influence_columns else {
308 return Ok(Arc::clone(marginal_dense));
309 };
310 let n = marginal_dense.nrows();
311 if z_infl.nrows() != n {
312 return Err(format!(
313 "influence block: residualised columns have {} rows, marginal design has {n}",
314 z_infl.nrows()
315 ));
316 }
317 let p_m = marginal_dense.ncols();
318 let p1 = z_infl.ncols();
319 let mut widened = Array2::<f64>::zeros((n, p_m + p1));
320 widened
321 .slice_mut(s![.., ..p_m])
322 .assign(marginal_dense.as_ref());
323 widened.slice_mut(s![.., p_m..]).assign(z_infl);
324 Ok(Arc::new(widened))
325}
326
327pub(crate) const LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL: f64 = 1.0e-6;
335
336#[derive(Debug, Clone)]
388pub(super) struct ReducedLogslopeReparam {
389 transform: Array2<f64>,
392}
393
394impl ReducedLogslopeReparam {
395 #[inline]
397 pub(super) fn original_cols(&self) -> usize {
398 self.transform.nrows()
399 }
400
401 #[inline]
403 pub(super) fn reduced_cols(&self) -> usize {
404 self.transform.ncols()
405 }
406
407 pub(super) fn recover_original_logslope_beta(
411 &self,
412 beta_reduced: &Array1<f64>,
413 ) -> Result<Array1<f64>, String> {
414 if beta_reduced.len() != self.reduced_cols() {
415 return Err(format!(
416 "reduced logslope reparam: β' length ({}) != reduced width ({})",
417 beta_reduced.len(),
418 self.reduced_cols()
419 ));
420 }
421 Ok(self.transform.dot(beta_reduced))
422 }
423}
424
425fn build_reduced_logslope_reparam(
433 marginal_design: &TermCollectionDesign,
434 logslope_design: &TermCollectionDesign,
435 z: &Array1<f64>,
436 row_metric: &Array1<f64>,
437 marginal_offset: &Array1<f64>,
438 logslope_offset: &Array1<f64>,
439 marginal_baseline: f64,
440 logslope_baseline: f64,
441 probit_scale: f64,
442) -> Result<Option<ReducedLogslopeReparam>, String> {
443 let marginal = marginal_design
444 .design
445 .try_to_dense_arc("build_reduced_logslope_reparam::marginal")?;
446 let logslope = logslope_design
447 .design
448 .try_to_dense_arc("build_reduced_logslope_reparam::logslope")?;
449 let n = marginal.nrows();
450 if logslope.nrows() != n
451 || z.len() != n
452 || row_metric.len() != n
453 || marginal_offset.len() != n
454 || logslope_offset.len() != n
455 {
456 return Err(format!(
457 "reduced logslope reparam row mismatch: marginal={}, logslope={}, z={}, row_metric={}, marginal_offset={}, logslope_offset={}",
458 marginal.nrows(),
459 logslope.nrows(),
460 z.len(),
461 row_metric.len(),
462 marginal_offset.len(),
463 logslope_offset.len(),
464 ));
465 }
466 let p_m = marginal.ncols();
467 let p_g = logslope.ncols();
468 if p_m == 0 || p_g == 0 {
469 return Ok(None);
470 }
471 if !marginal_baseline.is_finite()
472 || !logslope_baseline.is_finite()
473 || !probit_scale.is_finite()
474 || probit_scale <= 0.0
475 || z.iter().any(|v| !v.is_finite())
476 || row_metric.iter().any(|v| !v.is_finite() || *v < 0.0)
477 || marginal_offset.iter().any(|v| !v.is_finite())
478 || logslope_offset.iter().any(|v| !v.is_finite())
479 {
480 return Err(
481 "reduced logslope reparam requires finite pilot geometry and finite non-negative row metric"
482 .to_string(),
483 );
484 }
485
486 match reduced_logslope_transform_effective(
497 marginal.view(),
498 logslope.view(),
499 z,
500 row_metric,
501 marginal_offset,
502 logslope_offset,
503 marginal_baseline,
504 logslope_baseline,
505 probit_scale,
506 )? {
507 Some(transform) => Ok(Some(ReducedLogslopeReparam { transform })),
508 None => Ok(None),
509 }
510}
511
512pub(crate) fn reduced_logslope_transform_effective(
531 marginal: ArrayView2<'_, f64>,
532 logslope: ArrayView2<'_, f64>,
533 z: &Array1<f64>,
534 row_metric: &Array1<f64>,
535 marginal_offset: &Array1<f64>,
536 logslope_offset: &Array1<f64>,
537 marginal_baseline: f64,
538 logslope_baseline: f64,
539 probit_scale: f64,
540) -> Result<Option<Array2<f64>>, String> {
541 let n = marginal.nrows();
542 let p_m = marginal.ncols();
543 let p_g = logslope.ncols();
544 if p_m == 0 || p_g == 0 {
545 return Ok(None);
546 }
547
548 let mut m_eff = Array2::<f64>::zeros((n, p_m));
550 let mut g_eff = Array2::<f64>::zeros((n, p_g));
551 for i in 0..n {
552 let q_i = marginal_offset[i] + marginal_baseline;
553 let g_i = logslope_offset[i] + logslope_baseline;
554 let sg = probit_scale * g_i;
555 let c_i = (1.0 + sg * sg).sqrt();
556 let f_i = q_i * probit_scale * probit_scale * g_i / c_i + probit_scale * z[i];
557 for j in 0..p_m {
558 m_eff[[i, j]] = c_i * marginal[[i, j]];
559 }
560 for j in 0..p_g {
561 g_eff[[i, j]] = f_i * logslope[[i, j]];
562 }
563 }
564
565 let c_gram = fast_xt_diag_x(&g_eff, row_metric);
568 let energy_scale = (0..p_g).map(|i| c_gram[[i, i]]).fold(0.0_f64, f64::max);
569 if !energy_scale.is_finite() || energy_scale <= 0.0 {
570 return Ok(None);
571 }
572
573 let mut a_gram = fast_xt_diag_x(&m_eff, row_metric);
577 let a_scale = (0..p_m).map(|i| a_gram[[i, i]]).fold(0.0_f64, f64::max);
578 let a_ridge = (a_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL).max(f64::EPSILON);
579 for i in 0..p_m {
580 a_gram[[i, i]] += a_ridge;
581 }
582
583 let b_cross = gam_linalg::faer_ndarray::fast_xt_diag_y(&m_eff, row_metric, &g_eff);
585 let a_view = gam_linalg::faer_ndarray::FaerArrayView::new(&a_gram);
586 let a_factor =
587 gam_linalg::faer_ndarray::factorize_symmetricwith_fallback(a_view.as_ref(), Side::Lower)
588 .map_err(|e| {
589 format!(
590 "reduced logslope reparam: effective marginal Gram factorization failed: {e}"
591 )
592 })?;
593 let b_view = gam_linalg::faer_ndarray::FaerArrayView::new(&b_cross);
594 let solved = a_factor.solve(b_view.as_ref()); let a_inv_b = Array2::from_shape_fn((p_m, p_g), |(i, j)| solved[(i, j)]);
596 let schur = fast_atb(&b_cross, &a_inv_b); let mut stt = &c_gram - &schur;
598 stt = (&stt + &stt.t()) * 0.5;
599 if stt.iter().any(|v| !v.is_finite()) {
600 return Err(
601 "reduced logslope reparam: effective Schur Gram produced non-finite entries"
602 .to_string(),
603 );
604 }
605
606 let (evals, evecs) = stt
607 .eigh(Side::Lower)
608 .map_err(|e| format!("reduced logslope reparam: eigendecomposition failed: {e:?}"))?;
609 let tol = energy_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL;
613 let mut kept: Vec<usize> = (0..evals.len()).filter(|&i| evals[i] > tol).collect();
614 kept.sort_by(|&a, &b| {
615 evals[b]
616 .partial_cmp(&evals[a])
617 .unwrap_or(std::cmp::Ordering::Equal)
618 });
619 let r = kept.len();
620 if r == p_g || r == 0 {
624 return Ok(None);
625 }
626 let mut transform = Array2::<f64>::zeros((p_g, r));
627 for (out_col, &src) in kept.iter().enumerate() {
628 transform.column_mut(out_col).assign(&evecs.column(src));
629 }
630 if transform.iter().any(|v| !v.is_finite()) {
631 return Err(
632 "reduced logslope reparam: reduced transform produced non-finite entries".to_string(),
633 );
634 }
635 Ok(Some(transform))
636}
637
638fn reparameterize_logslope_design_reduced(
645 logslope_design: &TermCollectionDesign,
646 reparam: &ReducedLogslopeReparam,
647) -> Result<TermCollectionDesign, String> {
648 let g = logslope_design
649 .design
650 .try_to_dense_arc("reparameterize_logslope_design_reduced::logslope")?;
651 let p_g = g.ncols();
652 if p_g != reparam.original_cols() {
653 return Err(format!(
654 "reduced logslope reparam width mismatch: design has {p_g} cols, transform expects {}",
655 reparam.original_cols()
656 ));
657 }
658 let t = &reparam.transform;
659 let r = reparam.reduced_cols();
660 let g_reduced = fast_ab(&g, t);
662
663 let mut new_penalties: Vec<gam_terms::smooth::BlockwisePenalty> =
666 Vec::with_capacity(logslope_design.penalties.len());
667 let mut new_nullspace_dims: Vec<usize> = Vec::with_capacity(logslope_design.penalties.len());
668 for bp in &logslope_design.penalties {
669 let mut full = Array2::<f64>::zeros((p_g, p_g));
670 full.slice_mut(s![bp.col_range.clone(), bp.col_range.clone()])
671 .assign(&bp.local);
672 let st = fast_ab(&full, t); let mut s_reduced = fast_atb(t, &st); s_reduced = (&s_reduced + &s_reduced.t()) * 0.5;
676 let (evals, _) = s_reduced
678 .eigh(Side::Lower)
679 .map_err(|e| format!("reduced logslope penalty eigendecomposition failed: {e:?}"))?;
680 let max_eval = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
681 let pen_tol = (max_eval * 1.0e-12).max(f64::EPSILON);
682 let rank = evals.iter().filter(|&&v| v.abs() > pen_tol).count();
683 let nullspace_dim = r.saturating_sub(rank);
684 new_penalties.push(gam_terms::smooth::BlockwisePenalty::new(0..r, s_reduced));
685 new_nullspace_dims.push(nullspace_dim);
686 }
687
688 let new_design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(g_reduced));
689 Ok(TermCollectionDesign {
695 design: new_design,
696 penalties: new_penalties,
697 nullspace_dims: new_nullspace_dims,
698 penaltyinfo: Vec::new(),
699 dropped_penaltyinfo: Vec::new(),
700 coefficient_lower_bounds: None,
701 linear_constraints: None,
702 intercept_range: 0..0,
703 linear_ranges: Vec::new(),
704 random_effect_ranges: Vec::new(),
705 random_effect_levels: Vec::new(),
706 smooth: gam_terms::smooth::SmoothDesign {
707 term_designs: Vec::new(),
708 penalties: Vec::new(),
709 nullspace_dims: Vec::new(),
710 penaltyinfo: Vec::new(),
711 dropped_penaltyinfo: Vec::new(),
712 terms: Vec::new(),
713 coefficient_lower_bounds: None,
714 linear_constraints: None,
715 },
716 })
717}
718
719pub(crate) fn marginal_penalties_with_influence_ridge(
740 design: &TermCollectionDesign,
741 rho_marginal: &Array1<f64>,
742 influence_columns: Option<&Array2<f64>>,
743 influence_ridge_log_lambda: f64,
744) -> Result<(Vec<PenaltyMatrix>, Vec<usize>, Array1<f64>), String> {
745 let p_m = design.design.ncols();
746 let p1 = influence_columns.map(|z| z.ncols()).unwrap_or(0);
747 let total_dim = p_m + p1;
748 let mut penalties: Vec<PenaltyMatrix> = design
751 .penalties
752 .iter()
753 .map(|bp| bp.to_penalty_matrix(total_dim))
754 .collect();
755 let mut nullspace_dims = design.nullspace_dims.clone();
756 let mut log_lambdas = rho_marginal.to_vec();
757
758 if p1 > 0 {
762 penalties.push(
763 PenaltyMatrix::Blockwise {
764 local: Array2::<f64>::eye(p1),
765 col_range: p_m..total_dim,
766 total_dim,
767 }
768 .with_fixed_log_lambda(influence_ridge_log_lambda),
769 );
770 nullspace_dims.push(0);
771 log_lambdas.push(influence_ridge_log_lambda);
772 }
773
774 Ok((penalties, nullspace_dims, Array1::from_vec(log_lambdas)))
775}
776
777pub(crate) fn widen_marginal_beta_hint(
780 beta_hint: Option<Array1<f64>>,
781 p_marginal_widened: usize,
782) -> Option<Array1<f64>> {
783 beta_hint.map(|hint| {
784 if hint.len() == p_marginal_widened {
785 hint
786 } else {
787 let mut widened = Array1::<f64>::zeros(p_marginal_widened);
788 let copy = hint.len().min(p_marginal_widened);
789 widened
790 .slice_mut(s![..copy])
791 .assign(&hint.slice(s![..copy]));
792 widened
793 }
794 })
795}
796
797fn marginal_fitted_eta_sup_norm(design: &TermCollectionDesign, masked_beta: &Array1<f64>) -> f64 {
807 let x = &design.design;
808 let n = x.nrows();
809 if n == 0 || x.ncols() == 0 {
810 return 0.0;
811 }
812 let mut sup = 0.0_f64;
813 for row in 0..n {
814 let eta = x.dot_row_view(row, masked_beta.view());
815 if eta.is_finite() {
816 sup = sup.max(eta.abs());
817 }
818 }
819 sup
820}
821
822fn marginal_design_beta(
825 design: &TermCollectionDesign,
826 block_beta: ArrayView1<'_, f64>,
827) -> Array1<f64> {
828 let ncols = design.design.ncols();
829 let mut masked = Array1::<f64>::zeros(ncols);
830 let copy = ncols.min(block_beta.len());
831 masked
832 .slice_mut(s![..copy])
833 .assign(&block_beta.slice(s![..copy]));
834 masked
835}
836
837fn mask_parametric_columns(
843 design: &TermCollectionDesign,
844 spec: &TermCollectionSpec,
845 full: &Array1<f64>,
846) -> Array1<f64> {
847 let ncols = design.design.ncols();
848 let mut masked = Array1::<f64>::zeros(ncols);
849 if design.intercept_range.len() == 1 {
850 let idx = design.intercept_range.start;
851 if idx < ncols {
852 masked[idx] = full[idx];
853 }
854 }
855 for (linear, (_, range)) in spec.linear_terms.iter().zip(design.linear_ranges.iter()) {
856 if linear.double_penalty {
857 continue;
858 }
859 for col in range.clone() {
860 if col < ncols {
861 masked[col] = full[col];
862 }
863 }
864 }
865 masked
866}
867
868pub(crate) fn bernoulli_marginal_slope_runaway_error_from_beta(
879 block_beta: ArrayView1<'_, f64>,
880 design: &TermCollectionDesign,
881 spec: &TermCollectionSpec,
882 inner_converged: bool,
883 eval_label: &str,
884) -> Option<String> {
885 let full_beta = marginal_design_beta(design, block_beta);
886 let parametric_beta = mask_parametric_columns(design, spec, &full_beta);
887
888 let eta_parametric = marginal_fitted_eta_sup_norm(design, ¶metric_beta);
889 let eta_full = marginal_fitted_eta_sup_norm(design, &full_beta);
890
891 let (eta_inf, explanation) = if eta_parametric >= BMS_PROBIT_SEPARATION_ETA_INF {
892 (
893 eta_parametric,
894 "an unpenalized parametric marginal direction has no stable finite probit optimum and its fitted predictor has run to the probit underflow scale",
895 )
896 } else if eta_full >= BMS_PROBIT_SEPARATION_ETA_INF {
897 (
898 eta_full,
899 "a marginal direction is trading off against the logslope surface; this is the under-constrained marginal/logslope coupling that appears when the score is correlated with the shared surface covariates",
900 )
901 } else {
902 return None;
906 };
907
908 let inner_status = if inner_converged {
909 "the inner solve reached a KKT certificate at this separation-scale predictor"
910 } else {
911 "the inner solve failed while already carrying a separation-scale predictor"
912 };
913 let beta_abs = full_beta
915 .iter()
916 .copied()
917 .filter(|v| v.is_finite())
918 .fold(0.0_f64, |acc, v| acc.max(v.abs()));
919
920 Some(format!(
921 "bernoulli marginal-slope probit marginal/logslope runaway detected in block \
922 'marginal_surface' during {eval_label}: the fitted marginal predictor has \
923 |η|∞={eta_inf:.3e} (numerical-degeneracy threshold \
924 {BMS_PROBIT_SEPARATION_ETA_INF:.1}; raw |β|∞={beta_abs:.3e} is reported for \
925 context only and does not gate this diagnostic). The joint design is \
926 identifiable; {explanation}. {inner_status}. The robust Jeffreys curvature \
927 path is already installed for this fit, so this diagnostic means the current \
928 coupled surface still drives the linear predictor to the probit underflow \
929 scale rather than a request for an external bias-reduction prior. Reduce or \
930 reparameterize the coupled marginal/logslope surface, or use a \
931 lower-dimensional logslope interaction. This is not a \
932 Matérn/Duchon polynomial-nullspace or cross-block gauge-priority \
933 failure."
934 ))
935}
936
937pub(crate) fn bernoulli_marginal_slope_runaway_error(
938 warm_start: &CustomFamilyWarmStart,
939 design: &TermCollectionDesign,
940 spec: &TermCollectionSpec,
941 inner_converged: bool,
942 eval_label: &str,
943) -> Option<String> {
944 let block_beta = warm_start.block_beta_view(0)?;
945 bernoulli_marginal_slope_runaway_error_from_beta(
946 block_beta,
947 design,
948 spec,
949 inner_converged,
950 eval_label,
951 )
952}
953
954#[cfg(test)]
955mod runaway_tests {
956 use super::*;
957 use gam_linalg::faer_ndarray::{FaerArrayView, factorize_symmetricwith_fallback, fast_xt_diag_y};
958 use gam_terms::smooth::{LinearCoefficientGeometry, LinearTermSpec};
959
960 pub(crate) fn marginal_logslope_overlap_penalty(
966 marginal_design: &DesignMatrix,
967 logslope_design: &DesignMatrix,
968 z: &Array1<f64>,
969 row_metric: &Array1<f64>,
970 marginal_offset: &Array1<f64>,
971 logslope_offset: &Array1<f64>,
972 marginal_baseline: f64,
973 logslope_baseline: f64,
974 probit_scale: f64,
975 ) -> Result<Option<Array2<f64>>, String> {
976 let marginal =
977 marginal_design.try_to_dense_arc("marginal_logslope_overlap_penalty::marginal")?;
978 let logslope =
979 logslope_design.try_to_dense_arc("marginal_logslope_overlap_penalty::logslope")?;
980 let n = marginal.nrows();
981 if logslope.nrows() != n
982 || z.len() != n
983 || row_metric.len() != n
984 || marginal_offset.len() != n
985 || logslope_offset.len() != n
986 {
987 return Err(format!(
988 "marginal/logslope overlap penalty row mismatch: marginal={}, logslope={}, z={}, row_metric={}, marginal_offset={}, logslope_offset={}",
989 marginal.nrows(),
990 logslope.nrows(),
991 z.len(),
992 row_metric.len(),
993 marginal_offset.len(),
994 logslope_offset.len(),
995 ));
996 }
997 let p_m = marginal.ncols();
998 let p_g = logslope.ncols();
999 if p_m == 0 || p_g == 0 {
1000 return Ok(None);
1001 }
1002 if !marginal_baseline.is_finite()
1003 || !logslope_baseline.is_finite()
1004 || !probit_scale.is_finite()
1005 || probit_scale <= 0.0
1006 || z.iter().any(|v| !v.is_finite())
1007 || row_metric.iter().any(|v| !v.is_finite() || *v < 0.0)
1008 || marginal_offset.iter().any(|v| !v.is_finite())
1009 || logslope_offset.iter().any(|v| !v.is_finite())
1010 {
1011 return Err(
1012 "marginal/logslope overlap penalty requires finite pilot geometry and finite non-negative row metric"
1013 .to_string(),
1014 );
1015 }
1016
1017 let mut marginal_effective = Array2::<f64>::zeros((n, p_m));
1018 let mut effective_logslope = Array2::<f64>::zeros((n, p_g));
1019 for i in 0..n {
1020 let q_i = marginal_offset[i] + marginal_baseline;
1021 let g_i = logslope_offset[i] + logslope_baseline;
1022 let sg = probit_scale * g_i;
1023 let c_i = (1.0 + sg * sg).sqrt();
1024 let logslope_factor =
1025 q_i * probit_scale * probit_scale * g_i / c_i + probit_scale * z[i];
1026 for j in 0..p_m {
1027 marginal_effective[[i, j]] = c_i * marginal[[i, j]];
1028 }
1029 for j in 0..p_g {
1030 effective_logslope[[i, j]] = logslope_factor * logslope[[i, j]];
1031 }
1032 }
1033 if effective_logslope.iter().all(|v| v.abs() <= f64::EPSILON) {
1034 return Ok(None);
1035 }
1036
1037 let mut gram = fast_xt_diag_x(&effective_logslope, row_metric);
1038 let gram_scale = gram.diag().iter().copied().fold(0.0_f64, f64::max);
1039 if !gram_scale.is_finite() || gram_scale <= 0.0 {
1040 return Ok(None);
1041 }
1042 let projection_ridge = (gram_scale * 1.0e-10).max(f64::EPSILON);
1043 for i in 0..p_g {
1044 gram[[i, i]] += projection_ridge;
1045 }
1046 let cross = fast_xt_diag_y(&effective_logslope, row_metric, &marginal_effective);
1047 let gram_view = FaerArrayView::new(&gram);
1048 let factor = factorize_symmetricwith_fallback(gram_view.as_ref(), Side::Lower)
1049 .map_err(|e| format!("marginal/logslope overlap Gram factorization failed: {e}"))?;
1050 let rhsview = FaerArrayView::new(&cross);
1051 let coeffs_mat = factor.solve(rhsview.as_ref());
1052 let coeffs = Array2::from_shape_fn((p_g, p_m), |(i, j)| coeffs_mat[(i, j)]);
1053 let projected_marginal = fast_ab(&effective_logslope, &coeffs);
1054 let mut penalty = fast_xt_diag_y(&marginal_effective, row_metric, &projected_marginal);
1055 penalty = (&penalty + &penalty.t()) * 0.5;
1056 let max_abs = penalty.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
1057 if !max_abs.is_finite() || max_abs <= 1.0e-12 {
1058 return Ok(None);
1059 }
1060 Ok(Some(penalty))
1061 }
1062
1063 #[test]
1071 pub(crate) fn effective_reduction_drops_score_weighted_confound_raw_audit_misses() {
1072 let m = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1074 let g = Array2::<f64>::from_shape_vec((3, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 9.0]).unwrap();
1075 let z = Array1::from_vec(vec![1.0, 0.5, 1.0 / 3.0]);
1076 let w = Array1::<f64>::ones(3);
1077 let zero = Array1::<f64>::zeros(3);
1078
1079 let reparam = reduced_logslope_transform_effective(
1082 m.view(),
1083 g.view(),
1084 &z,
1085 &w,
1086 &zero,
1087 &zero,
1088 0.0,
1089 0.0,
1090 1.0,
1091 )
1092 .expect("effective reduction must succeed")
1093 .expect("effective audit must reduce the score-weighted confound (raw audit would not)");
1094 assert_eq!(
1095 reparam.ncols(),
1096 1,
1097 "exactly one effective-identifiable logslope direction should survive"
1098 );
1099
1100 let g_eff = {
1104 let mut e = Array2::<f64>::zeros((3, 2));
1105 for i in 0..3 {
1106 for j in 0..2 {
1107 e[[i, j]] = z[i] * g[[i, j]];
1108 }
1109 }
1110 e
1111 };
1112 let img = g_eff.dot(&reparam.column(0));
1113 let mean = img.iter().sum::<f64>() / 3.0;
1114 let var = img.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / 3.0;
1115 assert!(
1116 var > 1.0e-6,
1117 "kept direction must be the identifiable (non-constant) effective column, var={var}"
1118 );
1119 }
1120
1121 #[test]
1127 pub(crate) fn effective_reduction_fully_confounded_single_column_returns_none() {
1128 let m = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1129 let g = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1130 let z = Array1::from_vec(vec![1.0, 0.5, 1.0 / 3.0]);
1131 let w = Array1::<f64>::ones(3);
1132 let zero = Array1::<f64>::zeros(3);
1133 let reparam = reduced_logslope_transform_effective(
1134 m.view(),
1135 g.view(),
1136 &z,
1137 &w,
1138 &zero,
1139 &zero,
1140 0.0,
1141 0.0,
1142 1.0,
1143 )
1144 .expect("effective reduction must succeed");
1145 assert!(
1146 reparam.is_none(),
1147 "fully effective-confounded logslope must keep raw design (None), not a 0-width block"
1148 );
1149 }
1150
1151 #[test]
1154 pub(crate) fn effective_reduction_no_confound_returns_none() {
1155 let m = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1156 let g = Array2::<f64>::from_shape_vec((3, 2), vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
1158 let z = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1159 let w = Array1::<f64>::ones(3);
1160 let zero = Array1::<f64>::zeros(3);
1161 let reparam = reduced_logslope_transform_effective(
1162 m.view(),
1163 g.view(),
1164 &z,
1165 &w,
1166 &zero,
1167 &zero,
1168 0.0,
1169 0.0,
1170 1.0,
1171 )
1172 .expect("effective reduction must succeed");
1173 assert!(
1174 reparam.is_none(),
1175 "no effective confound ⇒ no reduction (raw design kept unchanged)"
1176 );
1177 }
1178
1179 #[test]
1180 pub(crate) fn spatial_joint_setup_counts_only_learned_penalties_in_rho() {
1181 let data = Array2::<f64>::zeros((3, 1));
1182 let empty_terms = TermCollectionSpec {
1183 linear_terms: Vec::new(),
1184 random_effect_terms: Vec::new(),
1185 smooth_terms: Vec::new(),
1186 };
1187 let setup = joint_setup(
1188 data.view(),
1189 &empty_terms,
1190 &empty_terms,
1191 2,
1192 3,
1193 &[0.4],
1194 &SpatialLengthScaleOptimizationOptions::default(),
1195 );
1196
1197 assert_eq!(
1198 setup.rho_dim(),
1199 6,
1200 "BMS spatial setup rho must contain only learned marginal/logslope/auxiliary penalties; fixed physical ridges are carried by PenaltyMatrix::Fixed"
1201 );
1202 }
1203
1204 #[test]
1205 pub(crate) fn overlap_penalty_targets_score_weighted_logslope_span() {
1206 let marginal = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1207 Array2::from_shape_vec((4, 1), vec![0.0, 1.0, 2.0, 3.0]).unwrap(),
1208 ));
1209 let logslope = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1210 Array2::from_shape_vec((4, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap(),
1211 ));
1212 let z = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
1213 let row_metric = Array1::ones(4);
1214 let offsets = Array1::zeros(4);
1215
1216 let penalty = marginal_logslope_overlap_penalty(
1217 &marginal,
1218 &logslope,
1219 &z,
1220 &row_metric,
1221 &offsets,
1222 &offsets,
1223 0.0,
1224 0.0,
1225 1.0,
1226 )
1227 .expect("overlap penalty should build")
1228 .expect("marginal signal lies in the pilot logslope Jacobian span");
1229
1230 assert_eq!(penalty.dim(), (1, 1));
1231 assert!((penalty[[0, 0]] - 14.0).abs() < 1.0e-6);
1232 }
1233
1234 #[test]
1235 pub(crate) fn overlap_penalty_skips_weight_orthogonal_channels() {
1236 let marginal = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1237 Array2::from_shape_vec((4, 1), vec![-1.0, 1.0, -1.0, 1.0]).unwrap(),
1238 ));
1239 let logslope = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1240 Array2::from_shape_vec((4, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap(),
1241 ));
1242 let z = Array1::ones(4);
1243 let row_metric = Array1::ones(4);
1244 let offsets = Array1::zeros(4);
1245
1246 let penalty = marginal_logslope_overlap_penalty(
1247 &marginal,
1248 &logslope,
1249 &z,
1250 &row_metric,
1251 &offsets,
1252 &offsets,
1253 0.0,
1254 0.0,
1255 1.0,
1256 )
1257 .expect("overlap penalty should build");
1258
1259 assert!(penalty.is_none());
1260 }
1261
1262 fn dense_marginal_design(
1270 x: Array2<f64>,
1271 intercept_range: std::ops::Range<usize>,
1272 linear_ranges: Vec<(String, std::ops::Range<usize>)>,
1273 ) -> TermCollectionDesign {
1274 TermCollectionDesign {
1275 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x)),
1276 penalties: Vec::new(),
1277 nullspace_dims: Vec::new(),
1278 penaltyinfo: Vec::new(),
1279 dropped_penaltyinfo: Vec::new(),
1280 coefficient_lower_bounds: None,
1281 linear_constraints: None,
1282 intercept_range,
1283 linear_ranges,
1284 random_effect_ranges: Vec::new(),
1285 random_effect_levels: Vec::new(),
1286 smooth: gam_terms::smooth::SmoothDesign {
1287 term_designs: Vec::new(),
1288 penalties: Vec::new(),
1289 nullspace_dims: Vec::new(),
1290 penaltyinfo: Vec::new(),
1291 dropped_penaltyinfo: Vec::new(),
1292 terms: Vec::new(),
1293 coefficient_lower_bounds: None,
1294 linear_constraints: None,
1295 },
1296 }
1297 }
1298
1299 fn linear_term(name: &str, feature_col: usize) -> LinearTermSpec {
1300 LinearTermSpec {
1301 name: name.to_string(),
1302 feature_col,
1303 feature_cols: vec![feature_col],
1304 categorical_levels: vec![],
1305 double_penalty: false,
1306 coefficient_geometry: LinearCoefficientGeometry::default(),
1307 coefficient_min: None,
1308 coefficient_max: None,
1309 }
1310 }
1311
1312 fn empty_spec() -> TermCollectionSpec {
1313 TermCollectionSpec {
1314 linear_terms: Vec::new(),
1315 random_effect_terms: Vec::new(),
1316 smooth_terms: Vec::new(),
1317 }
1318 }
1319
1320 #[test]
1327 pub(crate) fn runaway_guard_silent_when_huge_beta_cancels_to_bounded_eta() {
1328 let x = Array2::<f64>::from_shape_vec((4, 2), vec![1.0; 8]).unwrap();
1330 let design = dense_marginal_design(x, 0..0, Vec::new());
1331 let beta = Array1::from_vec(vec![60.0, -60.0]);
1332
1333 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1334 beta.view(),
1335 &design,
1336 &empty_spec(),
1337 true,
1338 "regression-fixture",
1339 );
1340 assert!(
1341 msg.is_none(),
1342 "huge cancelling β with bounded fitted η must NOT trip the runaway guard; got {msg:?}"
1343 );
1344 }
1345
1346 #[test]
1350 pub(crate) fn runaway_guard_fires_when_fitted_eta_exceeds_threshold() {
1351 let x = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1352 let design = dense_marginal_design(x, 0..0, Vec::new());
1353 let beta = Array1::from_vec(vec![40.0]);
1354
1355 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1356 beta.view(),
1357 &design,
1358 &empty_spec(),
1359 true,
1360 "separation-fixture",
1361 )
1362 .expect("fitted |η|∞=40 ≥ 35 must trip the runaway guard");
1363
1364 assert!(msg.contains("marginal/logslope runaway"));
1365 assert!(msg.contains("|η|∞"));
1366 assert!(msg.contains("4.000e1"));
1367 assert!(msg.contains("score is correlated with the shared surface covariates"));
1368 assert!(msg.contains("not a Matérn/Duchon polynomial-nullspace"));
1369 assert!(msg.contains("KKT certificate"));
1370 }
1371
1372 #[test]
1376 pub(crate) fn runaway_guard_names_unpenalized_parametric_direction_via_fitted_eta() {
1377 let x = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1378 let design = dense_marginal_design(x, 0..0, vec![("sex".to_string(), 0..1)]);
1379 let mut spec = empty_spec();
1380 spec.linear_terms.push(linear_term("sex", 0));
1381 let beta = Array1::from_vec(vec![41.0]);
1382
1383 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1384 beta.view(),
1385 &design,
1386 &spec,
1387 true,
1388 "parametric-fixture",
1389 )
1390 .expect("parametric fitted |η|∞=41 ≥ 35 must trip the runaway guard");
1391
1392 assert!(msg.contains("unpenalized parametric marginal direction"));
1393 assert!(msg.contains("|η|∞"));
1394 assert!(msg.contains("robust Jeffreys curvature path is already installed"));
1395 assert!(msg.contains("not a Matérn/Duchon polynomial-nullspace"));
1396 }
1397
1398 #[test]
1402 pub(crate) fn runaway_guard_silent_for_nonconverged_but_bounded_eta() {
1403 let x = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1404 let design = dense_marginal_design(x, 0..0, Vec::new());
1405 let beta = Array1::from_vec(vec![5.0]);
1406
1407 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1408 beta.view(),
1409 &design,
1410 &empty_spec(),
1411 false,
1412 "nonconverged-fixture",
1413 );
1414 assert!(
1415 msg.is_none(),
1416 "bounded fitted η must not raise the separation error even when the inner solve did not converge; got {msg:?}"
1417 );
1418 }
1419
1420 #[test]
1423 pub(crate) fn runaway_guard_fires_for_nonconverged_separating_eta() {
1424 let x = Array2::<f64>::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
1425 let design = dense_marginal_design(x, 0..0, Vec::new());
1426 let beta = Array1::from_vec(vec![50.0]);
1427
1428 let msg = bernoulli_marginal_slope_runaway_error_from_beta(
1429 beta.view(),
1430 &design,
1431 &empty_spec(),
1432 false,
1433 "nonconverged-separating-fixture",
1434 )
1435 .expect("separating |η|∞ at non-convergence must still trip the guard");
1436
1437 assert!(msg.contains(
1438 "the inner solve failed while already carrying a separation-scale predictor"
1439 ));
1440 }
1441
1442 #[test]
1455 pub(crate) fn bms_block_jacobians_self_compute_at_audit_empty_beta_nonzero_logslope_baseline() {
1456 use std::sync::Arc;
1457 let n = 4usize;
1458 let marginal =
1459 Arc::new(Array2::<f64>::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap());
1460 let logslope =
1461 Arc::new(Array2::<f64>::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap());
1462 let offset_m = Array1::<f64>::zeros(n);
1463 let g_baseline = 0.3_f64;
1466 let offset_s = Array1::<f64>::from_elem(n, g_baseline);
1467 let z = Arc::new(Array1::from_vec(vec![-0.7, 0.2, 0.9, 1.4]));
1468 let s = 1.0_f64;
1469
1470 let beta: Vec<f64> = Vec::new();
1472 let state = FamilyLinearizationState {
1473 beta: &beta,
1474 family_scalars: None,
1475 channel_hessian: None,
1476 probit_frailty_scale: s,
1477 };
1478
1479 let marginal_jac = BmsMarginalJacobian::new(
1480 Arc::clone(&marginal),
1481 Arc::clone(&logslope),
1482 offset_m.clone(),
1483 offset_s.clone(),
1484 1,
1485 );
1486 let j_m = marginal_jac
1487 .effective_jacobian_rows(&state, 0..n)
1488 .expect("BMS marginal Jacobian must self-compute at audit empty β (gam#370)");
1489 let c_expected = (1.0 + (s * g_baseline).powi(2)).sqrt();
1491 assert_eq!(j_m.dim(), (n, 1));
1492 for i in 0..n {
1493 assert!(
1494 (j_m[[i, 0]] - c_expected).abs() < 1e-12,
1495 "marginal J[{i}] = {} != closed-form c_i = {c_expected}",
1496 j_m[[i, 0]]
1497 );
1498 }
1499
1500 let logslope_jac = BmsLogslopeJacobian::new(
1501 Arc::clone(&marginal),
1502 Arc::clone(&logslope),
1503 offset_m,
1504 offset_s,
1505 Arc::clone(&z),
1506 1,
1507 );
1508 let j_s = logslope_jac
1509 .effective_jacobian_rows(&state, 0..n)
1510 .expect("BMS logslope Jacobian must self-compute at audit empty β (gam#370)");
1511 assert_eq!(j_s.dim(), (n, 1));
1514 for i in 0..n {
1515 let expected = s * z[i];
1516 assert!(
1517 (j_s[[i, 0]] - expected).abs() < 1e-12,
1518 "logslope J[{i}] = {} != closed-form factor {expected}",
1519 j_s[[i, 0]]
1520 );
1521 assert!(j_s[[i, 0]].is_finite());
1522 }
1523 }
1524}
1525
1526pub(crate) fn build_marginal_blockspec_bms(
1527 design: &TermCollectionDesign,
1528 baseline: f64,
1529 offset: &Array1<f64>,
1530 rho: Array1<f64>,
1531 beta_hint: Option<Array1<f64>>,
1532 logslope_design: &TermCollectionDesign,
1533 logslope_offset: &Array1<f64>,
1534 logslope_baseline: f64,
1535 p_marginal: usize,
1536 influence_columns: Option<&Array2<f64>>,
1537 influence_ridge_log_lambda: f64,
1538) -> Result<ParameterBlockSpec, String> {
1539 let offset_m = offset + baseline;
1540 let offset_s = logslope_offset + logslope_baseline;
1541 let raw_marginal_dense = design
1542 .design
1543 .try_to_dense_arc("build_marginal_blockspec_bms::marginal")?;
1544 let marginal_dense =
1545 widen_marginal_dense_with_influence(&raw_marginal_dense, influence_columns)?;
1546 let logslope_dense = logslope_design
1547 .design
1548 .try_to_dense_arc("build_marginal_blockspec_bms::logslope")?;
1549 let callback: Arc<dyn BlockEffectiveJacobian> = Arc::new(BmsMarginalJacobian {
1550 marginal_dense: Arc::clone(&marginal_dense),
1551 logslope_dense,
1552 offset_m: offset_m.clone(),
1553 offset_s,
1554 p_marginal,
1555 });
1556 let (penalties, nullspace_dims, initial_log_lambdas) = marginal_penalties_with_influence_ridge(
1557 design,
1558 &rho,
1559 influence_columns,
1560 influence_ridge_log_lambda,
1561 )?;
1562 Ok(ParameterBlockSpec {
1563 name: "marginal_surface".to_string(),
1564 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1565 (*marginal_dense).clone(),
1566 )),
1567 offset: offset_m,
1568 penalties,
1569 nullspace_dims,
1570 initial_log_lambdas,
1571 initial_beta: widen_marginal_beta_hint(beta_hint, p_marginal),
1572 gauge_priority: GAUGE_PRIORITY_MARGINAL,
1585 jacobian_callback: Some(callback),
1586 stacked_design: None,
1587 stacked_offset: None,
1588 })
1589}
1590
1591pub(crate) fn build_logslope_blockspec_bms(
1592 design: &TermCollectionDesign,
1593 baseline: f64,
1594 offset: &Array1<f64>,
1595 rho: Array1<f64>,
1596 beta_hint: Option<Array1<f64>>,
1597 marginal_design: &TermCollectionDesign,
1598 marginal_offset: &Array1<f64>,
1599 marginal_baseline: f64,
1600 z: Arc<Array1<f64>>,
1601 p_marginal: usize,
1602 influence_columns: Option<&Array2<f64>>,
1603) -> Result<ParameterBlockSpec, String> {
1604 let offset_s = offset + baseline;
1605 let offset_m = marginal_offset + marginal_baseline;
1606 let raw_marginal_dense = marginal_design
1607 .design
1608 .try_to_dense_arc("build_logslope_blockspec_bms::marginal")?;
1609 let marginal_dense =
1614 widen_marginal_dense_with_influence(&raw_marginal_dense, influence_columns)?;
1615 let logslope_dense = design
1616 .design
1617 .try_to_dense_arc("build_logslope_blockspec_bms::logslope")?;
1618 let callback: Arc<dyn BlockEffectiveJacobian> = Arc::new(BmsLogslopeJacobian {
1619 marginal_dense,
1620 logslope_dense: Arc::clone(&logslope_dense),
1621 offset_m,
1622 offset_s: offset_s.clone(),
1623 z,
1624 p_marginal,
1625 });
1626 Ok(ParameterBlockSpec {
1627 name: "logslope_surface".to_string(),
1628 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1629 (*logslope_dense).clone(),
1630 )),
1631 offset: offset_s,
1632 penalties: design.penalties_as_penalty_matrix(),
1633 nullspace_dims: design.nullspace_dims.clone(),
1634 initial_log_lambdas: rho,
1635 initial_beta: beta_hint,
1636 gauge_priority: GAUGE_PRIORITY_LOGSLOPE,
1644 jacobian_callback: Some(callback),
1645 stacked_design: None,
1646 stacked_offset: None,
1647 })
1648}
1649
1650pub(crate) fn build_deviation_aux_blockspec(
1651 name: &str,
1652 prepared: &DeviationPrepared,
1653 rho: Array1<f64>,
1654 beta_hint: Option<Array1<f64>>,
1655) -> Result<ParameterBlockSpec, String> {
1656 let mut block = prepared.block.clone();
1657 block.initial_log_lambdas = Some(rho);
1658 let candidate_beta = beta_hint.or_else(|| Some(Array1::<f64>::zeros(block.design.ncols())));
1659 block.initial_beta = candidate_beta
1660 .map(|beta| {
1661 let zero = Array1::<f64>::zeros(beta.len());
1662 project_monotone_feasible_beta(&prepared.runtime, &zero, &beta, name)
1663 })
1664 .transpose()?;
1665 let mut spec = block.intospec(name)?;
1666 spec.gauge_priority = match name {
1676 "link_dev" => GAUGE_PRIORITY_LINK_DEV,
1677 "score_warp_dev" => GAUGE_PRIORITY_SCORE_WARP_DEV,
1684 _ => GAUGE_PRIORITY_DEVIATION_DEFAULT,
1685 };
1686 Ok(spec)
1687}
1688
1689pub(crate) fn push_deviation_aux_blockspecs(
1690 blocks: &mut Vec<ParameterBlockSpec>,
1691 rho: &Array1<f64>,
1692 cursor: &mut usize,
1693 score_warp_prepared: Option<&DeviationPrepared>,
1694 link_dev_prepared: Option<&DeviationPrepared>,
1695 score_warp_beta_hint: Option<Array1<f64>>,
1696 link_dev_beta_hint: Option<Array1<f64>>,
1697) -> Result<(), String> {
1698 if let Some(prepared) = score_warp_prepared {
1699 let rho_h = rho
1700 .slice(s![*cursor..*cursor + prepared.block.penalties.len()])
1701 .to_owned();
1702 *cursor += prepared.block.penalties.len();
1703 blocks.push(build_deviation_aux_blockspec(
1704 "score_warp_dev",
1705 prepared,
1706 rho_h,
1707 score_warp_beta_hint,
1708 )?);
1709 }
1710 if let Some(prepared) = link_dev_prepared {
1711 let rho_w = rho
1712 .slice(s![*cursor..*cursor + prepared.block.penalties.len()])
1713 .to_owned();
1714 blocks.push(build_deviation_aux_blockspec(
1715 "link_dev",
1716 prepared,
1717 rho_w,
1718 link_dev_beta_hint,
1719 )?);
1720 }
1721 Ok(())
1722}
1723
1724fn inner_fit(
1725 family: &BernoulliMarginalSlopeFamily,
1726 blocks: &[ParameterBlockSpec],
1727 options: &BlockwiseFitOptions,
1728) -> Result<UnifiedFitResult, String> {
1729 let mut options = options.clone();
1730 options.use_outer_hessian = false;
1735 options.outer_tol = options.outer_tol.max(2.0e-5);
1736 fit_custom_family(family, blocks, &options).map_err(|e| e.to_string())
1737}
1738
1739pub fn fit_bernoulli_marginal_slope_terms(
1740 data: ArrayView2<'_, f64>,
1741 spec: BernoulliMarginalSlopeTermSpec,
1742 options: &BlockwiseFitOptions,
1743 kappa_options: &SpatialLengthScaleOptimizationOptions,
1744 policy: &gam_runtime::resource::ResourcePolicy,
1745) -> Result<BernoulliMarginalSlopeFitResult, String> {
1746 let mut spec = spec;
1747 let data_view = data;
1748 validate_spec(data_view, &spec)?;
1749 let mjs_frozen_marginal =
1758 gam_terms::smooth::freeze_measure_jet_length_scale_learning(&mut spec.marginalspec);
1759 let mjs_frozen_logslope =
1760 gam_terms::smooth::freeze_measure_jet_length_scale_learning(&mut spec.logslopespec);
1761 if mjs_frozen_marginal + mjs_frozen_logslope > 0 {
1762 log::info!(
1763 "[BMS spatial] froze measure-jet length-scale learning on {} marginal + {} log-slope \
1764 term(s): the coupled surface keeps ℓ at its conditioned auto value (#1116)",
1765 mjs_frozen_marginal,
1766 mjs_frozen_logslope
1767 );
1768 }
1769 let mut effective_kappa_options = kappa_options.clone();
1770 let kappa_locked_marginal = gam_terms::smooth::all_spatial_terms_kappa_fixed(&spec.marginalspec);
1780 let kappa_locked_logslope = gam_terms::smooth::all_spatial_terms_kappa_fixed(&spec.logslopespec);
1781 if effective_kappa_options.enabled && kappa_locked_marginal && kappa_locked_logslope {
1782 log::info!(
1783 "[BMS spatial] disabling κ/ψ optimization: every spatial term has an \
1784 explicit length_scale and no anisotropy; user-supplied kernel scale is fixed"
1785 );
1786 effective_kappa_options.enabled = false;
1787 }
1788 let flex_spatial_pilot_path = (spec.score_warp.is_some() || spec.link_dev.is_some())
1789 && spec.y.len() >= BMS_FLEX_SPATIAL_OUTER_PILOT_ROW_THRESHOLD
1790 && effective_kappa_options.enabled;
1791 if flex_spatial_pilot_path {
1792 let marginal_terms = spatial_length_scale_term_indices(&spec.marginalspec);
1793 let logslope_terms = spatial_length_scale_term_indices(&spec.logslopespec);
1794 let marginal_updates = apply_spatial_anisotropy_pilot_initializer(
1795 data_view,
1796 &mut spec.marginalspec,
1797 &marginal_terms,
1798 effective_kappa_options.pilot_subsample_threshold,
1799 &effective_kappa_options,
1800 );
1801 let logslope_updates = apply_spatial_anisotropy_pilot_initializer(
1802 data_view,
1803 &mut spec.logslopespec,
1804 &logslope_terms,
1805 effective_kappa_options.pilot_subsample_threshold,
1806 &effective_kappa_options,
1807 );
1808 effective_kappa_options.enabled = false;
1809 log::info!(
1810 "[BMS spatial] n={} flex=true pilot_geometry_updates={} iterative_spatial_outer=false reason=large-flex-spatial-pilot",
1811 spec.y.len(),
1812 marginal_updates + logslope_updates,
1813 );
1814 }
1815 let (z_standardized, z_normalization) = standardize_latent_z_with_policy(
1816 &spec.z,
1817 &spec.weights,
1818 "bernoulli-marginal-slope",
1819 &spec.latent_z_policy,
1820 )?;
1821 spec.z = z_standardized;
1822 let sigma_learnable = matches!(
1823 &spec.frailty,
1824 FrailtySpec::GaussianShift { sigma_fixed: None }
1825 );
1826 let initial_sigma = match &spec.frailty {
1827 FrailtySpec::GaussianShift {
1828 sigma_fixed: Some(s),
1829 } => Some(*s),
1830 FrailtySpec::GaussianShift { sigma_fixed: None } => Some(0.5),
1831 FrailtySpec::None => None,
1832 FrailtySpec::HazardMultiplier { .. } => {
1833 return Err(
1834 "internal: validate_spec should have rejected unsupported marginal-slope frailty"
1835 .to_string(),
1836 );
1837 }
1838 };
1839 let probit_scale = probit_frailty_scale(initial_sigma);
1840 let (_raw_joint_designs, mut joint_specs) = build_term_collection_designs_and_freeze_joint(
1841 data_view,
1842 &[spec.marginalspec.clone(), spec.logslopespec.clone()],
1843 )
1844 .map_err(|e| e.to_string())?;
1845 let marginalspec_boot = joint_specs.remove(0);
1846 let logslopespec_boot = joint_specs.remove(0);
1847 let (mut joint_designs, _) = build_term_collection_designs_and_freeze_joint(
1864 data_view,
1865 &[marginalspec_boot.clone(), logslopespec_boot.clone()],
1866 )
1867 .map_err(|e| format!("failed to rebuild frozen probe BMS joint designs: {e}"))?;
1868 let marginal_design = joint_designs.remove(0);
1869 let logslope_design = joint_designs.remove(0);
1870 let absorber_active = spec
1877 .score_influence_jacobian
1878 .as_ref()
1879 .is_some_and(|j| j.ncols() > 0);
1880 let conditioning_dense = if absorber_active {
1881 None
1882 } else {
1883 Some(
1884 marginal_design
1885 .design
1886 .try_to_dense_arc("bernoulli marginal-slope conditional latent-z gate")?,
1887 )
1888 };
1889 let (latent_measure, latent_z_calibration) = build_latent_measure_with_geometry(
1890 &spec.z,
1891 &spec.weights,
1892 &spec.latent_z_policy,
1893 conditioning_dense.as_ref().map(|d| d.view()),
1894 )?;
1895 if latent_measure.is_empirical() && sigma_learnable {
1896 return Err("empirical latent-measure marginal-slope calibration requires fixed GaussianShift sigma; learnable sigma derivatives must be fit under the standard-normal latent measure"
1897 .to_string());
1898 }
1899
1900 let y = Arc::new(spec.y.clone());
1901 let weights = Arc::new(spec.weights.clone());
1902 let z = match &latent_z_calibration {
1907 LatentMeasureCalibration::None => Arc::new(spec.z.clone()),
1908 LatentMeasureCalibration::RankInverseNormal(cal) => {
1909 Arc::new(cal.apply_to_training(&spec.z)?)
1910 }
1911 LatentMeasureCalibration::ConditionalLocationScale(cal) => {
1912 let a_block = conditioning_dense.as_ref().ok_or_else(|| {
1915 "conditional latent calibration requires the marginal conditioning block"
1916 .to_string()
1917 })?;
1918 Arc::new(cal.apply(spec.z.view(), a_block.view())?)
1919 }
1920 };
1921 let z_train = z.as_ref();
1922 let pilot_baseline = pooled_probit_baseline(&spec.y, z_train, &spec.weights)?;
1923 let baseline = (
1924 bernoulli_marginal_slope_eta_from_probability(
1925 &spec.base_link,
1926 normal_cdf(pilot_baseline.0),
1927 "bernoulli marginal-slope baseline link inversion",
1928 )?,
1929 pilot_baseline.1 / probit_scale,
1930 );
1931
1932 let rigid_pilot_eta = rigid_pooled_probit_pilot_eta(
1975 &spec.base_link,
1976 z_train,
1977 &spec.marginal_offset,
1978 &spec.logslope_offset,
1979 baseline.0,
1980 baseline.1,
1981 probit_scale,
1982 )?;
1983 let cross_block_pilot_w_score_warp =
1984 pilot_irls_hessian_row_metric_at_eta(&rigid_pilot_eta, &spec.weights);
1985
1986 let influence_columns = if let Some(jac) = spec
1997 .score_influence_jacobian
1998 .as_ref()
1999 .filter(|j| j.ncols() > 0)
2000 {
2001 let marginal_dense_for_proj = marginal_design
2002 .design
2003 .try_to_dense_arc("bernoulli marginal-slope influence-block marginal projection")?;
2004 let marginal_dense = marginal_dense_for_proj.as_ref();
2005 if jac.nrows() != marginal_dense.nrows() {
2006 return Err(format!(
2007 "influence block: Jacobian has {} rows, marginal design has {}",
2008 jac.nrows(),
2009 marginal_dense.nrows()
2010 ));
2011 }
2012 let rigid_logslope_at_rows = &spec.logslope_offset + baseline.1;
2019 let residualized =
2020 crate::marginal_slope_orthogonal::residualized_influence_block(
2021 jac,
2022 z_train,
2023 &rigid_logslope_at_rows,
2024 probit_scale,
2025 marginal_dense.view(),
2026 &cross_block_pilot_w_score_warp,
2027 )?;
2028 Some(residualized)
2029 } else {
2030 None
2031 };
2032 let mut cross_block_warnings: Vec<CrossBlockIdentifiabilityWarning> = Vec::new();
2033 let score_warp_prepared = if let Some(cfg) = spec.score_warp.as_ref() {
2034 use super::deviation_runtime::ParametricAnchorBlock;
2035 let mut prepared = build_score_warp_deviation_block_from_seed(z_train, cfg)?;
2036 let outcome = install_compiled_flex_block_into_runtime(
2041 &mut prepared,
2042 z_train,
2043 cfg,
2044 &[
2045 (&marginal_design.design, ParametricAnchorBlock::Marginal),
2046 (&logslope_design.design, ParametricAnchorBlock::Logslope),
2047 ],
2048 &[],
2049 &cross_block_pilot_w_score_warp,
2050 )?;
2051 match outcome {
2052 FlexCompileOutcome::Reparameterised => Some(prepared),
2053 FlexCompileOutcome::FullyAliased { reason } => {
2054 cross_block_warnings.push(CrossBlockIdentifiabilityWarning {
2060 candidate_label: "score_warp",
2061 anchor_summary: "marginal+logslope".to_string(),
2062 reason,
2063 });
2064 Some(prepared)
2065 }
2066 }
2067 } else {
2068 None
2069 };
2070 let link_dev_prepared = if let Some(cfg) = spec.link_dev.as_ref() {
2096 let eta_pilot = pilot_eta_for_link_dev_orthogonalisation(
2097 &spec.base_link,
2098 &spec.y,
2099 z_train,
2100 &spec.weights,
2101 &marginal_design.design,
2102 &spec.marginal_offset,
2103 &spec.logslope_offset,
2104 baseline.0,
2105 baseline.1,
2106 probit_scale,
2107 )?;
2108 let link_dev_seed = padded_deviation_seed(&eta_pilot, 1.0, 0.5);
2109 let mut prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
2110 &link_dev_seed,
2111 &eta_pilot,
2112 cfg,
2113 )?;
2114 let score_warp_anchor_design = score_warp_prepared
2151 .as_ref()
2152 .map(|sw| sw.runtime.design_at_training_with_residual(z_train))
2153 .transpose()?;
2154 use super::deviation_runtime::ParametricAnchorBlock;
2155 let parametric_anchors: [(&DesignMatrix, ParametricAnchorBlock); 2] = [
2156 (&marginal_design.design, ParametricAnchorBlock::Marginal),
2157 (&logslope_design.design, ParametricAnchorBlock::Logslope),
2158 ];
2159 let flex_anchor_slot: Option<&Array2<f64>> = score_warp_anchor_design.as_ref();
2160 let flex_anchors: Vec<&Array2<f64>> = flex_anchor_slot.into_iter().collect();
2161 let cross_block_pilot_w_link_dev =
2166 pilot_irls_hessian_row_metric_at_eta(&eta_pilot, &spec.weights);
2167 let outcome = install_compiled_flex_block_into_runtime(
2168 &mut prepared,
2169 &eta_pilot,
2170 cfg,
2171 ¶metric_anchors,
2172 &flex_anchors,
2173 &cross_block_pilot_w_link_dev,
2174 )?;
2175 match outcome {
2176 FlexCompileOutcome::Reparameterised => Some(prepared),
2177 FlexCompileOutcome::FullyAliased { reason } => {
2178 cross_block_warnings.push(CrossBlockIdentifiabilityWarning {
2184 candidate_label: "link_deviation",
2185 anchor_summary: "marginal+logslope+score_warp".to_string(),
2186 reason,
2187 });
2188 Some(prepared)
2189 }
2190 }
2191 } else {
2192 None
2193 };
2194 let extra_rho0 = {
2195 let mut out = Vec::new();
2196 if let Some(ref prepared) = score_warp_prepared {
2197 out.extend(std::iter::repeat_n(0.0, prepared.block.penalties.len()));
2198 }
2199 if let Some(ref prepared) = link_dev_prepared {
2200 out.extend(std::iter::repeat_n(0.0, prepared.block.penalties.len()));
2201 }
2202 out
2203 };
2204 let logslope_reduced_reparam: Option<ReducedLogslopeReparam> = build_reduced_logslope_reparam(
2217 &marginal_design,
2218 &logslope_design,
2219 z.as_ref(),
2220 &cross_block_pilot_w_score_warp,
2221 &spec.marginal_offset,
2222 &spec.logslope_offset,
2223 baseline.0,
2224 baseline.1,
2225 probit_scale,
2226 )?;
2227 let reduce_logslope_design =
2233 |logslope_design: &TermCollectionDesign| -> Result<TermCollectionDesign, String> {
2234 match logslope_reduced_reparam.as_ref() {
2235 Some(reparam) => reparameterize_logslope_design_reduced(logslope_design, reparam),
2236 None => Ok(logslope_design.clone()),
2237 }
2238 };
2239
2240 let marginal_penalty_count = marginal_design.penalties.len();
2241 let setup = joint_setup(
2242 data_view,
2243 &marginalspec_boot,
2244 &logslopespec_boot,
2245 marginal_penalty_count,
2246 logslope_design.penalties.len(),
2247 &extra_rho0,
2248 &effective_kappa_options,
2249 );
2250 let setup = if sigma_learnable {
2251 setup.with_auxiliary(
2252 Array1::from_vec(vec![initial_sigma.expect("learnable sigma seed").ln()]),
2253 Array1::from_vec(vec![0.01_f64.ln()]),
2254 Array1::from_vec(vec![5.0_f64.ln()]),
2255 )
2256 } else {
2257 setup
2258 };
2259 let final_sigma_cell = std::cell::Cell::new(initial_sigma);
2260 let exact_warm_start = RefCell::new(None::<CustomFamilyWarmStart>);
2261 let runaway_error = RefCell::new(None::<String>);
2262 let pending_beta_seed = RefCell::new(None::<Array1<f64>>);
2269 let hints = RefCell::new(ThetaHints::default());
2270 let score_warp_runtime = score_warp_prepared.as_ref().map(|p| p.runtime.clone());
2271 let link_dev_runtime = link_dev_prepared.as_ref().map(|p| p.runtime.clone());
2272
2273 let build_blocks = |rho: &Array1<f64>,
2274 marginal_design: &TermCollectionDesign,
2275 logslope_design: &TermCollectionDesign|
2276 -> Result<Vec<ParameterBlockSpec>, String> {
2277 let hints = hints.borrow();
2278 let mut cursor = 0usize;
2279 let logslope_design_reduced = reduce_logslope_design(logslope_design)?;
2286 let logslope_design = &logslope_design_reduced;
2287 let rho_marginal = rho
2292 .slice(s![cursor..cursor + marginal_design.penalties.len()])
2293 .to_owned();
2294 cursor += marginal_design.penalties.len();
2295 let rho_logslope = rho
2296 .slice(s![cursor..cursor + logslope_design.penalties.len()])
2297 .to_owned();
2298 cursor += logslope_design.penalties.len();
2299 let p_m = marginal_design.design.ncols()
2300 + influence_columns.as_ref().map(|z| z.ncols()).unwrap_or(0);
2301 let mut blocks = vec![
2302 build_marginal_blockspec_bms(
2303 marginal_design,
2304 baseline.0,
2305 &spec.marginal_offset,
2306 rho_marginal,
2307 hints.marginal_beta.clone(),
2308 logslope_design,
2309 &spec.logslope_offset,
2310 baseline.1,
2311 p_m,
2312 influence_columns.as_ref(),
2313 INFLUENCE_ABSORBER_FIXED_LOG_LAMBDA,
2314 )?,
2315 build_logslope_blockspec_bms(
2316 logslope_design,
2317 baseline.1,
2318 &spec.logslope_offset,
2319 rho_logslope,
2320 hints.logslope_beta.clone(),
2321 marginal_design,
2322 &spec.marginal_offset,
2323 baseline.0,
2324 Arc::clone(&z),
2325 p_m,
2326 influence_columns.as_ref(),
2327 )?,
2328 ];
2329 push_deviation_aux_blockspecs(
2330 &mut blocks,
2331 rho,
2332 &mut cursor,
2333 score_warp_prepared.as_ref(),
2334 link_dev_prepared.as_ref(),
2335 hints.score_warp_beta.clone(),
2336 hints.link_dev_beta.clone(),
2337 )?;
2338 Ok(blocks)
2339 };
2340
2341 let intercept_warm_starts = new_intercept_warm_start_cache(y.len());
2342 let cell_moment_lru = new_cell_moment_lru_cache(policy);
2343 let cell_moment_cache_stats = new_cell_moment_cache_stats();
2344 let make_family = |marginal_design: &TermCollectionDesign,
2345 logslope_design: &TermCollectionDesign,
2346 sigma: Option<f64>|
2347 -> BernoulliMarginalSlopeFamily {
2348 let kernel_marginal_design = match influence_columns.as_ref() {
2354 Some(z_infl) => {
2355 let raw = marginal_design
2356 .design
2357 .try_to_dense_arc("make_family::widened-marginal")
2358 .expect("dense marginal design for influence widening");
2359 let widened = widen_marginal_dense_with_influence(&raw, Some(z_infl))
2360 .expect("widen marginal design with influence columns");
2361 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from((*widened).clone()))
2362 }
2363 None => marginal_design.design.clone(),
2364 };
2365 let kernel_logslope_design = reduce_logslope_design(logslope_design)
2371 .expect("reduce logslope design for family construction")
2372 .design;
2373 BernoulliMarginalSlopeFamily {
2374 y: Arc::clone(&y),
2375 weights: Arc::clone(&weights),
2376 z: Arc::clone(&z),
2377 latent_measure: latent_measure.clone(),
2378 gaussian_frailty_sd: sigma,
2379 base_link: spec.base_link.clone(),
2380 marginal_design: kernel_marginal_design,
2381 logslope_design: kernel_logslope_design,
2382 score_warp: score_warp_runtime.clone(),
2383 link_dev: link_dev_runtime.clone(),
2384 policy: policy.clone(),
2385 cell_moment_lru: Arc::clone(&cell_moment_lru),
2386 cell_moment_cache_stats: Arc::clone(&cell_moment_cache_stats),
2387 intercept_warm_starts: Some(Arc::clone(&intercept_warm_starts)),
2388 auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
2389 auto_subsample_last_rho: Arc::new(Mutex::new(None)),
2390 }
2391 };
2392
2393 let marginal_terms = spatial_length_scale_term_indices(&marginalspec_boot);
2394 let logslope_terms = spatial_length_scale_term_indices(&logslopespec_boot);
2395 let marginal_has_spatial = !marginal_terms.is_empty();
2396 let logslope_has_spatial = !logslope_terms.is_empty();
2397 let analytic_joint_derivatives_available =
2398 marginal_has_spatial || logslope_has_spatial || setup.log_kappa_dim() == 0;
2399 if setup.log_kappa_dim() > 0 && !analytic_joint_derivatives_available {
2400 return Err("exact bernoulli marginal-slope spatial optimization requires analytic joint psi derivatives"
2401 .to_string());
2402 }
2403 let initial_rho = setup.theta0().slice(s![..setup.rho_dim()]).to_owned();
2404 let initial_blocks = build_blocks(&initial_rho, &marginal_design, &logslope_design)?;
2405 let initial_family = make_family(&marginal_design, &logslope_design, initial_sigma);
2406 let (joint_gradient, joint_hessian) =
2407 custom_family_outer_derivatives(&initial_family, &initial_blocks, options);
2408 let analytic_joint_gradient_available = analytic_joint_derivatives_available
2409 && matches!(joint_gradient, gam_problem::Derivative::Analytic);
2410 let analytic_joint_hessian_available =
2416 analytic_joint_derivatives_available && joint_hessian.is_analytic();
2417 let kappa_options_ref: &SpatialLengthScaleOptimizationOptions = &effective_kappa_options;
2418 let sigma_from_theta = |theta: &Array1<f64>| -> Option<f64> {
2419 if sigma_learnable {
2420 Some(theta[setup.rho_dim() + setup.log_kappa_dim()].exp())
2421 } else {
2422 initial_sigma
2423 }
2424 };
2425 let derivative_block_cache = RefCell::new(
2426 None::<(
2427 Array1<f64>,
2428 Arc<Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>>,
2429 )>,
2430 );
2431 let theta_matches = |left: &Array1<f64>, right: &Array1<f64>| -> bool {
2432 left.len() == right.len()
2433 && left
2434 .iter()
2435 .zip(right.iter())
2436 .all(|(lhs, rhs)| (*lhs - *rhs).abs() <= 1e-12 * (1.0 + lhs.abs().max(rhs.abs())))
2437 };
2438 let get_derivative_blocks = |theta: &Array1<f64>,
2439 specs: &[TermCollectionSpec],
2440 designs: &[TermCollectionDesign]|
2441 -> Result<
2442 Arc<Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>>,
2443 String,
2444 > {
2445 if let Some((cached_theta, cached_blocks)) = derivative_block_cache.borrow().as_ref()
2446 && theta_matches(cached_theta, theta)
2447 {
2448 return Ok(Arc::clone(cached_blocks));
2449 }
2450
2451 let built = |specs: &[TermCollectionSpec],
2452 designs: &[TermCollectionDesign]|
2453 -> Result<
2454 Vec<Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>>,
2455 String,
2456 > {
2457 let marginal_psi_derivs = if marginal_has_spatial {
2458 build_block_spatial_psi_derivatives(data_view, &specs[0], &designs[0])?.ok_or_else(
2459 || {
2460 "bernoulli marginal-slope: marginal block has spatial terms \
2461 but spatial psi derivatives are unavailable"
2462 .to_string()
2463 },
2464 )?
2465 } else {
2466 Vec::new()
2467 };
2468 let logslope_psi_derivs = if logslope_has_spatial {
2469 build_block_spatial_psi_derivatives(data_view, &specs[1], &designs[1])?.ok_or_else(
2470 || {
2471 "bernoulli marginal-slope: logslope block has spatial terms \
2472 but spatial psi derivatives are unavailable"
2473 .to_string()
2474 },
2475 )?
2476 } else {
2477 Vec::new()
2478 };
2479 let mut derivative_blocks = vec![marginal_psi_derivs, logslope_psi_derivs];
2480 if score_warp_runtime.is_some() {
2481 derivative_blocks.push(Vec::new());
2482 }
2483 if link_dev_runtime.is_some() {
2484 derivative_blocks.push(Vec::new());
2485 }
2486 if sigma_learnable {
2487 derivative_blocks
2488 .last_mut()
2489 .expect("bernoulli derivative block list is non-empty")
2490 .push(crate::custom_family::CustomFamilyBlockPsiDerivative::new(
2491 None,
2492 Array2::zeros((0, 0)),
2493 Array2::zeros((0, 0)),
2494 None,
2495 None,
2496 None,
2497 None,
2498 ));
2499 }
2500 Ok(derivative_blocks)
2501 }(specs, designs)?;
2502 let built = Arc::new(built);
2503 derivative_block_cache.replace(Some((theta.clone(), Arc::clone(&built))));
2504 Ok(built)
2505 };
2506
2507 let outer_policy = {
2512 let psi_dim = setup.theta0().len() - setup.rho_dim();
2513 initial_family.outer_derivative_policy(&initial_blocks, psi_dim, options)
2514 };
2515 let exact_spatial_outer_tol = kappa_options_ref.rel_tol.max(EXACT_SPATIAL_OUTER_TOL_FLOOR);
2516 let solved = optimize_spatial_length_scale_exact_joint(
2517 data_view,
2518 &[marginalspec_boot.clone(), logslopespec_boot.clone()],
2519 &[marginal_terms.clone(), logslope_terms.clone()],
2520 kappa_options_ref,
2521 &setup,
2522 gam_solve::seeding::SeedRiskProfile::GeneralizedLinear,
2523 analytic_joint_gradient_available,
2524 analytic_joint_hessian_available,
2525 true,
2526 None,
2527 outer_policy,
2528 |theta, specs: &[TermCollectionSpec], designs: &[TermCollectionDesign]| {
2529 if let Some(err) = runaway_error.borrow().as_ref().cloned() {
2530 return Err(err);
2531 }
2532 assert_eq!(
2533 specs.len(),
2534 designs.len(),
2535 "spatial joint optimizer must supply one spec per design",
2536 );
2537 let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
2538 let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
2539 let sigma = sigma_from_theta(theta);
2540 final_sigma_cell.set(sigma);
2541 let family = make_family(&designs[0], &designs[1], sigma);
2542 let fit = inner_fit(&family, &blocks, options)?;
2543 if let Some(block) = fit.block_states.first()
2544 && let Some(err) = bernoulli_marginal_slope_runaway_error_from_beta(
2545 block.beta.view(),
2546 &designs[0],
2547 &specs[0],
2548 fit.outer_converged,
2549 "final fit",
2550 )
2551 {
2552 runaway_error.replace(Some(err.clone()));
2553 return Err(err);
2554 }
2555 let mut hints_mut = hints.borrow_mut();
2556 let mut bidx = 0usize;
2557 if let Some(block) = fit.block_states.get(bidx) {
2558 hints_mut.marginal_beta = Some(block.beta.clone());
2559 }
2560 bidx += 1;
2561 if let Some(block) = fit.block_states.get(bidx) {
2562 hints_mut.logslope_beta = Some(block.beta.clone());
2563 }
2564 bidx += 1;
2565 if score_warp_prepared.is_some() {
2566 if let Some(block) = fit.block_states.get(bidx) {
2567 hints_mut.score_warp_beta = Some(block.beta.clone());
2568 }
2569 bidx += 1;
2570 }
2571 if link_dev_prepared.is_some()
2572 && let Some(block) = fit.block_states.get(bidx)
2573 {
2574 hints_mut.link_dev_beta = Some(block.beta.clone());
2575 }
2576 Ok(fit)
2577 },
2578 |theta,
2579 specs: &[TermCollectionSpec],
2580 designs: &[TermCollectionDesign],
2581 eval_mode,
2582 row_set: &crate::row_kernel::RowSet| {
2583 if let Some(err) = runaway_error.borrow().as_ref().cloned() {
2584 return Err(err);
2585 }
2586 use gam_problem::EvalMode;
2587 static BMS_OUTER_EVAL_ROWSET_LOGGED: std::sync::Once = std::sync::Once::new();
2594 BMS_OUTER_EVAL_ROWSET_LOGGED.call_once(|| {
2595 let row_set_rows = match row_set {
2596 crate::row_kernel::RowSet::All => spec.y.len(),
2597 crate::row_kernel::RowSet::Subsample { rows, .. } => rows.len(),
2598 };
2599 log::debug!(
2600 "[BMS exact outer eval] mode={eval_mode:?} row_set_rows={row_set_rows}"
2601 );
2602 });
2603 let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
2604 let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
2605 if let Some(beta_seed) = pending_beta_seed.borrow_mut().take() {
2609 let widths: Vec<usize> = blocks.iter().map(|b| b.design.ncols()).collect();
2610 match CustomFamilyWarmStart::from_cached_beta(&widths, &beta_seed) {
2611 Ok(ws) => {
2612 exact_warm_start.replace(Some(ws));
2613 }
2614 Err(e) => {
2615 log::warn!(
2616 "[BMS] outer ρ-cache β-warm-start rejected: {e}; falling back to cold β"
2617 );
2618 }
2619 }
2620 }
2621 let sigma = sigma_from_theta(theta);
2622 final_sigma_cell.set(sigma);
2623 let family = make_family(&designs[0], &designs[1], sigma);
2624 let derivative_blocks = get_derivative_blocks(theta, specs, designs)?;
2625 let effective_mode = match eval_mode {
2629 EvalMode::ValueGradientHessian if !analytic_joint_hessian_available => {
2630 EvalMode::ValueAndGradient
2631 }
2632 other => other,
2633 };
2634 let mut eval_options =
2635 joint_hyper_options_for_outer_tolerance(options, exact_spatial_outer_tol);
2636 if let crate::row_kernel::RowSet::Subsample { rows, n_full } = row_set {
2637 let subsample = crate::outer_subsample::OuterScoreSubsample::from_weighted_rows(
2638 rows.as_ref().clone(),
2639 *n_full,
2640 0,
2641 );
2642 eval_options.outer_score_subsample = Some(Arc::new(subsample));
2643 eval_options.auto_outer_subsample = false;
2644 }
2645 let eval = evaluate_custom_family_joint_hyper_shared(
2646 &family,
2647 &blocks,
2648 &eval_options,
2649 &rho,
2650 derivative_blocks,
2651 exact_warm_start.borrow().as_ref(),
2652 effective_mode,
2653 )?;
2654 if let Some(err) = bernoulli_marginal_slope_runaway_error(
2655 &eval.warm_start,
2656 &designs[0],
2657 &specs[0],
2658 eval.inner_converged,
2659 "exact outer evaluation",
2660 ) {
2661 runaway_error.replace(Some(err.clone()));
2662 return Err(err);
2663 }
2664 exact_warm_start.replace(Some(eval.warm_start.clone()));
2665 if !eval.inner_converged {
2666 return Err(
2667 "exact bernoulli marginal-slope inner solve did not converge".to_string(),
2668 );
2669 }
2670 if matches!(eval_mode, EvalMode::ValueGradientHessian)
2671 && analytic_joint_hessian_available
2672 && !eval.outer_hessian.is_analytic()
2673 {
2674 return Err("exact bernoulli marginal-slope joint [rho, psi] objective did not return an outer Hessian"
2675 .to_string());
2676 }
2677 Ok((eval.objective, eval.gradient, eval.outer_hessian))
2678 },
2679 |theta, specs: &[TermCollectionSpec], designs: &[TermCollectionDesign]| {
2680 if let Some(err) = runaway_error.borrow().as_ref().cloned() {
2681 return Err(err);
2682 }
2683 let rho = theta.slice(s![..setup.rho_dim()]).to_owned();
2684 let blocks = build_blocks(&rho, &designs[0], &designs[1])?;
2685 if let Some(beta_seed) = pending_beta_seed.borrow_mut().take() {
2686 let widths: Vec<usize> = blocks.iter().map(|b| b.design.ncols()).collect();
2687 match CustomFamilyWarmStart::from_cached_beta(&widths, &beta_seed) {
2688 Ok(ws) => {
2689 exact_warm_start.replace(Some(ws));
2690 }
2691 Err(e) => {
2692 log::warn!(
2693 "[BMS] outer ρ-cache β-warm-start rejected (efs): {e}; falling back to cold β"
2694 );
2695 }
2696 }
2697 }
2698 let sigma = sigma_from_theta(theta);
2699 final_sigma_cell.set(sigma);
2700 let family = make_family(&designs[0], &designs[1], sigma);
2701 let derivative_blocks = get_derivative_blocks(theta, specs, designs)?;
2702 let eval = evaluate_custom_family_joint_hyper_efs_shared(
2703 &family,
2704 &blocks,
2705 &joint_hyper_options_for_outer_tolerance(options, exact_spatial_outer_tol),
2706 &rho,
2707 derivative_blocks,
2708 exact_warm_start.borrow().as_ref(),
2709 )?;
2710 if let Some(err) = bernoulli_marginal_slope_runaway_error(
2711 &eval.warm_start,
2712 &designs[0],
2713 &specs[0],
2714 eval.inner_converged,
2715 "EFS outer evaluation",
2716 ) {
2717 runaway_error.replace(Some(err.clone()));
2718 return Err(err);
2719 }
2720 exact_warm_start.replace(Some(eval.warm_start.clone()));
2721 if !eval.inner_converged {
2722 return Err(
2723 "exact bernoulli marginal-slope EFS inner solve did not converge".to_string(),
2724 );
2725 }
2726 Ok(eval.efs_eval)
2727 },
2728 crate::marginal_slope_shared::make_beta_seed_validator(&pending_beta_seed),
2729 )?;
2730
2731 let mut resolved_specs = solved.resolved_specs;
2732 let mut designs = solved.designs;
2733 let mut solved_fit = solved.fit;
2745 if let Some(reparam) = logslope_reduced_reparam.as_ref() {
2746 let r = reparam.reduced_cols();
2747 if let Some(block) = solved_fit.blocks.get_mut(1)
2748 && block.beta.len() == r
2749 {
2750 block.beta = reparam.recover_original_logslope_beta(&block.beta)?;
2751 }
2752 if let Some(state) = solved_fit.block_states.get_mut(1)
2753 && state.beta.len() == r
2754 {
2755 state.beta = reparam.recover_original_logslope_beta(&state.beta)?;
2756 }
2757 }
2758 let (latent_z_rank_int_calibration, latent_z_conditional_calibration) =
2800 match latent_z_calibration {
2801 LatentMeasureCalibration::None => (None, None),
2802 LatentMeasureCalibration::RankInverseNormal(cal) => (Some(cal), None),
2803 LatentMeasureCalibration::ConditionalLocationScale(cal) => (None, Some(cal)),
2804 };
2805 if let Some(cal) = latent_z_conditional_calibration.as_ref()
2815 && let Some(vb) = solved_fit.covariance_conditional.clone()
2816 {
2817 let p_beta = vb.nrows();
2818 let marginal_dense = marginal_design
2819 .design
2820 .try_to_dense_arc("bms generated-regressor marginal design")?;
2821 let logslope_reduced = reduce_logslope_design(&logslope_design)?;
2822 let logslope_reduced_dense = logslope_reduced
2823 .design
2824 .try_to_dense_arc("bms generated-regressor reduced logslope design")?;
2825 let p_m = marginal_dense.ncols();
2826 let r = logslope_reduced_dense.ncols();
2827 if p_beta != vb.ncols() {
2828 return Err(format!(
2829 "bms generated-regressor: covariance_conditional must be square, got {}×{}",
2830 vb.nrows(),
2831 vb.ncols()
2832 ));
2833 }
2834 if p_beta == p_m + r {
2838 let marginal_eta = &solved_fit.block_states[0].eta;
2839 let slope_eta = &solved_fit.block_states[1].eta;
2840 let probit_scale = probit_frailty_scale(final_sigma_cell.get());
2841 let s = rigid_standard_normal_score_zeta_sensitivity(
2842 &spec.base_link,
2843 marginal_eta,
2844 slope_eta,
2845 z.as_ref(),
2846 y.as_ref(),
2847 weights.as_ref(),
2848 probit_scale,
2849 marginal_dense.view(),
2850 logslope_reduced_dense.view(),
2851 p_beta,
2852 )?;
2853 let correction = cal.generated_regressor_correction(
2861 s.view(),
2862 spec.z.view(),
2863 marginal_dense.view(),
2864 vb.view(),
2865 )?;
2866 if let Some(cov) = solved_fit.covariance_conditional.as_mut() {
2867 *cov = &*cov + &correction;
2868 }
2869 if let Some(cov) = solved_fit.covariance_corrected.as_mut() {
2870 *cov = &*cov + &correction;
2871 }
2872 log::info!(
2873 "[BMS latent-z] Murphy–Topel generated-regressor SE correction applied: \
2874 p_beta={p_beta} theta1_dim={} max_diag_inflation={:.3e}",
2875 cal.theta1_dim(),
2876 (0..p_beta)
2877 .map(|i| correction[[i, i]])
2878 .fold(0.0_f64, f64::max),
2879 );
2880 } else {
2881 log::info!(
2882 "[BMS latent-z] Murphy–Topel generated-regressor SE correction skipped: \
2883 aux deviation blocks present (p_beta={p_beta} > marginal({p_m})+logslope({r})); \
2884 rigid-kernel z-channel does not yet cover score_warp/link_dev deviations"
2885 );
2886 }
2887 }
2888 Ok(BernoulliMarginalSlopeFitResult {
2901 fit: solved_fit,
2902 marginalspec_resolved: resolved_specs.remove(0),
2903 logslopespec_resolved: resolved_specs.remove(0),
2904 marginal_design: designs.remove(0),
2905 logslope_design: designs.remove(0),
2906 baseline_marginal: baseline.0,
2907 baseline_logslope: baseline.1,
2908 z_normalization,
2909 latent_measure,
2910 score_warp_runtime,
2911 link_dev_runtime,
2912 gaussian_frailty_sd: final_sigma_cell.get(),
2913 cross_block_warnings,
2914 latent_z_rank_int_calibration,
2915 latent_z_conditional_calibration,
2916 })
2917}