Skip to main content

gam_models/bms/
install_flex.rs

1use super::family::{
2    append_deviation_function_penalty, require_probit_marginal_slope_link,
3    resolve_deviation_operator_orders,
4};
5use super::*;
6
7//      needed for both the audit gate and the compile step.
8//   2. `audit_identifiability_channel_aware` — structural rank gate using
9//      the BMS K=1 row Jacobian; catches full aliasing before any install.
10//   3. `identifiability::families::compiler::compile` — W-metric Gram + eigendecomp,
11//      produces the V selector and anchor-correction M.
12//   4. Install V/M into the `DeviationRuntime` via `install_compiled_flex_block`,
13//      rebuild the block's design + penalties, and return `FlexCompileOutcome`.
14//
15// The K=1 row-Jacobian math still runs through `identifiability::families::compiler::compile`,
16// so there is exactly one cross-block residualisation math implementation in
17// the codebase.
18
19/// Tolerance (in constraint units) by which a structural-monotonicity slack
20/// `A·β − b` may dip below zero before we treat it as a genuine violation
21/// rather than floating-point round-off in the constraint inner products. The
22/// constraint rows are O(1)-scaled deviation differences, so a few ulps of
23/// accumulation sit comfortably under this bound while any real infeasibility
24/// is orders of magnitude larger.
25pub(crate) const MONOTONICITY_SLACK_TOL: f64 = -1e-10;
26
27/// Assembled inputs for the BMS flex-block spec-builder → compile pipeline.
28///
29/// Produced by [`build_bms_flex_block_context`] and consumed by
30/// [`install_compiled_flex_block_into_runtime`].
31pub(crate) struct BmsFlexBlockContext {
32    /// Densified anchor blocks in parametric-before-flex order.
33    pub(super) anchor_dense_blocks: Vec<Array2<f64>>,
34    /// Per-anchor predict-time tags (same order as `anchor_dense_blocks`).
35    pub(super) anchor_components: Vec<super::deviation_runtime::AnchorComponentTag>,
36    /// Horizontally stacked anchor matrix N_train (n × d_total).
37    pub(super) n_train: Array2<f64>,
38    /// `BernoulliDenseDesignOperator` per anchor, then one for the candidate
39    /// (trailing). Indices align with `ordering`.
40    pub(super) operators:
41        Vec<std::sync::Arc<dyn gam_identifiability::families::compiler::RowJacobianOperator>>,
42    /// Block-order tags parallel to `operators`.
43    pub(super) ordering: Vec<gam_identifiability::families::compiler::BlockOrder>,
44    /// W-metric row Hessian built from the validated `training_row_weights`.
45    pub(super) row_hess: gam_identifiability::families::bernoulli::BernoulliRowHessian,
46    /// Dense candidate basis at training rows (n × p_candidate), cached to
47    /// avoid a second `design()` call after context construction.
48    pub(super) candidate_design_dense: Array2<f64>,
49    /// Number of training rows.
50    pub(super) n: usize,
51    /// Raw column count of the candidate block (= `candidate_design_dense.ncols()`).
52    pub(super) p_candidate: usize,
53    /// Total anchor columns (= `n_train.ncols()`).
54    pub(super) d_total: usize,
55}
56
57/// Validate inputs, densify anchors, stack N_train, and assemble the
58/// `BernoulliDenseDesignOperator` / `BlockOrder` / `BernoulliRowHessian`
59/// vectors needed by both [`audit_identifiability_channel_aware`] and
60/// [`identifiability::families::compiler::compile`].
61///
62/// Returns `Ok(None)` when the anchor union is empty (no-anchor fast path).
63pub(crate) fn build_bms_flex_block_context(
64    candidate: &DeviationPrepared,
65    candidate_arg_at_training_rows: &Array1<f64>,
66    parametric_anchors: &[(
67        &DesignMatrix,
68        super::deviation_runtime::ParametricAnchorBlock,
69    )],
70    flex_anchors: &[&Array2<f64>],
71    training_row_weights: &Array1<f64>,
72) -> Result<Option<BmsFlexBlockContext>, String> {
73    use super::deviation_runtime::AnchorComponentTag;
74    use gam_identifiability::families::bernoulli::{
75        BernoulliDenseDesignOperator, BernoulliRowHessian,
76    };
77    use gam_identifiability::families::compiler::{BlockOrder, RowJacobianOperator};
78
79    let candidate_design = candidate.runtime.design(candidate_arg_at_training_rows)?;
80    let n = candidate_design.nrows();
81    let p_candidate = candidate_design.ncols();
82
83    if training_row_weights.len() != n {
84        return Err(format!(
85            "cross-block identifiability: training_row_weights length {} does not match candidate row count {}",
86            training_row_weights.len(),
87            n,
88        ));
89    }
90    for (i, &w) in training_row_weights.iter().enumerate() {
91        if !w.is_finite() || w < 0.0 {
92            return Err(format!(
93                "cross-block identifiability: training_row_weights[{i}] = {w} is not finite/non-negative",
94            ));
95        }
96    }
97
98    // Densify parametric anchors (parametric-before-flex ordering).
99    let mut anchor_dense_blocks: Vec<Array2<f64>> = Vec::new();
100    let mut anchor_components: Vec<AnchorComponentTag> = Vec::new();
101    let mut total_anchor_cols = 0usize;
102    for (d, block_tag) in parametric_anchors {
103        if d.nrows() != n {
104            return Err(format!(
105                "cross-block identifiability: parametric anchor has {} rows, candidate has {}",
106                d.nrows(),
107                n,
108            ));
109        }
110        let p_a = d.ncols();
111        if p_a == 0 {
112            continue;
113        }
114        let dense = d
115            .try_to_dense_arc("cross-block parametric anchor")?
116            .as_ref()
117            .clone();
118        anchor_dense_blocks.push(dense);
119        anchor_components.push(AnchorComponentTag::Parametric {
120            block: *block_tag,
121            ncols: p_a,
122        });
123        total_anchor_cols += p_a;
124    }
125    for a in flex_anchors {
126        if a.nrows() != n {
127            return Err(format!(
128                "cross-block identifiability: flex anchor has {} rows, candidate has {}",
129                a.nrows(),
130                n,
131            ));
132        }
133        let p_a = a.ncols();
134        if p_a == 0 {
135            continue;
136        }
137        anchor_dense_blocks.push((*a).clone());
138        anchor_components.push(AnchorComponentTag::FlexEvaluation { ncols: p_a });
139        total_anchor_cols += p_a;
140    }
141    if total_anchor_cols == 0 {
142        return Ok(None);
143    }
144
145    let d_total = total_anchor_cols;
146    let mut n_train = Array2::<f64>::zeros((n, d_total));
147    {
148        let mut col_offset = 0usize;
149        for block in &anchor_dense_blocks {
150            let bc = block.ncols();
151            n_train
152                .slice_mut(s![.., col_offset..col_offset + bc])
153                .assign(block);
154            col_offset += bc;
155        }
156    }
157
158    // Build BernoulliDenseDesignOperator per anchor block, then one for the
159    // candidate (trailing, BlockOrder::LinkDev).
160    let mut operators: Vec<std::sync::Arc<dyn RowJacobianOperator>> =
161        Vec::with_capacity(anchor_dense_blocks.len() + 1);
162    let mut ordering: Vec<BlockOrder> = Vec::with_capacity(anchor_dense_blocks.len() + 1);
163    for dense in &anchor_dense_blocks {
164        operators.push(std::sync::Arc::new(BernoulliDenseDesignOperator::new(
165            dense.clone(),
166        )));
167        ordering.push(BlockOrder::Marginal);
168    }
169    operators.push(std::sync::Arc::new(BernoulliDenseDesignOperator::new(
170        candidate_design.clone(),
171    )));
172    ordering.push(BlockOrder::LinkDev);
173
174    let row_hess = BernoulliRowHessian::from_row_weights(training_row_weights.clone());
175
176    Ok(Some(BmsFlexBlockContext {
177        anchor_dense_blocks,
178        anchor_components,
179        n_train,
180        operators,
181        ordering,
182        row_hess,
183        candidate_design_dense: candidate_design,
184        n,
185        p_candidate,
186        d_total,
187    }))
188}
189
190/// Outcome of [`install_compiled_flex_block_into_runtime`].
191///
192/// * `Reparameterised` — the candidate was reparameterised in place so
193///   its column span at the n training rows is orthogonal to the anchor
194///   union. Kept/dropped direction counts are emitted via the
195///   `[BMS cross-block identifiability]` log at the construction site;
196///   callers only need to know which branch they're in to decide whether
197///   to keep the prepared block.
198/// * `FullyAliased` — every direction in span(C) is reproducible by the
199///   anchor union (`(I − P_A) C` has numerical rank zero). The candidate
200///   carries no information independent of the anchors and the caller
201///   should drop it from the design with a structured warning rather than
202///   continue with a zero-rank block. The candidate is left in its
203///   pre-call state (the reparameterisation is never applied) so the
204///   caller can safely discard it.
205#[derive(Debug)]
206pub enum FlexCompileOutcome {
207    Reparameterised,
208    FullyAliased { reason: String },
209}
210
211/// Structured warning surfaced by the BMS family when a candidate flex
212/// block is fully aliased by its anchor union and gets dropped.
213#[derive(Clone, Debug)]
214pub struct CrossBlockIdentifiabilityWarning {
215    pub candidate_label: &'static str,
216    pub anchor_summary: String,
217    pub reason: String,
218}
219
220/// Enforce joint-design identifiability for a single flex block by
221/// reparameterising its basis so its column span at the n training rows
222/// is orthogonal to the union of every supplied anchor's column span.
223///
224/// This is the standard GAM `gam.side` convention generalised to multiple
225/// anchor sources. After applying the resulting reparameterisation `T`,
226/// the joint design `[anchor₁ | anchor₂ | … | candidate · T]` has full
227/// numerical column rank, so `σ_min(joint H+S) ≥ λ_min(S) > 0` for every
228/// β regardless of how the linear-predictor distribution shifts during
229/// PIRLS. This eliminates the near-null direction in the joint penalised
230/// Hessian that arises whenever the candidate flex block's column span
231/// overlaps an anchor's column span (parametric span aliasing, flex-flex
232/// aliasing, or both simultaneously).
233///
234/// # Math
235///
236/// Let `C ∈ ℝⁿˣᵖᶜ` be the candidate basis evaluated at the n training
237/// rows, `A ∈ ℝⁿˣᵈ` the horizontally stacked parametric anchors, and
238/// `W = diag(training_row_weights)`. The W-metric projector onto span(A)
239/// is `P_A = A (AᵀWA)⁻¹ AᵀW`; the W-residualised candidate is
240/// `C̃ = (I − P_A) C`. The joint reparameterisation
241///
242///   `Aβ_A + Cβ_C = A(β_A + Bβ_C) + (C − AB)β_C`     with `B = (AᵀWA)⁻¹AᵀWC`
243///
244/// is block-triangular, so dropping the columns of C̃ that have negligible
245/// `C̃ᵀ W C̃` eigenvalues drops exactly the directions span(C) shares with
246/// span(A) — under the actual Hessian row metric W = p(1−p), not the
247/// uniform metric. Concretely: factor `AᵀWA = U Λ Uᵀ` and let
248/// `R = U₊ Λ₊⁻½` so `Q_w = AR` is W-orthonormal under W. Then
249/// `K_w = Q_wᵀ W C = Rᵀ AᵀW C` and `C̃ = C − AR · K_w`. After selecting
250/// the kept eigenvector matrix V of `C̃ᵀ W C̃`, the residual
251/// `M = R K_w V` is what each evaluated row subtracts:
252/// `design_row(x) = pure_span_row(x) · V − n_row(x) · M`.
253///
254/// Why the old `null(AᵀC)` test is wrong: it asks "which candidate
255/// directions are *already* exactly orthogonal to A?" rather than "what
256/// remains after projecting A out?". `null(AᵀC) ≠ ∅` is NOT equivalent
257/// to `span(C) ⊆ span(A)` — the equivalence is
258/// `span(C) ⊆ span(A) ⇔ (I − P_A) C = 0`. Whenever d ≥ p_c (anchor
259/// wider than candidate), `null(AᵀC)` is generically empty even if C
260/// carries plenty of information independent of A.
261///
262/// # Cost
263///
264/// `AᵀWA` is `d × d` (d = total parametric anchor cols), built as one
265/// matmul on the sqrt-W-scaled `A`. `K_w` is one `Q_wᵀ · (W^½ C)` matmul
266/// of size `r × p_c`. `C̃ᵀ W C̃` is `p_c × p_c`. Two `eigh`s, both small
267/// (`d ≲ a few dozen`, `p_c ≲ 50`); negligible against the per-cycle
268/// dense Hessian build at large scale. `DesignMatrix` parametric
269/// anchors are densified once into a contiguous `n × d` block (a few
270/// dozen columns).
271///
272/// # `training_row_weights` (the W in the W-metric)
273///
274/// Callers **must** pass the IRLS Hessian row metric the joint Hessian
275/// will see during PIRLS, not bare sample weights. For the probit-style
276/// Bernoulli marginal slope family that is
277/// `w[i] = sample_weights[i] · φ(η_i)² / (μ_i·(1−μ_i))` at a β-independent
278/// pilot η. Passing uniform `spec.weights` instead makes A and C̃ merely
279/// Euclidean-orthogonal: `Aᵀ W_pirls C̃` is nonzero at PIRLS time, the
280/// joint Hessian carries a near-null direction along the W-metric alias,
281/// and REML can drive the flex block's λ small enough that the alias
282/// direction's joint Hessian eigenvalue collapses — manifesting as the
283/// well-known runaway (rho≈2.0, constant `step_inf`, growing `beta_inf`,
284/// inner loop hitting `inner_max_cycles` without satisfying the KKT
285/// residual). See `pilot_irls_hessian_row_metric_at_eta`.
286///
287/// # No-op fast paths
288///
289/// * Anchor list is empty, or every anchor has zero parametric columns.
290/// * `r = 0` — `AᵀWA` is numerically zero (degenerate weights).
291///
292/// # Hard error
293///
294/// `(I − P_A) C` has numerical rank zero — every direction in span(C) is
295/// reproducible by the anchors up to tolerance. The candidate flex block
296/// carries no information the parametric blocks do not already capture in
297/// their unpenalised span; the diagnostic surfaces this explicitly rather
298/// than letting the inner solver collide with the resulting rank-deficient
299/// Hessian.
300pub(crate) fn install_compiled_flex_block_into_runtime(
301    candidate: &mut DeviationPrepared,
302    candidate_arg_at_training_rows: &Array1<f64>,
303    candidate_cfg: &DeviationBlockConfig,
304    parametric_anchors: &[(
305        &DesignMatrix,
306        super::deviation_runtime::ParametricAnchorBlock,
307    )],
308    flex_anchors: &[&Array2<f64>],
309    training_row_weights: &Array1<f64>,
310) -> Result<FlexCompileOutcome, String> {
311    use gam_identifiability::audit::audit_identifiability_channel_aware;
312    use gam_identifiability::families::compiler::compile;
313
314    // Fast path: zero-column candidate carries nothing to residualise.
315    let p_check = candidate
316        .runtime
317        .design(candidate_arg_at_training_rows)?
318        .ncols();
319    if p_check == 0 {
320        return Ok(FlexCompileOutcome::Reparameterised);
321    }
322
323    // Step 1 — spec-builder: validate inputs, densify anchors, stack N_train,
324    // assemble operators + row_hess. Returns None when the anchor union is
325    // empty (no residualisation needed).
326    let ctx = match build_bms_flex_block_context(
327        candidate,
328        candidate_arg_at_training_rows,
329        parametric_anchors,
330        flex_anchors,
331        training_row_weights,
332    )? {
333        None => {
334            // No anchors — the candidate's per-block smoothness-null-space
335            // drop already handles intra-block aliasing.
336            return Ok(FlexCompileOutcome::Reparameterised);
337        }
338        Some(c) => c,
339    };
340    let BmsFlexBlockContext {
341        anchor_dense_blocks,
342        anchor_components,
343        n_train,
344        operators,
345        ordering,
346        row_hess,
347        candidate_design_dense,
348        n,
349        p_candidate,
350        d_total,
351    } = ctx;
352
353    // Step 2 — audit gate: `audit_identifiability_channel_aware` uses the
354    // structural BMS K=1 row Jacobian to detect full aliasing before any
355    // install. A fatal audit with effective_dim == 0 for the trailing
356    // (candidate) block means every direction in span(C) is reproducible by
357    // the anchor union; return FullyAliased immediately without touching the
358    // runtime.
359    let audit = audit_identifiability_channel_aware(
360        &{
361            // Build minimal ParameterBlockSpec wrappers so the audit can record
362            // block names and column counts. The specs are audit-only; no
363            // penalties or log-lambdas are needed here.
364            let mut specs = Vec::with_capacity(anchor_dense_blocks.len() + 1);
365            for (idx, dense) in anchor_dense_blocks.iter().enumerate() {
366                specs.push(crate::custom_family::ParameterBlockSpec {
367                    name: format!("anchor_{idx}"),
368                    design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
369                        dense.clone(),
370                    )),
371                    offset: Array1::<f64>::zeros(n),
372                    penalties: Vec::new(),
373                    nullspace_dims: Vec::new(),
374                    initial_log_lambdas: Array1::<f64>::zeros(0),
375                    initial_beta: None,
376                    gauge_priority: super::block_specs::GAUGE_PRIORITY_ANCHOR,
377                    jacobian_callback: None,
378                    stacked_design: None,
379                    stacked_offset: None,
380                });
381            }
382            specs.push(crate::custom_family::ParameterBlockSpec {
383                name: "candidate_flex".to_string(),
384                design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
385                    candidate_design_dense.clone(),
386                )),
387                offset: Array1::<f64>::zeros(n),
388                penalties: Vec::new(),
389                nullspace_dims: Vec::new(),
390                initial_log_lambdas: Array1::<f64>::zeros(0),
391                initial_beta: None,
392                gauge_priority: super::block_specs::GAUGE_PRIORITY_CANDIDATE_FLEX,
393                jacobian_callback: None,
394                stacked_design: None,
395                stacked_offset: None,
396            });
397            specs
398        },
399        &operators,
400        &row_hess,
401    )
402    .map_err(|e| format!("cross-block identifiability audit failed: {e}"))?;
403
404    if audit.fatal {
405        let candidate_block = audit.blocks.last();
406        let effective = candidate_block.map(|b| b.effective_dim).unwrap_or(0);
407        if effective == 0 {
408            let reason = format!(
409                "candidate flex basis ({p_candidate} cols) has zero directions remaining after \
410                 W-metric residualisation against the anchor union ({d_total} anchor cols) at the \
411                 {n} training rows. The channel-aware audit collapses every direction in \
412                 span(C) — every direction in span(C) is reproducible by the anchor union up to \
413                 numerical tolerance. Drop the flex block or remove the anchor term that reproduces \
414                 its argument; knot count is NOT the relevant lever for this failure mode.",
415            );
416            return Ok(FlexCompileOutcome::FullyAliased { reason });
417        }
418    }
419
420    // Step 3 — W-metric compile: Gram + eigendecomp → V selector (t_lw) and
421    // anchor-correction M. The compiler runs at K=1 (BMS row primary state =
422    // scalar η) using `BernoulliRowHessian` as the row metric. This is the
423    // single math implementation of the cross-block W-metric residualisation.
424    let compiled = compile(&operators, &row_hess, &ordering).map_err(|e| {
425        format!(
426            "cross-block identifiability: compile failed (n={n}, d_total={d_total}, p_c={p_candidate}): {e}",
427        )
428    })?;
429    let candidate_compiled = compiled
430        .blocks
431        .last()
432        .ok_or_else(|| "cross-block identifiability: compile returned no blocks".to_string())?;
433    let k_kept = candidate_compiled.t_lw.ncols();
434    if k_kept == 0 {
435        let reason = format!(
436            "candidate flex basis ({p_candidate} cols) has zero directions remaining after \
437             W-metric residualisation against the anchor union ({d_total} anchor cols) at the \
438             {n} training rows. The compiler's joint pre-fit audit collapses every direction in \
439             span(C) — every direction in span(C) is reproducible by the anchor union up to \
440             numerical tolerance. Drop the flex block or remove the anchor term that reproduces \
441             its argument; knot count is NOT the relevant lever for this failure mode.",
442        );
443        return Ok(FlexCompileOutcome::FullyAliased { reason });
444    }
445    // Shape contract: compile() must emit (d_total × k_kept) anchor_correction
446    // for the trailing candidate block.
447    {
448        let m = candidate_compiled
449            .anchor_correction
450            .as_ref()
451            .ok_or_else(|| {
452                "cross-block identifiability: compile returned no anchor_correction for the \
453             candidate block (expected for trailing block with non-empty anchor union)"
454                    .to_string()
455            })?;
456        if m.nrows() != d_total || m.ncols() != k_kept {
457            return Err(format!(
458                "cross-block identifiability: anchor_correction shape {}×{} does not match \
459                 expected d_total={d_total} × k_kept={k_kept}",
460                m.nrows(),
461                m.ncols(),
462            ));
463        }
464    }
465
466    // Step 4 — install: wrap compiled output into the runtime as an
467    // InstalledFlexBlock (anchor_correction M + anchor_components tags),
468    // cache N_train, apply selector V to span_c{0..3} + boundary/monotonicity
469    // rows, then rebuild the block's design + penalties in the new basis.
470    candidate.runtime.install_compiled_flex_block(
471        candidate_compiled,
472        anchor_components,
473        n_train,
474    )?;
475    let new_design = candidate
476        .runtime
477        .design_at_training_with_residual(candidate_arg_at_training_rows)?;
478    let new_p = new_design.ncols();
479    assert_eq!(new_p, k_kept);
480    assert_eq!(new_design.nrows(), n);
481    candidate.block.design =
482        DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(new_design));
483    candidate.block.penalties.clear();
484    candidate.block.nullspace_dims.clear();
485    let penalty_orders = resolve_deviation_operator_orders(candidate_cfg)?;
486    for order in penalty_orders {
487        append_deviation_function_penalty(&mut candidate.block, &candidate.runtime, order)?;
488    }
489    if candidate_cfg.double_penalty {
490        append_deviation_function_penalty(&mut candidate.block, &candidate.runtime, 0)?;
491    }
492    candidate.block.initial_beta = Some(Array1::zeros(new_p));
493
494    log::info!(
495        "[BMS cross-block identifiability] flex block reparameterised via compiler: \
496         kept {kept}/{p_candidate} directions (anchor union cols={d_total}, training rows={n}, \
497         joint_rank={joint_rank}, dropped_by_audit={dropped})",
498        kept = new_p,
499        p_candidate = p_candidate,
500        d_total = d_total,
501        n = n,
502        joint_rank = compiled.joint_rank,
503        dropped = compiled.dropped.len(),
504    );
505    Ok(FlexCompileOutcome::Reparameterised)
506}
507
508pub(crate) fn project_monotone_feasible_beta(
509    runtime: &DeviationRuntime,
510    current: &Array1<f64>,
511    proposed: &Array1<f64>,
512    label: &str,
513) -> Result<Array1<f64>, String> {
514    if current.len() != runtime.basis_dim() {
515        return Err(format!(
516            "{label} monotone projection current length mismatch: current={}, expected={}",
517            current.len(),
518            runtime.basis_dim()
519        ));
520    }
521    if proposed.len() != runtime.basis_dim() {
522        return Err(format!(
523            "{label} monotone projection length mismatch: proposed={}, expected={}",
524            proposed.len(),
525            runtime.basis_dim()
526        ));
527    }
528    for (idx, value) in current.iter().enumerate() {
529        if !value.is_finite() {
530            return Err(format!("{label} current coefficient {idx} is non-finite"));
531        }
532    }
533    for (idx, value) in proposed.iter().enumerate() {
534        if !value.is_finite() {
535            return Err(format!("{label} coefficient {idx} is non-finite"));
536        }
537    }
538    runtime.monotonicity_feasible(current, &format!("{label} current beta"))?;
539    if runtime
540        .monotonicity_feasible(proposed, &format!("{label} proposed beta"))
541        .is_ok()
542    {
543        return Ok(proposed.clone());
544    }
545
546    let constraints = runtime.structural_monotonicity_constraints();
547    let alpha = max_linear_constraint_segment_alpha(current, proposed, &constraints, label)?;
548    let direction = proposed - current;
549    let candidate = current + &direction.mapv(|value| value * alpha);
550    validate_monotone_structural_feasible(runtime, &candidate, &format!("{label} projected beta"))?;
551    Ok(candidate)
552}
553
554pub(crate) fn validate_monotone_structural_feasible(
555    runtime: &DeviationRuntime,
556    beta: &Array1<f64>,
557    label: &str,
558) -> Result<(), String> {
559    let constraints = runtime.structural_monotonicity_constraints();
560    if beta.len() != constraints.a.ncols() {
561        return Err(format!(
562            "{label} structural monotonicity length mismatch: beta={}, expected={}",
563            beta.len(),
564            constraints.a.ncols()
565        ));
566    }
567    if beta.iter().any(|value| !value.is_finite()) {
568        let bad = beta
569            .iter()
570            .enumerate()
571            .find(|(_, value)| !value.is_finite())
572            .map(|(idx, value)| format!("{label} coefficient {idx} is non-finite ({value})"))
573            .unwrap_or_else(|| format!("{label} coefficient is non-finite"));
574        return Err(bad);
575    }
576    let slack = constraints.a.dot(beta) - &constraints.b;
577    let mut min_slack = f64::INFINITY;
578    let mut min_row = 0usize;
579    for (row, &value) in slack.iter().enumerate() {
580        if value < min_slack {
581            min_slack = value;
582            min_row = row;
583        }
584    }
585    if min_slack < MONOTONICITY_SLACK_TOL {
586        return Err(format!(
587            "{label} violates structural monotonicity row {min_row}: slack={min_slack:.3e}; \
588             deviation monotonicity must be enforced by analytic linear constraints, not post-update projection"
589        ));
590    }
591    runtime.monotonicity_feasible(beta, label)
592}
593
594pub(crate) fn max_linear_constraint_segment_alpha(
595    current: &Array1<f64>,
596    proposed: &Array1<f64>,
597    constraints: &LinearInequalityConstraints,
598    label: &str,
599) -> Result<f64, String> {
600    if current.len() != proposed.len() || current.len() != constraints.a.ncols() {
601        return Err(format!(
602            "{label} linear-constraint segment dimension mismatch: current={}, proposed={}, constraints={}",
603            current.len(),
604            proposed.len(),
605            constraints.a.ncols()
606        ));
607    }
608    if constraints.a.nrows() != constraints.b.len() {
609        return Err(format!(
610            "{label} linear-constraint segment row mismatch: A rows={}, b len={}",
611            constraints.a.nrows(),
612            constraints.b.len()
613        ));
614    }
615    let direction = proposed - current;
616    let mut alpha = 1.0_f64;
617    for row in 0..constraints.a.nrows() {
618        let a_row = constraints.a.row(row);
619        let slack = a_row.dot(current) - constraints.b[row];
620        if slack < MONOTONICITY_SLACK_TOL {
621            return Err(format!(
622                "{label} current beta violates structural monotonicity row {row}: slack={slack:.3e}"
623            ));
624        }
625        let drift = a_row.dot(&direction);
626        if drift < 0.0 {
627            alpha = alpha.min((slack / -drift).clamp(0.0, 1.0));
628        }
629    }
630    Ok(alpha.clamp(0.0, 1.0))
631}
632
633pub(super) fn validate_spec(
634    data: ArrayView2<'_, f64>,
635    spec: &BernoulliMarginalSlopeTermSpec,
636) -> Result<(), String> {
637    let n = data.nrows();
638    if spec.y.len() != n
639        || spec.weights.len() != n
640        || spec.z.len() != n
641        || spec.marginal_offset.len() != n
642        || spec.logslope_offset.len() != n
643    {
644        return Err(format!(
645            "bernoulli-marginal-slope row mismatch: data={}, y={}, weights={}, z={}, marginal_offset={}, logslope_offset={}",
646            n,
647            spec.y.len(),
648            spec.weights.len(),
649            spec.z.len(),
650            spec.marginal_offset.len(),
651            spec.logslope_offset.len()
652        ));
653    }
654    if spec
655        .y
656        .iter()
657        .any(|&yi| !yi.is_finite() || ((yi - 0.0).abs() > 1e-9 && (yi - 1.0).abs() > 1e-9))
658    {
659        return Err("bernoulli-marginal-slope requires binary y in {0,1}".to_string());
660    }
661    if spec.weights.iter().any(|&w| !w.is_finite() || w < 0.0) {
662        return Err("bernoulli-marginal-slope requires finite non-negative weights".to_string());
663    }
664    if spec.z.iter().any(|&zi| !zi.is_finite()) {
665        return Err("bernoulli-marginal-slope requires finite z values".to_string());
666    }
667    if spec.marginal_offset.iter().any(|&value| !value.is_finite()) {
668        return Err("bernoulli-marginal-slope requires finite marginal offsets".to_string());
669    }
670    if spec.logslope_offset.iter().any(|&value| !value.is_finite()) {
671        return Err("bernoulli-marginal-slope requires finite logslope offsets".to_string());
672    }
673    if let Some(jac) = spec.score_influence_jacobian.as_ref() {
674        // #461: the absorbed Stage-1 influence Jacobian J = ∂z/∂θ₁ must be an
675        // n×p₁ matrix of finite entries co-indexed with the training rows.
676        if jac.nrows() != n {
677            return Err(format!(
678                "bernoulli-marginal-slope score_influence_jacobian has {} rows, expected {n}",
679                jac.nrows()
680            ));
681        }
682        if jac.iter().any(|&value| !value.is_finite()) {
683            return Err(
684                "bernoulli-marginal-slope score_influence_jacobian must be finite".to_string(),
685            );
686        }
687    }
688    require_probit_marginal_slope_link(&spec.base_link, "bernoulli-marginal-slope")?;
689    spec.frailty.validate_for_marginal_slope()?;
690    match &spec.frailty {
691        FrailtySpec::None => {}
692        FrailtySpec::GaussianShift { sigma_fixed } => {
693            if let Some(sigma) = sigma_fixed
694                && (!sigma.is_finite() || *sigma < 0.0)
695            {
696                return Err(format!(
697                    "bernoulli-marginal-slope requires GaussianShift sigma >= 0, got {sigma}"
698                ));
699            }
700        }
701        FrailtySpec::HazardMultiplier { .. } => {
702            return Err(
703                "bernoulli-marginal-slope does not support FrailtySpec::HazardMultiplier"
704                    .to_string(),
705            );
706        }
707    }
708    Ok(())
709}