gam_terms/basis/types.rs
1use super::*;
2
3/// Wrapper to send a raw pointer across thread boundaries for parallel buffer fills.
4/// SAFETY: every `SendPtr` value must be built from live, properly aligned `f64`
5/// storage whose mutable borrow is held until all worker threads finish; callers
6/// may only dereference offsets that are in-bounds and disjoint across workers.
7#[derive(Clone, Copy)]
8pub(crate) struct SendPtr(pub(crate) *mut f64);
9
10// SAFETY: SendPtr only grants raw-pointer transport. Actual dereferences occur
11// at call sites after row-chunk partitioning proves each thread writes a
12// distinct in-bounds element of the backing Array/Vec allocation.
13unsafe impl Send for SendPtr {}
14
15// SAFETY: shared references to SendPtr are sound because the pointee is never
16// accessed through the wrapper without the call-site disjoint-offset proof.
17unsafe impl Sync for SendPtr {}
18
19impl SendPtr {
20 #[inline(always)]
21 pub(crate) fn add(self, offset: usize) -> *mut f64 {
22 // SAFETY: callers pass offsets within the backing allocation and only
23 // dereference the returned pointer after proving the target element is
24 // uniquely owned by that worker's chunk for the whole parallel region.
25 unsafe { self.0.add(offset) }
26 }
27}
28
29/// Re-export of the neutral basis-error contract. #1521: `BasisError` lives
30/// in `gam-problem` so `EstimationError` can wrap it (`#[from]`) without a
31/// back-edge; gam-terms re-exports it to preserve `gam_terms::basis::BasisError`.
32pub use gam_problem::BasisError;
33
34// ============================================================================
35// Unified Basis Generation API
36// ============================================================================
37
38/// Options for basis generation, controlling derivative order.
39#[derive(Clone, Copy, Debug, Default)]
40pub struct BasisOptions {
41 /// Derivative order: 0 = value (default), 1 = first derivative, 2 = second derivative
42 pub derivative_order: usize,
43 /// Basis family to evaluate.
44 pub basis_family: BasisFamily,
45}
46
47impl BasisOptions {
48 /// Create options for evaluating basis functions (no derivative).
49 pub const fn value() -> Self {
50 Self {
51 derivative_order: 0,
52 basis_family: BasisFamily::BSpline,
53 }
54 }
55
56 /// Create options for evaluating first derivatives of basis functions.
57 pub const fn first_derivative() -> Self {
58 Self {
59 derivative_order: 1,
60 basis_family: BasisFamily::BSpline,
61 }
62 }
63
64 /// Create options for evaluating second derivatives of basis functions.
65 pub const fn second_derivative() -> Self {
66 Self {
67 derivative_order: 2,
68 basis_family: BasisFamily::BSpline,
69 }
70 }
71
72 /// Create options for evaluating M-spline basis values.
73 pub const fn m_spline() -> Self {
74 Self {
75 derivative_order: 0,
76 basis_family: BasisFamily::MSpline,
77 }
78 }
79
80 /// Create options for evaluating I-spline basis values.
81 pub const fn i_spline() -> Self {
82 Self {
83 derivative_order: 0,
84 basis_family: BasisFamily::ISpline,
85 }
86 }
87}
88
89/// Basis-family selector for 1D spline evaluation.
90#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
91pub enum BasisFamily {
92 /// Standard B-splines.
93 #[default]
94 BSpline,
95 /// M-splines: normalized B-splines, M_i = ((k+1)/(t_{i+k+1}-t_i)) B_i.
96 MSpline,
97 /// I-splines: integrated M-splines, implemented by right-cumulative
98 /// sums of B-splines at degree k+1.
99 ISpline,
100}
101
102/// Specifies the source of knots for basis generation.
103#[derive(Clone, Debug)]
104pub enum KnotSource<'a> {
105 /// Use a pre-computed knot vector.
106 Provided(ArrayView1<'a, f64>),
107 /// Generate uniformly spaced knots based on data range.
108 Generate {
109 /// Data range (min, max) for knot placement.
110 data_range: (f64, f64),
111 /// Number of internal knots to place between boundaries.
112 num_internal_knots: usize,
113 },
114}
115/// Thin-plate regression spline basis and penalty (order m=2).
116///
117/// The returned basis has columns `[K_c | P]` where:
118/// - `K_c` is the constrained radial basis block (`K * Z`) with
119/// `P(knots)^T * α = 0` enforced via nullspace projection
120/// - `P` is the TPS polynomial null-space block containing all monomials of
121/// total degree `< m`, where `m = thin_plate_penalty_order(d)` (so `P` is
122/// just `[1, x_1, ..., x_d]` for `d <= 3`)
123///
124/// The returned penalty matrix is block-diagonal with:
125/// - upper-left `Omega_c = Z^T Omega Z` for the constrained radial block
126/// - zero lower-right block for unpenalized polynomial terms.
127///
128/// For double-penalty GAMs, a second ridge penalty `I` is also returned so the
129/// caller can optimize `(lambda_bending, lambdaridge)` jointly.
130#[derive(Debug, Clone)]
131pub struct ThinPlateSplineBasis {
132 pub basis: Array2<f64>,
133 pub penalty_bending: Array2<f64>,
134 pub penalty_ridge: Array2<f64>,
135 pub num_kernel_basis: usize,
136 pub num_polynomial_basis: usize,
137 pub dimension: usize,
138 /// Wood-TPRS radial reparameterization matrix `V`.
139 ///
140 /// Rows live in the side-constrained radial coefficient space. Columns are
141 /// the retained positive bending eigendirections of `Z' Ω Z`; numerically
142 /// near-null radial directions are dropped before the basis is exposed.
143 /// Therefore `V` can be rectangular: design columns are `Φ Z V`, and the
144 /// radial penalty is `diag(Λ_retained)`.
145 pub radial_reparam: Array2<f64>,
146}
147
148/// Matérn smoothness parameter `nu` (half-integer variants with closed forms).
149#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
150pub enum MaternNu {
151 Half,
152 ThreeHalves,
153 FiveHalves,
154 SevenHalves,
155 NineHalves,
156}
157
158impl MaternNu {
159 /// The half-integer smoothness value ν as an `f64` (0.5, 1.5, …).
160 pub const fn half_integer_value(self) -> f64 {
161 match self {
162 MaternNu::Half => 0.5,
163 MaternNu::ThreeHalves => 1.5,
164 MaternNu::FiveHalves => 2.5,
165 MaternNu::SevenHalves => 3.5,
166 MaternNu::NineHalves => 4.5,
167 }
168 }
169}
170
171/// Matérn radial basis and penalties.
172#[derive(Debug, Clone)]
173pub struct MaternSplineBasis {
174 pub basis: Array2<f64>,
175 pub penalty_kernel: Array2<f64>,
176 pub penalty_ridge: Array2<f64>,
177 pub num_kernel_basis: usize,
178 pub num_polynomial_basis: usize,
179 pub dimension: usize,
180}
181
182#[derive(Debug, Clone)]
183pub(crate) struct DuchonBasisDesign {
184 pub(crate) basis: Array2<f64>,
185}
186
187/// Boundary-condition policy for one-dimensional smooth bases.
188#[derive(Debug, Clone, Serialize, Deserialize, Default)]
189pub enum OneDimensionalBoundary {
190 /// Ordinary open interval basis with clamped endpoint behavior.
191 #[default]
192 Open,
193 /// Periodic/cyclic basis over the half-open interval `[start, end)`.
194 ///
195 /// Values are evaluated modulo `period = end - start`; the basis and its
196 /// first `degree - 1` derivatives agree at the two endpoints for B-splines.
197 Cyclic { start: f64, end: f64 },
198}
199
200impl OneDimensionalBoundary {
201 pub(crate) fn period(&self) -> Option<(f64, f64, f64)> {
202 match *self {
203 OneDimensionalBoundary::Open => None,
204 OneDimensionalBoundary::Cyclic { start, end } if end > start => {
205 Some((start, end, end - start))
206 }
207 OneDimensionalBoundary::Cyclic { .. } => None,
208 }
209 }
210}
211
212/// Which knot strategy to use for 1D B-spline bases.
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub enum BSplineKnotSpec {
215 Generate {
216 data_range: (f64, f64),
217 num_internal_knots: usize,
218 },
219 /// Uniform cyclic B-spline basis on `[data_range.0, data_range.1)`.
220 ///
221 /// The first and last endpoints are identified, so evaluating at `x` and
222 /// `x + m * period` gives identical rows. `num_basis` is the number of
223 /// periodic control sites around the loop and must be at least
224 /// `degree + 1` for an unaliased local support stencil.
225 PeriodicUniform {
226 data_range: (f64, f64),
227 num_basis: usize,
228 },
229 Automatic {
230 num_internal_knots: Option<usize>,
231 placement: BSplineKnotPlacement,
232 },
233 Provided(Array1<f64>),
234 /// Natural cubic regression spline (`bs="cr"`/`"cs"`) knot set (#1074).
235 ///
236 /// Unlike the open-spline variants above, these `knots` are the `k`
237 /// Lancaster–Salkauskas knots `x*_1 < … < x*_k` that *directly* index the
238 /// basis values `β_i = f(x*_i)` — the basis dimension equals `knots.len()`
239 /// (not `knots.len() - degree - 1`). The 1-D builder routes this variant to
240 /// the cubic-regression builder; the cr identity therefore round-trips
241 /// through freeze/reload by virtue of the variant itself (no separate
242 /// metadata marker is required), and tensor margins inherit cr by carrying
243 /// this knotspec into `build_bspline_basis_1d`.
244 NaturalCubicRegression {
245 knots: Array1<f64>,
246 },
247}
248
249/// Internal-knot placement strategy when knots are automatically inferred.
250#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
251pub enum BSplineKnotPlacement {
252 Uniform,
253 Quantile,
254}
255
256/// 1D B-spline basis configuration.
257#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct BSplineBasisSpec {
259 pub degree: usize,
260 pub penalty_order: usize,
261 pub knotspec: BSplineKnotSpec,
262 pub double_penalty: bool,
263 pub identifiability: BSplineIdentifiability,
264 #[serde(default)]
265 pub boundary: OneDimensionalBoundary,
266 /// Optional endpoint boundary constraints (Hermite-style pin of value and/or
267 /// derivative at the left/right knot extents). Default = `Free` on both
268 /// sides which is a no-op.
269 #[serde(default)]
270 pub boundary_conditions: BSplineBoundaryConditions,
271}
272
273/// Per-endpoint boundary constraint policy for B-spline 1D bases.
274#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
275pub enum BSplineEndpointBoundaryCondition {
276 /// No endpoint constraint.
277 #[default]
278 Free,
279 /// Pin the first derivative to zero at this endpoint.
280 Clamped,
281 /// Pin the value at this endpoint to `value` (currently only `value == 0`
282 /// is accepted in the builder; non-zero anchors require an affine offset).
283 Anchored { value: f64 },
284}
285
286/// Left/right pair of B-spline endpoint constraints.
287#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
288pub struct BSplineBoundaryConditions {
289 #[serde(default)]
290 pub left: BSplineEndpointBoundaryCondition,
291 #[serde(default)]
292 pub right: BSplineEndpointBoundaryCondition,
293}
294
295impl BSplineBoundaryConditions {
296 pub const fn is_free(&self) -> bool {
297 matches!(self.left, BSplineEndpointBoundaryCondition::Free)
298 && matches!(self.right, BSplineEndpointBoundaryCondition::Free)
299 }
300}
301
302/// Per-smooth identifiability policy for 1D B-spline bases.
303///
304/// These constraints are applied directly in the builder via a reparameterization
305/// `B_constrained = B * Z`, and every penalty matrix is projected as
306/// `S_constrained = Z' S Z`, so solver geometry stays consistent.
307#[derive(Debug, Clone, Serialize, Deserialize)]
308pub enum BSplineIdentifiability {
309 /// Keep unconstrained basis columns.
310 None,
311 /// Enforce weighted sum-to-zero: `B' w = 0` (or unweighted when `weights=None`).
312 // Smooth terms are centered by default to avoid intercept confounding.
313 WeightedSumToZero { weights: Option<Array1<f64>> },
314 /// Remove intercept + linear trend in coefficient space using Greville geometry.
315 RemoveLinearTrend,
316 /// Enforce orthogonality to supplied design columns `C` (n x q):
317 /// `B_c' W C = 0` (or unweighted when `weights=None`).
318 ///
319 /// To enforce `[intercept, x, ...]`, provide `columns` with those columns.
320 OrthogonalToDesignColumns {
321 columns: Array2<f64>,
322 weights: Option<Array1<f64>>,
323 },
324 /// Apply an explicit coefficient-space transform `Z` learned at fit time.
325 ///
326 /// This freezes identifiability behavior so prediction cannot drift based on
327 /// new-data distribution. The constrained basis is `B * Z`.
328 FrozenTransform { transform: Array2<f64> },
329}
330
331impl Default for BSplineIdentifiability {
332 fn default() -> Self {
333 BSplineIdentifiability::WeightedSumToZero { weights: None }
334 }
335}
336
337/// Spatial center selection strategy.
338///
339/// `num_centers` is the exact number of knot/center rows selected by the
340/// strategy. Polynomial nullspace columns are added separately by each basis
341/// builder and must never be folded into this count.
342#[derive(Debug, Clone, Serialize, Deserialize)]
343pub enum CenterStrategy {
344 Auto(Box<CenterStrategy>),
345 UserProvided(Array2<f64>),
346 /// Joint multidimensional equal-mass partitioning in the full smooth space.
347 EqualMass {
348 num_centers: usize,
349 },
350 /// Covariate-representative equal-mass partitioning along one selected axis.
351 EqualMassCovarRepresentative {
352 num_centers: usize,
353 },
354 FarthestPoint {
355 num_centers: usize,
356 },
357 KMeans {
358 num_centers: usize,
359 max_iter: usize,
360 },
361 UniformGrid {
362 points_per_dim: usize,
363 },
364}
365
366#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
367pub enum CenterStrategyKind {
368 UserProvided,
369 EqualMass,
370 EqualMassCovarRepresentative,
371 FarthestPoint,
372 KMeans,
373 UniformGrid,
374}
375
376/// Adaptive default center count for spatial smooths (TPS, Duchon, Matérn).
377///
378/// Use this when the user has not explicitly specified a knot/center count.
379/// The basis size is the sub-linear `ceil(8 * d_factor * n^0.4)`, clamped above
380/// at `K_MAX = 2000` and below at a *data-proportional* floor `min(200, n/8)` so
381/// the floor only engages once there are enough observations to support a rich
382/// basis. The result is additionally capped at `n/4` so the penalty matrices
383/// stay well-conditioned relative to the data:
384///
385/// | n | d=1 | d=2 | d=5 |
386/// |--------|------|------|------|
387/// | 800 | 116 | 134 | 186 |
388/// | 1 000 | 127 | 146 | 200 |
389/// | 2 000 | 200 | 200 | 268 |
390/// | 10 000 | 319 | 367 | 510 |
391/// | 100 000| 801 | 921 | 1281 |
392/// | 400 000| 1393 | 1602 | 2000 |
393/// | 1 000 000| 2000 | 2000 | 2000 |
394///
395/// The flat `200` floor used to inflate moderate-`n` spatial smooths (a few
396/// hundred to ~2000 rows) up to a dense 200-column design even though the raw
397/// sub-linear count — and the mesh/knot density that mgcv and R-INLA use on the
398/// same data — is far smaller. On ~800 rows that turned a single 2-D thin-plate
399/// REML fit into an `O(n·p² + p³)` grind at `p ≈ 200` (#718). Smoothness is
400/// already controlled by REML's penalty weight λ, not by the center count, so a
401/// data-proportional floor recovers the same surface at a fraction of the cost.
402///
403/// # Arguments
404/// * `n` - sample size (number of observations)
405/// * `d` - covariate dimensionality (number of input variables in the smooth)
406pub fn default_num_centers(n: usize, d: usize) -> usize {
407 const K_MIN: usize = 200;
408 const K_MAX: usize = 2000;
409 const ALPHA: f64 = 0.4;
410 const C: f64 = 8.0;
411 /// Per-extra-dimension growth in the center count: each covariate axis
412 /// beyond the first widens the basis by 15% to keep the per-axis mesh
413 /// density roughly constant as the smooth's domain dimensionality grows.
414 const PER_DIM_GROWTH: f64 = 0.15;
415 /// Divisor for the data-proportional floor: the `K_MIN` floor only engages
416 /// once `n` exceeds `K_MIN * FLOOR_N_DIVISOR`, so small samples are not
417 /// forced up to a dense `K_MIN`-column design.
418 const FLOOR_N_DIVISOR: usize = 8;
419 /// Divisor for the conditioning cap: the center count never exceeds `n /
420 /// COND_N_DIVISOR`, keeping the penalty matrices well-conditioned relative
421 /// to the data.
422 const COND_N_DIVISOR: usize = 4;
423
424 let d_factor = 1.0 + PER_DIM_GROWTH * (d.max(1) - 1) as f64;
425 let raw = (C * d_factor * (n as f64).powf(ALPHA)).ceil() as usize;
426
427 // Data-proportional floor: never inflate beyond n/FLOOR_N_DIVISOR, so the
428 // K_MIN-center floor only takes effect once n is large enough (~1600) to
429 // genuinely support that many basis columns.
430 let floor = K_MIN.min(n / FLOOR_N_DIVISOR);
431 let k = raw.clamp(floor, K_MAX);
432
433 // Never exceed n itself; cap at n/COND_N_DIVISOR to keep the penalty
434 // matrices well-conditioned relative to the data.
435 k.min(n).min(n / COND_N_DIVISOR)
436}
437
438/// Conservative center count for a *secondary* (distributional) predictor's
439/// spatial smooth — e.g. the log-σ scale model in a Gaussian location-scale
440/// fit.
441///
442/// The mean is identified directly by the response, so it warrants the
443/// generous [`default_num_centers`] basis. A scale/shape predictor is
444/// identified only through (noisy) squared residuals: handing it a basis sized
445/// for the mean lets REML/LAML smoothing selection over-fit it, because where
446/// the fitted scale is driven small the *observed* information collapses and
447/// the determinant penalty stops holding the wiggle down (#501). This mirrors
448/// standard GAMLSS/mgcv practice of giving distribution parameters a modest
449/// default (mgcv's `k = 10` for a 1-D `s()`), grown gently with dimensionality
450/// and never exceeding the generous primary-predictor default.
451pub fn conservative_secondary_centers(n: usize, d: usize) -> usize {
452 const BASE_1D_CENTERS: usize = 10;
453 let modest = BASE_1D_CENTERS.saturating_mul(d.max(1));
454 default_num_centers(n, d).min(modest).max(1)
455}
456
457/// Resource-aware plan for a spatial smooth (Duchon / Matérn / TPS).
458///
459/// Returned by [`plan_spatial_basis`]. Captures the resolved center count,
460/// final basis dimension `p`, the dense byte cost for the value matrix and
461/// each derivative tier, and a recommended storage mode that is consistent
462/// with the supplied [`gam_runtime::resource::ResourcePolicy`].
463#[derive(Clone, Debug)]
464pub struct SpatialBasisPlan {
465 pub n: usize,
466 pub d: usize,
467 pub centers: usize,
468 pub p_final_estimate: usize,
469 pub dense_design_bytes: usize,
470 pub first_derivative_dense_bytes: usize,
471 pub second_derivative_dense_bytes: usize,
472 pub recommended_storage: SpatialStorageMode,
473}
474
475/// Storage mode recommended by [`plan_spatial_basis`].
476///
477/// * `DenseValueDenseDerivatives` — both the value design and its derivative
478/// matrices fit under the policy's single-materialization budget.
479/// * `LazyValueImplicitDerivatives` — the value design fits dense but the
480/// derivative matrices do not; switch derivatives to the implicit operator.
481/// * `OperatorOnly` — neither the design nor its derivatives fit; everything
482/// must be operator-backed.
483#[derive(Clone, Copy, Debug, PartialEq, Eq)]
484pub enum SpatialStorageMode {
485 DenseValueDenseDerivatives,
486 LazyValueImplicitDerivatives,
487 OperatorOnly,
488}
489
490/// How [`plan_spatial_basis`] should pick the spatial center count.
491#[derive(Clone, Copy, Debug)]
492pub enum CenterCountRequest {
493 /// Use the heuristic [`default_num_centers`].
494 Default,
495 /// Use the caller-supplied count exactly.
496 Explicit(usize),
497 /// Use [`default_num_centers`] but cap at `cap` to bound dense cost.
498 HeuristicCapped { cap: usize },
499}
500
501/// Build a resource-aware plan for a spatial smooth basis.
502///
503/// Computes the resolved center count, final basis dimension, dense byte
504/// estimates for the value design and first/second derivative tiers, and a
505/// recommended [`SpatialStorageMode`] derived from `policy`. This is the
506/// resource-aware replacement for ad-hoc calls to [`default_num_centers`] /
507/// [`heuristic_centers`](crate::term_builder::heuristic_centers).
508pub fn plan_spatial_basis(
509 n: usize,
510 d: usize,
511 requested_centers: CenterCountRequest,
512 nullspace_order: DuchonNullspaceOrder,
513 scale_dims: bool,
514 policy: &gam_runtime::resource::ResourcePolicy,
515) -> Result<SpatialBasisPlan, BasisError> {
516 if n == 0 {
517 crate::bail_invalid_basis!("plan_spatial_basis: n must be >= 1");
518 }
519 if d == 0 {
520 crate::bail_invalid_basis!("plan_spatial_basis: d must be >= 1");
521 }
522
523 // 1. Resolve center count.
524 let centers = match requested_centers {
525 CenterCountRequest::Default => default_num_centers(n, d),
526 CenterCountRequest::Explicit(k) => k,
527 CenterCountRequest::HeuristicCapped { cap } => default_num_centers(n, d).min(cap),
528 };
529
530 // 2. Nullspace dimension (Duchon polynomial null space of degree p-1).
531 // `duchon_p_from_nullspace_order` returns m such that the null space is
532 // polynomials of total degree < m, matching `duchon_nullspace_dimension`'s
533 // `max_total_degree = m - 1` argument.
534 let m = duchon_p_from_nullspace_order(nullspace_order);
535 let nullspace_dim = if m == 0 {
536 0
537 } else {
538 duchon_nullspace_dimension(d, m - 1)
539 };
540
541 let p = centers.saturating_add(nullspace_dim);
542
543 // 3. Dense byte estimates.
544 let derivative_axes = if scale_dims { d } else { 0 };
545 let bytes_per_f64 = std::mem::size_of::<f64>();
546 let dense_design_bytes = bytes_per_f64.saturating_mul(n).saturating_mul(p);
547 let first_derivative_dense_bytes = dense_design_bytes.saturating_mul(derivative_axes);
548 // Diagonal second derivatives are also (D × n × p); off-diagonal cross terms
549 // would scale as D^2 but the planner reports the diagonal tier here.
550 let second_derivative_dense_bytes = first_derivative_dense_bytes;
551
552 // 4. Pick storage mode based on policy.
553 let recommended_storage = match policy.derivative_storage_mode {
554 gam_runtime::resource::DerivativeStorageMode::AnalyticOperatorRequired => {
555 SpatialStorageMode::OperatorOnly
556 }
557 gam_runtime::resource::DerivativeStorageMode::MaterializeIfSmall => {
558 let budget = policy.max_single_materialization_bytes;
559 if derivative_axes == 0 {
560 if dense_design_bytes <= budget {
561 SpatialStorageMode::DenseValueDenseDerivatives
562 } else {
563 SpatialStorageMode::LazyValueImplicitDerivatives
564 }
565 } else {
566 let total = dense_design_bytes
567 .saturating_add(first_derivative_dense_bytes)
568 .saturating_add(second_derivative_dense_bytes);
569 if total <= budget {
570 SpatialStorageMode::DenseValueDenseDerivatives
571 } else if dense_design_bytes <= budget {
572 SpatialStorageMode::LazyValueImplicitDerivatives
573 } else {
574 SpatialStorageMode::OperatorOnly
575 }
576 }
577 }
578 gam_runtime::resource::DerivativeStorageMode::DiagnosticsOnly => {
579 // Diagnostic mode still prefers analytic storage for correctness.
580 SpatialStorageMode::OperatorOnly
581 }
582 };
583
584 Ok(SpatialBasisPlan {
585 n,
586 d,
587 centers,
588 p_final_estimate: p,
589 dense_design_bytes,
590 first_derivative_dense_bytes,
591 second_derivative_dense_bytes,
592 recommended_storage,
593 })
594}
595
596pub const fn default_spatial_center_strategy(num_centers: usize, d: usize) -> CenterStrategy {
597 if d <= 3 {
598 CenterStrategy::FarthestPoint { num_centers }
599 } else {
600 CenterStrategy::EqualMassCovarRepresentative { num_centers }
601 }
602}
603
604pub fn auto_spatial_center_strategy(num_centers: usize, d: usize) -> CenterStrategy {
605 let strategy = if d == 1 {
606 // In one dimension, farthest-point selection is the deterministic
607 // maximin grid over the observed domain. Equal-mass midpoints leave the
608 // low-frequency Duchon radial block slightly under-resolved at the
609 // boundaries, and REML then compensates with an over-smooth λ on
610 // low-noise signals (#504). The maximin grid matches the native
611 // reproducing-kernel interpolation geometry. The default strategy below
612 // extends the same space-filling contract to low-dimensional spatial
613 // GP bases, where kriging accuracy is governed by fill distance rather
614 // than marginal quantile balance.
615 CenterStrategy::FarthestPoint { num_centers }
616 } else {
617 default_spatial_center_strategy(num_centers, d)
618 };
619 CenterStrategy::Auto(Box::new(strategy))
620}
621
622pub const fn center_strategy_is_auto(strategy: &CenterStrategy) -> bool {
623 matches!(strategy, CenterStrategy::Auto(_))
624}
625
626pub(crate) fn realized_center_strategy(strategy: &CenterStrategy) -> &CenterStrategy {
627 match strategy {
628 CenterStrategy::Auto(inner) => inner.as_ref(),
629 other => other,
630 }
631}
632
633pub fn center_strategy_kind(strategy: &CenterStrategy) -> CenterStrategyKind {
634 match strategy {
635 CenterStrategy::Auto(inner) => center_strategy_kind(inner.as_ref()),
636 CenterStrategy::UserProvided(_) => CenterStrategyKind::UserProvided,
637 CenterStrategy::EqualMass { .. } => CenterStrategyKind::EqualMass,
638 CenterStrategy::EqualMassCovarRepresentative { .. } => {
639 CenterStrategyKind::EqualMassCovarRepresentative
640 }
641 CenterStrategy::FarthestPoint { .. } => CenterStrategyKind::FarthestPoint,
642 CenterStrategy::KMeans { .. } => CenterStrategyKind::KMeans,
643 CenterStrategy::UniformGrid { .. } => CenterStrategyKind::UniformGrid,
644 }
645}
646
647pub fn center_strategy_num_centers(strategy: &CenterStrategy) -> Option<usize> {
648 match strategy {
649 CenterStrategy::Auto(inner) => center_strategy_num_centers(inner.as_ref()),
650 CenterStrategy::UserProvided(centers) => Some(centers.nrows()),
651 CenterStrategy::EqualMass { num_centers }
652 | CenterStrategy::EqualMassCovarRepresentative { num_centers }
653 | CenterStrategy::FarthestPoint { num_centers }
654 | CenterStrategy::KMeans { num_centers, .. } => Some(*num_centers),
655 CenterStrategy::UniformGrid { .. } => None,
656 }
657}
658
659pub fn center_strategy_with_num_centers(
660 strategy: &CenterStrategy,
661 num_centers: usize,
662) -> Result<CenterStrategy, BasisError> {
663 validate_center_count(num_centers)?;
664 fn rebuild_inner(
665 strategy: &CenterStrategy,
666 num_centers: usize,
667 ) -> Result<CenterStrategy, BasisError> {
668 match strategy {
669 CenterStrategy::Auto(inner) => rebuild_inner(inner.as_ref(), num_centers),
670 CenterStrategy::EqualMass { .. } => Ok(CenterStrategy::EqualMass { num_centers }),
671 CenterStrategy::EqualMassCovarRepresentative { .. } => {
672 Ok(CenterStrategy::EqualMassCovarRepresentative { num_centers })
673 }
674 CenterStrategy::FarthestPoint { .. } => {
675 Ok(CenterStrategy::FarthestPoint { num_centers })
676 }
677 CenterStrategy::KMeans { max_iter, .. } => Ok(CenterStrategy::KMeans {
678 num_centers,
679 max_iter: *max_iter,
680 }),
681 CenterStrategy::UserProvided(_) | CenterStrategy::UniformGrid { .. } => {
682 Err(BasisError::InvalidInput(format!(
683 "cannot replace center count for {:?} strategy",
684 center_strategy_kind(strategy)
685 )))
686 }
687 }
688 }
689 let rebuilt = rebuild_inner(strategy, num_centers)?;
690 Ok(match strategy {
691 CenterStrategy::Auto(_) => CenterStrategy::Auto(Box::new(rebuilt)),
692 _ => rebuilt,
693 })
694}
695
696/// Thin-plate basis configuration.
697#[derive(Debug, Clone, Serialize, Deserialize)]
698pub struct ThinPlateBasisSpec {
699 pub center_strategy: CenterStrategy,
700 #[serde(default)]
701 pub periodic: Option<Vec<Option<f64>>>,
702 pub length_scale: f64,
703 pub double_penalty: bool,
704 #[serde(default)]
705 pub identifiability: SpatialIdentifiability,
706 /// Frozen Wood-TPRS radial reparameterization. When `Some`, the builder
707 /// reuses this `(raw_radial_cols) × (kept_radial_cols)` matrix instead of
708 /// recomputing it from the constrained kernel penalty eigensystem. The
709 /// rectangular case is the truncated regression-spline path; carrying it
710 /// into prediction guarantees identical radial modes to fit-time.
711 #[serde(default)]
712 pub radial_reparam: Option<Array2<f64>>,
713}
714
715/// Per-smooth identifiability policy for spatial (TPS / Duchon) bases.
716///
717/// For a raw local basis `B` and parametric design block `C`, the orthogonalized
718/// basis is `B_c = B Z` where columns of `Z` span `null((B^T C)^T)`. This enforces:
719/// `B_c^T C = 0`
720/// in the unweighted inner product, so spatial effects cannot absorb parametric
721/// directions that actually exist in the model. The standalone basis builder has
722/// only an implicit intercept available, so it centers smooths against that
723/// intercept. The term-collection builder augments `C` with explicit linear
724/// terms when those terms are present in the formula.
725#[derive(Debug, Default, Clone, Serialize, Deserialize)]
726pub enum SpatialIdentifiability {
727 /// Keep unconstrained basis columns.
728 None,
729 /// Orthogonalize the smooth against model-owned parametric columns.
730 // "Magic" default for modular GAMs with explicit parametric block:
731 // keep spatial smooth orthogonal to intercept/linear terms.
732 // ApproxKind: Exact (orthogonalization is an exact projection).
733 #[default]
734 OrthogonalToParametric,
735 /// Freeze a fit-time transform `Z`; prediction uses `B_new * Z` unchanged.
736 FrozenTransform { transform: Array2<f64> },
737}
738
739pub(crate) use sphere_kernels::{
740 wahba_sphere_kernel_derivative_dcos_kind, wahba_sphere_kernel_from_cos_kind,
741 wahba_sphere_kernel_from_cos_simd_kind, wahba_sphere_kernel_sobolev_derivative_dcos,
742};
743
744pub use sphere_spectral::{
745 pseudo_s2_truncated_coefficients, sobolev_s2_truncated_coefficients,
746 sphere_truncated_spectral_eval,
747};
748
749/// Matérn basis configuration.
750#[derive(Debug, Clone, Serialize, Deserialize)]
751pub struct MaternBasisSpec {
752 pub center_strategy: CenterStrategy,
753 #[serde(default)]
754 pub periodic: Option<Vec<Option<f64>>>,
755 pub length_scale: f64,
756 pub nu: MaternNu,
757 #[serde(default)]
758 pub include_intercept: bool,
759 pub double_penalty: bool,
760 #[serde(default)]
761 pub identifiability: MaternIdentifiability,
762 /// Per-axis anisotropy log-scales η_a (contrasts with Ση_a = 0).
763 ///
764 /// This implements geometric anisotropy: Λ = κA where A = diag(exp(η_a)),
765 /// det(A) = 1. The kernel is evaluated at r = κ|Ah| instead of r = κ|h|.
766 /// The decomposition preserves the isotropic scaling law for global κ
767 /// and adds d−1 shape parameters for directional relevance.
768 ///
769 /// Conditional positive definiteness is preserved under any invertible
770 /// linear coordinate transform (Schoenberg), so the kernel remains valid.
771 ///
772 /// When Some, the distance is r = √(Σ_a exp(2η_a) · (x_a - c_a)²).
773 /// When None, isotropic distance r = ‖x - c‖ is used.
774 #[serde(default)]
775 pub aniso_log_scales: Option<Vec<f64>>,
776 /// Frozen double-penalty nullspace-shrinkage decision (gam#787/#860).
777 ///
778 /// `None` (the default, and the cold-build value) = decide whether to emit
779 /// the `DoublePenaltyNullspace` candidate via the κ-dependent spectral test in
780 /// `build_nullspace_shrinkage_penalty`. `Some(b)` = force the decision (set by
781 /// the freeze step from the bootstrap-κ build, mirrored from
782 /// `MaternIdentifiability::FrozenTransform`) so the learned-penalty count stays
783 /// invariant as the κ-optimizer rebuilds the design at each trial length-scale.
784 /// Only consulted when `double_penalty` is true.
785 #[serde(default)]
786 pub nullspace_shrinkage_survived: Option<bool>,
787}
788
789/// Per-smooth identifiability policy for Matérn kernel coefficients.
790///
791/// These constraints are geometric (center-based), so they are stable across
792/// train/predict and do not depend on response weights.
793#[derive(Debug, Default, Clone, Serialize, Deserialize)]
794pub enum MaternIdentifiability {
795 /// Keep the unconstrained kernel coefficient space.
796 None,
797 /// Enforce `1^T alpha = 0` at center locations (removes constant drift).
798 // Safe default with model intercepts: prevent kernel block from absorbing
799 // a global mean level.
800 #[default]
801 CenterSumToZero,
802 /// Enforce orthogonality to `[1, c_1, ..., c_d]` at centers.
803 /// Use this when explicit linear terms should own global trends.
804 CenterLinearOrthogonal,
805 /// Freeze a fit-time transform `Z` so prediction cannot drift.
806 ///
807 /// `nullspace_shrinkage_survived` freezes the double-penalty
808 /// nullspace-shrinkage decision alongside the transform (gam#787/#860). The
809 /// matern double-penalty path emits a `DoublePenaltyNullspace` candidate iff
810 /// `build_nullspace_shrinkage_penalty(&projected_kernel)` finds a near-zero
811 /// eigenvalue — but that spectral test is κ-DEPENDENT (its tolerance scales
812 /// with `λ_max`), so a near-zero eigenvalue can cross the threshold as the
813 /// κ-optimizer rebuilds the design at each trial length-scale. That flips the
814 /// learned-penalty count 6↔7 across the rebuild and the rebuilt design's ρ
815 /// dimension then disagrees with the frozen joint setup ("joint hyper rho
816 /// dimension mismatch" → every κ seed fails startup validation). Freezing the
817 /// bootstrap-κ decision here (`Some(true)` = always emit the shrinkage
818 /// candidate, `Some(false)` = never) keeps the penalty count INVARIANT across
819 /// the κ rebuild so κ actually optimizes. `None` = decide via the spectral
820 /// test (the non-frozen / cold-build behavior; also the serde back-compat
821 /// default for transforms frozen before this field existed).
822 FrozenTransform {
823 transform: Array2<f64>,
824 #[serde(default)]
825 nullspace_shrinkage_survived: Option<bool>,
826 },
827}
828
829/// Duchon null-space polynomial degree.
830///
831/// Controls the polynomial null space of the Duchon / polyharmonic spline. The
832/// Duchon seminorm `‖D^m f‖²` annihilates all polynomials of total degree
833/// `< m`, so those polynomials must be handled as explicit unpenalized columns.
834///
835/// The user-facing `order` knob selects the polynomial degree cutoff `r`, and
836/// the resulting polynomial null space has dimension `C(d + r, r)` where `d`
837/// is the covariate dimension. In the `duchon(...)` formula DSL:
838///
839/// | `order=` | Variant | max total degree | null-space dim |
840/// |----------|-----------------|------------------|-----------------|
841/// | `0` | `Zero` | 0 | `C(d+0,0) = 1` |
842/// | `1` | `Linear` | 1 | `C(d+1,1) = d+1`|
843/// | `k≥2` | `Degree(k)` | k | `C(d+k,k)` |
844///
845/// **How the polynomial null space is consumed during basis construction:**
846///
847/// 1. `polynomial_block_from_order` materialises an `(n, C(d+r,r))` block `P`
848/// of monomials up to total degree `r` at the selected `centers`.
849/// 2. `kernel_constraint_nullspace` computes `Z = null(P_centers^T)`, a
850/// `(k, k − C(d+r,r))` matrix. Reparameterising the radial kernel
851/// coefficients as `α = Z γ` enforces the side condition `P_centers^T α = 0`
852/// and yields `k − C(d+r,r)` free kernel parameters.
853/// 3. The polynomial block `P_data` evaluated at the data rows is appended to
854/// the kernel block `Φ Z`, giving a total of
855/// `(k − C(d+r,r)) + C(d+r,r) = k` columns before the spatial
856/// identifiability transform. Crucially, the total width equals the
857/// requested center count `k`, **not** `k + C(d+r,r)`.
858///
859/// **Example — `duchon(PC1, PC2, PC3, centers=10, order=1)` (d=3):**
860///
861/// - Polynomial null space: `C(3+1,1) = 4` monomials `{1, x₁, x₂, x₃}`.
862/// - Kernel columns after constraint: `10 − 4 = 6`.
863/// - Appended polynomial block: 4 columns.
864/// - Pre-identifiability total: `6 + 4 = 10` columns, i.e. exactly `centers`.
865///
866/// The variant naming matches the Duchon `m` parameter:
867/// `Zero` → `m=1`, `Linear` → `m=2`, `Degree(k)` → `m=k+1`.
868#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
869pub enum DuchonNullspaceOrder {
870 Zero,
871 Linear,
872 Degree(usize),
873}
874
875/// Duchon-like basis configuration with explicit low-frequency null-space
876/// control and explicit spectral power.
877#[derive(Debug, Clone, Serialize, Deserialize)]
878#[serde(deny_unknown_fields)]
879pub struct DuchonBasisSpec {
880 pub center_strategy: CenterStrategy,
881 #[serde(default)]
882 pub periodic: Option<Vec<Option<f64>>>,
883 /// Optional hybrid Matérn width. `None` means pure scale-free Duchon with
884 /// spectrum `||w||^(2p + 2s)`. `Some(length_scale)` enables the hybrid
885 /// spectrum `||w||^(2p) * (kappa^2 + ||w||^2)^s`, `kappa = 1/length_scale`.
886 pub length_scale: Option<f64>,
887 /// Literal Duchon spectral power `s` (`f64`, fractional values fully
888 /// threaded end-to-end). The pure-Duchon kernel exponent is `2(p + s) − d`,
889 /// so this is the knob that sets `φ(r)`: `s = 0` is the integer-order Duchon
890 /// kernel `r^{2p−d}` (its `r²·log r` log case in even `d`, ≡ the thin-plate
891 /// kernel); `s = (d − 1)/2` gives the cubic `r³` in every dimension.
892 ///
893 /// This field is taken LITERALLY by the basis builder — `power = 0` means
894 /// `s = 0`, NOT "use a default". The magic cubic default (applied when the
895 /// user gives no explicit power) is a request-layer choice resolved by the
896 /// formula / CLI / pyffi front-ends via [`duchon_cubic_default`]; by the time
897 /// a spec reaches the builder this value is the final intended `s`. The
898 /// hybrid Duchon–Matérn path (`length_scale = Some`) still requires an
899 /// integer `s` (read via `spec.power_as_usize()`).
900 pub power: f64,
901 pub nullspace_order: DuchonNullspaceOrder,
902 #[serde(default)]
903 pub identifiability: SpatialIdentifiability,
904 /// Per-axis anisotropy log-scales η_a.
905 ///
906 /// For hybrid Duchon (`length_scale=Some`), these are centered contrasts in
907 /// the decomposition Λ = κA with det(A)=1. For pure Duchon
908 /// (`length_scale=None`), they parameterize shape-only axis warping on the
909 /// public path and are centered before basis evaluation/writeback so no
910 /// global length scale is introduced.
911 ///
912 /// When Some, the distance is r = √(Σ_a exp(2η_a) · (x_a - c_a)²).
913 /// When None, isotropic distance r = ‖x - c‖ is used.
914 #[serde(default)]
915 pub aniso_log_scales: Option<Vec<f64>>,
916 #[serde(default)]
917 pub operator_penalties: DuchonOperatorPenaltySpec,
918 #[serde(default)]
919 pub boundary: OneDimensionalBoundary,
920 /// Data-metric radial reparameterization `V` (#1355), mirroring the
921 /// thin-plate Wood-TPRS reparam. When `Some`, the constrained kernel
922 /// transform is folded to `Z·V` so the realized design columns rotate into
923 /// the `G_c`-orthonormal generalized eigenbasis of `Ω_c v = μ G_c v` and the
924 /// native penalty becomes the diagonal curvature-per-unit-data-variance
925 /// spectrum (mgcv's cliff), preventing the REML over-smoothing collapse to
926 /// EDF = 1. Frozen at the cold dense build and replayed verbatim by the
927 /// predict / κ-trial / ψ-derivative paths so they stay bit-consistent with
928 /// the fit-time design. `None` on the lazy/streaming path (huge `n`), which
929 /// retains the original constrained basis.
930 #[serde(default)]
931 pub radial_reparam: Option<Array2<f64>>,
932}
933
934impl DuchonBasisSpec {
935 /// Integer view of `power` for the existing integer-only downstream chain.
936 /// Non-finite or non-integer values fall back to `0` (the integer-only
937 /// validators downstream already reject this case with a clear message).
938 pub fn power_as_usize(&self) -> usize {
939 duchon_power_to_usize(self.power)
940 }
941}
942
943/// Convert a Duchon spectral-power `f64` into the integer view used by the
944/// closed-form code paths. Non-finite, negative, or fractional values clamp to
945/// `0` so the validator downstream emits the canonical error.
946pub fn duchon_power_to_usize(power: f64) -> usize {
947 if !power.is_finite() || power < 0.0 {
948 return 0;
949 }
950 let rounded = power.round();
951 if (rounded - power).abs() > 1e-9 {
952 return 0;
953 }
954 rounded as usize
955}
956
957#[derive(Clone, Debug, Serialize, Deserialize)]
958pub struct DuchonOperatorPenaltySpec {
959 pub mass: OperatorPenaltySpec,
960 pub tension: OperatorPenaltySpec,
961 pub stiffness: OperatorPenaltySpec,
962}
963
964#[derive(Clone, Debug, Serialize, Deserialize)]
965pub enum OperatorPenaltySpec {
966 Active {
967 initial_log_lambda: f64,
968 prior: Option<RhoPrior>,
969 },
970 Disabled,
971}
972
973impl Default for DuchonOperatorPenaltySpec {
974 fn default() -> Self {
975 // ALL ON. The Duchon penalty is a Hilbert scale: curvature is the
976 // always-on exact RKHS `Primary` Gram and the trend ridge is always on;
977 // the lower orders — mass (amplitude `Σ(f−f̄)²`) and tension (first-order
978 // roughness `Σ‖∇f‖²`) — are active here, collocated on a density-blind
979 // data-support sample. REML deselects any the data don't support (SPEC:
980 // recover the null by default, opt INTO overfitting). Stiffness (`D2`)
981 // stays off — `Primary` is the exact, superior curvature. (The Matérn
982 // collocation overlay builds its own `all_active()`; SAE atoms, which
983 // ship only `Primary`, use `all_disabled()`.)
984 Self {
985 mass: OperatorPenaltySpec::Active {
986 initial_log_lambda: 0.0,
987 prior: None,
988 },
989 tension: OperatorPenaltySpec::Active {
990 initial_log_lambda: 0.0,
991 prior: None,
992 },
993 stiffness: OperatorPenaltySpec::Disabled,
994 }
995 }
996}
997
998impl DuchonOperatorPenaltySpec {
999 pub fn all_disabled() -> Self {
1000 Self {
1001 mass: OperatorPenaltySpec::Disabled,
1002 tension: OperatorPenaltySpec::Disabled,
1003 stiffness: OperatorPenaltySpec::Disabled,
1004 }
1005 }
1006
1007 /// All three operator dials active — used by the Matérn collocation overlay.
1008 pub fn all_active() -> Self {
1009 let active = || OperatorPenaltySpec::Active {
1010 initial_log_lambda: 0.0,
1011 prior: None,
1012 };
1013 Self {
1014 mass: active(),
1015 tension: active(),
1016 stiffness: active(),
1017 }
1018 }
1019
1020 /// Operator-penalty dials appropriate for a Matérn-ν kernel in dimension `d`.
1021 ///
1022 /// The Matérn-ν RKHS is the Sobolev space `H^m` with `m = ν + d/2`: its
1023 /// squared norm controls the order-`j` derivative in L2 exactly when
1024 /// `j ≤ m`. The collocation overlay penalizes the squared L2 norms of the
1025 /// value (mass, `D0`, j=0), gradient (tension, `D1`, j=1) and Hessian
1026 /// (stiffness, `D2`, j=2). Activating a penalty whose derivative order
1027 /// exceeds the RKHS smoothness (`j > m`) imposes a roughness constraint the
1028 /// true kernel does NOT — it over-smooths the reduced-rank fit relative to
1029 /// the exact GP (mgcv `bs="gp"`, GpGp).
1030 ///
1031 /// Concretely the roughest Matérn, ν=1/2 in d=1 (`m = 1`), is the
1032 /// Ornstein–Uhlenbeck/exponential kernel: an H¹ process whose sample paths
1033 /// are continuous but non-differentiable. Although `∫(f')²` is finite on
1034 /// its RKHS, the kernel itself already encodes the H¹ control; layering an
1035 /// extra tension dial on top biases the reduced-rank fit toward the smooth
1036 /// `C¹` functions the kernel does not favour (and stiffness `D2` toward
1037 /// `C²`), collapsing held-out oscillation (#707). We therefore gate each
1038 /// operator on `j < m` STRICTLY: mass (j=0) is always on, tension (j=1) is
1039 /// on for `m > 1`, stiffness (j=2) is on for `m > 2`. For ν ≥ 3/2 (or any
1040 /// d ≥ 2) every dial is active, recovering `all_active`; only the
1041 /// genuinely rough ν=1/2 (d=1) kernel — where the Sobolev order sits
1042 /// exactly on a derivative boundary — drops the higher operators.
1043 pub fn matern_for_smoothness(nu: MaternNu, d: usize) -> Self {
1044 let m = nu.half_integer_value() + 0.5 * d as f64;
1045 // Tolerance so an exact half-integer Sobolev order (e.g. m = 1.0 for
1046 // ν=1/2, d=1) reliably DISABLES the matching-order operator instead
1047 // of flipping on a float-equality knife-edge.
1048 const ORDER_EPS: f64 = 1e-9;
1049 let active = || OperatorPenaltySpec::Active {
1050 initial_log_lambda: 0.0,
1051 prior: None,
1052 };
1053 let gate = |order: f64| {
1054 if m > order + ORDER_EPS {
1055 active()
1056 } else {
1057 OperatorPenaltySpec::Disabled
1058 }
1059 };
1060 Self {
1061 mass: active(),
1062 tension: gate(1.0),
1063 stiffness: gate(2.0),
1064 }
1065 }
1066}
1067
1068pub fn minimum_duchon_power_for_operator_penalties(
1069 dim: usize,
1070 nullspace_order: DuchonNullspaceOrder,
1071 max_operator_derivative_order: usize,
1072) -> usize {
1073 let p = duchon_p_from_nullspace_order(nullspace_order);
1074 let mut s = 0usize;
1075 while 2 * (p + s) <= dim + max_operator_derivative_order {
1076 s += 1;
1077 }
1078 s
1079}
1080
1081/// Resolve a fully admissible Duchon `(nullspace_order, power)` pair.
1082///
1083/// Three constraints fold into one resolution:
1084/// (a) operator collocation up to `max_op`: `2(p + s) > d + max_op`
1085/// (b) pure-mode CPD vs polynomial nullspace P_p: `2s < d`
1086/// (Wendland Thm 8.17: pure polyharmonic kernel of order m = p+s in
1087/// R^d is CPD of order `m − ⌊d/2⌋ + 1[d even, log] / m − (d−1)/2
1088/// [d odd]`, and Duchon interpolation against P_p is well-posed iff
1089/// CPD order ≤ p, which collapses to `2s < d` since 2s, d are
1090/// integers and 2s is even.)
1091/// (a) implies the kernel-existence condition `2(p + s) > d`.
1092/// (b) is dropped when `length_scale` is `Some` (hybrid Matérn-blended
1093/// kernel is strictly PD, CPD order 0).
1094///
1095/// Strategy: at the requested `nullspace_order`, take the smallest `s`
1096/// satisfying (a). If that `s` violates (b) in pure mode, escalate the
1097/// nullspace order by one and retry. Termination: at `p ≥ ⌈(d+max_op)/2⌉ + 1`
1098/// the operator constraint (a) admits `s = 0`, and `0 < d` satisfies (b)
1099/// for any `d ≥ 1`, so escalation always converges.
1100///
1101/// The returned nullspace order is monotone in the request: it never
1102/// decreases the user's requested order — only strengthens it when pure-mode
1103/// CPD requires a richer polynomial absorption space.
1104pub fn resolve_duchon_orders(
1105 dim: usize,
1106 requested_nullspace_order: DuchonNullspaceOrder,
1107 max_operator_derivative_order: usize,
1108 length_scale: Option<f64>,
1109) -> (DuchonNullspaceOrder, usize) {
1110 assert!(dim >= 1, "Duchon basis requires dim >= 1");
1111 let pure = length_scale.is_none();
1112 let mut nullspace = requested_nullspace_order;
1113 // Bounded loop: escalation terminates by the argument above.
1114 for _ in 0..=(dim + max_operator_derivative_order + 1) {
1115 let p = duchon_p_from_nullspace_order(nullspace);
1116 // Smallest s with 2(p + s) > d + max_op:
1117 // 2p > d + max_op ⇒ s = 0
1118 // else s = ⌈(d + max_op + 1 − 2p) / 2⌉ = (d + max_op + 2 − 2p) / 2
1119 let s_op = if 2 * p > dim + max_operator_derivative_order {
1120 0
1121 } else {
1122 (dim + max_operator_derivative_order + 2 - 2 * p) / 2
1123 };
1124 if !pure || 2 * s_op < dim {
1125 return (nullspace, s_op);
1126 }
1127 nullspace = duchon_next_nullspace_order(nullspace);
1128 }
1129 // Bounded-loop fallback: by the analysis in the docstring, for
1130 // `p >= ceil((dim + max_op) / 2) + 1` the operator constraint admits
1131 // `s = 0` and (in pure mode) `0 < dim` satisfies the kernel-existence
1132 // condition. The loop above always reaches that regime within the bound,
1133 // so returning the last `nullspace` with `s = 0` is a valid answer.
1134 (nullspace, 0)
1135}
1136
1137#[inline]
1138pub(crate) fn duchon_next_nullspace_order(order: DuchonNullspaceOrder) -> DuchonNullspaceOrder {
1139 match order {
1140 DuchonNullspaceOrder::Zero => DuchonNullspaceOrder::Linear,
1141 DuchonNullspaceOrder::Linear => DuchonNullspaceOrder::Degree(2),
1142 DuchonNullspaceOrder::Degree(k) => DuchonNullspaceOrder::Degree(k + 1),
1143 }
1144}
1145
1146pub(crate) fn duchon_previous_nullspace_order(order: DuchonNullspaceOrder) -> DuchonNullspaceOrder {
1147 match order {
1148 DuchonNullspaceOrder::Zero => DuchonNullspaceOrder::Zero,
1149 DuchonNullspaceOrder::Linear => DuchonNullspaceOrder::Zero,
1150 DuchonNullspaceOrder::Degree(2) => DuchonNullspaceOrder::Linear,
1151 DuchonNullspaceOrder::Degree(k) => DuchonNullspaceOrder::Degree(k - 1),
1152 }
1153}
1154
1155/// Returns the maximum derivative order required by the *active* operator
1156/// penalties: 2 if stiffness is Active, else 1 if tension is Active, else 0.
1157/// Mass-only (or no active operator) penalties only require kernel validity
1158/// (`2(p+s) > d`), tension requires D1 collocation (`2(p+s) > d+1`), and
1159/// stiffness requires D2 collocation (`2(p+s) > d+2`).
1160pub fn duchon_max_active_operator_derivative_order(
1161 operator_penalties: &DuchonOperatorPenaltySpec,
1162) -> usize {
1163 if matches!(
1164 operator_penalties.stiffness,
1165 OperatorPenaltySpec::Active { .. }
1166 ) {
1167 2
1168 } else if matches!(
1169 operator_penalties.tension,
1170 OperatorPenaltySpec::Active { .. }
1171 ) {
1172 1
1173 } else {
1174 0
1175 }
1176}
1177
1178/// Metadata returned by generic basis builders.
1179#[derive(Debug, Clone)]
1180pub enum BasisMetadata {
1181 BSpline1D {
1182 knots: Array1<f64>,
1183 identifiability_transform: Option<Array2<f64>>,
1184 periodic: Option<(f64, f64, usize)>,
1185 /// Effective B-spline polynomial degree carried by `knots`.
1186 ///
1187 /// Persisted alongside `knots` so prediction can reconstruct an
1188 /// evaluator that matches fit-time geometry, even when the fit-time
1189 /// auto-shrink (issue #340) reduced the user's requested degree to
1190 /// fit the available data (`n` too small for cubic ⇒ quadratic ⇒
1191 /// linear). When `None` the consumer should fall back to the
1192 /// upstream `BSplineBasisSpec.degree` (legacy / non-shrunk path).
1193 degree: Option<usize>,
1194 /// Human-readable description of an automatic basis shrink (issue #340)
1195 /// when the user's requested `(degree, num_internal_knots)` exceeded the
1196 /// available evaluation count `n`. `Some(note)` records the before→after
1197 /// configuration; `None` means no auto-shrink occurred for this basis.
1198 auto_shrink_note: Option<String>,
1199 },
1200 /// Natural cubic regression spline (`bs="cr"`/`"cs"`) metadata (#1074).
1201 ///
1202 /// `knots` are the `k` Lancaster–Salkauskas knots that index the basis
1203 /// values directly (basis dim = `knots.len()`). Predict-time rebuilds
1204 /// reconstruct the cr geometry from `knots` and replay the captured
1205 /// `identifiability_transform` exactly, mirroring `BSpline1D`.
1206 CubicRegression1D {
1207 knots: Array1<f64>,
1208 identifiability_transform: Option<Array2<f64>>,
1209 },
1210 ThinPlate {
1211 centers: Array2<f64>,
1212 length_scale: f64,
1213 periodic: Option<Vec<Option<f64>>>,
1214 identifiability_transform: Option<Array2<f64>>,
1215 /// Per-column standard deviations used for input standardization (d > 1).
1216 input_scales: Option<Vec<f64>>,
1217 /// Wood-TPRS radial reparameterization carried into prediction so the
1218 /// rotated radial basis at predict-time matches fit-time exactly. `None`
1219 /// in the lazy/streaming path which retains the original basis.
1220 radial_reparam: Option<Array2<f64>>,
1221 },
1222 Sphere {
1223 centers: Array2<f64>,
1224 penalty_order: usize,
1225 method: SphereMethod,
1226 max_degree: Option<usize>,
1227 wahba_kernel: SphereWahbaKernel,
1228 constraint_transform: Option<Array2<f64>>,
1229 },
1230 /// Constant-curvature (`M_κ`) geodesic-kernel smooth (#944). `kappa` and
1231 /// the realized `length_scale` are persisted so predict-time (and the
1232 /// future ψ-channel per-trial) rebuilds replay the exact fit-time
1233 /// geometry; `constraint_transform` is the composed `z · z_parametric`
1234 /// frozen by the global identifiability pipeline (#532 pattern).
1235 ConstantCurvature {
1236 centers: Array2<f64>,
1237 kappa: f64,
1238 length_scale: f64,
1239 constraint_transform: Option<Array2<f64>>,
1240 },
1241 /// Measure-jet spline smooth: multiscale local-jet-residual energy of the
1242 /// empirical measure, quadratured on the center set. `centers` are the
1243 /// REALIZED barycenter nodes; `order_s` stores the spec's order sentinel
1244 /// verbatim as the mode marker (0.0 = per-level/spectral, > 0 = fused
1245 /// pin — persisting a realized default would flip the rebuilt mode). The
1246 /// penalty depends on the FIT data through `masses`, the realized
1247 /// `eps_band`, the support anchors, and the normalization scales, so all
1248 /// are persisted and replayed verbatim by
1249 /// predict-time (and per-ψ-trial) rebuilds — recomputing either from
1250 /// predict rows would change the penalty the coefficients were estimated
1251 /// under. `constraint_transform` is the composed `z · z_parametric`
1252 /// frozen by the global identifiability pipeline (#532 pattern).
1253 MeasureJet {
1254 centers: Array2<f64>,
1255 input_scales: Option<Vec<f64>>,
1256 length_scale: f64,
1257 eps_band: Vec<f64>,
1258 order_s: f64,
1259 alpha: f64,
1260 tau0: f64,
1261 masses: Array1<f64>,
1262 support_means: Vec<f64>,
1263 penalty_normalization_scales: Vec<f64>,
1264 raw_penalty_normalization_scales: Vec<f64>,
1265 fused_penalty_normalization_scale: Option<f64>,
1266 constraint_transform: Option<Array2<f64>>,
1267 },
1268 Matern {
1269 centers: Array2<f64>,
1270 length_scale: f64,
1271 periodic: Option<Vec<Option<f64>>>,
1272 nu: MaternNu,
1273 include_intercept: bool,
1274 identifiability_transform: Option<Array2<f64>>,
1275 /// Per-column standard deviations used for input standardization (d > 1).
1276 input_scales: Option<Vec<f64>>,
1277 /// Per-axis anisotropy log-scales η_a for geometric anisotropy.
1278 /// When Some, distance is r = √(Σ_a exp(2η_a) · (x_a - c_a)²).
1279 aniso_log_scales: Option<Vec<f64>>,
1280 /// Realized double-penalty nullspace-shrinkage decision at this build
1281 /// (gam#787/#860). The freeze step pins this into
1282 /// `MaternIdentifiability::FrozenTransform::nullspace_shrinkage_survived`
1283 /// so the κ-optimizer's per-trial rebuilds keep the learned-penalty count
1284 /// invariant (otherwise the κ-dependent spectral test flips it 6↔7 → "joint
1285 /// hyper rho dimension mismatch").
1286 nullspace_shrinkage_survived: bool,
1287 },
1288 Duchon {
1289 centers: Array2<f64>,
1290 length_scale: Option<f64>,
1291 periodic: Option<Vec<Option<f64>>>,
1292 power: f64,
1293 nullspace_order: DuchonNullspaceOrder,
1294 identifiability_transform: Option<Array2<f64>>,
1295 /// Per-column standard deviations used for input standardization (d > 1).
1296 input_scales: Option<Vec<f64>>,
1297 /// Per-axis anisotropy log-scales η_a, stored for prediction.
1298 aniso_log_scales: Option<Vec<f64>>,
1299 /// Support points used to build the active lower-order operator
1300 /// penalties (mass/tension/stiffness). Stored so runtime adaptive
1301 /// caches can rebuild the exact same operator rows instead of guessing
1302 /// from centers.
1303 operator_collocation_points: Option<Array2<f64>>,
1304 /// Data-metric radial reparameterization `V` (#1355). When `Some`, the
1305 /// constrained kernel transform is folded to `Z·V` so predict-time and
1306 /// κ-trial rebuilds replay the exact fit-time rotated radial basis.
1307 /// `None` on the lazy/streaming path (original constrained basis).
1308 radial_reparam: Option<Array2<f64>>,
1309 },
1310 Pca {
1311 feature_cols: Vec<usize>,
1312 basis_matrix: Array2<f64>,
1313 centered: bool,
1314 smooth_penalty: f64,
1315 center_mean: Option<Array1<f64>>,
1316 pca_basis_path: Option<std::path::PathBuf>,
1317 chunk_size: usize,
1318 },
1319 TensorBSpline {
1320 feature_cols: Vec<usize>,
1321 knots: Vec<Array1<f64>>,
1322 degrees: Vec<usize>,
1323 periods: Vec<Option<f64>>,
1324 /// Per-margin flag: `true` when that margin is a natural cubic
1325 /// regression spline (`NaturalCubicRegression` knotspec) rather than an
1326 /// open/periodic B-spline (#1074). Persisted so the tensor freeze
1327 /// rebuilds the cr marginal knotspec (value-at-knot) instead of an open
1328 /// `Provided(knots)` B-spline, keeping predict-time marginals identical
1329 /// to the fit-time cr margins. Defaults to all-`false` (legacy B-spline
1330 /// tensors) when deserialized from an older persisted model (the
1331 /// older-model default is applied on the persisted `SmoothBasisSpec`
1332 /// side; `BasisMetadata` itself is transient builder output and is not
1333 /// serde-serialized, so it carries no `#[serde]` attributes).
1334 is_cr: Vec<bool>,
1335 identifiability_transform: Option<Array2<f64>>,
1336 },
1337 SphereHarmonics {
1338 max_degree: usize,
1339 radians: bool,
1340 },
1341 /// Wrap an inner basis metadata to record a multiplicative `by` (continuous or
1342 /// factor) along a column of the dataset.
1343 BySmooth {
1344 inner: Box<BasisMetadata>,
1345 by_col: usize,
1346 levels: Option<Vec<u64>>,
1347 ordered: bool,
1348 },
1349 /// Factor-by-smooth (mgcv-style `s(x, by=g, bs="fs"|"sz"|"re")`).
1350 FactorSmooth {
1351 continuous_cols: Vec<usize>,
1352 group_col: usize,
1353 knots: Array1<f64>,
1354 degree: usize,
1355 periodic: Option<(f64, f64, usize)>,
1356 group_levels: Vec<u64>,
1357 flavour: String,
1358 /// `true` when the per-level marginal is a cubic regression spline
1359 /// (`NaturalCubicRegression` knotspec, mgcv's `bs="sz"` default marginal,
1360 /// #1074). Predict-time freeze must then restore a cr knotspec from the
1361 /// stored value-knots rather than treating them as a B-spline knot
1362 /// vector. Defaults to `false` (B-spline marginal) for backward compat.
1363 marginal_is_cr: bool,
1364 },
1365}
1366
1367/// Standardized basis build result for engine-level composition.
1368#[derive(Clone)]
1369pub struct BasisBuildResult {
1370 pub design: DesignMatrix,
1371 pub penalties: Vec<Array2<f64>>,
1372 pub nullspace_dims: Vec<usize>,
1373 pub penaltyinfo: Vec<PenaltyInfo>,
1374 pub metadata: BasisMetadata,
1375 /// Optional factored rowwise-Kronecker representation for tensor-product
1376 /// bases. When present, downstream code can keep the design operator-backed
1377 /// instead of forcing a fully materialized `n x prod(q_j)` block.
1378 pub kronecker_factored: Option<KroneckerFactoredBasis>,
1379 /// Per-active-penalty operator handles (parallel to `penalties`). Each
1380 /// entry is `Some(op)` when the closed-form factory emitted an op-form
1381 /// penalty bit-equivalent to the dense matrix, `None` for ordinary dense
1382 /// penalties. Downstream consumers route through the `Some` entries to
1383 /// avoid materializing dense `p x p` Grams in exact operator algebra.
1384 pub ops: Vec<Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>>,
1385 /// Per-active-penalty null-space eigenvector matrices (parallel to
1386 /// `penalties`). Each entry is `Some(U_null)` with `U_null.ncols() ==
1387 /// nullspace_dims[k]` when the active block has a non-trivial null space
1388 /// (eigenvalues ≤ spectral tolerance), and `None` when the block is
1389 /// already full-rank. The columns of `U_null` are the eigenvectors of
1390 /// `sym_penalty` at the (near-)zero eigenvalues — i.e., an orthonormal
1391 /// basis of `null(S_block)` in the block's own coordinate system.
1392 ///
1393 /// This is the raw spectral data that the construction pipeline uses to
1394 /// absorb each smooth's penalty null space into the parametric block
1395 /// (reparameterize-and-split). Without absorption the inner Newton solve
1396 /// cannot converge on data whose unpenalized signal lies along a null
1397 /// direction of `S` (phantom-multiplier refusal at the KKT certificate).
1398 pub null_eigenvectors: Vec<Option<Array2<f64>>>,
1399 /// Joint-null absorption rotation for this basis, when the basis carries
1400 /// any penalties with a non-trivial joint null space.
1401 ///
1402 /// `Some(rotation)` records `Q = [U_range | U_null]` where `U_null` spans
1403 /// the joint null space `null(Σ_k S_k)` over this basis's active
1404 /// penalties (unscaled — the structural joint null is independent of
1405 /// `λ`). After the basis pipeline applies this rotation, the design
1406 /// becomes `X · Q` and each penalty becomes `Qᵀ S_k Q`, block-diagonal
1407 /// with a guaranteed-zero null tail. The same `Q` must be replayed at
1408 /// prediction time, so it is persisted in the fitted model. `None`
1409 /// indicates either no penalties on this basis, or a full-rank joint
1410 /// penalty (joint nullity = 0). A `Some` value is never recorded with
1411 /// `joint_nullity == 0` — the `None` discriminant is canonical for
1412 /// "nothing to absorb".
1413 ///
1414 /// Stage-2 commit A: this field is plumbed into the struct but neither
1415 /// computed nor applied yet. Stage-2 commit B populates it; Stage-2
1416 /// commit D applies the rotation to `design` and `penalties`.
1417 pub joint_null_rotation: Option<JointNullRotation>,
1418}
1419
1420/// Joint-null absorption rotation, attached to a smooth's basis when the
1421/// basis's joint penalty `Σ_k S_k` has a non-trivial null space.
1422///
1423/// The `rotation` field stores the orthonormal eigenvector matrix
1424/// `Q = [U_range | U_null]` of the symmetric joint penalty: the first
1425/// `range_dim = rotation.ncols() - joint_nullity` columns span
1426/// `range(Σ_k S_k)`; the remaining `joint_nullity` columns span
1427/// `null(Σ_k S_k)`. After the pipeline applies the rotation, the smooth's
1428/// coefficient vector satisfies `β = Q · γ`, the design becomes `X · Q`,
1429/// and each per-block penalty `S_k` becomes `Qᵀ S_k Q`, which is guaranteed
1430/// block-diagonal with a zero `(joint_nullity × joint_nullity)` tail
1431/// (because the joint null annihilates every active `S_k`).
1432#[derive(Clone, Serialize, Deserialize)]
1433pub struct JointNullRotation {
1434 /// `(p_smooth × p_smooth)` orthonormal matrix; range columns first,
1435 /// joint-null columns last.
1436 pub rotation: Array2<f64>,
1437 /// Number of columns at the tail of `rotation` that span the joint
1438 /// null space. Always `> 0` when wrapped in `Some` — the value `0`
1439 /// is encoded as `None`.
1440 pub joint_nullity: usize,
1441}
1442
1443impl std::fmt::Debug for JointNullRotation {
1444 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1445 f.debug_struct("JointNullRotation")
1446 .field(
1447 "rotation",
1448 &format_args!("{}×{}", self.rotation.nrows(), self.rotation.ncols()),
1449 )
1450 .field("joint_nullity", &self.joint_nullity)
1451 .finish()
1452 }
1453}
1454
1455impl std::fmt::Debug for BasisBuildResult {
1456 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1457 f.debug_struct("BasisBuildResult")
1458 .field("design", &self.design)
1459 .field("penalties", &self.penalties)
1460 .field("nullspace_dims", &self.nullspace_dims)
1461 .field("penaltyinfo", &self.penaltyinfo)
1462 .field("metadata", &self.metadata)
1463 .field("kronecker_factored", &self.kronecker_factored)
1464 .field(
1465 "ops",
1466 &format_args!(
1467 "[{}]",
1468 self.ops
1469 .iter()
1470 .map(|o| if o.is_some() { "Some" } else { "None" })
1471 .collect::<Vec<_>>()
1472 .join(", ")
1473 ),
1474 )
1475 .field(
1476 "null_eigenvectors",
1477 &format_args!(
1478 "[{}]",
1479 self.null_eigenvectors
1480 .iter()
1481 .map(|u| match u {
1482 Some(m) => format!("Some({}x{})", m.nrows(), m.ncols()),
1483 None => "None".to_string(),
1484 })
1485 .collect::<Vec<_>>()
1486 .join(", ")
1487 ),
1488 )
1489 .field("joint_null_rotation", &self.joint_null_rotation)
1490 .finish()
1491 }
1492}
1493
1494/// Factored tensor-product basis metadata for operator-backed downstream use.
1495#[derive(Debug)]
1496pub struct KroneckerFactoredBasis {
1497 /// Marginal design matrices: `marginal_designs[j]` is `(n, q_j)`.
1498 pub marginal_designs: Vec<Array2<f64>>,
1499 /// Marginal penalty matrices: `marginal_penalties[k]` is `(q_k, q_k)`.
1500 pub marginal_penalties: Vec<Array2<f64>>,
1501 /// Marginal basis dimensions: `[q_0, ..., q_{d-1}]`.
1502 pub marginal_dims: Vec<usize>,
1503 /// Whether the system includes a global ridge (double) penalty.
1504 pub has_double_penalty: bool,
1505 /// λ-invariant tensor structure (marginal eigensystems, reparameterized
1506 /// marginals, shrinkage scale), memoized once per fit. The marginal
1507 /// designs/penalties are fixed for the whole fit, so the expensive marginal
1508 /// `eigh()` and `B_k·U_k` GEMMs only need to run once — every outer REML
1509 /// iterate (50+ on the #1082 tensor cases) then reuses this. Filled lazily
1510 /// on first use via [`Self::invariant_structure`]. NOT serialized and reset
1511 /// to empty on `Clone` (it is purely a within-fit performance cache; a fresh
1512 /// owner recomputes on first demand, keeping every result bit-identical).
1513 invariant: std::sync::OnceLock<std::sync::Arc<crate::kronecker::KroneckerInvariantStructure>>,
1514}
1515
1516impl Clone for KroneckerFactoredBasis {
1517 fn clone(&self) -> Self {
1518 Self {
1519 marginal_designs: self.marginal_designs.clone(),
1520 marginal_penalties: self.marginal_penalties.clone(),
1521 marginal_dims: self.marginal_dims.clone(),
1522 has_double_penalty: self.has_double_penalty,
1523 // Propagate the memoized structure when present so a clone made
1524 // mid-fit keeps the hoist; otherwise start empty (recomputed on
1525 // first demand, identical result).
1526 invariant: match self.invariant.get() {
1527 Some(s) => {
1528 let cell = std::sync::OnceLock::new();
1529 cell.get_or_init(|| std::sync::Arc::clone(s));
1530 cell
1531 }
1532 None => std::sync::OnceLock::new(),
1533 },
1534 }
1535 }
1536}
1537
1538impl KroneckerFactoredBasis {
1539 /// Construct from the fixed marginal data with an empty invariant cache.
1540 pub fn new(
1541 marginal_designs: Vec<Array2<f64>>,
1542 marginal_penalties: Vec<Array2<f64>>,
1543 marginal_dims: Vec<usize>,
1544 has_double_penalty: bool,
1545 ) -> Self {
1546 Self {
1547 marginal_designs,
1548 marginal_penalties,
1549 marginal_dims,
1550 has_double_penalty,
1551 invariant: std::sync::OnceLock::new(),
1552 }
1553 }
1554
1555 /// Lazily compute (once) and return the λ-invariant tensor structure
1556 /// (marginal eigensystems, reparameterized marginals, shrinkage scale).
1557 ///
1558 /// Computed from the fixed marginal designs/penalties, so the result is the
1559 /// same on every call within a fit; the first call pays the `eigh()` cost
1560 /// and every later call is a pointer load. Because the cache is keyed on the
1561 /// fixed marginal data and `marginal_penalties`/`marginal_designs` are
1562 /// immutable for the fit's lifetime, no invalidation is needed.
1563 pub fn invariant_structure(
1564 &self,
1565 ) -> Result<std::sync::Arc<crate::kronecker::KroneckerInvariantStructure>, BasisError> {
1566 // Fast path: already memoized.
1567 if let Some(s) = self.invariant.get() {
1568 return Ok(std::sync::Arc::clone(s));
1569 }
1570 // Compute outside the cell (fallible) and install via `get_or_init`. If a
1571 // concurrent racer already won, `get_or_init` drops our `computed` and
1572 // returns the stored one; either way the value is the unique function of
1573 // the fixed marginal data, so the returned Arc is correct.
1574 let computed = std::sync::Arc::new(crate::kronecker::KroneckerInvariantStructure::compute(
1575 &self.marginal_designs,
1576 &self.marginal_penalties,
1577 &self.marginal_dims,
1578 )?);
1579 let installed = self.invariant.get_or_init(|| computed);
1580 Ok(std::sync::Arc::clone(installed))
1581 }
1582}
1583
1584#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1585pub enum PenaltySource {
1586 Primary,
1587 DoublePenaltyNullspace,
1588 OperatorMass,
1589 OperatorTension,
1590 OperatorStiffness,
1591 /// One per input axis `a` of a multivariate Duchon smooth: the gradient
1592 /// energy along axis `a`, `Σ(∂f/∂x_a)²`, each with its own REML λ_a. REML
1593 /// shrinks an axis's contribution toward flat only when it does not earn
1594 /// its keep — penalty-based ARD / variable relevance, the replacement for
1595 /// brittle kernel-η optimization. Emitted when `scale_dims` is on.
1596 OperatorRelevance {
1597 axis: usize,
1598 },
1599 TensorMarginal {
1600 dim: usize,
1601 },
1602 TensorSeparable {
1603 penalized_margins: Vec<usize>,
1604 },
1605 TensorGlobalRidge,
1606 Other(String),
1607}
1608
1609#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1610pub enum PenaltyDropReason {
1611 ZeroMatrix,
1612 NumericalRankZero,
1613}
1614
1615fn default_normalization_scale() -> f64 {
1616 1.0
1617}
1618
1619#[derive(Debug, Clone, Serialize, Deserialize)]
1620pub struct PenaltyInfo {
1621 pub source: PenaltySource,
1622 pub original_index: usize,
1623 pub active: bool,
1624 pub effective_rank: usize,
1625 pub dropped_reason: Option<PenaltyDropReason>,
1626 pub nullspace_dim_hint: usize,
1627 #[serde(default = "default_normalization_scale")]
1628 pub normalization_scale: f64,
1629 /// Kronecker factors preserved from tensor penalty construction.
1630 /// When present, spectral decomposition can use per-factor eigendecomposition.
1631 #[serde(skip)]
1632 pub kronecker_factors: Option<Vec<Array2<f64>>>,
1633}
1634
1635#[derive(Clone)]
1636pub struct PenaltyCandidate {
1637 pub matrix: Array2<f64>,
1638 pub nullspace_dim_hint: usize,
1639 pub source: PenaltySource,
1640 pub normalization_scale: f64,
1641 /// Optional Kronecker factors whose product equals `matrix`.
1642 /// When present, spectral decomposition can be done per-factor
1643 /// (O(Σ q_j³) instead of O((Π q_j)³)).
1644 pub kronecker_factors: Option<Vec<Array2<f64>>>,
1645 /// Optional operator-form handle whose `as_dense()` matches `matrix`. When
1646 /// populated by the closed-form factories, this is propagated through to
1647 /// `CanonicalPenaltyBlock` so downstream consumers can use exact matvec
1648 /// algebra without rebuilding the dense Gram. When `None`, only the dense
1649 /// `matrix` path is available.
1650 pub op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1651}
1652
1653impl std::fmt::Debug for PenaltyCandidate {
1654 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1655 f.debug_struct("PenaltyCandidate")
1656 .field(
1657 "matrix",
1658 &format_args!("{}×{}", self.matrix.nrows(), self.matrix.ncols()),
1659 )
1660 .field("nullspace_dim_hint", &self.nullspace_dim_hint)
1661 .field("source", &self.source)
1662 .field("normalization_scale", &self.normalization_scale)
1663 .field(
1664 "kronecker_factors",
1665 &self.kronecker_factors.as_ref().map(|v| v.len()),
1666 )
1667 .field("op", &self.op.as_ref().map(|o| o.dim()))
1668 .finish()
1669 }
1670}
1671
1672#[derive(Clone)]
1673pub struct CanonicalPenaltyBlock {
1674 pub sym_penalty: Array2<f64>,
1675 /// Eigenvalues from spectral decomposition (retained to avoid recomputation).
1676 pub eigenvalues: Array1<f64>,
1677 /// Eigenvectors from spectral decomposition (retained to avoid recomputation).
1678 pub eigenvectors: Array2<f64>,
1679 pub rank: usize,
1680 pub nullity: usize,
1681 /// Number of genuine negative-curvature eigendirections (`ev < -tol`).
1682 /// A non-PSD penalty has `negative_dim > 0`; these directions are
1683 /// neither range nor null and are never absorbed as unpenalized (#1425).
1684 pub negative_dim: usize,
1685 pub tol: f64,
1686 pub iszero: bool,
1687 /// Optional operator-form handle that is bit-equivalent to `sym_penalty`.
1688 /// Propagated from `PenaltyCandidate.op` when present so downstream
1689 /// consumers can use matvec without rebuilding the dense Gram.
1690 pub op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1691}
1692
1693impl std::fmt::Debug for CanonicalPenaltyBlock {
1694 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1695 f.debug_struct("CanonicalPenaltyBlock")
1696 .field(
1697 "sym_penalty",
1698 &format_args!("{}×{}", self.sym_penalty.nrows(), self.sym_penalty.ncols()),
1699 )
1700 .field("eigenvalues", &self.eigenvalues)
1701 .field(
1702 "eigenvectors",
1703 &format_args!(
1704 "{}×{}",
1705 self.eigenvectors.nrows(),
1706 self.eigenvectors.ncols()
1707 ),
1708 )
1709 .field("rank", &self.rank)
1710 .field("nullity", &self.nullity)
1711 .field("negative_dim", &self.negative_dim)
1712 .field("tol", &self.tol)
1713 .field("iszero", &self.iszero)
1714 .field("op", &self.op.as_ref().map(|o| o.dim()))
1715 .finish()
1716 }
1717}
1718
1719#[derive(Debug)]
1720pub struct BasisPsiDerivativeResult {
1721 pub design_derivative: Array2<f64>,
1722 pub penalties_derivative: Vec<Array2<f64>>,
1723 /// Operator-backed design derivative for standalone first-derivative
1724 /// callers. Bundled first+second callers receive the shared operator on
1725 /// `BasisPsiDerivativeBundle` instead.
1726 pub implicit_operator: Option<ImplicitDesignPsiDerivative>,
1727}
1728
1729#[derive(Debug)]
1730pub struct BasisPsiSecondDerivativeResult {
1731 pub designsecond_derivative: Array2<f64>,
1732 pub penaltiessecond_derivative: Vec<Array2<f64>>,
1733 /// Operator-backed design derivative for standalone second-derivative
1734 /// callers. Bundled first+second callers receive the shared operator on
1735 /// `BasisPsiDerivativeBundle` instead.
1736 pub implicit_operator: Option<ImplicitDesignPsiDerivative>,
1737}
1738
1739#[derive(Debug)]
1740pub struct BasisPsiDerivativeBundle {
1741 pub first: BasisPsiDerivativeResult,
1742 pub second: BasisPsiSecondDerivativeResult,
1743 /// Shared operator-backed design derivative for the first and second
1744 /// psi derivatives. Bundled callers consume this once instead of storing
1745 /// duplicate materialized/streaming operators in both derivative payloads.
1746 pub implicit_operator: Option<ImplicitDesignPsiDerivative>,
1747}
1748
1749/// Per-axis psi_a derivative package for anisotropic spatial terms.
1750///
1751/// For a d-dimensional anisotropic term, the kernel phi(r) depends on
1752/// the anisotropic distance r = |Lambda h| where Lambda = diag(kappa_a). Each axis a
1753/// has its own log-scale psi_a = log(kappa_a), yielding d first derivatives,
1754/// d diagonal second derivatives, and d*(d-1)/2 cross second derivatives.
1755///
1756/// The cross second derivative d2 phi/(d psi_a d psi_b) = t * s_a * s_b (a != b)
1757/// is rank-1, so we store the t_values and s_components vectors rather
1758/// than materializing d^2 matrices.
1759#[derive(Clone)]
1760pub struct AnisoBasisPsiDerivatives {
1761 /// d matrices, each (n x p_smooth): dX/d psi_a.
1762 pub design_first: Vec<Array2<f64>>,
1763 /// d matrices, each (n x p_smooth): d2X/d psi_a^2 (diagonal second derivatives).
1764 pub design_second_diag: Vec<Array2<f64>>,
1765 /// Cross second derivatives d2X/(d psi_a d psi_b) for a < b.
1766 pub design_second_cross: Vec<Array2<f64>>,
1767 /// Axis-pair indices corresponding to `design_second_cross`.
1768 pub design_second_cross_pairs: Vec<(usize, usize)>,
1769 /// d x num_penalties: dS_m/d psi_a for each axis a and penalty m.
1770 pub penalties_first: Vec<Vec<Array2<f64>>>,
1771 /// d x num_penalties: d2S_m/d psi_a^2 for each axis a and penalty m.
1772 pub penalties_second_diag: Vec<Vec<Array2<f64>>>,
1773 /// The (a, b) axis pairs supported by the on-demand cross-penalty
1774 /// provider. Only the upper triangle (a < b) is stored.
1775 pub penalties_cross_pairs: Vec<(usize, usize)>,
1776 /// On-demand cross-penalty second-derivative provider. Exact anisotropic
1777 /// cross-axis penalty seconds are streamed one pair at a time rather than
1778 /// stored as a dense upper triangle of blocks.
1779 pub penalties_cross_provider: Option<AnisoPenaltyCrossProvider>,
1780 /// Shared operator-backed representation of the anisotropic kernel-side
1781 /// design derivatives. When `design_first` / `design_second_diag` are empty,
1782 /// callers must use this operator directly; when they are present, this
1783 /// operator still provides exact cross-axis second derivatives without
1784 /// duplicating separate `t` / `s_a` storage layouts.
1785 pub implicit_operator: Option<ImplicitDesignPsiDerivative>,
1786}
1787
1788#[derive(Clone)]
1789pub struct AnisoPenaltyCrossProvider(
1790 std::sync::Arc<
1791 dyn Fn(usize, usize) -> Result<Vec<Array2<f64>>, BasisError> + Send + Sync + 'static,
1792 >,
1793);
1794
1795impl AnisoPenaltyCrossProvider {
1796 pub(crate) fn new<F>(f: F) -> Self
1797 where
1798 F: Fn(usize, usize) -> Result<Vec<Array2<f64>>, BasisError> + Send + Sync + 'static,
1799 {
1800 Self(std::sync::Arc::new(f))
1801 }
1802
1803 pub fn evaluate(&self, axis_a: usize, axis_b: usize) -> Result<Vec<Array2<f64>>, BasisError> {
1804 (self.0)(axis_a, axis_b)
1805 }
1806}
1807
1808// ═══════════════════════════════════════════════════════════════════════════
1809// Implicit derivative operator for scalable anisotropic REML gradients
1810// ═══════════════════════════════════════════════════════════════════════════
1811
1812pub(crate) const SPATIAL_CENTER_CENTER_MAX_BYTES: usize = 512 * 1024 * 1024; // 512 MiB
1813pub(crate) const DESIGN_CROSS_CHUNK_SIZE: usize = 1024;
1814
1815/// Determine whether implicit operators should be used based on problem size
1816/// and the supplied [`ResourcePolicy`].
1817///
1818/// Returns `true` when the dense materialization of D first-derivative
1819/// matrices would exceed `policy.max_single_materialization_bytes`.
1820///
1821/// For D axes with n data points and p_smooth basis columns, the dense path
1822/// allocates D * n * p_smooth * 8 bytes for first-derivative matrices alone
1823/// (plus a similar amount for second derivatives). The implicit path stores
1824/// only the compact (n * n_knots) radial jets plus (n * n_knots * D) axis
1825/// fractions, which is O(n * k * D) instead of O(n * p * D).
1826pub fn should_use_implicit_operators_with_policy(
1827 n: usize,
1828 p: usize,
1829 d: usize,
1830 policy: &gam_runtime::resource::ResourcePolicy,
1831) -> bool {
1832 // Each first-derivative matrix is (n x p) f64 → n*p*8 bytes.
1833 // We need D of them for first derivatives, D for second diag, plus
1834 // the cross-t matrix and s_components. Conservative estimate: 3*D matrices.
1835 let dense_bytes = 3usize
1836 .saturating_mul(n)
1837 .saturating_mul(p)
1838 .saturating_mul(d)
1839 .saturating_mul(8);
1840 dense_bytes > policy.max_single_materialization_bytes
1841}
1842
1843pub(crate) fn implicit_radial_cache_bytes(n: usize, k: usize, n_axes: usize) -> usize {
1844 n.saturating_mul(k)
1845 .saturating_mul(n_axes.saturating_add(3))
1846 .saturating_mul(8)
1847}
1848
1849pub(crate) fn should_cache_implicit_radial_components(
1850 n: usize,
1851 k: usize,
1852 n_axes: usize,
1853 policy: &gam_runtime::resource::ResourcePolicy,
1854) -> bool {
1855 implicit_radial_cache_bytes(n, k, n_axes) <= policy.max_operator_cache_bytes
1856}
1857
1858pub fn assert_no_dense_derivative_materialization(n: usize, p: usize, d_pc: usize) {
1859 let first = dense_design_bytes(n, p).saturating_mul(d_pc);
1860 let second = dense_design_bytes(n, p).saturating_mul(d_pc.saturating_mul(d_pc));
1861 // Consult the library default ResourcePolicy. Production large-scale runs
1862 // configure `AnalyticOperatorRequired`, which still refuses every dense
1863 // materialization here. The default `MaterializeIfSmall` mode lets tiny
1864 // problems (and small-data/test usage) materialize as long as the combined
1865 // first- and second-order dense bytes fit under the single-materialization
1866 // byte budget. `DiagnosticsOnly` is treated like `MaterializeIfSmall` for
1867 // this guard: it permits dense materialization under the same byte cap.
1868 let policy = gam_runtime::resource::ResourcePolicy::default_library();
1869 let budget = policy.max_single_materialization_bytes;
1870 let needed = first.saturating_add(second);
1871 match policy.derivative_storage_mode {
1872 gam_runtime::resource::DerivativeStorageMode::AnalyticOperatorRequired => {
1873 // SAFETY: this assertion helper exists specifically to enforce
1874 // the large-scale invariant that spatial-PC Duchon derivative
1875 // designs never persist as dense `Array2<f64>` storage. When the
1876 // resource policy is `AnalyticOperatorRequired`, any caller that
1877 // reached this point has materialized something the strict
1878 // operator contract forbids.
1879 // SAFETY: AnalyticOperatorRequired forbids dense derivative materialization.
1880 panic!(
1881 "spatial PC Duchon derivative designs must remain operator-backed; refused persistent dense derivative materialization (n={n}, p={p}, d_pc={d_pc}, first_order={:.1} MiB, second_order={:.1} MiB)",
1882 first as f64 / (1024.0 * 1024.0),
1883 second as f64 / (1024.0 * 1024.0),
1884 );
1885 }
1886 gam_runtime::resource::DerivativeStorageMode::MaterializeIfSmall
1887 | gam_runtime::resource::DerivativeStorageMode::DiagnosticsOnly => {
1888 // SAFETY: exceeding the single-materialization budget here is a
1889 // contract violation by an upstream caller that must route through
1890 // the operator-backed path; failing loudly surfaces it rather than
1891 // silently materializing an oversized dense derivative design.
1892 assert!(
1893 needed <= budget,
1894 "spatial PC Duchon derivative designs would exceed the single-materialization budget; refused persistent dense derivative materialization (n={n}, p={p}, d_pc={d_pc}, first_order={:.1} MiB, second_order={:.1} MiB, budget={:.1} MiB)",
1895 first as f64 / (1024.0 * 1024.0),
1896 second as f64 / (1024.0 * 1024.0),
1897 budget as f64 / (1024.0 * 1024.0),
1898 );
1899 }
1900 }
1901}
1902
1903pub fn assert_spatial_centers_below_large_scale_cap(
1904 d_pc: usize,
1905 centers: ArrayView2<'_, f64>,
1906) -> Result<(), BasisError> {
1907 if centers.ncols() != d_pc {
1908 crate::bail_dim_basis!(
1909 "spatial PC center dimension mismatch: centers have {} columns, expected {d_pc}",
1910 centers.ncols()
1911 );
1912 }
1913 let k = centers.nrows();
1914 let centers_bytes = dense_design_bytes(k, d_pc);
1915 let center_center_bytes = dense_design_bytes(k, k);
1916 if centers_bytes > SPATIAL_CENTER_CENTER_MAX_BYTES {
1917 crate::bail_invalid_basis!(
1918 "spatial PC centers exceed center storage cap: K={k}, d_pc={d_pc}, centers={:.1} MiB, cap={:.1} MiB",
1919 centers_bytes as f64 / (1024.0 * 1024.0),
1920 SPATIAL_CENTER_CENTER_MAX_BYTES as f64 / (1024.0 * 1024.0),
1921 );
1922 }
1923 if center_center_bytes > SPATIAL_CENTER_CENTER_MAX_BYTES {
1924 crate::bail_invalid_basis!(
1925 "spatial PC centers exceed center-center large-scale cap: K={k}, d_pc={d_pc}, KxK={:.1} MiB, cap={:.1} MiB",
1926 center_center_bytes as f64 / (1024.0 * 1024.0),
1927 SPATIAL_CENTER_CENTER_MAX_BYTES as f64 / (1024.0 * 1024.0),
1928 );
1929 }
1930 Ok(())
1931}
1932
1933pub(crate) fn dense_design_bytes(n: usize, p: usize) -> usize {
1934 n.saturating_mul(p)
1935 .saturating_mul(std::mem::size_of::<f64>())
1936}
1937
1938pub(crate) fn should_use_lazy_spatial_design(
1939 n: usize,
1940 p: usize,
1941 policy: &gam_runtime::resource::ResourcePolicy,
1942) -> bool {
1943 dense_design_bytes(n, p) > policy.max_single_materialization_bytes
1944}
1945
1946pub(crate) fn wrap_dense_design_with_transform(
1947 design: DesignMatrix,
1948 transform: &Array2<f64>,
1949 label: &str,
1950) -> Result<DesignMatrix, BasisError> {
1951 match design {
1952 DesignMatrix::Dense(inner) => {
1953 let op = CoefficientTransformOperator::new(inner, transform.clone()).map_err(|e| {
1954 BasisError::InvalidInput(format!("{label} coefficient transform failed: {e}"))
1955 })?;
1956 Ok(DesignMatrix::Dense(
1957 gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op)),
1958 ))
1959 }
1960 DesignMatrix::Sparse(_) => Err(BasisError::InvalidInput(format!(
1961 "{label} coefficient transform requires a dense/operator-backed design"
1962 ))),
1963 }
1964}
1965
1966/// Single-pass `(Bᵀ(W·C), BᵀB)` accumulation over the streamed design.
1967///
1968/// Materialises each row chunk of the design **once** and reuses it for both
1969/// the constraint cross `Bᵀ(W·C)` and the Gram `BᵀB`. On the lazy chunked
1970/// spatial path each `try_row_chunk` re-evaluates all kernel columns for the
1971/// chunk, so accumulating both products in a single sweep halves the per-build
1972/// kernel re-evaluation work (the dominant cost at large scale) versus two
1973/// independent streaming passes — without changing the result beyond
1974/// floating-point reassociation. The cross is masked off (`q == 0`) by the
1975/// caller, which never invokes this when there is no constraint block.
1976pub(crate) fn design_cross_and_gram(
1977 design: &DesignMatrix,
1978 constraint_matrix: ArrayView2<'_, f64>,
1979 weights: Option<ArrayView1<'_, f64>>,
1980) -> Result<(Array2<f64>, Array2<f64>), BasisError> {
1981 let n = design.nrows();
1982 let k = design.ncols();
1983 if constraint_matrix.nrows() != n {
1984 return Err(BasisError::ConstraintMatrixRowMismatch {
1985 basisrows: n,
1986 constraintrows: constraint_matrix.nrows(),
1987 });
1988 }
1989 if let Some(w) = weights
1990 && w.len() != n
1991 {
1992 return Err(BasisError::WeightsDimensionMismatch {
1993 expected: n,
1994 found: w.len(),
1995 });
1996 }
1997 let q = constraint_matrix.ncols();
1998 let mut cross = Array2::<f64>::zeros((k, q));
1999 let mut gram = Array2::<f64>::zeros((k, k));
2000 for start in (0..n).step_by(DESIGN_CROSS_CHUNK_SIZE) {
2001 let end = (start + DESIGN_CROSS_CHUNK_SIZE).min(n);
2002 let basis_chunk = design
2003 .try_row_chunk(start..end)
2004 .map_err(|e| BasisError::InvalidInput(e.to_string()))?;
2005 let mut constraint_chunk = constraint_matrix.slice(s![start..end, ..]).to_owned();
2006 if let Some(w) = weights {
2007 for (mut row, &weight) in constraint_chunk
2008 .axis_iter_mut(Axis(0))
2009 .zip(w.slice(s![start..end]).iter())
2010 {
2011 row *= weight;
2012 }
2013 }
2014 cross += &fast_atb(&basis_chunk, &constraint_chunk);
2015 gram += &fast_atb(&basis_chunk, &basis_chunk);
2016 }
2017 Ok((cross, gram))
2018}
2019
2020pub(crate) fn positive_spectral_whitener_from_gram(
2021 gram: &Array2<f64>,
2022) -> Result<Array2<f64>, BasisError> {
2023 // Inverse-square-root for the positive part of `gram`. Eigenvalues at or
2024 // below the relative rank tolerance `α·ε·n·max_eval` are *dropped*: the
2025 // returned whitener has shape `(n × keep)` where `keep` counts strictly
2026 // positive eigendirections of `gram`.
2027 //
2028 // Dropping (rather than ridging) is what makes the result a true
2029 // square-root inverse on the column space of `gram`. This whitener is
2030 // used by `stabilized_orthogonality_transform_from_gram` to make a
2031 // pre-existing transform `K_raw` orthonormal under the W-inner product:
2032 // when some columns of `K_raw` map to zero (or near-zero) under `B`, the
2033 // constrained Gram `K_raw^T G K_raw` is rank-deficient. Ridging those
2034 // tail directions with `1/sqrt(ε)` produced spurious basis columns
2035 // whose coefficient norms blew up to `~1/sqrt(ε)` while their image in
2036 // `B` was floating-point zero, contaminating downstream linear algebra
2037 // (in particular it forced `smooth.rs` to widen the post-transform
2038 // orthogonality residual tolerance to absorb a `cond ≈ 1/sqrt(ε)`
2039 // rounding floor). Dropping these directions is the right behavior:
2040 // they contribute nothing to `B`'s column space, and removing them
2041 // tightens the orthogonality residual back down to the genuine
2042 // floating-point limit.
2043 let (eigenvalues, eigenvectors) = gram.eigh(Side::Lower).map_err(BasisError::LinalgError)?;
2044 let n = gram.nrows();
2045 let max_eval = eigenvalues.iter().copied().fold(0.0_f64, f64::max);
2046 let tol = (default_rrqr_rank_alpha() * f64::EPSILON * (n.max(1) as f64) * max_eval.max(1.0))
2047 .max(f64::EPSILON);
2048 let keep = eigenvalues.iter().filter(|&&ev| ev > tol).count();
2049 if keep == 0 {
2050 let min_ev = eigenvalues.iter().copied().fold(f64::INFINITY, f64::min);
2051 return Err(BasisError::ConstraintNullspaceCollapsed {
2052 site: "positive_spectral_whitener_from_gram",
2053 cross_rank: 0,
2054 coeff_dim: gram.nrows(),
2055 cross_frobenius: gram.iter().map(|v| v * v).sum::<f64>().sqrt(),
2056 gram_spectrum: format!(
2057 "max eigenvalue {max_eval:.3e} (min {min_ev:.3e}, spectral tolerance {tol:.3e})"
2058 ),
2059 });
2060 }
2061 // `eigh` returns eigenvalues in ascending order, so the largest `keep`
2062 // eigenvalues live at the tail.
2063 let eig_start = eigenvalues.len() - keep;
2064 let kept_vectors = eigenvectors.slice(s![.., eig_start..]).to_owned();
2065 let mut inv_sqrt = Array2::<f64>::zeros((keep, keep));
2066 for (out_i, eig_i) in (eig_start..eigenvalues.len()).enumerate() {
2067 inv_sqrt[[out_i, out_i]] = 1.0 / eigenvalues[eig_i].sqrt();
2068 }
2069 Ok(fast_ab(&kept_vectors, &inv_sqrt))
2070}
2071
2072pub(crate) fn stabilized_orthogonality_transform_from_gram(
2073 gram: &Array2<f64>,
2074 transform: &Array2<f64>,
2075) -> Result<Array2<f64>, BasisError> {
2076 let constrained_gram = {
2077 let gt = fast_ab(gram, transform);
2078 fast_atb(transform, >)
2079 };
2080 let whitening = positive_spectral_whitener_from_gram(&constrained_gram)?;
2081 Ok(fast_ab(transform, &whitening))
2082}
2083
2084pub(crate) fn orthogonality_transform_from_cross_and_gram(
2085 constraint_cross: &Array2<f64>,
2086 gram: &Array2<f64>,
2087) -> Result<Array2<f64>, BasisError> {
2088 // Compute null(M^T) directly on M = B^T W C (k × q) via column-pivoted QR.
2089 // Working in the original k-dim coefficient space rather than first
2090 // whitening by B^T B avoids a fundamental failure mode: when B is heavily
2091 // collinear, `positive_spectral_whitener_from_gram` truncates the design
2092 // column-space to a `keep`-dim subspace, and if `keep <= q` the subsequent
2093 // nullspace search has no room — even though dim null(M^T) = k - rank(M)
2094 // ≥ k - q is always positive when k > q. The constraint nullspace is a
2095 // property of M alone; conditioning of the design only matters for the
2096 // downstream stabilization of B*K_raw.
2097 let k = constraint_cross.nrows();
2098 if k == 0 {
2099 return Err(BasisError::InsufficientColumnsForConstraint { found: 0 });
2100 }
2101 let (transform_raw, rank) = rrqr_nullspace_basis(constraint_cross, default_rrqr_rank_alpha())
2102 .map_err(BasisError::LinalgError)?;
2103 if rank >= k || transform_raw.ncols() == 0 {
2104 return Err(BasisError::ConstraintNullspaceCollapsed {
2105 site: "orthogonality_transform_from_cross_and_gram",
2106 cross_rank: rank,
2107 coeff_dim: k,
2108 cross_frobenius: constraint_cross.iter().map(|v| v * v).sum::<f64>().sqrt(),
2109 gram_spectrum: "not computed (structural cross-rank collapse: null(Mᵀ) is empty, \
2110 so no constrained design exists to eigendecompose)"
2111 .to_string(),
2112 });
2113 }
2114
2115 // Make the constrained design B*K_raw orthonormal under the W-inner product.
2116 // If the constrained Gram K_raw^T G K_raw is rank-deficient (because some
2117 // directions in null(M^T) collapse under B), the spectral whitener drops
2118 // them — that is the right behavior: a degenerate column never contributes
2119 // to B's column space and shouldn't appear in the reparameterized basis.
2120 stabilized_orthogonality_transform_from_gram(gram, &transform_raw)
2121}
2122
2123pub fn orthogonality_transform_for_design(
2124 design: &DesignMatrix,
2125 constraint_matrix: ArrayView2<'_, f64>,
2126 weights: Option<ArrayView1<'_, f64>>,
2127) -> Result<Array2<f64>, BasisError> {
2128 let k = design.ncols();
2129 if k == 0 {
2130 return Err(BasisError::InsufficientColumnsForConstraint { found: 0 });
2131 }
2132 let q = constraint_matrix.ncols();
2133 if q == 0 {
2134 return Ok(Array2::eye(k));
2135 }
2136 let (constraint_cross, gram) = design_cross_and_gram(design, constraint_matrix, weights)?;
2137 orthogonality_transform_from_cross_and_gram(&constraint_cross, &gram)
2138}