Skip to main content

gam_terms/smooth/
term_specs.rs

1use coefficient_transforms::{
2    convex_divided_difference_transform_matrix, cumulative_exp, cumulative_sum_transform_matrix,
3    second_cumulative_exp,
4};
5
6pub use error::SmoothError;
7
8use input_standardization::{
9    apply_input_standardization, compensate_length_scale_for_standardization,
10    compensate_optional_length_scale_for_standardization, compute_spatial_input_scales,
11};
12
13use shape_constraints::{
14    build_shape_constraint_design_1d, build_shape_linear_constraints_1d,
15    merge_linear_constraints_global, shape_lower_bounds_local, shape_order_and_sign,
16    shape_supports_basis, shape_uses_box_reparameterization,
17};
18
19pub fn describe_thin_plate_center_request(strategy: &CenterStrategy) -> String {
20    match strategy {
21        CenterStrategy::Auto(inner) => describe_thin_plate_center_request(inner),
22        CenterStrategy::UserProvided(centers) => format!("{} centers", centers.nrows()),
23        CenterStrategy::EqualMass { num_centers }
24        | CenterStrategy::EqualMassCovarRepresentative { num_centers }
25        | CenterStrategy::FarthestPoint { num_centers }
26        | CenterStrategy::KMeans { num_centers, .. } => format!("{num_centers} centers"),
27        CenterStrategy::UniformGrid { points_per_dim } => {
28            format!("uniform grid with {points_per_dim} points per dimension")
29        }
30    }
31}
32
33pub fn rewrite_thin_plate_knots_error(
34    err: BasisError,
35    termname: &str,
36    feature_count: usize,
37    spec: &ThinPlateBasisSpec,
38) -> BasisError {
39    match err {
40        // Polynomial-nullspace shortfall reported directly by the kernel
41        // builder ("thin-plate spline requires at least N centers to span ...").
42        BasisError::InvalidInput(msg)
43            if msg.contains("thin-plate spline requires at least")
44                && (msg.contains("centers to span") || msg.contains("knots to span")) =>
45        {
46            let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
47            let requested = describe_thin_plate_center_request(&spec.center_strategy);
48            BasisError::InvalidInput(format!(
49                "joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
50            ))
51        }
52        // Insufficient-rows shortfall raised by `select_thin_plate_knots` when
53        // the requested center count exceeds the available row count. Rewrite
54        // it in term language so the diagnostic points at the smooth term and
55        // the polynomial-nullspace minimum the user needs to satisfy.
56        BasisError::InvalidInput(msg)
57            if msg.starts_with("requested ") && msg.contains(" knots but only ") =>
58        {
59            let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
60            let requested = describe_thin_plate_center_request(&spec.center_strategy);
61            BasisError::InvalidInput(format!(
62                "joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
63            ))
64        }
65        other => other,
66    }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
70pub enum ShapeConstraint {
71    None,
72    MonotoneIncreasing,
73    MonotoneDecreasing,
74    Convex,
75    Concave,
76}
77
78/// Parse a shape-constraint string into a [`ShapeConstraint`].
79///
80/// This is the single source of truth shared by the formula DSL
81/// (`s(x, shape=...)`) and the `smooths={...}` override path
82/// (`Smooth.shape_constraint`). The accepted spellings cover the canonical
83/// Python `ShapeConstraintLiteral` strings exactly
84/// (`"none"` / `"monotone_increasing"` / `"monotone_decreasing"` /
85/// `"convex"` / `"concave"`) plus a few common aliases. Hyphens and case are
86/// normalized, so `"Monotone-Increasing"` and `"mono_inc"` both resolve to
87/// [`ShapeConstraint::MonotoneIncreasing`].
88pub fn parse_shape_constraint(raw: &str) -> Result<ShapeConstraint, String> {
89    let normalized = raw.trim().to_ascii_lowercase().replace('-', "_");
90    match normalized.as_str() {
91        "" | "none" => Ok(ShapeConstraint::None),
92        "monotone_increasing" | "monotonic_increasing" | "increasing" | "mono_inc" | "mpi" => {
93            Ok(ShapeConstraint::MonotoneIncreasing)
94        }
95        "monotone_decreasing" | "monotonic_decreasing" | "decreasing" | "mono_dec" | "mpd" => {
96            Ok(ShapeConstraint::MonotoneDecreasing)
97        }
98        "convex" | "cvx" => Ok(ShapeConstraint::Convex),
99        "concave" | "ccv" => Ok(ShapeConstraint::Concave),
100        other => Err(format!(
101            "unknown shape constraint {other:?}; expected one of \
102             \"none\", \"monotone_increasing\", \"monotone_decreasing\", \
103             \"convex\", \"concave\""
104        )),
105    }
106}
107
108impl ShapeConstraint {
109    /// Canonical formula-DSL spelling, i.e. the text emitted into
110    /// `s(x, shape=...)`. Round-trips through [`parse_shape_constraint`].
111    pub fn dsl_str(&self) -> &'static str {
112        match self {
113            ShapeConstraint::None => "none",
114            ShapeConstraint::MonotoneIncreasing => "monotone_increasing",
115            ShapeConstraint::MonotoneDecreasing => "monotone_decreasing",
116            ShapeConstraint::Convex => "convex",
117            ShapeConstraint::Concave => "concave",
118        }
119    }
120}
121
122/// Smooth-term head keywords recognised by the formula DSL. A `shape=` option
123/// may be attached to any term whose head is one of these.
124pub const SMOOTH_HEAD_KEYWORDS: [&str; 11] = [
125    "s",
126    "smooth",
127    "te",
128    "tensor",
129    "thinplate",
130    "tps",
131    "duchon",
132    "matern",
133    "sphere",
134    "bs",
135    "bspline",
136];
137
138/// Rewrite smooth-term calls in `formula` so each named smooth carries a
139/// `shape=<kind>` option understood by the formula DSL.
140///
141/// `constraints` pairs the smooth-term text as it appears in the formula
142/// (e.g. `"s(x)"` or `"s(x, type=duchon, centers=8)"`) with a shape-constraint
143/// spelling accepted by [`parse_shape_constraint`]; comparison is exact after
144/// whitespace removal. A `"none"` constraint is a no-op. Referencing a term not
145/// present in the formula is an error.
146///
147/// This is the single source of truth for the `gamfit.fit(..., constraints=…)`
148/// rewrite — the Python wrapper only marshals the mapping across the FFI and
149/// holds no formula-parsing or alias-normalization logic of its own.
150pub fn apply_shape_constraints_to_formula(
151    formula: &str,
152    constraints: &[(String, String)],
153) -> Result<String, String> {
154    use std::collections::{BTreeMap, BTreeSet};
155
156    if constraints.is_empty() {
157        return Ok(formula.to_string());
158    }
159    let strip_ws = |s: &str| -> String { s.chars().filter(|c| !c.is_whitespace()).collect() };
160
161    // Whitespace-stripped term text -> canonical shape spelling.
162    let mut wanted: BTreeMap<String, &'static str> = BTreeMap::new();
163    // Whitespace-stripped term text -> original key (for error labels).
164    let mut originals: BTreeMap<String, String> = BTreeMap::new();
165    for (key, kind_raw) in constraints {
166        let kind = parse_shape_constraint(kind_raw)?;
167        let nk = strip_ws(key);
168        originals.entry(nk.clone()).or_insert_with(|| key.clone());
169        if kind != ShapeConstraint::None {
170            wanted.insert(nk, kind.dsl_str());
171        }
172    }
173    if wanted.is_empty() {
174        return Ok(formula.to_string());
175    }
176
177    let chars: Vec<char> = formula.chars().collect();
178    let n = chars.len();
179    let is_ident = |c: char| c.is_ascii_alphanumeric() || c == '_';
180
181    let mut out = String::with_capacity(formula.len() + 32);
182    let mut matched: BTreeSet<String> = BTreeSet::new();
183    let mut i = 0usize;
184    while i < n {
185        // Locate the next smooth-term head (`<keyword> \s* (`) at or after `i`,
186        // respecting word boundaries so `abs(` never matches the `s(` head.
187        let mut head: Option<(usize, usize)> = None; // (head_start, paren_index)
188        let mut p = i;
189        while p < n {
190            let boundary = p == 0 || !is_ident(chars[p - 1]);
191            if boundary {
192                for kw in SMOOTH_HEAD_KEYWORDS.iter() {
193                    let klen = kw.chars().count();
194                    if p + klen > n || chars[p..p + klen].iter().collect::<String>() != **kw {
195                        continue;
196                    }
197                    let mut q = p + klen;
198                    while q < n && chars[q].is_whitespace() {
199                        q += 1;
200                    }
201                    if q < n && chars[q] == '(' {
202                        head = Some((p, q));
203                        break;
204                    }
205                }
206            }
207            if head.is_some() {
208                break;
209            }
210            p += 1;
211        }
212        let (head_start, paren_open) = match head {
213            Some(h) => h,
214            None => {
215                out.extend(chars[i..].iter());
216                break;
217            }
218        };
219        out.extend(chars[i..head_start].iter());
220
221        // Find the matching close paren, honoring nesting and string literals.
222        let body_start = paren_open + 1;
223        let mut depth = 1i32;
224        let mut j = body_start;
225        let mut in_str: Option<char> = None;
226        let mut closed = false;
227        while j < n {
228            let ch = chars[j];
229            if let Some(quote) = in_str {
230                if ch == quote {
231                    in_str = None;
232                }
233            } else if ch == '\'' || ch == '"' {
234                in_str = Some(ch);
235            } else if ch == '(' {
236                depth += 1;
237            } else if ch == ')' {
238                depth -= 1;
239                if depth == 0 {
240                    closed = true;
241                    break;
242                }
243            }
244            j += 1;
245        }
246
247        if !closed {
248            // Unbalanced — emit the remainder verbatim; the DSL parser will
249            // produce the canonical error.
250            out.extend(chars[head_start..].iter());
251            break;
252        }
253
254        let term_text: String = chars[head_start..=j].iter().collect();
255
256        let key_norm = strip_ws(&term_text);
257
258        match wanted.get(&key_norm) {
259            None => out.extend(chars[head_start..=j].iter()),
260            Some(kind) => {
261                let head_paren: String = chars[head_start..body_start].iter().collect();
262                let inside: String = chars[body_start..j].iter().collect();
263                let inside = inside.trim();
264                if inside.is_empty() {
265                    out.push_str(&format!("{head_paren}shape={kind})"));
266                } else {
267                    out.push_str(&format!("{head_paren}{inside}, shape={kind})"));
268                }
269                matched.insert(key_norm);
270            }
271        }
272
273        i = j + 1;
274    }
275
276    let mut missing: Vec<String> = wanted
277        .keys()
278        .filter(|k| !matched.contains(*k))
279        .map(|k| originals.get(k).cloned().unwrap_or_else(|| k.clone()))
280        .collect();
281
282    if !missing.is_empty() {
283        missing.sort();
284        return Err(format!(
285            "shape constraints referenced smooth term(s) not found in formula: {}",
286            missing.join(", ")
287        ));
288    }
289
290    Ok(out)
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub enum BySmoothKind {
295    Numeric,
296    Level { level_bits: u64 },
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub enum SmoothBasisSpec {
301    /// Row-gated wrapper used for mgcv-style ``by=`` smooths.
302    ///
303    /// ``ByNumeric`` multiplies the inner smooth by a numeric column.
304    /// ``ByLevel`` keeps the inner smooth active only for rows whose encoded
305    /// categorical value has the stored bit pattern.  Unordered factor-by
306    /// smooths are represented as one independent ``ByLevel`` term per level.
307    ///
308    /// `kind` preserves the compact structural discriminator, while `by`
309    /// carries the full row-gating spec used to build the local design.
310    ByVariable {
311        inner: Box<SmoothBasisSpec>,
312        by_col: usize,
313        kind: BySmoothKind,
314        by: ByVariableSpec,
315    },
316    /// Sum-to-zero factor smooth (`bs="sz"`): with L levels, estimate L-1
317    /// deviation coefficient blocks and use the final level as the negative
318    /// sum of the others, enforcing coefficient-wise zero sums across levels.
319    FactorSumToZero {
320        inner: Box<SmoothBasisSpec>,
321        by_col: usize,
322        levels: Vec<u64>,
323        /// Global-orthogonality column map `Z` captured at fit time when this
324        /// term overlapped an owner smooth (`s(x) + s(g, x, bs=sz)`, #978):
325        /// the hierarchical-ownership pass residualized this term's realized
326        /// design as `X ← X·Z`, shrinking its coefficient block. `Z` depends
327        /// on the *training-row* owner designs, so prediction cannot rederive
328        /// it — it must be persisted and replayed
329        /// (`apply_global_smooth_identifiability` consumes it verbatim).
330        /// Chart convention: `Z` lives in the post-restack, post-joint-null-Q
331        /// coordinates — the raw `sz` rebuild reapplies `Q` deterministically
332        /// (#700), then `Z` applies on top. `None` for non-overlapping terms.
333        #[serde(default)]
334        frozen_global_orthogonality: Option<Array2<f64>>,
335    },
336    BSpline1D {
337        feature_col: usize,
338        spec: BSplineBasisSpec,
339    },
340    /// A smooth modulated by a `by=` variable. Numeric `by` scales one inner
341    /// smooth; factor `by` replicates the inner smooth by level.
342    BySmooth {
343        smooth: Box<SmoothBasisSpec>,
344        by_kind: ByVarKind,
345    },
346    /// Factor-smooth interaction families (`bs="fs"`, `bs="sz"`) and
347    /// random slopes (`bs="re"`).
348    FactorSmooth { spec: FactorSmoothSpec },
349    ThinPlate {
350        feature_cols: Vec<usize>,
351        spec: ThinPlateBasisSpec,
352        /// Per-column standard deviations used to standardize input dimensions
353        /// before kernel evaluation when d > 1. `None` means no standardization
354        /// (either d == 1 or explicitly disabled).
355        #[serde(default)]
356        input_scales: Option<Vec<f64>>,
357    },
358    Sphere {
359        feature_cols: Vec<usize>,
360        spec: SphericalSplineBasisSpec,
361    },
362    /// Constant-curvature (`M_κ`) geodesic-kernel smooth over κ-stereographic
363    /// chart coordinates (#944): one construction interpolating
364    /// S^d → ℝ^d → H^d through the spec's fixed κ. The Wahba S² smooth is the
365    /// structural template; the geometry comes from
366    /// `geometry::constant_curvature::ConstantCurvature`.
367    ConstantCurvature {
368        feature_cols: Vec<usize>,
369        spec: ConstantCurvatureBasisSpec,
370    },
371    Matern {
372        feature_cols: Vec<usize>,
373        spec: MaternBasisSpec,
374        #[serde(default)]
375        input_scales: Option<Vec<f64>>,
376    },
377    /// Measure-jet spline smooth: multiscale local-jet-residual energy of the
378    /// empirical measure (centers as μ-quadrature, masses as μ-weights — no
379    /// graph, mesh, or neighbor set inside the statistical object). The
380    /// feature columns are ambient coordinates of data concentrated near an
381    /// unknown low-dimensional, possibly stratified set.
382    MeasureJet {
383        feature_cols: Vec<usize>,
384        spec: MeasureJetBasisSpec,
385        #[serde(default)]
386        input_scales: Option<Vec<f64>>,
387    },
388    Duchon {
389        feature_cols: Vec<usize>,
390        spec: DuchonBasisSpec,
391        #[serde(default)]
392        input_scales: Option<Vec<f64>>,
393    },
394    Pca {
395        feature_cols: Vec<usize>,
396        basis_matrix: Array2<f64>,
397        centered: bool,
398        #[serde(default = "default_pca_smooth_penalty")]
399        smooth_penalty: f64,
400        #[serde(default)]
401        center_mean: Option<Array1<f64>>,
402        #[serde(default)]
403        pca_basis_path: Option<PathBuf>,
404        #[serde(default = "default_pca_chunk_size")]
405        chunk_size: usize,
406    },
407    /// Tensor-product smooth built from 1D B-spline marginals.
408    ///
409    /// This is the `te()`-style construction used when axes have different units/scales
410    /// (for example, space x time) and isotropic radial kernels are not appropriate.
411    TensorBSpline {
412        feature_cols: Vec<usize>,
413        spec: TensorBSplineSpec,
414    },
415}
416
417impl SmoothBasisSpec {
418    /// Conservative lower bound on the number of sample rows needed for this
419    /// smooth basis to have a well-posed REML fit.
420    ///
421    /// Each basis kind answers the question for itself, so the workflow does
422    /// not have to know how many columns a B-spline, tensor product, PCA
423    /// projection, or spatial kernel emits. The contract is a *lower bound*:
424    /// returning too small a number is permitted (the inner solver will catch
425    /// any genuine n-vs-rank failure that slips past); returning too large a
426    /// number is a regression because it rejects legitimate fits.
427    ///
428    /// Rationale: B-spline / tensor / PCA bases have a closed-form column
429    /// count, so we use the exact dimension. Radial bases (TPS, Matern,
430    /// Duchon, Sphere) and factor smooths choose their column count from the
431    /// data (`heuristic_centers`, `unique_count`); we fall back to a small
432    /// constant floor because a fit on fewer than five rows cannot stabilise
433    /// any radial smooth regardless of the configured kernel scale.
434    pub fn min_sample_rows(&self) -> usize {
435        // Floor used for data-driven bases whose column count is not known
436        // from the spec alone. Five rows is the minimum at which the inner
437        // pivot/QR + REML smoothing-parameter search has any chance of being
438        // well-posed for a non-parametric smooth.
439        const RADIAL_FLOOR: usize = 5;
440
441        match self {
442            Self::ByVariable { inner, .. } => inner.min_sample_rows(),
443            Self::FactorSumToZero { inner, levels, .. } => {
444                // L-1 independent deviation blocks each carrying the inner
445                // basis dimension. Skip the levels-multiplier if it doesn't
446                // bring more rows; we want the *lower bound* not the rank.
447                let inner_min = inner.min_sample_rows();
448                let lvls = levels.len().saturating_sub(1).max(1);
449                inner_min.saturating_mul(lvls)
450            }
451            Self::BSpline1D { spec, .. } => bspline_basis_min_rows(spec),
452            Self::BySmooth { smooth, .. } => smooth.min_sample_rows(),
453            Self::FactorSmooth { spec } => {
454                // Replicates the marginal once per level; without a known
455                // level count we conservatively require at least the marginal
456                // basis dimension.
457                bspline_basis_min_rows(&spec.marginal)
458            }
459            Self::ThinPlate { .. }
460            | Self::Sphere { .. }
461            | Self::ConstantCurvature { .. }
462            | Self::Matern { .. }
463            | Self::MeasureJet { .. }
464            | Self::Duchon { .. } => RADIAL_FLOOR,
465            Self::Pca { basis_matrix, .. } => basis_matrix.ncols().max(1),
466            Self::TensorBSpline { spec, .. } => {
467                // A `te(...)` smooth is *penalized*: each margin carries a
468                // difference (wiggliness) penalty and the tensor inherits a
469                // Kronecker-sum penalty `S = Σ_i I ⊗ … ⊗ S_i ⊗ … ⊗ I`. The raw
470                // column count is the *product* of the per-marginal column
471                // counts, but that product is the lower bound for an
472                // *unpenalized* tensor regression — it is the number of rows you
473                // would need to identify every interaction column with no
474                // regularization. The penalty regularizes all of those
475                // interaction directions; only the combined penalty *null space*
476                // (the tensor product of the per-margin polynomial trends, a
477                // handful of columns) must be identified by the data, and the
478                // smoothing-parameter search shrinks the rest. The effective
479                // degrees of freedom of the fitted `te()` are therefore a small
480                // fraction of the column product, which is exactly why mgcv
481                // fits a default `te(x, y)` on a couple hundred rows.
482                //
483                // The honest *penalized* lower bound is the **sum** of the
484                // per-marginal column counts, not their product: a row floor of
485                // `Σ_i k_i` still guarantees enough data to identify each
486                // margin's additive main-effect (the largest sub-block the
487                // penalty cannot shrink to zero), while no longer conflating
488                // unpenalized column-count identifiability with penalized
489                // well-posedness. This accepts moderate-`n` penalized tensors
490                // (e.g. a 20×20 default basis on n=200) yet still rejects a
491                // genuinely undersized fit where `n < Σ_i k_i` and even the
492                // additive part is rank-deficient.
493                //
494                // Binary / low-cardinality margins (#724): gam will accept a
495                // `te(x, badh)` whose `badh ∈ {0, 1}` margin nominally requests
496                // more basis columns than `badh` has unique values, where mgcv
497                // refuses the unpenalized term as ill-posed ("badh has
498                // insufficient unique values to support k knots"). This is
499                // correct-by-design, *not* a degenerate fit: the marginal
500                // wiggliness penalty on the `badh` axis has a null space that is
501                // exactly its identifiable trend (the two cell means of a binary
502                // covariate), and the Kronecker-sum penalty shrinks every tensor
503                // column outside that null space toward zero. The resulting fit
504                // is the well-posed "per-level `x` smooth + binary main effect"
505                // that mgcv reaches only after manually collapsing the basis —
506                // gam reaches it automatically because the penalty, not the raw
507                // column count, sets the effective rank. A genuinely
508                // rank-deficient design (penalty null space wider than the data
509                // can support) is still caught downstream by the inner pivoted
510                // factorization, which owns the exact n-vs-rank decision; this
511                // pre-fit gate only refuses the grossly-undersized formula.
512                let mut total: usize = 0;
513                for marginal in &spec.marginalspecs {
514                    let m = bspline_basis_min_rows(marginal);
515                    total = total.saturating_add(m.max(1));
516                }
517                total.max(RADIAL_FLOOR)
518            }
519        }
520    }
521
522    /// Stable structural discriminant for warm-start cache keying (#869).
523    ///
524    /// Two smooths that produce different bases / penalty structures must map
525    /// to different strings here so they cannot collide on the persistent
526    /// warm-start `cache_key` (which is otherwise blind to topology: it hashes
527    /// only the raw input column count, so e.g. `sphere` vs `torus` vs
528    /// `euclidean` candidates fit on the *same* data would otherwise share one
529    /// key and cross-contaminate each other's β/ρ seed). The string is the
530    /// topology identity, not the fitted coefficients, so same-topology refits
531    /// (the screen→full-refit cascade) still hit the same key and reuse work.
532    pub fn structural_kind(&self) -> &'static str {
533        match self {
534            Self::ByVariable { .. } => "by_variable",
535            Self::FactorSumToZero { .. } => "factor_sum_to_zero",
536            Self::BSpline1D { .. } => "bspline_1d",
537            Self::BySmooth { .. } => "by_smooth",
538            Self::FactorSmooth { .. } => "factor_smooth",
539            Self::ThinPlate { .. } => "thin_plate",
540            Self::Sphere { .. } => "sphere",
541            Self::ConstantCurvature { .. } => "constant_curvature",
542            Self::Matern { .. } => "matern",
543            Self::MeasureJet { .. } => "measurejet",
544            Self::Duchon { .. } => "duchon",
545            Self::Pca { .. } => "pca",
546            Self::TensorBSpline { .. } => "tensor_bspline",
547        }
548    }
549
550    /// True for a tensor-product smooth that is only *marginally* centered
551    /// (`ti(...)`, [`TensorBSplineIdentifiability::MarginalSumToZero`]): its
552    /// per-margin sum-to-zero reparameterization `(B_xZ_x)⊗(B_zZ_z)` has ALREADY
553    /// removed each axis's main effect analytically (mgcv-identical), so its
554    /// main-effect removal is complete and it must take NO additional
555    /// owner-residualization block. Residualizing it a second time against the
556    /// realized main-effect designs is a grid-fragile no-op on an exact tensor
557    /// grid but eats genuine pure-interaction curvature off-grid (#1470).
558    pub fn is_marginally_centered_tensor(&self) -> bool {
559        matches!(
560            self,
561            Self::TensorBSpline { spec, .. }
562                if matches!(spec.identifiability, TensorBSplineIdentifiability::MarginalSumToZero)
563        )
564    }
565
566    /// Feature columns this basis consumes, used alongside [`structural_kind`]
567    /// to disambiguate two same-kind smooths on different axes. Wrapper
568    /// variants delegate to their inner basis.
569    pub fn structural_feature_cols(&self) -> Vec<usize> {
570        match self {
571            Self::ByVariable { inner, .. } | Self::FactorSumToZero { inner, .. } => {
572                inner.structural_feature_cols()
573            }
574            Self::BySmooth { smooth, .. } => smooth.structural_feature_cols(),
575            Self::FactorSmooth { .. } => Vec::new(),
576            Self::BSpline1D { feature_col, .. } => vec![*feature_col],
577            Self::ThinPlate { feature_cols, .. }
578            | Self::Sphere { feature_cols, .. }
579            | Self::ConstantCurvature { feature_cols, .. }
580            | Self::Matern { feature_cols, .. }
581            | Self::MeasureJet { feature_cols, .. }
582            | Self::Duchon { feature_cols, .. }
583            | Self::Pca { feature_cols, .. }
584            | Self::TensorBSpline { feature_cols, .. } => feature_cols.clone(),
585        }
586    }
587}
588
589/// Lower bound on the number of sample rows a 1D B-spline smooth needs for a
590/// well-posed *penalized* REML fit. Used as the per-smooth row floor in
591/// [`SmoothBasisSpec::min_sample_rows`].
592///
593/// For a *singly*-penalized smooth the floor is the full column count: the
594/// wiggliness penalty leaves the order-`m` polynomial trend unpenalized, and
595/// gam's original gate conservatively required enough rows for the whole basis.
596/// That conservative floor is kept here unchanged.
597///
598/// A *double*-penalized smooth (mgcv `select=TRUE`) is different: it adds a
599/// second penalty on the wiggliness penalty's null space, so even the
600/// polynomial trend is shrinkable toward zero and *nothing* in the basis
601/// requires unpenalized identification by the data — exactly the reasoning the
602/// `TensorBSpline` arm of [`SmoothBasisSpec::min_sample_rows`] already applies
603/// to a penalized tensor. Its honest floor is therefore a small stabilization
604/// constant, not the column count. This is what lets mgcv (and now gam) fit
605/// several `select=TRUE` smooths on a dataset whose row count is below the
606/// summed basis width (e.g. the n≈30 `wine_gamair` fold, 5 `ps` smooths,
607/// p≈51): the penalties, not the data, set the effective rank. The bounded
608/// outer REML loop still terminates, and the genuine n-vs-rank decision is
609/// owned downstream by the inner pivoted factorization. Without this, gam
610/// rejected the fit outright (or, before the gate existed, the outer REML loop
611/// wandered the flat overparameterized surface until the benchmark wall budget
612/// killed it — #1089).
613pub fn bspline_basis_min_rows(spec: &crate::basis::BSplineBasisSpec) -> usize {
614    use crate::basis::BSplineKnotSpec;
615    let columns = match &spec.knotspec {
616        BSplineKnotSpec::Generate {
617            num_internal_knots, ..
618        } => *num_internal_knots + spec.degree + 1,
619        BSplineKnotSpec::Automatic {
620            num_internal_knots: Some(k),
621            ..
622        } => *k + spec.degree + 1,
623        BSplineKnotSpec::Automatic {
624            num_internal_knots: None,
625            ..
626        } => {
627            // Knot count is data-derived (`default_internal_knot_count_for_data`).
628            // A minimal cubic basis is `degree + 2` columns; below that the
629            // basis cannot represent a non-parametric smooth.
630            spec.degree + 2
631        }
632        BSplineKnotSpec::Provided(knots) => knots.len().saturating_sub(spec.degree + 1).max(1),
633        // cr basis dimension equals the knot count (no degree offset).
634        BSplineKnotSpec::NaturalCubicRegression { knots } => knots.len(),
635        BSplineKnotSpec::PeriodicUniform { num_basis, .. } => *num_basis,
636    };
637    let columns = columns.max(spec.degree + 2);
638
639    if spec.double_penalty {
640        // Fully shrinkable basis: only a small stabilization floor must be
641        // identified by the data, capped by the actual column count.
642        const DOUBLE_PENALTY_FLOOR: usize = 2;
643        DOUBLE_PENALTY_FLOOR.min(columns).max(1)
644    } else {
645        columns
646    }
647}
648
649#[derive(Debug, Clone, Serialize, Deserialize)]
650pub enum ByVariableSpec {
651    Numeric,
652    Level { value_bits: u64, label: String },
653}
654
655
656#[derive(Debug, Clone, Serialize, Deserialize)]
657pub enum ByVarKind {
658    Numeric {
659        feature_col: usize,
660    },
661    Factor {
662        feature_col: usize,
663        ordered: bool,
664        frozen_levels: Option<Vec<u64>>,
665    },
666}
667
668#[derive(Debug, Clone, Serialize, Deserialize)]
669pub struct FactorSmoothSpec {
670    pub continuous_cols: Vec<usize>,
671    pub group_col: usize,
672    pub marginal: BSplineBasisSpec,
673    pub flavour: FactorSmoothFlavour,
674    pub group_frozen_levels: Option<Vec<u64>>,
675    /// Fit-time global-orthogonality chart `Z` for this term (`s(x) + fs(x, g)`
676    /// overlap residualization, #978), in the post-joint-null-`Q` coordinates
677    /// (the raw rebuild recomputes any `Q` itself; `fs` penalties are
678    /// typically full-rank so `Q` is absent). Training-row dependent, hence
679    /// persisted; replayed verbatim by `apply_global_smooth_identifiability`.
680    #[serde(default)]
681    pub frozen_global_orthogonality: Option<Array2<f64>>,
682}
683
684#[derive(Debug, Clone, Serialize, Deserialize)]
685pub enum FactorSmoothFlavour {
686    Fs { m_null_penalty_orders: Vec<usize> },
687    Sz,
688    Re,
689}
690
691#[derive(Debug, Default, Clone, Serialize, Deserialize)]
692pub struct TensorBSplineSpec {
693    pub marginalspecs: Vec<BSplineBasisSpec>,
694    #[serde(default)]
695    pub periods: Vec<Option<f64>>,
696    pub double_penalty: bool,
697    #[serde(default)]
698    pub identifiability: TensorBSplineIdentifiability,
699    #[serde(default)]
700    pub penalty_decomposition: TensorBSplinePenaltyDecomposition,
701}
702
703#[derive(Debug, Default, Clone, Serialize, Deserialize)]
704pub enum TensorBSplineIdentifiability {
705    None,
706    #[default]
707    SumToZero,
708    /// mgcv `ti(...)` semantics: a *tensor interaction* smooth that excludes the
709    /// marginal main effects. A sum-to-zero constraint is applied to **each
710    /// marginal basis independently** before forming the tensor product, so the
711    /// resulting column space contains no function of a single variable alone —
712    /// only the pure interaction survives. The realized identifiability
713    /// transform is the Kronecker product `Z = Z₀ ⊗ Z₁ ⊗ … ⊗ Z_{d-1}` of the
714    /// per-margin sum-to-zero null-space bases, which is exactly the
715    /// reparameterization that turns the full-tensor design into the tensor
716    /// product of the centered margins.
717    MarginalSumToZero,
718    FrozenTransform {
719        transform: Array2<f64>,
720    },
721}
722
723#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
724pub enum TensorBSplinePenaltyDecomposition {
725    /// mgcv `te(...)`: one overlapping Kronecker-product penalty per margin,
726    /// `S_j` embedded against identities in the other tensor factors.
727    #[default]
728    MarginalKroneckerSum,
729    /// mgcv `t2(...)`: split every marginal coefficient space into penalized
730    /// range and penalty-null subspaces, then emit one disjoint tensor-subspace
731    /// penalty for every non-empty penalized/null combination.
732    Separable,
733}
734
735#[derive(Debug, Clone, Serialize, Deserialize)]
736pub struct SmoothTermSpec {
737    pub name: String,
738    pub basis: SmoothBasisSpec,
739    pub shape: ShapeConstraint,
740    /// Joint-null absorption rotation captured at fit time. `Some(Q)` means
741    /// the fitted coefficient vector lives in `γ`-coordinates with
742    /// `β_raw = Q · γ`; prediction must rotate the raw-basis design via
743    /// `X_new = X_new_raw · Q` to match. `None` means either the smooth had
744    /// no joint null space (penalty already full-rank) or rotation was
745    /// suppressed (smooth carries shape constraints whose cone geometry
746    /// would not survive an arbitrary orthogonal rotation). Persisted so
747    /// `save → load → predict` is bit-equivalent to in-memory prediction.
748    #[serde(default)]
749    pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
750}
751
752#[derive(Debug, Clone)]
753pub struct SmoothTerm {
754    pub name: String,
755    pub coeff_range: Range<usize>,
756    pub shape: ShapeConstraint,
757    pub penalties_local: Vec<Array2<f64>>,
758    pub nullspace_dims: Vec<usize>,
759    pub penaltyinfo_local: Vec<PenaltyInfo>,
760    pub metadata: BasisMetadata,
761    /// Optional term-local lower bounds for constrained coefficients.
762    /// `-inf` means unconstrained.
763    pub lower_bounds_local: Option<Array1<f64>>,
764    /// Optional term-local inequality constraints in local coefficient coordinates.
765    /// `A_local * beta_local >= b_local`.
766    pub linear_constraints_local: Option<LinearInequalityConstraints>,
767    /// Optional factored tensor-product representation preserved for operator-backed
768    /// assembly in the main design builder.
769    pub kronecker_factored: Option<KroneckerFactoredBasis>,
770    /// Joint-null absorption rotation. `Some(Q)` records the orthonormal
771    /// `(p_local × p_local)` matrix that was applied to this term's design
772    /// and per-block penalties at construction time:
773    /// `term_design ← X_raw · Q`, `penalties_local[k] ← Qᵀ · S_raw · Q`.
774    /// The smooth's coefficient vector therefore lives in the rotated
775    /// (`γ`) coordinate system, with `β_raw = Q · γ` recovering the raw
776    /// pre-rotation parameterization. `None` means either no joint null
777    /// space (penalty already full-rank) or rotation was suppressed —
778    /// suppression fires when the smooth carries shape constraints
779    /// (lower bounds or local linear inequalities) that would lose their
780    /// cone geometry under a general orthogonal rotation.
781    ///
782    /// Prediction-side replay: callers building a new-data design `X_new_raw`
783    /// from the *raw* basis must call [`SmoothTerm::apply_rotation_to_predict`]
784    /// (or equivalent) to obtain `X_new = X_new_raw · Q` matching this
785    /// term's coefficient system.
786    ///
787    /// Persistence replay: `freeze_term_collection_from_design` copies this
788    /// rotation into `SmoothTermSpec`, which is serialized with fitted-model
789    /// payloads and reused by the predict-time basis builder. Saved models
790    /// therefore replay the same `X_new_raw · Q` transform as in-memory
791    /// prediction.
792    pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
793    /// Global-orthogonality transform that `apply_global_smooth_identifiability`
794    /// applied to this term's design but could NOT embed into `metadata`
795    /// (factor-smooth kinds: `sz` metadata is per-marginal, `fs` metadata has
796    /// no transform slot — #978). `freeze_term_collection_from_design` copies
797    /// it onto the term's basis spec (`frozen_global_orthogonality`) so the
798    /// predict-side rebuild replays it instead of emitting the unresidualized
799    /// (wider) design that the fitted coefficients no longer match.
800    /// Chart convention is per kind: post-`Q` `Z` for `sz` (the raw rebuild
801    /// reapplies `Q` itself, #700), full `Q·Z` chart for `fs`.
802    pub unabsorbed_global_orthogonality: Option<Array2<f64>>,
803}
804
805impl SmoothTerm {
806    /// Apply the joint-null absorption rotation to a raw new-data design
807    /// matrix, returning `X_new_raw · Q` when this term was rotated at
808    /// fit time, or `X_new_raw` unchanged when no rotation was applied.
809    ///
810    /// Callers in the prediction path: after building the smooth's basis
811    /// at new data via the *raw* basis builder (the same builder used at
812    /// fit time, applied to `x_new` instead of the training rows), call
813    /// this method on the resulting matrix before forming `X · β`. The
814    /// fitted `β` lives in `γ`-coordinates if Q was applied; multiplying
815    /// the un-rotated `X_new_raw` by `β` would give a wrong η.
816    ///
817    /// Returns an error if the raw design's column count does not match
818    /// the rotation's `p_local`. The width invariant must hold: the raw
819    /// basis builder MUST emit the same `p_local` columns that the
820    /// fit-time builder did, and the rotation is `(p_local × p_local)`.
821    pub fn apply_rotation_to_predict(
822        &self,
823        x_new_raw: Array2<f64>,
824    ) -> Result<Array2<f64>, BasisError> {
825        let Some(rot) = self.joint_null_rotation.as_ref() else {
826            return Ok(x_new_raw);
827        };
828        let p_local = rot.rotation.nrows();
829        if x_new_raw.ncols() != p_local {
830            crate::bail_dim_basis!(
831                "joint-null rotation replay for term '{}': raw design has {} columns, \
832                 rotation expects {} (the raw basis builder must emit the same column \
833                 count as at fit time)",
834                self.name,
835                x_new_raw.ncols(),
836                p_local,
837            );
838        }
839        Ok(gam_linalg::faer_ndarray::fast_ab(
840            &x_new_raw,
841            &rot.rotation,
842        ))
843    }
844
845    /// Dimension of the **joint** null space of this term's active penalties:
846    /// the coefficient directions penalized by *no* penalty. The smooth-component
847    /// Wald test ([`crate::inference::smooth_test::wood_smooth_test`]) treats this
848    /// many leading coefficients as genuine unpenalized fixed effects and tests
849    /// them at full rank; the remainder is the penalized sub-block tested with a
850    /// rank-`≈EDF` truncated pseudo-inverse.
851    ///
852    /// Because every penalty block `S_k` is positive semi-definite,
853    /// `vᵀ(Σ_k S_k)v = Σ_k vᵀ S_k v = 0` iff `S_k v = 0` for *every* `k`; the
854    /// joint null space is therefore exactly `null(Σ_k S_k)`, of dimension
855    /// `p_local − rank(Σ_k S_k)`. This is the **intersection** of the per-penalty
856    /// null spaces, not their sum.
857    ///
858    /// Summing the per-penalty `nullspace_dims` instead (the historical defect
859    /// behind #1360) *unions* the null spaces and badly over-counts: a
860    /// double-penalty smooth carries a bending penalty (null space = its
861    /// polynomial part) plus a complementary null-space ridge (which penalizes
862    /// exactly that polynomial part), so the two null spaces are disjoint and the
863    /// joint null space is empty — yet the per-penalty dims sum to nearly
864    /// `p_local`. Feeding that inflated count to the Wald test makes it test
865    /// almost the whole shrunk block at full rank, manufacturing overwhelming
866    /// "significance" for a term the fit drove to ~0 EDF.
867    pub fn wald_unpenalized_dim(&self) -> usize {
868        joint_unpenalized_dim(
869            self.coeff_range.len(),
870            &self.penalties_local,
871            &self.nullspace_dims,
872        )
873    }
874}
875
876/// Numeric core of [`SmoothTerm::wald_unpenalized_dim`]: the dimension of the
877/// joint null space `∩_k null(S_k) = null(Σ_k S_k)` of a term's local penalty
878/// blocks, with a conservative fallback when a penalty is not materialized as a
879/// full `p_local × p_local` matrix (e.g. a Kronecker tensor factor).
880pub fn joint_unpenalized_dim(
881    p_local: usize,
882    penalties_local: &[Array2<f64>],
883    nullspace_dims: &[usize],
884) -> usize {
885    use gam_linalg::faer_ndarray::FaerEigh;
886    if p_local == 0 {
887        return 0;
888    }
889    if penalties_local.is_empty() {
890        // No penalty ⇒ a wholly unpenalized (fixed-effect) block.
891        return p_local;
892    }
893    // Sum the penalties that are materialized as full `p_local × p_local`
894    // blocks (the common smooth case). The covariance block the Wald test
895    // slices lives in this same coefficient basis (post joint-null rotation),
896    // so the rank is computed in the right metric.
897    let mut s_total = Array2::<f64>::zeros((p_local, p_local));
898    let mut materialized = 0usize;
899    for s in penalties_local {
900        if s.nrows() == p_local && s.ncols() == p_local {
901            s_total += s;
902            materialized += 1;
903        }
904    }
905    if materialized == penalties_local.len() {
906        let symmetric = {
907            let transpose = s_total.t().to_owned();
908            (&s_total + &transpose) * 0.5
909        };
910        if let Ok((evals, _)) = symmetric.eigh(faer::Side::Lower) {
911            let max_abs = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
912            if max_abs == 0.0 {
913                // All penalties identically zero ⇒ unpenalized block.
914                return p_local;
915            }
916            let tol = max_abs * (p_local as f64) * 1e-12;
917            let rank = evals.iter().filter(|&&v| v > tol).count();
918            return p_local.saturating_sub(rank);
919        }
920    }
921    // Conservative fallback when a penalty is not a materialized full block
922    // (e.g. a Kronecker tensor factor): with ≥2 active penalties the joint
923    // null space is almost always empty (the only over-rejecting direction);
924    // with a single penalty it is exactly that penalty's own null space.
925    if penalties_local.len() >= 2 {
926        0
927    } else {
928        nullspace_dims
929            .iter()
930            .copied()
931            .min()
932            .unwrap_or(0)
933            .min(p_local)
934    }
935}
936
937#[derive(Debug, Clone, Serialize, Deserialize)]
938pub struct PenaltyBlockInfo {
939    pub global_index: usize,
940    pub termname: Option<String>,
941    pub penalty: PenaltyInfo,
942}
943
944#[derive(Debug, Clone, Serialize, Deserialize)]
945pub struct DroppedPenaltyBlockInfo {
946    pub termname: Option<String>,
947    pub penalty: PenaltyInfo,
948}
949
950#[derive(Debug, Clone)]
951pub struct SmoothDesign {
952    pub term_designs: Vec<DesignMatrix>,
953    /// Per-term block-local penalties.  Each `col_range` is relative to the
954    /// smooth block (i.e. indexing into the concatenation of `term_designs`).
955    pub penalties: Vec<BlockwisePenalty>,
956    pub nullspace_dims: Vec<usize>,
957    pub penaltyinfo: Vec<PenaltyBlockInfo>,
958    pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
959    pub terms: Vec<SmoothTerm>,
960    /// Optional smooth-block lower bounds in smooth coefficient coordinates.
961    /// Length equals `total_smooth_cols()` when present.
962    pub coefficient_lower_bounds: Option<Array1<f64>>,
963    /// Optional smooth-block inequality constraints:
964    /// `A_smooth * beta_smooth >= b`.
965    pub linear_constraints: Option<LinearInequalityConstraints>,
966}
967
968impl SmoothDesign {
969    pub fn total_smooth_cols(&self) -> usize {
970        self.term_designs.iter().map(DesignMatrix::ncols).sum()
971    }
972    pub fn nrows(&self) -> usize {
973        self.term_designs.first().map_or(0, DesignMatrix::nrows)
974    }
975}
976
977#[derive(Debug, Clone)]
978pub struct RawSmoothDesign {
979    pub term_designs: Vec<DesignMatrix>,
980    /// Per-term block-local penalties.  Each `col_range` is relative to the
981    /// smooth block (i.e. indexing into the concatenation of `term_designs`).
982    pub penalties: Vec<BlockwisePenalty>,
983    pub nullspace_dims: Vec<usize>,
984    pub penaltyinfo: Vec<PenaltyBlockInfo>,
985    pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
986    pub terms: Vec<SmoothTerm>,
987    pub coefficient_lower_bounds: Option<Array1<f64>>,
988    pub linear_constraints: Option<LinearInequalityConstraints>,
989}
990
991impl RawSmoothDesign {
992    pub fn total_smooth_cols(&self) -> usize {
993        self.term_designs.iter().map(DesignMatrix::ncols).sum()
994    }
995    pub fn nrows(&self) -> usize {
996        self.term_designs.first().map_or(0, DesignMatrix::nrows)
997    }
998}
999
1000impl From<RawSmoothDesign> for SmoothDesign {
1001    fn from(value: RawSmoothDesign) -> Self {
1002        Self {
1003            term_designs: value.term_designs,
1004            penalties: value.penalties,
1005            nullspace_dims: value.nullspace_dims,
1006            penaltyinfo: value.penaltyinfo,
1007            dropped_penaltyinfo: value.dropped_penaltyinfo,
1008            terms: value.terms,
1009            coefficient_lower_bounds: value.coefficient_lower_bounds,
1010            linear_constraints: value.linear_constraints,
1011        }
1012    }
1013}
1014
1015#[derive(Debug, Default, Clone, Serialize, Deserialize)]
1016pub enum BoundedCoefficientPriorSpec {
1017    #[default]
1018    None,
1019    Uniform,
1020    Beta {
1021        a: f64,
1022        b: f64,
1023    },
1024}
1025
1026#[derive(Debug, Clone, Serialize, Deserialize, Default)]
1027pub enum LinearCoefficientGeometry {
1028    #[default]
1029    Unconstrained,
1030    Bounded {
1031        min: f64,
1032        max: f64,
1033        #[serde(default)]
1034        prior: BoundedCoefficientPriorSpec,
1035    },
1036}
1037
1038#[derive(Debug, Clone, Serialize, Deserialize)]
1039pub struct LinearTermSpec {
1040    pub name: String,
1041    /// Primary feature column index. For Wilkinson-Rogers `:` interaction
1042    /// terms (`a:b[:c...]`) this is the first column in `feature_cols`; the
1043    /// realized design column is the elementwise product across every entry
1044    /// of `feature_cols`. Plain (non-interaction) linear terms set
1045    /// `feature_cols == vec![feature_col]`.
1046    pub feature_col: usize,
1047    /// Full list of columns whose elementwise product yields this term's
1048    /// design column. `len() >= 1`; `len() == 1` is a plain linear effect.
1049    #[serde(default)]
1050    pub feature_cols: Vec<usize>,
1051    /// Categorical-level gates for a factor-aware `:` interaction.
1052    ///
1053    /// Each `(col, level_bits)` multiplies the realized design column by the
1054    /// indicator `1[data[row, col].to_bits() == level_bits]`. This is how a
1055    /// `factor:x` (or `factor:factor`) interaction is expanded: `build_termspec`
1056    /// emits one `LinearTermSpec` per surviving cell of the categorical
1057    /// operand(s) (treatment-coded, first level dropped per factor), each
1058    /// carrying the numeric operands in `feature_cols` and the cell's level
1059    /// gate(s) here. Empty for a plain numeric `:` interaction or main effect,
1060    /// in which case the realized column is exactly the numeric product.
1061    #[serde(default)]
1062    pub categorical_levels: Vec<(usize, u64)>,
1063    /// Optional ridge (`S = I`, REML-selected `λ`) on this linear coefficient.
1064    /// A parametric linear term carries no wiggliness, so it is **unpenalized by
1065    /// default** — gam reports the MLE, matching mgcv/glm/survreg/VGAM (which
1066    /// penalize parametric terms only under an explicit `paraPen`). Set `true`
1067    /// to opt into an explicit shrinkage ridge (a zero-mean Gaussian prior
1068    /// `β ~ N(0, λ⁻¹)`); doing so adds one outer REML smoothing coordinate.
1069    #[serde(default = "default_linear_term_double_penalty")]
1070    pub double_penalty: bool,
1071    #[serde(default)]
1072    pub coefficient_geometry: LinearCoefficientGeometry,
1073    #[serde(default)]
1074    pub coefficient_min: Option<f64>,
1075    #[serde(default)]
1076    pub coefficient_max: Option<f64>,
1077}
1078
1079impl LinearTermSpec {
1080    /// Return the effective list of feature columns. Backfills from
1081    /// `feature_col` for legacy specs that predate the multi-column field.
1082    pub fn effective_feature_cols(&self) -> Vec<usize> {
1083        if self.feature_cols.is_empty() {
1084            vec![self.feature_col]
1085        } else {
1086            self.feature_cols.clone()
1087        }
1088    }
1089
1090    /// True when this term is a Wilkinson-Rogers `:` interaction (multi-col).
1091    pub fn is_interaction(&self) -> bool {
1092        self.feature_cols.len() > 1 || !self.categorical_levels.is_empty()
1093    }
1094
1095    /// Realize this linear term's `(n,)` design column from `data`.
1096    ///
1097    /// The column is the elementwise product of every numeric feature column
1098    /// (`effective_feature_cols`) gated by the categorical-level indicators in
1099    /// `categorical_levels`: each `(col, level_bits)` multiplies the running
1100    /// column by `1[data[row, col].to_bits() == level_bits]`. A plain numeric
1101    /// term (no `categorical_levels`) reduces to the bare product, matching the
1102    /// historical behaviour. A pure categorical interaction (empty
1103    /// `feature_cols`, non-empty `categorical_levels`) reduces to the cell
1104    /// indicator. Bounds are validated here; the returned column has length
1105    /// `data.nrows()`.
1106    pub fn realized_design_column(&self, data: ArrayView2<'_, f64>) -> Result<Array1<f64>, String> {
1107        let n = data.nrows();
1108        let p = data.ncols();
1109        let bounds = |col: usize| -> Result<(), String> {
1110            if col >= p {
1111                Err(format!(
1112                    "linear term '{}' feature column {} out of bounds for {} columns",
1113                    self.name, col, p
1114                ))
1115            } else {
1116                Ok(())
1117            }
1118        };
1119
1120        // Numeric operands. When `categorical_levels` is set we treat
1121        // `feature_cols` as the (possibly empty) numeric operand list and start
1122        // from a column of ones; otherwise we preserve the legacy backfill from
1123        // `feature_col` so a plain term with no `feature_cols` still resolves.
1124        let mut column = if self.categorical_levels.is_empty() {
1125            let cols = self.effective_feature_cols();
1126            for &c in &cols {
1127                bounds(c)?;
1128            }
1129            let mut acc = data.column(cols[0]).to_owned();
1130            for &c in cols.iter().skip(1) {
1131                acc *= &data.column(c);
1132            }
1133            acc
1134        } else {
1135            let mut acc = Array1::<f64>::ones(n);
1136            for &c in &self.feature_cols {
1137                bounds(c)?;
1138                acc *= &data.column(c);
1139            }
1140            acc
1141        };
1142
1143        for &(col, level_bits) in &self.categorical_levels {
1144            bounds(col)?;
1145            let gate = data.column(col);
1146            for (out, &v) in column.iter_mut().zip(gate.iter()) {
1147                if v.to_bits() != level_bits {
1148                    *out = 0.0;
1149                }
1150            }
1151        }
1152
1153        Ok(column)
1154    }
1155}
1156
1157pub const fn default_linear_term_double_penalty() -> bool {
1158    // Parametric/linear terms are unpenalized by default — a single linear
1159    // coefficient has no roughness for a smoothing penalty to control, so the
1160    // historical `S = I`, REML-selected `λ` shrank every linear coefficient off
1161    // the MLE and injected a spurious outer smoothing coordinate (#749). Mature
1162    // tools (mgcv/glm/survreg/VGAM) leave parametric terms unpenalized; gam now
1163    // matches that and reports the MLE. An explicit `double_penalty = true`
1164    // still opts a term into a ridge.
1165    false
1166}
1167
1168pub const fn default_pca_smooth_penalty() -> f64 {
1169    1.0
1170}
1171
1172pub const fn default_pca_chunk_size() -> usize {
1173    4096
1174}
1175
1176/// Random-effects term specification.
1177///
1178/// The selected feature column is interpreted as a categorical grouping variable.
1179/// The term contributes a one-hot dummy block with an identity penalty on group
1180/// coefficients, equivalent to i.i.d. Gaussian random effects.
1181#[derive(Debug, Clone, Serialize, Deserialize)]
1182pub struct RandomEffectTermSpec {
1183    pub name: String,
1184    pub feature_col: usize,
1185    /// If true, drop the lexicographically first group level to use treatment coding.
1186    /// If false, keep all levels (full one-hot block, still identifiable under ridge).
1187    pub drop_first_level: bool,
1188    /// If true, add a ridge penalty and estimate this block as a random effect.
1189    /// If false, leave the one-hot/treatment-coded block unpenalized so it is a
1190    /// fixed categorical main effect.  The default preserves older saved models.
1191    #[serde(default = "default_random_effect_penalized")]
1192    pub penalized: bool,
1193    /// Optional fixed kept-level set (sorted by f64 bit pattern) captured at fit time.
1194    /// When present, prediction uses exactly these columns to avoid design drift.
1195    #[serde(default)]
1196    pub frozen_levels: Option<Vec<u64>>,
1197}
1198
1199pub fn default_random_effect_penalized() -> bool {
1200    true
1201}
1202
1203pub fn validate_measure_jet_positive_vec_len(
1204    label: &str,
1205    term_name: &str,
1206    field: &str,
1207    values: &[f64],
1208    expected: usize,
1209) -> Result<(), String> {
1210    if values.len() != expected {
1211        return Err(SmoothError::invalid_config(format!(
1212            "{label} term '{term_name}' frozen MeasureJet {field} has length {}, expected {expected}",
1213            values.len()
1214        ))
1215        .into());
1216    }
1217    if values
1218        .iter()
1219        .any(|value| !(value.is_finite() && *value > 0.0))
1220    {
1221        return Err(SmoothError::invalid_config(format!(
1222            "{label} term '{term_name}' frozen MeasureJet {field} values must be positive and finite"
1223        ))
1224        .into());
1225    }
1226    Ok(())
1227}
1228
1229#[derive(Debug, Clone, Serialize, Deserialize)]
1230pub struct TermCollectionSpec {
1231    pub linear_terms: Vec<LinearTermSpec>,
1232    pub random_effect_terms: Vec<RandomEffectTermSpec>,
1233    pub smooth_terms: Vec<SmoothTermSpec>,
1234}
1235
1236pub fn validate_smooth_basis_frozen(
1237    basis: &SmoothBasisSpec,
1238    label: &str,
1239    term_name: &str,
1240) -> Result<(), String> {
1241    match basis {
1242        SmoothBasisSpec::ByVariable { inner, .. }
1243        | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
1244            validate_smooth_basis_frozen(inner, label, term_name)
1245        }
1246        SmoothBasisSpec::BSpline1D { spec, .. } => {
1247            if !matches!(
1248                spec.knotspec,
1249                BSplineKnotSpec::Provided(_)
1250                    | BSplineKnotSpec::PeriodicUniform { .. }
1251                    | BSplineKnotSpec::NaturalCubicRegression { .. }
1252            ) {
1253                return Err(format!(
1254                    "{label} term '{term_name}' is not frozen: BSpline knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression"
1255                ));
1256            }
1257            Ok(())
1258        }
1259        SmoothBasisSpec::ThinPlate { spec, .. } => {
1260            if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1261                return Err(format!(
1262                    "{label} term '{term_name}' is not frozen: ThinPlate centers must be UserProvided"
1263                ));
1264            }
1265            if matches!(
1266                spec.identifiability,
1267                SpatialIdentifiability::OrthogonalToParametric
1268            ) {
1269                return Err(format!(
1270                    "{label} term '{term_name}' is not frozen: ThinPlate identifiability must be FrozenTransform or None"
1271                ));
1272            }
1273            Ok(())
1274        }
1275        _ => Ok(()),
1276    }
1277}
1278
1279impl TermCollectionSpec {
1280    /// Write this collection's topology identity into a warm-start cache
1281    /// fingerprint (#869).
1282    ///
1283    /// The persistent warm-start `cache_key` hashes only family + raw input
1284    /// dimensions, so two fits on the same data that differ *only* in their
1285    /// smooth topology (the `s(..., type=AUTO)` candidate enumeration: sphere
1286    /// vs torus vs euclidean vs duchon) collide on one key and seed each other
1287    /// with geometrically incompatible β/ρ. Folding the per-term structural
1288    /// kind + feature columns + linear/random-effect counts into the shape hash
1289    /// gives each candidate its own key, so the screen→full-refit reuse of one
1290    /// candidate is preserved while cross-candidate contamination is removed.
1291    /// Only the structural identity is hashed (not fitted coefficients or
1292    /// frozen knot values), so a refit of the *same* topology still hits.
1293    pub fn write_structural_shape_hash(&self, h: &mut gam_runtime::warm_start::Fingerprinter) {
1294        h.write_str("term-collection");
1295        h.write_usize(self.linear_terms.len());
1296        for linear in &self.linear_terms {
1297            h.write_str(&linear.name);
1298        }
1299        h.write_usize(self.random_effect_terms.len());
1300        h.write_usize(self.smooth_terms.len());
1301        for smooth in &self.smooth_terms {
1302            h.write_str(&smooth.name);
1303            h.write_str(smooth.basis.structural_kind());
1304            for col in smooth.basis.structural_feature_cols() {
1305                h.write_usize(col);
1306            }
1307        }
1308    }
1309
1310    /// Validate that a term collection spec represents a fully frozen model
1311    /// (i.e. all knots/centers are pre-computed, identifiability transforms are
1312    /// baked in, and random-effect levels are fixed).
1313    pub fn validate_frozen(&self, label: &str) -> Result<(), String> {
1314        for linear in &self.linear_terms {
1315            if let (Some(min), Some(max)) = (linear.coefficient_min, linear.coefficient_max)
1316                && (!min.is_finite() || !max.is_finite() || min > max)
1317            {
1318                return Err(SmoothError::invalid_config(format!(
1319                    "{label} linear term '{}' has invalid coefficient constraint [{min}, {max}]",
1320                    linear.name
1321                ))
1322                .into());
1323            }
1324            if let Some(min) = linear.coefficient_min
1325                && !min.is_finite()
1326            {
1327                return Err(SmoothError::invalid_config(format!(
1328                    "{label} linear term '{}' has non-finite coefficient minimum {min}",
1329                    linear.name
1330                ))
1331                .into());
1332            }
1333            if let Some(max) = linear.coefficient_max
1334                && !max.is_finite()
1335            {
1336                return Err(SmoothError::invalid_config(format!(
1337                    "{label} linear term '{}' has non-finite coefficient maximum {max}",
1338                    linear.name
1339                ))
1340                .into());
1341            }
1342            if let LinearCoefficientGeometry::Bounded { min, max, prior } =
1343                &linear.coefficient_geometry
1344            {
1345                if !min.is_finite() || !max.is_finite() || min >= max {
1346                    return Err(SmoothError::invalid_config(format!(
1347                        "{label} bounded term '{}' has invalid bounds [{min}, {max}]",
1348                        linear.name
1349                    ))
1350                    .into());
1351                }
1352                match prior {
1353                    BoundedCoefficientPriorSpec::None | BoundedCoefficientPriorSpec::Uniform => {}
1354                    BoundedCoefficientPriorSpec::Beta { a, b } => {
1355                        if !a.is_finite() || !b.is_finite() || *a < 1.0 || *b < 1.0 {
1356                            return Err(SmoothError::invalid_config(format!(
1357                                "{label} bounded term '{}' has invalid Beta prior ({a}, {b})",
1358                                linear.name
1359                            ))
1360                            .into());
1361                        }
1362                    }
1363                }
1364            }
1365        }
1366        for st in &self.smooth_terms {
1367            match &st.basis {
1368                SmoothBasisSpec::ByVariable { inner, .. } => {
1369                    validate_smooth_basis_frozen(inner, label, &st.name)?;
1370                    let nested = SmoothTermSpec {
1371                        name: st.name.clone(),
1372                        basis: (**inner).clone(),
1373                        shape: st.shape,
1374                        joint_null_rotation: None,
1375                    };
1376                    TermCollectionSpec {
1377                        linear_terms: Vec::new(),
1378                        random_effect_terms: Vec::new(),
1379                        smooth_terms: vec![nested],
1380                    }
1381                    .validate_frozen(label)?;
1382                }
1383                SmoothBasisSpec::FactorSumToZero { inner, levels, .. } => {
1384                    if levels.len() < 2 {
1385                        return Err(format!(
1386                            "{label} term '{}' has invalid frozen sz levels",
1387                            st.name
1388                        ));
1389                    }
1390                    validate_smooth_basis_frozen(inner, label, &st.name)?;
1391                }
1392                SmoothBasisSpec::BSpline1D { spec, .. } => {
1393                    if !matches!(
1394                        spec.knotspec,
1395                        BSplineKnotSpec::Provided(_)
1396                            | BSplineKnotSpec::PeriodicUniform { .. }
1397                            | BSplineKnotSpec::NaturalCubicRegression { .. }
1398                    ) {
1399                        return Err(SmoothError::invalid_config(format!(
1400                            "{label} term '{}' is not frozen: BSpline knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression",
1401                            st.name
1402                        ))
1403                        .into());
1404                    }
1405                }
1406                SmoothBasisSpec::ThinPlate { spec, .. } => {
1407                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1408                        return Err(SmoothError::invalid_config(format!(
1409                            "{label} term '{}' is not frozen: ThinPlate centers must be UserProvided",
1410                            st.name
1411                        ))
1412                        .into());
1413                    }
1414                    if matches!(
1415                        spec.identifiability,
1416                        SpatialIdentifiability::OrthogonalToParametric
1417                    ) {
1418                        return Err(SmoothError::invalid_config(format!(
1419                            "{label} term '{}' is not frozen: ThinPlate identifiability must be FrozenTransform or None",
1420                            st.name
1421                        ))
1422                        .into());
1423                    }
1424                }
1425                SmoothBasisSpec::Sphere { spec, .. } => {
1426                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1427                        return Err(SmoothError::invalid_config(format!(
1428                            "{label} term '{}' is not frozen: Sphere centers must be UserProvided",
1429                            st.name
1430                        ))
1431                        .into());
1432                    }
1433                    if matches!(spec.method, crate::basis::SphereMethod::Harmonic)
1434                        && spec.max_degree.is_none_or(|d| d == 0)
1435                    {
1436                        return Err(format!(
1437                            "{label} term '{}' is not frozen: sphere max_degree must be positive",
1438                            st.name
1439                        ));
1440                    }
1441                }
1442                SmoothBasisSpec::ConstantCurvature { spec, .. } => {
1443                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1444                        return Err(SmoothError::invalid_config(format!(
1445                            "{label} term '{}' is not frozen: ConstantCurvature centers must be UserProvided",
1446                            st.name
1447                        ))
1448                        .into());
1449                    }
1450                    if !(spec.length_scale.is_finite() && spec.length_scale > 0.0) {
1451                        return Err(SmoothError::invalid_config(format!(
1452                            "{label} term '{}' is not frozen: ConstantCurvature length_scale must be the realized positive value",
1453                            st.name
1454                        ))
1455                        .into());
1456                    }
1457                }
1458                SmoothBasisSpec::MeasureJet { spec, .. } => {
1459                    let centers = match &spec.center_strategy {
1460                        CenterStrategy::UserProvided(centers) => centers,
1461                        _ => {
1462                            return Err(SmoothError::invalid_config(format!(
1463                                "{label} term '{}' is not frozen: MeasureJet centers must be UserProvided",
1464                                st.name
1465                            ))
1466                            .into());
1467                        }
1468                    };
1469                    if centers.nrows() == 0 {
1470                        return Err(SmoothError::invalid_config(format!(
1471                            "{label} term '{}' is not frozen: MeasureJet centers are empty",
1472                            st.name
1473                        ))
1474                        .into());
1475                    }
1476                    if !(spec.length_scale.is_finite() && spec.length_scale > 0.0) {
1477                        return Err(SmoothError::invalid_config(format!(
1478                            "{label} term '{}' is not frozen: MeasureJet length_scale must be the realized positive value",
1479                            st.name
1480                        ))
1481                        .into());
1482                    }
1483                    // Exact replay needs the fit-data penalty quadrature and
1484                    // normalization payload (`BasisMetadata::MeasureJet`).
1485                    let frozen = spec.frozen_quadrature.as_ref().ok_or_else(|| {
1486                        SmoothError::invalid_config(format!(
1487                            "{label} term '{}' is not frozen: MeasureJet frozen_quadrature payload is missing",
1488                            st.name
1489                        ))
1490                    })?;
1491                    if frozen.masses.len() != centers.nrows() {
1492                        return Err(SmoothError::invalid_config(format!(
1493                            "{label} term '{}' frozen MeasureJet has {} masses for {} centers",
1494                            st.name,
1495                            frozen.masses.len(),
1496                            centers.nrows()
1497                        ))
1498                        .into());
1499                    }
1500                    let total_mass = frozen.masses.sum();
1501                    if frozen
1502                        .masses
1503                        .iter()
1504                        .any(|mass| !(mass.is_finite() && *mass >= 0.0))
1505                        || !(total_mass.is_finite() && total_mass > 0.0)
1506                    {
1507                        return Err(SmoothError::invalid_config(format!(
1508                            "{label} term '{}' frozen MeasureJet masses must be finite, nonnegative, and have positive total mass",
1509                            st.name
1510                        ))
1511                        .into());
1512                    }
1513                    let n_levels = frozen.eps_band.len();
1514                    if n_levels == 0
1515                        || frozen
1516                            .eps_band
1517                            .iter()
1518                            .any(|eps| !(eps.is_finite() && *eps > 0.0))
1519                    {
1520                        return Err(SmoothError::invalid_config(format!(
1521                            "{label} term '{}' frozen MeasureJet eps_band must be nonempty, finite, and positive",
1522                            st.name
1523                        ))
1524                        .into());
1525                    }
1526                    for (idx, pair) in frozen.eps_band.windows(2).enumerate() {
1527                        if pair[1] <= pair[0] {
1528                            return Err(SmoothError::invalid_config(format!(
1529                                "{label} term '{}' frozen MeasureJet eps_band is not strictly ascending at {idx}: {} then {}",
1530                                st.name,
1531                                pair[0],
1532                                pair[1]
1533                            ))
1534                            .into());
1535                        }
1536                    }
1537                    validate_measure_jet_positive_vec_len(
1538                        label,
1539                        &st.name,
1540                        "support_means",
1541                        &frozen.support_means,
1542                        n_levels,
1543                    )?;
1544                    // Mode predicate MUST match the builder's
1545                    // (`measure_jet_multiscale_mode`): per-level/multiscale is the
1546                    // explicit `spec.multiscale` opt-in (#1116). In single-scale
1547                    // mode the builder emits a single FUSED penalty (empty
1548                    // per-level scales + `fused_penalty_normalization_scale:
1549                    // Some`); only the multiscale opt-in carries `n_levels`
1550                    // per-level scales.
1551                    let per_level = crate::basis::measure_jet_multiscale_mode(spec);
1552                    if per_level {
1553                        validate_measure_jet_positive_vec_len(
1554                            label,
1555                            &st.name,
1556                            "penalty_normalization_scales",
1557                            &frozen.penalty_normalization_scales,
1558                            n_levels,
1559                        )?;
1560                        validate_measure_jet_positive_vec_len(
1561                            label,
1562                            &st.name,
1563                            "raw_penalty_normalization_scales",
1564                            &frozen.raw_penalty_normalization_scales,
1565                            n_levels,
1566                        )?;
1567                        if frozen.fused_penalty_normalization_scale.is_some() {
1568                            return Err(SmoothError::invalid_config(format!(
1569                                "{label} term '{}' per-level MeasureJet must not carry a fused penalty normalization scale",
1570                                st.name
1571                            ))
1572                            .into());
1573                        }
1574                    } else {
1575                        if !frozen.penalty_normalization_scales.is_empty()
1576                            || !frozen.raw_penalty_normalization_scales.is_empty()
1577                        {
1578                            return Err(SmoothError::invalid_config(format!(
1579                                "{label} term '{}' fused MeasureJet must not carry per-level penalty normalization scales",
1580                                st.name
1581                            ))
1582                            .into());
1583                        }
1584                        match frozen.fused_penalty_normalization_scale {
1585                            Some(scale) if scale.is_finite() && scale > 0.0 => {}
1586                            Some(scale) => {
1587                                return Err(SmoothError::invalid_config(format!(
1588                                    "{label} term '{}' fused MeasureJet penalty normalization scale must be positive and finite, got {scale}",
1589                                    st.name
1590                                ))
1591                                .into());
1592                            }
1593                            None => {
1594                                return Err(SmoothError::invalid_config(format!(
1595                                    "{label} term '{}' fused MeasureJet is missing its penalty normalization scale",
1596                                    st.name
1597                                ))
1598                                .into());
1599                            }
1600                        }
1601                    }
1602                }
1603                SmoothBasisSpec::Matern { spec, .. } => {
1604                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1605                        return Err(SmoothError::invalid_config(format!(
1606                            "{label} term '{}' is not frozen: Matern centers must be UserProvided",
1607                            st.name
1608                        ))
1609                        .into());
1610                    }
1611                }
1612                SmoothBasisSpec::Duchon { spec, .. } => {
1613                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1614                        return Err(SmoothError::invalid_config(format!(
1615                            "{label} term '{}' is not frozen: Duchon centers must be UserProvided",
1616                            st.name
1617                        ))
1618                        .into());
1619                    }
1620                    if matches!(
1621                        spec.identifiability,
1622                        SpatialIdentifiability::OrthogonalToParametric
1623                    ) {
1624                        return Err(SmoothError::invalid_config(format!(
1625                            "{label} term '{}' is not frozen: Duchon identifiability must be FrozenTransform or None",
1626                            st.name
1627                        ))
1628                        .into());
1629                    }
1630                }
1631                SmoothBasisSpec::Pca {
1632                    centered,
1633                    center_mean,
1634                    pca_basis_path,
1635                    ..
1636                } => {
1637                    if *centered && center_mean.is_none() && pca_basis_path.is_none() {
1638                        return Err(SmoothError::invalid_config(format!(
1639                            "{label} term '{}' is not frozen: centered Pca missing center_mean",
1640                            st.name
1641                        ))
1642                        .into());
1643                    }
1644                }
1645                SmoothBasisSpec::BySmooth { smooth, by_kind } => {
1646                    if let SmoothBasisSpec::BySmooth { .. } = smooth.as_ref() {
1647                        return Err(format!("{label} term '{}' has nested by-smooths", st.name));
1648                    }
1649                    match by_kind {
1650                        ByVarKind::Numeric { .. } => {}
1651                        ByVarKind::Factor { frozen_levels, .. } if frozen_levels.is_none() => {
1652                            return Err(format!(
1653                                "{label} term '{}' is not frozen: by-factor levels missing",
1654                                st.name
1655                            ));
1656                        }
1657                        ByVarKind::Factor { .. } => {}
1658                    }
1659                    let nested = TermCollectionSpec {
1660                        linear_terms: vec![],
1661                        random_effect_terms: vec![],
1662                        smooth_terms: vec![SmoothTermSpec {
1663                            name: st.name.clone(),
1664                            basis: (**smooth).clone(),
1665                            shape: st.shape,
1666                            joint_null_rotation: None,
1667                        }],
1668                    };
1669                    nested.validate_frozen(label)?;
1670                }
1671                SmoothBasisSpec::FactorSmooth { spec } => {
1672                    if spec.group_frozen_levels.is_none() {
1673                        return Err(format!(
1674                            "{label} term '{}' is not frozen: factor-smooth levels missing",
1675                            st.name
1676                        ));
1677                    }
1678                    if !matches!(
1679                        spec.marginal.knotspec,
1680                        BSplineKnotSpec::Provided(_)
1681                            | BSplineKnotSpec::PeriodicUniform { .. }
1682                            // mgcv's `bs="sz"` default marginal is a cubic
1683                            // regression spline (#1074), and the freeze step
1684                            // restores it as a `NaturalCubicRegression` knotspec
1685                            // carrying its `k` value-knots (spatial_optimization.rs
1686                            // `marginal_is_cr` branch) — the SAME treatment the
1687                            // tensor margin already gets in the arm below. Without
1688                            // this variant a frozen `sz` factor smooth fails its own
1689                            // predict-time freeze check ("factor-smooth marginal
1690                            // knots missing") even though its knots are fully
1691                            // materialized; the validation simply was not updated
1692                            // when the cr marginal landed.
1693                            | BSplineKnotSpec::NaturalCubicRegression { .. }
1694                    ) {
1695                        return Err(format!(
1696                            "{label} term '{}' is not frozen: factor-smooth marginal knots missing",
1697                            st.name
1698                        ));
1699                    }
1700                }
1701                SmoothBasisSpec::TensorBSpline { spec, .. } => {
1702                    for (dim, marginal) in spec.marginalspecs.iter().enumerate() {
1703                        if !matches!(
1704                            marginal.knotspec,
1705                            BSplineKnotSpec::Provided(_)
1706                                | BSplineKnotSpec::PeriodicUniform { .. }
1707                                | BSplineKnotSpec::NaturalCubicRegression { .. }
1708                        ) {
1709                            return Err(SmoothError::invalid_config(format!(
1710                                "{label} term '{}' dim {} is not frozen: tensor marginal knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression",
1711                                st.name, dim
1712                            ))
1713                            .into());
1714                        }
1715                    }
1716                    if matches!(
1717                        spec.identifiability,
1718                        TensorBSplineIdentifiability::SumToZero
1719                            | TensorBSplineIdentifiability::MarginalSumToZero
1720                    ) {
1721                        return Err(SmoothError::invalid_config(format!(
1722                            "{label} term '{}' is not frozen: tensor identifiability must be FrozenTransform or None",
1723                            st.name
1724                        ))
1725                        .into());
1726                    }
1727                }
1728            }
1729        }
1730
1731        for rt in &self.random_effect_terms {
1732            if rt.frozen_levels.is_none() {
1733                return Err(SmoothError::invalid_config(format!(
1734                    "{label} random-effect term '{}' is not frozen: missing frozen_levels",
1735                    rt.name
1736                ))
1737                .into());
1738            }
1739        }
1740
1741        Ok(())
1742    }
1743
1744    /// Re-resolve every stored feature-column index through `remap`, returning a
1745    /// spec that addresses a different column layout.
1746    ///
1747    /// A frozen `TermCollectionSpec` stores feature columns as *absolute indices
1748    /// into the training table*. To replay it on a fresh dataset whose columns
1749    /// sit at different positions — the common case at prediction time, where
1750    /// the response column is unknown and may be absent entirely — every index
1751    /// must be re-resolved against the new layout. `remap` receives each
1752    /// training-table index and returns its position in the runtime table;
1753    /// callers typically implement it as "look the name up in the training
1754    /// headers, then resolve that name against the prediction dataset".
1755    ///
1756    /// This is the single authority on *which* fields carry a column index
1757    /// across every basis variant (linear, random-effect, the `by=` column of
1758    /// `ByVariable`/`FactorSumToZero`/`BySmooth`, the continuous and group
1759    /// columns of a `FactorSmooth`, and the multi-axis `feature_cols` of every
1760    /// spatial/tensor basis), so a predict-time realignment cannot silently miss
1761    /// one and dereference a stale training index.
1762    pub fn remap_feature_columns<E, F>(&self, mut remap: F) -> Result<TermCollectionSpec, E>
1763    where
1764        F: FnMut(usize) -> Result<usize, E>,
1765    {
1766        let mut out = self.clone();
1767        for lt in &mut out.linear_terms {
1768            lt.feature_col = remap(lt.feature_col)?;
1769            // Also remap the full interaction-factor list. The design builder
1770            // (`build_term_collection_design_inner`) materializes the column from
1771            // `effective_feature_cols()` — which returns `feature_cols` whenever
1772            // it is non-empty (i.e. essentially always, including a plain linear
1773            // term where `feature_cols == [feature_col]`). Remapping only the
1774            // singular `feature_col` left these at their saved *training* indices
1775            // at predict time, so a parametric `Surv(...) ~ x` (and any `:`
1776            // interaction) bailed with "feature column N out of bounds" once the
1777            // response/time columns shift the runtime layout (issue #898).
1778            for fc in lt.feature_cols.iter_mut() {
1779                *fc = remap(*fc)?;
1780            }
1781            // A factor-aware `:` interaction also gates on categorical columns;
1782            // those indices live in the same training-time layout and must be
1783            // realigned to the runtime table alongside the numeric operands, or
1784            // the predict-time level indicator would dereference a stale column.
1785            for (col, _bits) in lt.categorical_levels.iter_mut() {
1786                *col = remap(*col)?;
1787            }
1788        }
1789        for rt in &mut out.random_effect_terms {
1790            rt.feature_col = remap(rt.feature_col)?;
1791        }
1792        for st in &mut out.smooth_terms {
1793            remap_smooth_basis_feature_columns(&mut st.basis, &mut remap)?;
1794        }
1795        Ok(out)
1796    }
1797}
1798
1799/// Walk a `SmoothBasisSpec` tree, re-resolving every column index through
1800/// `remap`. Shared by all predict-time column realignment (see
1801/// [`TermCollectionSpec::remap_feature_columns`]); kept exhaustive so a newly
1802/// added index-bearing variant fails to compile until it is handled here.
1803pub fn remap_smooth_basis_feature_columns<E, F>(
1804    basis: &mut SmoothBasisSpec,
1805    remap: &mut F,
1806) -> Result<(), E>
1807where
1808    F: FnMut(usize) -> Result<usize, E>,
1809{
1810    match basis {
1811        SmoothBasisSpec::ByVariable { inner, by_col, .. }
1812        | SmoothBasisSpec::FactorSumToZero { inner, by_col, .. } => {
1813            *by_col = remap(*by_col)?;
1814            remap_smooth_basis_feature_columns(inner, remap)?;
1815        }
1816        SmoothBasisSpec::BSpline1D { feature_col, .. } => {
1817            *feature_col = remap(*feature_col)?;
1818        }
1819        SmoothBasisSpec::BySmooth { smooth, by_kind } => {
1820            let by_feature_col = match by_kind {
1821                ByVarKind::Numeric { feature_col } | ByVarKind::Factor { feature_col, .. } => {
1822                    feature_col
1823                }
1824            };
1825            *by_feature_col = remap(*by_feature_col)?;
1826            remap_smooth_basis_feature_columns(smooth, remap)?;
1827        }
1828        SmoothBasisSpec::FactorSmooth { spec } => {
1829            for fc in spec.continuous_cols.iter_mut() {
1830                *fc = remap(*fc)?;
1831            }
1832            spec.group_col = remap(spec.group_col)?;
1833        }
1834        SmoothBasisSpec::ThinPlate { feature_cols, .. }
1835        | SmoothBasisSpec::Sphere { feature_cols, .. }
1836        | SmoothBasisSpec::ConstantCurvature { feature_cols, .. }
1837        | SmoothBasisSpec::Matern { feature_cols, .. }
1838        | SmoothBasisSpec::MeasureJet { feature_cols, .. }
1839        | SmoothBasisSpec::Duchon { feature_cols, .. }
1840        | SmoothBasisSpec::Pca { feature_cols, .. }
1841        | SmoothBasisSpec::TensorBSpline { feature_cols, .. } => {
1842            for fc in feature_cols.iter_mut() {
1843                *fc = remap(*fc)?;
1844            }
1845        }
1846    }
1847    Ok(())
1848}
1849
1850#[derive(Debug, Clone)]
1851pub enum PenaltyStructureHint {
1852    Ridge(f64),
1853    Kronecker(Vec<Array2<f64>>),
1854}
1855
1856/// A penalty matrix stored at its natural block size together with the
1857/// column range it occupies in the global coefficient vector.
1858///
1859/// Instead of embedding every penalty into a full `p_total × p_total` dense
1860/// matrix filled with zeros, we keep the compact local matrix and reconstruct
1861/// the global view only when a downstream consumer explicitly requires it.
1862#[derive(Clone)]
1863pub struct BlockwisePenalty {
1864    /// Column range in the global coefficient vector that this penalty covers.
1865    pub col_range: Range<usize>,
1866    /// The local penalty matrix — dimensions `block_p × block_p` where
1867    /// `block_p = col_range.len()`.
1868    pub local: Array2<f64>,
1869    /// Optional nonzero centering vector for this coefficient block.
1870    pub prior_mean: gam_problem::CoefficientPriorMean,
1871    /// Optional structural hint so downstream spectral/logdet code can stay
1872    /// block-local or factorized without reverse-engineering the matrix.
1873    pub structure_hint: Option<PenaltyStructureHint>,
1874    /// Optional operator-form handle bit-equivalent to `local`. Populated when
1875    /// the originating closed-form factory emitted an op-form penalty so exact
1876    /// operator algebra can use matvec instead of materializing the dense
1877    /// `block_p × block_p` Gram. `None` for ordinary dense penalties.
1878    pub op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1879}
1880
1881impl std::fmt::Debug for BlockwisePenalty {
1882    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1883        f.debug_struct("BlockwisePenalty")
1884            .field("col_range", &self.col_range)
1885            .field(
1886                "local",
1887                &format_args!("{}×{}", self.local.nrows(), self.local.ncols()),
1888            )
1889            .field("prior_mean", &self.prior_mean)
1890            .field("structure_hint", &self.structure_hint)
1891            .field("op", &self.op.as_ref().map(|o| o.dim()))
1892            .finish()
1893    }
1894}
1895
1896impl BlockwisePenalty {
1897    /// Create a new blockwise penalty.
1898    pub fn new(col_range: Range<usize>, local: Array2<f64>) -> Self {
1899        assert_eq!(col_range.len(), local.nrows());
1900        assert_eq!(col_range.len(), local.ncols());
1901        Self {
1902            col_range,
1903            local,
1904            prior_mean: gam_problem::CoefficientPriorMean::Zero,
1905            structure_hint: None,
1906            op: None,
1907        }
1908    }
1909
1910    pub fn with_prior_mean(
1911        mut self,
1912        prior_mean: gam_problem::CoefficientPriorMean,
1913    ) -> Self {
1914        self.prior_mean = prior_mean;
1915        self
1916    }
1917
1918    /// Attach an op-form penalty handle bit-equivalent to `local`.
1919    pub fn with_op(
1920        mut self,
1921        op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1922    ) -> Self {
1923        self.op = op;
1924        self
1925    }
1926
1927    pub fn ridge(col_range: Range<usize>, scale: f64) -> Self {
1928        let block_size = col_range.len();
1929        let mut local = Array2::<f64>::zeros((block_size, block_size));
1930        for i in 0..block_size {
1931            local[[i, i]] = scale;
1932        }
1933        Self {
1934            col_range,
1935            local,
1936            prior_mean: gam_problem::CoefficientPriorMean::Zero,
1937            structure_hint: Some(PenaltyStructureHint::Ridge(scale)),
1938            op: None,
1939        }
1940    }
1941
1942    pub fn kronecker(
1943        col_range: Range<usize>,
1944        local: Array2<f64>,
1945        factors: Vec<Array2<f64>>,
1946    ) -> Self {
1947        assert_eq!(col_range.len(), local.nrows());
1948        assert_eq!(col_range.len(), local.ncols());
1949        Self {
1950            col_range,
1951            local,
1952            prior_mean: gam_problem::CoefficientPriorMean::Zero,
1953            structure_hint: Some(PenaltyStructureHint::Kronecker(factors)),
1954            op: None,
1955        }
1956    }
1957
1958    /// Expand this blockwise penalty into a full `p_total × p_total` dense
1959    /// matrix (mostly zeros). Use sparingly — the whole point of blockwise
1960    /// storage is to avoid this allocation.
1961    pub fn to_global(&self, p_total: usize) -> Array2<f64> {
1962        let mut g = Array2::<f64>::zeros((p_total, p_total));
1963        let r = &self.col_range;
1964        assert!(
1965            r.end <= p_total && self.local.nrows() == r.len() && self.local.ncols() == r.len(),
1966            "BlockwisePenalty::to_global shape invariant violated: \
1967             col_range={}..{}, local={}x{}, p_total={}",
1968            r.start,
1969            r.end,
1970            self.local.nrows(),
1971            self.local.ncols(),
1972            p_total,
1973        );
1974        g.slice_mut(s![r.start..r.end, r.start..r.end])
1975            .assign(&self.local);
1976        g
1977    }
1978
1979    /// Convert into a blockwise [`gam_problem::PenaltyMatrix`] without
1980    /// expanding to full dimensions.
1981    pub fn to_penalty_matrix(
1982        &self,
1983        total_dim: usize,
1984    ) -> gam_problem::PenaltyMatrix {
1985        gam_problem::PenaltyMatrix::Blockwise {
1986            local: self.local.clone(),
1987            col_range: self.col_range.clone(),
1988            total_dim,
1989        }
1990    }
1991
1992    /// The block size of this penalty.
1993    #[inline]
1994    pub fn block_size(&self) -> usize {
1995        self.col_range.len()
1996    }
1997}
1998
1999/// Compute `Σ_k λ_k S_k` directly from blockwise penalties, accumulating
2000/// into a pre-allocated `p_total × p_total` output without ever materializing
2001/// individual global matrices.
2002pub fn weighted_blockwise_penalty_sum(
2003    penalties: &[BlockwisePenalty],
2004    lambdas: &[f64],
2005    p_total: usize,
2006) -> Array2<f64> {
2007    assert_eq!(penalties.len(), lambdas.len());
2008    // Smoothing parameters λ_k must be non-negative and finite. A negative
2009    // λ would flip the sign of the corresponding block S_k, turning the
2010    // total penalty matrix indefinite and silently corrupting every
2011    // downstream Cholesky / PIRLS / REML / pseudo-logdet computation that
2012    // assumes S_λ ⪰ 0. Catch this at the boundary rather than after it
2013    // has propagated.
2014    for (idx, &lam) in lambdas.iter().enumerate() {
2015        assert!(
2016            lam.is_finite() && lam >= 0.0,
2017            "weighted_blockwise_penalty_sum: lambdas[{idx}] = {lam} is invalid (must be finite and non-negative; negative smoothing parameters violate S_λ ⪰ 0)",
2018        );
2019    }
2020    // Block column ranges must also fit inside the declared total parameter
2021    // dimension; an out-of-bounds slice would otherwise panic from ndarray
2022    // with a far less informative message.
2023    for (idx, bp) in penalties.iter().enumerate() {
2024        let r = &bp.col_range;
2025        assert!(
2026            r.end <= p_total,
2027            "weighted_blockwise_penalty_sum: penalties[{idx}] col_range {:?} exceeds p_total = {p_total}",
2028            r,
2029        );
2030    }
2031    let mut out = Array2::<f64>::zeros((p_total, p_total));
2032    for (bp, &lam) in penalties.iter().zip(lambdas.iter()) {
2033        let r = &bp.col_range;
2034        let mut slice = out.slice_mut(s![r.start..r.end, r.start..r.end]);
2035        slice.scaled_add(lam, &bp.local);
2036    }
2037    out
2038}
2039
2040// ---------------------------------------------------------------------------
2041// KroneckerPenaltySystem — factored tensor-product penalty representation
2042// ---------------------------------------------------------------------------
2043
2044/// Factored representation of tensor-product penalties with precomputed
2045/// marginal eigensystems for O(∏q_j) logdet and penalty operations.
2046#[derive(Debug, Clone)]
2047pub struct KroneckerPenaltySystem {
2048    /// Marginal penalty matrices: `marginal_penalties[k]` is `(q_k, q_k)`.
2049    pub marginal_penalties: Vec<Array2<f64>>,
2050    /// Precomputed eigensystems: `(eigenvalues, eigenvectors)` per marginal.
2051    pub marginal_eigensystems: Vec<(Array1<f64>, Array2<f64>)>,
2052    /// Marginal basis dimensions.
2053    pub marginal_dims: Vec<usize>,
2054    /// Whether a global ridge (double) penalty is present.
2055    pub has_double_penalty: bool,
2056}
2057
2058impl KroneckerPenaltySystem {
2059    pub fn new(
2060        marginal_penalties: Vec<Array2<f64>>,
2061        marginal_dims: Vec<usize>,
2062        has_double_penalty: bool,
2063    ) -> Result<Self, BasisError> {
2064        if marginal_penalties.len() != marginal_dims.len() {
2065            crate::bail_dim_basis!(
2066                "KroneckerPenaltySystem: {} penalties vs {} dims",
2067                marginal_penalties.len(),
2068                marginal_dims.len()
2069            );
2070        }
2071        let eigensystems =
2072            kronecker_marginal_eigensystems(&marginal_penalties, "KroneckerPenaltySystem")
2073                .map_err(|e| BasisError::InvalidInput(e.to_string()))?;
2074        Ok(Self {
2075            marginal_penalties,
2076            marginal_eigensystems: eigensystems,
2077            marginal_dims,
2078            has_double_penalty,
2079        })
2080    }
2081
2082    pub fn p_total(&self) -> usize {
2083        self.marginal_dims.iter().copied().product()
2084    }
2085
2086    pub fn ndim(&self) -> usize {
2087        self.marginal_dims.len()
2088    }
2089
2090    pub fn num_penalties(&self) -> usize {
2091        self.marginal_dims.len() + if self.has_double_penalty { 1 } else { 0 }
2092    }
2093
2094    /// Compute `log|S|₊` and its first/second derivatives w.r.t. `ρ_k = log(λ_k)`.
2095    ///
2096    /// Iterates over the ∏q_j multi-index grid. Cost: O(d · ∏q_j), no O(p²) storage.
2097    pub fn logdet_and_derivatives(
2098        &self,
2099        lambdas: &[f64],
2100        ridge: f64,
2101    ) -> (f64, Array1<f64>, Array2<f64>) {
2102        let n_pen = self.num_penalties();
2103        assert_eq!(lambdas.len(), n_pen, "lambda count mismatch");
2104        let marginal_evals: Vec<_> = self
2105            .marginal_eigensystems
2106            .iter()
2107            .map(|(evals, _)| evals.view())
2108            .collect();
2109        kronecker_logdet_and_derivatives(
2110            &marginal_evals,
2111            &self.marginal_dims,
2112            lambdas,
2113            self.has_double_penalty,
2114            ridge,
2115        )
2116    }
2117
2118    pub fn logdet_rank_and_derivatives(
2119        &self,
2120        lambdas: &[f64],
2121        ridge: f64,
2122    ) -> (f64, usize, Array1<f64>, Array2<f64>) {
2123        let n_pen = self.num_penalties();
2124        assert_eq!(lambdas.len(), n_pen, "lambda count mismatch");
2125        let d = self.marginal_dims.len();
2126        let mut logdet = 0.0;
2127        let mut rank = 0usize;
2128        let mut grad = Array1::<f64>::zeros(n_pen);
2129        let mut hess = Array2::<f64>::zeros((n_pen, n_pen));
2130        // Positivity floor for a penalized eigenvalue `σ`: below this the mode
2131        // is treated as an unpenalized (null-space) direction and excluded from
2132        // both the rank count and the pseudo-log-determinant.
2133        const EIGENVALUE_POSITIVITY_FLOOR: f64 = 1e-12;
2134        // Floor on the *structural* eigenvalue sum (λ-independent) used to decide
2135        // whether a mode lives in the penalty range space and so should receive
2136        // the stabilizing ridge; a structurally-null mode gets no ridge.
2137        const STRUCTURAL_ZERO_FLOOR: f64 = 1e-12;
2138        let mut multi_idx = vec![0usize; d];
2139        loop {
2140            let mut sigma = 0.0;
2141            let mut structural_sigma = 0.0;
2142            for k in 0..d {
2143                let marginal_eigenvalue = self.marginal_eigensystems[k].0[multi_idx[k]];
2144                structural_sigma += marginal_eigenvalue;
2145                sigma += lambdas[k] * marginal_eigenvalue;
2146            }
2147            let joint_null = structural_sigma <= STRUCTURAL_ZERO_FLOOR;
2148            if self.has_double_penalty && joint_null {
2149                sigma += lambdas[d];
2150            }
2151            if structural_sigma > STRUCTURAL_ZERO_FLOOR {
2152                sigma += ridge;
2153            }
2154
2155            if sigma > EIGENVALUE_POSITIVITY_FLOOR {
2156                rank += 1;
2157                logdet += sigma.ln();
2158                let inv_sigma = 1.0 / sigma;
2159                let inv_sigma2 = inv_sigma * inv_sigma;
2160                for k in 0..n_pen {
2161                    let ck = if k < d {
2162                        lambdas[k] * self.marginal_eigensystems[k].0[multi_idx[k]]
2163                    } else if joint_null {
2164                        lambdas[d]
2165                    } else {
2166                        0.0
2167                    };
2168                    grad[k] += ck * inv_sigma;
2169                    hess[[k, k]] += ck * inv_sigma - ck * ck * inv_sigma2;
2170                    for l in (k + 1)..n_pen {
2171                        let cl = if l < d {
2172                            lambdas[l] * self.marginal_eigensystems[l].0[multi_idx[l]]
2173                        } else if joint_null {
2174                            lambdas[d]
2175                        } else {
2176                            0.0
2177                        };
2178                        let off = -ck * cl * inv_sigma2;
2179                        hess[[k, l]] += off;
2180                        hess[[l, k]] += off;
2181                    }
2182                }
2183            }
2184
2185            let mut carry = true;
2186            for dim in (0..d).rev() {
2187                if carry {
2188                    multi_idx[dim] += 1;
2189                    if multi_idx[dim] < self.marginal_dims[dim] {
2190                        carry = false;
2191                    } else {
2192                        multi_idx[dim] = 0;
2193                    }
2194                }
2195            }
2196            if carry {
2197                break;
2198            }
2199        }
2200        (logdet, rank, grad, hess)
2201    }
2202}
2203
2204#[cfg(test)]
2205mod joint_unpenalized_dim_tests {
2206    use super::joint_unpenalized_dim;
2207    use ndarray::{Array2, array};
2208
2209    #[test]
2210    fn no_penalty_is_fully_unpenalized() {
2211        assert_eq!(joint_unpenalized_dim(4, &[], &[]), 4);
2212    }
2213
2214    #[test]
2215    fn single_penalty_returns_its_own_null_space() {
2216        // A 3×3 penalty that penalizes only the last coordinate ⇒ 2-dim null
2217        // space (the first two coordinates are unpenalized).
2218        let s = array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 5.0]];
2219        assert_eq!(joint_unpenalized_dim(3, std::slice::from_ref(&s), &[2]), 2);
2220    }
2221
2222    #[test]
2223    fn complementary_double_penalty_has_empty_joint_null_space() {
2224        // The #1360 case in miniature: a "bending" penalty that leaves the
2225        // first coordinate (its 2-dim... here 1-dim) null, plus a
2226        // complementary "null-space ridge" that penalizes exactly that
2227        // coordinate. Per-penalty null dims are {1, 2} and sum to 3 (≈ p),
2228        // but the INTERSECTION is empty: every coordinate is penalized by
2229        // someone, so the joint unpenalized dim is 0.
2230        let bending = array![[0.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]];
2231        let ridge = array![[2.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]];
2232        assert_eq!(joint_unpenalized_dim(3, &[bending, ridge], &[1, 2]), 0);
2233    }
2234
2235    #[test]
2236    fn partial_overlap_keeps_shared_null_direction() {
2237        // Two penalties that BOTH leave coordinate 0 unpenalized ⇒ the shared
2238        // null direction survives the intersection (joint unpenalized dim 1),
2239        // even though naively summing the per-penalty dims would give 4.
2240        let a = array![[0.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 0.0]];
2241        let b = array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]];
2242        assert_eq!(joint_unpenalized_dim(3, &[a, b], &[2, 2]), 1);
2243    }
2244
2245    #[test]
2246    fn non_materialized_penalty_falls_back_conservatively() {
2247        // A penalty whose stored block is not p_local × p_local (e.g. a
2248        // Kronecker tensor factor). With ≥2 penalties the conservative joint
2249        // dim is 0 (never over-rejecting).
2250        let full: Array2<f64> = array![[0.0, 0.0], [0.0, 1.0]];
2251        let factor: Array2<f64> = array![[1.0]]; // wrong shape for p_local=2
2252        assert_eq!(
2253            joint_unpenalized_dim(2, &[full, factor.clone()], &[1, 0]),
2254            0
2255        );
2256        // With a single non-materialized penalty, fall back to its own null dim.
2257        assert_eq!(joint_unpenalized_dim(4, std::slice::from_ref(&factor), &[2]), 2);
2258    }
2259}
2260
2261#[cfg(test)]
2262mod kronecker_penalty_system_tests {
2263    use super::KroneckerPenaltySystem;
2264    use ndarray::array;
2265
2266    #[test]
2267    fn double_penalty_rank_derivatives_use_only_joint_null_space() {
2268        let penalties = vec![
2269            array![[0.0, 0.0], [0.0, 2.0]],
2270            array![[0.0, 0.0], [0.0, 3.0]],
2271        ];
2272        let system = KroneckerPenaltySystem::new(penalties, vec![2usize, 2usize], true).unwrap();
2273        let lambdas = vec![5.0, 7.0, 11.0];
2274
2275        let (logdet, rank, grad, hess) = system.logdet_rank_and_derivatives(&lambdas, 0.0);
2276
2277        let expected_diag = [11.0_f64, 21.0, 10.0, 31.0];
2278        let expected_logdet: f64 = expected_diag.iter().map(|v| v.ln()).sum();
2279        assert_eq!(rank, 4);
2280        assert!((logdet - expected_logdet).abs() <= 1e-12);
2281        assert!(
2282            (grad[2] - 1.0).abs() <= 1e-12,
2283            "double-penalty rank derivative must count only the joint null mode, got {}",
2284            grad[2]
2285        );
2286        assert!(hess[[2, 2]].abs() <= 1e-12);
2287    }
2288}
2289
2290#[derive(Clone, Debug)]
2291pub struct TermCollectionDesign {
2292    /// The full design matrix.
2293    ///
2294    /// Prefer a true sparse matrix when every block is sparse-compatible.
2295    /// If the collection already contains intrinsically sparse blocks, preserve
2296    /// that storage and let PIRLS decide later whether the penalized system is
2297    /// sparse-native eligible. Purely dense materialized blocks still fall back
2298    /// to the lazy block operator when sparse storage would just re-encode a
2299    /// dense matrix.
2300    pub design: DesignMatrix,
2301    pub penalties: Vec<BlockwisePenalty>,
2302    pub nullspace_dims: Vec<usize>,
2303    pub penaltyinfo: Vec<PenaltyBlockInfo>,
2304    pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
2305    /// Optional global coefficient lower bounds for constrained fitting.
2306    /// Length equals `design.ncols()` when present. Unconstrained entries are `-inf`.
2307    pub coefficient_lower_bounds: Option<Array1<f64>>,
2308    /// Optional global inequality constraints:
2309    /// `A * beta >= b`.
2310    pub linear_constraints: Option<LinearInequalityConstraints>,
2311    pub intercept_range: Range<usize>,
2312    pub linear_ranges: Vec<(String, Range<usize>)>,
2313    pub random_effect_ranges: Vec<(String, Range<usize>)>,
2314    pub random_effect_levels: Vec<(String, Vec<u64>)>,
2315    pub smooth: SmoothDesign,
2316}
2317
2318impl TermCollectionDesign {
2319    /// Convert blockwise penalties to `PenaltyMatrix::Blockwise` without
2320    /// expanding to `p_total × p_total`. This is the preferred path for
2321    /// family modules that accept `Vec<PenaltyMatrix>`.
2322    pub fn penalties_as_penalty_matrix(&self) -> Vec<gam_problem::PenaltyMatrix> {
2323        let p = self.design.ncols();
2324        self.penalties
2325            .iter()
2326            .map(|bp| bp.to_penalty_matrix(p))
2327            .collect()
2328    }
2329
2330    /// Number of penalty blocks.
2331    #[inline]
2332    pub fn num_penalties(&self) -> usize {
2333        self.penalties.len()
2334    }
2335
2336    /// Resolve coefficient groups against this design's global coefficient
2337    /// layout and append their penalties after the existing term penalties.
2338    pub fn realize_coefficient_groups(
2339        &self,
2340        groups: &[CoefficientGroupSpec],
2341        base_prior: &gam_spec::RhoPrior,
2342    ) -> Result<RealizedCoefficientGroups, BasisError> {
2343        realize_coefficient_groups(self, groups, base_prior)
2344    }
2345
2346    /// Extract a `KroneckerPenaltySystem` when the model's *only* smooth term is
2347    /// a single Kronecker-factored tensor.
2348    ///
2349    /// This is a deliberate single-tensor fast path, not a partial feature: any
2350    /// other shape — zero Kronecker terms, several of them, or a tensor mixed
2351    /// with non-tensor smooth terms — is served correctly by the standard
2352    /// block-separable assembly, so this returns `None` and the caller falls
2353    /// back to it. The two former conditions (`len != 1` and "a non-Kronecker
2354    /// smooth term exists") are jointly equivalent to "the sole smooth term is
2355    /// Kronecker", which the slice pattern below expresses directly in one pass.
2356    pub fn kronecker_penalty_system(&self) -> Option<KroneckerPenaltySystem> {
2357        let [only_term] = self.smooth.terms.as_slice() else {
2358            return None;
2359        };
2360        let kron = only_term.kronecker_factored.as_ref()?;
2361        // A genuine tensor product needs at least two margins, and the marginal
2362        // design / penalty / dim collections must agree in length. A degenerate
2363        // (single-margin) or internally inconsistent factored basis cannot feed
2364        // the Kronecker fast path, so fall back to the standard assembly rather
2365        // than construct a malformed `KroneckerPenaltySystem` from it.
2366        if kron.marginal_dims.len() < 2
2367            || kron.marginal_penalties.len() != kron.marginal_dims.len()
2368            || kron.marginal_designs.len() != kron.marginal_dims.len()
2369        {
2370            return None;
2371        }
2372        KroneckerPenaltySystem::new(
2373            kron.marginal_penalties.clone(),
2374            kron.marginal_dims.clone(),
2375            kron.has_double_penalty,
2376        )
2377        .ok()
2378    }
2379}
2380
2381// `FittedTermCollection`, `SpatialLengthScaleOptimizationTiming`, and
2382// `FittedTermCollectionWithSpec` were relocated with the GAM fit-orchestration
2383// drivers to `gam-models` (`crate::fit_orchestration::drivers`) — they hold a
2384// `gam_solve::UnifiedFitResult` and are consumed only by those drivers (#1521).
2385
2386#[derive(Clone)]
2387pub struct StandardLatentCoordConfig {
2388    pub values: std::sync::Arc<crate::latent::LatentCoordValues>,
2389    pub term_index: gam_problem::types::SmoothTermIdx,
2390    pub feature_cols: Vec<usize>,
2391    pub manifold: crate::latent::LatentManifold,
2392    pub manifold_auto: bool,
2393    pub retraction_registry: gam_problem::LatentRetractionRegistry,
2394    pub analytic_penalties: Option<std::sync::Arc<crate::AnalyticPenaltyRegistry>>,
2395}
2396
2397#[derive(Clone, Debug, Serialize, Deserialize)]
2398pub struct AdaptiveSpatialMap {
2399    pub termname: String,
2400    pub feature_cols: Vec<usize>,
2401    pub collocation_points: Array2<f64>,
2402    pub inv_magweight: Array1<f64>,
2403    pub invgradweight: Array1<f64>,
2404    pub inv_lapweight: Array1<f64>,
2405}
2406
2407#[derive(Clone, Debug, Serialize, Deserialize)]
2408pub struct AdaptiveRegularizationDiagnostics {
2409    pub epsilon_0: f64,
2410    pub epsilon_g: f64,
2411    pub epsilon_c: f64,
2412    pub epsilon_outer_iterations: usize,
2413    pub mm_iterations: usize,
2414    pub converged: bool,
2415    pub maps: Vec<AdaptiveSpatialMap>,
2416}
2417
2418#[derive(Debug, Clone)]
2419pub struct LinearColumnConditioning {
2420    col_idx: usize,
2421    mean: f64,
2422    scale: f64,
2423}
2424
2425#[derive(Debug, Clone, Default)]
2426pub struct LinearFitConditioning {
2427    pub intercept_idx: usize,
2428    pub columns: Vec<LinearColumnConditioning>,
2429}
2430
2431#[derive(Clone)]
2432pub struct SpatialPsiDerivative {
2433    // These are derivatives with respect to psi = log(kappa), not log(length_scale).
2434    pub penalty_index: usize,
2435    pub penalty_indices: Vec<usize>,
2436    pub global_range: Range<usize>,
2437    pub total_p: usize,
2438    pub x_psi_local: Array2<f64>,
2439    pub s_psi_components_local: Vec<Array2<f64>>,
2440    pub x_psi_psi_local: Array2<f64>,
2441    pub s_psi_psi_components_local: Vec<Array2<f64>>,
2442    pub aniso_group_id: Option<usize>,
2443    /// Pre-computed cross-derivative design matrices for other axes
2444    /// in the same aniso group: Vec of (axis_offset_in_group, matrix).
2445    pub aniso_cross_designs: Option<Vec<(usize, Array2<f64>)>>,
2446    /// On-demand cross-penalty second derivatives ∂²S_m/∂ψ_a∂ψ_b for axes in
2447    /// the same anisotropy group. The input is the other axis offset in the
2448    /// group, and the output is one local penalty matrix per active penalty.
2449    pub aniso_cross_penalty_provider: Option<
2450        std::sync::Arc<
2451            dyn Fn(usize) -> Result<Vec<Array2<f64>>, EstimationError> + Send + Sync + 'static,
2452        >,
2453    >,
2454    /// Optional implicit design-derivative operator (shared across all axes
2455    /// in the same aniso group). When present, `x_psi_local` and
2456    /// `x_psi_psi_local` may be zero-sized, and design-derivative matvecs
2457    /// should go through this operator using `implicit_axis` as the axis index.
2458    pub implicit_operator: Option<std::sync::Arc<crate::basis::ImplicitDesignPsiDerivative>>,
2459    /// Which axis in the implicit operator this entry corresponds to.
2460    pub implicit_axis: usize,
2461}
2462
2463#[derive(Debug, Clone)]
2464pub struct SpatialLogKappaCoords {
2465    /// Flattened ψ values. For isotropic terms, one entry per term.
2466    /// For anisotropic terms, d entries per term (one ψ_a per axis).
2467    pub values: Array1<f64>,
2468    /// Dimensionality of each term: 1 for isotropic, d for anisotropic.
2469    pub dims_per_term: Vec<usize>,
2470}
2471
2472/// Which end of the ψ bound the shared `aniso_bounds_from_data` helper is
2473/// computing. The lower end uses `-max_length_scale.ln()` as the pure-Duchon
2474/// fallback and the `.0` element of `spatial_term_psi_bounds`; the upper end
2475/// uses `-min_length_scale.ln()` and `.1`. Everything else is identical.
2476#[derive(Clone, Copy)]
2477pub enum AnisoBoundEnd {
2478    Lower,
2479    Upper,
2480}
2481
2482impl SpatialLogKappaCoords {
2483    /// Construct from an explicit dims layout plus values.
2484    pub fn new_with_dims(values: Array1<f64>, dims_per_term: Vec<usize>) -> Self {
2485        assert_eq!(
2486            values.len(),
2487            dims_per_term.iter().sum::<usize>(),
2488            "SpatialLogKappaCoords: values length {} != sum of dims_per_term {}",
2489            values.len(),
2490            dims_per_term.iter().sum::<usize>(),
2491        );
2492        Self {
2493            values,
2494            dims_per_term,
2495        }
2496    }
2497
2498    /// Isotropic initialization (backward-compatible path).
2499    pub fn from_length_scales(
2500        spec: &TermCollectionSpec,
2501        term_indices: &[usize],
2502        options: &SpatialLengthScaleOptimizationOptions,
2503    ) -> Self {
2504        let mut out = Array1::<f64>::zeros(term_indices.len());
2505        for (slot, &term_idx) in term_indices.iter().enumerate() {
2506            // Constant-curvature: the single ψ slot is the raw signed κ, seeded
2507            // from the spec (default κ = 0). The −ln(length_scale) convention is
2508            // log-κ semantics and must not touch the raw-κ coordinate; the κ
2509            // window projection happens later via `clamp_to_bounds`. Mirrors the
2510            // aniso constructor's κ branch.
2511            if let Some(cc) = constant_curvature_term_spec(spec, term_idx) {
2512                out[slot] = cc.kappa;
2513                continue;
2514            }
2515            let length_scale = get_spatial_length_scale(spec, term_idx)
2516                .unwrap_or(options.min_length_scale)
2517                .clamp(options.min_length_scale, options.max_length_scale);
2518            out[slot] = -length_scale.ln();
2519        }
2520        Self {
2521            values: out,
2522            dims_per_term: vec![1; term_indices.len()],
2523        }
2524    }
2525
2526    /// Anisotropic-aware initialization.
2527    ///
2528    /// Initialization strategy (per math team recommendation): standardize the
2529    /// knot cloud axiswise, then run the existing isotropic κ initializer in
2530    /// the standardized space. This reuses the trusted isotropic initializer
2531    /// and gives initial η_a = −ln(σ_a) + mean(ln(σ_a)), which satisfies
2532    /// Ση_a = 0 by construction.
2533    ///
2534    /// For each term, checks whether it has `aniso_log_scales` set on its basis spec.
2535    /// - If isotropic (no aniso_log_scales, or 1-D): 1 entry = −ln(length_scale).
2536    /// - If anisotropic with a scalar length scale: d entries, one ψ_a per axis.
2537    ///   Initialized as ψ_a = −ln(length_scale) + η_a  where η_a are the existing
2538    ///   aniso_log_scales (which sum to zero). Multi-dimensional terms without
2539    ///   explicit anisotropy stay scalar here so the seed dimensionality matches
2540    ///   `spatial_dims_per_term`.
2541    /// - If pure Duchon anisotropic: d - 1 free entries store the leading η_a
2542    ///   values directly; the final axis is reconstructed to keep Ση_a = 0.
2543    pub fn from_length_scales_aniso(
2544        spec: &TermCollectionSpec,
2545        term_indices: &[usize],
2546        options: &SpatialLengthScaleOptimizationOptions,
2547    ) -> Self {
2548        let mut vals = Vec::new();
2549        let mut dims = Vec::new();
2550        for &term_idx in term_indices {
2551            // Measure-jet: dial coordinates seeded directly from the term's
2552            // realized (α, τ[, s]); the −ln(length_scale) convention below is
2553            // κ-semantics and never applies to dials.
2554            if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
2555                let seed = measure_jet_psi_seed(mj);
2556                dims.push(seed.len());
2557                vals.extend(seed);
2558                continue;
2559            }
2560            // Constant-curvature: one signed κ slot seeded from the spec's κ
2561            // (clamped feasible). The −ln(length_scale) convention below is
2562            // log-κ semantics and must not touch the raw-κ coordinate. Bounds
2563            // are unavailable here (no data view), so this is the raw spec κ;
2564            // `reseed_from_data` / `clamp_to_bounds` later project it feasible.
2565            if let Some(cc) = constant_curvature_term_spec(spec, term_idx) {
2566                vals.push(cc.kappa);
2567                dims.push(1);
2568                continue;
2569            }
2570            let length_scale = get_spatial_length_scale(spec, term_idx)
2571                .unwrap_or(options.min_length_scale)
2572                .clamp(options.min_length_scale, options.max_length_scale);
2573            let psi_bar = -length_scale.ln(); // global scale = −ln(length_scale)
2574
2575            if spatial_term_uses_per_axis_psi(spec, term_idx) {
2576                // Per-axis anisotropy is enrolled in the joint outer vector:
2577                // ψ_a = ψ̄ + η_a, one slot per axis. The hyper_dirs builder
2578                // produces matching per-axis derivatives in
2579                // `try_build_spatial_term_log_kappa_aniso_derivativeinfos`.
2580                let d = get_spatial_feature_dim(spec, term_idx).unwrap_or(1);
2581                let eta_raw = get_spatial_aniso_log_scales(spec, term_idx)
2582                    .expect("predicate guarantees aniso_log_scales is Some");
2583                let eta = center_aniso_log_scales(&eta_raw);
2584                for &eta_a in &eta {
2585                    vals.push(psi_bar + eta_a);
2586                }
2587                dims.push(d);
2588            } else {
2589                // Isotropic enrollment — either a 1-D term, a multi-D term
2590                // without explicit anisotropy, or a basis (e.g. Duchon) whose
2591                // η is a fixed geometry parameter rather than a REML hyper
2592                // axis. Exactly one ψ̄ slot, matching the single
2593                // `SpatialPsiDerivative` produced by
2594                // `try_build_spatial_term_log_kappa_derivativeinfo`.
2595                vals.push(psi_bar);
2596                dims.push(1);
2597            }
2598        }
2599        Self {
2600            values: Array1::from_vec(vals),
2601            dims_per_term: dims,
2602        }
2603    }
2604
2605    /// Isotropic lower bounds derived from per-term data geometry.
2606    /// Each entry gets the ψ_lo bound returned by `spatial_term_psi_bounds`
2607    /// for the corresponding term, intersected with the options window.
2608    pub fn lower_bounds_from_data(
2609        data: ArrayView2<'_, f64>,
2610        spec: &TermCollectionSpec,
2611        term_indices: &[usize],
2612        options: &SpatialLengthScaleOptimizationOptions,
2613    ) -> Self {
2614        let mut values = Array1::<f64>::zeros(term_indices.len());
2615        for (slot, &term_idx) in term_indices.iter().enumerate() {
2616            values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).0;
2617        }
2618        Self {
2619            values,
2620            dims_per_term: vec![1; term_indices.len()],
2621        }
2622    }
2623
2624    /// Isotropic upper bounds derived from per-term data geometry.
2625    pub fn upper_bounds_from_data(
2626        data: ArrayView2<'_, f64>,
2627        spec: &TermCollectionSpec,
2628        term_indices: &[usize],
2629        options: &SpatialLengthScaleOptimizationOptions,
2630    ) -> Self {
2631        let mut values = Array1::<f64>::zeros(term_indices.len());
2632        for (slot, &term_idx) in term_indices.iter().enumerate() {
2633            values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).1;
2634        }
2635        Self {
2636            values,
2637            dims_per_term: vec![1; term_indices.len()],
2638        }
2639    }
2640
2641    /// Anisotropic-aware lower bounds derived from per-term data geometry.
2642    /// For hybrid anisotropic terms the scalar ψ_lo bound applies to the
2643    /// mean `ψ̄`, not directly to every raw axis coordinate `ψ_a = ψ̄ + η_a`.
2644    /// Shift each axis by the current centered `η_a` so projecting/clamping
2645    /// the seed moves only the global scale direction and does not silently
2646    /// shrink anisotropy that is already consistent with the current
2647    /// `length_scale`.
2648    ///
2649    /// Pure Duchon anisotropy is structurally different: its stored
2650    /// coordinates are (d-1) free η_a values representing log axis-scale
2651    /// ratios, NOT log-κ. For those terms the κ-range geometry bound is
2652    /// over-restrictive (η_a = ±5 is normal, but that corresponds to 7+
2653    /// orders of magnitude in κ-space and would be rejected by the data
2654    /// window). Fall back to the options window `[-ln(max_ls), -ln(min_ls)]`
2655    /// for those coordinates — that's the same bound the pre-data-geometry
2656    /// code used, which is calibrated to allow legitimate anisotropy.
2657    pub fn lower_bounds_aniso_from_data(
2658        data: ArrayView2<'_, f64>,
2659        spec: &TermCollectionSpec,
2660        term_indices: &[usize],
2661        dims_per_term: &[usize],
2662        options: &SpatialLengthScaleOptimizationOptions,
2663    ) -> Self {
2664        Self::aniso_bounds_from_data(
2665            data,
2666            spec,
2667            term_indices,
2668            dims_per_term,
2669            options,
2670            AnisoBoundEnd::Lower,
2671        )
2672    }
2673
2674    /// Anisotropic-aware upper bounds derived from per-term data geometry.
2675    /// See `lower_bounds_aniso_from_data` for the hybrid-aniso offsetting and
2676    /// pure-Duchon dispatch rationale.
2677    pub fn upper_bounds_aniso_from_data(
2678        data: ArrayView2<'_, f64>,
2679        spec: &TermCollectionSpec,
2680        term_indices: &[usize],
2681        dims_per_term: &[usize],
2682        options: &SpatialLengthScaleOptimizationOptions,
2683    ) -> Self {
2684        Self::aniso_bounds_from_data(
2685            data,
2686            spec,
2687            term_indices,
2688            dims_per_term,
2689            options,
2690            AnisoBoundEnd::Upper,
2691        )
2692    }
2693
2694    /// Shared implementation for the lower/upper aniso bounds. The bound end
2695    /// only changes which options scale (`max_length_scale` vs
2696    /// `min_length_scale`) becomes the pure-Duchon fallback bound and which
2697    /// element of the `(lo, hi)` data-geometry tuple is consumed; the
2698    /// per-term cursor walk and aniso-offset handling are identical.
2699    fn aniso_bounds_from_data(
2700        data: ArrayView2<'_, f64>,
2701        spec: &TermCollectionSpec,
2702        term_indices: &[usize],
2703        dims_per_term: &[usize],
2704        options: &SpatialLengthScaleOptimizationOptions,
2705        end: AnisoBoundEnd,
2706    ) -> Self {
2707        assert_eq!(term_indices.len(), dims_per_term.len());
2708        let total: usize = dims_per_term.iter().sum();
2709        let mut values = Array1::<f64>::zeros(total);
2710        let mut cursor = 0;
2711        for (slot, &term_idx) in term_indices.iter().enumerate() {
2712            let d = dims_per_term[slot];
2713            // Measure-jet: per-coordinate dial boxes, never κ-window geometry
2714            // (which would reject legitimate dial values outright).
2715            if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
2716                let bounds = measure_jet_psi_bound_values(mj, matches!(end, AnisoBoundEnd::Upper));
2717                for (offset, bound) in bounds.into_iter().enumerate() {
2718                    if offset < d {
2719                        values[cursor + offset] = bound;
2720                    }
2721                }
2722                cursor += d;
2723                continue;
2724            }
2725            // Constant-curvature: the single signed-κ box from the data chart
2726            // window (symmetric about κ = 0), never a κ = log-scale window.
2727            if constant_curvature_term_spec(spec, term_idx).is_some() {
2728                let (lo, hi) = constant_curvature_kappa_bounds(data, spec, term_idx);
2729                if d >= 1 {
2730                    values[cursor] = match end {
2731                        AnisoBoundEnd::Lower => lo,
2732                        AnisoBoundEnd::Upper => hi,
2733                    };
2734                }
2735                cursor += d;
2736                continue;
2737            }
2738            let psi_bound = {
2739                let (lo, hi) = spatial_term_psi_bounds(data, spec, term_idx, options);
2740                match end {
2741                    AnisoBoundEnd::Lower => lo,
2742                    AnisoBoundEnd::Upper => hi,
2743                }
2744            };
2745            let axis_offsets = if d <= 1 {
2746                vec![0.0; d]
2747            } else {
2748                get_spatial_aniso_log_scales(spec, term_idx)
2749                    .filter(|eta| eta.len() == d)
2750                    .map(|eta| center_aniso_log_scales(&eta))
2751                    .unwrap_or_else(|| vec![0.0; d])
2752            };
2753            for offset in 0..d {
2754                values[cursor + offset] = psi_bound + axis_offsets[offset];
2755            }
2756            cursor += d;
2757        }
2758        Self {
2759            values,
2760            dims_per_term: dims_per_term.to_vec(),
2761        }
2762    }
2763
2764    /// Rewrite any ψ entries whose originating term lacks an explicit
2765    /// `length_scale` so they sit at the midpoint of the per-term data-derived
2766    /// ψ window. Used so the outer optimizer starts inside the physically
2767    /// meaningful region instead of at an arbitrary `options.max_length_scale`
2768    /// derived seed. For terms with an explicit length_scale, the user's
2769    /// choice is respected. Anisotropy offsets η_a (those stored by
2770    /// `from_length_scales_aniso`) are preserved: we re-center around the new
2771    /// ψ̄, keeping Ση_a = 0.
2772    pub fn reseed_from_data(
2773        mut self,
2774        data: ArrayView2<'_, f64>,
2775        spec: &TermCollectionSpec,
2776        term_indices: &[usize],
2777        options: &SpatialLengthScaleOptimizationOptions,
2778    ) -> Self {
2779        assert_eq!(term_indices.len(), self.dims_per_term.len());
2780        let mut cursor = 0;
2781        for (slot, &term_idx) in term_indices.iter().enumerate() {
2782            let d = self.dims_per_term[slot];
2783            // Measure-jet dials are seeded from the realized spec and must
2784            // not be recentered into a κ data window.
2785            if measure_jet_term_spec(spec, term_idx).is_some() {
2786                cursor += d;
2787                continue;
2788            }
2789            // Constant-curvature κ is seeded from the spec (the user's curvature
2790            // hint, default κ = 0); `clamp_to_bounds` projects it feasible. It
2791            // is not a log-scale, so the log-κ recenter below never applies.
2792            if constant_curvature_term_spec(spec, term_idx).is_some() {
2793                cursor += d;
2794                continue;
2795            }
2796            let Some(psi_bar_new) = spatial_term_psi_seed(data, spec, term_idx, options) else {
2797                cursor += d;
2798                continue;
2799            };
2800            if d == 0 {
2801                continue;
2802            }
2803            let current: Vec<f64> = self.values.slice(s![cursor..cursor + d]).to_vec();
2804            let psi_bar_old = current.iter().sum::<f64>() / d as f64;
2805            for (offset, &old_value) in current.iter().enumerate() {
2806                self.values[cursor + offset] = psi_bar_new + (old_value - psi_bar_old);
2807            }
2808            cursor += d;
2809        }
2810        self
2811    }
2812
2813    /// Project ψ values into `[lower, upper]` element-wise. Used after
2814    /// `from_length_scales*` + `reseed_from_data` when a user-supplied
2815    /// `spec.length_scale` falls outside the data-derived ψ window set by
2816    /// `{lower,upper}_bounds*_from_data`. BFGS requires theta0 ∈ [lower,
2817    /// upper]; projecting is the unique closest feasible seed. The user's
2818    /// length_scale was always a hint for the outer optimizer (the optimizer
2819    /// is authoritative for κ), not a hard constraint — so clipping preserves
2820    /// their intent as far as the geometry allows. Emits `log::info!` when
2821    /// any coordinate moves, so the outside-window case is diagnostically
2822    /// visible (not silent).
2823    pub fn clamp_to_bounds(
2824        mut self,
2825        lower: &SpatialLogKappaCoords,
2826        upper: &SpatialLogKappaCoords,
2827    ) -> Self {
2828        assert_eq!(self.values.len(), lower.values.len());
2829        assert_eq!(self.values.len(), upper.values.len());
2830        let mut n_projected = 0usize;
2831        let mut worst_delta = 0.0_f64;
2832        for idx in 0..self.values.len() {
2833            let lo = lower.values[idx];
2834            let hi = upper.values[idx];
2835            if !(lo.is_finite() && hi.is_finite()) {
2836                continue;
2837            }
2838            let v = self.values[idx];
2839            if v < lo {
2840                worst_delta = worst_delta.max(lo - v);
2841                self.values[idx] = lo;
2842                n_projected += 1;
2843            } else if v > hi {
2844                worst_delta = worst_delta.max(v - hi);
2845                self.values[idx] = hi;
2846                n_projected += 1;
2847            }
2848        }
2849        if n_projected > 0 {
2850            log::info!(
2851                "[spatial-kappa] projected {n_projected}/{} ψ seed coords into data-derived bounds \
2852                 (worst excess={worst_delta:.3} log units); user length_scale falls outside \
2853                 [{KERNEL_RANGE_MIN_DIAMETER_FRACTION}/r_max, {KERNEL_RANGE_MAX_SPACING_MULTIPLE}/r_min] geometry window",
2854                self.values.len()
2855            );
2856        }
2857        self
2858    }
2859
2860    /// Reconstruct from theta tail with known dimensionality layout.
2861    pub fn from_theta_tail_with_dims(
2862        theta: &Array1<f64>,
2863        start: usize,
2864        dims_per_term: Vec<usize>,
2865    ) -> Self {
2866        let total: usize = dims_per_term.iter().sum();
2867        Self {
2868            values: theta.slice(s![start..start + total]).to_owned(),
2869            dims_per_term,
2870        }
2871    }
2872
2873    /// Total number of ψ values in the flat array (= sum of dims_per_term).
2874    pub fn len(&self) -> usize {
2875        self.values.len()
2876    }
2877
2878    /// Dimensionality layout: how many ψ values each term contributes.
2879    pub fn dims_per_term(&self) -> &[usize] {
2880        &self.dims_per_term
2881    }
2882
2883    /// Get the offset into the flat array for logical term i.
2884    fn term_offset(&self, term_idx: usize) -> usize {
2885        self.dims_per_term[..term_idx].iter().sum()
2886    }
2887
2888    /// Get the slice of ψ values for logical term i.
2889    pub fn term_slice(&self, term_idx: usize) -> &[f64] {
2890        let offset = self.term_offset(term_idx);
2891        let d = self.dims_per_term[term_idx];
2892        &self.values.as_slice().unwrap()[offset..offset + d]
2893    }
2894
2895    pub fn as_array(&self) -> &Array1<f64> {
2896        &self.values
2897    }
2898
2899    /// #1464: overwrite the single ψ value of a scalar (1-D) logical term by its
2900    /// position `slot` in this coords vector (the same ordering as the
2901    /// `term_indices` slice the constructors were built from). Used to inject the
2902    /// fixed-κ sign-basin seed into a constant-curvature term's raw-κ slot before
2903    /// the joint solve. No-op (returns `false`) when the slot is not scalar.
2904    pub fn set_scalar_slot(&mut self, slot: usize, value: f64) -> bool {
2905        if slot >= self.dims_per_term.len() || self.dims_per_term[slot] != 1 {
2906            return false;
2907        }
2908        let offset = self.term_offset(slot);
2909        self.values[offset] = value;
2910        true
2911    }
2912
2913    /// Split at a logical-term boundary. `mid` is the number of terms in the
2914    /// first half (not a flat-array index).
2915    pub fn split_at(&self, mid: usize) -> (Self, Self) {
2916        let flat_mid: usize = self.dims_per_term[..mid].iter().sum();
2917        (
2918            Self {
2919                values: self.values.slice(s![0..flat_mid]).to_owned(),
2920                dims_per_term: self.dims_per_term[..mid].to_vec(),
2921            },
2922            Self {
2923                values: self.values.slice(s![flat_mid..]).to_owned(),
2924                dims_per_term: self.dims_per_term[mid..].to_vec(),
2925            },
2926        )
2927    }
2928
2929    /// Apply optimized ψ values back to the spec.
2930    ///
2931    /// For isotropic terms (dims=1): sets scalar length_scale = exp(−ψ).
2932    /// For anisotropic terms (dims=d): hybrid/isotropic families set
2933    /// length_scale = exp(−ψ̄) with centered η_a = ψ_a − ψ̄, while pure Duchon
2934    /// writes only centered η_a and leaves length_scale = None.
2935    pub fn apply_tospec(
2936        &self,
2937        spec: &TermCollectionSpec,
2938        term_indices: &[usize],
2939    ) -> Result<TermCollectionSpec, EstimationError> {
2940        if term_indices.len() != self.dims_per_term.len() {
2941            crate::bail_invalid_estim!(
2942                "SpatialLogKappaCoords::apply_tospec: term count mismatch: \
2943                 term_indices={} dims_per_term={}",
2944                term_indices.len(),
2945                self.dims_per_term.len()
2946            );
2947        }
2948        let mut updated = spec.clone();
2949        for (slot, &term_idx) in term_indices.iter().enumerate() {
2950            let psi = self.term_slice(slot);
2951            let d = self.dims_per_term[slot];
2952            // Measure-jet: write the dial coordinates straight back; the
2953            // κ-translation below would misread them as log-scales.
2954            if measure_jet_term_spec(&updated, term_idx).is_some() {
2955                set_measure_jet_psi_dials(&mut updated, term_idx, psi)?;
2956                continue;
2957            }
2958            // Constant-curvature: write the optimized signed κ straight back;
2959            // the −exp(ψ) length-scale translation below is log-κ semantics and
2960            // would misread the raw curvature.
2961            if constant_curvature_term_spec(&updated, term_idx).is_some() {
2962                set_constant_curvature_kappa(&mut updated, term_idx, psi)?;
2963                continue;
2964            }
2965            let (next_length_scale, next_aniso) = spatial_term_psi_to_length_scale_and_aniso(psi);
2966            if (d == 1 || next_length_scale.is_some())
2967                && let Some(length_scale) = next_length_scale
2968            {
2969                set_spatial_length_scale(&mut updated, term_idx, length_scale)?;
2970            }
2971            if let Some(eta) = next_aniso {
2972                set_spatial_aniso_log_scales(&mut updated, term_idx, eta)?;
2973            }
2974        }
2975        Ok(updated)
2976    }
2977}
2978
2979pub fn center_aniso_log_scales(eta: &[f64]) -> Vec<f64> {
2980    if eta.len() <= 1 {
2981        return eta.to_vec();
2982    }
2983    let mean = eta.iter().sum::<f64>() / eta.len() as f64;
2984    eta.iter()
2985        .map(|&v| {
2986            let centered = v - mean;
2987            if centered.abs() <= 1e-15 {
2988                0.0
2989            } else {
2990                centered
2991            }
2992        })
2993        .collect()
2994}
2995
2996/// Whether a spatial term contributes per-axis ψ entries to the outer joint
2997/// hyperparameter vector.
2998pub fn spatial_term_uses_per_axis_psi(resolvedspec: &TermCollectionSpec, term_idx: usize) -> bool {
2999    if let Some(mj) = measure_jet_term_spec(resolvedspec, term_idx) {
3000        return measure_jet_enrolls_psi(mj);
3001    }
3002    let Some(d) = get_spatial_feature_dim(resolvedspec, term_idx) else {
3003        return false;
3004    };
3005    if d <= 1 {
3006        return false;
3007    }
3008    let Some(eta) = get_spatial_aniso_log_scales(resolvedspec, term_idx) else {
3009        return false;
3010    };
3011    if eta.len() != d {
3012        return false;
3013    }
3014    !matches!(
3015        resolvedspec.smooth_terms.get(term_idx).map(|term| &term.basis),
3016        Some(SmoothBasisSpec::Duchon { .. })
3017    )
3018}
3019
3020pub fn set_spatial_length_scale(
3021    spec: &mut TermCollectionSpec,
3022    term_idx: usize,
3023    length_scale: f64,
3024) -> Result<(), EstimationError> {
3025    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3026        crate::bail_invalid_estim!("spatial length-scale term index {term_idx} out of range");
3027    };
3028    match &mut term.basis {
3029        SmoothBasisSpec::ThinPlate { spec, .. } => {
3030            spec.length_scale = length_scale;
3031            Ok(())
3032        }
3033        SmoothBasisSpec::Matern { spec, .. } => {
3034            spec.length_scale = length_scale;
3035            Ok(())
3036        }
3037        SmoothBasisSpec::Duchon { spec, .. } => {
3038            spec.length_scale = Some(length_scale);
3039            Ok(())
3040        }
3041        _ => Err(EstimationError::InvalidInput(format!(
3042            "term '{}' does not expose a spatial length scale",
3043            term.name
3044        ))),
3045    }
3046}
3047
3048pub fn get_spatial_length_scale(spec: &TermCollectionSpec, term_idx: usize) -> Option<f64> {
3049    spec.smooth_terms
3050        .get(term_idx)
3051        .and_then(|term| match &term.basis {
3052            SmoothBasisSpec::ThinPlate { spec, .. } => Some(spec.length_scale),
3053            SmoothBasisSpec::Matern { spec, .. } => Some(spec.length_scale),
3054            SmoothBasisSpec::Duchon { spec, .. } => spec.length_scale,
3055            _ => None,
3056        })
3057}
3058
3059pub fn spatial_term_supports_hyper_optimization(spec: &TermCollectionSpec, term_idx: usize) -> bool {
3060    // Ordinary penalized thin-plate regression splines do not have an
3061    // identifiable kernel scale once REML is already learning the smoothing
3062    // penalty. Treat the resolved length scale as fixed geometry; enrolling a
3063    // scalar TPS kappa axis creates the flat ρ/κ valleys reported in #718,
3064    // #721, #731, and #732.
3065    if let Some(term) = spec.smooth_terms.get(term_idx)
3066        && let SmoothBasisSpec::ThinPlate { .. } = &term.basis
3067    {
3068        return false;
3069    }
3070
3071    // Duchon anisotropy η is a FIXED, geometry-derived basis parameter, NOT a
3072    // REML hyper axis: the metric is estimated once from the knot-cloud spread
3073    // (`auto_seed_aniso_contrasts`, applied on every Duchon basis build) and the
3074    // Hilbert-scale λ's carry all learned smoothness. So a pure Duchon (no κ)
3075    // contributes no outer optimization axis even when `scale_dims` is on —
3076    // "standardize the geometry, then learn the smoothness." Only an explicit
3077    // kernel length scale κ (the Matérn / hybrid path) is optimized here.
3078    //
3079    // ISOTROPIC Matérn: the *default* `matern(x1, x2)` is isotropic
3080    // (`scale_dims=false` → `aniso_log_scales = None`). It contributes exactly
3081    // ONE κ optimization axis — its scalar log-κ. The shared GAMLSS /
3082    // location-scale exact-joint ψ engine and the spatial-κ joint outer solver
3083    // both require an isotropic Matérn block to expose this single isotropic κ
3084    // axis (#822/#851); without it the per-block ψ-derivative lists are empty
3085    // and the joint-ψ hooks degenerate to `None`. The isotropic κ is the lone
3086    // kernel hyper axis here, mirroring the per-axis ψ ARD that the anisotropic
3087    // path exposes (just collapsed to one dimension).
3088    //
3089    // ANISOTROPIC Matérn (`scale_dims=true` → `aniso_log_scales = Some`) keeps
3090    // its per-axis kernel-η ARD: the d-dimensional ψ search is the *point* of
3091    // the anisotropic request ("Matérn keeps its kernel-η ARD").
3092    //
3093    // Either way a Matérn term always enrolls a κ/ψ axis (1 isotropic, or d
3094    // anisotropic), so `spatial_dims_per_term` reports the correct count.
3095    if let Some(term) = spec.smooth_terms.get(term_idx)
3096        && let SmoothBasisSpec::Matern { .. } = &term.basis
3097    {
3098        return true;
3099    }
3100
3101    // Measure-jet geometry dials are outer ψ coordinates; enrollment is
3102    // owned by `measure_jet_enrolls_psi`.
3103    if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
3104        return measure_jet_enrolls_psi(mj);
3105    }
3106
3107    // Constant-curvature smooths always enroll their single signed curvature κ
3108    // as an outer ψ-coordinate (#944 stage 3): κ̂ is the headline estimand, so
3109    // unlike a fixed-ℓ kernel it is fitted by default, not gated on a
3110    // user-supplied scale. The coordinate is raw κ (interior κ = 0), and its
3111    // exact design/penalty κ-derivatives come from
3112    // `build_constant_curvature_basis_kappa_derivatives`.
3113    if constant_curvature_term_spec(spec, term_idx).is_some() {
3114        return true;
3115    }
3116
3117    get_spatial_length_scale(spec, term_idx).is_some()
3118}
3119
3120/// The measure-jet term's spec, when `term_idx` is a measure-jet smooth.
3121/// Single accessor for every dial-plumbing dispatch below.
3122pub fn measure_jet_term_spec(
3123    spec: &TermCollectionSpec,
3124    term_idx: usize,
3125) -> Option<&crate::basis::MeasureJetBasisSpec> {
3126    spec.smooth_terms
3127        .get(term_idx)
3128        .and_then(|term| match &term.basis {
3129            SmoothBasisSpec::MeasureJet { spec, .. } => Some(spec),
3130            _ => None,
3131        })
3132}
3133
3134/// Single source for measure-jet outer-ψ enrollment: the lnτ dial is
3135/// undefined in the τ = 0 pseudo-inverse oracle mode (see
3136/// `build_measure_jet_basis_psi_derivatives`), so only a positive ridge
3137/// enrolls the dial group. `spatial_term_supports_hyper_optimization` and
3138/// `spatial_term_uses_per_axis_psi` both defer here so the θ-layout
3139/// sources cannot disagree.
3140pub fn measure_jet_enrolls_psi(mj: &crate::basis::MeasureJetBasisSpec) -> bool {
3141    // Two independent enrollment sources (#1116), both explicit:
3142    //   * the design-moving representer length-scale ℓ (`learn_length_scale`),
3143    //     available in every mode when the spec opts in;
3144    //   * the multiscale penalty dials (s, α, lnτ): the per-scale spectral
3145    //     split's (α, lnτ) ride the explicit `multiscale` opt-in, and the lnτ
3146    //     channel additionally needs a positive ridge (τ = 0 is the
3147    //     pseudo-inverse oracle mode where lnτ is undefined).
3148    // A term enrolls if EITHER source is active.
3149    measure_jet_learns_length_scale(mj)
3150        || (mj.tau0 > 0.0 && crate::basis::measure_jet_multiscale_mode(mj))
3151}
3152
3153/// Whether the design-moving ℓ dial is enrolled for this term. ℓ is fixed by
3154/// default and learnable in every mode only when `learn_length_scale = true`.
3155pub fn measure_jet_learns_length_scale(mj: &crate::basis::MeasureJetBasisSpec) -> bool {
3156    mj.learn_length_scale
3157}
3158
3159pub fn freeze_measure_jet_length_scale_learning(spec: &mut TermCollectionSpec) -> usize {
3160    let mut frozen = 0;
3161    for term in spec.smooth_terms.iter_mut() {
3162        if let SmoothBasisSpec::MeasureJet { spec: mj, .. } = &mut term.basis
3163            && mj.learn_length_scale
3164        {
3165            mj.learn_length_scale = false;
3166            frozen += 1;
3167        }
3168    }
3169    frozen
3170}
3171
3172/// Measure-jet ψ dial boxes. The dials are NOT log-kernel-scales, so the
3173/// κ-window machinery never applies: `α` spans density-weighted (0) through
3174/// past-Coifman–Lafon (>1) normalization, and `lnτ` covers the ridge from
3175/// numerically-exact-projection to heavy noise-floor damping. (The energy
3176/// order `s` is the pinned explicit value or absorbed by the REML-learned
3177/// per-scale amplitudes — see `measure_jet_penalty_psi_dim` — so it carries no
3178/// dial box.)
3179pub const MEASURE_JET_PSI_ALPHA_BOUNDS: (f64, f64) = (-1.0, 3.0);
3180
3181pub const MEASURE_JET_PSI_LN_TAU_BOUNDS: (f64, f64) = (-18.420680743952367, 4.605170185988092);
3182
3183/// Log-ℓ box for the design-moving representer length-scale dial (#1116). An
3184/// ABSOLUTE window in the data coordinate scale (ln of ℓ ∈ [1e-3, 1e2]) used
3185/// only when the spec explicitly enrolls the learned representer range. Absolute
3186/// (not seed-relative) so the bound producer needs no data view, matching the
3187/// other dial boxes. `ln(1e-3) = -6.9077…`, `ln(1e2) = 4.6051…`.
3188pub const MEASURE_JET_PSI_LN_LENGTH_SCALE_BOUNDS: (f64, f64) = (-6.907755278982137, 4.605170185988092);
3189
3190/// Number of multiscale PENALTY dials (excluding the design-moving ℓ):
3191/// multiscale (per-scale spectral) mode carries (α, lnτ) = 2 — the order is
3192/// either the pinned explicit `s` or absorbed by the REML-learned per-scale
3193/// amplitudes, so it is NOT a dial; single-scale (the default) carries none.
3194/// MUST agree with the penalty-coordinate layout of
3195/// `build_measure_jet_basis_psi_derivatives` (its `per_level` branch always
3196/// emits exactly the (α, lnτ) coordinate pair).
3197pub fn measure_jet_penalty_psi_dim(mj: &crate::basis::MeasureJetBasisSpec) -> usize {
3198    if crate::basis::measure_jet_multiscale_mode(mj) {
3199        2
3200    } else {
3201        0
3202    }
3203}
3204
3205/// ψ dimension of a measure-jet term. The design-moving ℓ dial (when enrolled)
3206/// is coordinate 0; the multiscale penalty dials follow. MUST agree with the
3207/// coordinate layout of `build_measure_jet_basis_psi_derivatives` (ℓ first).
3208pub fn measure_jet_psi_dim(mj: &crate::basis::MeasureJetBasisSpec) -> usize {
3209    usize::from(measure_jet_learns_length_scale(mj)) + measure_jet_penalty_psi_dim(mj)
3210}
3211
3212/// Seed ψ from the term's realized dials, in producer coordinate order: ℓ first
3213/// (when enrolled), then the multiscale penalty dials. The ℓ seed is the
3214/// realized representer range `ln(length_scale)` (the resolved spec carries the
3215/// concrete auto value after the design build/freeze).
3216pub fn measure_jet_psi_seed(mj: &crate::basis::MeasureJetBasisSpec) -> Vec<f64> {
3217    let mut seed = Vec::with_capacity(measure_jet_psi_dim(mj));
3218    if measure_jet_learns_length_scale(mj) {
3219        // length_scale > 0 after resolution; the 0.0 sentinel (pre-resolution)
3220        // falls back to the centre of the log-ℓ box so the optimizer still
3221        // starts feasible and the first data-aware reseed corrects it.
3222        let ell = if mj.length_scale > 0.0 {
3223            mj.length_scale
3224        } else {
3225            1.0
3226        };
3227        seed.push(ell.ln());
3228    }
3229    if measure_jet_penalty_psi_dim(mj) > 0 {
3230        // Multiscale penalty dials, producer order: (α, lnτ).
3231        let ln_tau = mj.tau0.max(f64::MIN_POSITIVE).ln();
3232        seed.extend_from_slice(&[mj.alpha, ln_tau]);
3233    }
3234    seed
3235}
3236
3237/// One end of the per-coordinate dial boxes, in producer coordinate order
3238/// (ℓ first when enrolled, then the multiscale penalty dials).
3239pub fn measure_jet_psi_bound_values(mj: &crate::basis::MeasureJetBasisSpec, upper: bool) -> Vec<f64> {
3240    let pick = |b: (f64, f64)| if upper { b.1 } else { b.0 };
3241    let mut bounds = Vec::with_capacity(measure_jet_psi_dim(mj));
3242    if measure_jet_learns_length_scale(mj) {
3243        bounds.push(pick(MEASURE_JET_PSI_LN_LENGTH_SCALE_BOUNDS));
3244    }
3245    if measure_jet_penalty_psi_dim(mj) > 0 {
3246        // Multiscale penalty dials, producer order: (α, lnτ).
3247        bounds.push(pick(MEASURE_JET_PSI_ALPHA_BOUNDS));
3248        bounds.push(pick(MEASURE_JET_PSI_LN_TAU_BOUNDS));
3249    }
3250    bounds
3251}
3252
3253/// Write optimized ψ dials back into a measure-jet spec. Returns `true` when
3254/// any dial actually moved. The geometry (centers, masses, band, ℓ, z) is
3255/// ψ-FIXED by contract — only the dials change, so frozen-quadrature
3256/// rebuilds reproduce the identical penalty layout at the new dials.
3257pub fn apply_measure_jet_psi(
3258    mj: &mut crate::basis::MeasureJetBasisSpec,
3259    psi: &[f64],
3260) -> Result<bool, EstimationError> {
3261    if psi.len() != measure_jet_psi_dim(mj) {
3262        crate::bail_invalid_estim!(
3263            "measure-jet ψ write-back dimension mismatch: got {} values for a {}-dial term",
3264            psi.len(),
3265            measure_jet_psi_dim(mj)
3266        );
3267    }
3268    let mut changed = false;
3269    // Coordinate 0 (when enrolled) is the design-moving ln(ℓ); the multiscale
3270    // penalty dials follow. Same order as `measure_jet_psi_seed` and the
3271    // producer (`build_measure_jet_basis_psi_derivatives`).
3272    let mut cursor = 0usize;
3273    if measure_jet_learns_length_scale(mj) {
3274        let next_ell = psi[cursor].exp();
3275        cursor += 1;
3276        if !(next_ell.is_finite() && next_ell > 0.0) {
3277            crate::bail_invalid_estim!(
3278                "measure-jet ψ write-back produced a non-finite/non-positive length_scale (ℓ={next_ell})"
3279            );
3280        }
3281        if next_ell != mj.length_scale {
3282            mj.length_scale = next_ell;
3283            changed = true;
3284        }
3285    }
3286    if measure_jet_penalty_psi_dim(mj) > 0 {
3287        // Multiscale penalty dials, producer order: (α, lnτ). The order `s` is
3288        // not a dial (pinned explicit or absorbed by the per-scale amplitudes).
3289        let next_alpha = psi[cursor];
3290        let next_tau = psi[cursor + 1].exp();
3291        if !(next_alpha.is_finite() && next_tau.is_finite() && next_tau > 0.0) {
3292            crate::bail_invalid_estim!(
3293                "measure-jet ψ write-back produced non-finite dials (alpha={next_alpha}, tau={next_tau})"
3294            );
3295        }
3296        if next_alpha != mj.alpha {
3297            mj.alpha = next_alpha;
3298            changed = true;
3299        }
3300        if next_tau != mj.tau0 {
3301            mj.tau0 = next_tau;
3302            changed = true;
3303        }
3304    }
3305    Ok(changed)
3306}
3307
3308/// Collection-level measure-jet dial write-back (the `apply_tospec` /
3309/// realizer-side entry). Returns whether anything moved.
3310pub fn set_measure_jet_psi_dials(
3311    spec: &mut TermCollectionSpec,
3312    term_idx: usize,
3313    psi: &[f64],
3314) -> Result<bool, EstimationError> {
3315    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3316        crate::bail_invalid_estim!("measure-jet ψ write-back: term index {term_idx} out of range");
3317    };
3318    set_single_term_measure_jet_psi_dials(term, psi)
3319}
3320
3321/// Single-term dial write-back: the shared match+apply core, also used
3322/// directly on the cached per-trial build spec (whose caller has already
3323/// change-checked at the collection level and rebuilds regardless of the
3324/// moved flag).
3325pub fn set_single_term_measure_jet_psi_dials(
3326    term: &mut SmoothTermSpec,
3327    psi: &[f64],
3328) -> Result<bool, EstimationError> {
3329    let SmoothBasisSpec::MeasureJet { spec: mj, .. } = &mut term.basis else {
3330        crate::bail_invalid_estim!("measure-jet ψ write-back targeted a non-measure-jet term");
3331    };
3332    apply_measure_jet_psi(mj, psi)
3333}
3334
3335/// The constant-curvature smooth's spec, when `term_idx` is one. Single
3336/// accessor for every κ-ψ dispatch below, mirroring `measure_jet_term_spec`.
3337pub fn constant_curvature_term_spec(
3338    spec: &TermCollectionSpec,
3339    term_idx: usize,
3340) -> Option<&crate::basis::ConstantCurvatureBasisSpec> {
3341    spec.smooth_terms
3342        .get(term_idx)
3343        .and_then(|term| match &term.basis {
3344            SmoothBasisSpec::ConstantCurvature { spec, .. } => Some(spec),
3345            _ => None,
3346        })
3347}
3348
3349/// Hard positive cap on |κ| relative to the data's inverse squared chart
3350/// radius. The κ-stereographic chart is valid for `1 + κ‖x‖² > 0`; at
3351/// `|κ| = 1/R²` (R² = max squared chart radius) the gauge `1 + κ‖x‖²` reaches
3352/// the chart edge for the farthest data point, so the optimizer is boxed to a
3353/// safe fraction of that scale on both sides. κ = 0 (flat) is the centre of
3354/// the window, an interior point of the `S^d ← ℝ^d → H^d` family — exactly the
3355/// reachability the raw-κ (not log-κ) coordinate exists to preserve.
3356pub const CONSTANT_CURVATURE_KAPPA_CHART_FRACTION: f64 = 0.5;
3357
3358/// Floor on the data's squared chart radius used to scale the κ window, so a
3359/// degenerate (near-origin) point cloud still yields a finite, usable bracket
3360/// rather than an unbounded one.
3361pub const CONSTANT_CURVATURE_MIN_CHART_RADIUS2: f64 = 1e-8;
3362
3363/// `(κ_min, κ_max)` outer-optimization window for a constant-curvature term,
3364/// derived from the data's maximum squared chart radius `R²` so the κ-jets
3365/// never leave the κ-stereographic chart. Symmetric about κ = 0:
3366/// `±CONSTANT_CURVATURE_KAPPA_CHART_FRACTION / R²`.
3367pub fn constant_curvature_kappa_bounds(
3368    data: ArrayView2<'_, f64>,
3369    spec: &TermCollectionSpec,
3370    term_idx: usize,
3371) -> (f64, f64) {
3372    let feature_cols = match spec.smooth_terms.get(term_idx).map(|t| &t.basis) {
3373        Some(SmoothBasisSpec::ConstantCurvature { feature_cols, .. }) => feature_cols,
3374        _ => return (-1.0, 1.0),
3375    };
3376    let mut max_r2 = CONSTANT_CURVATURE_MIN_CHART_RADIUS2;
3377    for row in data.outer_iter() {
3378        let mut r2 = 0.0_f64;
3379        for &c in feature_cols.iter() {
3380            if let Some(&v) = row.get(c)
3381                && v.is_finite()
3382            {
3383                r2 += v * v;
3384            }
3385        }
3386        if r2 > max_r2 {
3387            max_r2 = r2;
3388        }
3389    }
3390    let half = CONSTANT_CURVATURE_KAPPA_CHART_FRACTION / max_r2;
3391    (-half, half)
3392}
3393
3394/// Write the optimized κ back into a constant-curvature term spec. Returns
3395/// `true` when κ moved. Centers, ℓ, and the constraint transform `z` are
3396/// κ-FIXED by the basis κ-contract, so only `kappa` changes.
3397pub fn set_constant_curvature_kappa(
3398    spec: &mut TermCollectionSpec,
3399    term_idx: usize,
3400    psi: &[f64],
3401) -> Result<bool, EstimationError> {
3402    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3403        crate::bail_invalid_estim!(
3404            "constant-curvature κ write-back: term index {term_idx} out of range"
3405        );
3406    };
3407    set_single_term_constant_curvature_kappa(term, psi)
3408}
3409
3410/// Single-term κ write-back: the shared validate+apply core, also used directly
3411/// on the cached per-trial build spec in the incremental realizer (whose caller
3412/// has already change-checked at the collection level and rebuilds regardless
3413/// of the moved flag). Mirrors [`set_single_term_measure_jet_psi_dials`].
3414pub fn set_single_term_constant_curvature_kappa(
3415    term: &mut SmoothTermSpec,
3416    psi: &[f64],
3417) -> Result<bool, EstimationError> {
3418    if psi.len() != 1 {
3419        crate::bail_invalid_estim!(
3420            "constant-curvature κ write-back expects exactly one value, got {}",
3421            psi.len()
3422        );
3423    }
3424    let next_kappa = psi[0];
3425    if !next_kappa.is_finite() {
3426        crate::bail_invalid_estim!(
3427            "constant-curvature κ write-back produced a non-finite κ = {next_kappa}"
3428        );
3429    }
3430    let SmoothBasisSpec::ConstantCurvature { spec: cc, .. } = &mut term.basis else {
3431        crate::bail_invalid_estim!(
3432            "constant-curvature κ write-back targeted a non-constant-curvature term"
3433        );
3434    };
3435    if cc.kappa != next_kappa {
3436        cc.kappa = next_kappa;
3437        Ok(true)
3438    } else {
3439        Ok(false)
3440    }
3441}
3442
3443/// Returns `true` when a spatial term has NO outer optimization axes — i.e.
3444/// the user provided an explicit `length_scale` and the term does not enroll
3445/// REML-side per-axis ψ contrasts, so both the scalar κ and any fixed geometry
3446/// anisotropy are anchored.
3447///
3448/// This is the per-term predicate that distinguishes "fixed kernel scale"
3449/// from "optimize the kernel scale" within the family entry points that
3450/// want to honor an explicit user-supplied scale (e.g. Bernoulli
3451/// marginal-slope, where the joint-spatial outer solver otherwise spends
3452/// ~80 iters stalled on the user's chosen ρ at high gradient).
3453pub fn spatial_term_has_locked_kappa(spec: &TermCollectionSpec, term_idx: usize) -> bool {
3454    get_spatial_length_scale(spec, term_idx).is_some()
3455        && !spatial_term_uses_per_axis_psi(spec, term_idx)
3456}
3457
3458pub fn all_spatial_terms_kappa_fixed(spec: &TermCollectionSpec) -> bool {
3459    spec.smooth_terms.iter().enumerate().all(|(idx, _)| {
3460        !spatial_term_supports_hyper_optimization(spec, idx)
3461            || spatial_term_has_locked_kappa(spec, idx)
3462    })
3463}
3464
3465pub fn spatial_identifiability_policy(termspec: &SmoothTermSpec) -> Option<&SpatialIdentifiability> {
3466    match &termspec.basis {
3467        SmoothBasisSpec::ThinPlate { spec, .. } => Some(&spec.identifiability),
3468        SmoothBasisSpec::Duchon { spec, .. } => Some(&spec.identifiability),
3469        _ => None,
3470    }
3471}
3472
3473/// Standard deviation of the wide, weakly-informative symmetric `Normal` prior
3474/// placed on a relaxable double-penalty smooth's `DoublePenaltyNullspace`
3475/// selection coordinate when the fit is well-determined.
3476pub const NULLSPACE_WELLDET_DEGENERACY_RHO_SD: f64 = 15.0;
3477
3478/// True iff `prior` is the well-determined double-penalty null-space
3479/// degeneracy prior placed on a `DoublePenaltyNullspace` selection coordinate.
3480pub fn is_nullspace_degeneracy_prior(prior: &gam_spec::RhoPrior) -> bool {
3481    matches!(
3482        prior,
3483        gam_spec::RhoPrior::Normal { mean, sd }
3484            if *mean == 0.0 && *sd == NULLSPACE_WELLDET_DEGENERACY_RHO_SD
3485    )
3486}
3487
3488/// Per-term data-derived ψ = log κ bounds.
3489///
3490/// Uses the same safe operating range documented in
3491/// [`crate::basis::build_matern_basis`] / [`crate::basis::build_duchon_basis`]:
3492///   κ ∈ [2 / r_max, 1e2 / r_min]
3493/// where (r_min, r_max) are pairwise-distance extrema of the term's resolved
3494/// centers (post-fit) or the standardized feature data columns (pre-fit).
3495/// Lower edge of the data-derived kernel-range window, as a fraction of the
3496/// maximum pairwise distance `r_max`: length scales below `2/r_max` resolve
3497/// structure finer than the closest center pair, so the kernel range floor is
3498/// set at twice the maximum spacing.
3499pub const KERNEL_RANGE_MIN_DIAMETER_FRACTION: f64 = 2.0;
3500
3501/// Upper edge of the data-derived kernel-range window, as a multiple of the
3502/// minimum pairwise distance `r_min`: beyond `100/r_min` the radial columns go
3503/// nearly collinear with the polynomial nullspace, so the kernel range is
3504/// capped here to keep the basis geometry well-conditioned.
3505pub const KERNEL_RANGE_MAX_SPACING_MULTIPLE: f64 = 1e2;
3506
3507
3508/// Returns ψ-space bounds (ψ_lo = ln(κ_lo), ψ_hi = ln(κ_hi)).
3509///
3510/// When geometry is unavailable (e.g., fewer than 2 distinct points), falls
3511/// back to the scalar `options.min_length_scale` / `options.max_length_scale`
3512/// window so the outer optimizer never sees NaN bounds.
3513///
3514/// The returned window is intersected with the options window so user-set
3515/// `min_length_scale` / `max_length_scale` remain hard limits.
3516pub fn spatial_term_psi_bounds(
3517    data: ArrayView2<'_, f64>,
3518    spec: &TermCollectionSpec,
3519    term_idx: usize,
3520    options: &SpatialLengthScaleOptimizationOptions,
3521) -> (f64, f64) {
3522    let fallback = (
3523        -options.max_length_scale.ln(),
3524        -options.min_length_scale.ln(),
3525    );
3526    // Constant-curvature: the ψ coordinate is the raw signed κ, so its window is
3527    // the chart-feasible κ bracket, NOT a log-ℓ window. Mirrors the aniso bounds
3528    // path's `constant_curvature_kappa_bounds` branch so the isotropic
3529    // (non-aniso) seed clamp projects κ into the right interval.
3530    if constant_curvature_term_spec(spec, term_idx).is_some() {
3531        return constant_curvature_kappa_bounds(data, spec, term_idx);
3532    }
3533    let Some(term) = spec.smooth_terms.get(term_idx) else {
3534        return fallback;
3535    };
3536    // Prefer resolved centers (post-fit) since they live in the same standardized
3537    // space the kernel actually sees. Centers are capped at `default_num_centers`
3538    // (<=2000), so exact pairwise bounds are cheap (<4M ops). If centers are
3539    // not yet UserProvided, fall back to the standardized feature data columns
3540    // with the capped-sample path (O(K²·d), K=1024) — the sample is
3541    // conservative for κ bounds (see `pairwise_distance_bounds_sampled`
3542    // docs): it never excludes a feasible κ the exact method would include.
3543    //
3544    // Under anisotropy the kernel metric is y-space (y_a = exp(η_a) x_a),
3545    // so r_min/r_max must be y-space distances. This matters only when the
3546    // spec already carries calibrated η_a at setup time (e.g., warm-start
3547    // or refit paths); for fresh optimization η_a starts at 0 and y = x.
3548    let aniso = get_spatial_aniso_log_scales(spec, term_idx);
3549    let r_bounds = match spatial_term_center_strategy(term) {
3550        Some(CenterStrategy::UserProvided(centers)) if centers.nrows() >= 2 => {
3551            match aniso.as_deref() {
3552                Some(eta) if eta.len() == centers.ncols() => {
3553                    let y = points_in_aniso_y_space(centers.view(), eta);
3554                    pairwise_distance_bounds(y.view())
3555                }
3556                _ => pairwise_distance_bounds(centers.view()),
3557            }
3558        }
3559        _ => standardized_spatial_term_data(data, term)
3560            .ok()
3561            .and_then(|x| match aniso.as_deref() {
3562                Some(eta) if eta.len() == x.ncols() => {
3563                    let y = points_in_aniso_y_space(x.view(), eta);
3564                    pairwise_distance_bounds_sampled(y.view())
3565                }
3566                _ => pairwise_distance_bounds_sampled(x.view()),
3567            }),
3568    };
3569    let Some((r_min, r_max)) = r_bounds else {
3570        return fallback;
3571    };
3572    // Length scales substantially larger than the data diameter make radial
3573    // TPS/Matern columns nearly collinear with their polynomial nullspace.
3574    // The nullspace already carries constant/linear low-frequency structure,
3575    // so cap the kernel range at the diameter scale instead of letting the
3576    // optimizer enter a numerically degenerate basis geometry.
3577    let psi_lo_data = (KERNEL_RANGE_MIN_DIAMETER_FRACTION / r_max).ln();
3578    let psi_hi_data = (KERNEL_RANGE_MAX_SPACING_MULTIPLE / r_min).ln();
3579    // #1074: the Matérn-specific length-scale ceiling that used to live here was
3580    // deleted. It was masking, not fixing, the real defect: a hard upper bound on
3581    // the kernel range that pinned the κ-optimizer short rather than letting the
3582    // optimizer find the REML optimum. Matérn now shares the same generic geometry
3583    // window as Duchon / TPS (`KERNEL_RANGE_MIN_DIAMETER_FRACTION / r_max` floor,
3584    // `KERNEL_RANGE_MAX_SPACING_MULTIPLE / r_min` ceiling); the #1357 fully-flat
3585    // collapse corner is guarded by the EDF-collapse guard in
3586    // `spatial_optimization.rs`, which acts on the realized fit, not on a clamp.
3587    // Intersect with the options window so min/max_length_scale remain hard caps.
3588    let psi_lo = psi_lo_data.max(fallback.0);
3589    let psi_hi = psi_hi_data.min(fallback.1);
3590    if psi_lo >= psi_hi {
3591        // Degenerate intersection — fall back to the options window to keep the
3592        // outer optimizer from collapsing to a point.
3593        return fallback;
3594    }
3595    (psi_lo, psi_hi)
3596}
3597
3598/// Data-derived ψ seed for a spatial term when the user has not set an
3599/// explicit length_scale on its basis spec. Uses the geometric mean of the
3600/// data-informed kappa range (i.e., the midpoint of the ψ window).
3601pub fn spatial_term_psi_seed(
3602    data: ArrayView2<'_, f64>,
3603    spec: &TermCollectionSpec,
3604    term_idx: usize,
3605    options: &SpatialLengthScaleOptimizationOptions,
3606) -> Option<f64> {
3607    if get_spatial_length_scale(spec, term_idx).is_some() {
3608        return None; // user/spec-provided length_scale wins
3609    }
3610    let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, spec, term_idx, options);
3611    Some(0.5 * (psi_lo + psi_hi))
3612}
3613
3614pub fn spatial_term_psi_to_length_scale_and_aniso(psi: &[f64]) -> (Option<f64>, Option<Vec<f64>>) {
3615    if psi.len() <= 1 {
3616        (Some((-psi.first().copied().unwrap_or(0.0)).exp()), None)
3617    } else {
3618        let psi_bar = psi.iter().sum::<f64>() / psi.len() as f64;
3619        (
3620            Some((-psi_bar).exp()),
3621            Some(psi.iter().map(|&value| value - psi_bar).collect()),
3622        )
3623    }
3624}
3625
3626/// Get the `aniso_log_scales` from a spatial term, if present.
3627pub fn get_spatial_aniso_log_scales(
3628    spec: &TermCollectionSpec,
3629    term_idx: usize,
3630) -> Option<Vec<f64>> {
3631    spec.smooth_terms
3632        .get(term_idx)
3633        .and_then(|term| match &term.basis {
3634            SmoothBasisSpec::Matern { spec, .. } => spec.aniso_log_scales.clone(),
3635            SmoothBasisSpec::Duchon { spec, .. } => spec.aniso_log_scales.clone(),
3636            _ => None,
3637        })
3638}
3639
3640/// Per-axis response-structure score for anisotropy seeding.
3641///
3642/// For each spatial axis `a`, sort the response `y` by the axis coordinate
3643/// `x_a` and measure the total squared successive variation of the sorted
3644/// response, `tv_a = Σ_i (y_{σ(i+1)} − y_{σ(i)})²` where `σ` orders rows by
3645/// `x_a`. An axis that carries real (possibly nonlinear) signal makes `y` vary
3646/// SMOOTHLY when the rows are walked in that axis's order, so `tv_a` is SMALL;
3647/// a pure-nuisance axis leaves `y` looking unordered, so `tv_a` is LARGE.
3648///
3649/// This deliberately does NOT use a linear correlation `corr(x_a, y)`: for an
3650/// odd, symmetric signal such as `sin(2·x1)` over a symmetric domain the linear
3651/// correlation is ~0 on the *signal* axis, which would misdirect the seed. The
3652/// total-variation-of-sorted-response score captures nonlinear association.
3653///
3654/// Returns `score_a = −½·ln(tv_a + ε)` (larger ⇒ more signal on axis `a`),
3655/// centered to sum to zero, or `None` when the data is degenerate (too few
3656/// rows, non-finite, or all axes equally (un)structured). The caller adds a
3657/// BOUNDED multiple of this to the geometry seed — it is a conservative nudge,
3658/// never a hard override.
3659pub fn response_aware_axis_contrasts(
3660    x: ndarray::ArrayView2<'_, f64>,
3661    y: ndarray::ArrayView1<'_, f64>,
3662) -> Option<Vec<f64>> {
3663    let n = x.nrows();
3664    let d = x.ncols();
3665    if d <= 1 || n < 4 || y.len() != n {
3666        return None;
3667    }
3668    if x.iter().any(|v| !v.is_finite()) || y.iter().any(|v| !v.is_finite()) {
3669        return None;
3670    }
3671    let mut scores = Vec::with_capacity(d);
3672    for a in 0..d {
3673        let mut order: Vec<usize> = (0..n).collect();
3674        let col = x.column(a);
3675        order.sort_by(|&i, &j| {
3676            col[i]
3677                .partial_cmp(&col[j])
3678                .unwrap_or(std::cmp::Ordering::Equal)
3679        });
3680        let mut tv = 0.0_f64;
3681        for w in order.windows(2) {
3682            let diff = y[w[1]] - y[w[0]];
3683            tv += diff * diff;
3684        }
3685        // ε guards against ln(0) on a perfectly flat / constant response.
3686        scores.push(-0.5 * (tv + 1e-12).ln());
3687    }
3688    if scores.iter().any(|v| !v.is_finite()) {
3689        return None;
3690    }
3691    let mean = scores.iter().sum::<f64>() / d as f64;
3692    let centered: Vec<f64> = scores.iter().map(|&s| s - mean).collect();
3693    // If every axis is equally structured the centered scores are ~0 and the
3694    // nudge is a no-op — return None so the geometry seed is used unchanged.
3695    if centered.iter().all(|&v| v.abs() < 1e-9) {
3696        return None;
3697    }
3698    Some(centered)
3699}
3700
3701/// Conservative, response-aware anisotropy seed nudge applied before the κ outer
3702/// loop. For each anisotropic spatial term it adds a BOUNDED multiple of the
3703/// per-axis response-structure contrast (`response_aware_axis_contrasts`) on top
3704/// of the existing geometry seed, so the optimizer starts in the correct basin
3705/// instead of at a response-blind near-symmetric point (the #1376 under-recovery
3706/// where a signal axis and a nuisance axis with equal coordinate spread seed to
3707/// ~[0,0]). The nudge is clamped to keep this a perturbation, never a hard
3708/// override, so shared aniso Matérn/Duchon fits cannot be destabilized by it.
3709pub fn apply_response_aware_anisotropy_seed(
3710    data: ArrayView2<'_, f64>,
3711    y: ndarray::ArrayView1<'_, f64>,
3712    spec: &mut TermCollectionSpec,
3713    spatial_terms: &[usize],
3714) {
3715    // Bound on the per-axis contrast nudge (in η units). One LN_2 ≈ 0.69 halves
3716    // the effective per-axis length scale; capping at LN_2 keeps the seed within
3717    // one optimizer log-step of the geometry seed while still breaking the
3718    // symmetric-seed trap.
3719    const MAX_NUDGE: f64 = std::f64::consts::LN_2;
3720    for &term_idx in spatial_terms {
3721        let Some(current_eta) = get_spatial_aniso_log_scales(spec, term_idx) else {
3722            continue;
3723        };
3724        let d = current_eta.len();
3725        if d <= 1 {
3726            continue;
3727        }
3728        let Some(term) = spec.smooth_terms.get(term_idx) else {
3729            continue;
3730        };
3731        let feature_cols = term.basis.structural_feature_cols();
3732        if feature_cols.len() != d {
3733            continue;
3734        }
3735        let Ok(x) = select_columns(data, &feature_cols) else {
3736            continue;
3737        };
3738        let Some(contrast) = response_aware_axis_contrasts(x.view(), y) else {
3739            continue;
3740        };
3741        let nudged: Vec<f64> = current_eta
3742            .iter()
3743            .zip(contrast.iter())
3744            .map(|(&eta_a, &c_a)| eta_a + c_a.clamp(-MAX_NUDGE, MAX_NUDGE))
3745            .collect();
3746        // `set_spatial_aniso_log_scales` re-centers to Σ η = 0. A term that does
3747        // not support aniso scales is silently skipped (the seed is optional).
3748        if let Err(err) = set_spatial_aniso_log_scales(spec, term_idx, nudged) {
3749            log::debug!(
3750                "[spatial-kappa] response-aware anisotropy seed skipped for term {term_idx}: {err}"
3751            );
3752        }
3753    }
3754}
3755
3756/// Get the number of feature columns (spatial dimensionality) for a spatial term.
3757pub fn get_spatial_feature_dim(spec: &TermCollectionSpec, term_idx: usize) -> Option<usize> {
3758    spec.smooth_terms
3759        .get(term_idx)
3760        .and_then(|term| match &term.basis {
3761            SmoothBasisSpec::ThinPlate { feature_cols, .. } => Some(feature_cols.len()),
3762            SmoothBasisSpec::Matern { feature_cols, .. } => Some(feature_cols.len()),
3763            SmoothBasisSpec::Duchon { feature_cols, .. } => Some(feature_cols.len()),
3764            _ => None,
3765        })
3766}
3767
3768/// Log the learned per-axis spatial anisotropy for all spatial terms that
3769/// have `aniso_log_scales` set after optimization.
3770///
3771/// For scalar-scale families this reports eta, effective per-axis length
3772/// scales, and per-axis kappa values. For pure Duchon it reports the centered
3773/// eta contrasts only.
3774pub fn log_spatial_aniso_scales(spec: &TermCollectionSpec) {
3775    for (term_idx, term) in spec.smooth_terms.iter().enumerate() {
3776        let (aniso, length_scale) = match &term.basis {
3777            SmoothBasisSpec::Matern { spec, .. } => {
3778                (spec.aniso_log_scales.as_ref(), Some(spec.length_scale))
3779            }
3780            SmoothBasisSpec::Duchon { spec, .. } => {
3781                (spec.aniso_log_scales.as_ref(), spec.length_scale)
3782            }
3783            _ => (None, None),
3784        };
3785        let Some(eta) = aniso else { continue };
3786        if eta.is_empty() {
3787            continue;
3788        }
3789        let mut lines = match length_scale {
3790            Some(ls) => format!(
3791                "[spatial-kappa] term {} (\"{}\"): anisotropic length scales optimized (global length_scale={:.4})",
3792                term_idx, term.name, ls
3793            ),
3794            None => format!(
3795                "[spatial-kappa] term {} (\"{}\"): pure Duchon shape anisotropy optimized",
3796                term_idx, term.name
3797            ),
3798        };
3799        for (a, &eta_a) in eta.iter().enumerate() {
3800            if let Some(ls) = length_scale {
3801                let length_a = ls * (-eta_a).exp();
3802                let kappa_a = (1.0 / ls) * eta_a.exp();
3803                lines.push_str(&format!(
3804                    "\n  axis {}: eta={:+.4}, length={:.4}, kappa={:.4}",
3805                    a, eta_a, length_a, kappa_a
3806                ));
3807            } else {
3808                lines.push_str(&format!("\n  axis {}: eta={:+.4}", a, eta_a));
3809            }
3810        }
3811        log::info!("{}", lines);
3812    }
3813}
3814
3815/// Set `aniso_log_scales` on a spatial term's basis spec.
3816pub fn set_spatial_aniso_log_scales(
3817    spec: &mut TermCollectionSpec,
3818    term_idx: usize,
3819    eta: Vec<f64>,
3820) -> Result<(), EstimationError> {
3821    let eta = center_aniso_log_scales(&eta);
3822    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3823        crate::bail_invalid_estim!("spatial aniso_log_scales term index {term_idx} out of range");
3824    };
3825    match &mut term.basis {
3826        SmoothBasisSpec::Matern { spec, .. } => {
3827            spec.aniso_log_scales = Some(eta);
3828            Ok(())
3829        }
3830        SmoothBasisSpec::Duchon { spec, .. } => {
3831            spec.aniso_log_scales = Some(eta);
3832            Ok(())
3833        }
3834        _ => Err(EstimationError::InvalidInput(format!(
3835            "term '{}' does not support aniso_log_scales",
3836            term.name
3837        ))),
3838    }
3839}
3840
3841/// Sync knot-cloud-derived anisotropy contrasts from basis metadata back into
3842/// the mutable spec so the optimizer starts from the correct eta values.
3843///
3844/// Call this after building the smooth design but before initializing the
3845/// optimizer's psi coordinates. For each spatial term whose metadata contains
3846/// computed `aniso_log_scales`, this writes them into the spec.
3847pub fn sync_aniso_contrasts_from_metadata(
3848    spec: &mut TermCollectionSpec,
3849    design: &SmoothDesign,
3850) {
3851    for (term_idx, term) in design.terms.iter().enumerate() {
3852        let meta_aniso = match &term.metadata {
3853            BasisMetadata::Matern {
3854                aniso_log_scales, ..
3855            } => aniso_log_scales.clone(),
3856            BasisMetadata::Duchon {
3857                aniso_log_scales, ..
3858            } => aniso_log_scales.clone(),
3859            _ => None,
3860        };
3861        if let Some(eta) = meta_aniso
3862            && eta.len() > 1
3863        {
3864            set_spatial_aniso_log_scales(spec, term_idx, eta).ok();
3865        }
3866    }
3867}
3868
3869#[derive(Debug, Clone)]
3870pub struct SpatialLengthScaleOptimizationOptions {
3871    /// Enable outer-loop optimization over spatial κ (= 1 / length_scale)
3872    /// for supported radial-kernel smooths.
3873    /// This applies to ThinPlate, Matérn, and Duchon terms.
3874    pub enabled: bool,
3875    /// Maximum number of outer iterations in the exact joint [rho, psi] solve.
3876    pub max_outer_iter: usize,
3877    /// Relative improvement threshold for terminating the outer solve.
3878    pub rel_tol: f64,
3879    /// Initial log(length_scale) perturbation used for seed construction.
3880    pub log_step: f64,
3881    /// Minimum allowed length_scale during κ search.
3882    pub min_length_scale: f64,
3883    /// Maximum allowed length_scale during κ search.
3884    pub max_length_scale: f64,
3885    /// Automatic geometry-initializer threshold for large-scale spatial fits.
3886    ///
3887    /// When n exceeds twice this value, the fitter uses a spatially stratified
3888    /// subsample only to seed κ/anisotropy geometry: centers are resolved,
3889    /// axis contrasts are initialized from center/data spread, and one or two
3890    /// cheap ψ reseeding updates are applied. It never runs PIRLS, REML, ARC,
3891    /// BFGS, or any recursive optimizer on the pilot.
3892    ///
3893    /// The final coefficients, smoothing parameters, and spatial geometry are
3894    /// always optimized on the full dataset.
3895    ///
3896    /// Set to 0 to skip the pilot geometry initializer.
3897    pub pilot_subsample_threshold: usize,
3898    /// Optional wall-clock budget (seconds) for the whole outer smoothing search
3899    /// (gam#979). When a family arms the global deadline from this, an outer
3900    /// search that cannot certify convergence (survival marginal-slope's
3901    /// monotonicity-pinned constrained joint-Newton) returns its best-so-far
3902    /// iterate (or a catchable error) within the budget instead of hanging.
3903    /// `None` keeps the legacy unbounded behavior; the survival marginal-slope
3904    /// path applies a generous default when this is `None`.
3905    pub outer_wall_clock_budget_secs: Option<f64>,
3906}
3907
3908impl Default for SpatialLengthScaleOptimizationOptions {
3909    fn default() -> Self {
3910        Self {
3911            enabled: true,
3912            max_outer_iter: 80,
3913            rel_tol: 1e-4,
3914            log_step: std::f64::consts::LN_2,
3915            min_length_scale: 1e-3,
3916            max_length_scale: 1e3,
3917            pilot_subsample_threshold: 10_000,
3918            outer_wall_clock_budget_secs: None,
3919        }
3920    }
3921}
3922
3923impl SpatialLengthScaleOptimizationOptions {
3924    /// Validate the struct's invariants. Callers that construct these options
3925    /// from external input (CLI, config, Python API) should call this before
3926    /// passing the options into the fitter. Returns `Err` with a descriptive
3927    /// message when an invariant is violated; the fitter then panics or
3928    /// returns `EstimationError` at its own boundary.
3929    ///
3930    /// Invariants:
3931    ///   * `min_length_scale > 0`, finite
3932    ///   * `max_length_scale > 0`, finite
3933    ///   * `min_length_scale < max_length_scale`
3934    ///   * `rel_tol > 0`, finite
3935    ///   * `log_step > 0`, finite
3936    ///
3937    /// These invariants are what the downstream κ-bound and ψ-window code
3938    /// assumes (`-log(max_ls)` must be finite, `(min,max)` must not be
3939    /// inverted, etc.). Without validation, invalid options produce silent
3940    /// NaN-propagation inside the outer optimizer.
3941    pub fn validate(&self) -> Result<(), String> {
3942        if !self.min_length_scale.is_finite() || self.min_length_scale <= 0.0 {
3943            return Err(SmoothError::invalid_config(format!(
3944                "SpatialLengthScaleOptimizationOptions::min_length_scale must be > 0 and finite, got {}",
3945                self.min_length_scale
3946            ))
3947            .into());
3948        }
3949        if !self.max_length_scale.is_finite() || self.max_length_scale <= 0.0 {
3950            return Err(SmoothError::invalid_config(format!(
3951                "SpatialLengthScaleOptimizationOptions::max_length_scale must be > 0 and finite, got {}",
3952                self.max_length_scale
3953            ))
3954            .into());
3955        }
3956        if self.min_length_scale >= self.max_length_scale {
3957            return Err(SmoothError::invalid_config(format!(
3958                "SpatialLengthScaleOptimizationOptions requires min_length_scale < max_length_scale, got min={} max={}",
3959                self.min_length_scale, self.max_length_scale
3960            ))
3961            .into());
3962        }
3963        if !self.rel_tol.is_finite() || self.rel_tol <= 0.0 {
3964            return Err(SmoothError::invalid_config(format!(
3965                "SpatialLengthScaleOptimizationOptions::rel_tol must be > 0 and finite, got {}",
3966                self.rel_tol
3967            ))
3968            .into());
3969        }
3970        if !self.log_step.is_finite() || self.log_step <= 0.0 {
3971            return Err(SmoothError::invalid_config(format!(
3972                "SpatialLengthScaleOptimizationOptions::log_step must be > 0 and finite, got {}",
3973                self.log_step
3974            ))
3975            .into());
3976        }
3977        Ok(())
3978    }
3979}
3980
3981#[derive(Debug, Clone)]
3982pub struct RandomEffectBlock {
3983    pub name: String,
3984    /// O(n) group-label vector: group_ids[i] = column index in [0, num_groups).
3985    /// `None` if the observation's level is not in the kept set.
3986    pub group_ids: Vec<Option<usize>>,
3987    pub num_groups: usize,
3988    pub kept_levels: Vec<u64>,
3989}
3990
3991pub const BLOCK_SPARSE_ZERO_EPS: f64 = 1e-12;
3992
3993pub const BLOCK_SPARSE_MAX_DENSITY: f64 = 0.20;
3994
3995pub fn blocks_have_intrinsic_sparse_structure(blocks: &[DesignBlock]) -> bool {
3996    blocks
3997        .iter()
3998        .any(|block| matches!(block, DesignBlock::Sparse(_) | DesignBlock::RandomEffect(_)))
3999}
4000
4001pub fn sparse_compatible_block_nnz(block: &DesignBlock) -> Option<usize> {
4002    match block {
4003        DesignBlock::Intercept(n) => Some(*n),
4004        DesignBlock::RandomEffect(op) => {
4005            Some(op.group_ids.iter().filter(|gid| gid.is_some()).count())
4006        }
4007        DesignBlock::Sparse(sparse) => Some(sparse.val().len()),
4008        DesignBlock::Dense(dense) => dense.as_dense_ref().map(|matrix| {
4009            matrix
4010                .iter()
4011                .filter(|&&value| value.abs() > BLOCK_SPARSE_ZERO_EPS)
4012                .count()
4013        }),
4014    }
4015}
4016
4017pub fn try_build_sparse_design_from_blocks(
4018    blocks: &[DesignBlock],
4019) -> Result<Option<DesignMatrix>, BasisError> {
4020    if blocks.is_empty() {
4021        return Ok(None);
4022    }
4023    let nrows = blocks[0].nrows();
4024    let ncols: usize = blocks.iter().map(DesignBlock::ncols).sum();
4025    if nrows == 0 || ncols == 0 || ncols <= 32 {
4026        return Ok(None);
4027    }
4028
4029    let preserve_sparse_storage = blocks_have_intrinsic_sparse_structure(blocks);
4030    let sparse_nnz_limit = if preserve_sparse_storage {
4031        usize::MAX
4032    } else {
4033        let total_cells = nrows.saturating_mul(ncols);
4034        ((total_cells as f64) * BLOCK_SPARSE_MAX_DENSITY).floor() as usize
4035    };
4036    let mut nnz = 0usize;
4037    for block in blocks {
4038        let block_nnz = if let Some(block_nnz) = sparse_compatible_block_nnz(block) {
4039            block_nnz
4040        } else {
4041            return Ok(None);
4042        };
4043        nnz = nnz.saturating_add(block_nnz);
4044        if nnz > sparse_nnz_limit {
4045            return Ok(None);
4046        }
4047    }
4048
4049    let mut triplets = Vec::<Triplet<usize, usize, f64>>::with_capacity(nnz);
4050    let mut col_offset = 0usize;
4051    for block in blocks {
4052        match block {
4053            DesignBlock::Intercept(n) => {
4054                for row in 0..*n {
4055                    triplets.push(Triplet::new(row, col_offset, 1.0));
4056                }
4057            }
4058            DesignBlock::RandomEffect(op) => {
4059                for (row, group_id) in op.group_ids.iter().enumerate() {
4060                    if let Some(group) = group_id {
4061                        triplets.push(Triplet::new(row, col_offset + group, 1.0));
4062                    }
4063                }
4064            }
4065            DesignBlock::Sparse(sparse) => {
4066                let (symbolic, values) = sparse.parts();
4067                let col_ptr = symbolic.col_ptr();
4068                let row_idx = symbolic.row_idx();
4069                for col in 0..sparse.ncols() {
4070                    for idx in col_ptr[col]..col_ptr[col + 1] {
4071                        let value = values[idx];
4072                        if value.abs() > BLOCK_SPARSE_ZERO_EPS {
4073                            triplets.push(Triplet::new(row_idx[idx], col_offset + col, value));
4074                        }
4075                    }
4076                }
4077            }
4078            DesignBlock::Dense(dense) => {
4079                let matrix = dense.as_dense_ref().ok_or_else(|| {
4080                    BasisError::InvalidInput(
4081                        "sparse-compatible block assembly requires materialized dense blocks"
4082                            .to_string(),
4083                    )
4084                })?;
4085                for row in 0..matrix.nrows() {
4086                    for col in 0..matrix.ncols() {
4087                        let value = matrix[[row, col]];
4088                        if value.abs() > BLOCK_SPARSE_ZERO_EPS {
4089                            triplets.push(Triplet::new(row, col_offset + col, value));
4090                        }
4091                    }
4092                }
4093            }
4094        }
4095        col_offset += block.ncols();
4096    }
4097
4098    let sparse = SparseColMat::try_new_from_triplets(nrows, ncols, &triplets).map_err(|_| {
4099        BasisError::SparseCreation("failed to assemble sparse term-collection design".to_string())
4100    })?;
4101    Ok(Some(DesignMatrix::Sparse(
4102        gam_linalg::matrix::SparseDesignMatrix::new(sparse),
4103    )))
4104}
4105
4106pub fn assemble_term_collection_design_matrix(
4107    blocks: Vec<DesignBlock>,
4108) -> Result<DesignMatrix, BasisError> {
4109    if let Some(sparse) = try_build_sparse_design_from_blocks(&blocks)? {
4110        return Ok(sparse);
4111    }
4112    let block_op = BlockDesignOperator::new(blocks).map_err(|e| {
4113        BasisError::InvalidInput(format!("failed to build block design operator: {e}"))
4114    })?;
4115    Ok(DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
4116        Arc::new(block_op),
4117    )))
4118}
4119
4120pub fn select_columns(data: ArrayView2<'_, f64>, cols: &[usize]) -> Result<Array2<f64>, BasisError> {
4121    let n = data.nrows();
4122    let p = data.ncols();
4123    for &c in cols {
4124        if c >= p {
4125            crate::bail_dim_basis!("feature column {c} is out of bounds for data with {p} columns");
4126        }
4127    }
4128    let mut out = Array2::<f64>::zeros((n, cols.len()));
4129    for (j, &c) in cols.iter().enumerate() {
4130        out.column_mut(j).assign(&data.column(c));
4131    }
4132    Ok(out)
4133}
4134
4135pub fn nonfinite_value_label(value: f64) -> &'static str {
4136    if value.is_nan() {
4137        "NaN"
4138    } else if value.is_sign_positive() {
4139        "+Inf"
4140    } else {
4141        "-Inf"
4142    }
4143}
4144
4145pub fn validate_term_feature_column_finite(
4146    data: ArrayView2<'_, f64>,
4147    term_kind: &str,
4148    term_name: &str,
4149    feature_col: usize,
4150) -> Result<(), BasisError> {
4151    let p = data.ncols();
4152    if feature_col >= p {
4153        crate::bail_dim_basis!(
4154            "{term_kind} term '{term_name}' feature column {feature_col} out of bounds for {p} columns"
4155        );
4156    }
4157    for (row, &value) in data.column(feature_col).iter().enumerate() {
4158        if !value.is_finite() {
4159            crate::bail_invalid_basis!(
4160                "{term_kind} term '{term_name}' feature column {feature_col} row {row} contains non-finite value {}",
4161                nonfinite_value_label(value)
4162            );
4163        }
4164    }
4165    Ok(())
4166}
4167
4168pub fn validate_smooth_terms_finite_inputs(
4169    data: ArrayView2<'_, f64>,
4170    terms: &[SmoothTermSpec],
4171) -> Result<(), BasisError> {
4172    for term in terms {
4173        for feature_col in smooth_term_feature_cols(term) {
4174            validate_term_feature_column_finite(data, "smooth", &term.name, feature_col)?;
4175        }
4176    }
4177    Ok(())
4178}
4179
4180pub fn validate_term_collection_finite_inputs(
4181    data: ArrayView2<'_, f64>,
4182    spec: &TermCollectionSpec,
4183) -> Result<(), BasisError> {
4184    for term in &spec.linear_terms {
4185        validate_term_feature_column_finite(data, "linear", &term.name, term.feature_col)?;
4186    }
4187    for term in &spec.random_effect_terms {
4188        validate_term_feature_column_finite(data, "random-effect", &term.name, term.feature_col)?;
4189    }
4190    validate_smooth_terms_finite_inputs(data, &spec.smooth_terms)
4191}
4192
4193#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
4194pub struct JointSpatialCenterGroupKey {
4195    feature_cols: Vec<usize>,
4196    strategy_kind: CenterStrategyKind,
4197    strategy_aux: usize,
4198    requested_num_centers: usize,
4199    input_scale_bits: Option<Vec<u64>>,
4200}
4201
4202pub fn spatial_term_min_center_count(term: &SmoothTermSpec) -> usize {
4203    match &term.basis {
4204        SmoothBasisSpec::ThinPlate { feature_cols, .. } => feature_cols.len() + 1,
4205        SmoothBasisSpec::Duchon {
4206            feature_cols, spec, ..
4207        } => match spec.nullspace_order {
4208            crate::basis::DuchonNullspaceOrder::Zero => 1,
4209            crate::basis::DuchonNullspaceOrder::Linear => feature_cols.len() + 1,
4210            crate::basis::DuchonNullspaceOrder::Degree(degree) => {
4211                crate::basis::duchon_nullspace_dimension(feature_cols.len(), degree)
4212            }
4213        },
4214        SmoothBasisSpec::Matern { .. } => 1,
4215        _ => 1,
4216    }
4217}
4218
4219pub fn spatial_term_group_key(term: &SmoothTermSpec) -> Option<JointSpatialCenterGroupKey> {
4220    let (feature_cols, strategy, input_scales) = match &term.basis {
4221        SmoothBasisSpec::ThinPlate {
4222            feature_cols,
4223            spec,
4224            input_scales,
4225        } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4226        SmoothBasisSpec::Matern {
4227            feature_cols,
4228            spec,
4229            input_scales,
4230        } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4231        SmoothBasisSpec::Duchon {
4232            feature_cols,
4233            spec,
4234            input_scales,
4235        } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4236        _ => return None,
4237    };
4238    let strategy_kind = center_strategy_kind(strategy);
4239    let strategy_aux = match strategy {
4240        CenterStrategy::Auto(inner) => match inner.as_ref() {
4241            CenterStrategy::KMeans { max_iter, .. } => *max_iter,
4242            CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
4243            _ => 0,
4244        },
4245        CenterStrategy::KMeans { max_iter, .. } => *max_iter,
4246        CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
4247        _ => 0,
4248    };
4249    Some(JointSpatialCenterGroupKey {
4250        feature_cols: feature_cols.clone(),
4251        strategy_kind,
4252        strategy_aux,
4253        requested_num_centers: center_strategy_num_centers(strategy)?,
4254        input_scale_bits: input_scales
4255            .map(|values| values.iter().map(|value| value.to_bits()).collect()),
4256    })
4257}
4258
4259pub fn spatial_term_center_strategy(term: &SmoothTermSpec) -> Option<&CenterStrategy> {
4260    match &term.basis {
4261        SmoothBasisSpec::ThinPlate { spec, .. } => Some(&spec.center_strategy),
4262        SmoothBasisSpec::Matern { spec, .. } => Some(&spec.center_strategy),
4263        SmoothBasisSpec::Duchon { spec, .. } => Some(&spec.center_strategy),
4264        _ => None,
4265    }
4266}
4267
4268pub fn set_spatial_term_centers(
4269    term: &mut SmoothTermSpec,
4270    centers: Array2<f64>,
4271) -> Result<(), BasisError> {
4272    match &mut term.basis {
4273        SmoothBasisSpec::ThinPlate { spec, .. } => {
4274            spec.center_strategy = CenterStrategy::UserProvided(centers);
4275            Ok(())
4276        }
4277        SmoothBasisSpec::Matern { spec, .. } => {
4278            spec.center_strategy = CenterStrategy::UserProvided(centers);
4279            Ok(())
4280        }
4281        SmoothBasisSpec::Duchon { spec, .. } => {
4282            spec.center_strategy = CenterStrategy::UserProvided(centers);
4283            Ok(())
4284        }
4285        _ => Err(BasisError::InvalidInput(format!(
4286            "term '{}' does not support spatial center planning",
4287            term.name
4288        ))),
4289    }
4290}
4291
4292pub fn standardized_spatial_term_data(
4293    data: ArrayView2<'_, f64>,
4294    term: &SmoothTermSpec,
4295) -> Result<Array2<f64>, BasisError> {
4296    let (feature_cols, input_scales) = match &term.basis {
4297        SmoothBasisSpec::ThinPlate {
4298            feature_cols,
4299            input_scales,
4300            ..
4301        }
4302        | SmoothBasisSpec::Matern {
4303            feature_cols,
4304            input_scales,
4305            ..
4306        }
4307        | SmoothBasisSpec::Duchon {
4308            feature_cols,
4309            input_scales,
4310            ..
4311        } => (feature_cols, input_scales.as_ref()),
4312        _ => {
4313            crate::bail_invalid_basis!("term '{}' is not a spatial smooth", term.name);
4314        }
4315    };
4316    let mut x = select_columns(data, feature_cols)?;
4317    if let Some(scales) = input_scales {
4318        apply_input_standardization(&mut x, scales);
4319    } else if let Some(scales) = compute_spatial_input_scales(x.view()) {
4320        apply_input_standardization(&mut x, &scales);
4321    }
4322    Ok(x)
4323}
4324
4325pub fn plan_joint_spatial_centers_for_term_blocks(
4326    data: ArrayView2<'_, f64>,
4327    term_blocks: &[Vec<SmoothTermSpec>],
4328) -> Result<Vec<Vec<SmoothTermSpec>>, BasisError> {
4329    let mut planned_blocks = term_blocks.to_vec();
4330    let n = data.nrows();
4331    let mut groups: BTreeMap<JointSpatialCenterGroupKey, Vec<(usize, usize)>> = BTreeMap::new();
4332
4333    for (block_idx, terms) in planned_blocks.iter().enumerate() {
4334        for (term_idx, term) in terms.iter().enumerate() {
4335            let Some(strategy) = spatial_term_center_strategy(term) else {
4336                continue;
4337            };
4338            if !center_strategy_is_auto(strategy) {
4339                continue;
4340            }
4341            let Some(group_key) = spatial_term_group_key(term) else {
4342                continue;
4343            };
4344            if !matches!(
4345                group_key.strategy_kind,
4346                CenterStrategyKind::EqualMass
4347                    | CenterStrategyKind::EqualMassCovarRepresentative
4348                    | CenterStrategyKind::FarthestPoint
4349                    | CenterStrategyKind::KMeans
4350            ) {
4351                continue;
4352            }
4353            if center_strategy_num_centers(strategy).is_none() {
4354                continue;
4355            }
4356            groups
4357                .entry(group_key)
4358                .or_default()
4359                .push((block_idx, term_idx));
4360        }
4361    }
4362
4363    for (group_key, members) in groups {
4364        if members.len() < 2 {
4365            continue;
4366        }
4367        let min_required = members
4368            .iter()
4369            .map(|&(block_idx, term_idx)| {
4370                spatial_term_min_center_count(&planned_blocks[block_idx][term_idx])
4371            })
4372            .max()
4373            .unwrap_or(1);
4374        let joint_centers = group_key
4375            .requested_num_centers
4376            .max(min_required)
4377            .min(n.max(1));
4378        let (first_block_idx, first_term_idx) = members[0];
4379        let prototype = &planned_blocks[first_block_idx][first_term_idx];
4380        let standardized = standardized_spatial_term_data(data, prototype)?;
4381        let strategy = spatial_term_center_strategy(prototype).ok_or_else(|| {
4382            BasisError::InvalidInput(format!(
4383                "term '{}' lost its spatial center strategy during joint planning",
4384                prototype.name
4385            ))
4386        })?;
4387        let joint_strategy = center_strategy_with_num_centers(strategy, joint_centers)?;
4388        let shared_centers = select_centers_by_strategy(standardized.view(), &joint_strategy)?;
4389        log::info!(
4390            "sharing {} spatial centers across {} smooth terms over columns {:?} (requested {} centers)",
4391            shared_centers.nrows(),
4392            members.len(),
4393            group_key.feature_cols,
4394            group_key.requested_num_centers,
4395        );
4396        for (block_idx, term_idx) in members {
4397            set_spatial_term_centers(
4398                &mut planned_blocks[block_idx][term_idx],
4399                shared_centers.clone(),
4400            )?;
4401        }
4402    }
4403
4404    // Sentinel auto-init: Matern and thin-plate builders write length_scale =
4405    // 0.0 when the user didn't pass `length_scale=...`. Replace those with a
4406    // data-driven initialization here so REML starts in a regime where it can
4407    // escape; the hard-coded 1.0 default was a basin from which ν ≥ 5/2 Matern
4408    // could not recover on high-frequency truths, silently collapsing the fit
4409    // to a near-constant prediction.
4410    for block in planned_blocks.iter_mut() {
4411        for term in block.iter_mut() {
4412            auto_init_length_scale_in_place(data, term);
4413        }
4414    }
4415
4416    Ok(planned_blocks)
4417}
4418
4419/// Compute a data-driven initial length scale from the per-axis range of the
4420/// feature columns. The heuristic `max_range / sqrt(n)` puts the kernel on
4421/// the wiggly side of REML's basin so the optimizer can grow it back if the
4422/// signal is smooth, but is small enough that high-frequency truths remain
4423/// reachable for smoother kernels (ν ≥ 5/2). Clamped to a tiny positive
4424/// floor so degenerate constant-input columns can't produce 0.
4425pub fn auto_initial_length_scale(data: ArrayView2<'_, f64>, feature_cols: &[usize]) -> f64 {
4426    /// Tiny positive floor for the auto length scale, guarding against a zero
4427    /// kernel range when every feature column is (near-)constant.
4428    const LENGTH_SCALE_FLOOR: f64 = 1e-6;
4429    let n = data.nrows();
4430    if n == 0 || feature_cols.is_empty() {
4431        return 1.0;
4432    }
4433    let mut max_range = 0.0_f64;
4434    for &c in feature_cols {
4435        if c >= data.ncols() {
4436            continue;
4437        }
4438        let col = data.column(c);
4439        let mut lo = f64::INFINITY;
4440        let mut hi = f64::NEG_INFINITY;
4441        for &v in col.iter() {
4442            if v.is_finite() {
4443                if v < lo {
4444                    lo = v;
4445                }
4446                if v > hi {
4447                    hi = v;
4448                }
4449            }
4450        }
4451        if hi > lo {
4452            let r = hi - lo;
4453            if r > max_range {
4454                max_range = r;
4455            }
4456        }
4457    }
4458    if !max_range.is_finite() || max_range <= 0.0 {
4459        return 1.0;
4460    }
4461    let init = max_range / (n as f64).sqrt();
4462    init.max(LENGTH_SCALE_FLOOR).min(max_range)
4463}
4464
4465/// Walk a term and, if it is a Matern or thin-plate smooth whose length_scale
4466/// was left at the auto sentinel (`0.0`), overwrite it with
4467/// [`auto_initial_length_scale`].
4468pub fn auto_init_length_scale_in_place(data: ArrayView2<'_, f64>, term: &mut SmoothTermSpec) {
4469    auto_init_length_scale_in_basis(data, &mut term.basis);
4470}
4471
4472/// Replace the `0.0` auto-init length-scale sentinel with a data-derived value
4473/// for any Matern / thin-plate kernel reachable from this basis — including the
4474/// inner kernel of a `by=`/factor-smooth wrapper.
4475///
4476/// `by=<factor>` and the sum-to-zero factor smooth wrap a spatial kernel inside
4477/// `SmoothBasisSpec::ByVariable` / `SmoothBasisSpec::FactorSumToZero` /
4478/// `SmoothBasisSpec::BySmooth`, so the wrapper variant is what the planner sees.
4479/// Without recursing into the wrapped basis the inner ThinPlate/Matern keeps the
4480/// `0.0` sentinel (the post-`1605b3a6e` builder default), which makes the kernel
4481/// distance divide by `length_scale² = 0`, producing a non-finite design at both
4482/// fit and predict time. Recurse so the inner kernel is initialized identically
4483/// to a top-level one.
4484pub fn auto_init_length_scale_in_basis(data: ArrayView2<'_, f64>, basis: &mut SmoothBasisSpec) {
4485    match basis {
4486        SmoothBasisSpec::Matern {
4487            feature_cols, spec, ..
4488        } => {
4489            if spec.length_scale == 0.0 {
4490                spec.length_scale = auto_initial_length_scale(data, feature_cols);
4491            }
4492        }
4493        SmoothBasisSpec::ThinPlate {
4494            feature_cols, spec, ..
4495        } => {
4496            if spec.length_scale == 0.0 {
4497                spec.length_scale = auto_initial_length_scale(data, feature_cols);
4498            }
4499        }
4500        SmoothBasisSpec::ByVariable { inner, .. }
4501        | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
4502            auto_init_length_scale_in_basis(data, inner);
4503        }
4504        SmoothBasisSpec::BySmooth { smooth, .. } => {
4505            auto_init_length_scale_in_basis(data, smooth);
4506        }
4507        _ => {}
4508    }
4509}
4510
4511impl LinearFitConditioning {
4512    pub fn from_columns(design: &TermCollectionDesign, selected_cols: &[usize]) -> Self {
4513        const SCALE_EPS: f64 = 1e-12;
4514        let n = design.design.nrows();
4515        let p = design.design.ncols();
4516        let mut columns = Vec::with_capacity(selected_cols.len());
4517        if n == 0 || selected_cols.is_empty() {
4518            return Self {
4519                intercept_idx: design.intercept_range.start,
4520                columns,
4521            };
4522        }
4523        let chunk_rows = gam_linalg::utils::row_chunk_for_byte_budget(n, p);
4524        // Two-pass mean/variance so operator-backed designs don't need to
4525        // materialize the full dense matrix. Pass 1 accumulates per-column
4526        // sums; pass 2 accumulates the sum of squared deviations from the
4527        // pass-1 mean. This matches the original `Σ (x − mean)² / n` formula
4528        // without the catastrophic cancellation of `E[X²] − E[X]²`.
4529        let mut sums = vec![0.0_f64; selected_cols.len()];
4530        for start in (0..n).step_by(chunk_rows) {
4531            let end = (start + chunk_rows).min(n);
4532            let chunk = design
4533                .design
4534                .try_row_chunk(start..end)
4535                .expect("LinearFitConditioning::from_columns row chunk failed");
4536            for (k, &col_idx) in selected_cols.iter().enumerate() {
4537                let column = chunk.column(col_idx);
4538                for &v in column.iter() {
4539                    sums[k] += v;
4540                }
4541            }
4542        }
4543        let inv_n = 1.0_f64 / n as f64;
4544        let means: Vec<f64> = sums.iter().map(|&s| s * inv_n).collect();
4545        let mut sq_devs = vec![0.0_f64; selected_cols.len()];
4546        for start in (0..n).step_by(chunk_rows) {
4547            let end = (start + chunk_rows).min(n);
4548            let chunk = design
4549                .design
4550                .try_row_chunk(start..end)
4551                .expect("LinearFitConditioning::from_columns row chunk failed");
4552            for (k, &col_idx) in selected_cols.iter().enumerate() {
4553                let mean_k = means[k];
4554                let column = chunk.column(col_idx);
4555                for &v in column.iter() {
4556                    let d = v - mean_k;
4557                    sq_devs[k] += d * d;
4558                }
4559            }
4560        }
4561        for (k, &col_idx) in selected_cols.iter().enumerate() {
4562            let mean = means[k];
4563            let var = sq_devs[k] * inv_n;
4564            let (mean, scale) = if var.is_finite() && var > SCALE_EPS * SCALE_EPS {
4565                (mean, var.sqrt())
4566            } else {
4567                // Leave nearly-constant columns untouched; centering them would collapse
4568                // the design column to ~0 and change the model rather than just condition it.
4569                (0.0, 1.0)
4570            };
4571            columns.push(LinearColumnConditioning {
4572                col_idx,
4573                mean,
4574                scale,
4575            });
4576        }
4577        Self {
4578            intercept_idx: design.intercept_range.start,
4579            columns,
4580        }
4581    }
4582
4583    pub fn apply_to_design(&self, design: &Array2<f64>) -> Array2<f64> {
4584        let mut out = design.clone();
4585        for col in &self.columns {
4586            {
4587                let mut dst = out.column_mut(col.col_idx);
4588                dst -= col.mean;
4589            }
4590            if col.scale != 1.0 {
4591                out.column_mut(col.col_idx).mapv_inplace(|v| v / col.scale);
4592            }
4593        }
4594        out
4595    }
4596
4597    fn transform_matrix_columnswith_a(&self, mat: &Array2<f64>) -> Array2<f64> {
4598        let mut out = mat.clone();
4599        let intercept = self.intercept_idx;
4600        for col in &self.columns {
4601            let intercept_col = out.column(intercept).to_owned();
4602            let mut target = out.column_mut(col.col_idx);
4603            target -= &(intercept_col * col.mean);
4604            if col.scale != 1.0 {
4605                target.mapv_inplace(|v| v / col.scale);
4606            }
4607        }
4608        out
4609    }
4610
4611    fn transform_matrixrowswith_a_transpose(&self, mat: &Array2<f64>) -> Array2<f64> {
4612        let mut out = mat.clone();
4613        let intercept = self.intercept_idx;
4614        for col in &self.columns {
4615            let interceptrow = out.row(intercept).to_owned();
4616            let mut target = out.row_mut(col.col_idx);
4617            target -= &(interceptrow * col.mean);
4618            if col.scale != 1.0 {
4619                target.mapv_inplace(|v| v / col.scale);
4620            }
4621        }
4622        out
4623    }
4624
4625    /// Left-multiply `mat_internal` by `M⁻ᵀ` where `M⁻¹[intercept, j] = mean_j`
4626    /// and `M⁻¹[j, j] = scale_j` for each conditioned column. Used together
4627    /// with [`Self::right_multiply_by_m_inv`] to back-transform an internal
4628    /// penalized Hessian to the original coefficient basis.
4629    fn left_multiply_by_m_inv_transpose(&self, mat_internal: &Array2<f64>) -> Array2<f64> {
4630        let mut out = mat_internal.clone();
4631        let intercept = self.intercept_idx;
4632        let interceptrow_snapshot = mat_internal.row(intercept).to_owned();
4633        for col in &self.columns {
4634            if col.scale != 1.0 {
4635                out.row_mut(col.col_idx).mapv_inplace(|v| v * col.scale);
4636            }
4637            if col.mean != 0.0 {
4638                let mut target = out.row_mut(col.col_idx);
4639                target += &(&interceptrow_snapshot * col.mean);
4640            }
4641        }
4642        out
4643    }
4644
4645    /// Right-multiply `mat_internal` by `M⁻¹`. Mirror of
4646    /// [`Self::left_multiply_by_m_inv_transpose`] on columns.
4647    fn right_multiply_by_m_inv(&self, mat_internal: &Array2<f64>) -> Array2<f64> {
4648        let mut out = mat_internal.clone();
4649        let intercept = self.intercept_idx;
4650        let intercept_col_snapshot = mat_internal.column(intercept).to_owned();
4651        for col in &self.columns {
4652            if col.scale != 1.0 {
4653                out.column_mut(col.col_idx).mapv_inplace(|v| v * col.scale);
4654            }
4655            if col.mean != 0.0 {
4656                let mut target = out.column_mut(col.col_idx);
4657                target += &(&intercept_col_snapshot * col.mean);
4658            }
4659        }
4660        out
4661    }
4662
4663    /// Transform blockwise penalties through the conditioning.
4664    ///
4665    /// For block-local penalties whose `col_range` does not overlap with any
4666    /// conditioning column, the transform is identity (the conditioning only
4667    /// affects unpenalized linear columns). In that common case the penalty
4668    /// passes through unchanged, avoiding O(p²) materialization entirely.
4669    pub fn transform_blockwise_penalties_to_internal(
4670        &self,
4671        penalties: &[BlockwisePenalty],
4672        p: usize,
4673    ) -> Vec<crate::penalty_spec::PenaltySpec> {
4674        let conditioning_cols: std::collections::HashSet<usize> =
4675            self.columns.iter().map(|c| c.col_idx).collect();
4676        penalties
4677            .iter()
4678            .map(|bp| {
4679                let overlaps =
4680                    (bp.col_range.start..bp.col_range.end).any(|j| conditioning_cols.contains(&j));
4681                if overlaps {
4682                    // Rare: penalty block overlaps conditioning columns.
4683                    // Fall back to dense transform.
4684                    let global = bp.to_global(p);
4685                    let right = self.transform_matrix_columnswith_a(&global);
4686                    let transformed = self.transform_matrixrowswith_a_transpose(&right);
4687                    crate::penalty_spec::PenaltySpec::Dense(transformed)
4688                } else {
4689                    // Common: smooth penalty block doesn't touch linear columns.
4690                    // The conditioning is identity on this block.
4691                    crate::penalty_spec::PenaltySpec::from_blockwise(bp.clone())
4692                }
4693            })
4694            .collect()
4695    }
4696
4697    pub fn backtransform_beta(&self, beta_internal: &Array1<f64>) -> Array1<f64> {
4698        let mut beta = beta_internal.clone();
4699        let intercept = self.intercept_idx;
4700        for col in &self.columns {
4701            beta[intercept] -= beta_internal[col.col_idx] * col.mean / col.scale;
4702            beta[col.col_idx] = beta_internal[col.col_idx] / col.scale;
4703        }
4704        beta
4705    }
4706
4707    /// `H_orig = M⁻ᵀ · H_int · M⁻¹`, derived from
4708    /// `L_int(β_int) = L_orig(M · β_int)` via the chain rule.
4709    pub fn transform_penalized_hessian_to_original(&self, h_internal: &Array2<f64>) -> Array2<f64> {
4710        let right = self.right_multiply_by_m_inv(h_internal);
4711        self.left_multiply_by_m_inv_transpose(&right)
4712    }
4713
4714    pub fn internal_bounds_for(&self, col_idx: usize, min: f64, max: f64) -> (f64, f64) {
4715        if let Some(col) = self.columns.iter().find(|c| c.col_idx == col_idx) {
4716            (min * col.scale, max * col.scale)
4717        } else {
4718            (min, max)
4719        }
4720    }
4721}
4722
4723pub fn freeze_raw_spatial_metadata(metadata: BasisMetadata, raw_cols: usize) -> BasisMetadata {
4724    match metadata {
4725        BasisMetadata::ThinPlate {
4726            centers,
4727            length_scale,
4728            periodic,
4729            identifiability_transform: None,
4730            input_scales,
4731            radial_reparam,
4732        } => BasisMetadata::ThinPlate {
4733            centers,
4734            length_scale,
4735            periodic,
4736            identifiability_transform: Some(Array2::eye(raw_cols)),
4737            input_scales,
4738            radial_reparam,
4739        },
4740        BasisMetadata::Duchon {
4741            centers,
4742            length_scale,
4743            periodic,
4744            power,
4745            nullspace_order,
4746            identifiability_transform: None,
4747            input_scales,
4748            aniso_log_scales,
4749            operator_collocation_points,
4750            radial_reparam,
4751        } => BasisMetadata::Duchon {
4752            centers,
4753            length_scale,
4754            periodic,
4755            power,
4756            nullspace_order,
4757            identifiability_transform: Some(Array2::eye(raw_cols)),
4758            input_scales,
4759            aniso_log_scales,
4760            operator_collocation_points,
4761            radial_reparam,
4762        },
4763        other => other,
4764    }
4765}
4766
4767pub fn matern_operator_penalty_triplet_from_metadata(
4768    metadata: &BasisMetadata,
4769) -> Result<(Vec<Array2<f64>>, Vec<usize>, Vec<PenaltyInfo>), BasisError> {
4770    let BasisMetadata::Matern {
4771        centers,
4772        length_scale,
4773        periodic,
4774        nu,
4775        include_intercept,
4776        identifiability_transform,
4777        aniso_log_scales,
4778        input_scales,
4779        ..
4780    } = metadata
4781    else {
4782        crate::bail_invalid_basis!("Matérn operator penalties require Matérn metadata");
4783    };
4784    // The metadata records `length_scale` in *original* (un-standardized) data
4785    // coordinates, while `centers` live in the *standardized* coordinate frame
4786    // (per-axis division by `input_scales`). The realized design built the
4787    // kernel against those standardized centers using the σ_geom-compensated
4788    // effective length scale `length_scale / σ_geom`. The collocation operators
4789    // here are evaluated on the same standardized centers, so they must use the
4790    // SAME effective length scale — otherwise the penalty regularizes a
4791    // different RKHS range than the design lives in, leaving rough coefficient
4792    // directions effectively unpenalized. That mismatch is benign in 1-D
4793    // (no standardization) but produces a catastrophic out-of-sample blow-up in
4794    // d ≥ 2 where σ_geom ≠ 1 (#706).
4795    let penalty_length_scale = match input_scales.as_deref() {
4796        Some(scales) => compensate_length_scale_for_standardization(*length_scale, scales),
4797        None => *length_scale,
4798    };
4799    matern_operator_penalty_triplet_at_length_scale(
4800        centers.view(),
4801        periodic.as_deref(),
4802        identifiability_transform.as_ref(),
4803        *nu,
4804        *include_intercept,
4805        aniso_log_scales.as_deref(),
4806        penalty_length_scale,
4807    )
4808}
4809
4810/// Build the canonical Matérn operator-penalty triplet (mass / tension /
4811/// stiffness) at an explicit **effective** length scale — i.e. the
4812/// σ_geom-compensated, standardized-frame scale the design's kernel was built
4813/// against (NOT the original-coordinate `length_scale` stored in metadata).
4814///
4815/// This is the SINGLE source of truth for the Matérn penalty topology. Two
4816/// callers route through it and must therefore stay byte-for-byte consistent:
4817///   * the cold/slow design rebuild (`matern_operator_penalty_triplet_from_metadata`,
4818///     compensating the frozen metadata `length_scale`), and
4819///   * the n-free κ-optimizer re-key (`FrozenTermCollectionIncrementalRealizer::
4820///     canonical_penalties_at_psi`, compensating the trial `ψ → exp(-ψ)` scale).
4821///
4822/// Sharing the body makes the penalty BLOCK COUNT and the per-block numerics
4823/// one deterministic function of `(geometry, ν, η, ℓ_eff)`. The active-operator
4824/// gate is `m = ν + d/2`, which is independent of ℓ, so the block count is
4825/// **ψ-stable by construction**: the re-key can never produce a different number
4826/// of blocks than the frozen design (the desync that #1270 hard-errored on).
4827pub fn matern_operator_penalty_triplet_at_length_scale(
4828    centers: ArrayView2<'_, f64>,
4829    periodic: Option<&[Option<f64>]>,
4830    identifiability_transform: Option<&Array2<f64>>,
4831    nu: crate::basis::MaternNu,
4832    include_intercept: bool,
4833    aniso_log_scales: Option<&[f64]>,
4834    effective_length_scale: f64,
4835) -> Result<(Vec<Array2<f64>>, Vec<usize>, Vec<PenaltyInfo>), BasisError> {
4836    let penalty_centers = crate::basis::expand_periodic_centers(&centers.to_owned(), periodic)?;
4837    let ops = build_matern_collocation_operator_matrices(
4838        penalty_centers.view(),
4839        None,
4840        effective_length_scale,
4841        nu,
4842        include_intercept,
4843        identifiability_transform.map(|z| z.view()),
4844        aniso_log_scales,
4845    )?;
4846    // Gate the operator dials on the Matérn-ν RKHS Sobolev order m = ν + d/2:
4847    // mass (j=0) is always on, tension (j=1) is on for m > 1, stiffness (j=2)
4848    // is on for m > 2. The threshold is strict so the roughest kernel ν=1/2 in
4849    // d=1 (m=1, the exponential/OU H¹ process) sheds both higher operators —
4850    // its kernel already encodes the H¹ control, so adding an extra tension
4851    // dial over-smooths the oscillation it is meant to track (#707). The
4852    // matching gate lives at `DuchonOperatorPenaltySpec::matern_for_smoothness`.
4853    const ORDER_EPS: f64 = 1e-9;
4854    let d = penalty_centers.ncols();
4855    let m = nu.half_integer_value() + 0.5 * d as f64;
4856    let mut candidates = Vec::with_capacity(3);
4857    for (raw, source, min_order) in [
4858        (ops.d0.t().dot(&ops.d0), PenaltySource::OperatorMass, 0.0),
4859        (ops.d1.t().dot(&ops.d1), PenaltySource::OperatorTension, 1.0),
4860        (
4861            ops.d2.t().dot(&ops.d2),
4862            PenaltySource::OperatorStiffness,
4863            2.0,
4864        ),
4865    ] {
4866        if min_order > 0.0 && m <= min_order + ORDER_EPS {
4867            continue;
4868        }
4869        let sym = (&raw + &raw.t()) * 0.5;
4870        let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&sym);
4871        candidates.push(PenaltyCandidate {
4872            matrix,
4873            nullspace_dim_hint: 0,
4874            source,
4875            normalization_scale,
4876            kronecker_factors: None,
4877            op: None,
4878        });
4879    }
4880    filter_active_penalty_candidates(candidates)
4881}
4882
4883pub fn normalize_penalty_in_constrained_space(matrix: &Array2<f64>) -> (Array2<f64>, f64) {
4884    // Constrained-space normalization:
4885    //   c = ||S_con||_F,  S_tilde = S_con / c.
4886    // This is the only normalization coherent with a REML objective that is
4887    // evaluated entirely in constrained coordinates.
4888    let matrix = (matrix + &matrix.t().to_owned()) * 0.5;
4889    // Clamp noise-floor negative eigenvalues so β'Sβ is non-negative as a contract, not just in exact arithmetic.
4890    let matrix = crate::basis::project_penalty_to_psd_cone(&matrix);
4891    let c = matrix.iter().map(|v| v * v).sum::<f64>().sqrt();
4892    if c.is_finite() && c > 0.0 {
4893        (matrix.mapv(|v| v / c), c)
4894    } else {
4895        (matrix, 1.0)
4896    }
4897}
4898
4899pub fn tensor_product_design_from_sparse_marginals(
4900    marginal_sparse: &[&SparseColMat<usize, f64>],
4901) -> Result<SparseColMat<usize, f64>, BasisError> {
4902    if marginal_sparse.is_empty() {
4903        crate::bail_invalid_basis!("TensorBSpline requires at least one marginal basis");
4904    }
4905    let n = marginal_sparse[0].nrows();
4906    for (i, m) in marginal_sparse.iter().enumerate().skip(1) {
4907        if m.nrows() != n {
4908            crate::bail_dim_basis!(
4909                "tensor sparse marginal row mismatch at dim {i}: expected {n}, got {}",
4910                m.nrows()
4911            );
4912        }
4913    }
4914    let dims: Vec<usize> = marginal_sparse.iter().map(|m| m.ncols()).collect();
4915    let total_cols = dims.iter().try_fold(1usize, |acc, &q| {
4916        acc.checked_mul(q)
4917            .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))
4918    })?;
4919    let mut strides = vec![1usize; dims.len()];
4920    for d in (0..dims.len().saturating_sub(1)).rev() {
4921        strides[d] = strides[d + 1]
4922            .checked_mul(dims[d + 1])
4923            .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))?;
4924    }
4925
4926    use faer::sparse::SparseRowMat;
4927    let csrs: Vec<SparseRowMat<usize, f64>> = marginal_sparse
4928        .iter()
4929        .enumerate()
4930        .map(|(d, m)| {
4931            m.as_ref().to_row_major().map_err(|e| {
4932                BasisError::SparseCreation(format!(
4933                    "tensor sparse marginal {d} CSR conversion failed: {e:?}"
4934                ))
4935            })
4936        })
4937        .collect::<Result<Vec<_>, _>>()?;
4938    let row_ptrs: Vec<&[usize]> = csrs.iter().map(|c| c.symbolic().row_ptr()).collect();
4939    let col_idxs: Vec<&[usize]> = csrs.iter().map(|c| c.symbolic().col_idx()).collect();
4940    let vals: Vec<&[f64]> = csrs.iter().map(|c| c.val()).collect();
4941
4942    use rayon::prelude::*;
4943    const CHUNK: usize = 1024;
4944    let num_chunks = n.div_ceil(CHUNK);
4945    let per_chunk: Vec<Vec<Triplet<usize, usize, f64>>> = (0..num_chunks)
4946        .into_par_iter()
4947        .map(|chunk_idx| {
4948            let row_start = chunk_idx * CHUNK;
4949            let row_end = (row_start + CHUNK).min(n);
4950            let mut chunk_triplets = Vec::<Triplet<usize, usize, f64>>::new();
4951            let mut cur_cols = Vec::<usize>::with_capacity(64);
4952            let mut cur_vals = Vec::<f64>::with_capacity(64);
4953            let mut next_cols = Vec::<usize>::with_capacity(64);
4954            let mut next_vals = Vec::<f64>::with_capacity(64);
4955            for i in row_start..row_end {
4956                cur_cols.clear();
4957                cur_vals.clear();
4958                cur_cols.push(0);
4959                cur_vals.push(1.0);
4960                let mut row_is_zero = false;
4961                for d in 0..dims.len() {
4962                    let row_start_d = row_ptrs[d][i];
4963                    let row_end_d = row_ptrs[d][i + 1];
4964                    if row_start_d == row_end_d {
4965                        row_is_zero = true;
4966                        break;
4967                    }
4968                    let stride = strides[d];
4969                    next_cols.clear();
4970                    next_vals.clear();
4971                    next_cols.reserve(cur_cols.len() * (row_end_d - row_start_d));
4972                    next_vals.reserve(cur_vals.len() * (row_end_d - row_start_d));
4973                    for (&prev_col, &prev_val) in cur_cols.iter().zip(cur_vals.iter()) {
4974                        for ptr in row_start_d..row_end_d {
4975                            let cj = col_idxs[d][ptr];
4976                            let vj = vals[d][ptr];
4977                            next_cols.push(prev_col + cj * stride);
4978                            next_vals.push(prev_val * vj);
4979                        }
4980                    }
4981                    std::mem::swap(&mut cur_cols, &mut next_cols);
4982                    std::mem::swap(&mut cur_vals, &mut next_vals);
4983                }
4984                if row_is_zero {
4985                    continue;
4986                }
4987                for (&col, &val) in cur_cols.iter().zip(cur_vals.iter()) {
4988                    chunk_triplets.push(Triplet::new(i, col, val));
4989                }
4990            }
4991            chunk_triplets
4992        })
4993        .collect();
4994    let total_nnz: usize = per_chunk.iter().map(Vec::len).sum();
4995    let mut triplets = Vec::<Triplet<usize, usize, f64>>::with_capacity(total_nnz);
4996    for chunk in per_chunk {
4997        triplets.extend(chunk);
4998    }
4999    SparseColMat::try_new_from_triplets(n, total_cols, &triplets).map_err(|e| {
5000        BasisError::SparseCreation(format!(
5001            "failed to assemble sparse tensor product design: {e:?}"
5002        ))
5003    })
5004}
5005
5006pub fn dense_local_margin_to_sparse(
5007    dense: &Array2<f64>,
5008) -> Result<SparseColMat<usize, f64>, BasisError> {
5009    let expected_row_nnz = dense.ncols().min(4);
5010    let mut triplets =
5011        Vec::<Triplet<usize, usize, f64>>::with_capacity(dense.nrows() * expected_row_nnz);
5012    for ((row, col), &value) in dense.indexed_iter() {
5013        if value != 0.0 {
5014            triplets.push(Triplet::new(row, col, value));
5015        }
5016    }
5017    SparseColMat::try_new_from_triplets(dense.nrows(), dense.ncols(), &triplets).map_err(|e| {
5018        BasisError::SparseCreation(format!(
5019            "failed to convert tensor marginal design to sparse form: {e:?}"
5020        ))
5021    })
5022}
5023
5024pub struct TensorMarginRangeNullProjectors {
5025    range: Array2<f64>,
5026    null: Array2<f64>,
5027}
5028
5029pub fn projector_from_columns(columns: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
5030    if indices.is_empty() {
5031        return Array2::<f64>::zeros((columns.nrows(), columns.nrows()));
5032    }
5033    let basis = columns.select(Axis(1), indices);
5034    basis.dot(&basis.t())
5035}
5036
5037pub fn tensor_margin_range_null_projectors(
5038    normalized_marginal_penalties: &[(Array2<f64>, f64)],
5039) -> Result<Vec<TensorMarginRangeNullProjectors>, BasisError> {
5040    normalized_marginal_penalties
5041        .iter()
5042        .enumerate()
5043        .map(|(dim, (penalty, _))| {
5044            let analysis = crate::basis::analyze_penalty_block(penalty)?;
5045            if analysis.rank == 0 {
5046                crate::bail_invalid_basis!(
5047                    "t2 separable tensor penalty margin {dim} has rank-zero penalty; \
5048                     cannot split penalized and null subspaces"
5049                );
5050            }
5051            let mut range_idx = Vec::<usize>::new();
5052            let mut null_idx = Vec::<usize>::new();
5053            for (idx, &ev) in analysis.eigenvalues.iter().enumerate() {
5054                if ev > analysis.tol {
5055                    range_idx.push(idx);
5056                } else {
5057                    null_idx.push(idx);
5058                }
5059            }
5060            Ok(TensorMarginRangeNullProjectors {
5061                range: projector_from_columns(&analysis.eigenvectors, &range_idx),
5062                null: projector_from_columns(&analysis.eigenvectors, &null_idx),
5063            })
5064        })
5065        .collect()
5066}
5067
5068pub fn build_tensor_bspline_basis(
5069    data: ArrayView2<'_, f64>,
5070    feature_cols: &[usize],
5071    spec: &TensorBSplineSpec,
5072) -> Result<BasisBuildResult, BasisError> {
5073    if feature_cols.is_empty() {
5074        crate::bail_invalid_basis!("TensorBSpline requires at least one feature column");
5075    }
5076    if feature_cols.len() != spec.marginalspecs.len() {
5077        crate::bail_dim_basis!(
5078            "TensorBSpline feature/spec mismatch: feature_cols={}, marginalspecs={}",
5079            feature_cols.len(),
5080            spec.marginalspecs.len()
5081        );
5082    }
5083    if !spec.periods.is_empty() && spec.periods.len() != feature_cols.len() {
5084        crate::bail_dim_basis!(
5085            "TensorBSpline periods length {} does not match feature count {}",
5086            spec.periods.len(),
5087            feature_cols.len()
5088        );
5089    }
5090    let p = data.ncols();
5091    for &c in feature_cols {
5092        if c >= p {
5093            crate::bail_dim_basis!(
5094                "tensor feature column {c} is out of bounds for data with {p} columns"
5095            );
5096        }
5097    }
5098
5099    let mut marginal_knots = Vec::<Array1<f64>>::with_capacity(feature_cols.len());
5100    // Per-margin cr flag (#1074): `true` when the margin is a natural cubic
5101    // regression spline, so the tensor freeze rebuilds the cr knotspec.
5102    let mut marginal_is_cr_flags = Vec::<bool>::with_capacity(feature_cols.len());
5103    let mut marginal_degrees = Vec::<usize>::with_capacity(feature_cols.len());
5104    let mut marginalnum_basis = Vec::<usize>::with_capacity(feature_cols.len());
5105    let mut marginal_penalties = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
5106    let mut marginal_designs = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
5107    // Per-margin effective period: either user-set via `spec.periods` or
5108    // implied by a `PeriodicUniform` marginal knotspec (which the 1D B-spline
5109    // builder realizes as a cyclic B-spline basis).
5110    // Captured here so freeze→reload round-trips both routes back to a
5111    // `PeriodicUniform` marginal knotspec; otherwise a `PeriodicUniform`
5112    // margin specified without `spec.periods` would freeze as a plain
5113    // `Provided(knots)` open spline and lose its wrap-around at predict time.
5114    let mut marginal_effective_periods = Vec::<Option<f64>>::with_capacity(feature_cols.len());
5115    // Per-marginal sparse representation, populated when the 1D builder returned
5116    // a `DesignMatrix::Sparse`. Used to assemble the Khatri-Rao tensor product
5117    // sparsely (only ∏(degree+1) nonzeros per row) instead of densifying to
5118    // shape (n, ∏ q_j) up front. Periodic B-spline margins are local-support
5119    // bases too; when the 1D builder returns them densely, we convert that
5120    // marginal back to sparse form so cylinder/torus tensor products keep the
5121    // same scale behavior as open tensor products.
5122    let mut marginal_sparse =
5123        Vec::<Option<SparseColMat<usize, f64>>>::with_capacity(feature_cols.len());
5124
5125    // Reuse the robust 1D builder to ensure the same knot validation and
5126    // marginal difference-penalty construction as standalone smooth terms.
5127    for (dim, (&col, marginalspec)) in feature_cols
5128        .iter()
5129        .zip(spec.marginalspecs.iter())
5130        .enumerate()
5131    {
5132        // Tensor basis uses raw marginal knot-product columns. Applying 1D
5133        // identifiability constraints here would change marginal penalty sizes
5134        // without changing the tensor design construction, causing dimension
5135        // mismatch. Keep marginal builders unconstrained at this stage.
5136        let mut marginal_unconstrained = marginalspec.clone();
5137        marginal_unconstrained.identifiability = BSplineIdentifiability::None;
5138        let built = build_bspline_basis_1d(data.column(col), &marginal_unconstrained)?;
5139        // A cr (`NaturalCubicRegression`) margin emits `CubicRegression1D`
5140        // metadata whose `knots` are the k value-knots; a B-spline margin emits
5141        // `BSpline1D` with the clamped knot vector. Capture either so the
5142        // tensor freeze can rebuild the exact same marginal knotspec (#1074).
5143        let (knots, marginal_is_cr) = match built.metadata {
5144            BasisMetadata::BSpline1D { knots, .. } => (knots, false),
5145            BasisMetadata::CubicRegression1D { knots, .. } => (knots, true),
5146            _ => {
5147                crate::bail_invalid_basis!(
5148                    "internal TensorBSpline error at dim {dim}: expected BSpline1D or CubicRegression1D metadata"
5149                );
5150            }
5151        };
5152        let metadata_knots = match marginalspec.knotspec {
5153            BSplineKnotSpec::PeriodicUniform {
5154                data_range,
5155                num_basis,
5156            } => Array1::linspace(data_range.0, data_range.1, num_basis),
5157            _ => knots,
5158        };
5159        marginal_knots.push(metadata_knots);
5160        marginal_is_cr_flags.push(marginal_is_cr);
5161        marginal_degrees.push(marginalspec.degree);
5162        marginalnum_basis.push(built.design.ncols());
5163        // Capture the sparse representation of this marginal (when the
5164        // 1D builder produced one) before densifying for the dense
5165        // marginal cache used by `tensor_product_design_from_marginals`
5166        // and `TensorProductDesignOperator`.
5167        let dense_marginal = built.design.to_dense();
5168        let sparse_view: Option<SparseColMat<usize, f64>> = match built.design.as_sparse() {
5169            Some(sd) => {
5170                let inner: &SparseColMat<usize, f64> = sd;
5171                Some(inner.clone())
5172            }
5173            None => match marginalspec.knotspec {
5174                BSplineKnotSpec::PeriodicUniform { .. } => {
5175                    Some(dense_local_margin_to_sparse(&dense_marginal)?)
5176                }
5177                _ => None,
5178            },
5179        };
5180        marginal_sparse.push(sparse_view);
5181        marginal_designs.push(dense_marginal);
5182        marginal_penalties.push(
5183            built
5184                .penalties
5185                .first()
5186                .ok_or_else(|| {
5187                    BasisError::InvalidInput(format!(
5188                        "internal TensorBSpline error at dim {dim}: missing marginal penalty"
5189                    ))
5190                })?
5191                .clone(),
5192        );
5193        built.nullspace_dims.first().ok_or_else(|| {
5194            BasisError::InvalidInput(format!(
5195                "internal TensorBSpline error at dim {dim}: missing marginal nullspace dim"
5196            ))
5197        })?;
5198        // A `PeriodicUniform` marginal knotspec implies the margin is
5199        // wrap-around: the 1D builder already realized it as a periodic
5200        // basis, so the tensor product inherits that periodicity. Record
5201        // the period derived from the knotspec's data range so freeze
5202        // restores `PeriodicUniform` on the marginal — otherwise the
5203        // round-trip downgrades it to `Provided(knots)` (an open spline)
5204        // and predict-time wraps disappear.
5205        let implied_period = match marginalspec.knotspec {
5206            BSplineKnotSpec::PeriodicUniform { data_range, .. } => {
5207                Some(data_range.1 - data_range.0)
5208            }
5209            _ => spec.periods.get(dim).and_then(|p| *p),
5210        };
5211        marginal_effective_periods.push(implied_period);
5212    }
5213
5214    let total_cols: usize = marginalnum_basis.iter().product();
5215    let mut dense_design = (!matches!(spec.identifiability, TensorBSplineIdentifiability::None))
5216        .then(|| tensor_product_design_from_marginals(&marginal_designs))
5217        .transpose()?;
5218    let mut candidates = Vec::<PenaltyCandidate>::with_capacity(
5219        match spec.penalty_decomposition {
5220            TensorBSplinePenaltyDecomposition::MarginalKroneckerSum => marginal_penalties.len(),
5221            TensorBSplinePenaltyDecomposition::Separable => marginal_penalties.len() * 2,
5222        } + if spec.double_penalty { 1 } else { 0 },
5223    );
5224
5225    // Tensor-product smoothing parameters are one-per-margin.  Therefore the
5226    // physical penalty attached to a margin must be normalized in that margin's
5227    // own working coordinates before it is embedded in the full tensor product.
5228    // Normalizing only the already-Kroneckered matrix would fold arbitrary
5229    // dimension-dependent identity factors into the margin's lambda and would
5230    // make anisotropic REML/LAML smoothing depend on the other margins' basis
5231    // sizes rather than on the marginal roughness operator itself.
5232    let normalized_marginal_penalties: Vec<(Array2<f64>, f64)> = marginal_penalties
5233        .iter()
5234        .map(normalize_penalty_in_constrained_space)
5235        .collect();
5236    let mut kronecker_marginal_penalties =
5237        Vec::<Array2<f64>>::with_capacity(normalized_marginal_penalties.len());
5238
5239    match spec.penalty_decomposition {
5240        TensorBSplinePenaltyDecomposition::MarginalKroneckerSum => {
5241            // Accumulate the Kronecker-sum of the per-margin penalties,
5242            // `Σ_dim S_dim`, whose null space is exactly the *joint* null space
5243            // of all marginal penalties — the tensor of marginal polynomial
5244            // null spaces. The tensor double penalty (below) shrinks only this
5245            // joint null, never the already-penalized interaction range.
5246            let mut marginal_kron_sum = Array2::<f64>::zeros((total_cols, total_cols));
5247
5248            for dim in 0..normalized_marginal_penalties.len() {
5249                let mut s_dim = Array2::<f64>::eye(1);
5250                let mut factors = Vec::<Array2<f64>>::with_capacity(marginalnum_basis.len());
5251                for (j, &qj) in marginalnum_basis.iter().enumerate() {
5252                    let factor = if j == dim {
5253                        normalized_marginal_penalties[j].0.clone()
5254                    } else {
5255                        Array2::<f64>::eye(qj)
5256                    };
5257                    factors.push(factor.clone());
5258                    s_dim = kronecker_product(&s_dim, &factor);
5259                }
5260                if dim == kronecker_marginal_penalties.len() {
5261                    kronecker_marginal_penalties.push(normalized_marginal_penalties[dim].0.clone());
5262                }
5263                marginal_kron_sum += &s_dim;
5264
5265                candidates.push(PenaltyCandidate {
5266                    matrix: s_dim,
5267                    nullspace_dim_hint: 0,
5268                    source: PenaltySource::TensorMarginal { dim },
5269                    normalization_scale: normalized_marginal_penalties[dim].1,
5270                    kronecker_factors: Some(factors),
5271                    op: None,
5272                });
5273            }
5274
5275            if spec.double_penalty
5276                && let Some(shrink) =
5277                    crate::basis::build_nullspace_shrinkage_penalty(&marginal_kron_sum)?
5278            {
5279                let (matrix, normalization_scale) =
5280                    normalize_penalty_in_constrained_space(&shrink.sym_penalty);
5281                candidates.push(PenaltyCandidate {
5282                    matrix,
5283                    nullspace_dim_hint: 0,
5284                    source: PenaltySource::TensorGlobalRidge,
5285                    normalization_scale,
5286                    kronecker_factors: None,
5287                    op: None,
5288                });
5289            }
5290        }
5291        TensorBSplinePenaltyDecomposition::Separable => {
5292            let projectors = tensor_margin_range_null_projectors(&normalized_marginal_penalties)?;
5293            let n_masks = 1usize.checked_shl(projectors.len() as u32).ok_or_else(|| {
5294                BasisError::InvalidInput(format!(
5295                    "t2 separable tensor penalty supports at most {} margins, got {}",
5296                    usize::BITS - 1,
5297                    projectors.len()
5298                ))
5299            })?;
5300            for mask in 1..n_masks {
5301                let mut matrix = Array2::<f64>::eye(1);
5302                let mut factors = Vec::<Array2<f64>>::with_capacity(projectors.len());
5303                let mut penalized_margins = Vec::<usize>::new();
5304                for (dim, projector) in projectors.iter().enumerate() {
5305                    let use_range = ((mask >> dim) & 1) == 1;
5306                    let factor = if use_range {
5307                        penalized_margins.push(dim);
5308                        projector.range.clone()
5309                    } else {
5310                        projector.null.clone()
5311                    };
5312                    matrix = kronecker_product(&matrix, &factor);
5313                    factors.push(factor);
5314                }
5315                let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&matrix);
5316                candidates.push(PenaltyCandidate {
5317                    matrix,
5318                    nullspace_dim_hint: 0,
5319                    source: PenaltySource::TensorSeparable { penalized_margins },
5320                    normalization_scale,
5321                    kronecker_factors: Some(factors),
5322                    op: None,
5323                });
5324            }
5325
5326            if spec.double_penalty {
5327                let mut matrix = Array2::<f64>::eye(1);
5328                let mut factors = Vec::<Array2<f64>>::with_capacity(projectors.len());
5329                for projector in &projectors {
5330                    matrix = kronecker_product(&matrix, &projector.null);
5331                    factors.push(projector.null.clone());
5332                }
5333                let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&matrix);
5334                candidates.push(PenaltyCandidate {
5335                    matrix,
5336                    nullspace_dim_hint: 0,
5337                    source: PenaltySource::TensorGlobalRidge,
5338                    normalization_scale,
5339                    kronecker_factors: Some(factors),
5340                    op: None,
5341                });
5342            }
5343        }
5344    }
5345
5346    let z_opt = match &spec.identifiability {
5347        TensorBSplineIdentifiability::None => None,
5348        TensorBSplineIdentifiability::SumToZero => {
5349            if total_cols < 2 {
5350                crate::bail_invalid_basis!(
5351                    "TensorBSpline requires at least 2 basis coefficients to enforce sum-to-zero identifiability"
5352                );
5353            }
5354            let dense_design_ref = dense_design.as_ref().ok_or_else(|| {
5355                BasisError::InvalidInput(
5356                    "tensor sum-to-zero identifiability requires a realized basis".to_string(),
5357                )
5358            })?;
5359            let (_, z) = apply_sum_to_zero_constraint(dense_design_ref.view(), None)?;
5360            let gauge = gam_problem::Gauge::sum_to_zero(z);
5361            Some(gauge.block_transform(0))
5362        }
5363        TensorBSplineIdentifiability::MarginalSumToZero => {
5364            // `ti(...)`: drop the marginal main effects by centering every
5365            // margin independently, then form the tensor product of the
5366            // centered margins. Concretely, each margin `j` is reparameterized
5367            // by its own sum-to-zero null basis `Z_j` (so the constant — i.e.
5368            // the marginal intercept — is removed from that axis), and the
5369            // combined reparameterization is the Kronecker product
5370            // `Z = Z₀ ⊗ Z₁ ⊗ … ⊗ Z_{d-1}`. Applying `Z` to the full-tensor
5371            // design `B = B₀ ⊗ … ⊗ B_{d-1}` yields `B Z = (B₀ Z₀) ⊗ … ⊗
5372            // (B_{d-1} Z_{d-1})`, the tensor product of the centered margins,
5373            // which by construction contains no pure main effect.
5374            if marginal_designs.len() < 2 {
5375                crate::bail_invalid_basis!(
5376                    "tensor interaction (ti) identifiability requires at least 2 margins"
5377                );
5378            }
5379            let mut z = Array2::<f64>::eye(1);
5380            for (dim, marginal) in marginal_designs.iter().enumerate() {
5381                if marginal.ncols() < 2 {
5382                    crate::bail_invalid_basis!(
5383                        "tensor interaction (ti) margin {dim} has fewer than 2 basis functions; \
5384                         cannot remove its marginal main effect"
5385                    );
5386                }
5387                let (_, z_dim) = apply_sum_to_zero_constraint(marginal.view(), None)?;
5388                let gauge_dim = gam_problem::Gauge::sum_to_zero(z_dim);
5389                let z_dim = gauge_dim.block_transform(0);
5390                z = kronecker_product(&z, &z_dim);
5391            }
5392            Some(z)
5393        }
5394        TensorBSplineIdentifiability::FrozenTransform { transform } => {
5395            if transform.nrows() != total_cols {
5396                crate::bail_dim_basis!(
5397                    "frozen tensor identifiability transform mismatch: design has {} columns but transform has {} rows",
5398                    total_cols,
5399                    transform.nrows()
5400                );
5401            }
5402            Some(transform.clone())
5403        }
5404    };
5405
5406    if let Some(z) = z_opt.as_ref() {
5407        let gauge = gam_problem::Gauge::from_block_transforms(&[z.clone()]);
5408        let dense = dense_design.as_mut().ok_or_else(|| {
5409            BasisError::InvalidInput(
5410                "tensor identifiability transform requires a realized basis".to_string(),
5411            )
5412        })?;
5413        let restricted_design = gauge.restrict_design(dense);
5414        *dense = restricted_design;
5415        candidates = candidates
5416            .into_iter()
5417            .map(|candidate| -> Result<PenaltyCandidate, BasisError> {
5418                let matrix = gauge.restrict_penalty(&candidate.matrix);
5419                // Re-normalize in the *actual* coefficient chart used by the
5420                // fit.  The tensor sum-to-zero transform is not norm-preserving
5421                // for each overlapping marginal penalty, so carrying the raw
5422                // marginal Frobenius scale into the restricted space changes the
5423                // relative amount of smoothing seen by the LAML/REML optimizer.
5424                // Keep the physical scale in metadata and give the optimizer
5425                // unit-scale constrained penalties for every tensor margin.
5426                let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
5427                Ok(PenaltyCandidate {
5428                    nullspace_dim_hint: candidate.nullspace_dim_hint,
5429                    matrix,
5430                    source: candidate.source,
5431                    normalization_scale: candidate.normalization_scale * c_new,
5432                    // Z^T S Z is no longer a Kronecker product of the original
5433                    // marginal factors, so the Kronecker fast path in construction.rs
5434                    // must not be taken. Clearing kronecker_factors forces the generic
5435                    // block-local eigendecomposition path, which operates on the
5436                    // transformed matrix and is correct.
5437                    kronecker_factors: None,
5438                    op: candidate.op.clone(),
5439                })
5440            })
5441            .collect::<Result<Vec<_>, _>>()?;
5442    }
5443
5444    let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
5445        filter_active_penalty_candidates_with_ops(candidates)?;
5446    let identifiability_is_none =
5447        matches!(spec.identifiability, TensorBSplineIdentifiability::None);
5448    // All marginals expose a sparse representation iff each `marginal_sparse`
5449    // slot is `Some(...)`. Currently this is true when every marginal is a
5450    // free-boundary, non-periodic 1D B-spline returned as
5451    // `DesignMatrix::Sparse` from `build_bspline_basis_1d`. Periodic B-splines
5452    // and other dense-only marginals leave a `None` and trigger the fall-back
5453    // path. Identifiability transforms (`SumToZero`, `FrozenTransform`) make
5454    // the tensor design dense in general, so we also gate on that.
5455    let all_marginals_sparse = marginal_sparse.iter().all(Option::is_some);
5456    let design = if let Some(dense_design) = dense_design {
5457        DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense_design))
5458    } else if identifiability_is_none && all_marginals_sparse {
5459        // Sparse Khatri-Rao path: assemble the (n, ∏ q_j) tensor product
5460        // directly as a SparseColMat, preserving the ∏(degree_j+1) nonzero
5461        // structure per row instead of densifying to ∏ q_j columns. This is
5462        // mathematically identical to `tensor_product_design_from_marginals`
5463        // applied to the corresponding dense marginals.
5464        let sparse_marginals: Vec<&SparseColMat<usize, f64>> = marginal_sparse
5465            .iter()
5466            .map(|m| m.as_ref().expect("all_marginals_sparse just verified"))
5467            .collect();
5468        let sparse_design = tensor_product_design_from_sparse_marginals(&sparse_marginals)?;
5469        DesignMatrix::Sparse(gam_linalg::matrix::SparseDesignMatrix::new(sparse_design))
5470    } else {
5471        let marginals: Vec<Arc<Array2<f64>>> = marginal_designs
5472            .iter()
5473            .map(|m| Arc::new(m.clone()))
5474            .collect();
5475        let op = TensorProductDesignOperator::new(marginals).map_err(|e| {
5476            BasisError::InvalidInput(format!("TensorProductDesignOperator build failed: {e}"))
5477        })?;
5478        DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op)))
5479    };
5480
5481    Ok(BasisBuildResult {
5482        design,
5483        penalties,
5484        nullspace_dims,
5485        penaltyinfo,
5486        ops,
5487        null_eigenvectors,
5488        joint_null_rotation: None,
5489        metadata: BasisMetadata::TensorBSpline {
5490            feature_cols: feature_cols.to_vec(),
5491            knots: marginal_knots,
5492            degrees: marginal_degrees,
5493            // Prefer the per-margin effective period derived in the loop —
5494            // it captures both the explicit `spec.periods` route and the
5495            // implied period from a `PeriodicUniform` marginal knotspec.
5496            // Falling back to `spec.periods` when populated keeps any
5497            // user-supplied explicit period authoritative even if the
5498            // marginal knotspec carried no periodicity hint.
5499            periods: marginal_effective_periods,
5500            is_cr: marginal_is_cr_flags,
5501            identifiability_transform: z_opt,
5502        },
5503        kronecker_factored: if matches!(spec.identifiability, TensorBSplineIdentifiability::None)
5504            && matches!(
5505                spec.penalty_decomposition,
5506                TensorBSplinePenaltyDecomposition::MarginalKroneckerSum
5507            ) {
5508            Some(KroneckerFactoredBasis::new(
5509                marginal_designs,
5510                kronecker_marginal_penalties,
5511                marginalnum_basis.clone(),
5512                spec.double_penalty,
5513            ))
5514        } else {
5515            None
5516        },
5517    })
5518}
5519
5520pub fn tensor_product_design_from_marginals(
5521    marginal_designs: &[Array2<f64>],
5522) -> Result<Array2<f64>, BasisError> {
5523    if marginal_designs.is_empty() {
5524        crate::bail_invalid_basis!("TensorBSpline requires at least one marginal basis");
5525    }
5526    let n = marginal_designs[0].nrows();
5527    for (i, b) in marginal_designs.iter().enumerate().skip(1) {
5528        if b.nrows() != n {
5529            crate::bail_dim_basis!(
5530                "tensor marginal row mismatch at dim {i}: expected {n}, got {}",
5531                b.nrows()
5532            );
5533        }
5534    }
5535    let total_cols = marginal_designs.iter().try_fold(1usize, |acc, b| {
5536        acc.checked_mul(b.ncols())
5537            .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))
5538    })?;
5539    // Tensor-product Khatri-Rao: design[i, j] = Π_d marginal_d[i, j_d]
5540    // where j is the multi-index (j_1, ..., j_D) flattened. Independent
5541    // across rows; parallelize row chunks and fill the pre-allocated
5542    // contiguous Array2 in place (no Vec-flatten-collect intermediate,
5543    // which doubled the peak memory at large-scale N).
5544    use ndarray::parallel::prelude::*;
5545    use rayon::iter::{IntoParallelIterator, ParallelIterator};
5546    let mut design = Array2::<f64>::zeros((n, total_cols));
5547    design
5548        .axis_chunks_iter_mut(ndarray::Axis(0), 1024)
5549        .into_par_iter()
5550        .enumerate()
5551        .for_each(|(chunk_idx, mut block)| {
5552            let row_offset = chunk_idx * 1024;
5553            // Scratch buffers reused across rows in this chunk.
5554            let mut cur = Vec::<f64>::with_capacity(total_cols);
5555            let mut next = Vec::<f64>::with_capacity(total_cols);
5556            for (local_i, mut out_row) in block.outer_iter_mut().enumerate() {
5557                let i = row_offset + local_i;
5558                cur.clear();
5559                cur.push(1.0);
5560                for b in marginal_designs {
5561                    let q = b.ncols();
5562                    next.clear();
5563                    next.resize(cur.len() * q, 0.0);
5564                    // Hoist the row view out of the inner `col` loop so the
5565                    // q reads per `a_idx` reuse a single contiguous slice
5566                    // instead of recomputing `b[[i, col]]` strides per cell.
5567                    let b_row = b.row(i);
5568                    let b_slice = b_row
5569                        .as_slice()
5570                        .expect("Array2 row from outer_iter is contiguous");
5571                    for (a_idx, &aval) in cur.iter().enumerate() {
5572                        let off = a_idx * q;
5573                        let dst = &mut next[off..off + q];
5574                        for col in 0..q {
5575                            dst[col] = aval * b_slice[col];
5576                        }
5577                    }
5578                    std::mem::swap(&mut cur, &mut next);
5579                }
5580                // `out_row` is a row of the contiguous C-major `design`
5581                // Array2, so it is backed by a contiguous slice. Use a
5582                // bulk slice copy instead of an element-by-element write
5583                // loop.
5584                let out_slice = out_row
5585                    .as_slice_mut()
5586                    .expect("design row is contiguous in C-major Array2");
5587                out_slice.copy_from_slice(&cur);
5588            }
5589        });
5590    Ok(design)
5591}
5592
5593pub fn build_random_effect_block(
5594    data: ArrayView2<'_, f64>,
5595    spec: &RandomEffectTermSpec,
5596) -> Result<RandomEffectBlock, BasisError> {
5597    let n = data.nrows();
5598    let p = data.ncols();
5599    if spec.feature_col >= p {
5600        crate::bail_dim_basis!(
5601            "random-effect term '{}' feature column {} out of bounds for {} columns",
5602            spec.name,
5603            spec.feature_col,
5604            p
5605        );
5606    }
5607
5608    let col = data.column(spec.feature_col);
5609    if col.iter().any(|v| !v.is_finite()) {
5610        crate::bail_invalid_basis!(
5611            "random-effect term '{}' contains non-finite group values",
5612            spec.name
5613        );
5614    }
5615
5616    let kept_levels: Vec<u64> = if let Some(levels) = spec.frozen_levels.as_ref() {
5617        if levels.is_empty() {
5618            crate::bail_invalid_basis!(
5619                "random-effect term '{}' has empty frozen_levels",
5620                spec.name
5621            );
5622        }
5623        levels.clone()
5624    } else {
5625        let mut levels_set = BTreeSet::<u64>::new();
5626        for &v in col {
5627            levels_set.insert(v.to_bits());
5628        }
5629        if levels_set.is_empty() {
5630            crate::bail_invalid_basis!("random-effect term '{}' has no observed levels", spec.name);
5631        }
5632        let levels: Vec<u64> = levels_set.into_iter().collect();
5633        let start_idx = if spec.drop_first_level && levels.len() > 1 {
5634            1usize
5635        } else {
5636            0usize
5637        };
5638        levels[start_idx..].to_vec()
5639    };
5640
5641    if kept_levels.is_empty() {
5642        crate::bail_invalid_basis!(
5643            "random-effect term '{}' drops all levels; keep at least one level",
5644            spec.name
5645        );
5646    }
5647
5648    let q = kept_levels.len();
5649    let mut level_to_col = BTreeMap::<u64, usize>::new();
5650    for (idx, &bits) in kept_levels.iter().enumerate() {
5651        if level_to_col.insert(bits, idx).is_some() {
5652            crate::bail_invalid_basis!(
5653                "random-effect term '{}' has duplicate frozen level bits {bits}",
5654                spec.name
5655            );
5656        }
5657    }
5658    let mut group_ids = Vec::with_capacity(n);
5659    for &v in col {
5660        let bits = v.to_bits();
5661        group_ids.push(level_to_col.get(&bits).copied());
5662    }
5663
5664    Ok(RandomEffectBlock {
5665        name: spec.name.clone(),
5666        group_ids,
5667        num_groups: q,
5668        kept_levels,
5669    })
5670}
5671
5672impl SmoothDesign {
5673    /// Map an unconstrained term coefficient vector to its constrained shape space.
5674    /// This is useful for nonlinear fits that optimize unconstrained parameters.
5675    pub fn map_term_coefficients(
5676        unconstrained: &Array1<f64>,
5677        shape: ShapeConstraint,
5678    ) -> Result<Array1<f64>, BasisError> {
5679        if unconstrained.is_empty() {
5680            crate::bail_invalid_basis!("unconstrained coefficient vector cannot be empty");
5681        }
5682        let mapped = match shape {
5683            ShapeConstraint::None => unconstrained.clone(),
5684            ShapeConstraint::MonotoneIncreasing => cumulative_exp(unconstrained, 1.0),
5685            ShapeConstraint::MonotoneDecreasing => cumulative_exp(unconstrained, -1.0),
5686            ShapeConstraint::Convex => second_cumulative_exp(unconstrained, 1.0),
5687            ShapeConstraint::Concave => second_cumulative_exp(unconstrained, -1.0),
5688        };
5689        Ok(mapped)
5690    }
5691}
5692
5693pub struct LocalSmoothTermBuild {
5694    pub dim: usize,
5695    pub design: DesignMatrix,
5696    pub penalties: Vec<Array2<f64>>,
5697    pub ops: Vec<Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>>,
5698    pub nullspaces: Vec<usize>,
5699    /// Per-active-penalty null-space eigenvector matrices, parallel to
5700    /// `penalties` / `ops` / `nullspaces`. `Some(U_null)` when
5701    /// `nullspaces[k] > 0`, with `U_null` orthonormal columns spanning
5702    /// `null(penalties[k])` in this smooth's local coordinate system; `None`
5703    /// when the active block is full-rank. Stage 1 plumbing; Stage 2
5704    /// consumes this to absorb the smooth's null space into the parametric
5705    /// block at `TermCollectionDesign` construction.
5706    pub null_eigenvectors: Vec<Option<Array2<f64>>>,
5707    /// Joint-null absorption rotation for this smooth. `Some(rotation)`
5708    /// records `Q = [U_range | U_null]` spanning `null(Σ_k penalties[k])`,
5709    /// the joint null across all active penalty blocks on this smooth.
5710    /// `None` means the joint penalty is full-rank (joint nullity = 0) or
5711    /// there are no penalties. Stage-2 commit A: plumbing only — populated
5712    /// by commit B, applied by commit D.
5713    pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
5714    pub penaltyinfo: Vec<PenaltyInfo>,
5715    pub pre_dropped_penaltyinfo: Vec<PenaltyInfo>,
5716    pub metadata: BasisMetadata,
5717    pub linear_constraints: Option<LinearInequalityConstraints>,
5718    pub box_reparam: bool,
5719    pub kronecker_factored: Option<KroneckerFactoredBasis>,
5720}
5721
5722#[derive(Clone)]
5723pub struct PcaScoresMemmapDesignOperator {
5724    mmap: Arc<memmap2::Mmap>,
5725    data_offset: usize,
5726    nrows: usize,
5727    ncols: usize,
5728    chunk_size: usize,
5729}
5730
5731impl PcaScoresMemmapDesignOperator {
5732    fn open(path: PathBuf, chunk_size: usize) -> Result<Self, BasisError> {
5733        let file = File::open(&path).map_err(|err| {
5734            BasisError::InvalidInput(format!(
5735                "failed to open lazy Pca .npy scores '{}': {err}",
5736                path.display()
5737            ))
5738        })?;
5739        // The .npy scores file is read-only training-cache data; this
5740        // module never mutates it. The error path below converts mmap
5741        // failure to a typed `BasisError::InvalidInput`.
5742        // SAFETY: `memmap2::Mmap::map` requires no concurrent writers; the
5743        // contract is held by this module's read-only access pattern.
5744        let mmap = unsafe {
5745            memmap2::Mmap::map(&file).map_err(|err| {
5746                BasisError::InvalidInput(format!(
5747                    "failed to memmap lazy Pca .npy scores '{}': {err}",
5748                    path.display()
5749                ))
5750            })?
5751        };
5752        let (data_offset, nrows, ncols) = parse_f64_2d_npy_header(&mmap, &path)?;
5753        let expected = data_offset
5754            .checked_add(nrows.saturating_mul(ncols).saturating_mul(8))
5755            .ok_or_else(|| {
5756                BasisError::InvalidInput(format!(
5757                    "lazy Pca .npy scores '{}' shape is too large",
5758                    path.display()
5759                ))
5760            })?;
5761        if mmap.len() < expected {
5762            crate::bail_invalid_basis!(
5763                "lazy Pca .npy scores '{}' is truncated: header expects {} bytes, file has {}",
5764                path.display(),
5765                expected,
5766                mmap.len()
5767            );
5768        }
5769        Ok(Self {
5770            mmap: Arc::new(mmap),
5771            data_offset,
5772            nrows,
5773            ncols,
5774            chunk_size: chunk_size.max(1),
5775        })
5776    }
5777
5778    fn value(&self, row: usize, col: usize) -> f64 {
5779        let offset = self.data_offset + (row * self.ncols + col) * 8;
5780        let mut bytes = [0_u8; 8];
5781        bytes.copy_from_slice(&self.mmap[offset..offset + 8]);
5782        f64::from_le_bytes(bytes)
5783    }
5784
5785    fn chunk_rows(&self) -> usize {
5786        self.chunk_size.min(self.nrows.max(1))
5787    }
5788}
5789
5790impl LinearOperator for PcaScoresMemmapDesignOperator {
5791    fn nrows(&self) -> usize {
5792        self.nrows
5793    }
5794
5795    fn ncols(&self) -> usize {
5796        self.ncols
5797    }
5798
5799    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
5800        assert_eq!(
5801            vector.len(),
5802            self.ncols,
5803            "lazy Pca apply vector length mismatch"
5804        );
5805        let mut out = Array1::<f64>::zeros(self.nrows);
5806        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5807            let end = (start + self.chunk_rows()).min(self.nrows);
5808            for row in start..end {
5809                let mut acc = 0.0;
5810                for col in 0..self.ncols {
5811                    acc += self.value(row, col) * vector[col];
5812                }
5813                out[row] = acc;
5814            }
5815        }
5816        out
5817    }
5818
5819    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
5820        assert_eq!(
5821            vector.len(),
5822            self.nrows,
5823            "lazy Pca apply_transpose vector length mismatch"
5824        );
5825        let mut out = Array1::<f64>::zeros(self.ncols);
5826        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5827            let end = (start + self.chunk_rows()).min(self.nrows);
5828            for row in start..end {
5829                let scale = vector[row];
5830                if scale == 0.0 {
5831                    continue;
5832                }
5833                for col in 0..self.ncols {
5834                    out[col] += scale * self.value(row, col);
5835                }
5836            }
5837        }
5838        out
5839    }
5840
5841    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
5842        if weights.len() != self.nrows {
5843            return Err(format!(
5844                "lazy Pca diag_xtw_x weight length mismatch: weights={}, nrows={}",
5845                weights.len(),
5846                self.nrows
5847            ));
5848        }
5849        let mut gram = Array2::<f64>::zeros((self.ncols, self.ncols));
5850        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5851            let end = (start + self.chunk_rows()).min(self.nrows);
5852            for row in start..end {
5853                let w = weights[row];
5854                if w == 0.0 {
5855                    continue;
5856                }
5857                for a in 0..self.ncols {
5858                    let xa = self.value(row, a);
5859                    if xa == 0.0 {
5860                        continue;
5861                    }
5862                    for b in a..self.ncols {
5863                        gram[[a, b]] += w * xa * self.value(row, b);
5864                    }
5865                }
5866            }
5867        }
5868        for a in 0..self.ncols {
5869            for b in 0..a {
5870                gram[[a, b]] = gram[[b, a]];
5871            }
5872        }
5873        Ok(gram)
5874    }
5875
5876    fn apply_weighted_normal(
5877        &self,
5878        weights: &Array1<f64>,
5879        vector: &Array1<f64>,
5880        penalty: Option<&Array2<f64>>,
5881        ridge: f64,
5882    ) -> Array1<f64> {
5883        assert_eq!(
5884            weights.len(),
5885            self.nrows,
5886            "lazy Pca weighted-normal weight mismatch"
5887        );
5888        assert_eq!(
5889            vector.len(),
5890            self.ncols,
5891            "lazy Pca weighted-normal vector mismatch"
5892        );
5893        let mut out = Array1::<f64>::zeros(self.ncols);
5894        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5895            let end = (start + self.chunk_rows()).min(self.nrows);
5896            for row in start..end {
5897                let w = weights[row].max(0.0);
5898                if w == 0.0 {
5899                    continue;
5900                }
5901                let mut row_dot = 0.0;
5902                for col in 0..self.ncols {
5903                    row_dot += self.value(row, col) * vector[col];
5904                }
5905                if row_dot == 0.0 {
5906                    continue;
5907                }
5908                let scaled = w * row_dot;
5909                for col in 0..self.ncols {
5910                    out[col] += scaled * self.value(row, col);
5911                }
5912            }
5913        }
5914        if let Some(pen) = penalty {
5915            out += &pen.dot(vector);
5916        }
5917        if ridge > 0.0 {
5918            out += &vector.mapv(|x| ridge * x);
5919        }
5920        out
5921    }
5922}
5923
5924impl DenseDesignOperator for PcaScoresMemmapDesignOperator {
5925    fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
5926        if weights.len() != self.nrows || y.len() != self.nrows {
5927            return Err(format!(
5928                "lazy Pca compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
5929                weights.len(),
5930                y.len(),
5931                self.nrows
5932            ));
5933        }
5934        let mut out = Array1::<f64>::zeros(self.ncols);
5935        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5936            let end = (start + self.chunk_rows()).min(self.nrows);
5937            for row in start..end {
5938                let scale = weights[row] * y[row];
5939                if scale == 0.0 {
5940                    continue;
5941                }
5942                for col in 0..self.ncols {
5943                    out[col] += scale * self.value(row, col);
5944                }
5945            }
5946        }
5947        Ok(out)
5948    }
5949
5950    fn row_chunk_into(
5951        &self,
5952        rows: Range<usize>,
5953        mut out: ArrayViewMut2<'_, f64>,
5954    ) -> Result<(), MatrixMaterializationError> {
5955        if rows.end > self.nrows || rows.start > rows.end {
5956            return Err(MatrixMaterializationError::MissingRowChunk {
5957                context: "lazy Pca row range out of bounds",
5958            });
5959        }
5960        if out.nrows() != rows.end - rows.start || out.ncols() != self.ncols {
5961            return Err(MatrixMaterializationError::MissingRowChunk {
5962                context: "lazy Pca row_chunk_into shape mismatch",
5963            });
5964        }
5965        for (local, row) in (rows.start..rows.end).enumerate() {
5966            for col in 0..self.ncols {
5967                out[[local, col]] = self.value(row, col);
5968            }
5969        }
5970        Ok(())
5971    }
5972
5973    fn to_dense(&self) -> Array2<f64> {
5974        let mut out = Array2::<f64>::zeros((self.nrows, self.ncols));
5975        self.row_chunk_into(0..self.nrows, out.view_mut())
5976            .expect("lazy Pca full materialization failed");
5977        out
5978    }
5979}
5980
5981pub fn parse_f64_2d_npy_header(
5982    bytes: &[u8],
5983    path: &PathBuf,
5984) -> Result<(usize, usize, usize), BasisError> {
5985    if bytes.len() < 10 || &bytes[0..6] != b"\x93NUMPY" {
5986        crate::bail_invalid_basis!("lazy Pca scores '{}' is not a .npy file", path.display());
5987    }
5988    let major = bytes[6];
5989    let header_len = match major {
5990        1 => u16::from_le_bytes([bytes[8], bytes[9]]) as usize,
5991        2 | 3 => {
5992            if bytes.len() < 12 {
5993                crate::bail_invalid_basis!(
5994                    "lazy Pca scores '{}' has a truncated .npy header",
5995                    path.display()
5996                );
5997            }
5998            u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize
5999        }
6000        other => {
6001            crate::bail_invalid_basis!(
6002                "lazy Pca scores '{}' uses unsupported .npy version {}",
6003                path.display(),
6004                other
6005            );
6006        }
6007    };
6008    let header_start = if major == 1 { 10 } else { 12 };
6009    let data_offset = header_start + header_len;
6010    if bytes.len() < data_offset {
6011        crate::bail_invalid_basis!(
6012            "lazy Pca scores '{}' has a truncated .npy header",
6013            path.display()
6014        );
6015    }
6016    let header = std::str::from_utf8(&bytes[header_start..data_offset]).map_err(|err| {
6017        BasisError::InvalidInput(format!(
6018            "lazy Pca scores '{}' has a non-UTF8 .npy header: {err}",
6019            path.display()
6020        ))
6021    })?;
6022    if !(header.contains("'descr': '<f8'")
6023        || header.contains("\"descr\": \"<f8\"")
6024        || header.contains("'descr': '|f8'")
6025        || header.contains("\"descr\": \"|f8\""))
6026    {
6027        crate::bail_invalid_basis!(
6028            "lazy Pca scores '{}' must be float64 little-endian .npy",
6029            path.display()
6030        );
6031    }
6032    if header.contains("True") {
6033        crate::bail_invalid_basis!(
6034            "lazy Pca scores '{}' must be C-contiguous, not Fortran-ordered",
6035            path.display()
6036        );
6037    }
6038    let shape_pos = header.find("shape").ok_or_else(|| {
6039        BasisError::InvalidInput(format!(
6040            "lazy Pca scores '{}' .npy header is missing shape",
6041            path.display()
6042        ))
6043    })?;
6044    let open = header[shape_pos..].find('(').ok_or_else(|| {
6045        BasisError::InvalidInput(format!(
6046            "lazy Pca scores '{}' .npy header has malformed shape",
6047            path.display()
6048        ))
6049    })? + shape_pos;
6050    let close = header[open..].find(')').ok_or_else(|| {
6051        BasisError::InvalidInput(format!(
6052            "lazy Pca scores '{}' .npy header has malformed shape",
6053            path.display()
6054        ))
6055    })? + open;
6056    let dims = header[open + 1..close]
6057        .split(',')
6058        .map(str::trim)
6059        .filter(|part| !part.is_empty())
6060        .map(|part| part.parse::<usize>())
6061        .collect::<Result<Vec<_>, _>>()
6062        .map_err(|err| {
6063            BasisError::InvalidInput(format!(
6064                "lazy Pca scores '{}' .npy shape is not integral: {err}",
6065                path.display()
6066            ))
6067        })?;
6068    if dims.len() != 2 {
6069        crate::bail_invalid_basis!(
6070            "lazy Pca scores '{}' must have shape (N, K), got {:?}",
6071            path.display(),
6072            dims
6073        );
6074    }
6075    Ok((data_offset, dims[0], dims[1]))
6076}
6077
6078pub fn pca_center_mean(x: ArrayView2<'_, f64>) -> Result<Array1<f64>, BasisError> {
6079    if x.nrows() == 0 {
6080        crate::bail_invalid_basis!("Pca basis requires at least one row to compute center mean");
6081    }
6082    let mut mean = Array1::<f64>::zeros(x.ncols());
6083    for row in x.rows() {
6084        mean += &row;
6085    }
6086    mean.mapv_inplace(|v| v / x.nrows() as f64);
6087    Ok(mean)
6088}
6089
6090pub fn build_pca_smooth_basis(
6091    data: ArrayView2<'_, f64>,
6092    feature_cols: &[usize],
6093    basis_matrix: &Array2<f64>,
6094    centered: bool,
6095    smooth_penalty: f64,
6096    center_mean: Option<&Array1<f64>>,
6097    pca_basis_path: Option<&PathBuf>,
6098    chunk_size: usize,
6099) -> Result<BasisBuildResult, BasisError> {
6100    if let Some(path) = pca_basis_path {
6101        let op = PcaScoresMemmapDesignOperator::open(path.clone(), chunk_size)?;
6102        if op.nrows != data.nrows() {
6103            crate::bail_dim_basis!(
6104                "lazy Pca scores row mismatch: .npy has {}, data has {}",
6105                op.nrows,
6106                data.nrows()
6107            );
6108        }
6109        let k = op.ncols;
6110        let mut penalty = Array2::<f64>::eye(k);
6111        penalty.mapv_inplace(|v| v * smooth_penalty);
6112        let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
6113            filter_active_penalty_candidates_with_ops(vec![PenaltyCandidate {
6114                matrix: penalty,
6115                nullspace_dim_hint: 0,
6116                source: PenaltySource::Other("PcaRidge".to_string()),
6117                normalization_scale: 1.0,
6118                kronecker_factors: None,
6119                op: None,
6120            }])?;
6121        return Ok(BasisBuildResult {
6122            design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op))),
6123            penalties,
6124            nullspace_dims,
6125            penaltyinfo,
6126            ops,
6127            null_eigenvectors,
6128            joint_null_rotation: None,
6129            metadata: BasisMetadata::Pca {
6130                feature_cols: feature_cols.to_vec(),
6131                basis_matrix: basis_matrix.clone(),
6132                centered,
6133                smooth_penalty,
6134                center_mean: center_mean.cloned(),
6135                pca_basis_path: Some(path.clone()),
6136                chunk_size: chunk_size.max(1),
6137            },
6138            kronecker_factored: None,
6139        });
6140    }
6141    if basis_matrix.nrows() != feature_cols.len() {
6142        crate::bail_dim_basis!(
6143            "Pca basis row mismatch: basis rows={}, feature columns={}",
6144            basis_matrix.nrows(),
6145            feature_cols.len()
6146        );
6147    }
6148    let mut x = select_columns(data, feature_cols)?;
6149    let mean = if centered {
6150        match center_mean {
6151            Some(mean) => mean.clone(),
6152            None => pca_center_mean(x.view())?,
6153        }
6154    } else {
6155        Array1::<f64>::zeros(feature_cols.len())
6156    };
6157    if centered {
6158        for mut row in x.rows_mut() {
6159            row -= &mean;
6160        }
6161    }
6162    let design = fast_ab(&x, basis_matrix);
6163    let k = basis_matrix.ncols();
6164    let mut penalty = Array2::<f64>::eye(k);
6165    penalty.mapv_inplace(|v| v * smooth_penalty);
6166    let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
6167        filter_active_penalty_candidates_with_ops(vec![PenaltyCandidate {
6168            matrix: penalty,
6169            nullspace_dim_hint: 0,
6170            source: PenaltySource::Other("PcaRidge".to_string()),
6171            normalization_scale: 1.0,
6172            kronecker_factors: None,
6173            op: None,
6174        }])?;
6175    Ok(BasisBuildResult {
6176        design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(design)),
6177        penalties,
6178        nullspace_dims,
6179        penaltyinfo,
6180        ops,
6181        null_eigenvectors,
6182        joint_null_rotation: None,
6183        metadata: BasisMetadata::Pca {
6184            feature_cols: feature_cols.to_vec(),
6185            basis_matrix: basis_matrix.clone(),
6186            centered,
6187            smooth_penalty,
6188            center_mean: centered.then_some(mean),
6189            pca_basis_path: None,
6190            chunk_size: chunk_size.max(1),
6191        },
6192        kronecker_factored: None,
6193    })
6194}
6195
6196/// A factor-level `by=` wrapper owns the model-space centering of its inner
6197/// smooth: it gates the raw/structurally-constrained basis to the level rows
6198/// and then centers that gated block exactly once against the level indicator
6199/// (`build_parametric_constraint_block_for_term` in `design_construction`).
6200/// Leaving the inner B-spline's default pooled weighted-sum-to-zero active here
6201/// would impose two generically-independent constraints — the pooled column
6202/// moment `m = Σ_h m_h` and the per-level moment `m_g` — so a raw `k`-column
6203/// basis collapses to `k-2` columns per level instead of `k-1`, deleting one
6204/// genuine nonconstant spline direction *before REML runs* (#1427). The group
6205/// main effect carries only the constant, so it cannot restore that direction.
6206///
6207/// Only the *default model-space* centering is deferred. Explicit structural or
6208/// frozen transforms (`RemoveLinearTrend`, `OrthogonalToDesignColumns`,
6209/// `FrozenTransform`, `None`) are user/structural choices and are preserved
6210/// verbatim.
6211pub fn defer_inner_model_centering_to_factor_level_wrapper(basis: &mut SmoothBasisSpec) {
6212    if let SmoothBasisSpec::BSpline1D { spec, .. } = basis
6213        && matches!(
6214            spec.identifiability,
6215            BSplineIdentifiability::WeightedSumToZero { .. }
6216        )
6217    {
6218        spec.identifiability = BSplineIdentifiability::None;
6219    }
6220}
6221
6222pub fn apply_by_variable_to_local_build(
6223    mut built: LocalSmoothTermBuild,
6224    data: ArrayView2<'_, f64>,
6225    by_col: usize,
6226    by: &ByVariableSpec,
6227    term_name: &str,
6228) -> Result<LocalSmoothTermBuild, BasisError> {
6229    if by_col >= data.ncols() {
6230        crate::bail_dim_basis!(
6231            "by-variable smooth term '{term_name}' references column {by_col}, but data has {} columns",
6232            data.ncols()
6233        );
6234    }
6235    let weights = match by {
6236        ByVariableSpec::Numeric => data.column(by_col).to_owned(),
6237        ByVariableSpec::Level { value_bits, .. } => data.column(by_col).mapv(|value| {
6238            if value.to_bits() == *value_bits {
6239                1.0
6240            } else {
6241                0.0
6242            }
6243        }),
6244    };
6245    if weights.iter().any(|value| !value.is_finite()) {
6246        crate::bail_invalid_basis!(
6247            "by-variable smooth term '{term_name}' has non-finite by-column values"
6248        );
6249    }
6250
6251    let mut dense = built
6252        .design
6253        .try_to_dense_by_chunks("by-variable smooth row gating")
6254        .map_err(BasisError::InvalidInput)?;
6255    for (mut row, &weight) in dense.rows_mut().into_iter().zip(weights.iter()) {
6256        row.mapv_inplace(|value| value * weight);
6257    }
6258    built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense));
6259    built.kronecker_factored = None;
6260    Ok(built)
6261}
6262
6263/// Build the local smooth term for a `BySmooth` spec, which unifies numeric-by
6264/// and factor-by modulation into a single `SmoothTermSpec`.
6265///
6266/// For a **numeric** by-variable the inner smooth is built once and every row
6267/// is multiplied by the by-column value (identical to `ByVariable::Numeric`).
6268///
6269/// For a **factor** by-variable the inner smooth is built once and gated per
6270/// level into side-by-side column blocks, producing a `n × (L * p)` design
6271/// matrix.  The penalties are block-diagonalised (one copy of the inner penalty
6272/// per level) exactly as `build_factor_smooth` does for `bs="fs"/"sz"`.
6273pub fn build_by_smooth_local(
6274    data: ArrayView2<'_, f64>,
6275    term: &SmoothTermSpec,
6276    smooth: &SmoothBasisSpec,
6277    by_kind: &ByVarKind,
6278    workspace: &mut crate::basis::BasisWorkspace,
6279) -> Result<LocalSmoothTermBuild, BasisError> {
6280    let inner_term = SmoothTermSpec {
6281        name: term.name.clone(),
6282        basis: (*smooth).clone(),
6283        shape: term.shape,
6284        joint_null_rotation: None,
6285    };
6286    let inner = build_single_local_smooth_term(data, &inner_term, workspace)?;
6287
6288    match by_kind {
6289        ByVarKind::Numeric { feature_col } => {
6290            let inner_meta = inner.metadata.clone();
6291            let mut built = apply_by_variable_to_local_build(
6292                inner,
6293                data,
6294                *feature_col,
6295                &ByVariableSpec::Numeric,
6296                &term.name,
6297            )?;
6298            built.metadata = BasisMetadata::BySmooth {
6299                inner: Box::new(inner_meta),
6300                by_col: *feature_col,
6301                levels: None,
6302                ordered: false,
6303            };
6304            Ok(built)
6305        }
6306        ByVarKind::Factor {
6307            feature_col,
6308            frozen_levels,
6309            ordered,
6310        } => {
6311            // Collect factor levels: prefer the frozen set (replay path), else
6312            // scan the data column (first-fit path).
6313            let level_bits: Vec<u64> = if let Some(fl) = frozen_levels {
6314                fl.clone()
6315            } else {
6316                let col = data.column(*feature_col);
6317                let mut seen = BTreeSet::<u64>::new();
6318                for &v in col.iter() {
6319                    if v.is_finite() {
6320                        seen.insert(v.to_bits());
6321                    }
6322                }
6323                seen.into_iter().collect()
6324            };
6325            let n_levels = level_bits.len();
6326            if n_levels == 0 {
6327                crate::bail_invalid_basis!(
6328                    "by-factor smooth term '{}': factor column {} has no observed levels",
6329                    term.name,
6330                    feature_col
6331                );
6332            }
6333            let p = inner.dim;
6334            let q = n_levels * p;
6335            let n = data.nrows();
6336
6337            let inner_dense = inner
6338                .design
6339                .try_to_dense_by_chunks("by-factor smooth design gating")
6340                .map_err(BasisError::InvalidInput)?;
6341
6342            // Gate each level into its own p-wide column block.
6343            let mut combined = Array2::<f64>::zeros((n, q));
6344            for (lvl_idx, &bits) in level_bits.iter().enumerate() {
6345                let col_start = lvl_idx * p;
6346                for row in 0..n {
6347                    if data[[row, *feature_col]].to_bits() == bits {
6348                        combined
6349                            .slice_mut(s![row, col_start..col_start + p])
6350                            .assign(&inner_dense.row(row));
6351                    }
6352                }
6353            }
6354
6355            // Build per-level INDEPENDENT penalties (#1427): one copy of each
6356            // inner penalty per level, but each confined to that single level's
6357            // diagonal block, so every (level, inner-penalty) pair is its OWN
6358            // smoothing-parameter coordinate. `s(x, by=g)` selects the per-group
6359            // curve wiggliness independently — the design is block-diagonal and
6360            // block-separable, so a correct REML must reproduce gamfit's own
6361            // independent per-group fits. Tiling a single inner penalty across
6362            // every level (as the `bs="fs"` shared-λ random-effect construction
6363            // does) collapses all groups onto ONE λ, which cannot match uneven
6364            // per-level smoothness and degrades as data grows (under-recovery up
6365            // to ~16× at n=2000). Emit `n_levels * n_penalties` blocks instead.
6366            let inner_meta = inner.metadata.clone();
6367            let n_penalties = inner.penalties.len();
6368            let n_blocks = n_penalties.saturating_mul(n_levels);
6369            let mut penalties = Vec::<Array2<f64>>::with_capacity(n_blocks);
6370            let mut penaltyinfo = Vec::<PenaltyInfo>::with_capacity(n_blocks);
6371            let mut nullspaces = Vec::<usize>::with_capacity(n_blocks);
6372            for (pen_pos, s_inner) in inner.penalties.iter().enumerate() {
6373                for lvl in 0..n_levels {
6374                    let off = lvl * p;
6375                    let mut s_big = Array2::<f64>::zeros((q, q));
6376                    s_big
6377                        .slice_mut(s![off..off + p, off..off + p])
6378                        .assign(s_inner);
6379                    let (s_big, scale) = normalize_penalty_in_constrained_space(&s_big);
6380                    let mut info = inner.penaltyinfo[pen_pos].clone();
6381                    // Distinct original_index per (penalty, level) so each λ is a
6382                    // separate identifiable coordinate downstream.
6383                    info.original_index = pen_pos * n_levels + lvl;
6384                    info.normalization_scale *= scale;
6385                    // Each block now spans exactly ONE level → per-level nullity,
6386                    // not the tiled (× n_levels) hint of the shared construction.
6387                    info.kronecker_factors = None;
6388                    penalties.push(s_big);
6389                    penaltyinfo.push(info);
6390                    nullspaces.push(inner.nullspaces[pen_pos]);
6391                }
6392            }
6393
6394            let null_eigenvectors = vec![None; penalties.len()];
6395            let ops = vec![None; penalties.len()];
6396
6397            Ok(LocalSmoothTermBuild {
6398                dim: q,
6399                design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(combined)),
6400                penalties,
6401                ops,
6402                nullspaces,
6403                null_eigenvectors,
6404                joint_null_rotation: None,
6405                penaltyinfo,
6406                pre_dropped_penaltyinfo: inner.pre_dropped_penaltyinfo,
6407                metadata: BasisMetadata::BySmooth {
6408                    inner: Box::new(inner_meta),
6409                    by_col: *feature_col,
6410                    levels: Some(level_bits),
6411                    ordered: *ordered,
6412                },
6413                linear_constraints: None,
6414                box_reparam: false,
6415                kronecker_factored: None,
6416            })
6417        }
6418    }
6419}
6420
6421pub fn ensure_by_variable_specs_match(
6422    kind: &BySmoothKind,
6423    by: &ByVariableSpec,
6424    term_name: &str,
6425) -> Result<(), BasisError> {
6426    match (kind, by) {
6427        (BySmoothKind::Numeric, ByVariableSpec::Numeric) => Ok(()),
6428        (BySmoothKind::Level { level_bits }, ByVariableSpec::Level { value_bits, .. })
6429            if level_bits == value_bits =>
6430        {
6431            Ok(())
6432        }
6433        _ => Err(BasisError::InvalidInput(format!(
6434            "by-variable smooth term '{term_name}' has inconsistent by-variable specifications"
6435        ))),
6436    }
6437}
6438
6439/// Build a factor-smooth interaction basis (`bs="fs"`/`"sz"`/`"re"`).
6440///
6441/// A factor smooth replicates a shared marginal smooth in the continuous
6442/// covariate(s) once per level of a grouping factor, coupling all level blocks
6443/// through a *single* set of smoothing parameters (one per marginal penalty).
6444/// This is mgcv's `smooth.construct.fs.smooth.spec` realization and the
6445/// random-effect interpretation of a smooth: the per-level deviations are an
6446/// exchangeable family whose joint wiggliness/shrinkage is governed by the
6447/// shared λ, so the construction scales to many levels with a fixed parameter
6448/// count.
6449///
6450/// Flavours:
6451/// * `Fs` — full random factor-smooth. The marginal carries its wiggliness
6452///   penalty *and* a null-space ridge (double penalty), so the replicated
6453///   design is a proper full-rank random effect: each level's curve is shrunk
6454///   toward zero (intercept + linear trend included), recovering the mgcv
6455///   `bs="fs"` penalty structure `I_L ⊗ S_j` for every marginal penalty `S_j`.
6456/// * `Sz` — sum-to-zero factor smooth. Delegates to the existing
6457///   [`SmoothBasisSpec::FactorSumToZero`] construction (`L-1` deviation blocks,
6458///   coefficient-wise zero sum across levels).
6459/// * `Re` — pure random effect / random slope (`bs="re"`). A degree-1 marginal
6460///   gives the per-level `[1, x]` span; the penalty is the identity over each
6461///   level block (iid Gaussian coefficients), matching mgcv's `bs="re"` ridge.
6462///
6463/// The grouping levels are resolved once at fit time (sorted unique bit
6464/// patterns of the factor column) and frozen into the returned metadata so the
6465/// predict-time rebuild evaluates every row against its own level's block.
6466pub fn build_factor_smooth(
6467    data: ArrayView2<'_, f64>,
6468    spec: &FactorSmoothSpec,
6469    term_name: &str,
6470    workspace: &mut crate::basis::BasisWorkspace,
6471) -> Result<LocalSmoothTermBuild, BasisError> {
6472    if spec.continuous_cols.len() != 1 {
6473        crate::bail_invalid_basis!(
6474            "factor smooth term '{}' currently supports exactly one continuous covariate; found {}",
6475            term_name,
6476            spec.continuous_cols.len()
6477        );
6478    }
6479    let feature_col = spec.continuous_cols[0];
6480    let group_col = spec.group_col;
6481    if feature_col >= data.ncols() || group_col >= data.ncols() {
6482        crate::bail_dim_basis!(
6483            "factor smooth term '{}' references columns ({}, {}) out of bounds for {} columns",
6484            term_name,
6485            feature_col,
6486            group_col,
6487            data.ncols()
6488        );
6489    }
6490
6491    // `Sz` is exactly the existing sum-to-zero factor smooth: reuse it verbatim
6492    // so there is a single source of truth for the zero-sum construction.
6493    if matches!(spec.flavour, FactorSmoothFlavour::Sz) {
6494        let levels = resolve_factor_smooth_levels(data, group_col, spec, term_name)?;
6495        let inner = SmoothBasisSpec::BSpline1D {
6496            feature_col,
6497            spec: factor_smooth_marginal_for_replay(&spec.marginal),
6498        };
6499        let sz_term = SmoothTermSpec {
6500            name: term_name.to_string(),
6501            basis: SmoothBasisSpec::FactorSumToZero {
6502                inner: Box::new(inner),
6503                by_col: group_col,
6504                levels: levels.clone(),
6505                frozen_global_orthogonality: None,
6506            },
6507            shape: ShapeConstraint::None,
6508            joint_null_rotation: None,
6509        };
6510        let mut built = build_single_local_smooth_term(data, &sz_term, workspace)?;
6511        // The delegated `FactorSumToZero` build returns the BARE inner B-spline
6512        // metadata (`BasisMetadata::BSpline1D`), but the term that owns this
6513        // build carries a `SmoothBasisSpec::FactorSmooth { Sz }` spec. Two
6514        // things break if we hand that mismatched pair downstream:
6515        //   1. `freeze_smooth_basis_from_metadata` matches on (spec, metadata)
6516        //      and has no `(FactorSmooth, BSpline1D)` arm, so any refit / spatial
6517        //      re-optimization that freezes the basis aborts with a "smooth
6518        //      metadata/spec type mismatch" error.
6519        //   2. The bare B-spline metadata carries no grouping levels, so a
6520        //      predict-time rebuild cannot replay the SAME replicated design.
6521        // Re-wrap the marginal geometry as `FactorSmooth` metadata exactly as
6522        // the Fs/Re path below does, giving all three factor-smooth flavours a
6523        // single, freeze-consistent metadata shape that also pins the levels.
6524        // The delegated marginal may be a B-spline (`bs="ps"`-style) OR a cubic
6525        // regression spline (`NaturalCubicRegression`, mgcv's `bs="sz"` default,
6526        // #1074); capture either so the predict-time freeze restores the SAME
6527        // marginal class.
6528        let (knots, degree, periodic, marginal_is_cr) = match &built.metadata {
6529            BasisMetadata::BSpline1D {
6530                knots,
6531                periodic,
6532                degree,
6533                ..
6534            } => (
6535                knots.clone(),
6536                degree.unwrap_or(spec.marginal.degree),
6537                *periodic,
6538                false,
6539            ),
6540            BasisMetadata::CubicRegression1D { knots, .. } => {
6541                (knots.clone(), spec.marginal.degree, None, true)
6542            }
6543            other => {
6544                crate::bail_invalid_basis!(
6545                    "sz factor smooth term '{}' produced an unexpected marginal metadata variant {:?}",
6546                    term_name,
6547                    other
6548                );
6549            }
6550        };
6551        built.metadata = BasisMetadata::FactorSmooth {
6552            continuous_cols: spec.continuous_cols.clone(),
6553            group_col,
6554            knots,
6555            degree,
6556            periodic,
6557            group_levels: levels,
6558            flavour: "sz".to_string(),
6559            marginal_is_cr,
6560        };
6561        return Ok(built);
6562    }
6563
6564    let levels = resolve_factor_smooth_levels(data, group_col, spec, term_name)?;
6565    let n_levels = levels.len();
6566    if n_levels < 2 {
6567        crate::bail_invalid_basis!(
6568            "factor smooth term '{}' requires at least two grouping levels; found {}",
6569            term_name,
6570            n_levels
6571        );
6572    }
6573
6574    // `Fs` (order ≥ 1, the default) is the random-effect flavour: it penalizes
6575    // each null-space dimension of the marginal wiggliness penalty separately
6576    // below (mgcv's `bs="fs"` construction). That replaces the marginal's single
6577    // *combined* double penalty, so disable the latter here to avoid penalizing
6578    // the null space twice (once combined, once per dimension). The explicit
6579    // `m=0` opt-out keeps the legacy combined double penalty and adds no
6580    // per-dimension penalties.
6581    let use_per_dim_null = matches!(
6582        &spec.flavour,
6583        FactorSmoothFlavour::Fs { m_null_penalty_orders }
6584            if m_null_penalty_orders.iter().copied().max().unwrap_or(0) >= 1
6585    );
6586
6587    // Build the shared marginal design + penalties from the 1-D B-spline.
6588    // `Re` forces a degree-1 marginal (linear span) and replaces the marginal
6589    // wiggliness with an identity ridge below; `Fs` keeps the user's marginal
6590    // (cubic by default) and, under the per-dimension null path, gets its null
6591    // space penalized one dimension at a time after replication.
6592    let mut marginal_spec = factor_smooth_marginal_for_replay(&spec.marginal);
6593    if use_per_dim_null {
6594        marginal_spec.double_penalty = false;
6595    }
6596    let inner_term = SmoothTermSpec {
6597        name: format!("{term_name}::marginal"),
6598        basis: SmoothBasisSpec::BSpline1D {
6599            feature_col,
6600            spec: marginal_spec,
6601        },
6602        shape: ShapeConstraint::None,
6603        joint_null_rotation: None,
6604    };
6605    let inner = build_single_local_smooth_term(data, &inner_term, workspace)?;
6606    let base = inner
6607        .design
6608        .try_to_dense_by_chunks("factor smooth marginal")
6609        .map_err(BasisError::InvalidInput)?;
6610    let n = base.nrows();
6611    let p = base.ncols();
6612    let q = p * n_levels;
6613
6614    // Block-diagonal replicated design: row i contributes its marginal row to
6615    // the column block owned by its grouping level, zeros elsewhere.
6616    let mut dense = Array2::<f64>::zeros((n, q));
6617    for i in 0..n {
6618        let bits = data[[i, group_col]].to_bits();
6619        let level_idx = levels.iter().position(|b| *b == bits).ok_or_else(|| {
6620            BasisError::InvalidInput(format!(
6621                "factor smooth term '{term_name}' saw an unseen grouping level at row {}",
6622                i + 1
6623            ))
6624        })?;
6625        let start = level_idx * p;
6626        dense
6627            .slice_mut(s![i, start..start + p])
6628            .assign(&base.row(i));
6629    }
6630
6631    // Penalties: replicate each marginal penalty into a block-diagonal
6632    // `I_L ⊗ S_j` so every level shares the same smoothing parameter λ_j (one
6633    // λ per marginal penalty), the defining feature of a factor smooth. For
6634    // `Re` the marginal penalty is replaced by an identity ridge so each
6635    // per-level coefficient is an iid Gaussian random effect.
6636    let marginal_penalties: Vec<Array2<f64>> = if matches!(spec.flavour, FactorSmoothFlavour::Re) {
6637        vec![Array2::<f64>::eye(p)]
6638    } else {
6639        inner.penalties.clone()
6640    };
6641    let marginal_penaltyinfo: Vec<PenaltyInfo> = if matches!(spec.flavour, FactorSmoothFlavour::Re)
6642    {
6643        vec![PenaltyInfo {
6644            source: PenaltySource::Primary,
6645            original_index: 0,
6646            active: true,
6647            effective_rank: p,
6648            dropped_reason: None,
6649            nullspace_dim_hint: 0,
6650            normalization_scale: 1.0,
6651            kronecker_factors: None,
6652        }]
6653    } else {
6654        inner.penaltyinfo.clone()
6655    };
6656    if marginal_penalties.len() != marginal_penaltyinfo.len() {
6657        crate::bail_invalid_basis!(
6658            "internal factor-smooth penalty metadata mismatch for term '{}': penalties={}, infos={}",
6659            term_name,
6660            marginal_penalties.len(),
6661            marginal_penaltyinfo.len()
6662        );
6663    }
6664
6665    let mut penalties = Vec::<Array2<f64>>::with_capacity(marginal_penalties.len());
6666    let mut penaltyinfo = Vec::<PenaltyInfo>::with_capacity(marginal_penalties.len());
6667    for (penalty_pos, s_inner) in marginal_penalties.iter().enumerate() {
6668        let mut s_big = Array2::<f64>::zeros((q, q));
6669        for level in 0..n_levels {
6670            let start = level * p;
6671            s_big
6672                .slice_mut(s![start..start + p, start..start + p])
6673                .assign(s_inner);
6674        }
6675        let (s_big, factor_smooth_scale) = normalize_penalty_in_constrained_space(&s_big);
6676        let mut info = marginal_penaltyinfo[penalty_pos].clone();
6677        info.original_index = penalty_pos;
6678        info.normalization_scale *= factor_smooth_scale;
6679        info.nullspace_dim_hint = info.nullspace_dim_hint.saturating_mul(n_levels);
6680        info.kronecker_factors = None;
6681        penalties.push(s_big);
6682        penaltyinfo.push(info);
6683    }
6684
6685    let mut nullspaces: Vec<usize> = if matches!(spec.flavour, FactorSmoothFlavour::Re) {
6686        vec![0]
6687    } else {
6688        inner
6689            .nullspaces
6690            .iter()
6691            .map(|ns| ns.saturating_mul(n_levels))
6692            .collect()
6693    };
6694
6695    // `Fs` is the random-effect flavour of a smooth: the per-group curve is an
6696    // exchangeable Gaussian *function*, so EVERY coefficient — including the
6697    // {const, linear} null space of the marginal wiggliness penalty — must be
6698    // shrinkable toward zero under its own shared variance. The wiggliness
6699    // penalty `S_wiggle` shapes curvature but leaves the per-group intercept and
6700    // slope (its null space) completely UNPENALIZED. With the null space free,
6701    // each group fits its own intercept and slope with NO partial pooling, so
6702    // the held-out per-subject forecast inherits the full no-pooling variance
6703    // and curves away from the true per-group line (gam#712 real arm, gam#713;
6704    // gam#903 sleepstudy forecast ran ~74% over the lme4 BLUP bar).
6705    //
6706    // mgcv's `bs="fs"` fixes this by penalizing each null-space dimension
6707    // SEPARATELY (`smooth.construct.fs.smooth.spec` adds one rank-1 penalty per
6708    // null coordinate), each replicated block-diagonally across levels under a
6709    // single shared smoothing parameter — so REML fits a distinct
6710    // random-intercept variance and random-slope variance, the partial pooling
6711    // that makes the forecast track lme4's correlated random-effect BLUP. A
6712    // single *combined* null penalty (one λ for intercept+slope together) cannot
6713    // express the typically very different intercept and slope variances, which
6714    // is the residual forecast gap. We mirror mgcv exactly: for each orthonormal
6715    // null direction `z_k` of the marginal wiggliness penalty add
6716    // `I_L ⊗ (z_k z_kᵀ)` as its own penalty. The marginal's combined double
6717    // penalty was disabled above, so the null space is penalized once, per
6718    // dimension. With linear data REML drives the curvature λ up and degrades
6719    // `fs` to a linear random slope (edf → ≈2/group); with genuine curvature the
6720    // wiggliness λ stays small and the wiggle survives (data-adaptive, not a
6721    // cap). Gated by `m_null_penalty_orders`: order ≥ 1 (default) enables the
6722    // per-dimension null penalties; `m=0` keeps the legacy combined double
6723    // penalty and adds nothing here.
6724    if use_per_dim_null
6725        && let Some(Some(z)) = inner.null_eigenvectors.first()
6726        && z.nrows() == p
6727    {
6728        for k in 0..z.ncols() {
6729            // Rank-1 marginal penalty `z_k z_kᵀ`, replicated block-diagonally
6730            // across levels into `I_L ⊗ (z_k z_kᵀ)`. Its own λ is one shared
6731            // variance for this null component (intercept or slope) across all
6732            // groups — the random-effect structure of mgcv `fs`.
6733            let zk = z.column(k);
6734            let mut p_k = Array2::<f64>::zeros((p, p));
6735            for a in 0..p {
6736                for b in 0..p {
6737                    p_k[[a, b]] = zk[a] * zk[b];
6738                }
6739            }
6740            let mut s_null = Array2::<f64>::zeros((q, q));
6741            for level in 0..n_levels {
6742                let start = level * p;
6743                s_null
6744                    .slice_mut(s![start..start + p, start..start + p])
6745                    .assign(&p_k);
6746            }
6747            let (s_null, null_scale) = normalize_penalty_in_constrained_space(&s_null);
6748            let null_block = crate::basis::analyze_penalty_block_with_op(&s_null, None)?;
6749            if null_block.rank > 0 {
6750                let original_index = penalties.len();
6751                penalties.push(null_block.sym_penalty);
6752                nullspaces.push(null_block.nullity);
6753                penaltyinfo.push(PenaltyInfo {
6754                    source: PenaltySource::Primary,
6755                    original_index,
6756                    active: true,
6757                    effective_rank: null_block.rank,
6758                    dropped_reason: None,
6759                    nullspace_dim_hint: null_block.nullity,
6760                    normalization_scale: null_scale,
6761                    kronecker_factors: None,
6762                });
6763            }
6764        }
6765    }
6766    let null_eigenvectors = crate::basis::recompute_null_eigenvectors(&penalties)?;
6767    let joint_null_rotation = crate::basis::compute_joint_null_rotation(&penalties)?;
6768
6769    // Metadata: carry the marginal knot geometry + frozen levels so prediction
6770    // reconstructs an identical replicated design.
6771    let (knots, degree, periodic) = match &inner.metadata {
6772        BasisMetadata::BSpline1D {
6773            knots,
6774            periodic,
6775            degree,
6776            ..
6777        } => (
6778            knots.clone(),
6779            degree.unwrap_or(spec.marginal.degree),
6780            *periodic,
6781        ),
6782        other => {
6783            crate::bail_invalid_basis!(
6784                "factor smooth term '{}' produced an unexpected marginal metadata variant {:?}",
6785                term_name,
6786                other
6787            );
6788        }
6789    };
6790    let flavour_tag = match &spec.flavour {
6791        FactorSmoothFlavour::Fs { .. } => "fs",
6792        FactorSmoothFlavour::Sz => "sz",
6793        FactorSmoothFlavour::Re => "re",
6794    }
6795    .to_string();
6796    let metadata = BasisMetadata::FactorSmooth {
6797        continuous_cols: spec.continuous_cols.clone(),
6798        group_col,
6799        knots,
6800        degree,
6801        periodic,
6802        group_levels: levels,
6803        flavour: flavour_tag,
6804        // fs/re marginals are always B-spline; the cr marginal is sz-only and
6805        // handled on the dedicated Sz path above.
6806        marginal_is_cr: false,
6807    };
6808
6809    let ops = vec![None; penalties.len()];
6810    Ok(LocalSmoothTermBuild {
6811        dim: q,
6812        design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense)),
6813        penalties,
6814        ops,
6815        nullspaces,
6816        null_eigenvectors,
6817        joint_null_rotation,
6818        penaltyinfo,
6819        pre_dropped_penaltyinfo: Vec::new(),
6820        metadata,
6821        linear_constraints: None,
6822        box_reparam: false,
6823        kronecker_factored: None,
6824    })
6825}
6826
6827/// Resolve the grouping levels for a factor smooth: replay the frozen level
6828/// list when present (predict path), otherwise discover the sorted unique bit
6829/// patterns of the factor column (fit path).
6830pub fn resolve_factor_smooth_levels(
6831    data: ArrayView2<'_, f64>,
6832    group_col: usize,
6833    spec: &FactorSmoothSpec,
6834    term_name: &str,
6835) -> Result<Vec<u64>, BasisError> {
6836    if let Some(frozen) = &spec.group_frozen_levels {
6837        if frozen.is_empty() {
6838            crate::bail_invalid_basis!(
6839                "factor smooth term '{}' has an empty frozen level list",
6840                term_name
6841            );
6842        }
6843        return Ok(frozen.clone());
6844    }
6845    let mut bits: Vec<u64> = data.column(group_col).iter().map(|v| v.to_bits()).collect();
6846    bits.sort_by(|a, b| {
6847        f64::from_bits(*a)
6848            .partial_cmp(&f64::from_bits(*b))
6849            .unwrap_or(std::cmp::Ordering::Equal)
6850    });
6851    bits.dedup();
6852    Ok(bits)
6853}
6854
6855/// Marginal B-spline spec for a factor-smooth block. The marginal always builds
6856/// without an identifiability constraint (the per-level replication, not a
6857/// sum-to-zero side constraint, provides identifiability against the parametric
6858/// block). At predict time the marginal's knot geometry has already been pinned
6859/// into `marginal.knotspec` by the metadata replay, so the spec is used
6860/// verbatim aside from clearing the identifiability transform.
6861pub fn factor_smooth_marginal_for_replay(marginal: &BSplineBasisSpec) -> BSplineBasisSpec {
6862    let mut m = marginal.clone();
6863    m.identifiability = BSplineIdentifiability::None;
6864    m
6865}
6866
6867pub fn build_single_local_smooth_term(
6868    data: ArrayView2<'_, f64>,
6869    term: &SmoothTermSpec,
6870    workspace: &mut crate::basis::BasisWorkspace,
6871) -> Result<LocalSmoothTermBuild, BasisError> {
6872    if term.shape != ShapeConstraint::None && !shape_supports_basis(term) {
6873        crate::bail_invalid_basis!(
6874            "ShapeConstraint::{:?} is unsupported for term '{}'",
6875            term.shape,
6876            term.name
6877        );
6878    }
6879    if let SmoothBasisSpec::ByVariable {
6880        inner,
6881        by_col,
6882        kind,
6883        by,
6884    } = &term.basis
6885    {
6886        ensure_by_variable_specs_match(kind, by, &term.name)?;
6887        let mut inner_basis = (**inner).clone();
6888        // Factor-level `by=` owns model-space centering (it centers the gated
6889        // block against the level indicator downstream). Defer the inner
6890        // basis's default pooled centering so the level block is not
6891        // double-centered down to `k-2` columns (#1427). Numeric-by smooths are
6892        // untouched: they are not row-gated to a level and keep ordinary
6893        // intercept centering.
6894        if matches!(by, ByVariableSpec::Level { .. }) {
6895            defer_inner_model_centering_to_factor_level_wrapper(&mut inner_basis);
6896        }
6897        let inner_term = SmoothTermSpec {
6898            name: term.name.clone(),
6899            basis: inner_basis,
6900            shape: term.shape,
6901            joint_null_rotation: None,
6902        };
6903        let built = build_single_local_smooth_term(data, &inner_term, workspace)?;
6904        return apply_by_variable_to_local_build(built, data, *by_col, by, &term.name);
6905    }
6906
6907    // BySmooth: a `by=` smooth that unifies numeric or factor modulation into a
6908    // single term.  Lower it here so the downstream match does not need an arm.
6909    if let SmoothBasisSpec::BySmooth { smooth, by_kind } = &term.basis {
6910        return build_by_smooth_local(data, term, smooth, by_kind, workspace);
6911    }
6912
6913    let mut shape_axis_col: Option<usize> = None;
6914    let mut built: BasisBuildResult = match &term.basis {
6915        SmoothBasisSpec::FactorSumToZero {
6916            inner,
6917            by_col,
6918            levels,
6919            ..
6920        } => {
6921            if *by_col >= data.ncols() {
6922                crate::bail_dim_basis!(
6923                    "term '{}' by column {} out of bounds for {} columns",
6924                    term.name,
6925                    by_col,
6926                    data.ncols()
6927                );
6928            }
6929            if levels.len() < 2 {
6930                crate::bail_invalid_basis!(
6931                    "sum-to-zero factor smooth term '{}' requires at least two levels",
6932                    term.name
6933                );
6934            }
6935            if term.shape != ShapeConstraint::None {
6936                crate::bail_invalid_basis!(
6937                    "ShapeConstraint::{:?} is unsupported for sum-to-zero factor smooth term '{}'",
6938                    term.shape,
6939                    term.name
6940                );
6941            }
6942            let inner_term = SmoothTermSpec {
6943                name: format!("{}::inner", term.name),
6944                basis: (**inner).clone(),
6945                shape: ShapeConstraint::None,
6946                joint_null_rotation: None,
6947            };
6948            let mut inner_built = build_single_local_smooth_term(data, &inner_term, workspace)?;
6949            let base = inner_built
6950                .design
6951                .try_to_dense_by_chunks("sum-to-zero factor smooth")
6952                .map_err(BasisError::InvalidInput)?;
6953            let n = base.nrows();
6954            let p = base.ncols();
6955            let l_minus_one = levels.len() - 1;
6956            let mut dense = Array2::<f64>::zeros((n, p * l_minus_one));
6957            for i in 0..n {
6958                let bits = data[[i, *by_col]].to_bits();
6959                let level_idx = levels.iter().position(|b| *b == bits).ok_or_else(|| {
6960                    BasisError::InvalidInput(format!(
6961                        "sum-to-zero factor smooth term '{}' saw an unseen level at row {}",
6962                        term.name,
6963                        i + 1
6964                    ))
6965                })?;
6966                if level_idx < l_minus_one {
6967                    let start = level_idx * p;
6968                    dense
6969                        .slice_mut(s![i, start..start + p])
6970                        .assign(&base.row(i));
6971                } else {
6972                    for level in 0..l_minus_one {
6973                        let start = level * p;
6974                        dense
6975                            .slice_mut(s![i, start..start + p])
6976                            .assign(&base.row(i).mapv(|v| -v));
6977                    }
6978                }
6979            }
6980            let mut penalties = Vec::<Array2<f64>>::with_capacity(inner_built.penalties.len());
6981            let active_penalty_indices = inner_built
6982                .penaltyinfo
6983                .iter()
6984                .enumerate()
6985                .filter_map(|(idx, info)| info.active.then_some(idx))
6986                .collect::<Vec<_>>();
6987            if active_penalty_indices.len() != inner_built.penalties.len() {
6988                crate::bail_invalid_basis!(
6989                    "internal sz penalty metadata mismatch: activeinfos={}, penalties={}",
6990                    active_penalty_indices.len(),
6991                    inner_built.penalties.len()
6992                );
6993            }
6994            for (penalty_pos, s_inner) in inner_built.penalties.iter().enumerate() {
6995                let mut s_big = Array2::<f64>::zeros((p * l_minus_one, p * l_minus_one));
6996                for a in 0..l_minus_one {
6997                    for b in 0..l_minus_one {
6998                        let factor = if a == b { 2.0 } else { 1.0 };
6999                        let mut block = s_big.slice_mut(s![a * p..(a + 1) * p, b * p..(b + 1) * p]);
7000                        block.assign(&s_inner.mapv(|v| v * factor));
7001                    }
7002                }
7003                let (s_big, factor_smooth_scale) = normalize_penalty_in_constrained_space(&s_big);
7004                let info_idx = active_penalty_indices[penalty_pos];
7005                inner_built.penaltyinfo[info_idx].normalization_scale *= factor_smooth_scale;
7006                penalties.push(s_big);
7007            }
7008            inner_built.dim = p * l_minus_one;
7009            inner_built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense));
7010            inner_built.penalties = penalties;
7011            inner_built.ops = vec![None; inner_built.penalties.len()];
7012            inner_built.nullspaces = inner_built
7013                .nullspaces
7014                .iter()
7015                .map(|ns| ns.saturating_mul(l_minus_one))
7016                .collect();
7017            // Invariant: `null_eigenvectors[k]` must mirror `penalties[k]`'s
7018            // spectral null space. We just rebuilt `inner_built.penalties` from
7019            // Kronecker-like `S_big` blocks, so the previously-plumbed
7020            // `null_eigenvectors` (still parallel to the OLD per-level penalty)
7021            // is stale. Recompute from the rebuilt penalties to restore the
7022            // invariant; ditto for the joint-null absorption rotation.
7023            inner_built.null_eigenvectors =
7024                crate::basis::recompute_null_eigenvectors(&inner_built.penalties)?;
7025            inner_built.joint_null_rotation =
7026                crate::basis::compute_joint_null_rotation(&inner_built.penalties)?;
7027            inner_built.kronecker_factored = None;
7028            return Ok(inner_built);
7029        }
7030        SmoothBasisSpec::BSpline1D { feature_col, spec } => {
7031            if *feature_col >= data.ncols() {
7032                crate::bail_dim_basis!(
7033                    "term '{}' feature column {} out of bounds for {} columns",
7034                    term.name,
7035                    feature_col,
7036                    data.ncols()
7037                );
7038            }
7039            let mut spec_local = spec.clone();
7040            if term.shape != ShapeConstraint::None {
7041                // Shape-constrained B-splines are anchored by construction.
7042                // Sum-to-zero side constraints conflict with monotonic/convex cones.
7043                spec_local.identifiability = BSplineIdentifiability::None;
7044            }
7045            // Endpoint boundary conditions are structural for B-splines: the
7046            // basis builder bakes their homogeneous nullspace transform into
7047            // the design, penalties, and stored raw-basis transform.
7048            build_bspline_basis_1d(data.column(*feature_col), &spec_local)?
7049        }
7050        SmoothBasisSpec::ThinPlate {
7051            feature_cols,
7052            spec,
7053            input_scales,
7054        } => {
7055            if term.shape != ShapeConstraint::None {
7056                if feature_cols.len() != 1 {
7057                    crate::bail_invalid_basis!(
7058                        "ShapeConstraint::{:?} for term '{}' on ThinPlate basis requires exactly 1 feature axis; found {}",
7059                        term.shape,
7060                        term.name,
7061                        feature_cols.len()
7062                    );
7063                }
7064                shape_axis_col = Some(feature_cols[0]);
7065            }
7066            let mut x = select_columns(data, feature_cols)?;
7067            // Auto-standardize multivariate inputs: use stored scales (prediction)
7068            // or compute fresh ones (training). Same standardization-vs-
7069            // length-scale compensation as Matérn / hybrid Duchon: divide
7070            // the user's L by σ_geom so kernel(‖x_std − c_std‖/L_eff)
7071            // matches the original-coord kernel for uniform σ.
7072            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7073                apply_input_standardization(&mut x, s);
7074                (
7075                    Some(s.clone()),
7076                    compensate_length_scale_for_standardization(spec.length_scale, s),
7077                )
7078            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7079                apply_input_standardization(&mut x, &s);
7080                let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
7081                (Some(s), l_eff)
7082            } else {
7083                (None, spec.length_scale)
7084            };
7085            let mut spec_local = spec.clone();
7086            spec_local.length_scale = length_scale_eff;
7087            if matches!(
7088                spec_local.identifiability,
7089                SpatialIdentifiability::OrthogonalToParametric
7090            ) {
7091                spec_local.identifiability = SpatialIdentifiability::None;
7092            }
7093            let mut result = build_thin_plate_basis(x.view(), &spec_local).map_err(|err| {
7094                rewrite_thin_plate_knots_error(err, &term.name, feature_cols.len(), spec)
7095            })?;
7096            // Inject input scales into metadata; also restore the user's
7097            // original length_scale (not the σ_geom-compensated one) so a
7098            // metadata-driven rebuild that re-applies compensation does not
7099            // double-divide. The build may auto-promote to Duchon when
7100            // canonical TPS is infeasible (k < polynomial-nullspace size);
7101            // in that case patch the Duchon metadata variant so predict-time
7102            // round-trips through the same standardized data path.
7103            match &mut result.metadata {
7104                BasisMetadata::ThinPlate {
7105                    input_scales: ms,
7106                    length_scale,
7107                    ..
7108                } => {
7109                    *ms = scales;
7110                    *length_scale = spec.length_scale;
7111                }
7112                BasisMetadata::Duchon {
7113                    input_scales: ms,
7114                    length_scale,
7115                    ..
7116                } => {
7117                    *ms = scales;
7118                    // The ThinPlate auto-promotion path delegates to
7119                    // `build_duchon_basis` with `Some(spec_local.length_scale)`,
7120                    // which is the σ_geom-compensated value. The metadata
7121                    // therefore records the compensated kernel range, but the
7122                    // freeze→replay round trip plugs that value back into a
7123                    // user-facing `DuchonBasisSpec.length_scale` whose builder
7124                    // applies the σ_geom compensation a second time. Restore
7125                    // the user-facing scale here so replay re-compensates
7126                    // exactly once and reproduces the realized fit-time basis.
7127                    *length_scale = Some(spec.length_scale);
7128                }
7129                _ => {}
7130            }
7131            result
7132        }
7133        SmoothBasisSpec::Sphere { feature_cols, spec } => {
7134            if term.shape != ShapeConstraint::None {
7135                crate::bail_invalid_basis!(
7136                    "ShapeConstraint::{:?} for term '{}' is not supported on spherical splines",
7137                    term.shape,
7138                    term.name
7139                );
7140            }
7141            let x = select_columns(data, feature_cols)?;
7142            build_spherical_spline_basis(x.view(), spec)?
7143        }
7144        SmoothBasisSpec::ConstantCurvature { feature_cols, spec } => {
7145            if term.shape != ShapeConstraint::None {
7146                crate::bail_invalid_basis!(
7147                    "ShapeConstraint::{:?} for term '{}' is not supported on constant-curvature smooths",
7148                    term.shape,
7149                    term.name
7150                );
7151            }
7152            // Chart coordinates are consumed verbatim: NO auto-standardization.
7153            // Rescaling axes would change the chart gauge `1 + κ‖x‖²` and
7154            // silently redefine which curvature κ refers to (the same point
7155            // cloud at a different chart scale has a different κ̂); the user's
7156            // coordinates ARE the geometry here, exactly as for the sphere
7157            // smooth's (lat, lon).
7158            let x = select_columns(data, feature_cols)?;
7159            build_constant_curvature_basis(x.view(), spec)?
7160        }
7161        SmoothBasisSpec::MeasureJet {
7162            feature_cols,
7163            spec,
7164            input_scales,
7165        } => {
7166            if term.shape != ShapeConstraint::None {
7167                crate::bail_invalid_basis!(
7168                    "ShapeConstraint::{:?} for term '{}' is not supported on measure-jet smooths",
7169                    term.shape,
7170                    term.name
7171                );
7172            }
7173            let mut x = select_columns(data, feature_cols)?;
7174            // Matern-style per-axis standardization; the realized σ vector is
7175            // persisted into the metadata for predict-time replay.
7176            //
7177            // Length-scale round-trip contract (owning statement; the freeze
7178            // and frozen-validation arms reference it): `input_scales: Some`
7179            // marks the REPLAY path — the frozen length_scale is already the
7180            // realized post-standardization value and passes through
7181            // verbatim. Fresh path: an explicit user length_scale is in
7182            // ORIGINAL coordinates and gets the σ_geom compensation; the 0.0
7183            // auto sentinel passes through (auto-derivation runs inside the
7184            // builder, post-standardization).
7185            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7186                apply_input_standardization(&mut x, s);
7187                (Some(s.clone()), spec.length_scale)
7188            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7189                apply_input_standardization(&mut x, &s);
7190                let l_eff = if spec.length_scale > 0.0 {
7191                    compensate_length_scale_for_standardization(spec.length_scale, &s)
7192                } else {
7193                    spec.length_scale
7194                };
7195                (Some(s), l_eff)
7196            } else {
7197                (None, spec.length_scale)
7198            };
7199            let mut spec_local = spec.clone();
7200            spec_local.length_scale = length_scale_eff;
7201            let mut result = build_measure_jet_basis(x.view(), &spec_local)?;
7202            if let BasisMetadata::MeasureJet {
7203                input_scales: ms, ..
7204            } = &mut result.metadata
7205            {
7206                *ms = scales;
7207            }
7208            result
7209        }
7210        SmoothBasisSpec::Matern {
7211            feature_cols,
7212            spec,
7213            input_scales,
7214        } => {
7215            if term.shape != ShapeConstraint::None {
7216                if feature_cols.len() != 1 {
7217                    crate::bail_invalid_basis!(
7218                        "ShapeConstraint::{:?} for term '{}' on Matern basis requires exactly 1 feature axis; found {}",
7219                        term.shape,
7220                        term.name,
7221                        feature_cols.len()
7222                    );
7223                }
7224                shape_axis_col = Some(feature_cols[0]);
7225            }
7226            let mut x = select_columns(data, feature_cols)?;
7227            // Auto-standardization (per-axis division by σ_a) reinterprets
7228            // the user's `length_scale` from original data coordinates
7229            // into post-standardization coordinates: for uniform σ_a = σ,
7230            // `kernel(‖x_std − c_std‖/L)` equals `kernel(‖x − c‖/(σ·L))`,
7231            // so the effective kernel range shrinks by σ. To keep
7232            // `length_scale` consistently expressed in *original* data
7233            // coordinates regardless of axis variances, we standardize
7234            // and divide L by σ_geom = (∏σ_a)^(1/d). For uniform σ this
7235            // recovers the user's kernel exactly; for anisotropic data
7236            // the resulting per-axis effective scales σ_a / σ_geom are
7237            // the standard Mahalanobis preconditioning and preserve the
7238            // geometric-mean kernel range. Storing the σ vector in
7239            // metadata.input_scales makes the same transformation
7240            // replayable at predict time.
7241            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7242                apply_input_standardization(&mut x, s);
7243                (
7244                    Some(s.clone()),
7245                    compensate_length_scale_for_standardization(spec.length_scale, s),
7246                )
7247            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7248                apply_input_standardization(&mut x, &s);
7249                let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
7250                (Some(s), l_eff)
7251            } else {
7252                (None, spec.length_scale)
7253            };
7254            let mut spec_local = spec.clone();
7255            spec_local.length_scale = length_scale_eff;
7256            let mut result = build_matern_basiswithworkspace(x.view(), &spec_local, workspace)?;
7257            if let BasisMetadata::Matern {
7258                input_scales,
7259                length_scale,
7260                ..
7261            } = &mut result.metadata
7262            {
7263                *input_scales = scales;
7264                *length_scale = spec.length_scale;
7265            }
7266            result
7267        }
7268        SmoothBasisSpec::Duchon {
7269            feature_cols,
7270            spec,
7271            input_scales,
7272        } => {
7273            if term.shape != ShapeConstraint::None {
7274                if feature_cols.len() != 1 {
7275                    crate::bail_invalid_basis!(
7276                        "ShapeConstraint::{:?} for term '{}' on Duchon basis requires exactly 1 feature axis; found {}",
7277                        term.shape,
7278                        term.name,
7279                        feature_cols.len()
7280                    );
7281                }
7282                shape_axis_col = Some(feature_cols[0]);
7283            }
7284            let mut x = select_columns(data, feature_cols)?;
7285            // Hybrid Duchon (length_scale=Some) is governed by the same
7286            // standardization-vs-length-scale equivalence as Matérn: the
7287            // user's `length_scale` is interpreted in original data
7288            // coordinates, but auto-standardization (per-axis division by
7289            // σ_a) reinterprets it as σ_geom · L. Pre-multiply by 1/σ_geom
7290            // so kernel(‖x_std − c_std‖/L_eff) reproduces the user's
7291            // original-coord kernel exactly for uniform σ_a, and reduces
7292            // to standard Mahalanobis preconditioning for anisotropic σ.
7293            // Pure Duchon (length_scale=None) is scale-free and needs no
7294            // compensation.
7295            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7296                apply_input_standardization(&mut x, s);
7297                (
7298                    Some(s.clone()),
7299                    compensate_optional_length_scale_for_standardization(spec.length_scale, s),
7300                )
7301            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7302                apply_input_standardization(&mut x, &s);
7303                let l_eff =
7304                    compensate_optional_length_scale_for_standardization(spec.length_scale, &s);
7305                (Some(s), l_eff)
7306            } else {
7307                (None, spec.length_scale)
7308            };
7309            let mut spec_local = spec.clone();
7310            spec_local.length_scale = length_scale_eff;
7311            if matches!(
7312                spec_local.identifiability,
7313                SpatialIdentifiability::OrthogonalToParametric
7314            ) {
7315                spec_local.identifiability = SpatialIdentifiability::None;
7316            }
7317            let mut result = build_duchon_basiswithworkspace(x.view(), &spec_local, workspace)?;
7318            if let BasisMetadata::Duchon {
7319                input_scales,
7320                length_scale,
7321                ..
7322            } = &mut result.metadata
7323            {
7324                *input_scales = scales;
7325                *length_scale = spec.length_scale;
7326            }
7327            result
7328        }
7329        SmoothBasisSpec::Pca {
7330            feature_cols,
7331            basis_matrix,
7332            centered,
7333            smooth_penalty,
7334            center_mean,
7335            pca_basis_path,
7336            chunk_size,
7337        } => {
7338            if term.shape != ShapeConstraint::None {
7339                crate::bail_invalid_basis!(
7340                    "ShapeConstraint::{:?} for term '{}' is not supported on Pca basis",
7341                    term.shape,
7342                    term.name
7343                );
7344            }
7345            build_pca_smooth_basis(
7346                data,
7347                feature_cols,
7348                basis_matrix,
7349                *centered,
7350                *smooth_penalty,
7351                center_mean.as_ref(),
7352                pca_basis_path.as_ref(),
7353                *chunk_size,
7354            )?
7355        }
7356        SmoothBasisSpec::TensorBSpline { feature_cols, spec } => {
7357            build_tensor_bspline_basis(data, feature_cols, spec)?
7358        }
7359        SmoothBasisSpec::ByVariable { .. } => {
7360            crate::bail_invalid_basis!(
7361                "internal: ByVariable smooths must return before inner basis dispatch"
7362            );
7363        }
7364        SmoothBasisSpec::BySmooth { .. } => {
7365            crate::bail_invalid_basis!("internal: BySmooth smooths must be lowered to ByVariable before inner basis dispatch"
7366                    .to_string(),);
7367        }
7368        SmoothBasisSpec::FactorSmooth { spec } => {
7369            if term.shape != ShapeConstraint::None {
7370                crate::bail_invalid_basis!(
7371                    "ShapeConstraint::{:?} is unsupported for factor smooth term '{}'",
7372                    term.shape,
7373                    term.name
7374                );
7375            }
7376            return build_factor_smooth(data, spec, &term.name, workspace);
7377        }
7378    };
7379
7380    // The Matérn design ALWAYS uses the operator-collocation {mass, tension,
7381    // stiffness} penalty triplet, overriding whatever penalty
7382    // `build_matern_basis_seeded` produced for the `double_penalty` flag.
7383    //
7384    // #1074 investigated swapping this for the genuine RKHS kernel penalty
7385    // `β' K_CC β` (mgcv `bs="gp"` / fields kriging) on the theory that the
7386    // operator triplet under-smooths the rougher half-integer kernels. MSI
7387    // truth-recovery measurement REFUTED that: the kernel penalty did NOT
7388    // improve ν=3/2 recovery (`matern(x,nu=1.5)` RMSE-vs-truth stayed 0.0554)
7389    // and it REGRESSED the high-frequency-init guard — `matern(x,nu≥5/2)` on
7390    // sin(2π·8·x) collapsed (span 0.53, RMSE 0.70) because the single RKHS
7391    // norm over-smooths a high-frequency truth where the Sobolev-order operator
7392    // dials do not. The operator triplet is therefore retained as the Matérn
7393    // penalty, and the κ-optimizer re-key / ψ-derivative paths route through the
7394    // same triplet builder so the block count stays ψ-stable (#1270).
7395    if let SmoothBasisSpec::Matern { .. } = &term.basis {
7396        let (penalties, nullspace_dims, penaltyinfo) =
7397            matern_operator_penalty_triplet_from_metadata(&built.metadata)?;
7398        built.penalties = penalties;
7399        built.nullspace_dims = nullspace_dims;
7400        built.penaltyinfo = penaltyinfo;
7401    }
7402
7403    let p_local = built.design.ncols();
7404    let mut metadata = built.metadata.clone();
7405    // Extract factored Kronecker representation before consuming fields.
7406    // Invalidate it if shape transforms will be applied (they break structure).
7407    let kron_factored = if term.shape == ShapeConstraint::None {
7408        built.kronecker_factored
7409    } else {
7410        None
7411    };
7412    let mut design_t = built.design;
7413    let mut penalties_t: Vec<Array2<f64>> = built.penalties;
7414    // Ops vector parallel to `penalties_t`. Survives unchanged through the
7415    // identity path; nulled element-wise when `T^T S T` reparametrization
7416    // is applied (operator no longer bit-equivalent to the transformed
7417    // matrix); wrapped in `ScaledPenaltyOp` after Frobenius normalization.
7418    let mut ops_t: Vec<Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>> =
7419        built.ops;
7420    if matches!(
7421        spatial_identifiability_policy(term),
7422        Some(SpatialIdentifiability::OrthogonalToParametric)
7423    ) {
7424        metadata = freeze_raw_spatial_metadata(metadata, design_t.ncols());
7425    }
7426
7427    let active_penaltyinfo_t = built
7428        .penaltyinfo
7429        .iter()
7430        .filter(|info| info.active)
7431        .cloned()
7432        .collect::<Vec<_>>();
7433    let pre_dropped_penaltyinfo_t = built
7434        .penaltyinfo
7435        .iter()
7436        .filter(|info| !info.active)
7437        .cloned()
7438        .collect::<Vec<_>>();
7439    let use_box_reparam =
7440        term.shape != ShapeConstraint::None && shape_uses_box_reparameterization(&term.basis);
7441    if let Some((order, sign)) = shape_order_and_sign(term.shape)
7442        && use_box_reparam
7443    {
7444        // Order 1 (monotone): the plain first-difference cone θ_{i+1}−θ_i ≥ 0 is
7445        // the control-polygon monotonicity criterion, which is independent of
7446        // Greville-abscissa spacing (it only fixes the *sign* of consecutive
7447        // control-point gaps), so the integer-difference transform is exact.
7448        //
7449        // Order 2 (convex/concave): the plain second-difference cone is only
7450        // correct for evenly spaced Greville abscissae. gam's B-splines are
7451        // clamped (and may use quantile knots), so the abscissae are not
7452        // uniform and the geometrically-correct cone is the second *divided*
7453        // difference. Build the Greville-scaled transform so γ_{≥2} ≥ 0
7454        // certifies convexity of the function, not of the raw coefficient
7455        // index. Periodic B-splines use uniform interior knots (uniform
7456        // abscissae), where the divided differences coincide with the integer
7457        // differences up to scale, so the plain path stays exact there.
7458        let t = if order == 2 {
7459            let bspline_meta = match &metadata {
7460                BasisMetadata::BSpline1D {
7461                    knots,
7462                    degree,
7463                    periodic,
7464                    ..
7465                } if periodic.is_none() => Some((knots.clone(), degree.unwrap_or(0))),
7466                _ => None,
7467            };
7468            match bspline_meta {
7469                Some((knots, degree)) if degree >= 1 => {
7470                    let greville = crate::basis::compute_greville_abscissae(&knots, degree)?;
7471                    if greville.len() != p_local {
7472                        crate::bail_invalid_basis!(
7473                            "shape-constraint Greville abscissae count {} does not match basis dim {} for term '{}'",
7474                            greville.len(),
7475                            p_local,
7476                            term.name
7477                        );
7478                    }
7479                    convex_divided_difference_transform_matrix(&greville, sign)?
7480                }
7481                _ => cumulative_sum_transform_matrix(p_local, order, sign),
7482            }
7483        } else {
7484            cumulative_sum_transform_matrix(p_local, order, sign)
7485        };
7486        // Coefficient-side transform: wrap the design in an operator that
7487        // applies T on the coefficient side, preserving sparsity/operator
7488        // structure of the inner design.
7489        let inner_dense = match design_t {
7490            DesignMatrix::Dense(d) => d,
7491            DesignMatrix::Sparse(sp) => gam_linalg::matrix::DenseDesignMatrix::from(
7492                sp.try_to_dense_arc("shape-constrained coefficient transform")
7493                    .map_err(BasisError::InvalidInput)?,
7494            ),
7495        };
7496        let coeff_op = gam_linalg::matrix::CoefficientTransformOperator::new(inner_dense, t.clone())
7497            .map_err(|e| BasisError::InvalidInput(format!("CoefficientTransformOperator: {e}")))?;
7498        design_t = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(coeff_op)));
7499        if penalties_t.len() != active_penaltyinfo_t.len() {
7500            crate::bail_invalid_basis!(
7501                "internal box-reparam penalty/info mismatch for term '{}': penalties={}, infos={}",
7502                term.name,
7503                penalties_t.len(),
7504                active_penaltyinfo_t.len()
7505            );
7506        }
7507        // Wiggliness penalties undergo the exact congruence `S → TᵀST` (PSD
7508        // preserving). The double-penalty *nullspace shrinkage* ridge must NOT:
7509        // it is a unit-eigenvalue projector `ZZᵀ` onto null(S_wiggle) in the
7510        // β (B-spline coefficient) coordinates, and the congruence
7511        // `Tᵀ(ZZᵀ)T = (TᵀZ)(TᵀZ)ᵀ` is no longer a projector — its eigenvalues
7512        // blow up by the conditioning of the cumulative-sum `T` (cond(T) grows
7513        // with the basis dim), concentrating an enormous penalty on the leading
7514        // γ₀ "level" coordinate. REML then drives the shared λ to its ceiling
7515        // and the smooth collapses to a flat constant (#509, the over-smoothing
7516        // face). The principled fix keeps mgcv's double-penalty semantics in the
7517        // *reparametrized* space: rebuild the ridge as the unit-eigenvalue
7518        // nullspace projector of the transformed wiggliness penalty `TᵀST`, so
7519        // the double penalty shrinks exactly the unpenalized polynomial
7520        // directions of the γ-space smooth with eigenvalue 1, identical in
7521        // conditioning to the unconstrained fit.
7522        let transformed_wiggliness = penalties_t
7523            .iter()
7524            .zip(active_penaltyinfo_t.iter())
7525            .find(|(_, info)| !matches!(info.source, PenaltySource::DoublePenaltyNullspace))
7526            .map(|(s_local, _)| {
7527                let tt_s = fast_atb(&t, s_local);
7528                fast_ab(&tt_s, &t)
7529            });
7530        let mut rebuilt = Vec::with_capacity(penalties_t.len());
7531        for (s_local, info) in penalties_t.iter().zip(active_penaltyinfo_t.iter()) {
7532            if matches!(info.source, PenaltySource::DoublePenaltyNullspace) {
7533                let s_wiggle_t = transformed_wiggliness.as_ref().ok_or_else(|| {
7534                    BasisError::InvalidInput(format!(
7535                        "box-reparam term '{}' has a double-penalty ridge but no primary wiggliness penalty to derive its nullspace from",
7536                        term.name
7537                    ))
7538                })?;
7539                let ridge = crate::basis::build_nullspace_shrinkage_penalty(s_wiggle_t)?
7540                    .map(|shrink| shrink.sym_penalty)
7541                    .unwrap_or_else(|| Array2::<f64>::zeros((p_local, p_local)));
7542                rebuilt.push(ridge);
7543            } else {
7544                let tt_s = fast_atb(&t, s_local);
7545                rebuilt.push(fast_ab(&tt_s, &t));
7546            }
7547        }
7548        penalties_t = rebuilt;
7549        // T^T S T (and the rebuilt γ-space ridge) invalidate op-form
7550        // bit-equivalence; drop ops here.
7551        ops_t = vec![None; penalties_t.len()];
7552    }
7553    if penalties_t.len() != active_penaltyinfo_t.len() {
7554        crate::bail_invalid_basis!(
7555            "internal penalty metadata mismatch for term '{}': active penalties={}, active infos={}",
7556            term.name,
7557            penalties_t.len(),
7558            active_penaltyinfo_t.len()
7559        );
7560    }
7561    if ops_t.len() != penalties_t.len() {
7562        ops_t = vec![None; penalties_t.len()];
7563    }
7564    let penalty_candidates = penalties_t
7565        .into_iter()
7566        .zip(active_penaltyinfo_t.into_iter())
7567        .zip(ops_t.into_iter())
7568        .map(
7569            |((matrix, info), op_in)| -> Result<PenaltyCandidate, BasisError> {
7570                let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
7571                let normalization_scale = info.normalization_scale * c_new;
7572                let op_scale = 1.0 / c_new;
7573                let kronecker_scale = 1.0 / c_new;
7574                // Frobenius rescale: wrap inner op in `ScaledPenaltyOp(1/c_new)`
7575                // so `op.as_dense() == matrix` post-normalization.
7576                let scaled_op = if op_scale > 0.0 && op_scale.is_finite() {
7577                    op_in.map(|op| {
7578                        std::sync::Arc::new(crate::analytic_penalties::ScaledPenaltyOp::new(
7579                            op, op_scale,
7580                        ))
7581                            as std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>
7582                    })
7583                } else {
7584                    None
7585                };
7586                let kronecker_factors = info.kronecker_factors.map(|mut factors| {
7587                    if let Some(first) = factors.first_mut() {
7588                        first.mapv_inplace(|v| v * kronecker_scale);
7589                    }
7590                    factors
7591                });
7592                Ok(PenaltyCandidate {
7593                    nullspace_dim_hint: info.nullspace_dim_hint,
7594                    matrix,
7595                    source: info.source,
7596                    normalization_scale,
7597                    kronecker_factors,
7598                    op: scaled_op,
7599                })
7600            },
7601        )
7602        .collect::<Result<Vec<_>, _>>()?;
7603    let (penalties_t, nullspaces_t, penaltyinfo_t, null_eigenvectors_t, ops_t) =
7604        crate::basis::filter_active_penalty_candidates_with_ops(penalty_candidates)?;
7605    let shape_linear_constraints = if term.shape != ShapeConstraint::None && !use_box_reparam {
7606        let axis = shape_axis_col.ok_or_else(|| {
7607            BasisError::InvalidInput(format!(
7608                "internal shape-constraint axis missing for term '{}'",
7609                term.name
7610            ))
7611        })?;
7612        let (x_shape_eval, design_shape_eval) =
7613            build_shape_constraint_design_1d(data, term, &metadata, axis)?;
7614        build_shape_linear_constraints_1d(
7615            x_shape_eval.view(),
7616            design_shape_eval.view(),
7617            term.shape,
7618        )?
7619    } else {
7620        None
7621    };
7622    let linear_constraints_local = merge_linear_constraints_global(shape_linear_constraints, None);
7623
7624    // Joint-null absorption rotation. Fresh fit specs compute Q from the final
7625    // per-smooth penalty set (after all in-smooth reparameterizations have
7626    // already been applied). Frozen specs already carry the complete realized
7627    // coefficient chart in their `FrozenTransform`; recomputing Q there would
7628    // rotate an already-frozen chart a second time and desynchronize value
7629    // rebuilds from derivative operators.
7630    //
7631    // Kronecker-factored smooths (tensor B-splines under `TensorBSplineIdentifiability::None`)
7632    // carry their joint penalty as `Σ_d S_d` with `S_d = I ⊗ … ⊗ S_d^{1D} ⊗ … ⊗ I`.
7633    // The joint null space is the tensor of marginal nulls and is handled directly
7634    // by the REML runtime's `kronecker_penalty_system` path (see
7635    // `runtime.rs:8334-8344`). Applying a dense (p × p) Q here would densify
7636    // `X_raw = mx ⊗ my` into `X_raw · Q`, destroying the Kronecker product
7637    // structure that the runtime relies on for fast log-det/derivative
7638    // assembly — and the rotation block at the wrapper site also unconditionally
7639    // wipes `kronecker_factored`, leaving the runtime to fall back to the
7640    // dense per-block log-det. Skip the rotation for Kronecker-factored terms
7641    // so the factored representation survives end-to-end.
7642    let joint_null_rotation = match term.joint_null_rotation.clone() {
7643        Some(persisted) => Some(persisted),
7644        None if smooth_has_frozen_identifiability(term) => None,
7645        None if kron_factored.is_some() => None,
7646        None => crate::basis::compute_joint_null_rotation(&penalties_t)?,
7647    };
7648
7649    Ok(LocalSmoothTermBuild {
7650        dim: p_local,
7651        design: design_t,
7652        penalties: penalties_t,
7653        ops: ops_t,
7654        nullspaces: nullspaces_t,
7655        null_eigenvectors: null_eigenvectors_t,
7656        joint_null_rotation,
7657        penaltyinfo: penaltyinfo_t,
7658        pre_dropped_penaltyinfo: pre_dropped_penaltyinfo_t,
7659        metadata,
7660        linear_constraints: linear_constraints_local,
7661        box_reparam: use_box_reparam,
7662        kronecker_factored: kron_factored,
7663    })
7664}
7665
7666pub fn build_smooth_design(
7667    data: ArrayView2<'_, f64>,
7668    terms: &[SmoothTermSpec],
7669) -> Result<RawSmoothDesign, BasisError> {
7670    let mut ws = crate::basis::BasisWorkspace::new();
7671    build_smooth_design_withworkspace(data, terms, &mut ws)
7672}
7673
7674/// Like `build_smooth_design`, but honors the caller workspace policy while
7675/// building each planned smooth term with an independent per-term workspace.
7676///
7677/// Independent workspaces avoid shared mutable distance-cache state during the
7678/// parallel term build; the final design, penalties, and metadata are assembled
7679/// in the original smooth-term order.
7680pub fn build_smooth_design_withworkspace(
7681    data: ArrayView2<'_, f64>,
7682    terms: &[SmoothTermSpec],
7683    workspace: &mut crate::basis::BasisWorkspace,
7684) -> Result<RawSmoothDesign, BasisError> {
7685    validate_smooth_terms_finite_inputs(data, terms)?;
7686    build_smooth_design_withworkspace_unvalidated(data, terms, workspace)
7687}
7688
7689pub fn build_smooth_design_withworkspace_unvalidated(
7690    data: ArrayView2<'_, f64>,
7691    terms: &[SmoothTermSpec],
7692    workspace: &mut crate::basis::BasisWorkspace,
7693) -> Result<RawSmoothDesign, BasisError> {
7694    let mut planned_blocks = plan_joint_spatial_centers_for_term_blocks(data, &[terms.to_vec()])?;
7695    let planned_terms = planned_blocks.pop().ok_or_else(|| {
7696        BasisError::InvalidInput(
7697            "joint spatial center planner returned no smooth blocks".to_string(),
7698        )
7699    })?;
7700    let policy = workspace.policy().clone();
7701    let local_builds: Vec<LocalSmoothTermBuild> = {
7702        use rayon::iter::{IntoParallelIterator, ParallelIterator};
7703        planned_terms
7704            .into_par_iter()
7705            .map(|term| {
7706                let mut term_workspace = crate::basis::BasisWorkspace::with_policy(policy.clone());
7707                build_single_local_smooth_term(data, &term, &mut term_workspace)
7708            })
7709            .collect::<Result<Vec<_>, _>>()?
7710    };
7711
7712    let total_p: usize = local_builds.iter().map(|built| built.dim).sum();
7713
7714    let mut local_designs: Vec<DesignMatrix> = Vec::with_capacity(local_builds.len());
7715    let mut terms_out = Vec::<SmoothTerm>::with_capacity(terms.len());
7716    let mut penalties_global = Vec::<BlockwisePenalty>::new();
7717    let mut nullspace_dims_global = Vec::<usize>::new();
7718    let mut penaltyinfo_global = Vec::<PenaltyBlockInfo>::new();
7719    let mut dropped_penaltyinfo_global = Vec::<DroppedPenaltyBlockInfo>::new();
7720    let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
7721    let mut any_bounds = false;
7722    // Each linear-constraint row only touches the current term's column slice.
7723    // Track `(col_start, col_end, local_row_values)` and assemble the final
7724    // dense `Array2` in one pass, avoiding per-row `Array1::zeros(total_p)`
7725    // allocation plus a row-by-row copy at the end.
7726    let mut linear_constraintsrows: Vec<(usize, usize, Array1<f64>)> = Vec::new();
7727    let mut linear_constraints_b: Vec<f64> = Vec::new();
7728
7729    let mut col_start = 0usize;
7730    for (term, mut built) in terms.iter().zip(local_builds.into_iter()) {
7731        let p_local = built.dim;
7732        let col_end = col_start + p_local;
7733        let lb_local = if built.box_reparam {
7734            shape_lower_bounds_local(term.shape, p_local)
7735        } else {
7736            None
7737        };
7738
7739        // Stage-2 joint-null absorption rotation. Fired *before* the
7740        // penalty / design / global aggregation loops below so that every
7741        // subsequent reference to `built.penalties`, `built.design`, and
7742        // `built.ops` sees the post-rotation values.
7743        //
7744        // The math: when the smooth's joint penalty `Σ_k S_k` has a
7745        // non-trivial null space, eigh selects `Q = [U_range | U_null]`
7746        // with null columns at the tail. Setting `β_raw = Q · γ` and
7747        // applying:
7748        //     design        ← X · Q
7749        //     penalties[k]  ← Qᵀ · S_k · Q   (block-diag, zero null tail)
7750        // yields a model whose fitted γ is invariant to the rotation
7751        // (since likelihood depends only on `X · β_raw = X · Q · γ`), but
7752        // whose penalty is full-rank on the range columns. The large-scale
7753        // failing case (cert refusal in the joint-Newton inner solve)
7754        // resolves because `H_pen = H_loglik + S` becomes full rank on
7755        // the smooth's range columns.
7756        //
7757        // Rotation is suppressed when the smooth carries coordinate-wise
7758        // shape constraints (`lb_local` or `built.linear_constraints`):
7759        // those encode a cone in the original coordinate system and a
7760        // general orthogonal rotation breaks the cone geometry. Smooths
7761        // with shape constraints typically have full-rank joint penalty
7762        // (their structural shape comes from the cone, not from null
7763        // directions in the penalty), so suppression is rarely a loss.
7764        //
7765        // `applied_rotation` carries the Q that was applied (or `None`
7766        // if no rotation fired). It is persisted onto `SmoothTerm` below
7767        // so prediction-side `X_new_raw · Q` replay can reproduce the
7768        // exact rotation. Persistence through the saved-model artifact
7769        // is a follow-up — see the doc on `SmoothTerm.joint_null_rotation`.
7770        let applied_rotation: Option<crate::basis::JointNullRotation> = match (
7771            built.joint_null_rotation.take(),
7772            lb_local.is_some(),
7773            built.linear_constraints.is_some(),
7774        ) {
7775            (Some(rot), false, false) => {
7776                let q = &rot.rotation;
7777                let dense = built
7778                    .design
7779                    .try_to_dense_by_chunks("joint-null absorption rotation")
7780                    .map_err(BasisError::InvalidInput)?;
7781                let rotated = gam_linalg::faer_ndarray::fast_ab(&dense, q);
7782                built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(rotated));
7783                built.penalties = built
7784                    .penalties
7785                    .into_iter()
7786                    .map(|s_local| {
7787                        let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
7788                        gam_linalg::faer_ndarray::fast_ab(&qt_s, q)
7789                    })
7790                    .collect();
7791                built.ops = vec![None; built.penalties.len()];
7792                built.kronecker_factored = None;
7793                Some(rot)
7794            }
7795            (Some(_), _, _) => None,
7796            (None, _, _) => None,
7797        };
7798
7799        let activeinfos = built
7800            .penaltyinfo
7801            .iter()
7802            .filter(|info| info.active)
7803            .collect::<Vec<_>>();
7804        if activeinfos.len() != built.penalties.len() {
7805            crate::bail_invalid_basis!(
7806                "internal penalty info mismatch for term '{}': activeinfos={}, penalties={}",
7807                term.name,
7808                activeinfos.len(),
7809                built.penalties.len()
7810            );
7811        }
7812        for (((s_local, &ns), info), op_local) in built
7813            .penalties
7814            .iter()
7815            .zip(built.nullspaces.iter())
7816            .zip(activeinfos.into_iter())
7817            .zip(built.ops.iter())
7818        {
7819            let global_index = penalties_global.len();
7820            penalties_global.push(
7821                BlockwisePenalty::new(col_start..col_end, s_local.clone())
7822                    .with_op(op_local.clone()),
7823            );
7824            nullspace_dims_global.push(ns);
7825            let mut penalty = info.clone();
7826            penalty.nullspace_dim_hint = ns;
7827            penaltyinfo_global.push(PenaltyBlockInfo {
7828                global_index,
7829                termname: Some(term.name.clone()),
7830                penalty,
7831            });
7832        }
7833        for info in built.penaltyinfo.iter().filter(|info| !info.active) {
7834            dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
7835                termname: Some(term.name.clone()),
7836                penalty: info.clone(),
7837            });
7838        }
7839        for info in &built.pre_dropped_penaltyinfo {
7840            dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
7841                termname: Some(term.name.clone()),
7842                penalty: info.clone(),
7843            });
7844        }
7845
7846        if let Some(lin_local) = &built.linear_constraints {
7847            for r in 0..lin_local.a.nrows() {
7848                linear_constraintsrows.push((col_start, col_end, lin_local.a.row(r).to_owned()));
7849                linear_constraints_b.push(lin_local.b[r]);
7850            }
7851        }
7852        if let Some(lb_local) = &lb_local {
7853            coefficient_lower_bounds
7854                .slice_mut(s![col_start..col_end])
7855                .assign(lb_local);
7856            any_bounds = true;
7857        }
7858
7859        // Move the per-term design out of `built` rather than cloning it.
7860        local_designs.push(built.design);
7861
7862        terms_out.push(SmoothTerm {
7863            name: term.name.clone(),
7864            coeff_range: col_start..col_end,
7865            shape: term.shape,
7866            penalties_local: built.penalties,
7867            nullspace_dims: built.nullspaces,
7868            penaltyinfo_local: built.penaltyinfo,
7869            metadata: built.metadata,
7870            lower_bounds_local: lb_local,
7871            linear_constraints_local: built.linear_constraints,
7872            kronecker_factored: built.kronecker_factored.take(),
7873            joint_null_rotation: applied_rotation,
7874            unabsorbed_global_orthogonality: None,
7875        });
7876
7877        col_start = col_end;
7878    }
7879
7880    assert_eq!(
7881        penalties_global.len(),
7882        nullspace_dims_global.len(),
7883        "global smooth penalty/nullspace bookkeeping diverged"
7884    );
7885    assert_eq!(
7886        penalties_global.len(),
7887        penaltyinfo_global.len(),
7888        "global smooth penalty metadata bookkeeping diverged"
7889    );
7890
7891    Ok(RawSmoothDesign {
7892        term_designs: local_designs,
7893        penalties: penalties_global,
7894        nullspace_dims: nullspace_dims_global,
7895        penaltyinfo: penaltyinfo_global,
7896        dropped_penaltyinfo: dropped_penaltyinfo_global,
7897        terms: terms_out,
7898        coefficient_lower_bounds: if any_bounds {
7899            Some(coefficient_lower_bounds)
7900        } else {
7901            None
7902        },
7903        linear_constraints: if linear_constraintsrows.is_empty() {
7904            None
7905        } else {
7906            let mut a = Array2::<f64>::zeros((linear_constraintsrows.len(), total_p));
7907            for (i, (cs, ce, values)) in linear_constraintsrows.iter().enumerate() {
7908                a.row_mut(i).slice_mut(s![*cs..*ce]).assign(values);
7909            }
7910            Some(LinearInequalityConstraints {
7911                a,
7912                b: Array1::from_vec(linear_constraints_b),
7913            })
7914        },
7915    })
7916}