1use super::family::{
2 append_deviation_function_penalty, require_probit_marginal_slope_link,
3 resolve_deviation_operator_orders,
4};
5use super::*;
6
7pub(crate) const MONOTONICITY_SLACK_TOL: f64 = -1e-10;
26
27pub(crate) struct BmsFlexBlockContext {
32 pub(super) anchor_dense_blocks: Vec<Array2<f64>>,
34 pub(super) anchor_components: Vec<super::deviation_runtime::AnchorComponentTag>,
36 pub(super) n_train: Array2<f64>,
38 pub(super) operators:
41 Vec<std::sync::Arc<dyn gam_identifiability::families::compiler::RowJacobianOperator>>,
42 pub(super) ordering: Vec<gam_identifiability::families::compiler::BlockOrder>,
44 pub(super) row_hess: gam_identifiability::families::bernoulli::BernoulliRowHessian,
46 pub(super) candidate_design_dense: Array2<f64>,
49 pub(super) n: usize,
51 pub(super) p_candidate: usize,
53 pub(super) d_total: usize,
55}
56
57pub(crate) fn build_bms_flex_block_context(
64 candidate: &DeviationPrepared,
65 candidate_arg_at_training_rows: &Array1<f64>,
66 parametric_anchors: &[(
67 &DesignMatrix,
68 super::deviation_runtime::ParametricAnchorBlock,
69 )],
70 flex_anchors: &[&Array2<f64>],
71 training_row_weights: &Array1<f64>,
72) -> Result<Option<BmsFlexBlockContext>, String> {
73 use super::deviation_runtime::AnchorComponentTag;
74 use gam_identifiability::families::bernoulli::{
75 BernoulliDenseDesignOperator, BernoulliRowHessian,
76 };
77 use gam_identifiability::families::compiler::{BlockOrder, RowJacobianOperator};
78
79 let candidate_design = candidate.runtime.design(candidate_arg_at_training_rows)?;
80 let n = candidate_design.nrows();
81 let p_candidate = candidate_design.ncols();
82
83 if training_row_weights.len() != n {
84 return Err(format!(
85 "cross-block identifiability: training_row_weights length {} does not match candidate row count {}",
86 training_row_weights.len(),
87 n,
88 ));
89 }
90 for (i, &w) in training_row_weights.iter().enumerate() {
91 if !w.is_finite() || w < 0.0 {
92 return Err(format!(
93 "cross-block identifiability: training_row_weights[{i}] = {w} is not finite/non-negative",
94 ));
95 }
96 }
97
98 let mut anchor_dense_blocks: Vec<Array2<f64>> = Vec::new();
100 let mut anchor_components: Vec<AnchorComponentTag> = Vec::new();
101 let mut total_anchor_cols = 0usize;
102 for (d, block_tag) in parametric_anchors {
103 if d.nrows() != n {
104 return Err(format!(
105 "cross-block identifiability: parametric anchor has {} rows, candidate has {}",
106 d.nrows(),
107 n,
108 ));
109 }
110 let p_a = d.ncols();
111 if p_a == 0 {
112 continue;
113 }
114 let dense = d
115 .try_to_dense_arc("cross-block parametric anchor")?
116 .as_ref()
117 .clone();
118 anchor_dense_blocks.push(dense);
119 anchor_components.push(AnchorComponentTag::Parametric {
120 block: *block_tag,
121 ncols: p_a,
122 });
123 total_anchor_cols += p_a;
124 }
125 for a in flex_anchors {
126 if a.nrows() != n {
127 return Err(format!(
128 "cross-block identifiability: flex anchor has {} rows, candidate has {}",
129 a.nrows(),
130 n,
131 ));
132 }
133 let p_a = a.ncols();
134 if p_a == 0 {
135 continue;
136 }
137 anchor_dense_blocks.push((*a).clone());
138 anchor_components.push(AnchorComponentTag::FlexEvaluation { ncols: p_a });
139 total_anchor_cols += p_a;
140 }
141 if total_anchor_cols == 0 {
142 return Ok(None);
143 }
144
145 let d_total = total_anchor_cols;
146 let mut n_train = Array2::<f64>::zeros((n, d_total));
147 {
148 let mut col_offset = 0usize;
149 for block in &anchor_dense_blocks {
150 let bc = block.ncols();
151 n_train
152 .slice_mut(s![.., col_offset..col_offset + bc])
153 .assign(block);
154 col_offset += bc;
155 }
156 }
157
158 let mut operators: Vec<std::sync::Arc<dyn RowJacobianOperator>> =
161 Vec::with_capacity(anchor_dense_blocks.len() + 1);
162 let mut ordering: Vec<BlockOrder> = Vec::with_capacity(anchor_dense_blocks.len() + 1);
163 for dense in &anchor_dense_blocks {
164 operators.push(std::sync::Arc::new(BernoulliDenseDesignOperator::new(
165 dense.clone(),
166 )));
167 ordering.push(BlockOrder::Marginal);
168 }
169 operators.push(std::sync::Arc::new(BernoulliDenseDesignOperator::new(
170 candidate_design.clone(),
171 )));
172 ordering.push(BlockOrder::LinkDev);
173
174 let row_hess = BernoulliRowHessian::from_row_weights(training_row_weights.clone());
175
176 Ok(Some(BmsFlexBlockContext {
177 anchor_dense_blocks,
178 anchor_components,
179 n_train,
180 operators,
181 ordering,
182 row_hess,
183 candidate_design_dense: candidate_design,
184 n,
185 p_candidate,
186 d_total,
187 }))
188}
189
190#[derive(Debug)]
206pub enum FlexCompileOutcome {
207 Reparameterised,
208 FullyAliased { reason: String },
209}
210
211#[derive(Clone, Debug)]
214pub struct CrossBlockIdentifiabilityWarning {
215 pub candidate_label: &'static str,
216 pub anchor_summary: String,
217 pub reason: String,
218}
219
220pub(crate) fn install_compiled_flex_block_into_runtime(
301 candidate: &mut DeviationPrepared,
302 candidate_arg_at_training_rows: &Array1<f64>,
303 candidate_cfg: &DeviationBlockConfig,
304 parametric_anchors: &[(
305 &DesignMatrix,
306 super::deviation_runtime::ParametricAnchorBlock,
307 )],
308 flex_anchors: &[&Array2<f64>],
309 training_row_weights: &Array1<f64>,
310) -> Result<FlexCompileOutcome, String> {
311 use gam_identifiability::audit::audit_identifiability_channel_aware;
312 use gam_identifiability::families::compiler::compile;
313
314 let p_check = candidate
316 .runtime
317 .design(candidate_arg_at_training_rows)?
318 .ncols();
319 if p_check == 0 {
320 return Ok(FlexCompileOutcome::Reparameterised);
321 }
322
323 let ctx = match build_bms_flex_block_context(
327 candidate,
328 candidate_arg_at_training_rows,
329 parametric_anchors,
330 flex_anchors,
331 training_row_weights,
332 )? {
333 None => {
334 return Ok(FlexCompileOutcome::Reparameterised);
337 }
338 Some(c) => c,
339 };
340 let BmsFlexBlockContext {
341 anchor_dense_blocks,
342 anchor_components,
343 n_train,
344 operators,
345 ordering,
346 row_hess,
347 candidate_design_dense,
348 n,
349 p_candidate,
350 d_total,
351 } = ctx;
352
353 let audit = audit_identifiability_channel_aware(
360 &{
361 let mut specs = Vec::with_capacity(anchor_dense_blocks.len() + 1);
365 for (idx, dense) in anchor_dense_blocks.iter().enumerate() {
366 specs.push(crate::custom_family::ParameterBlockSpec {
367 name: format!("anchor_{idx}"),
368 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
369 dense.clone(),
370 )),
371 offset: Array1::<f64>::zeros(n),
372 penalties: Vec::new(),
373 nullspace_dims: Vec::new(),
374 initial_log_lambdas: Array1::<f64>::zeros(0),
375 initial_beta: None,
376 gauge_priority: super::block_specs::GAUGE_PRIORITY_ANCHOR,
377 jacobian_callback: None,
378 stacked_design: None,
379 stacked_offset: None,
380 });
381 }
382 specs.push(crate::custom_family::ParameterBlockSpec {
383 name: "candidate_flex".to_string(),
384 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
385 candidate_design_dense.clone(),
386 )),
387 offset: Array1::<f64>::zeros(n),
388 penalties: Vec::new(),
389 nullspace_dims: Vec::new(),
390 initial_log_lambdas: Array1::<f64>::zeros(0),
391 initial_beta: None,
392 gauge_priority: super::block_specs::GAUGE_PRIORITY_CANDIDATE_FLEX,
393 jacobian_callback: None,
394 stacked_design: None,
395 stacked_offset: None,
396 });
397 specs
398 },
399 &operators,
400 &row_hess,
401 )
402 .map_err(|e| format!("cross-block identifiability audit failed: {e}"))?;
403
404 if audit.fatal {
405 let candidate_block = audit.blocks.last();
406 let effective = candidate_block.map(|b| b.effective_dim).unwrap_or(0);
407 if effective == 0 {
408 let reason = format!(
409 "candidate flex basis ({p_candidate} cols) has zero directions remaining after \
410 W-metric residualisation against the anchor union ({d_total} anchor cols) at the \
411 {n} training rows. The channel-aware audit collapses every direction in \
412 span(C) — every direction in span(C) is reproducible by the anchor union up to \
413 numerical tolerance. Drop the flex block or remove the anchor term that reproduces \
414 its argument; knot count is NOT the relevant lever for this failure mode.",
415 );
416 return Ok(FlexCompileOutcome::FullyAliased { reason });
417 }
418 }
419
420 let compiled = compile(&operators, &row_hess, &ordering).map_err(|e| {
425 format!(
426 "cross-block identifiability: compile failed (n={n}, d_total={d_total}, p_c={p_candidate}): {e}",
427 )
428 })?;
429 let candidate_compiled = compiled
430 .blocks
431 .last()
432 .ok_or_else(|| "cross-block identifiability: compile returned no blocks".to_string())?;
433 let k_kept = candidate_compiled.t_lw.ncols();
434 if k_kept == 0 {
435 let reason = format!(
436 "candidate flex basis ({p_candidate} cols) has zero directions remaining after \
437 W-metric residualisation against the anchor union ({d_total} anchor cols) at the \
438 {n} training rows. The compiler's joint pre-fit audit collapses every direction in \
439 span(C) — every direction in span(C) is reproducible by the anchor union up to \
440 numerical tolerance. Drop the flex block or remove the anchor term that reproduces \
441 its argument; knot count is NOT the relevant lever for this failure mode.",
442 );
443 return Ok(FlexCompileOutcome::FullyAliased { reason });
444 }
445 {
448 let m = candidate_compiled
449 .anchor_correction
450 .as_ref()
451 .ok_or_else(|| {
452 "cross-block identifiability: compile returned no anchor_correction for the \
453 candidate block (expected for trailing block with non-empty anchor union)"
454 .to_string()
455 })?;
456 if m.nrows() != d_total || m.ncols() != k_kept {
457 return Err(format!(
458 "cross-block identifiability: anchor_correction shape {}×{} does not match \
459 expected d_total={d_total} × k_kept={k_kept}",
460 m.nrows(),
461 m.ncols(),
462 ));
463 }
464 }
465
466 candidate.runtime.install_compiled_flex_block(
471 candidate_compiled,
472 anchor_components,
473 n_train,
474 )?;
475 let new_design = candidate
476 .runtime
477 .design_at_training_with_residual(candidate_arg_at_training_rows)?;
478 let new_p = new_design.ncols();
479 assert_eq!(new_p, k_kept);
480 assert_eq!(new_design.nrows(), n);
481 candidate.block.design =
482 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(new_design));
483 candidate.block.penalties.clear();
484 candidate.block.nullspace_dims.clear();
485 let penalty_orders = resolve_deviation_operator_orders(candidate_cfg)?;
486 for order in penalty_orders {
487 append_deviation_function_penalty(&mut candidate.block, &candidate.runtime, order)?;
488 }
489 if candidate_cfg.double_penalty {
490 append_deviation_function_penalty(&mut candidate.block, &candidate.runtime, 0)?;
491 }
492 candidate.block.initial_beta = Some(Array1::zeros(new_p));
493
494 log::info!(
495 "[BMS cross-block identifiability] flex block reparameterised via compiler: \
496 kept {kept}/{p_candidate} directions (anchor union cols={d_total}, training rows={n}, \
497 joint_rank={joint_rank}, dropped_by_audit={dropped})",
498 kept = new_p,
499 p_candidate = p_candidate,
500 d_total = d_total,
501 n = n,
502 joint_rank = compiled.joint_rank,
503 dropped = compiled.dropped.len(),
504 );
505 Ok(FlexCompileOutcome::Reparameterised)
506}
507
508pub(crate) fn project_monotone_feasible_beta(
509 runtime: &DeviationRuntime,
510 current: &Array1<f64>,
511 proposed: &Array1<f64>,
512 label: &str,
513) -> Result<Array1<f64>, String> {
514 if current.len() != runtime.basis_dim() {
515 return Err(format!(
516 "{label} monotone projection current length mismatch: current={}, expected={}",
517 current.len(),
518 runtime.basis_dim()
519 ));
520 }
521 if proposed.len() != runtime.basis_dim() {
522 return Err(format!(
523 "{label} monotone projection length mismatch: proposed={}, expected={}",
524 proposed.len(),
525 runtime.basis_dim()
526 ));
527 }
528 for (idx, value) in current.iter().enumerate() {
529 if !value.is_finite() {
530 return Err(format!("{label} current coefficient {idx} is non-finite"));
531 }
532 }
533 for (idx, value) in proposed.iter().enumerate() {
534 if !value.is_finite() {
535 return Err(format!("{label} coefficient {idx} is non-finite"));
536 }
537 }
538 runtime.monotonicity_feasible(current, &format!("{label} current beta"))?;
539 if runtime
540 .monotonicity_feasible(proposed, &format!("{label} proposed beta"))
541 .is_ok()
542 {
543 return Ok(proposed.clone());
544 }
545
546 let constraints = runtime.structural_monotonicity_constraints();
547 let alpha = max_linear_constraint_segment_alpha(current, proposed, &constraints, label)?;
548 let direction = proposed - current;
549 let candidate = current + &direction.mapv(|value| value * alpha);
550 validate_monotone_structural_feasible(runtime, &candidate, &format!("{label} projected beta"))?;
551 Ok(candidate)
552}
553
554pub(crate) fn validate_monotone_structural_feasible(
555 runtime: &DeviationRuntime,
556 beta: &Array1<f64>,
557 label: &str,
558) -> Result<(), String> {
559 let constraints = runtime.structural_monotonicity_constraints();
560 if beta.len() != constraints.a.ncols() {
561 return Err(format!(
562 "{label} structural monotonicity length mismatch: beta={}, expected={}",
563 beta.len(),
564 constraints.a.ncols()
565 ));
566 }
567 if beta.iter().any(|value| !value.is_finite()) {
568 let bad = beta
569 .iter()
570 .enumerate()
571 .find(|(_, value)| !value.is_finite())
572 .map(|(idx, value)| format!("{label} coefficient {idx} is non-finite ({value})"))
573 .unwrap_or_else(|| format!("{label} coefficient is non-finite"));
574 return Err(bad);
575 }
576 let slack = constraints.a.dot(beta) - &constraints.b;
577 let mut min_slack = f64::INFINITY;
578 let mut min_row = 0usize;
579 for (row, &value) in slack.iter().enumerate() {
580 if value < min_slack {
581 min_slack = value;
582 min_row = row;
583 }
584 }
585 if min_slack < MONOTONICITY_SLACK_TOL {
586 return Err(format!(
587 "{label} violates structural monotonicity row {min_row}: slack={min_slack:.3e}; \
588 deviation monotonicity must be enforced by analytic linear constraints, not post-update projection"
589 ));
590 }
591 runtime.monotonicity_feasible(beta, label)
592}
593
594pub(crate) fn max_linear_constraint_segment_alpha(
595 current: &Array1<f64>,
596 proposed: &Array1<f64>,
597 constraints: &LinearInequalityConstraints,
598 label: &str,
599) -> Result<f64, String> {
600 if current.len() != proposed.len() || current.len() != constraints.a.ncols() {
601 return Err(format!(
602 "{label} linear-constraint segment dimension mismatch: current={}, proposed={}, constraints={}",
603 current.len(),
604 proposed.len(),
605 constraints.a.ncols()
606 ));
607 }
608 if constraints.a.nrows() != constraints.b.len() {
609 return Err(format!(
610 "{label} linear-constraint segment row mismatch: A rows={}, b len={}",
611 constraints.a.nrows(),
612 constraints.b.len()
613 ));
614 }
615 let direction = proposed - current;
616 let mut alpha = 1.0_f64;
617 for row in 0..constraints.a.nrows() {
618 let a_row = constraints.a.row(row);
619 let slack = a_row.dot(current) - constraints.b[row];
620 if slack < MONOTONICITY_SLACK_TOL {
621 return Err(format!(
622 "{label} current beta violates structural monotonicity row {row}: slack={slack:.3e}"
623 ));
624 }
625 let drift = a_row.dot(&direction);
626 if drift < 0.0 {
627 alpha = alpha.min((slack / -drift).clamp(0.0, 1.0));
628 }
629 }
630 Ok(alpha.clamp(0.0, 1.0))
631}
632
633pub(super) fn validate_spec(
634 data: ArrayView2<'_, f64>,
635 spec: &BernoulliMarginalSlopeTermSpec,
636) -> Result<(), String> {
637 let n = data.nrows();
638 if spec.y.len() != n
639 || spec.weights.len() != n
640 || spec.z.len() != n
641 || spec.marginal_offset.len() != n
642 || spec.logslope_offset.len() != n
643 {
644 return Err(format!(
645 "bernoulli-marginal-slope row mismatch: data={}, y={}, weights={}, z={}, marginal_offset={}, logslope_offset={}",
646 n,
647 spec.y.len(),
648 spec.weights.len(),
649 spec.z.len(),
650 spec.marginal_offset.len(),
651 spec.logslope_offset.len()
652 ));
653 }
654 if spec
655 .y
656 .iter()
657 .any(|&yi| !yi.is_finite() || ((yi - 0.0).abs() > 1e-9 && (yi - 1.0).abs() > 1e-9))
658 {
659 return Err("bernoulli-marginal-slope requires binary y in {0,1}".to_string());
660 }
661 if spec.weights.iter().any(|&w| !w.is_finite() || w < 0.0) {
662 return Err("bernoulli-marginal-slope requires finite non-negative weights".to_string());
663 }
664 if spec.z.iter().any(|&zi| !zi.is_finite()) {
665 return Err("bernoulli-marginal-slope requires finite z values".to_string());
666 }
667 if spec.marginal_offset.iter().any(|&value| !value.is_finite()) {
668 return Err("bernoulli-marginal-slope requires finite marginal offsets".to_string());
669 }
670 if spec.logslope_offset.iter().any(|&value| !value.is_finite()) {
671 return Err("bernoulli-marginal-slope requires finite logslope offsets".to_string());
672 }
673 if let Some(jac) = spec.score_influence_jacobian.as_ref() {
674 if jac.nrows() != n {
677 return Err(format!(
678 "bernoulli-marginal-slope score_influence_jacobian has {} rows, expected {n}",
679 jac.nrows()
680 ));
681 }
682 if jac.iter().any(|&value| !value.is_finite()) {
683 return Err(
684 "bernoulli-marginal-slope score_influence_jacobian must be finite".to_string(),
685 );
686 }
687 }
688 require_probit_marginal_slope_link(&spec.base_link, "bernoulli-marginal-slope")?;
689 spec.frailty.validate_for_marginal_slope()?;
690 match &spec.frailty {
691 FrailtySpec::None => {}
692 FrailtySpec::GaussianShift { sigma_fixed } => {
693 if let Some(sigma) = sigma_fixed
694 && (!sigma.is_finite() || *sigma < 0.0)
695 {
696 return Err(format!(
697 "bernoulli-marginal-slope requires GaussianShift sigma >= 0, got {sigma}"
698 ));
699 }
700 }
701 FrailtySpec::HazardMultiplier { .. } => {
702 return Err(
703 "bernoulli-marginal-slope does not support FrailtySpec::HazardMultiplier"
704 .to_string(),
705 );
706 }
707 }
708 Ok(())
709}