Skip to main content

gam_terms/basis/
types.rs

1use super::*;
2
3/// Wrapper to send a raw pointer across thread boundaries for parallel buffer fills.
4/// SAFETY: every `SendPtr` value must be built from live, properly aligned `f64`
5/// storage whose mutable borrow is held until all worker threads finish; callers
6/// may only dereference offsets that are in-bounds and disjoint across workers.
7#[derive(Clone, Copy)]
8pub(crate) struct SendPtr(pub(crate) *mut f64);
9
10// SAFETY: SendPtr only grants raw-pointer transport. Actual dereferences occur
11// at call sites after row-chunk partitioning proves each thread writes a
12// distinct in-bounds element of the backing Array/Vec allocation.
13unsafe impl Send for SendPtr {}
14
15// SAFETY: shared references to SendPtr are sound because the pointee is never
16// accessed through the wrapper without the call-site disjoint-offset proof.
17unsafe impl Sync for SendPtr {}
18
19impl SendPtr {
20    #[inline(always)]
21    pub(crate) fn add(self, offset: usize) -> *mut f64 {
22        // SAFETY: callers pass offsets within the backing allocation and only
23        // dereference the returned pointer after proving the target element is
24        // uniquely owned by that worker's chunk for the whole parallel region.
25        unsafe { self.0.add(offset) }
26    }
27}
28
29/// Re-export of the neutral basis-error contract. #1521: `BasisError` lives
30/// in `gam-problem` so `EstimationError` can wrap it (`#[from]`) without a
31/// back-edge; gam-terms re-exports it to preserve `gam_terms::basis::BasisError`.
32pub use gam_problem::BasisError;
33
34// ============================================================================
35// Unified Basis Generation API
36// ============================================================================
37
38/// Options for basis generation, controlling derivative order.
39#[derive(Clone, Copy, Debug, Default)]
40pub struct BasisOptions {
41    /// Derivative order: 0 = value (default), 1 = first derivative, 2 = second derivative
42    pub derivative_order: usize,
43    /// Basis family to evaluate.
44    pub basis_family: BasisFamily,
45}
46
47impl BasisOptions {
48    /// Create options for evaluating basis functions (no derivative).
49    pub const fn value() -> Self {
50        Self {
51            derivative_order: 0,
52            basis_family: BasisFamily::BSpline,
53        }
54    }
55
56    /// Create options for evaluating first derivatives of basis functions.
57    pub const fn first_derivative() -> Self {
58        Self {
59            derivative_order: 1,
60            basis_family: BasisFamily::BSpline,
61        }
62    }
63
64    /// Create options for evaluating second derivatives of basis functions.
65    pub const fn second_derivative() -> Self {
66        Self {
67            derivative_order: 2,
68            basis_family: BasisFamily::BSpline,
69        }
70    }
71
72    /// Create options for evaluating M-spline basis values.
73    pub const fn m_spline() -> Self {
74        Self {
75            derivative_order: 0,
76            basis_family: BasisFamily::MSpline,
77        }
78    }
79
80    /// Create options for evaluating I-spline basis values.
81    pub const fn i_spline() -> Self {
82        Self {
83            derivative_order: 0,
84            basis_family: BasisFamily::ISpline,
85        }
86    }
87}
88
89/// Basis-family selector for 1D spline evaluation.
90#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
91pub enum BasisFamily {
92    /// Standard B-splines.
93    #[default]
94    BSpline,
95    /// M-splines: normalized B-splines, M_i = ((k+1)/(t_{i+k+1}-t_i)) B_i.
96    MSpline,
97    /// I-splines: integrated M-splines, implemented by right-cumulative
98    /// sums of B-splines at degree k+1.
99    ISpline,
100}
101
102/// Specifies the source of knots for basis generation.
103#[derive(Clone, Debug)]
104pub enum KnotSource<'a> {
105    /// Use a pre-computed knot vector.
106    Provided(ArrayView1<'a, f64>),
107    /// Generate uniformly spaced knots based on data range.
108    Generate {
109        /// Data range (min, max) for knot placement.
110        data_range: (f64, f64),
111        /// Number of internal knots to place between boundaries.
112        num_internal_knots: usize,
113    },
114}
115/// Thin-plate regression spline basis and penalty (order m=2).
116///
117/// The returned basis has columns `[K_c | P]` where:
118/// - `K_c` is the constrained radial basis block (`K * Z`) with
119///   `P(knots)^T * α = 0` enforced via nullspace projection
120/// - `P` is the TPS polynomial null-space block containing all monomials of
121///   total degree `< m`, where `m = thin_plate_penalty_order(d)` (so `P` is
122///   just `[1, x_1, ..., x_d]` for `d <= 3`)
123///
124/// The returned penalty matrix is block-diagonal with:
125/// - upper-left `Omega_c = Z^T Omega Z` for the constrained radial block
126/// - zero lower-right block for unpenalized polynomial terms.
127///
128/// For double-penalty GAMs, a second ridge penalty `I` is also returned so the
129/// caller can optimize `(lambda_bending, lambdaridge)` jointly.
130#[derive(Debug, Clone)]
131pub struct ThinPlateSplineBasis {
132    pub basis: Array2<f64>,
133    pub penalty_bending: Array2<f64>,
134    pub penalty_ridge: Array2<f64>,
135    pub num_kernel_basis: usize,
136    pub num_polynomial_basis: usize,
137    pub dimension: usize,
138    /// Wood-TPRS radial reparameterization matrix `V`.
139    ///
140    /// Rows live in the side-constrained radial coefficient space. Columns are
141    /// the retained positive bending eigendirections of `Z' Ω Z`; numerically
142    /// near-null radial directions are dropped before the basis is exposed.
143    /// Therefore `V` can be rectangular: design columns are `Φ Z V`, and the
144    /// radial penalty is `diag(Λ_retained)`.
145    pub radial_reparam: Array2<f64>,
146}
147
148/// Matérn smoothness parameter `nu` (half-integer variants with closed forms).
149#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
150pub enum MaternNu {
151    Half,
152    ThreeHalves,
153    FiveHalves,
154    SevenHalves,
155    NineHalves,
156}
157
158impl MaternNu {
159    /// The half-integer smoothness value ν as an `f64` (0.5, 1.5, …).
160    pub const fn half_integer_value(self) -> f64 {
161        match self {
162            MaternNu::Half => 0.5,
163            MaternNu::ThreeHalves => 1.5,
164            MaternNu::FiveHalves => 2.5,
165            MaternNu::SevenHalves => 3.5,
166            MaternNu::NineHalves => 4.5,
167        }
168    }
169}
170
171/// Matérn radial basis and penalties.
172#[derive(Debug, Clone)]
173pub struct MaternSplineBasis {
174    pub basis: Array2<f64>,
175    pub penalty_kernel: Array2<f64>,
176    pub penalty_ridge: Array2<f64>,
177    pub num_kernel_basis: usize,
178    pub num_polynomial_basis: usize,
179    pub dimension: usize,
180}
181
182#[derive(Debug, Clone)]
183pub(crate) struct DuchonBasisDesign {
184    pub(crate) basis: Array2<f64>,
185}
186
187/// Boundary-condition policy for one-dimensional smooth bases.
188#[derive(Debug, Clone, Serialize, Deserialize, Default)]
189pub enum OneDimensionalBoundary {
190    /// Ordinary open interval basis with clamped endpoint behavior.
191    #[default]
192    Open,
193    /// Periodic/cyclic basis over the half-open interval `[start, end)`.
194    ///
195    /// Values are evaluated modulo `period = end - start`; the basis and its
196    /// first `degree - 1` derivatives agree at the two endpoints for B-splines.
197    Cyclic { start: f64, end: f64 },
198}
199
200impl OneDimensionalBoundary {
201    pub(crate) fn period(&self) -> Option<(f64, f64, f64)> {
202        match *self {
203            OneDimensionalBoundary::Open => None,
204            OneDimensionalBoundary::Cyclic { start, end } if end > start => {
205                Some((start, end, end - start))
206            }
207            OneDimensionalBoundary::Cyclic { .. } => None,
208        }
209    }
210}
211
212/// Which knot strategy to use for 1D B-spline bases.
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub enum BSplineKnotSpec {
215    Generate {
216        data_range: (f64, f64),
217        num_internal_knots: usize,
218    },
219    /// Uniform cyclic B-spline basis on `[data_range.0, data_range.1)`.
220    ///
221    /// The first and last endpoints are identified, so evaluating at `x` and
222    /// `x + m * period` gives identical rows. `num_basis` is the number of
223    /// periodic control sites around the loop and must be at least
224    /// `degree + 1` for an unaliased local support stencil.
225    PeriodicUniform {
226        data_range: (f64, f64),
227        num_basis: usize,
228    },
229    Automatic {
230        num_internal_knots: Option<usize>,
231        placement: BSplineKnotPlacement,
232    },
233    Provided(Array1<f64>),
234    /// Natural cubic regression spline (`bs="cr"`/`"cs"`) knot set (#1074).
235    ///
236    /// Unlike the open-spline variants above, these `knots` are the `k`
237    /// Lancaster–Salkauskas knots `x*_1 < … < x*_k` that *directly* index the
238    /// basis values `β_i = f(x*_i)` — the basis dimension equals `knots.len()`
239    /// (not `knots.len() - degree - 1`). The 1-D builder routes this variant to
240    /// the cubic-regression builder; the cr identity therefore round-trips
241    /// through freeze/reload by virtue of the variant itself (no separate
242    /// metadata marker is required), and tensor margins inherit cr by carrying
243    /// this knotspec into `build_bspline_basis_1d`.
244    NaturalCubicRegression {
245        knots: Array1<f64>,
246    },
247}
248
249/// Internal-knot placement strategy when knots are automatically inferred.
250#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
251pub enum BSplineKnotPlacement {
252    Uniform,
253    Quantile,
254}
255
256/// 1D B-spline basis configuration.
257#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct BSplineBasisSpec {
259    pub degree: usize,
260    pub penalty_order: usize,
261    pub knotspec: BSplineKnotSpec,
262    pub double_penalty: bool,
263    pub identifiability: BSplineIdentifiability,
264    #[serde(default)]
265    pub boundary: OneDimensionalBoundary,
266    /// Optional endpoint boundary constraints (Hermite-style pin of value and/or
267    /// derivative at the left/right knot extents). Default = `Free` on both
268    /// sides which is a no-op.
269    #[serde(default)]
270    pub boundary_conditions: BSplineBoundaryConditions,
271}
272
273/// Per-endpoint boundary constraint policy for B-spline 1D bases.
274#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
275pub enum BSplineEndpointBoundaryCondition {
276    /// No endpoint constraint.
277    #[default]
278    Free,
279    /// Pin the first derivative to zero at this endpoint.
280    Clamped,
281    /// Pin the value at this endpoint to `value` (currently only `value == 0`
282    /// is accepted in the builder; non-zero anchors require an affine offset).
283    Anchored { value: f64 },
284}
285
286/// Left/right pair of B-spline endpoint constraints.
287#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
288pub struct BSplineBoundaryConditions {
289    #[serde(default)]
290    pub left: BSplineEndpointBoundaryCondition,
291    #[serde(default)]
292    pub right: BSplineEndpointBoundaryCondition,
293}
294
295impl BSplineBoundaryConditions {
296    pub const fn is_free(&self) -> bool {
297        matches!(self.left, BSplineEndpointBoundaryCondition::Free)
298            && matches!(self.right, BSplineEndpointBoundaryCondition::Free)
299    }
300}
301
302/// Per-smooth identifiability policy for 1D B-spline bases.
303///
304/// These constraints are applied directly in the builder via a reparameterization
305/// `B_constrained = B * Z`, and every penalty matrix is projected as
306/// `S_constrained = Z' S Z`, so solver geometry stays consistent.
307#[derive(Debug, Clone, Serialize, Deserialize)]
308pub enum BSplineIdentifiability {
309    /// Keep unconstrained basis columns.
310    None,
311    /// Enforce weighted sum-to-zero: `B' w = 0` (or unweighted when `weights=None`).
312    // Smooth terms are centered by default to avoid intercept confounding.
313    WeightedSumToZero { weights: Option<Array1<f64>> },
314    /// Remove intercept + linear trend in coefficient space using Greville geometry.
315    RemoveLinearTrend,
316    /// Enforce orthogonality to supplied design columns `C` (n x q):
317    /// `B_c' W C = 0` (or unweighted when `weights=None`).
318    ///
319    /// To enforce `[intercept, x, ...]`, provide `columns` with those columns.
320    OrthogonalToDesignColumns {
321        columns: Array2<f64>,
322        weights: Option<Array1<f64>>,
323    },
324    /// Apply an explicit coefficient-space transform `Z` learned at fit time.
325    ///
326    /// This freezes identifiability behavior so prediction cannot drift based on
327    /// new-data distribution. The constrained basis is `B * Z`.
328    FrozenTransform { transform: Array2<f64> },
329}
330
331impl Default for BSplineIdentifiability {
332    fn default() -> Self {
333        BSplineIdentifiability::WeightedSumToZero { weights: None }
334    }
335}
336
337/// Spatial center selection strategy.
338///
339/// `num_centers` is the exact number of knot/center rows selected by the
340/// strategy. Polynomial nullspace columns are added separately by each basis
341/// builder and must never be folded into this count.
342#[derive(Debug, Clone, Serialize, Deserialize)]
343pub enum CenterStrategy {
344    Auto(Box<CenterStrategy>),
345    UserProvided(Array2<f64>),
346    /// Joint multidimensional equal-mass partitioning in the full smooth space.
347    EqualMass {
348        num_centers: usize,
349    },
350    /// Covariate-representative equal-mass partitioning along one selected axis.
351    EqualMassCovarRepresentative {
352        num_centers: usize,
353    },
354    FarthestPoint {
355        num_centers: usize,
356    },
357    KMeans {
358        num_centers: usize,
359        max_iter: usize,
360    },
361    UniformGrid {
362        points_per_dim: usize,
363    },
364}
365
366#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
367pub enum CenterStrategyKind {
368    UserProvided,
369    EqualMass,
370    EqualMassCovarRepresentative,
371    FarthestPoint,
372    KMeans,
373    UniformGrid,
374}
375
376/// Adaptive default center count for spatial smooths (TPS, Duchon, Matérn).
377///
378/// Use this when the user has not explicitly specified a knot/center count.
379/// The basis size is the sub-linear `ceil(8 * d_factor * n^0.4)`, clamped above
380/// at `K_MAX = 2000` and below at a *data-proportional* floor `min(200, n/8)` so
381/// the floor only engages once there are enough observations to support a rich
382/// basis. The result is additionally capped at `n/4` so the penalty matrices
383/// stay well-conditioned relative to the data:
384///
385/// | n      | d=1  | d=2  | d=5  |
386/// |--------|------|------|------|
387/// | 800    | 116  | 134  | 186  |
388/// | 1 000  | 127  | 146  | 200  |
389/// | 2 000  | 200  | 200  | 268  |
390/// | 10 000 | 319  | 367  | 510  |
391/// | 100 000| 801  | 921  | 1281 |
392/// | 400 000| 1393 | 1602 | 2000 |
393/// | 1 000 000| 2000 | 2000 | 2000 |
394///
395/// The flat `200` floor used to inflate moderate-`n` spatial smooths (a few
396/// hundred to ~2000 rows) up to a dense 200-column design even though the raw
397/// sub-linear count — and the mesh/knot density that mgcv and R-INLA use on the
398/// same data — is far smaller. On ~800 rows that turned a single 2-D thin-plate
399/// REML fit into an `O(n·p² + p³)` grind at `p ≈ 200` (#718). Smoothness is
400/// already controlled by REML's penalty weight λ, not by the center count, so a
401/// data-proportional floor recovers the same surface at a fraction of the cost.
402///
403/// # Arguments
404/// * `n` - sample size (number of observations)
405/// * `d` - covariate dimensionality (number of input variables in the smooth)
406pub fn default_num_centers(n: usize, d: usize) -> usize {
407    const K_MIN: usize = 200;
408    const K_MAX: usize = 2000;
409    const ALPHA: f64 = 0.4;
410    const C: f64 = 8.0;
411    /// Per-extra-dimension growth in the center count: each covariate axis
412    /// beyond the first widens the basis by 15% to keep the per-axis mesh
413    /// density roughly constant as the smooth's domain dimensionality grows.
414    const PER_DIM_GROWTH: f64 = 0.15;
415    /// Divisor for the data-proportional floor: the `K_MIN` floor only engages
416    /// once `n` exceeds `K_MIN * FLOOR_N_DIVISOR`, so small samples are not
417    /// forced up to a dense `K_MIN`-column design.
418    const FLOOR_N_DIVISOR: usize = 8;
419    /// Divisor for the conditioning cap: the center count never exceeds `n /
420    /// COND_N_DIVISOR`, keeping the penalty matrices well-conditioned relative
421    /// to the data.
422    const COND_N_DIVISOR: usize = 4;
423
424    let d_factor = 1.0 + PER_DIM_GROWTH * (d.max(1) - 1) as f64;
425    let raw = (C * d_factor * (n as f64).powf(ALPHA)).ceil() as usize;
426
427    // Data-proportional floor: never inflate beyond n/FLOOR_N_DIVISOR, so the
428    // K_MIN-center floor only takes effect once n is large enough (~1600) to
429    // genuinely support that many basis columns.
430    let floor = K_MIN.min(n / FLOOR_N_DIVISOR);
431    let k = raw.clamp(floor, K_MAX);
432
433    // Never exceed n itself; cap at n/COND_N_DIVISOR to keep the penalty
434    // matrices well-conditioned relative to the data.
435    k.min(n).min(n / COND_N_DIVISOR)
436}
437
438/// Conservative center count for a *secondary* (distributional) predictor's
439/// spatial smooth — e.g. the log-σ scale model in a Gaussian location-scale
440/// fit.
441///
442/// The mean is identified directly by the response, so it warrants the
443/// generous [`default_num_centers`] basis. A scale/shape predictor is
444/// identified only through (noisy) squared residuals: handing it a basis sized
445/// for the mean lets REML/LAML smoothing selection over-fit it, because where
446/// the fitted scale is driven small the *observed* information collapses and
447/// the determinant penalty stops holding the wiggle down (#501). This mirrors
448/// standard GAMLSS/mgcv practice of giving distribution parameters a modest
449/// default (mgcv's `k = 10` for a 1-D `s()`), grown gently with dimensionality
450/// and never exceeding the generous primary-predictor default.
451pub fn conservative_secondary_centers(n: usize, d: usize) -> usize {
452    const BASE_1D_CENTERS: usize = 10;
453    let modest = BASE_1D_CENTERS.saturating_mul(d.max(1));
454    default_num_centers(n, d).min(modest).max(1)
455}
456
457/// Resource-aware plan for a spatial smooth (Duchon / Matérn / TPS).
458///
459/// Returned by [`plan_spatial_basis`]. Captures the resolved center count,
460/// final basis dimension `p`, the dense byte cost for the value matrix and
461/// each derivative tier, and a recommended storage mode that is consistent
462/// with the supplied [`gam_runtime::resource::ResourcePolicy`].
463#[derive(Clone, Debug)]
464pub struct SpatialBasisPlan {
465    pub n: usize,
466    pub d: usize,
467    pub centers: usize,
468    pub p_final_estimate: usize,
469    pub dense_design_bytes: usize,
470    pub first_derivative_dense_bytes: usize,
471    pub second_derivative_dense_bytes: usize,
472    pub recommended_storage: SpatialStorageMode,
473}
474
475/// Storage mode recommended by [`plan_spatial_basis`].
476///
477/// * `DenseValueDenseDerivatives` — both the value design and its derivative
478///   matrices fit under the policy's single-materialization budget.
479/// * `LazyValueImplicitDerivatives` — the value design fits dense but the
480///   derivative matrices do not; switch derivatives to the implicit operator.
481/// * `OperatorOnly` — neither the design nor its derivatives fit; everything
482///   must be operator-backed.
483#[derive(Clone, Copy, Debug, PartialEq, Eq)]
484pub enum SpatialStorageMode {
485    DenseValueDenseDerivatives,
486    LazyValueImplicitDerivatives,
487    OperatorOnly,
488}
489
490/// How [`plan_spatial_basis`] should pick the spatial center count.
491#[derive(Clone, Copy, Debug)]
492pub enum CenterCountRequest {
493    /// Use the heuristic [`default_num_centers`].
494    Default,
495    /// Use the caller-supplied count exactly.
496    Explicit(usize),
497    /// Use [`default_num_centers`] but cap at `cap` to bound dense cost.
498    HeuristicCapped { cap: usize },
499}
500
501/// Build a resource-aware plan for a spatial smooth basis.
502///
503/// Computes the resolved center count, final basis dimension, dense byte
504/// estimates for the value design and first/second derivative tiers, and a
505/// recommended [`SpatialStorageMode`] derived from `policy`. This is the
506/// resource-aware replacement for ad-hoc calls to [`default_num_centers`] /
507/// [`heuristic_centers`](crate::term_builder::heuristic_centers).
508pub fn plan_spatial_basis(
509    n: usize,
510    d: usize,
511    requested_centers: CenterCountRequest,
512    nullspace_order: DuchonNullspaceOrder,
513    scale_dims: bool,
514    policy: &gam_runtime::resource::ResourcePolicy,
515) -> Result<SpatialBasisPlan, BasisError> {
516    if n == 0 {
517        crate::bail_invalid_basis!("plan_spatial_basis: n must be >= 1");
518    }
519    if d == 0 {
520        crate::bail_invalid_basis!("plan_spatial_basis: d must be >= 1");
521    }
522
523    // 1. Resolve center count.
524    let centers = match requested_centers {
525        CenterCountRequest::Default => default_num_centers(n, d),
526        CenterCountRequest::Explicit(k) => k,
527        CenterCountRequest::HeuristicCapped { cap } => default_num_centers(n, d).min(cap),
528    };
529
530    // 2. Nullspace dimension (Duchon polynomial null space of degree p-1).
531    //    `duchon_p_from_nullspace_order` returns m such that the null space is
532    //    polynomials of total degree < m, matching `duchon_nullspace_dimension`'s
533    //    `max_total_degree = m - 1` argument.
534    let m = duchon_p_from_nullspace_order(nullspace_order);
535    let nullspace_dim = if m == 0 {
536        0
537    } else {
538        duchon_nullspace_dimension(d, m - 1)
539    };
540
541    let p = centers.saturating_add(nullspace_dim);
542
543    // 3. Dense byte estimates.
544    let derivative_axes = if scale_dims { d } else { 0 };
545    let bytes_per_f64 = std::mem::size_of::<f64>();
546    let dense_design_bytes = bytes_per_f64.saturating_mul(n).saturating_mul(p);
547    let first_derivative_dense_bytes = dense_design_bytes.saturating_mul(derivative_axes);
548    // Diagonal second derivatives are also (D × n × p); off-diagonal cross terms
549    // would scale as D^2 but the planner reports the diagonal tier here.
550    let second_derivative_dense_bytes = first_derivative_dense_bytes;
551
552    // 4. Pick storage mode based on policy.
553    let recommended_storage = match policy.derivative_storage_mode {
554        gam_runtime::resource::DerivativeStorageMode::AnalyticOperatorRequired => {
555            SpatialStorageMode::OperatorOnly
556        }
557        gam_runtime::resource::DerivativeStorageMode::MaterializeIfSmall => {
558            let budget = policy.max_single_materialization_bytes;
559            if derivative_axes == 0 {
560                if dense_design_bytes <= budget {
561                    SpatialStorageMode::DenseValueDenseDerivatives
562                } else {
563                    SpatialStorageMode::LazyValueImplicitDerivatives
564                }
565            } else {
566                let total = dense_design_bytes
567                    .saturating_add(first_derivative_dense_bytes)
568                    .saturating_add(second_derivative_dense_bytes);
569                if total <= budget {
570                    SpatialStorageMode::DenseValueDenseDerivatives
571                } else if dense_design_bytes <= budget {
572                    SpatialStorageMode::LazyValueImplicitDerivatives
573                } else {
574                    SpatialStorageMode::OperatorOnly
575                }
576            }
577        }
578        gam_runtime::resource::DerivativeStorageMode::DiagnosticsOnly => {
579            // Diagnostic mode still prefers analytic storage for correctness.
580            SpatialStorageMode::OperatorOnly
581        }
582    };
583
584    Ok(SpatialBasisPlan {
585        n,
586        d,
587        centers,
588        p_final_estimate: p,
589        dense_design_bytes,
590        first_derivative_dense_bytes,
591        second_derivative_dense_bytes,
592        recommended_storage,
593    })
594}
595
596pub const fn default_spatial_center_strategy(num_centers: usize, d: usize) -> CenterStrategy {
597    if d <= 3 {
598        CenterStrategy::FarthestPoint { num_centers }
599    } else {
600        CenterStrategy::EqualMassCovarRepresentative { num_centers }
601    }
602}
603
604pub fn auto_spatial_center_strategy(num_centers: usize, d: usize) -> CenterStrategy {
605    let strategy = if d == 1 {
606        // In one dimension, farthest-point selection is the deterministic
607        // maximin grid over the observed domain. Equal-mass midpoints leave the
608        // low-frequency Duchon radial block slightly under-resolved at the
609        // boundaries, and REML then compensates with an over-smooth λ on
610        // low-noise signals (#504). The maximin grid matches the native
611        // reproducing-kernel interpolation geometry. The default strategy below
612        // extends the same space-filling contract to low-dimensional spatial
613        // GP bases, where kriging accuracy is governed by fill distance rather
614        // than marginal quantile balance.
615        CenterStrategy::FarthestPoint { num_centers }
616    } else {
617        default_spatial_center_strategy(num_centers, d)
618    };
619    CenterStrategy::Auto(Box::new(strategy))
620}
621
622pub const fn center_strategy_is_auto(strategy: &CenterStrategy) -> bool {
623    matches!(strategy, CenterStrategy::Auto(_))
624}
625
626pub(crate) fn realized_center_strategy(strategy: &CenterStrategy) -> &CenterStrategy {
627    match strategy {
628        CenterStrategy::Auto(inner) => inner.as_ref(),
629        other => other,
630    }
631}
632
633pub fn center_strategy_kind(strategy: &CenterStrategy) -> CenterStrategyKind {
634    match strategy {
635        CenterStrategy::Auto(inner) => center_strategy_kind(inner.as_ref()),
636        CenterStrategy::UserProvided(_) => CenterStrategyKind::UserProvided,
637        CenterStrategy::EqualMass { .. } => CenterStrategyKind::EqualMass,
638        CenterStrategy::EqualMassCovarRepresentative { .. } => {
639            CenterStrategyKind::EqualMassCovarRepresentative
640        }
641        CenterStrategy::FarthestPoint { .. } => CenterStrategyKind::FarthestPoint,
642        CenterStrategy::KMeans { .. } => CenterStrategyKind::KMeans,
643        CenterStrategy::UniformGrid { .. } => CenterStrategyKind::UniformGrid,
644    }
645}
646
647pub fn center_strategy_num_centers(strategy: &CenterStrategy) -> Option<usize> {
648    match strategy {
649        CenterStrategy::Auto(inner) => center_strategy_num_centers(inner.as_ref()),
650        CenterStrategy::UserProvided(centers) => Some(centers.nrows()),
651        CenterStrategy::EqualMass { num_centers }
652        | CenterStrategy::EqualMassCovarRepresentative { num_centers }
653        | CenterStrategy::FarthestPoint { num_centers }
654        | CenterStrategy::KMeans { num_centers, .. } => Some(*num_centers),
655        CenterStrategy::UniformGrid { .. } => None,
656    }
657}
658
659pub fn center_strategy_with_num_centers(
660    strategy: &CenterStrategy,
661    num_centers: usize,
662) -> Result<CenterStrategy, BasisError> {
663    validate_center_count(num_centers)?;
664    fn rebuild_inner(
665        strategy: &CenterStrategy,
666        num_centers: usize,
667    ) -> Result<CenterStrategy, BasisError> {
668        match strategy {
669            CenterStrategy::Auto(inner) => rebuild_inner(inner.as_ref(), num_centers),
670            CenterStrategy::EqualMass { .. } => Ok(CenterStrategy::EqualMass { num_centers }),
671            CenterStrategy::EqualMassCovarRepresentative { .. } => {
672                Ok(CenterStrategy::EqualMassCovarRepresentative { num_centers })
673            }
674            CenterStrategy::FarthestPoint { .. } => {
675                Ok(CenterStrategy::FarthestPoint { num_centers })
676            }
677            CenterStrategy::KMeans { max_iter, .. } => Ok(CenterStrategy::KMeans {
678                num_centers,
679                max_iter: *max_iter,
680            }),
681            CenterStrategy::UserProvided(_) | CenterStrategy::UniformGrid { .. } => {
682                Err(BasisError::InvalidInput(format!(
683                    "cannot replace center count for {:?} strategy",
684                    center_strategy_kind(strategy)
685                )))
686            }
687        }
688    }
689    let rebuilt = rebuild_inner(strategy, num_centers)?;
690    Ok(match strategy {
691        CenterStrategy::Auto(_) => CenterStrategy::Auto(Box::new(rebuilt)),
692        _ => rebuilt,
693    })
694}
695
696/// Thin-plate basis configuration.
697#[derive(Debug, Clone, Serialize, Deserialize)]
698pub struct ThinPlateBasisSpec {
699    pub center_strategy: CenterStrategy,
700    #[serde(default)]
701    pub periodic: Option<Vec<Option<f64>>>,
702    pub length_scale: f64,
703    pub double_penalty: bool,
704    #[serde(default)]
705    pub identifiability: SpatialIdentifiability,
706    /// Frozen Wood-TPRS radial reparameterization. When `Some`, the builder
707    /// reuses this `(raw_radial_cols) × (kept_radial_cols)` matrix instead of
708    /// recomputing it from the constrained kernel penalty eigensystem. The
709    /// rectangular case is the truncated regression-spline path; carrying it
710    /// into prediction guarantees identical radial modes to fit-time.
711    #[serde(default)]
712    pub radial_reparam: Option<Array2<f64>>,
713}
714
715/// Per-smooth identifiability policy for spatial (TPS / Duchon) bases.
716///
717/// For a raw local basis `B` and parametric design block `C`, the orthogonalized
718/// basis is `B_c = B Z` where columns of `Z` span `null((B^T C)^T)`. This enforces:
719///   `B_c^T C = 0`
720/// in the unweighted inner product, so spatial effects cannot absorb parametric
721/// directions that actually exist in the model. The standalone basis builder has
722/// only an implicit intercept available, so it centers smooths against that
723/// intercept. The term-collection builder augments `C` with explicit linear
724/// terms when those terms are present in the formula.
725#[derive(Debug, Default, Clone, Serialize, Deserialize)]
726pub enum SpatialIdentifiability {
727    /// Keep unconstrained basis columns.
728    None,
729    /// Orthogonalize the smooth against model-owned parametric columns.
730    // "Magic" default for modular GAMs with explicit parametric block:
731    // keep spatial smooth orthogonal to intercept/linear terms.
732    // ApproxKind: Exact (orthogonalization is an exact projection).
733    #[default]
734    OrthogonalToParametric,
735    /// Freeze a fit-time transform `Z`; prediction uses `B_new * Z` unchanged.
736    FrozenTransform { transform: Array2<f64> },
737}
738
739pub(crate) use sphere_kernels::{
740    wahba_sphere_kernel_derivative_dcos_kind, wahba_sphere_kernel_from_cos_kind,
741    wahba_sphere_kernel_from_cos_simd_kind, wahba_sphere_kernel_sobolev_derivative_dcos,
742};
743
744pub use sphere_spectral::{
745    pseudo_s2_truncated_coefficients, sobolev_s2_truncated_coefficients,
746    sphere_truncated_spectral_eval,
747};
748
749/// Matérn basis configuration.
750#[derive(Debug, Clone, Serialize, Deserialize)]
751pub struct MaternBasisSpec {
752    pub center_strategy: CenterStrategy,
753    #[serde(default)]
754    pub periodic: Option<Vec<Option<f64>>>,
755    pub length_scale: f64,
756    pub nu: MaternNu,
757    #[serde(default)]
758    pub include_intercept: bool,
759    pub double_penalty: bool,
760    #[serde(default)]
761    pub identifiability: MaternIdentifiability,
762    /// Per-axis anisotropy log-scales η_a (contrasts with Ση_a = 0).
763    ///
764    /// This implements geometric anisotropy: Λ = κA where A = diag(exp(η_a)),
765    /// det(A) = 1. The kernel is evaluated at r = κ|Ah| instead of r = κ|h|.
766    /// The decomposition preserves the isotropic scaling law for global κ
767    /// and adds d−1 shape parameters for directional relevance.
768    ///
769    /// Conditional positive definiteness is preserved under any invertible
770    /// linear coordinate transform (Schoenberg), so the kernel remains valid.
771    ///
772    /// When Some, the distance is r = √(Σ_a exp(2η_a) · (x_a - c_a)²).
773    /// When None, isotropic distance r = ‖x - c‖ is used.
774    #[serde(default)]
775    pub aniso_log_scales: Option<Vec<f64>>,
776    /// Frozen double-penalty nullspace-shrinkage decision (gam#787/#860).
777    ///
778    /// `None` (the default, and the cold-build value) = decide whether to emit
779    /// the `DoublePenaltyNullspace` candidate via the κ-dependent spectral test in
780    /// `build_nullspace_shrinkage_penalty`. `Some(b)` = force the decision (set by
781    /// the freeze step from the bootstrap-κ build, mirrored from
782    /// `MaternIdentifiability::FrozenTransform`) so the learned-penalty count stays
783    /// invariant as the κ-optimizer rebuilds the design at each trial length-scale.
784    /// Only consulted when `double_penalty` is true.
785    #[serde(default)]
786    pub nullspace_shrinkage_survived: Option<bool>,
787}
788
789/// Per-smooth identifiability policy for Matérn kernel coefficients.
790///
791/// These constraints are geometric (center-based), so they are stable across
792/// train/predict and do not depend on response weights.
793#[derive(Debug, Default, Clone, Serialize, Deserialize)]
794pub enum MaternIdentifiability {
795    /// Keep the unconstrained kernel coefficient space.
796    None,
797    /// Enforce `1^T alpha = 0` at center locations (removes constant drift).
798    // Safe default with model intercepts: prevent kernel block from absorbing
799    // a global mean level.
800    #[default]
801    CenterSumToZero,
802    /// Enforce orthogonality to `[1, c_1, ..., c_d]` at centers.
803    /// Use this when explicit linear terms should own global trends.
804    CenterLinearOrthogonal,
805    /// Freeze a fit-time transform `Z` so prediction cannot drift.
806    ///
807    /// `nullspace_shrinkage_survived` freezes the double-penalty
808    /// nullspace-shrinkage decision alongside the transform (gam#787/#860). The
809    /// matern double-penalty path emits a `DoublePenaltyNullspace` candidate iff
810    /// `build_nullspace_shrinkage_penalty(&projected_kernel)` finds a near-zero
811    /// eigenvalue — but that spectral test is κ-DEPENDENT (its tolerance scales
812    /// with `λ_max`), so a near-zero eigenvalue can cross the threshold as the
813    /// κ-optimizer rebuilds the design at each trial length-scale. That flips the
814    /// learned-penalty count 6↔7 across the rebuild and the rebuilt design's ρ
815    /// dimension then disagrees with the frozen joint setup ("joint hyper rho
816    /// dimension mismatch" → every κ seed fails startup validation). Freezing the
817    /// bootstrap-κ decision here (`Some(true)` = always emit the shrinkage
818    /// candidate, `Some(false)` = never) keeps the penalty count INVARIANT across
819    /// the κ rebuild so κ actually optimizes. `None` = decide via the spectral
820    /// test (the non-frozen / cold-build behavior; also the serde back-compat
821    /// default for transforms frozen before this field existed).
822    FrozenTransform {
823        transform: Array2<f64>,
824        #[serde(default)]
825        nullspace_shrinkage_survived: Option<bool>,
826    },
827}
828
829/// Duchon null-space polynomial degree.
830///
831/// Controls the polynomial null space of the Duchon / polyharmonic spline. The
832/// Duchon seminorm `‖D^m f‖²` annihilates all polynomials of total degree
833/// `< m`, so those polynomials must be handled as explicit unpenalized columns.
834///
835/// The user-facing `order` knob selects the polynomial degree cutoff `r`, and
836/// the resulting polynomial null space has dimension `C(d + r, r)` where `d`
837/// is the covariate dimension.  In the `duchon(...)` formula DSL:
838///
839/// | `order=` | Variant         | max total degree | null-space dim  |
840/// |----------|-----------------|------------------|-----------------|
841/// | `0`      | `Zero`          | 0                | `C(d+0,0) = 1`  |
842/// | `1`      | `Linear`        | 1                | `C(d+1,1) = d+1`|
843/// | `k≥2`    | `Degree(k)`     | k                | `C(d+k,k)`      |
844///
845/// **How the polynomial null space is consumed during basis construction:**
846///
847/// 1. `polynomial_block_from_order` materialises an `(n, C(d+r,r))` block `P`
848///    of monomials up to total degree `r` at the selected `centers`.
849/// 2. `kernel_constraint_nullspace` computes `Z = null(P_centers^T)`, a
850///    `(k, k − C(d+r,r))` matrix. Reparameterising the radial kernel
851///    coefficients as `α = Z γ` enforces the side condition `P_centers^T α = 0`
852///    and yields `k − C(d+r,r)` free kernel parameters.
853/// 3. The polynomial block `P_data` evaluated at the data rows is appended to
854///    the kernel block `Φ Z`, giving a total of
855///    `(k − C(d+r,r)) + C(d+r,r) = k` columns before the spatial
856///    identifiability transform.  Crucially, the total width equals the
857///    requested center count `k`, **not** `k + C(d+r,r)`.
858///
859/// **Example — `duchon(PC1, PC2, PC3, centers=10, order=1)` (d=3):**
860///
861/// - Polynomial null space: `C(3+1,1) = 4` monomials `{1, x₁, x₂, x₃}`.
862/// - Kernel columns after constraint: `10 − 4 = 6`.
863/// - Appended polynomial block: 4 columns.
864/// - Pre-identifiability total: `6 + 4 = 10` columns, i.e. exactly `centers`.
865///
866/// The variant naming matches the Duchon `m` parameter:
867/// `Zero` → `m=1`, `Linear` → `m=2`, `Degree(k)` → `m=k+1`.
868#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
869pub enum DuchonNullspaceOrder {
870    Zero,
871    Linear,
872    Degree(usize),
873}
874
875/// Duchon-like basis configuration with explicit low-frequency null-space
876/// control and explicit spectral power.
877#[derive(Debug, Clone, Serialize, Deserialize)]
878#[serde(deny_unknown_fields)]
879pub struct DuchonBasisSpec {
880    pub center_strategy: CenterStrategy,
881    #[serde(default)]
882    pub periodic: Option<Vec<Option<f64>>>,
883    /// Optional hybrid Matérn width. `None` means pure scale-free Duchon with
884    /// spectrum `||w||^(2p + 2s)`. `Some(length_scale)` enables the hybrid
885    /// spectrum `||w||^(2p) * (kappa^2 + ||w||^2)^s`, `kappa = 1/length_scale`.
886    pub length_scale: Option<f64>,
887    /// Literal Duchon spectral power `s` (`f64`, fractional values fully
888    /// threaded end-to-end). The pure-Duchon kernel exponent is `2(p + s) − d`,
889    /// so this is the knob that sets `φ(r)`: `s = 0` is the integer-order Duchon
890    /// kernel `r^{2p−d}` (its `r²·log r` log case in even `d`, ≡ the thin-plate
891    /// kernel); `s = (d − 1)/2` gives the cubic `r³` in every dimension.
892    ///
893    /// This field is taken LITERALLY by the basis builder — `power = 0` means
894    /// `s = 0`, NOT "use a default". The magic cubic default (applied when the
895    /// user gives no explicit power) is a request-layer choice resolved by the
896    /// formula / CLI / pyffi front-ends via [`duchon_cubic_default`]; by the time
897    /// a spec reaches the builder this value is the final intended `s`. The
898    /// hybrid Duchon–Matérn path (`length_scale = Some`) still requires an
899    /// integer `s` (read via `spec.power_as_usize()`).
900    pub power: f64,
901    pub nullspace_order: DuchonNullspaceOrder,
902    #[serde(default)]
903    pub identifiability: SpatialIdentifiability,
904    /// Per-axis anisotropy log-scales η_a.
905    ///
906    /// For hybrid Duchon (`length_scale=Some`), these are centered contrasts in
907    /// the decomposition Λ = κA with det(A)=1. For pure Duchon
908    /// (`length_scale=None`), they parameterize shape-only axis warping on the
909    /// public path and are centered before basis evaluation/writeback so no
910    /// global length scale is introduced.
911    ///
912    /// When Some, the distance is r = √(Σ_a exp(2η_a) · (x_a - c_a)²).
913    /// When None, isotropic distance r = ‖x - c‖ is used.
914    #[serde(default)]
915    pub aniso_log_scales: Option<Vec<f64>>,
916    #[serde(default)]
917    pub operator_penalties: DuchonOperatorPenaltySpec,
918    #[serde(default)]
919    pub boundary: OneDimensionalBoundary,
920    /// Data-metric radial reparameterization `V` (#1355), mirroring the
921    /// thin-plate Wood-TPRS reparam. When `Some`, the constrained kernel
922    /// transform is folded to `Z·V` so the realized design columns rotate into
923    /// the `G_c`-orthonormal generalized eigenbasis of `Ω_c v = μ G_c v` and the
924    /// native penalty becomes the diagonal curvature-per-unit-data-variance
925    /// spectrum (mgcv's cliff), preventing the REML over-smoothing collapse to
926    /// EDF = 1. Frozen at the cold dense build and replayed verbatim by the
927    /// predict / κ-trial / ψ-derivative paths so they stay bit-consistent with
928    /// the fit-time design. `None` on the lazy/streaming path (huge `n`), which
929    /// retains the original constrained basis.
930    #[serde(default)]
931    pub radial_reparam: Option<Array2<f64>>,
932}
933
934impl DuchonBasisSpec {
935    /// Integer view of `power` for the existing integer-only downstream chain.
936    /// Non-finite or non-integer values fall back to `0` (the integer-only
937    /// validators downstream already reject this case with a clear message).
938    pub fn power_as_usize(&self) -> usize {
939        duchon_power_to_usize(self.power)
940    }
941}
942
943/// Convert a Duchon spectral-power `f64` into the integer view used by the
944/// closed-form code paths. Non-finite, negative, or fractional values clamp to
945/// `0` so the validator downstream emits the canonical error.
946pub fn duchon_power_to_usize(power: f64) -> usize {
947    if !power.is_finite() || power < 0.0 {
948        return 0;
949    }
950    let rounded = power.round();
951    if (rounded - power).abs() > 1e-9 {
952        return 0;
953    }
954    rounded as usize
955}
956
957#[derive(Clone, Debug, Serialize, Deserialize)]
958pub struct DuchonOperatorPenaltySpec {
959    pub mass: OperatorPenaltySpec,
960    pub tension: OperatorPenaltySpec,
961    pub stiffness: OperatorPenaltySpec,
962}
963
964#[derive(Clone, Debug, Serialize, Deserialize)]
965pub enum OperatorPenaltySpec {
966    Active {
967        initial_log_lambda: f64,
968        prior: Option<RhoPrior>,
969    },
970    Disabled,
971}
972
973impl Default for DuchonOperatorPenaltySpec {
974    fn default() -> Self {
975        // ALL ON. The Duchon penalty is a Hilbert scale: curvature is the
976        // always-on exact RKHS `Primary` Gram and the trend ridge is always on;
977        // the lower orders — mass (amplitude `Σ(f−f̄)²`) and tension (first-order
978        // roughness `Σ‖∇f‖²`) — are active here, collocated on a density-blind
979        // data-support sample. REML deselects any the data don't support (SPEC:
980        // recover the null by default, opt INTO overfitting). Stiffness (`D2`)
981        // stays off — `Primary` is the exact, superior curvature. (The Matérn
982        // collocation overlay builds its own `all_active()`; SAE atoms, which
983        // ship only `Primary`, use `all_disabled()`.)
984        Self {
985            mass: OperatorPenaltySpec::Active {
986                initial_log_lambda: 0.0,
987                prior: None,
988            },
989            tension: OperatorPenaltySpec::Active {
990                initial_log_lambda: 0.0,
991                prior: None,
992            },
993            stiffness: OperatorPenaltySpec::Disabled,
994        }
995    }
996}
997
998impl DuchonOperatorPenaltySpec {
999    pub fn all_disabled() -> Self {
1000        Self {
1001            mass: OperatorPenaltySpec::Disabled,
1002            tension: OperatorPenaltySpec::Disabled,
1003            stiffness: OperatorPenaltySpec::Disabled,
1004        }
1005    }
1006
1007    /// All three operator dials active — used by the Matérn collocation overlay.
1008    pub fn all_active() -> Self {
1009        let active = || OperatorPenaltySpec::Active {
1010            initial_log_lambda: 0.0,
1011            prior: None,
1012        };
1013        Self {
1014            mass: active(),
1015            tension: active(),
1016            stiffness: active(),
1017        }
1018    }
1019
1020    /// Operator-penalty dials appropriate for a Matérn-ν kernel in dimension `d`.
1021    ///
1022    /// The Matérn-ν RKHS is the Sobolev space `H^m` with `m = ν + d/2`: its
1023    /// squared norm controls the order-`j` derivative in L2 exactly when
1024    /// `j ≤ m`. The collocation overlay penalizes the squared L2 norms of the
1025    /// value (mass, `D0`, j=0), gradient (tension, `D1`, j=1) and Hessian
1026    /// (stiffness, `D2`, j=2). Activating a penalty whose derivative order
1027    /// exceeds the RKHS smoothness (`j > m`) imposes a roughness constraint the
1028    /// true kernel does NOT — it over-smooths the reduced-rank fit relative to
1029    /// the exact GP (mgcv `bs="gp"`, GpGp).
1030    ///
1031    /// Concretely the roughest Matérn, ν=1/2 in d=1 (`m = 1`), is the
1032    /// Ornstein–Uhlenbeck/exponential kernel: an H¹ process whose sample paths
1033    /// are continuous but non-differentiable. Although `∫(f')²` is finite on
1034    /// its RKHS, the kernel itself already encodes the H¹ control; layering an
1035    /// extra tension dial on top biases the reduced-rank fit toward the smooth
1036    /// `C¹` functions the kernel does not favour (and stiffness `D2` toward
1037    /// `C²`), collapsing held-out oscillation (#707). We therefore gate each
1038    /// operator on `j < m` STRICTLY: mass (j=0) is always on, tension (j=1) is
1039    /// on for `m > 1`, stiffness (j=2) is on for `m > 2`. For ν ≥ 3/2 (or any
1040    /// d ≥ 2) every dial is active, recovering `all_active`; only the
1041    /// genuinely rough ν=1/2 (d=1) kernel — where the Sobolev order sits
1042    /// exactly on a derivative boundary — drops the higher operators.
1043    pub fn matern_for_smoothness(nu: MaternNu, d: usize) -> Self {
1044        let m = nu.half_integer_value() + 0.5 * d as f64;
1045        // Tolerance so an exact half-integer Sobolev order (e.g. m = 1.0 for
1046        // ν=1/2, d=1) reliably DISABLES the matching-order operator instead
1047        // of flipping on a float-equality knife-edge.
1048        const ORDER_EPS: f64 = 1e-9;
1049        let active = || OperatorPenaltySpec::Active {
1050            initial_log_lambda: 0.0,
1051            prior: None,
1052        };
1053        let gate = |order: f64| {
1054            if m > order + ORDER_EPS {
1055                active()
1056            } else {
1057                OperatorPenaltySpec::Disabled
1058            }
1059        };
1060        Self {
1061            mass: active(),
1062            tension: gate(1.0),
1063            stiffness: gate(2.0),
1064        }
1065    }
1066}
1067
1068pub fn minimum_duchon_power_for_operator_penalties(
1069    dim: usize,
1070    nullspace_order: DuchonNullspaceOrder,
1071    max_operator_derivative_order: usize,
1072) -> usize {
1073    let p = duchon_p_from_nullspace_order(nullspace_order);
1074    let mut s = 0usize;
1075    while 2 * (p + s) <= dim + max_operator_derivative_order {
1076        s += 1;
1077    }
1078    s
1079}
1080
1081/// Resolve a fully admissible Duchon `(nullspace_order, power)` pair.
1082///
1083/// Three constraints fold into one resolution:
1084///   (a) operator collocation up to `max_op`:        `2(p + s) > d + max_op`
1085///   (b) pure-mode CPD vs polynomial nullspace P_p:  `2s < d`
1086///       (Wendland Thm 8.17: pure polyharmonic kernel of order m = p+s in
1087///        R^d is CPD of order `m − ⌊d/2⌋ + 1[d even, log] / m − (d−1)/2
1088///        [d odd]`, and Duchon interpolation against P_p is well-posed iff
1089///        CPD order ≤ p, which collapses to `2s < d` since 2s, d are
1090///        integers and 2s is even.)
1091///   (a) implies the kernel-existence condition `2(p + s) > d`.
1092///   (b) is dropped when `length_scale` is `Some` (hybrid Matérn-blended
1093///       kernel is strictly PD, CPD order 0).
1094///
1095/// Strategy: at the requested `nullspace_order`, take the smallest `s`
1096/// satisfying (a). If that `s` violates (b) in pure mode, escalate the
1097/// nullspace order by one and retry. Termination: at `p ≥ ⌈(d+max_op)/2⌉ + 1`
1098/// the operator constraint (a) admits `s = 0`, and `0 < d` satisfies (b)
1099/// for any `d ≥ 1`, so escalation always converges.
1100///
1101/// The returned nullspace order is monotone in the request: it never
1102/// decreases the user's requested order — only strengthens it when pure-mode
1103/// CPD requires a richer polynomial absorption space.
1104pub fn resolve_duchon_orders(
1105    dim: usize,
1106    requested_nullspace_order: DuchonNullspaceOrder,
1107    max_operator_derivative_order: usize,
1108    length_scale: Option<f64>,
1109) -> (DuchonNullspaceOrder, usize) {
1110    assert!(dim >= 1, "Duchon basis requires dim >= 1");
1111    let pure = length_scale.is_none();
1112    let mut nullspace = requested_nullspace_order;
1113    // Bounded loop: escalation terminates by the argument above.
1114    for _ in 0..=(dim + max_operator_derivative_order + 1) {
1115        let p = duchon_p_from_nullspace_order(nullspace);
1116        // Smallest s with 2(p + s) > d + max_op:
1117        //   2p > d + max_op            ⇒ s = 0
1118        //   else s = ⌈(d + max_op + 1 − 2p) / 2⌉ = (d + max_op + 2 − 2p) / 2
1119        let s_op = if 2 * p > dim + max_operator_derivative_order {
1120            0
1121        } else {
1122            (dim + max_operator_derivative_order + 2 - 2 * p) / 2
1123        };
1124        if !pure || 2 * s_op < dim {
1125            return (nullspace, s_op);
1126        }
1127        nullspace = duchon_next_nullspace_order(nullspace);
1128    }
1129    // Bounded-loop fallback: by the analysis in the docstring, for
1130    // `p >= ceil((dim + max_op) / 2) + 1` the operator constraint admits
1131    // `s = 0` and (in pure mode) `0 < dim` satisfies the kernel-existence
1132    // condition. The loop above always reaches that regime within the bound,
1133    // so returning the last `nullspace` with `s = 0` is a valid answer.
1134    (nullspace, 0)
1135}
1136
1137#[inline]
1138pub(crate) fn duchon_next_nullspace_order(order: DuchonNullspaceOrder) -> DuchonNullspaceOrder {
1139    match order {
1140        DuchonNullspaceOrder::Zero => DuchonNullspaceOrder::Linear,
1141        DuchonNullspaceOrder::Linear => DuchonNullspaceOrder::Degree(2),
1142        DuchonNullspaceOrder::Degree(k) => DuchonNullspaceOrder::Degree(k + 1),
1143    }
1144}
1145
1146pub(crate) fn duchon_previous_nullspace_order(order: DuchonNullspaceOrder) -> DuchonNullspaceOrder {
1147    match order {
1148        DuchonNullspaceOrder::Zero => DuchonNullspaceOrder::Zero,
1149        DuchonNullspaceOrder::Linear => DuchonNullspaceOrder::Zero,
1150        DuchonNullspaceOrder::Degree(2) => DuchonNullspaceOrder::Linear,
1151        DuchonNullspaceOrder::Degree(k) => DuchonNullspaceOrder::Degree(k - 1),
1152    }
1153}
1154
1155/// Returns the maximum derivative order required by the *active* operator
1156/// penalties: 2 if stiffness is Active, else 1 if tension is Active, else 0.
1157/// Mass-only (or no active operator) penalties only require kernel validity
1158/// (`2(p+s) > d`), tension requires D1 collocation (`2(p+s) > d+1`), and
1159/// stiffness requires D2 collocation (`2(p+s) > d+2`).
1160pub fn duchon_max_active_operator_derivative_order(
1161    operator_penalties: &DuchonOperatorPenaltySpec,
1162) -> usize {
1163    if matches!(
1164        operator_penalties.stiffness,
1165        OperatorPenaltySpec::Active { .. }
1166    ) {
1167        2
1168    } else if matches!(
1169        operator_penalties.tension,
1170        OperatorPenaltySpec::Active { .. }
1171    ) {
1172        1
1173    } else {
1174        0
1175    }
1176}
1177
1178/// Metadata returned by generic basis builders.
1179#[derive(Debug, Clone)]
1180pub enum BasisMetadata {
1181    BSpline1D {
1182        knots: Array1<f64>,
1183        identifiability_transform: Option<Array2<f64>>,
1184        periodic: Option<(f64, f64, usize)>,
1185        /// Effective B-spline polynomial degree carried by `knots`.
1186        ///
1187        /// Persisted alongside `knots` so prediction can reconstruct an
1188        /// evaluator that matches fit-time geometry, even when the fit-time
1189        /// auto-shrink (issue #340) reduced the user's requested degree to
1190        /// fit the available data (`n` too small for cubic ⇒ quadratic ⇒
1191        /// linear). When `None` the consumer should fall back to the
1192        /// upstream `BSplineBasisSpec.degree` (legacy / non-shrunk path).
1193        degree: Option<usize>,
1194        /// Human-readable description of an automatic basis shrink (issue #340)
1195        /// when the user's requested `(degree, num_internal_knots)` exceeded the
1196        /// available evaluation count `n`. `Some(note)` records the before→after
1197        /// configuration; `None` means no auto-shrink occurred for this basis.
1198        auto_shrink_note: Option<String>,
1199    },
1200    /// Natural cubic regression spline (`bs="cr"`/`"cs"`) metadata (#1074).
1201    ///
1202    /// `knots` are the `k` Lancaster–Salkauskas knots that index the basis
1203    /// values directly (basis dim = `knots.len()`). Predict-time rebuilds
1204    /// reconstruct the cr geometry from `knots` and replay the captured
1205    /// `identifiability_transform` exactly, mirroring `BSpline1D`.
1206    CubicRegression1D {
1207        knots: Array1<f64>,
1208        identifiability_transform: Option<Array2<f64>>,
1209    },
1210    ThinPlate {
1211        centers: Array2<f64>,
1212        length_scale: f64,
1213        periodic: Option<Vec<Option<f64>>>,
1214        identifiability_transform: Option<Array2<f64>>,
1215        /// Per-column standard deviations used for input standardization (d > 1).
1216        input_scales: Option<Vec<f64>>,
1217        /// Wood-TPRS radial reparameterization carried into prediction so the
1218        /// rotated radial basis at predict-time matches fit-time exactly. `None`
1219        /// in the lazy/streaming path which retains the original basis.
1220        radial_reparam: Option<Array2<f64>>,
1221    },
1222    Sphere {
1223        centers: Array2<f64>,
1224        penalty_order: usize,
1225        method: SphereMethod,
1226        max_degree: Option<usize>,
1227        wahba_kernel: SphereWahbaKernel,
1228        constraint_transform: Option<Array2<f64>>,
1229    },
1230    /// Constant-curvature (`M_κ`) geodesic-kernel smooth (#944). `kappa` and
1231    /// the realized `length_scale` are persisted so predict-time (and the
1232    /// future ψ-channel per-trial) rebuilds replay the exact fit-time
1233    /// geometry; `constraint_transform` is the composed `z · z_parametric`
1234    /// frozen by the global identifiability pipeline (#532 pattern).
1235    ConstantCurvature {
1236        centers: Array2<f64>,
1237        kappa: f64,
1238        length_scale: f64,
1239        constraint_transform: Option<Array2<f64>>,
1240    },
1241    /// Measure-jet spline smooth: multiscale local-jet-residual energy of the
1242    /// empirical measure, quadratured on the center set. `centers` are the
1243    /// REALIZED barycenter nodes; `order_s` stores the spec's order sentinel
1244    /// verbatim as the mode marker (0.0 = per-level/spectral, > 0 = fused
1245    /// pin — persisting a realized default would flip the rebuilt mode). The
1246    /// penalty depends on the FIT data through `masses`, the realized
1247    /// `eps_band`, the support anchors, and the normalization scales, so all
1248    /// are persisted and replayed verbatim by
1249    /// predict-time (and per-ψ-trial) rebuilds — recomputing either from
1250    /// predict rows would change the penalty the coefficients were estimated
1251    /// under. `constraint_transform` is the composed `z · z_parametric`
1252    /// frozen by the global identifiability pipeline (#532 pattern).
1253    MeasureJet {
1254        centers: Array2<f64>,
1255        input_scales: Option<Vec<f64>>,
1256        length_scale: f64,
1257        eps_band: Vec<f64>,
1258        order_s: f64,
1259        alpha: f64,
1260        tau0: f64,
1261        masses: Array1<f64>,
1262        support_means: Vec<f64>,
1263        penalty_normalization_scales: Vec<f64>,
1264        raw_penalty_normalization_scales: Vec<f64>,
1265        fused_penalty_normalization_scale: Option<f64>,
1266        constraint_transform: Option<Array2<f64>>,
1267    },
1268    Matern {
1269        centers: Array2<f64>,
1270        length_scale: f64,
1271        periodic: Option<Vec<Option<f64>>>,
1272        nu: MaternNu,
1273        include_intercept: bool,
1274        identifiability_transform: Option<Array2<f64>>,
1275        /// Per-column standard deviations used for input standardization (d > 1).
1276        input_scales: Option<Vec<f64>>,
1277        /// Per-axis anisotropy log-scales η_a for geometric anisotropy.
1278        /// When Some, distance is r = √(Σ_a exp(2η_a) · (x_a - c_a)²).
1279        aniso_log_scales: Option<Vec<f64>>,
1280        /// Realized double-penalty nullspace-shrinkage decision at this build
1281        /// (gam#787/#860). The freeze step pins this into
1282        /// `MaternIdentifiability::FrozenTransform::nullspace_shrinkage_survived`
1283        /// so the κ-optimizer's per-trial rebuilds keep the learned-penalty count
1284        /// invariant (otherwise the κ-dependent spectral test flips it 6↔7 → "joint
1285        /// hyper rho dimension mismatch").
1286        nullspace_shrinkage_survived: bool,
1287    },
1288    Duchon {
1289        centers: Array2<f64>,
1290        length_scale: Option<f64>,
1291        periodic: Option<Vec<Option<f64>>>,
1292        power: f64,
1293        nullspace_order: DuchonNullspaceOrder,
1294        identifiability_transform: Option<Array2<f64>>,
1295        /// Per-column standard deviations used for input standardization (d > 1).
1296        input_scales: Option<Vec<f64>>,
1297        /// Per-axis anisotropy log-scales η_a, stored for prediction.
1298        aniso_log_scales: Option<Vec<f64>>,
1299        /// Support points used to build the active lower-order operator
1300        /// penalties (mass/tension/stiffness). Stored so runtime adaptive
1301        /// caches can rebuild the exact same operator rows instead of guessing
1302        /// from centers.
1303        operator_collocation_points: Option<Array2<f64>>,
1304        /// Data-metric radial reparameterization `V` (#1355). When `Some`, the
1305        /// constrained kernel transform is folded to `Z·V` so predict-time and
1306        /// κ-trial rebuilds replay the exact fit-time rotated radial basis.
1307        /// `None` on the lazy/streaming path (original constrained basis).
1308        radial_reparam: Option<Array2<f64>>,
1309    },
1310    Pca {
1311        feature_cols: Vec<usize>,
1312        basis_matrix: Array2<f64>,
1313        centered: bool,
1314        smooth_penalty: f64,
1315        center_mean: Option<Array1<f64>>,
1316        pca_basis_path: Option<std::path::PathBuf>,
1317        chunk_size: usize,
1318    },
1319    TensorBSpline {
1320        feature_cols: Vec<usize>,
1321        knots: Vec<Array1<f64>>,
1322        degrees: Vec<usize>,
1323        periods: Vec<Option<f64>>,
1324        /// Per-margin flag: `true` when that margin is a natural cubic
1325        /// regression spline (`NaturalCubicRegression` knotspec) rather than an
1326        /// open/periodic B-spline (#1074). Persisted so the tensor freeze
1327        /// rebuilds the cr marginal knotspec (value-at-knot) instead of an open
1328        /// `Provided(knots)` B-spline, keeping predict-time marginals identical
1329        /// to the fit-time cr margins. Defaults to all-`false` (legacy B-spline
1330        /// tensors) when deserialized from an older persisted model (the
1331        /// older-model default is applied on the persisted `SmoothBasisSpec`
1332        /// side; `BasisMetadata` itself is transient builder output and is not
1333        /// serde-serialized, so it carries no `#[serde]` attributes).
1334        is_cr: Vec<bool>,
1335        identifiability_transform: Option<Array2<f64>>,
1336    },
1337    SphereHarmonics {
1338        max_degree: usize,
1339        radians: bool,
1340    },
1341    /// Wrap an inner basis metadata to record a multiplicative `by` (continuous or
1342    /// factor) along a column of the dataset.
1343    BySmooth {
1344        inner: Box<BasisMetadata>,
1345        by_col: usize,
1346        levels: Option<Vec<u64>>,
1347        ordered: bool,
1348    },
1349    /// Factor-by-smooth (mgcv-style `s(x, by=g, bs="fs"|"sz"|"re")`).
1350    FactorSmooth {
1351        continuous_cols: Vec<usize>,
1352        group_col: usize,
1353        knots: Array1<f64>,
1354        degree: usize,
1355        periodic: Option<(f64, f64, usize)>,
1356        group_levels: Vec<u64>,
1357        flavour: String,
1358        /// `true` when the per-level marginal is a cubic regression spline
1359        /// (`NaturalCubicRegression` knotspec, mgcv's `bs="sz"` default marginal,
1360        /// #1074). Predict-time freeze must then restore a cr knotspec from the
1361        /// stored value-knots rather than treating them as a B-spline knot
1362        /// vector. Defaults to `false` (B-spline marginal) for backward compat.
1363        marginal_is_cr: bool,
1364    },
1365}
1366
1367/// Standardized basis build result for engine-level composition.
1368#[derive(Clone)]
1369pub struct BasisBuildResult {
1370    pub design: DesignMatrix,
1371    pub penalties: Vec<Array2<f64>>,
1372    pub nullspace_dims: Vec<usize>,
1373    pub penaltyinfo: Vec<PenaltyInfo>,
1374    pub metadata: BasisMetadata,
1375    /// Optional factored rowwise-Kronecker representation for tensor-product
1376    /// bases. When present, downstream code can keep the design operator-backed
1377    /// instead of forcing a fully materialized `n x prod(q_j)` block.
1378    pub kronecker_factored: Option<KroneckerFactoredBasis>,
1379    /// Per-active-penalty operator handles (parallel to `penalties`). Each
1380    /// entry is `Some(op)` when the closed-form factory emitted an op-form
1381    /// penalty bit-equivalent to the dense matrix, `None` for ordinary dense
1382    /// penalties. Downstream consumers route through the `Some` entries to
1383    /// avoid materializing dense `p x p` Grams in exact operator algebra.
1384    pub ops: Vec<Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>>,
1385    /// Per-active-penalty null-space eigenvector matrices (parallel to
1386    /// `penalties`). Each entry is `Some(U_null)` with `U_null.ncols() ==
1387    /// nullspace_dims[k]` when the active block has a non-trivial null space
1388    /// (eigenvalues ≤ spectral tolerance), and `None` when the block is
1389    /// already full-rank. The columns of `U_null` are the eigenvectors of
1390    /// `sym_penalty` at the (near-)zero eigenvalues — i.e., an orthonormal
1391    /// basis of `null(S_block)` in the block's own coordinate system.
1392    ///
1393    /// This is the raw spectral data that the construction pipeline uses to
1394    /// absorb each smooth's penalty null space into the parametric block
1395    /// (reparameterize-and-split). Without absorption the inner Newton solve
1396    /// cannot converge on data whose unpenalized signal lies along a null
1397    /// direction of `S` (phantom-multiplier refusal at the KKT certificate).
1398    pub null_eigenvectors: Vec<Option<Array2<f64>>>,
1399    /// Joint-null absorption rotation for this basis, when the basis carries
1400    /// any penalties with a non-trivial joint null space.
1401    ///
1402    /// `Some(rotation)` records `Q = [U_range | U_null]` where `U_null` spans
1403    /// the joint null space `null(Σ_k S_k)` over this basis's active
1404    /// penalties (unscaled — the structural joint null is independent of
1405    /// `λ`). After the basis pipeline applies this rotation, the design
1406    /// becomes `X · Q` and each penalty becomes `Qᵀ S_k Q`, block-diagonal
1407    /// with a guaranteed-zero null tail. The same `Q` must be replayed at
1408    /// prediction time, so it is persisted in the fitted model. `None`
1409    /// indicates either no penalties on this basis, or a full-rank joint
1410    /// penalty (joint nullity = 0). A `Some` value is never recorded with
1411    /// `joint_nullity == 0` — the `None` discriminant is canonical for
1412    /// "nothing to absorb".
1413    ///
1414    /// Stage-2 commit A: this field is plumbed into the struct but neither
1415    /// computed nor applied yet. Stage-2 commit B populates it; Stage-2
1416    /// commit D applies the rotation to `design` and `penalties`.
1417    pub joint_null_rotation: Option<JointNullRotation>,
1418}
1419
1420/// Joint-null absorption rotation, attached to a smooth's basis when the
1421/// basis's joint penalty `Σ_k S_k` has a non-trivial null space.
1422///
1423/// The `rotation` field stores the orthonormal eigenvector matrix
1424/// `Q = [U_range | U_null]` of the symmetric joint penalty: the first
1425/// `range_dim = rotation.ncols() - joint_nullity` columns span
1426/// `range(Σ_k S_k)`; the remaining `joint_nullity` columns span
1427/// `null(Σ_k S_k)`. After the pipeline applies the rotation, the smooth's
1428/// coefficient vector satisfies `β = Q · γ`, the design becomes `X · Q`,
1429/// and each per-block penalty `S_k` becomes `Qᵀ S_k Q`, which is guaranteed
1430/// block-diagonal with a zero `(joint_nullity × joint_nullity)` tail
1431/// (because the joint null annihilates every active `S_k`).
1432#[derive(Clone, Serialize, Deserialize)]
1433pub struct JointNullRotation {
1434    /// `(p_smooth × p_smooth)` orthonormal matrix; range columns first,
1435    /// joint-null columns last.
1436    pub rotation: Array2<f64>,
1437    /// Number of columns at the tail of `rotation` that span the joint
1438    /// null space. Always `> 0` when wrapped in `Some` — the value `0`
1439    /// is encoded as `None`.
1440    pub joint_nullity: usize,
1441}
1442
1443impl std::fmt::Debug for JointNullRotation {
1444    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1445        f.debug_struct("JointNullRotation")
1446            .field(
1447                "rotation",
1448                &format_args!("{}×{}", self.rotation.nrows(), self.rotation.ncols()),
1449            )
1450            .field("joint_nullity", &self.joint_nullity)
1451            .finish()
1452    }
1453}
1454
1455impl std::fmt::Debug for BasisBuildResult {
1456    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1457        f.debug_struct("BasisBuildResult")
1458            .field("design", &self.design)
1459            .field("penalties", &self.penalties)
1460            .field("nullspace_dims", &self.nullspace_dims)
1461            .field("penaltyinfo", &self.penaltyinfo)
1462            .field("metadata", &self.metadata)
1463            .field("kronecker_factored", &self.kronecker_factored)
1464            .field(
1465                "ops",
1466                &format_args!(
1467                    "[{}]",
1468                    self.ops
1469                        .iter()
1470                        .map(|o| if o.is_some() { "Some" } else { "None" })
1471                        .collect::<Vec<_>>()
1472                        .join(", ")
1473                ),
1474            )
1475            .field(
1476                "null_eigenvectors",
1477                &format_args!(
1478                    "[{}]",
1479                    self.null_eigenvectors
1480                        .iter()
1481                        .map(|u| match u {
1482                            Some(m) => format!("Some({}x{})", m.nrows(), m.ncols()),
1483                            None => "None".to_string(),
1484                        })
1485                        .collect::<Vec<_>>()
1486                        .join(", ")
1487                ),
1488            )
1489            .field("joint_null_rotation", &self.joint_null_rotation)
1490            .finish()
1491    }
1492}
1493
1494/// Factored tensor-product basis metadata for operator-backed downstream use.
1495#[derive(Debug)]
1496pub struct KroneckerFactoredBasis {
1497    /// Marginal design matrices: `marginal_designs[j]` is `(n, q_j)`.
1498    pub marginal_designs: Vec<Array2<f64>>,
1499    /// Marginal penalty matrices: `marginal_penalties[k]` is `(q_k, q_k)`.
1500    pub marginal_penalties: Vec<Array2<f64>>,
1501    /// Marginal basis dimensions: `[q_0, ..., q_{d-1}]`.
1502    pub marginal_dims: Vec<usize>,
1503    /// Whether the system includes a global ridge (double) penalty.
1504    pub has_double_penalty: bool,
1505    /// λ-invariant tensor structure (marginal eigensystems, reparameterized
1506    /// marginals, shrinkage scale), memoized once per fit. The marginal
1507    /// designs/penalties are fixed for the whole fit, so the expensive marginal
1508    /// `eigh()` and `B_k·U_k` GEMMs only need to run once — every outer REML
1509    /// iterate (50+ on the #1082 tensor cases) then reuses this. Filled lazily
1510    /// on first use via [`Self::invariant_structure`]. NOT serialized and reset
1511    /// to empty on `Clone` (it is purely a within-fit performance cache; a fresh
1512    /// owner recomputes on first demand, keeping every result bit-identical).
1513    invariant: std::sync::OnceLock<std::sync::Arc<crate::kronecker::KroneckerInvariantStructure>>,
1514}
1515
1516impl Clone for KroneckerFactoredBasis {
1517    fn clone(&self) -> Self {
1518        Self {
1519            marginal_designs: self.marginal_designs.clone(),
1520            marginal_penalties: self.marginal_penalties.clone(),
1521            marginal_dims: self.marginal_dims.clone(),
1522            has_double_penalty: self.has_double_penalty,
1523            // Propagate the memoized structure when present so a clone made
1524            // mid-fit keeps the hoist; otherwise start empty (recomputed on
1525            // first demand, identical result).
1526            invariant: match self.invariant.get() {
1527                Some(s) => {
1528                    let cell = std::sync::OnceLock::new();
1529                    cell.get_or_init(|| std::sync::Arc::clone(s));
1530                    cell
1531                }
1532                None => std::sync::OnceLock::new(),
1533            },
1534        }
1535    }
1536}
1537
1538impl KroneckerFactoredBasis {
1539    /// Construct from the fixed marginal data with an empty invariant cache.
1540    pub fn new(
1541        marginal_designs: Vec<Array2<f64>>,
1542        marginal_penalties: Vec<Array2<f64>>,
1543        marginal_dims: Vec<usize>,
1544        has_double_penalty: bool,
1545    ) -> Self {
1546        Self {
1547            marginal_designs,
1548            marginal_penalties,
1549            marginal_dims,
1550            has_double_penalty,
1551            invariant: std::sync::OnceLock::new(),
1552        }
1553    }
1554
1555    /// Lazily compute (once) and return the λ-invariant tensor structure
1556    /// (marginal eigensystems, reparameterized marginals, shrinkage scale).
1557    ///
1558    /// Computed from the fixed marginal designs/penalties, so the result is the
1559    /// same on every call within a fit; the first call pays the `eigh()` cost
1560    /// and every later call is a pointer load. Because the cache is keyed on the
1561    /// fixed marginal data and `marginal_penalties`/`marginal_designs` are
1562    /// immutable for the fit's lifetime, no invalidation is needed.
1563    pub fn invariant_structure(
1564        &self,
1565    ) -> Result<std::sync::Arc<crate::kronecker::KroneckerInvariantStructure>, BasisError> {
1566        // Fast path: already memoized.
1567        if let Some(s) = self.invariant.get() {
1568            return Ok(std::sync::Arc::clone(s));
1569        }
1570        // Compute outside the cell (fallible) and install via `get_or_init`. If a
1571        // concurrent racer already won, `get_or_init` drops our `computed` and
1572        // returns the stored one; either way the value is the unique function of
1573        // the fixed marginal data, so the returned Arc is correct.
1574        let computed = std::sync::Arc::new(crate::kronecker::KroneckerInvariantStructure::compute(
1575            &self.marginal_designs,
1576            &self.marginal_penalties,
1577            &self.marginal_dims,
1578        )?);
1579        let installed = self.invariant.get_or_init(|| computed);
1580        Ok(std::sync::Arc::clone(installed))
1581    }
1582}
1583
1584#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1585pub enum PenaltySource {
1586    Primary,
1587    DoublePenaltyNullspace,
1588    OperatorMass,
1589    OperatorTension,
1590    OperatorStiffness,
1591    /// One per input axis `a` of a multivariate Duchon smooth: the gradient
1592    /// energy along axis `a`, `Σ(∂f/∂x_a)²`, each with its own REML λ_a. REML
1593    /// shrinks an axis's contribution toward flat only when it does not earn
1594    /// its keep — penalty-based ARD / variable relevance, the replacement for
1595    /// brittle kernel-η optimization. Emitted when `scale_dims` is on.
1596    OperatorRelevance {
1597        axis: usize,
1598    },
1599    TensorMarginal {
1600        dim: usize,
1601    },
1602    TensorSeparable {
1603        penalized_margins: Vec<usize>,
1604    },
1605    TensorGlobalRidge,
1606    Other(String),
1607}
1608
1609#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1610pub enum PenaltyDropReason {
1611    ZeroMatrix,
1612    NumericalRankZero,
1613}
1614
1615fn default_normalization_scale() -> f64 {
1616    1.0
1617}
1618
1619#[derive(Debug, Clone, Serialize, Deserialize)]
1620pub struct PenaltyInfo {
1621    pub source: PenaltySource,
1622    pub original_index: usize,
1623    pub active: bool,
1624    pub effective_rank: usize,
1625    pub dropped_reason: Option<PenaltyDropReason>,
1626    pub nullspace_dim_hint: usize,
1627    #[serde(default = "default_normalization_scale")]
1628    pub normalization_scale: f64,
1629    /// Kronecker factors preserved from tensor penalty construction.
1630    /// When present, spectral decomposition can use per-factor eigendecomposition.
1631    #[serde(skip)]
1632    pub kronecker_factors: Option<Vec<Array2<f64>>>,
1633}
1634
1635#[derive(Clone)]
1636pub struct PenaltyCandidate {
1637    pub matrix: Array2<f64>,
1638    pub nullspace_dim_hint: usize,
1639    pub source: PenaltySource,
1640    pub normalization_scale: f64,
1641    /// Optional Kronecker factors whose product equals `matrix`.
1642    /// When present, spectral decomposition can be done per-factor
1643    /// (O(Σ q_j³) instead of O((Π q_j)³)).
1644    pub kronecker_factors: Option<Vec<Array2<f64>>>,
1645    /// Optional operator-form handle whose `as_dense()` matches `matrix`. When
1646    /// populated by the closed-form factories, this is propagated through to
1647    /// `CanonicalPenaltyBlock` so downstream consumers can use exact matvec
1648    /// algebra without rebuilding the dense Gram. When `None`, only the dense
1649    /// `matrix` path is available.
1650    pub op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1651}
1652
1653impl std::fmt::Debug for PenaltyCandidate {
1654    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1655        f.debug_struct("PenaltyCandidate")
1656            .field(
1657                "matrix",
1658                &format_args!("{}×{}", self.matrix.nrows(), self.matrix.ncols()),
1659            )
1660            .field("nullspace_dim_hint", &self.nullspace_dim_hint)
1661            .field("source", &self.source)
1662            .field("normalization_scale", &self.normalization_scale)
1663            .field(
1664                "kronecker_factors",
1665                &self.kronecker_factors.as_ref().map(|v| v.len()),
1666            )
1667            .field("op", &self.op.as_ref().map(|o| o.dim()))
1668            .finish()
1669    }
1670}
1671
1672#[derive(Clone)]
1673pub struct CanonicalPenaltyBlock {
1674    pub sym_penalty: Array2<f64>,
1675    /// Eigenvalues from spectral decomposition (retained to avoid recomputation).
1676    pub eigenvalues: Array1<f64>,
1677    /// Eigenvectors from spectral decomposition (retained to avoid recomputation).
1678    pub eigenvectors: Array2<f64>,
1679    pub rank: usize,
1680    pub nullity: usize,
1681    /// Number of genuine negative-curvature eigendirections (`ev < -tol`).
1682    /// A non-PSD penalty has `negative_dim > 0`; these directions are
1683    /// neither range nor null and are never absorbed as unpenalized (#1425).
1684    pub negative_dim: usize,
1685    pub tol: f64,
1686    pub iszero: bool,
1687    /// Optional operator-form handle that is bit-equivalent to `sym_penalty`.
1688    /// Propagated from `PenaltyCandidate.op` when present so downstream
1689    /// consumers can use matvec without rebuilding the dense Gram.
1690    pub op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1691}
1692
1693impl std::fmt::Debug for CanonicalPenaltyBlock {
1694    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1695        f.debug_struct("CanonicalPenaltyBlock")
1696            .field(
1697                "sym_penalty",
1698                &format_args!("{}×{}", self.sym_penalty.nrows(), self.sym_penalty.ncols()),
1699            )
1700            .field("eigenvalues", &self.eigenvalues)
1701            .field(
1702                "eigenvectors",
1703                &format_args!(
1704                    "{}×{}",
1705                    self.eigenvectors.nrows(),
1706                    self.eigenvectors.ncols()
1707                ),
1708            )
1709            .field("rank", &self.rank)
1710            .field("nullity", &self.nullity)
1711            .field("negative_dim", &self.negative_dim)
1712            .field("tol", &self.tol)
1713            .field("iszero", &self.iszero)
1714            .field("op", &self.op.as_ref().map(|o| o.dim()))
1715            .finish()
1716    }
1717}
1718
1719#[derive(Debug)]
1720pub struct BasisPsiDerivativeResult {
1721    pub design_derivative: Array2<f64>,
1722    pub penalties_derivative: Vec<Array2<f64>>,
1723    /// Operator-backed design derivative for standalone first-derivative
1724    /// callers. Bundled first+second callers receive the shared operator on
1725    /// `BasisPsiDerivativeBundle` instead.
1726    pub implicit_operator: Option<ImplicitDesignPsiDerivative>,
1727}
1728
1729#[derive(Debug)]
1730pub struct BasisPsiSecondDerivativeResult {
1731    pub designsecond_derivative: Array2<f64>,
1732    pub penaltiessecond_derivative: Vec<Array2<f64>>,
1733    /// Operator-backed design derivative for standalone second-derivative
1734    /// callers. Bundled first+second callers receive the shared operator on
1735    /// `BasisPsiDerivativeBundle` instead.
1736    pub implicit_operator: Option<ImplicitDesignPsiDerivative>,
1737}
1738
1739#[derive(Debug)]
1740pub struct BasisPsiDerivativeBundle {
1741    pub first: BasisPsiDerivativeResult,
1742    pub second: BasisPsiSecondDerivativeResult,
1743    /// Shared operator-backed design derivative for the first and second
1744    /// psi derivatives. Bundled callers consume this once instead of storing
1745    /// duplicate materialized/streaming operators in both derivative payloads.
1746    pub implicit_operator: Option<ImplicitDesignPsiDerivative>,
1747}
1748
1749/// Per-axis psi_a derivative package for anisotropic spatial terms.
1750///
1751/// For a d-dimensional anisotropic term, the kernel phi(r) depends on
1752/// the anisotropic distance r = |Lambda h| where Lambda = diag(kappa_a). Each axis a
1753/// has its own log-scale psi_a = log(kappa_a), yielding d first derivatives,
1754/// d diagonal second derivatives, and d*(d-1)/2 cross second derivatives.
1755///
1756/// The cross second derivative d2 phi/(d psi_a d psi_b) = t * s_a * s_b (a != b)
1757/// is rank-1, so we store the t_values and s_components vectors rather
1758/// than materializing d^2 matrices.
1759#[derive(Clone)]
1760pub struct AnisoBasisPsiDerivatives {
1761    /// d matrices, each (n x p_smooth): dX/d psi_a.
1762    pub design_first: Vec<Array2<f64>>,
1763    /// d matrices, each (n x p_smooth): d2X/d psi_a^2 (diagonal second derivatives).
1764    pub design_second_diag: Vec<Array2<f64>>,
1765    /// Cross second derivatives d2X/(d psi_a d psi_b) for a < b.
1766    pub design_second_cross: Vec<Array2<f64>>,
1767    /// Axis-pair indices corresponding to `design_second_cross`.
1768    pub design_second_cross_pairs: Vec<(usize, usize)>,
1769    /// d x num_penalties: dS_m/d psi_a for each axis a and penalty m.
1770    pub penalties_first: Vec<Vec<Array2<f64>>>,
1771    /// d x num_penalties: d2S_m/d psi_a^2 for each axis a and penalty m.
1772    pub penalties_second_diag: Vec<Vec<Array2<f64>>>,
1773    /// The (a, b) axis pairs supported by the on-demand cross-penalty
1774    /// provider. Only the upper triangle (a < b) is stored.
1775    pub penalties_cross_pairs: Vec<(usize, usize)>,
1776    /// On-demand cross-penalty second-derivative provider. Exact anisotropic
1777    /// cross-axis penalty seconds are streamed one pair at a time rather than
1778    /// stored as a dense upper triangle of blocks.
1779    pub penalties_cross_provider: Option<AnisoPenaltyCrossProvider>,
1780    /// Shared operator-backed representation of the anisotropic kernel-side
1781    /// design derivatives. When `design_first` / `design_second_diag` are empty,
1782    /// callers must use this operator directly; when they are present, this
1783    /// operator still provides exact cross-axis second derivatives without
1784    /// duplicating separate `t` / `s_a` storage layouts.
1785    pub implicit_operator: Option<ImplicitDesignPsiDerivative>,
1786}
1787
1788#[derive(Clone)]
1789pub struct AnisoPenaltyCrossProvider(
1790    std::sync::Arc<
1791        dyn Fn(usize, usize) -> Result<Vec<Array2<f64>>, BasisError> + Send + Sync + 'static,
1792    >,
1793);
1794
1795impl AnisoPenaltyCrossProvider {
1796    pub(crate) fn new<F>(f: F) -> Self
1797    where
1798        F: Fn(usize, usize) -> Result<Vec<Array2<f64>>, BasisError> + Send + Sync + 'static,
1799    {
1800        Self(std::sync::Arc::new(f))
1801    }
1802
1803    pub fn evaluate(&self, axis_a: usize, axis_b: usize) -> Result<Vec<Array2<f64>>, BasisError> {
1804        (self.0)(axis_a, axis_b)
1805    }
1806}
1807
1808// ═══════════════════════════════════════════════════════════════════════════
1809//  Implicit derivative operator for scalable anisotropic REML gradients
1810// ═══════════════════════════════════════════════════════════════════════════
1811
1812pub(crate) const SPATIAL_CENTER_CENTER_MAX_BYTES: usize = 512 * 1024 * 1024; // 512 MiB
1813pub(crate) const DESIGN_CROSS_CHUNK_SIZE: usize = 1024;
1814
1815/// Determine whether implicit operators should be used based on problem size
1816/// and the supplied [`ResourcePolicy`].
1817///
1818/// Returns `true` when the dense materialization of D first-derivative
1819/// matrices would exceed `policy.max_single_materialization_bytes`.
1820///
1821/// For D axes with n data points and p_smooth basis columns, the dense path
1822/// allocates D * n * p_smooth * 8 bytes for first-derivative matrices alone
1823/// (plus a similar amount for second derivatives). The implicit path stores
1824/// only the compact (n * n_knots) radial jets plus (n * n_knots * D) axis
1825/// fractions, which is O(n * k * D) instead of O(n * p * D).
1826pub fn should_use_implicit_operators_with_policy(
1827    n: usize,
1828    p: usize,
1829    d: usize,
1830    policy: &gam_runtime::resource::ResourcePolicy,
1831) -> bool {
1832    // Each first-derivative matrix is (n x p) f64 → n*p*8 bytes.
1833    // We need D of them for first derivatives, D for second diag, plus
1834    // the cross-t matrix and s_components. Conservative estimate: 3*D matrices.
1835    let dense_bytes = 3usize
1836        .saturating_mul(n)
1837        .saturating_mul(p)
1838        .saturating_mul(d)
1839        .saturating_mul(8);
1840    dense_bytes > policy.max_single_materialization_bytes
1841}
1842
1843pub(crate) fn implicit_radial_cache_bytes(n: usize, k: usize, n_axes: usize) -> usize {
1844    n.saturating_mul(k)
1845        .saturating_mul(n_axes.saturating_add(3))
1846        .saturating_mul(8)
1847}
1848
1849pub(crate) fn should_cache_implicit_radial_components(
1850    n: usize,
1851    k: usize,
1852    n_axes: usize,
1853    policy: &gam_runtime::resource::ResourcePolicy,
1854) -> bool {
1855    implicit_radial_cache_bytes(n, k, n_axes) <= policy.max_operator_cache_bytes
1856}
1857
1858pub fn assert_no_dense_derivative_materialization(n: usize, p: usize, d_pc: usize) {
1859    let first = dense_design_bytes(n, p).saturating_mul(d_pc);
1860    let second = dense_design_bytes(n, p).saturating_mul(d_pc.saturating_mul(d_pc));
1861    // Consult the library default ResourcePolicy. Production large-scale runs
1862    // configure `AnalyticOperatorRequired`, which still refuses every dense
1863    // materialization here. The default `MaterializeIfSmall` mode lets tiny
1864    // problems (and small-data/test usage) materialize as long as the combined
1865    // first- and second-order dense bytes fit under the single-materialization
1866    // byte budget. `DiagnosticsOnly` is treated like `MaterializeIfSmall` for
1867    // this guard: it permits dense materialization under the same byte cap.
1868    let policy = gam_runtime::resource::ResourcePolicy::default_library();
1869    let budget = policy.max_single_materialization_bytes;
1870    let needed = first.saturating_add(second);
1871    match policy.derivative_storage_mode {
1872        gam_runtime::resource::DerivativeStorageMode::AnalyticOperatorRequired => {
1873            // SAFETY: this assertion helper exists specifically to enforce
1874            // the large-scale invariant that spatial-PC Duchon derivative
1875            // designs never persist as dense `Array2<f64>` storage. When the
1876            // resource policy is `AnalyticOperatorRequired`, any caller that
1877            // reached this point has materialized something the strict
1878            // operator contract forbids.
1879            // SAFETY: AnalyticOperatorRequired forbids dense derivative materialization.
1880            panic!(
1881                "spatial PC Duchon derivative designs must remain operator-backed; refused persistent dense derivative materialization (n={n}, p={p}, d_pc={d_pc}, first_order={:.1} MiB, second_order={:.1} MiB)",
1882                first as f64 / (1024.0 * 1024.0),
1883                second as f64 / (1024.0 * 1024.0),
1884            );
1885        }
1886        gam_runtime::resource::DerivativeStorageMode::MaterializeIfSmall
1887        | gam_runtime::resource::DerivativeStorageMode::DiagnosticsOnly => {
1888            // SAFETY: exceeding the single-materialization budget here is a
1889            // contract violation by an upstream caller that must route through
1890            // the operator-backed path; failing loudly surfaces it rather than
1891            // silently materializing an oversized dense derivative design.
1892            assert!(
1893                needed <= budget,
1894                "spatial PC Duchon derivative designs would exceed the single-materialization budget; refused persistent dense derivative materialization (n={n}, p={p}, d_pc={d_pc}, first_order={:.1} MiB, second_order={:.1} MiB, budget={:.1} MiB)",
1895                first as f64 / (1024.0 * 1024.0),
1896                second as f64 / (1024.0 * 1024.0),
1897                budget as f64 / (1024.0 * 1024.0),
1898            );
1899        }
1900    }
1901}
1902
1903pub fn assert_spatial_centers_below_large_scale_cap(
1904    d_pc: usize,
1905    centers: ArrayView2<'_, f64>,
1906) -> Result<(), BasisError> {
1907    if centers.ncols() != d_pc {
1908        crate::bail_dim_basis!(
1909            "spatial PC center dimension mismatch: centers have {} columns, expected {d_pc}",
1910            centers.ncols()
1911        );
1912    }
1913    let k = centers.nrows();
1914    let centers_bytes = dense_design_bytes(k, d_pc);
1915    let center_center_bytes = dense_design_bytes(k, k);
1916    if centers_bytes > SPATIAL_CENTER_CENTER_MAX_BYTES {
1917        crate::bail_invalid_basis!(
1918            "spatial PC centers exceed center storage cap: K={k}, d_pc={d_pc}, centers={:.1} MiB, cap={:.1} MiB",
1919            centers_bytes as f64 / (1024.0 * 1024.0),
1920            SPATIAL_CENTER_CENTER_MAX_BYTES as f64 / (1024.0 * 1024.0),
1921        );
1922    }
1923    if center_center_bytes > SPATIAL_CENTER_CENTER_MAX_BYTES {
1924        crate::bail_invalid_basis!(
1925            "spatial PC centers exceed center-center large-scale cap: K={k}, d_pc={d_pc}, KxK={:.1} MiB, cap={:.1} MiB",
1926            center_center_bytes as f64 / (1024.0 * 1024.0),
1927            SPATIAL_CENTER_CENTER_MAX_BYTES as f64 / (1024.0 * 1024.0),
1928        );
1929    }
1930    Ok(())
1931}
1932
1933pub(crate) fn dense_design_bytes(n: usize, p: usize) -> usize {
1934    n.saturating_mul(p)
1935        .saturating_mul(std::mem::size_of::<f64>())
1936}
1937
1938pub(crate) fn should_use_lazy_spatial_design(
1939    n: usize,
1940    p: usize,
1941    policy: &gam_runtime::resource::ResourcePolicy,
1942) -> bool {
1943    dense_design_bytes(n, p) > policy.max_single_materialization_bytes
1944}
1945
1946pub(crate) fn wrap_dense_design_with_transform(
1947    design: DesignMatrix,
1948    transform: &Array2<f64>,
1949    label: &str,
1950) -> Result<DesignMatrix, BasisError> {
1951    match design {
1952        DesignMatrix::Dense(inner) => {
1953            let op = CoefficientTransformOperator::new(inner, transform.clone()).map_err(|e| {
1954                BasisError::InvalidInput(format!("{label} coefficient transform failed: {e}"))
1955            })?;
1956            Ok(DesignMatrix::Dense(
1957                gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op)),
1958            ))
1959        }
1960        DesignMatrix::Sparse(_) => Err(BasisError::InvalidInput(format!(
1961            "{label} coefficient transform requires a dense/operator-backed design"
1962        ))),
1963    }
1964}
1965
1966/// Single-pass `(Bᵀ(W·C), BᵀB)` accumulation over the streamed design.
1967///
1968/// Materialises each row chunk of the design **once** and reuses it for both
1969/// the constraint cross `Bᵀ(W·C)` and the Gram `BᵀB`. On the lazy chunked
1970/// spatial path each `try_row_chunk` re-evaluates all kernel columns for the
1971/// chunk, so accumulating both products in a single sweep halves the per-build
1972/// kernel re-evaluation work (the dominant cost at large scale) versus two
1973/// independent streaming passes — without changing the result beyond
1974/// floating-point reassociation. The cross is masked off (`q == 0`) by the
1975/// caller, which never invokes this when there is no constraint block.
1976pub(crate) fn design_cross_and_gram(
1977    design: &DesignMatrix,
1978    constraint_matrix: ArrayView2<'_, f64>,
1979    weights: Option<ArrayView1<'_, f64>>,
1980) -> Result<(Array2<f64>, Array2<f64>), BasisError> {
1981    let n = design.nrows();
1982    let k = design.ncols();
1983    if constraint_matrix.nrows() != n {
1984        return Err(BasisError::ConstraintMatrixRowMismatch {
1985            basisrows: n,
1986            constraintrows: constraint_matrix.nrows(),
1987        });
1988    }
1989    if let Some(w) = weights
1990        && w.len() != n
1991    {
1992        return Err(BasisError::WeightsDimensionMismatch {
1993            expected: n,
1994            found: w.len(),
1995        });
1996    }
1997    let q = constraint_matrix.ncols();
1998    let mut cross = Array2::<f64>::zeros((k, q));
1999    let mut gram = Array2::<f64>::zeros((k, k));
2000    for start in (0..n).step_by(DESIGN_CROSS_CHUNK_SIZE) {
2001        let end = (start + DESIGN_CROSS_CHUNK_SIZE).min(n);
2002        let basis_chunk = design
2003            .try_row_chunk(start..end)
2004            .map_err(|e| BasisError::InvalidInput(e.to_string()))?;
2005        let mut constraint_chunk = constraint_matrix.slice(s![start..end, ..]).to_owned();
2006        if let Some(w) = weights {
2007            for (mut row, &weight) in constraint_chunk
2008                .axis_iter_mut(Axis(0))
2009                .zip(w.slice(s![start..end]).iter())
2010            {
2011                row *= weight;
2012            }
2013        }
2014        cross += &fast_atb(&basis_chunk, &constraint_chunk);
2015        gram += &fast_atb(&basis_chunk, &basis_chunk);
2016    }
2017    Ok((cross, gram))
2018}
2019
2020pub(crate) fn positive_spectral_whitener_from_gram(
2021    gram: &Array2<f64>,
2022) -> Result<Array2<f64>, BasisError> {
2023    // Inverse-square-root for the positive part of `gram`. Eigenvalues at or
2024    // below the relative rank tolerance `α·ε·n·max_eval` are *dropped*: the
2025    // returned whitener has shape `(n × keep)` where `keep` counts strictly
2026    // positive eigendirections of `gram`.
2027    //
2028    // Dropping (rather than ridging) is what makes the result a true
2029    // square-root inverse on the column space of `gram`. This whitener is
2030    // used by `stabilized_orthogonality_transform_from_gram` to make a
2031    // pre-existing transform `K_raw` orthonormal under the W-inner product:
2032    // when some columns of `K_raw` map to zero (or near-zero) under `B`, the
2033    // constrained Gram `K_raw^T G K_raw` is rank-deficient. Ridging those
2034    // tail directions with `1/sqrt(ε)` produced spurious basis columns
2035    // whose coefficient norms blew up to `~1/sqrt(ε)` while their image in
2036    // `B` was floating-point zero, contaminating downstream linear algebra
2037    // (in particular it forced `smooth.rs` to widen the post-transform
2038    // orthogonality residual tolerance to absorb a `cond ≈ 1/sqrt(ε)`
2039    // rounding floor). Dropping these directions is the right behavior:
2040    // they contribute nothing to `B`'s column space, and removing them
2041    // tightens the orthogonality residual back down to the genuine
2042    // floating-point limit.
2043    let (eigenvalues, eigenvectors) = gram.eigh(Side::Lower).map_err(BasisError::LinalgError)?;
2044    let n = gram.nrows();
2045    let max_eval = eigenvalues.iter().copied().fold(0.0_f64, f64::max);
2046    let tol = (default_rrqr_rank_alpha() * f64::EPSILON * (n.max(1) as f64) * max_eval.max(1.0))
2047        .max(f64::EPSILON);
2048    let keep = eigenvalues.iter().filter(|&&ev| ev > tol).count();
2049    if keep == 0 {
2050        let min_ev = eigenvalues.iter().copied().fold(f64::INFINITY, f64::min);
2051        return Err(BasisError::ConstraintNullspaceCollapsed {
2052            site: "positive_spectral_whitener_from_gram",
2053            cross_rank: 0,
2054            coeff_dim: gram.nrows(),
2055            cross_frobenius: gram.iter().map(|v| v * v).sum::<f64>().sqrt(),
2056            gram_spectrum: format!(
2057                "max eigenvalue {max_eval:.3e} (min {min_ev:.3e}, spectral tolerance {tol:.3e})"
2058            ),
2059        });
2060    }
2061    // `eigh` returns eigenvalues in ascending order, so the largest `keep`
2062    // eigenvalues live at the tail.
2063    let eig_start = eigenvalues.len() - keep;
2064    let kept_vectors = eigenvectors.slice(s![.., eig_start..]).to_owned();
2065    let mut inv_sqrt = Array2::<f64>::zeros((keep, keep));
2066    for (out_i, eig_i) in (eig_start..eigenvalues.len()).enumerate() {
2067        inv_sqrt[[out_i, out_i]] = 1.0 / eigenvalues[eig_i].sqrt();
2068    }
2069    Ok(fast_ab(&kept_vectors, &inv_sqrt))
2070}
2071
2072pub(crate) fn stabilized_orthogonality_transform_from_gram(
2073    gram: &Array2<f64>,
2074    transform: &Array2<f64>,
2075) -> Result<Array2<f64>, BasisError> {
2076    let constrained_gram = {
2077        let gt = fast_ab(gram, transform);
2078        fast_atb(transform, &gt)
2079    };
2080    let whitening = positive_spectral_whitener_from_gram(&constrained_gram)?;
2081    Ok(fast_ab(transform, &whitening))
2082}
2083
2084pub(crate) fn orthogonality_transform_from_cross_and_gram(
2085    constraint_cross: &Array2<f64>,
2086    gram: &Array2<f64>,
2087) -> Result<Array2<f64>, BasisError> {
2088    // Compute null(M^T) directly on M = B^T W C (k × q) via column-pivoted QR.
2089    // Working in the original k-dim coefficient space rather than first
2090    // whitening by B^T B avoids a fundamental failure mode: when B is heavily
2091    // collinear, `positive_spectral_whitener_from_gram` truncates the design
2092    // column-space to a `keep`-dim subspace, and if `keep <= q` the subsequent
2093    // nullspace search has no room — even though dim null(M^T) = k - rank(M)
2094    // ≥ k - q is always positive when k > q. The constraint nullspace is a
2095    // property of M alone; conditioning of the design only matters for the
2096    // downstream stabilization of B*K_raw.
2097    let k = constraint_cross.nrows();
2098    if k == 0 {
2099        return Err(BasisError::InsufficientColumnsForConstraint { found: 0 });
2100    }
2101    let (transform_raw, rank) = rrqr_nullspace_basis(constraint_cross, default_rrqr_rank_alpha())
2102        .map_err(BasisError::LinalgError)?;
2103    if rank >= k || transform_raw.ncols() == 0 {
2104        return Err(BasisError::ConstraintNullspaceCollapsed {
2105            site: "orthogonality_transform_from_cross_and_gram",
2106            cross_rank: rank,
2107            coeff_dim: k,
2108            cross_frobenius: constraint_cross.iter().map(|v| v * v).sum::<f64>().sqrt(),
2109            gram_spectrum: "not computed (structural cross-rank collapse: null(Mᵀ) is empty, \
2110                            so no constrained design exists to eigendecompose)"
2111                .to_string(),
2112        });
2113    }
2114
2115    // Make the constrained design B*K_raw orthonormal under the W-inner product.
2116    // If the constrained Gram K_raw^T G K_raw is rank-deficient (because some
2117    // directions in null(M^T) collapse under B), the spectral whitener drops
2118    // them — that is the right behavior: a degenerate column never contributes
2119    // to B's column space and shouldn't appear in the reparameterized basis.
2120    stabilized_orthogonality_transform_from_gram(gram, &transform_raw)
2121}
2122
2123pub fn orthogonality_transform_for_design(
2124    design: &DesignMatrix,
2125    constraint_matrix: ArrayView2<'_, f64>,
2126    weights: Option<ArrayView1<'_, f64>>,
2127) -> Result<Array2<f64>, BasisError> {
2128    let k = design.ncols();
2129    if k == 0 {
2130        return Err(BasisError::InsufficientColumnsForConstraint { found: 0 });
2131    }
2132    let q = constraint_matrix.ncols();
2133    if q == 0 {
2134        return Ok(Array2::eye(k));
2135    }
2136    let (constraint_cross, gram) = design_cross_and_gram(design, constraint_matrix, weights)?;
2137    orthogonality_transform_from_cross_and_gram(&constraint_cross, &gram)
2138}