Skip to main content

gam_models/bms/
deviation_runtime.rs

1use gam_terms::basis::create_ispline_derivative_dense;
2use gam_linalg::faer_ndarray::{FaerEigh, fast_ab};
3use crate::cubic_cell_kernel as exact_kernel;
4use gam_solve::pirls::LinearInequalityConstraints;
5use crate::util::span::span_index_for_breakpoints;
6use ndarray::{Array1, Array2, ArrayView2};
7
8/// Require a breakpoint sequence suitable for BMS span lookup: finite,
9/// strictly increasing, and long enough to define at least one span.
10fn validate_breakpoints(breakpoints: &[f64], label: &str) -> Result<(), String> {
11    if breakpoints.len() < 2 {
12        return Err(format!("{label} requires at least two breakpoints"));
13    }
14    if let Some((idx, window)) = breakpoints.windows(2).enumerate().find(|(_, window)| {
15        !window[0].is_finite() || !window[1].is_finite() || window[0] >= window[1]
16    }) {
17        return Err(format!(
18            "{label} requires strictly increasing finite breakpoints; breakpoints[{idx}]={:.6}, breakpoints[{}]={:.6}",
19            window[0],
20            idx + 1,
21            window[1]
22        ));
23    }
24    Ok::<(), _>(())
25}
26
27/// Deduplicate an ordered BMS knot sequence into strictly increasing
28/// breakpoints.
29fn breakpoints_from_knots(knots: &[f64], label: &str) -> Result<Vec<f64>, String> {
30    let mut breakpoints = Vec::new();
31    for &knot in knots {
32        if breakpoints
33            .last()
34            .is_none_or(|prev: &f64| (knot - *prev).abs() > 1e-12)
35        {
36            breakpoints.push(knot);
37        }
38    }
39    validate_breakpoints(&breakpoints, label)?;
40    Ok(breakpoints)
41}
42
43/// Round-off tolerance on the minimum monotonicity-derivative slack. The
44/// constraints are constructed with a positive required margin
45/// (`monotonicity_eps`); this separate, tiny negative bound only absorbs the
46/// finite-precision accumulation in evaluating the slack at the I-spline
47/// breakpoints, so a coefficient that is feasible up to a few ulps is not
48/// spuriously rejected. Anything more negative is a genuine violation.
49pub(crate) const MONOTONICITY_SLACK_ROUNDOFF_TOL: f64 = -1e-10;
50
51/// Typed errors emitted by the deviation runtime construction and evaluation
52/// helpers in this module.
53///
54/// Each variant carries a pre-formatted `reason` string so `Display` is
55/// byte-equivalent to the original `format!(...)` outputs the module used
56/// before the typed-error migration. The category split lets callers
57/// pattern-match on the failure kind without parsing the message.
58#[derive(Debug, Clone)]
59pub enum DeviationRuntimeError {
60    /// A scalar configuration value, index, derivative order, runtime value,
61    /// or required metadata bundle did not satisfy the contract (out-of-range
62    /// index, non-finite value, missing support points, span width <= 0).
63    InvalidInput { reason: String },
64    /// A matrix / vector shape did not match an expected dimension while
65    /// composing transforms, validating anchors, or accepting beta vectors.
66    DimensionMismatch { reason: String },
67    /// A numerical kernel (eigendecomposition, I-spline construction,
68    /// monotonicity slack search) failed or produced no usable output.
69    NumericalFailure { reason: String },
70}
71
72impl_reason_error_boilerplate! {
73    DeviationRuntimeError {
74        InvalidInput,
75        DimensionMismatch,
76        NumericalFailure,
77    }
78}
79
80/// Installed cross-block flex block on the runtime.
81///
82/// Direct on-runtime image of `identifiability::families::compiler::CompiledBlock`:
83/// `anchor_correction` = `compiled.anchor_correction` (the d × k matrix M),
84/// `anchor_components` = the per-anchor predict-time tags (the parent
85/// predictor uses them to rebuild `n_row` at predict-time rows). The
86/// post-residualisation row evaluator is
87///
88///   design_row(x) = pure_span_row(x) − n_row · M
89///
90/// The compiler bakes the orthonormalising rotation into M, so no
91/// separate rotation matrix is stored on the install state.
92#[derive(Clone, Debug)]
93pub struct InstalledFlexBlock {
94    /// Anchor correction matrix `M ∈ R^{d × k}` from
95    /// `CompiledBlock::anchor_correction`. The design evaluator subtracts
96    /// `n_row · M` per row.
97    pub anchor_correction: Array2<f64>,
98    /// Per-anchor predict-time tags, in the order the anchors were stacked
99    /// (parametric before flex). `sum(ncols)` equals the row dimension of
100    /// `anchor_correction`.
101    pub anchor_components: Vec<AnchorComponentTag>,
102}
103
104#[derive(Clone, Debug)]
105pub enum AnchorComponentTag {
106    /// Parametric anchor — at predict time the parent predictor reconstructs
107    /// the per-row vector from the saved marginal/logslope blocks; the
108    /// runtime only needs to know which block and how many columns. The
109    /// `block` tag is consumed by the serde plumbing in
110    /// `inference::model::SavedAnchorComponent`.
111    Parametric {
112        block: ParametricAnchorBlock,
113        ncols: usize,
114    },
115    /// Flex-evaluation anchor — a sibling flex block's design at training
116    /// rows (post-reparameterisation, in the same coordinate frame the
117    /// predictor will use at predict time). The number of columns equals
118    /// the sibling block's reparameterised basis dimension.
119    FlexEvaluation { ncols: usize },
120}
121
122#[derive(Clone, Copy, Debug, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
123pub enum ParametricAnchorBlock {
124    Marginal,
125    Logslope,
126}
127
128pub(crate) fn integrate_polynomial_product(left: &[f64], right: &[f64], width: f64) -> f64 {
129    let mut total = 0.0;
130    for (left_power, &left_coeff) in left.iter().enumerate() {
131        for (right_power, &right_coeff) in right.iter().enumerate() {
132            let power = left_power + right_power + 1;
133            total += left_coeff * right_coeff * width.powi(power as i32) / power as f64;
134        }
135    }
136    total
137}
138
139/// Precomputed per-span polynomial coefficient matrices for a structurally
140/// monotone anchored deviation basis.
141///
142/// Raw coefficients are monotone I-spline coefficients. The deviation
143/// derivative `w'(x)` is a nonnegative quadratic B-spline combination, so
144/// `w(x)` is a cubic I-spline combination with `C2` continuity at knots and
145/// constant tails. Zero coefficients still mean the identity map. The fitted
146/// coefficients live in the configured moment-anchor nullspace and are mapped
147/// back to these raw coefficients for monotonicity.
148///
149/// Monotonicity of the full transform `x + w(x)` is enforced by lower bounds
150/// on each span's quadratic Bernstein controls for `w'(x)`.
151#[derive(Clone, Debug)]
152pub struct DeviationRuntime {
153    pub(crate) degree: usize,
154    pub(crate) value_span_degree: usize,
155    pub(crate) basis_dim: usize,
156    pub(crate) monotonicity_eps: f64,
157    pub(crate) endpoint_points: Array1<f64>,
158    pub(crate) span_c0: Array2<f64>,
159    pub(crate) span_c1: Array2<f64>,
160    pub(crate) span_c2: Array2<f64>,
161    pub(crate) span_c3: Array2<f64>,
162    pub(crate) monotonicity_constraint_rows: Array2<f64>,
163    /// Deviation basis values at the rightmost breakpoint (1 × basis_dim).
164    /// Used for constant-tail continuation outside support: the deviation
165    /// saturates at this value for all z > right endpoint.
166    pub(crate) right_boundary_value_row: Array1<f64>,
167    /// Cross-block installed flex block. `None` until
168    /// `install_compiled_flex_block` is called.
169    pub(crate) installed_flex_block: Option<InstalledFlexBlock>,
170    /// Stacked parametric-anchor rows at training rows (n × d). Used by
171    /// `design_at_training_with_residual` to rebuild `block.design` after
172    /// orthogonalisation. Dropped before serialisation; predict-time
173    /// reconstruction rebuilds anchor rows fresh at the predict-time
174    /// feature rows.
175    pub(crate) anchor_rows_at_training: Option<Array2<f64>>,
176}
177
178/// Build the integrated derivative penalty matrix `P` on the *raw* I-spline
179/// coefficients (before any null-space transform), where
180/// `P_{ij} = ∫ b_i^(k)(x) b_j^(k)(x) dx` integrated piecewise over the knot
181/// support. The null space of `P` is the function-space null space of the
182/// k-th-derivative penalty: polynomials of degree < k. For k = 1 this is
183/// {constants}; for k = 2 it is {constants, linears}; for k = 3 it is
184/// {constants, linears, quadratics}. Dropping these directions from the
185/// basis at construction time is what gives the link-deviation block
186/// β-independent identifiability (the location block's intercept and any
187/// unpenalized location-linear absorb constants/linears in η; β_dev contains
188/// only the wiggle).
189///
190/// Mirrors `integrated_derivative_penalty_with_nullity` but operates on the
191/// raw cubic span coefficients, so it can be evaluated *before* the basis
192/// transform `Z` is constructed (which is what we need to compute `Z`
193/// itself).
194pub(crate) fn raw_integrated_derivative_penalty(
195    endpoint_points: &Array1<f64>,
196    raw_span_c0: &Array2<f64>,
197    raw_span_c1: &Array2<f64>,
198    raw_span_c2: &Array2<f64>,
199    raw_span_c3: &Array2<f64>,
200    derivative_order: usize,
201) -> Result<Array2<f64>, String> {
202    let raw_dim = raw_span_c0.ncols();
203    let n_spans = endpoint_points.len().saturating_sub(1);
204    if raw_span_c1.ncols() != raw_dim
205        || raw_span_c2.ncols() != raw_dim
206        || raw_span_c3.ncols() != raw_dim
207    {
208        return Err("raw smoothness penalty: span coefficient column dimensions disagree".into());
209    }
210    let mut penalty = Array2::<f64>::zeros((raw_dim, raw_dim));
211    for span_idx in 0..n_spans {
212        let left = endpoint_points[span_idx];
213        let right = endpoint_points[span_idx + 1];
214        let width = right - left;
215        if !width.is_finite() || width <= 0.0 {
216            return Err(format!(
217                "raw smoothness penalty span {span_idx} has invalid width {width}"
218            ));
219        }
220        for i in 0..raw_dim {
221            let ci = raw_span_derivative_polynomial_coefficients(
222                span_idx,
223                i,
224                derivative_order,
225                raw_span_c0,
226                raw_span_c1,
227                raw_span_c2,
228                raw_span_c3,
229            );
230            for j in i..raw_dim {
231                let cj = raw_span_derivative_polynomial_coefficients(
232                    span_idx,
233                    j,
234                    derivative_order,
235                    raw_span_c0,
236                    raw_span_c1,
237                    raw_span_c2,
238                    raw_span_c3,
239                );
240                let contribution = integrate_polynomial_product(&ci, &cj, width);
241                penalty[[i, j]] += contribution;
242                if i != j {
243                    penalty[[j, i]] += contribution;
244                }
245            }
246        }
247    }
248    Ok(penalty)
249}
250
251/// Per-span polynomial coefficients of the `derivative_order`-th derivative
252/// of raw basis function `basis_idx` on its parametric coordinate `t`. Mirrors
253/// `DeviationRuntime::span_derivative_polynomial_coefficients` but on raw
254/// coefficients so it's callable before `Z` exists.
255pub(crate) fn raw_span_derivative_polynomial_coefficients(
256    span_idx: usize,
257    basis_idx: usize,
258    derivative_order: usize,
259    raw_span_c0: &Array2<f64>,
260    raw_span_c1: &Array2<f64>,
261    raw_span_c2: &Array2<f64>,
262    raw_span_c3: &Array2<f64>,
263) -> Vec<f64> {
264    let c0 = raw_span_c0[[span_idx, basis_idx]];
265    let c1 = raw_span_c1[[span_idx, basis_idx]];
266    let c2 = raw_span_c2[[span_idx, basis_idx]];
267    let c3 = raw_span_c3[[span_idx, basis_idx]];
268    match derivative_order {
269        0 => vec![c0, c1, c2, c3],
270        1 => vec![c1, 2.0 * c2, 3.0 * c3],
271        2 => vec![2.0 * c2, 6.0 * c3],
272        3 => vec![6.0 * c3],
273        _ => Vec::new(),
274    }
275}
276
277/// Compute `Z` = orthonormal columns spanning the orthogonal complement of
278/// the null space of `P_raw` (the integrated derivative penalty in raw
279/// coordinates). Eigenvectors with strictly-positive eigenvalues are taken;
280/// near-zero eigenvalues correspond to functions with zero `derivative_order`-
281/// th derivative, i.e., polynomials of degree `< derivative_order` evaluated
282/// in the raw basis.
283///
284/// Returned `Z` has shape `raw_dim × (raw_dim − nullity)`. After applying it
285/// (`raw_basis · Z`), the transformed basis cannot represent any polynomial
286/// of degree < `derivative_order` — that direction is structurally absent
287/// from the parameterization. This is the β-independent identifiability
288/// constraint that replaces the data-distribution-dependent moment anchor.
289pub(crate) fn smoothness_nullspace_orthogonal_complement(
290    raw_penalty: &Array2<f64>,
291) -> Result<Array2<f64>, String> {
292    let n = raw_penalty.nrows();
293    if raw_penalty.ncols() != n {
294        return Err("smoothness penalty matrix must be square for null-space drop".to_string());
295    }
296    let (eigenvalues, eigenvectors) = raw_penalty
297        .eigh(faer::Side::Lower)
298        .map_err(|e| format!("raw smoothness penalty eigendecomposition failed: {e}"))?;
299    let evals = eigenvalues
300        .as_slice()
301        .ok_or_else(|| "raw smoothness penalty eigenvalues are not contiguous".to_string())?;
302    let threshold = gam_solve::estimate::reml::reml_outer_engine::positive_eigenvalue_threshold(evals);
303    let kept: Vec<usize> = evals
304        .iter()
305        .enumerate()
306        .filter_map(|(i, &v)| (v > threshold).then_some(i))
307        .collect();
308    if kept.is_empty() {
309        return Err(
310            "smoothness penalty has no positive eigenvalues; basis is entirely in the penalty's \
311             null space and cannot be identified after the smoothness null-space drop"
312                .to_string(),
313        );
314    }
315    if kept.len() == n {
316        return Err(
317            "smoothness penalty has no null directions; nothing to drop. The link-deviation \
318             basis was expected to carry a non-trivial null space (constants/linears) for \
319             absorption by the location block — check the configured penalty derivative order"
320                .to_string(),
321        );
322    }
323    let mut z = Array2::<f64>::zeros((n, kept.len()));
324    for (col_out, &col_in) in kept.iter().enumerate() {
325        z.column_mut(col_out).assign(&eigenvectors.column(col_in));
326    }
327    Ok(z)
328}
329
330pub(crate) fn build_quadratic_derivative_bernstein_constraints(
331    endpoint_points: &Array1<f64>,
332    span_c1: &Array2<f64>,
333    span_c2: &Array2<f64>,
334    span_c3: &Array2<f64>,
335) -> Result<Array2<f64>, String> {
336    let n_spans = endpoint_points.len().saturating_sub(1);
337    let basis_dim = span_c1.ncols();
338    let mut rows = Array2::<f64>::zeros((3 * n_spans, basis_dim));
339    for span_idx in 0..n_spans {
340        let width = endpoint_points[span_idx + 1] - endpoint_points[span_idx];
341        if !width.is_finite() || width <= 0.0 {
342            return Err(DeviationRuntimeError::InvalidInput {
343                reason: format!(
344                    "DeviationRuntime monotonicity span {span_idx} has invalid width {width}"
345                ),
346            }
347            .into());
348        }
349        let left_row = 3 * span_idx;
350        let mid_row = left_row + 1;
351        let right_row = left_row + 2;
352        for basis_idx in 0..basis_dim {
353            let c1 = span_c1[[span_idx, basis_idx]];
354            let c2 = span_c2[[span_idx, basis_idx]];
355            let c3 = span_c3[[span_idx, basis_idx]];
356            // For w(t)=c0+c1*t+c2*t^2+c3*t^3 on t in [0,h],
357            // w'(t)=c1+2*c2*t+3*c3*t^2. In quadratic Bernstein form over
358            // s=t/h, the controls are:
359            //   b0 = c1
360            //   b1 = c1 + c2*h
361            //   b2 = c1 + 2*c2*h + 3*c3*h^2
362            // Since Bernstein basis functions are non-negative and sum to 1,
363            // b_k >= eps-1 is a linear certificate for x + w(x) monotonicity.
364            // `exact_monotonicity_min_slack` below still checks the true
365            // quadratic minimum, including the interior vertex.
366            rows[[left_row, basis_idx]] = c1;
367            rows[[mid_row, basis_idx]] = c1 + c2 * width;
368            rows[[right_row, basis_idx]] = c1 + 2.0 * c2 * width + 3.0 * c3 * width * width;
369        }
370    }
371    Ok(rows)
372}
373
374impl DeviationRuntime {
375    /// Construct the link-deviation runtime with a smoothness-null-space-drop
376    /// basis transform. `max_penalty_derivative_order` is the highest
377    /// derivative order of any penalty that will subsequently be applied to
378    /// this block (computed by the caller from its `DeviationBlockConfig`).
379    /// The returned basis structurally excludes polynomials of degree
380    /// `< max_penalty_derivative_order`, so the configured smoothness
381    /// penalties have no null space on the transformed basis and the
382    /// joint Hessian + penalty system is well-conditioned at every PIRLS
383    /// iteration regardless of how β shifts the linear predictor distribution.
384    ///
385    /// This replaces the previous data-distribution moment anchor (at the
386    /// rigid-pilot η₀), which gave a β-dependent identifiability constraint
387    /// that drifted out of alignment with η_current during PIRLS and produced
388    /// near-singular joint Hessians (σ_min ≈ ridge_floor).
389    pub(crate) fn try_new(
390        knots: Array1<f64>,
391        monotonicity_eps: f64,
392        max_penalty_derivative_order: usize,
393    ) -> Result<Self, String> {
394        Self::try_new_with_smoothness_drop(knots, monotonicity_eps, max_penalty_derivative_order)
395    }
396
397    pub(super) fn try_new_with_smoothness_drop(
398        knots: Array1<f64>,
399        monotonicity_eps: f64,
400        max_penalty_derivative_order: usize,
401    ) -> Result<Self, String> {
402        if !monotonicity_eps.is_finite() || monotonicity_eps < 0.0 {
403            return Err(DeviationRuntimeError::InvalidInput {
404                reason: format!(
405                    "DeviationRuntime monotonicity_eps must be finite and non-negative, got {monotonicity_eps}"
406                ),
407            }
408            .into());
409        }
410
411        let bkpts = breakpoints_from_knots(
412            knots.as_slice().ok_or_else(|| {
413                String::from(DeviationRuntimeError::InvalidInput {
414                    reason: "DeviationRuntime knots are not contiguous".to_string(),
415                })
416            })?,
417            "DeviationRuntime breakpoints",
418        )?;
419        let endpoint_points = Array1::from_vec(bkpts);
420        if endpoint_points.len() < 3 {
421            return Err(DeviationRuntimeError::InvalidInput {
422                reason:
423                    "DeviationRuntime requires at least two active knot spans and one interior node"
424                        .to_string(),
425            }
426            .into());
427        }
428        let n_spans = endpoint_points.len() - 1;
429        for span_idx in 0..n_spans {
430            let left = endpoint_points[span_idx];
431            let right = endpoint_points[span_idx + 1];
432            let width = right - left;
433            if !width.is_finite() || width <= 0.0 {
434                return Err(DeviationRuntimeError::InvalidInput {
435                    reason: format!(
436                        "DeviationRuntime requires strictly increasing span endpoints at span {span_idx}: left={left}, right={right}"
437                    ),
438                }
439                .into());
440            }
441        }
442        let span_lefts = Array1::from_iter((0..n_spans).map(|idx| endpoint_points[idx]));
443        let span_midpoints = Array1::from_iter(
444            (0..n_spans).map(|idx| 0.5 * (endpoint_points[idx] + endpoint_points[idx + 1])),
445        );
446        let right_endpoint = Array1::from_vec(vec![endpoint_points[n_spans]]);
447        let internal_degree = 2usize;
448        let raw_span_c0 =
449            create_ispline_derivative_dense(span_lefts.view(), &knots, internal_degree, 0)
450                .map_err(|e| {
451                    String::from(DeviationRuntimeError::NumericalFailure {
452                        reason: format!("DeviationRuntime cubic I-spline values failed: {e}"),
453                    })
454                })?;
455        let raw_span_c1 =
456            create_ispline_derivative_dense(span_lefts.view(), &knots, internal_degree, 1)
457                .map_err(|e| {
458                    String::from(DeviationRuntimeError::NumericalFailure {
459                        reason: format!(
460                            "DeviationRuntime cubic I-spline first derivatives failed: {e}"
461                        ),
462                    })
463                })?;
464        let raw_span_c2 =
465            create_ispline_derivative_dense(span_lefts.view(), &knots, internal_degree, 2)
466                .map_err(|e| {
467                    String::from(DeviationRuntimeError::NumericalFailure {
468                        reason: format!(
469                            "DeviationRuntime cubic I-spline second derivatives failed: {e}"
470                        ),
471                    })
472                })?
473                .mapv(|value| 0.5 * value);
474        let raw_span_c3 =
475            create_ispline_derivative_dense(span_midpoints.view(), &knots, internal_degree, 3)
476                .map_err(|e| {
477                    String::from(DeviationRuntimeError::NumericalFailure {
478                        reason: format!(
479                            "DeviationRuntime cubic I-spline third derivatives failed: {e}"
480                        ),
481                    })
482                })?
483                .mapv(|value| value / 6.0);
484        let raw_right_boundary_values =
485            create_ispline_derivative_dense(right_endpoint.view(), &knots, internal_degree, 0)
486                .map_err(|e| {
487                    String::from(DeviationRuntimeError::NumericalFailure {
488                        reason: format!(
489                            "DeviationRuntime cubic I-spline right boundary failed: {e}"
490                        ),
491                    })
492                })?;
493        let raw_right_boundary_value_row = raw_right_boundary_values.row(0).to_owned();
494
495        if max_penalty_derivative_order == 0 {
496            return Err(
497                "DeviationRuntime requires max_penalty_derivative_order >= 1 so the basis can \
498                 drop the corresponding smoothness null space; an order-0 (mass) penalty alone \
499                 has no null space and would not require any drop"
500                    .to_string(),
501            );
502        }
503        if max_penalty_derivative_order > 3 {
504            return Err(format!(
505                "DeviationRuntime cubic basis supports derivative orders up to 3; got max \
506                 penalty derivative order {max_penalty_derivative_order}"
507            ));
508        }
509        let raw_smoothness_penalty = raw_integrated_derivative_penalty(
510            &endpoint_points,
511            &raw_span_c0,
512            &raw_span_c1,
513            &raw_span_c2,
514            &raw_span_c3,
515            max_penalty_derivative_order,
516        )?;
517        let coefficient_transform =
518            smoothness_nullspace_orthogonal_complement(&raw_smoothness_penalty)?;
519        let basis_dim = coefficient_transform.ncols();
520        let span_c0 = fast_ab(&raw_span_c0, &coefficient_transform);
521        let span_c1 = fast_ab(&raw_span_c1, &coefficient_transform);
522        let span_c2 = fast_ab(&raw_span_c2, &coefficient_transform);
523        let span_c3 = fast_ab(&raw_span_c3, &coefficient_transform);
524        let right_boundary_value_row = raw_right_boundary_value_row.dot(&coefficient_transform);
525        let monotonicity_constraint_rows = build_quadratic_derivative_bernstein_constraints(
526            &endpoint_points,
527            &span_c1,
528            &span_c2,
529            &span_c3,
530        )?;
531
532        Ok(Self {
533            degree: 3,
534            value_span_degree: 3,
535            basis_dim,
536            monotonicity_eps,
537            endpoint_points,
538            span_c0,
539            span_c1,
540            span_c2,
541            span_c3,
542            monotonicity_constraint_rows,
543            right_boundary_value_row,
544            installed_flex_block: None,
545            anchor_rows_at_training: None,
546        })
547    }
548
549    // The per-block `smoothness_nullspace_orthogonal_complement` transform
550    // above eliminates within-block polynomial aliasing (constants/linears in
551    // η_pilot) so the location block can carry the intercept. That handles
552    // single-flex-block configurations. When two flex blocks of η_pilot are
553    // simultaneously active (score-warp + linkwiggle), each is individually
554    // orthogonal to constants, but their column spans still overlap inside
555    // the orthogonal complement of constants — both are cubic I-spline bases
556    // of the same scalar argument. The overlap manifests as a near-null
557    // direction in the joint penalized Hessian: a linear combination of
558    // β_score_warp and β_link_dev that produces zero net η-contribution at
559    // the rigid-pilot training points yet costs only the (penalised) basis
560    // norm, so Newton steps along that direction blow up.
561    //
562    // Compose an external column transform `T` (shape `basis_dim × new_dim`)
563    // into the cubic span tables and monotonicity constraints. After this
564    // call every `design(...)`-style method returns matrices in the new
565    // `new_dim`-column parameterisation: `runtime.design(values) ==
566    // old_runtime.design(values) · T`. Penalties built later via
567    // `integrated_derivative_penalty_with_nullity` are also expressed in
568    // the new parameterisation.
569    //
570    // Used by `install_compiled_flex_block_into_runtime` to
571    // enforce the joint-design identifiability invariant in the W-metric
572    // (W = p(1−p) at training rows). With `A_train` the stacked parametric
573    // anchors and `C_train = span_eval(values)` the candidate basis at the
574    // training rows, the residualised candidate is
575    //
576    //     C̃_train = (I − P_A^{(W)}) C_train,    P_A^{(W)} = A(AᵀWA)⁻¹AᵀW
577    //
578    // and the kept directions are the eigenvectors of `C̃ᵀ W C̃` above the
579    // numerical noise floor. The block-triangular reparameterisation
580    // `Aβ_A + Cβ_C = A(β_A + Bβ_C) + (C − AB)β_C` with `B = (AᵀWA)⁻¹AᵀWC`
581    // means dropping a direction in C̃ drops *exactly* a direction
582    // span(C) shares with span(A) under W, leaving no aliasing in the
583    // joint design `[X_loc | X_logslope | A | C·V − N·M]` (full column
584    // rank up to numerical tolerance, so `σ_min(joint H+S) ≥ λ_min(S₊)`
585    // regardless of how β shifts the linear-predictor distribution).
586    //
587    // The old `T = null(A_trainᵀ C_train)` algorithm was wrong: that
588    // null-space is the candidate directions *already* exactly W-orthogonal
589    // to A (Gram = 0), not the directions left after projecting A out.
590    // `null(AᵀC) = ∅` does NOT imply `span(C) ⊆ span(A)` — counterexample
591    // `A = e₁`, `C = e₁ + e₂` has `AᵀC = 1 ≠ 0` (empty null space) yet
592    // `(I − P_A) C = e₂ ≠ 0`. Whenever the anchor is wider than the
593    // candidate (d ≥ p_c) the old test generically returned ∅ even when
594    // the residualised candidate had full rank, prompting a spurious
595    // "fully aliased" hard-error. The current code residualises and keeps
596    // exactly the surviving rank.
597    /// Compose a rank-reveal right-selector and an optional anchor-residual.
598    /// After this call, `design(x)` returns
599    ///   design_row(x) = span_eval(x) · V  −  n_row(x) · installed.anchor_correction
600    /// where V is `right_selector` (applied via right-multiplication into
601    /// `span_c{0..3}`). Only the `design()` path (derivative_order=0) subtracts
602    /// the residual: the anchor argument is a different scalar variable than
603    /// the candidate argument, so d/dx of `n_row(x)` w.r.t. the candidate
604    /// argument is identically zero.
605    pub(crate) fn compose_anchor_orthogonalisation(
606        &mut self,
607        right_selector: &Array2<f64>,
608        installed_flex_block: Option<InstalledFlexBlock>,
609    ) -> Result<(), String> {
610        let old_dim = self.basis_dim;
611        if right_selector.nrows() != old_dim {
612            return Err(DeviationRuntimeError::DimensionMismatch {
613                reason: format!(
614                    "DeviationRuntime cross-block transform shape mismatch: \
615                     transform rows={}, expected basis_dim={}",
616                    right_selector.nrows(),
617                    old_dim,
618                ),
619            }
620            .into());
621        }
622        let new_dim = right_selector.ncols();
623        if new_dim == 0 {
624            return Err(DeviationRuntimeError::DimensionMismatch {
625                reason: "DeviationRuntime cross-block transform reduces basis dim to 0; \
626                 the candidate's column span is fully aliased by the anchor block"
627                    .to_string(),
628            }
629            .into());
630        }
631        if new_dim > old_dim {
632            return Err(DeviationRuntimeError::DimensionMismatch {
633                reason: format!(
634                    "DeviationRuntime cross-block transform must not increase basis dim; \
635                     got new_dim={} from old_dim={}",
636                    new_dim, old_dim,
637                ),
638            }
639            .into());
640        }
641        if let Some(ref installed) = installed_flex_block {
642            let d_expected: usize = installed
643                .anchor_components
644                .iter()
645                .map(|c| match c {
646                    AnchorComponentTag::Parametric { ncols, .. } => *ncols,
647                    AnchorComponentTag::FlexEvaluation { ncols } => *ncols,
648                })
649                .sum();
650            if installed.anchor_correction.nrows() != d_expected {
651                return Err(DeviationRuntimeError::DimensionMismatch {
652                    reason: format!(
653                        "DeviationRuntime installed flex block: anchor_correction rows={}, expected sum-of-component-ncols={}",
654                        installed.anchor_correction.nrows(),
655                        d_expected,
656                    ),
657                }
658                .into());
659            }
660            if installed.anchor_correction.ncols() != new_dim {
661                return Err(DeviationRuntimeError::DimensionMismatch {
662                    reason: format!(
663                        "DeviationRuntime installed flex block: anchor_correction cols={}, expected new basis dim {}",
664                        installed.anchor_correction.ncols(),
665                        new_dim,
666                    ),
667                }
668                .into());
669            }
670        }
671        self.span_c0 = fast_ab(&self.span_c0, right_selector);
672        self.span_c1 = fast_ab(&self.span_c1, right_selector);
673        self.span_c2 = fast_ab(&self.span_c2, right_selector);
674        self.span_c3 = fast_ab(&self.span_c3, right_selector);
675        // `right_boundary_value_row` is a 1-D row vector of length basis_dim;
676        // right-multiplying by V (basis_dim × new_dim) gives the new row.
677        self.right_boundary_value_row = self.right_boundary_value_row.dot(right_selector);
678        // Monotonicity rows (n_constraints × basis_dim) follow the same
679        // right-multiplication. The constraint inequality `A β ≥ ε - 1`
680        // becomes `(A · V) β_new ≥ ε - 1` under the reparameterisation
681        // β = V β_new, so the row matrix is right-multiplied directly.
682        self.monotonicity_constraint_rows =
683            fast_ab(&self.monotonicity_constraint_rows, right_selector);
684        self.basis_dim = new_dim;
685        self.installed_flex_block = installed_flex_block;
686        Ok(())
687    }
688
689    /// Accessor for the installed flex block set via
690    /// `install_compiled_flex_block`. Save-time code uses this to snapshot
691    /// the install state into the saved model; predict-time code reconstructs
692    /// the per-row η correction `n_row · anchor_correction · β`.
693    pub fn installed_flex_block(&self) -> Option<&InstalledFlexBlock> {
694        self.installed_flex_block.as_ref()
695    }
696
697    /// Single-step install of a compiled flex block from
698    /// `identifiability::families::compiler::compile`.
699    ///
700    /// Semantics:
701    /// - `compiled.t_lw` is the right-selector `V` applied to `span_c{0..3}`,
702    ///   `right_boundary_value_row`, and `monotonicity_constraint_rows`.
703    /// - `compiled.anchor_correction` (always `Some` for non-empty anchor
704    ///   unions) is the d×k correction `M`.
705    /// - `anchor_components` records the per-anchor predict-time tags so
706    ///   the saved-model rebuild can replay the anchor row map.
707    /// - `n_train_at_training` is cached for
708    ///   `design_at_training_with_residual`.
709    pub(crate) fn install_compiled_flex_block(
710        &mut self,
711        compiled: &gam_identifiability::families::compiler::CompiledBlock,
712        anchor_components: Vec<AnchorComponentTag>,
713        n_train_at_training: Array2<f64>,
714    ) -> Result<(), String> {
715        let m = compiled.anchor_correction.as_ref().ok_or_else(|| {
716            "DeviationRuntime::install_compiled_flex_block: compiled block has no \
717             anchor_correction — install requires a non-empty anchor union"
718                .to_string()
719        })?;
720        let installed = InstalledFlexBlock {
721            anchor_correction: m.clone(),
722            anchor_components,
723        };
724        self.anchor_rows_at_training = Some(n_train_at_training);
725        self.compose_anchor_orthogonalisation(&compiled.t_lw, Some(installed))
726    }
727
728    /// Cached parametric-anchor matrix at training rows, installed by
729    /// `install_compiled_flex_block_into_runtime` when the
730    /// runtime is reparameterised against the parametric anchor union.
731    /// Used by per-row link-deviation evaluators that need the row's
732    /// anchor slice to apply `design_with_anchor_rows` correctly. Returns
733    /// `None` for runtimes that have not been reparameterised.
734    pub fn anchor_rows_at_training(&self) -> Option<&Array2<f64>> {
735        self.anchor_rows_at_training.as_ref()
736    }
737
738    /// Evaluate `design(values) - anchor_rows · M` where `anchor_rows` is
739    /// the n × d parametric-anchor matrix at the same rows as `values`.
740    /// Mandatory when an installed flex block is present; for runtimes
741    /// without one this is equivalent to `design(values)` and `anchor_rows`
742    /// must be `n × 0`.
743    pub fn design_with_anchor_rows(
744        &self,
745        values: &Array1<f64>,
746        anchor_rows: ArrayView2<f64>,
747    ) -> Result<Array2<f64>, String> {
748        let mut out = self.evaluate_span_polynomial_design_raw(values, 0)?;
749        if let Some(installed) = &self.installed_flex_block {
750            if anchor_rows.nrows() != values.len() {
751                return Err(DeviationRuntimeError::DimensionMismatch {
752                    reason: format!(
753                        "design_with_anchor_rows: anchor_rows has {} rows, expected {} (matching values)",
754                        anchor_rows.nrows(),
755                        values.len(),
756                    ),
757                }
758                .into());
759            }
760            if anchor_rows.ncols() != installed.anchor_correction.nrows() {
761                return Err(DeviationRuntimeError::DimensionMismatch {
762                    reason: format!(
763                        "design_with_anchor_rows: anchor_rows has {} cols, expected {} (sum of component ncols)",
764                        anchor_rows.ncols(),
765                        installed.anchor_correction.nrows(),
766                    ),
767                }
768                .into());
769            }
770            let subtract = anchor_rows.dot(&installed.anchor_correction);
771            out = out - subtract;
772        } else if anchor_rows.ncols() != 0 {
773            // Permit empty 0-col anchor rows without complaint; otherwise
774            // hard-error so callers don't silently pass mismatched rows.
775            return Err(DeviationRuntimeError::DimensionMismatch {
776                reason: format!(
777                    "design_with_anchor_rows: runtime has no installed flex block but anchor_rows has {} cols",
778                    anchor_rows.ncols(),
779                ),
780            }
781            .into());
782        }
783        Ok(out)
784    }
785
786    /// Rebuild the training-row design after orthogonalisation, using
787    /// `anchor_rows_at_training` cached at `install_compiled_flex_block` time.
788    pub(crate) fn design_at_training_with_residual(
789        &self,
790        values: &Array1<f64>,
791    ) -> Result<Array2<f64>, String> {
792        if let Some(rows) = self.anchor_rows_at_training.as_ref() {
793            self.design_with_anchor_rows(values, rows.view())
794        } else if self.installed_flex_block.is_some() {
795            Err(
796                "design_at_training_with_residual: runtime has installed_flex_block but no cached training anchor rows"
797                    .to_string(),
798            )
799        } else {
800            self.design(values)
801        }
802    }
803
804    // ── public field accessors ──
805
806    pub fn degree(&self) -> usize {
807        self.degree
808    }
809
810    pub fn value_span_degree(&self) -> usize {
811        self.value_span_degree
812    }
813
814    pub fn basis_dim(&self) -> usize {
815        self.basis_dim
816    }
817
818    pub fn monotonicity_eps(&self) -> f64 {
819        self.monotonicity_eps
820    }
821
822    pub fn span_c0(&self) -> &Array2<f64> {
823        &self.span_c0
824    }
825
826    pub fn span_c1(&self) -> &Array2<f64> {
827        &self.span_c1
828    }
829
830    pub fn span_c2(&self) -> &Array2<f64> {
831        &self.span_c2
832    }
833
834    pub fn span_c3(&self) -> &Array2<f64> {
835        &self.span_c3
836    }
837
838    // ── design evaluation ──
839
840    pub(super) fn validate_beta_shape(
841        &self,
842        beta: &Array1<f64>,
843        label: &str,
844    ) -> Result<(), String> {
845        if beta.len() != self.basis_dim {
846            return Err(DeviationRuntimeError::DimensionMismatch {
847                reason: format!(
848                    "{label} length mismatch: got {}, expected {}",
849                    beta.len(),
850                    self.basis_dim
851                ),
852            }
853            .into());
854        }
855        Ok::<(), _>(())
856    }
857
858    /// Raw cubic-span polynomial design evaluation, without any
859    /// anchor-residual subtraction. Internal — callers that need the
860    /// residualised design must go through `design()` (which asserts no
861    /// residual) or `design_with_anchor_rows()`.
862    pub(super) fn evaluate_span_polynomial_design_raw(
863        &self,
864        values: &Array1<f64>,
865        derivative_order: usize,
866    ) -> Result<Array2<f64>, String> {
867        let (left_ep, right_ep) = self.support_interval()?;
868        let mut out = Array2::<f64>::zeros((values.len(), self.basis_dim));
869        for (row_idx, &value) in values.iter().enumerate() {
870            if !value.is_finite() {
871                return Err(DeviationRuntimeError::InvalidInput {
872                    reason: format!(
873                        "deviation runtime design value at row {row_idx} is non-finite ({value})"
874                    ),
875                }
876                .into());
877            }
878            if value < left_ep {
879                if derivative_order == 0 {
880                    out.row_mut(row_idx).assign(&self.span_c0.row(0));
881                }
882                continue;
883            }
884            if value > right_ep {
885                if derivative_order == 0 {
886                    out.row_mut(row_idx)
887                        .assign(&self.right_boundary_value_row.view());
888                }
889                continue;
890            }
891            let span_idx = self.left_biased_span_index_for(value)?;
892            let left = self.endpoint_points[span_idx];
893            let t = value - left;
894            for basis_idx in 0..self.basis_dim {
895                let c0 = self.span_c0[[span_idx, basis_idx]];
896                let c1 = self.span_c1[[span_idx, basis_idx]];
897                let c2 = self.span_c2[[span_idx, basis_idx]];
898                let c3 = self.span_c3[[span_idx, basis_idx]];
899                out[[row_idx, basis_idx]] = match derivative_order {
900                    0 => c0 + c1 * t + c2 * t * t + c3 * t * t * t,
901                    1 => c1 + 2.0 * c2 * t + 3.0 * c3 * t * t,
902                    2 => 2.0 * c2 + 6.0 * c3 * t,
903                    3 => 6.0 * c3,
904                    4 => 0.0,
905                    other => {
906                        return Err(DeviationRuntimeError::InvalidInput {
907                            reason: format!(
908                                "deviation runtime only supports derivative orders up to 4, got {other}"
909                            ),
910                        }
911                        .into());
912                    }
913                };
914            }
915        }
916        Ok(out)
917    }
918
919    /// Pure-span design (no anchor-residual subtraction). Callers must
920    /// ensure the runtime has no anchor residual; otherwise use
921    /// `design_with_anchor_rows`. Derivative paths are unaffected: the
922    /// residual subtraction `n_row · M` is constant in the candidate
923    /// argument, so its derivatives are identically zero.
924    pub fn design(&self, values: &Array1<f64>) -> Result<Array2<f64>, String> {
925        assert!(
926            self.installed_flex_block.is_none(),
927            "DeviationRuntime::design called on a runtime with an installed flex block; \
928             use design_with_anchor_rows or design_at_training_with_residual instead"
929        );
930        self.evaluate_span_polynomial_design_raw(values, 0)
931    }
932
933    pub fn first_derivative_design(&self, values: &Array1<f64>) -> Result<Array2<f64>, String> {
934        self.evaluate_span_polynomial_design_raw(values, 1)
935    }
936
937    pub fn second_derivative_design(&self, values: &Array1<f64>) -> Result<Array2<f64>, String> {
938        self.evaluate_span_polynomial_design_raw(values, 2)
939    }
940
941    pub fn third_derivative_design(&self, values: &Array1<f64>) -> Result<Array2<f64>, String> {
942        self.evaluate_span_polynomial_design_raw(values, 3)
943    }
944
945    pub(crate) fn integrated_derivative_penalty_with_nullity(
946        &self,
947        derivative_order: usize,
948    ) -> Result<(Array2<f64>, usize), String> {
949        if derivative_order > self.value_span_degree {
950            return Err(DeviationRuntimeError::InvalidInput {
951                reason: format!(
952                    "deviation penalty derivative order {derivative_order} exceeds value-basis degree {}",
953                    self.value_span_degree
954                ),
955            }
956            .into());
957        }
958        let mut penalty = Array2::<f64>::zeros((self.basis_dim, self.basis_dim));
959        for span_idx in 0..self.span_count() {
960            let (left, right) = self.span_interval(span_idx)?;
961            let width = right - left;
962            if !width.is_finite() || width <= 0.0 {
963                return Err(DeviationRuntimeError::InvalidInput {
964                    reason: format!("deviation penalty span {span_idx} has invalid width {width}"),
965                }
966                .into());
967            }
968            for i in 0..self.basis_dim {
969                let ci =
970                    self.span_derivative_polynomial_coefficients(span_idx, i, derivative_order)?;
971                for j in i..self.basis_dim {
972                    let cj = self.span_derivative_polynomial_coefficients(
973                        span_idx,
974                        j,
975                        derivative_order,
976                    )?;
977                    let contribution = integrate_polynomial_product(&ci, &cj, width);
978                    penalty[[i, j]] += contribution;
979                    if i != j {
980                        penalty[[j, i]] += contribution;
981                    }
982                }
983            }
984        }
985        let (evals, _) = penalty.eigh(faer::Side::Lower).map_err(|e| {
986            String::from(DeviationRuntimeError::NumericalFailure {
987                reason: format!("deviation integrated penalty eigendecomposition failed: {e}"),
988            })
989        })?;
990        let threshold = gam_solve::estimate::reml::reml_outer_engine::positive_eigenvalue_threshold(
991            evals.as_slice().ok_or_else(|| {
992                String::from(DeviationRuntimeError::NumericalFailure {
993                    reason: "deviation penalty eigenvalues are not contiguous".to_string(),
994                })
995            })?,
996        );
997        let rank = evals.iter().filter(|&&value| value > threshold).count();
998        let nullity = self.basis_dim.saturating_sub(rank);
999        Ok((penalty, nullity))
1000    }
1001
1002    pub(crate) fn structural_monotonicity_constraints(&self) -> LinearInequalityConstraints {
1003        LinearInequalityConstraints {
1004            a: self.monotonicity_constraint_rows.clone(),
1005            b: Array1::from_elem(
1006                self.monotonicity_constraint_rows.nrows(),
1007                self.monotonicity_eps - 1.0,
1008            ),
1009        }
1010    }
1011
1012    // ── span geometry ──
1013
1014    pub(super) fn span_count(&self) -> usize {
1015        self.endpoint_points.len().saturating_sub(1)
1016    }
1017
1018    pub fn breakpoints(&self) -> &Array1<f64> {
1019        &self.endpoint_points
1020    }
1021
1022    pub(super) fn span_interval(&self, span_idx: usize) -> Result<(f64, f64), String> {
1023        if span_idx >= self.span_count() {
1024            return Err(DeviationRuntimeError::InvalidInput {
1025                reason: format!(
1026                    "deviation span index {} out of range for {} spans",
1027                    span_idx,
1028                    self.span_count()
1029                ),
1030            }
1031            .into());
1032        }
1033        Ok((
1034            self.endpoint_points[span_idx],
1035            self.endpoint_points[span_idx + 1],
1036        ))
1037    }
1038
1039    pub(super) fn span_index_for(&self, value: f64) -> Result<usize, String> {
1040        span_index_for_breakpoints(
1041            self.endpoint_points.as_slice().ok_or_else(|| {
1042                String::from(DeviationRuntimeError::InvalidInput {
1043                    reason: "deviation runtime breakpoints are not contiguous".to_string(),
1044                })
1045            })?,
1046            value,
1047            "deviation span lookup",
1048        )
1049    }
1050
1051    pub(super) fn left_biased_span_index_for(&self, value: f64) -> Result<usize, String> {
1052        let mut span_idx = self.span_index_for(value)?;
1053        // Bias to the LEFT-hand span at internal breakpoints. The cubic basis
1054        // is C², so value, first derivative, and second derivative are
1055        // unchanged; only the span-local third derivative needs a convention.
1056        if span_idx > 0 && value == self.endpoint_points[span_idx] {
1057            span_idx -= 1;
1058        }
1059        Ok(span_idx)
1060    }
1061
1062    pub(super) fn span_derivative_polynomial_coefficients(
1063        &self,
1064        span_idx: usize,
1065        basis_idx: usize,
1066        derivative_order: usize,
1067    ) -> Result<Vec<f64>, String> {
1068        if span_idx >= self.span_count() {
1069            return Err(DeviationRuntimeError::InvalidInput {
1070                reason: format!(
1071                    "deviation span index {} out of range for {} spans",
1072                    span_idx,
1073                    self.span_count()
1074                ),
1075            }
1076            .into());
1077        }
1078        if basis_idx >= self.basis_dim {
1079            return Err(DeviationRuntimeError::InvalidInput {
1080                reason: format!(
1081                    "deviation basis index {} out of range for {} coefficients",
1082                    basis_idx, self.basis_dim
1083                ),
1084            }
1085            .into());
1086        }
1087        let c0 = self.span_c0[[span_idx, basis_idx]];
1088        let c1 = self.span_c1[[span_idx, basis_idx]];
1089        let c2 = self.span_c2[[span_idx, basis_idx]];
1090        let c3 = self.span_c3[[span_idx, basis_idx]];
1091        match derivative_order {
1092            0 => Ok(vec![c0, c1, c2, c3]),
1093            1 => Ok(vec![c1, 2.0 * c2, 3.0 * c3]),
1094            2 => Ok(vec![2.0 * c2, 6.0 * c3]),
1095            3 => Ok(vec![6.0 * c3]),
1096            other => Err(DeviationRuntimeError::InvalidInput {
1097                reason: format!(
1098                    "deviation polynomial coefficients only support derivative orders up to 3, got {other}"
1099                ),
1100            }
1101            .into()),
1102        }
1103    }
1104
1105    // ── cubic Taylor extraction ──
1106
1107    pub(crate) fn local_cubic_on_span(
1108        &self,
1109        beta: &Array1<f64>,
1110        span_idx: usize,
1111    ) -> Result<exact_kernel::LocalSpanCubic, String> {
1112        self.validate_beta_shape(beta, "deviation local cubic coefficients")?;
1113        let (left, right) = self.span_interval(span_idx)?;
1114        Ok(exact_kernel::LocalSpanCubic {
1115            left,
1116            right,
1117            c0: self.span_c0.row(span_idx).dot(beta),
1118            c1: self.span_c1.row(span_idx).dot(beta),
1119            c2: self.span_c2.row(span_idx).dot(beta),
1120            c3: self.span_c3.row(span_idx).dot(beta),
1121        })
1122    }
1123
1124    pub fn basis_span_cubic(
1125        &self,
1126        span_idx: usize,
1127        basis_idx: usize,
1128    ) -> Result<exact_kernel::LocalSpanCubic, String> {
1129        if basis_idx >= self.basis_dim {
1130            return Err(DeviationRuntimeError::InvalidInput {
1131                reason: format!(
1132                    "deviation basis index {} out of range for {} coefficients",
1133                    basis_idx, self.basis_dim
1134                ),
1135            }
1136            .into());
1137        }
1138        let (left, right) = self.span_interval(span_idx)?;
1139        Ok(exact_kernel::LocalSpanCubic {
1140            left,
1141            right,
1142            c0: self.span_c0[[span_idx, basis_idx]],
1143            c1: self.span_c1[[span_idx, basis_idx]],
1144            c2: self.span_c2[[span_idx, basis_idx]],
1145            c3: self.span_c3[[span_idx, basis_idx]],
1146        })
1147    }
1148
1149    /// Return the correct per-basis `LocalSpanCubic` for any evaluation
1150    /// point. Strictly outside the knot support, returns a constant cubic
1151    /// (c1=c2=c3=0) at the saturated tail value. Interior breakpoints use the
1152    /// left span so span-local third derivatives match derivative designs.
1153    pub fn basis_cubic_at(
1154        &self,
1155        basis_idx: usize,
1156        value: f64,
1157    ) -> Result<exact_kernel::LocalSpanCubic, String> {
1158        if basis_idx >= self.basis_dim {
1159            return Err(DeviationRuntimeError::InvalidInput {
1160                reason: format!(
1161                    "deviation basis index {} out of range for {} coefficients",
1162                    basis_idx, self.basis_dim
1163                ),
1164            }
1165            .into());
1166        }
1167        let (left_ep, right_ep) = self.support_interval()?;
1168        if value < left_ep {
1169            return Ok(exact_kernel::LocalSpanCubic {
1170                left: left_ep,
1171                right: left_ep + 1.0,
1172                c0: self.span_c0[[0, basis_idx]],
1173                c1: 0.0,
1174                c2: 0.0,
1175                c3: 0.0,
1176            });
1177        }
1178        if value > right_ep {
1179            return Ok(exact_kernel::LocalSpanCubic {
1180                left: right_ep,
1181                right: right_ep + 1.0,
1182                c0: self.right_boundary_value_row[basis_idx],
1183                c1: 0.0,
1184                c2: 0.0,
1185                c3: 0.0,
1186            });
1187        }
1188        let span_idx = self.left_biased_span_index_for(value)?;
1189        self.basis_span_cubic(span_idx, basis_idx)
1190    }
1191
1192    pub fn for_each_basis_cubic_at<F>(&self, value: f64, mut visit: F) -> Result<(), String>
1193    where
1194        F: FnMut(usize, exact_kernel::LocalSpanCubic) -> Result<(), String>,
1195    {
1196        let (left_ep, right_ep) = self.support_interval()?;
1197        if value < left_ep {
1198            for basis_idx in 0..self.basis_dim {
1199                visit(
1200                    basis_idx,
1201                    exact_kernel::LocalSpanCubic {
1202                        left: left_ep,
1203                        right: left_ep + 1.0,
1204                        c0: self.span_c0[[0, basis_idx]],
1205                        c1: 0.0,
1206                        c2: 0.0,
1207                        c3: 0.0,
1208                    },
1209                )?;
1210            }
1211            return Ok(());
1212        }
1213        if value > right_ep {
1214            for basis_idx in 0..self.basis_dim {
1215                visit(
1216                    basis_idx,
1217                    exact_kernel::LocalSpanCubic {
1218                        left: right_ep,
1219                        right: right_ep + 1.0,
1220                        c0: self.right_boundary_value_row[basis_idx],
1221                        c1: 0.0,
1222                        c2: 0.0,
1223                        c3: 0.0,
1224                    },
1225                )?;
1226            }
1227            return Ok(());
1228        }
1229
1230        let span_idx = self.left_biased_span_index_for(value)?;
1231        let (left, right) = self.span_interval(span_idx)?;
1232        for basis_idx in 0..self.basis_dim {
1233            visit(
1234                basis_idx,
1235                exact_kernel::LocalSpanCubic {
1236                    left,
1237                    right,
1238                    c0: self.span_c0[[span_idx, basis_idx]],
1239                    c1: self.span_c1[[span_idx, basis_idx]],
1240                    c2: self.span_c2[[span_idx, basis_idx]],
1241                    c3: self.span_c3[[span_idx, basis_idx]],
1242                },
1243            )?;
1244        }
1245        Ok(())
1246    }
1247
1248    /// Return the correct composite `LocalSpanCubic` for any evaluation
1249    /// point. Strictly outside the knot support, returns a constant cubic
1250    /// (c1=c2=c3=0) at the saturated tail value. Interior breakpoints use the
1251    /// left span so span-local third derivatives match derivative designs.
1252    pub(crate) fn local_cubic_at(
1253        &self,
1254        beta: &Array1<f64>,
1255        value: f64,
1256    ) -> Result<exact_kernel::LocalSpanCubic, String> {
1257        self.validate_beta_shape(beta, "deviation local cubic")?;
1258        let (left_ep, right_ep) = self.support_interval()?;
1259        if value < left_ep {
1260            return Ok(exact_kernel::LocalSpanCubic {
1261                left: left_ep,
1262                right: left_ep + 1.0,
1263                c0: self.left_tail_value(beta),
1264                c1: 0.0,
1265                c2: 0.0,
1266                c3: 0.0,
1267            });
1268        }
1269        if value > right_ep {
1270            return Ok(exact_kernel::LocalSpanCubic {
1271                left: right_ep,
1272                right: right_ep + 1.0,
1273                c0: self.right_tail_value(beta),
1274                c1: 0.0,
1275                c2: 0.0,
1276                c3: 0.0,
1277            });
1278        }
1279        let span_idx = self.left_biased_span_index_for(value)?;
1280        self.local_cubic_on_span(beta, span_idx)
1281    }
1282
1283    // ── tail value helpers ──
1284
1285    /// Left-tail constant: deviation value at the leftmost breakpoint.
1286    /// For anchored I-spline bases this is the anchor value (typically 0).
1287    pub(super) fn left_tail_value(&self, beta: &Array1<f64>) -> f64 {
1288        self.span_c0.row(0).dot(beta)
1289    }
1290
1291    /// Right-tail constant: deviation value at the rightmost breakpoint.
1292    /// For I-spline bases this is the saturated integral value.
1293    pub(super) fn right_tail_value(&self, beta: &Array1<f64>) -> f64 {
1294        self.right_boundary_value_row.dot(beta)
1295    }
1296
1297    /// Conservative L1 sup-norm bound for the deviation value basis.
1298    ///
1299    /// For every evaluation point `x`, this returns a finite `K` such that
1300    /// `|B(x)·β| <= K * ||β||_∞`.  Each basis column is a cubic on each
1301    /// finite span and constant in the two tails, so the supremum is attained
1302    /// at a span endpoint, an interior root of the derivative, or a tail
1303    /// value.  Summing per-column suprema gives a conservative row-wise L1
1304    /// bound that is independent of `x`.
1305    pub(crate) fn value_basis_l1_sup_norm(&self) -> f64 {
1306        let mut total = 0.0;
1307        for basis_idx in 0..self.basis_dim {
1308            let mut col_sup = self.span_c0[[0, basis_idx]]
1309                .abs()
1310                .max(self.right_boundary_value_row[basis_idx].abs());
1311            for span_idx in 0..self.span_count() {
1312                let left = self.endpoint_points[span_idx];
1313                let right = self.endpoint_points[span_idx + 1];
1314                let width = right - left;
1315                if !width.is_finite() || width <= 0.0 {
1316                    continue;
1317                }
1318                let c0 = self.span_c0[[span_idx, basis_idx]];
1319                let c1 = self.span_c1[[span_idx, basis_idx]];
1320                let c2 = self.span_c2[[span_idx, basis_idx]];
1321                let c3 = self.span_c3[[span_idx, basis_idx]];
1322                let eval_abs = |t: f64| (c0 + c1 * t + c2 * t * t + c3 * t * t * t).abs();
1323                col_sup = col_sup.max(eval_abs(0.0)).max(eval_abs(width));
1324                let a = 3.0 * c3;
1325                let b = 2.0 * c2;
1326                let c = c1;
1327                if a.abs() <= f64::EPSILON {
1328                    if b.abs() > f64::EPSILON {
1329                        let t = -c / b;
1330                        if t > 0.0 && t < width {
1331                            col_sup = col_sup.max(eval_abs(t));
1332                        }
1333                    }
1334                } else {
1335                    let disc = b * b - 4.0 * a * c;
1336                    if disc >= 0.0 {
1337                        let sqrt_disc = disc.sqrt();
1338                        for t in [(-b - sqrt_disc) / (2.0 * a), (-b + sqrt_disc) / (2.0 * a)] {
1339                            if t > 0.0 && t < width {
1340                                col_sup = col_sup.max(eval_abs(t));
1341                            }
1342                        }
1343                    }
1344                }
1345            }
1346            total += col_sup;
1347        }
1348        total
1349    }
1350
1351    // ── monotonicity enforcement ──
1352
1353    pub(super) fn support_interval(&self) -> Result<(f64, f64), String> {
1354        match (self.endpoint_points.first(), self.endpoint_points.last()) {
1355            (Some(&left), Some(&right)) => Ok((left, right)),
1356            _ => Err(DeviationRuntimeError::InvalidInput {
1357                reason: "deviation runtime is missing monotonicity support points".to_string(),
1358            }
1359            .into()),
1360        }
1361    }
1362
1363    pub(crate) fn exact_monotonicity_min_slack(&self, beta: &Array1<f64>) -> Result<f64, String> {
1364        if beta.len() != self.basis_dim {
1365            return Err(DeviationRuntimeError::DimensionMismatch {
1366                reason: format!(
1367                    "deviation monotonicity length mismatch: got {}, expected {}",
1368                    beta.len(),
1369                    self.basis_dim
1370                ),
1371            }
1372            .into());
1373        }
1374        if beta.iter().any(|value| !value.is_finite()) {
1375            let bad = beta
1376                .iter()
1377                .enumerate()
1378                .find(|(_, value)| !value.is_finite())
1379                .map(|(idx, value)| format!("deviation coefficient {idx} is non-finite ({value})"))
1380                .unwrap_or_else(|| "deviation coefficient is non-finite".to_string());
1381            return Err(DeviationRuntimeError::InvalidInput { reason: bad }.into());
1382        }
1383
1384        let mut min_slack = f64::INFINITY;
1385        for span_idx in 0..self.span_count() {
1386            let left = self.endpoint_points[span_idx];
1387            let right = self.endpoint_points[span_idx + 1];
1388            let width = right - left;
1389            if !width.is_finite() || width <= 0.0 {
1390                continue;
1391            }
1392            let c1 = self.span_c1.row(span_idx).dot(beta);
1393            let c2 = self.span_c2.row(span_idx).dot(beta);
1394            let c3 = self.span_c3.row(span_idx).dot(beta);
1395            let d1_left = c1;
1396            let d1_right = c1 + 2.0 * c2 * width + 3.0 * c3 * width * width;
1397            let d2_left = 2.0 * c2;
1398            let d3 = 6.0 * c3;
1399            let left_slack = 1.0 + d1_left - self.monotonicity_eps;
1400            let right_slack = 1.0 + d1_right - self.monotonicity_eps;
1401            min_slack = min_slack.min(left_slack.min(right_slack));
1402
1403            if d3 > 0.0 {
1404                let t_star = -d2_left / d3;
1405                if t_star > 0.0 && t_star < width {
1406                    let interior = 1.0 + d1_left + d2_left * t_star + 0.5 * d3 * t_star * t_star
1407                        - self.monotonicity_eps;
1408                    min_slack = min_slack.min(interior);
1409                }
1410            }
1411        }
1412        if min_slack.is_finite() {
1413            Ok(min_slack)
1414        } else {
1415            Err(DeviationRuntimeError::NumericalFailure {
1416                reason: "deviation monotonicity slack computation produced no active spans"
1417                    .to_string(),
1418            }
1419            .into())
1420        }
1421    }
1422
1423    pub(crate) fn monotonicity_feasible(
1424        &self,
1425        beta: &Array1<f64>,
1426        context: &str,
1427    ) -> Result<(), String> {
1428        let slack = self.exact_monotonicity_min_slack(beta)?;
1429        if slack >= MONOTONICITY_SLACK_ROUNDOFF_TOL {
1430            Ok(())
1431        } else {
1432            let (left, right) = self.support_interval()?;
1433            Err(DeviationRuntimeError::NumericalFailure {
1434                reason: format!(
1435                    "{context} violates exact monotonicity on [{left:.6}, {right:.6}] (minimum derivative slack {slack:.3e}, eps={:.3e})",
1436                    self.monotonicity_eps
1437                ),
1438            }
1439            .into())
1440        }
1441    }
1442}