Skip to main content

gam_terms/smooth/
term_specs.rs

1use coefficient_transforms::{
2    convex_divided_difference_transform_matrix, cumulative_exp, cumulative_sum_transform_matrix,
3    second_cumulative_exp,
4};
5
6pub use error::SmoothError;
7
8use input_standardization::{
9    apply_input_standardization, compensate_length_scale_for_standardization,
10    compensate_optional_length_scale_for_standardization, compute_spatial_input_scales,
11};
12
13use shape_constraints::{
14    build_shape_constraint_design_1d, build_shape_linear_constraints_1d,
15    merge_linear_constraints_global, shape_lower_bounds_local, shape_order_and_sign,
16    shape_supports_basis, shape_uses_box_reparameterization,
17};
18
19pub fn describe_thin_plate_center_request(strategy: &CenterStrategy) -> String {
20    match strategy {
21        CenterStrategy::Auto(inner) => describe_thin_plate_center_request(inner),
22        CenterStrategy::UserProvided(centers) => format!("{} centers", centers.nrows()),
23        CenterStrategy::EqualMass { num_centers }
24        | CenterStrategy::EqualMassCovarRepresentative { num_centers }
25        | CenterStrategy::FarthestPoint { num_centers }
26        | CenterStrategy::KMeans { num_centers, .. } => format!("{num_centers} centers"),
27        CenterStrategy::UniformGrid { points_per_dim } => {
28            format!("uniform grid with {points_per_dim} points per dimension")
29        }
30    }
31}
32
33pub fn rewrite_thin_plate_knots_error(
34    err: BasisError,
35    termname: &str,
36    feature_count: usize,
37    spec: &ThinPlateBasisSpec,
38) -> BasisError {
39    match err {
40        // Polynomial-nullspace shortfall reported directly by the kernel
41        // builder ("thin-plate spline requires at least N centers to span ...").
42        BasisError::InvalidInput(msg)
43            if msg.contains("thin-plate spline requires at least")
44                && (msg.contains("centers to span") || msg.contains("knots to span")) =>
45        {
46            let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
47            let requested = describe_thin_plate_center_request(&spec.center_strategy);
48            BasisError::InvalidInput(format!(
49                "joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
50            ))
51        }
52        // Insufficient-rows shortfall raised by `select_thin_plate_knots` when
53        // the requested center count exceeds the available row count. Rewrite
54        // it in term language so the diagnostic points at the smooth term and
55        // the polynomial-nullspace minimum the user needs to satisfy.
56        BasisError::InvalidInput(msg)
57            if msg.starts_with("requested ") && msg.contains(" knots but only ") =>
58        {
59            let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
60            let requested = describe_thin_plate_center_request(&spec.center_strategy);
61            BasisError::InvalidInput(format!(
62                "joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
63            ))
64        }
65        other => other,
66    }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
70pub enum ShapeConstraint {
71    None,
72    MonotoneIncreasing,
73    MonotoneDecreasing,
74    Convex,
75    Concave,
76}
77
78/// Parse a shape-constraint string into a [`ShapeConstraint`].
79///
80/// This is the single source of truth shared by the formula DSL
81/// (`s(x, shape=...)`) and the `smooths={...}` override path
82/// (`Smooth.shape_constraint`). The accepted spellings cover the canonical
83/// Python `ShapeConstraintLiteral` strings exactly
84/// (`"none"` / `"monotone_increasing"` / `"monotone_decreasing"` /
85/// `"convex"` / `"concave"`) plus a few common aliases. Hyphens and case are
86/// normalized, so `"Monotone-Increasing"` and `"mono_inc"` both resolve to
87/// [`ShapeConstraint::MonotoneIncreasing`].
88pub fn parse_shape_constraint(raw: &str) -> Result<ShapeConstraint, String> {
89    let normalized = raw.trim().to_ascii_lowercase().replace('-', "_");
90    match normalized.as_str() {
91        "" | "none" => Ok(ShapeConstraint::None),
92        "monotone_increasing" | "monotonic_increasing" | "increasing" | "mono_inc" | "mpi" => {
93            Ok(ShapeConstraint::MonotoneIncreasing)
94        }
95        "monotone_decreasing" | "monotonic_decreasing" | "decreasing" | "mono_dec" | "mpd" => {
96            Ok(ShapeConstraint::MonotoneDecreasing)
97        }
98        "convex" | "cvx" => Ok(ShapeConstraint::Convex),
99        "concave" | "ccv" => Ok(ShapeConstraint::Concave),
100        other => Err(format!(
101            "unknown shape constraint {other:?}; expected one of \
102             \"none\", \"monotone_increasing\", \"monotone_decreasing\", \
103             \"convex\", \"concave\""
104        )),
105    }
106}
107
108impl ShapeConstraint {
109    /// Canonical formula-DSL spelling, i.e. the text emitted into
110    /// `s(x, shape=...)`. Round-trips through [`parse_shape_constraint`].
111    pub fn dsl_str(&self) -> &'static str {
112        match self {
113            ShapeConstraint::None => "none",
114            ShapeConstraint::MonotoneIncreasing => "monotone_increasing",
115            ShapeConstraint::MonotoneDecreasing => "monotone_decreasing",
116            ShapeConstraint::Convex => "convex",
117            ShapeConstraint::Concave => "concave",
118        }
119    }
120}
121
122/// Smooth-term head keywords recognised by the formula DSL. A `shape=` option
123/// may be attached to any term whose head is one of these.
124pub const SMOOTH_HEAD_KEYWORDS: [&str; 11] = [
125    "s",
126    "smooth",
127    "te",
128    "tensor",
129    "thinplate",
130    "tps",
131    "duchon",
132    "matern",
133    "sphere",
134    "bs",
135    "bspline",
136];
137
138/// Rewrite smooth-term calls in `formula` so each named smooth carries a
139/// `shape=<kind>` option understood by the formula DSL.
140///
141/// `constraints` pairs the smooth-term text as it appears in the formula
142/// (e.g. `"s(x)"` or `"s(x, type=duchon, centers=8)"`) with a shape-constraint
143/// spelling accepted by [`parse_shape_constraint`]; comparison is exact after
144/// whitespace removal. A `"none"` constraint is a no-op. Referencing a term not
145/// present in the formula is an error.
146///
147/// This is the single source of truth for the `gamfit.fit(..., constraints=…)`
148/// rewrite — the Python wrapper only marshals the mapping across the FFI and
149/// holds no formula-parsing or alias-normalization logic of its own.
150pub fn apply_shape_constraints_to_formula(
151    formula: &str,
152    constraints: &[(String, String)],
153) -> Result<String, String> {
154    use std::collections::{BTreeMap, BTreeSet};
155
156    if constraints.is_empty() {
157        return Ok(formula.to_string());
158    }
159    let strip_ws = |s: &str| -> String { s.chars().filter(|c| !c.is_whitespace()).collect() };
160
161    // Whitespace-stripped term text -> canonical shape spelling.
162    let mut wanted: BTreeMap<String, &'static str> = BTreeMap::new();
163    // Whitespace-stripped term text -> original key (for error labels).
164    let mut originals: BTreeMap<String, String> = BTreeMap::new();
165    for (key, kind_raw) in constraints {
166        let kind = parse_shape_constraint(kind_raw)?;
167        let nk = strip_ws(key);
168        originals.entry(nk.clone()).or_insert_with(|| key.clone());
169        if kind != ShapeConstraint::None {
170            wanted.insert(nk, kind.dsl_str());
171        }
172    }
173    if wanted.is_empty() {
174        return Ok(formula.to_string());
175    }
176
177    let chars: Vec<char> = formula.chars().collect();
178    let n = chars.len();
179    let is_ident = |c: char| c.is_ascii_alphanumeric() || c == '_';
180
181    let mut out = String::with_capacity(formula.len() + 32);
182    let mut matched: BTreeSet<String> = BTreeSet::new();
183    let mut i = 0usize;
184    while i < n {
185        // Locate the next smooth-term head (`<keyword> \s* (`) at or after `i`,
186        // respecting word boundaries so `abs(` never matches the `s(` head.
187        let mut head: Option<(usize, usize)> = None; // (head_start, paren_index)
188        let mut p = i;
189        while p < n {
190            let boundary = p == 0 || !is_ident(chars[p - 1]);
191            if boundary {
192                for kw in SMOOTH_HEAD_KEYWORDS.iter() {
193                    let klen = kw.chars().count();
194                    if p + klen > n || chars[p..p + klen].iter().collect::<String>() != **kw {
195                        continue;
196                    }
197                    let mut q = p + klen;
198                    while q < n && chars[q].is_whitespace() {
199                        q += 1;
200                    }
201                    if q < n && chars[q] == '(' {
202                        head = Some((p, q));
203                        break;
204                    }
205                }
206            }
207            if head.is_some() {
208                break;
209            }
210            p += 1;
211        }
212        let (head_start, paren_open) = match head {
213            Some(h) => h,
214            None => {
215                out.extend(chars[i..].iter());
216                break;
217            }
218        };
219        out.extend(chars[i..head_start].iter());
220
221        // Find the matching close paren, honoring nesting and string literals.
222        let body_start = paren_open + 1;
223        let mut depth = 1i32;
224        let mut j = body_start;
225        let mut in_str: Option<char> = None;
226        let mut closed = false;
227        while j < n {
228            let ch = chars[j];
229            if let Some(quote) = in_str {
230                if ch == quote {
231                    in_str = None;
232                }
233            } else if ch == '\'' || ch == '"' {
234                in_str = Some(ch);
235            } else if ch == '(' {
236                depth += 1;
237            } else if ch == ')' {
238                depth -= 1;
239                if depth == 0 {
240                    closed = true;
241                    break;
242                }
243            }
244            j += 1;
245        }
246
247        if !closed {
248            // Unbalanced — emit the remainder verbatim; the DSL parser will
249            // produce the canonical error.
250            out.extend(chars[head_start..].iter());
251            break;
252        }
253
254        let term_text: String = chars[head_start..=j].iter().collect();
255
256        let key_norm = strip_ws(&term_text);
257
258        match wanted.get(&key_norm) {
259            None => out.extend(chars[head_start..=j].iter()),
260            Some(kind) => {
261                let head_paren: String = chars[head_start..body_start].iter().collect();
262                let inside: String = chars[body_start..j].iter().collect();
263                let inside = inside.trim();
264                if inside.is_empty() {
265                    out.push_str(&format!("{head_paren}shape={kind})"));
266                } else {
267                    out.push_str(&format!("{head_paren}{inside}, shape={kind})"));
268                }
269                matched.insert(key_norm);
270            }
271        }
272
273        i = j + 1;
274    }
275
276    let mut missing: Vec<String> = wanted
277        .keys()
278        .filter(|k| !matched.contains(*k))
279        .map(|k| originals.get(k).cloned().unwrap_or_else(|| k.clone()))
280        .collect();
281
282    if !missing.is_empty() {
283        missing.sort();
284        return Err(format!(
285            "shape constraints referenced smooth term(s) not found in formula: {}",
286            missing.join(", ")
287        ));
288    }
289
290    Ok(out)
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub enum BySmoothKind {
295    Numeric,
296    Level { level_bits: u64 },
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub enum SmoothBasisSpec {
301    /// Row-gated wrapper used for mgcv-style ``by=`` smooths.
302    ///
303    /// ``ByNumeric`` multiplies the inner smooth by a numeric column.
304    /// ``ByLevel`` keeps the inner smooth active only for rows whose encoded
305    /// categorical value has the stored bit pattern.  Unordered factor-by
306    /// smooths are represented as one independent ``ByLevel`` term per level.
307    ///
308    /// `kind` preserves the compact structural discriminator, while `by`
309    /// carries the full row-gating spec used to build the local design.
310    ByVariable {
311        inner: Box<SmoothBasisSpec>,
312        by_col: usize,
313        kind: BySmoothKind,
314        by: ByVariableSpec,
315    },
316    /// Sum-to-zero factor smooth (`bs="sz"`): with L levels, estimate L-1
317    /// deviation coefficient blocks and use the final level as the negative
318    /// sum of the others, enforcing coefficient-wise zero sums across levels.
319    FactorSumToZero {
320        inner: Box<SmoothBasisSpec>,
321        by_col: usize,
322        levels: Vec<u64>,
323        /// Global-orthogonality column map `Z` captured at fit time when this
324        /// term overlapped an owner smooth (`s(x) + s(g, x, bs=sz)`, #978):
325        /// the hierarchical-ownership pass residualized this term's realized
326        /// design as `X ← X·Z`, shrinking its coefficient block. `Z` depends
327        /// on the *training-row* owner designs, so prediction cannot rederive
328        /// it — it must be persisted and replayed
329        /// (`apply_global_smooth_identifiability` consumes it verbatim).
330        /// Chart convention: `Z` lives in the post-restack, post-joint-null-Q
331        /// coordinates — the raw `sz` rebuild reapplies `Q` deterministically
332        /// (#700), then `Z` applies on top. `None` for non-overlapping terms.
333        #[serde(default)]
334        frozen_global_orthogonality: Option<Array2<f64>>,
335    },
336    BSpline1D {
337        feature_col: usize,
338        spec: BSplineBasisSpec,
339    },
340    /// A smooth modulated by a `by=` variable. Numeric `by` scales one inner
341    /// smooth; factor `by` replicates the inner smooth by level.
342    BySmooth {
343        smooth: Box<SmoothBasisSpec>,
344        by_kind: ByVarKind,
345    },
346    /// Factor-smooth interaction families (`bs="fs"`, `bs="sz"`) and
347    /// random slopes (`bs="re"`).
348    FactorSmooth { spec: FactorSmoothSpec },
349    ThinPlate {
350        feature_cols: Vec<usize>,
351        spec: ThinPlateBasisSpec,
352        /// Per-column standard deviations used to standardize input dimensions
353        /// before kernel evaluation when d > 1. `None` means no standardization
354        /// (either d == 1 or explicitly disabled).
355        #[serde(default)]
356        input_scales: Option<Vec<f64>>,
357    },
358    Sphere {
359        feature_cols: Vec<usize>,
360        spec: SphericalSplineBasisSpec,
361    },
362    /// Constant-curvature (`M_κ`) geodesic-kernel smooth over κ-stereographic
363    /// chart coordinates (#944): one construction interpolating
364    /// S^d → ℝ^d → H^d through the spec's fixed κ. The Wahba S² smooth is the
365    /// structural template; the geometry comes from
366    /// `geometry::constant_curvature::ConstantCurvature`.
367    ConstantCurvature {
368        feature_cols: Vec<usize>,
369        spec: ConstantCurvatureBasisSpec,
370    },
371    Matern {
372        feature_cols: Vec<usize>,
373        spec: MaternBasisSpec,
374        #[serde(default)]
375        input_scales: Option<Vec<f64>>,
376    },
377    /// Measure-jet spline smooth: multiscale local-jet-residual energy of the
378    /// empirical measure (centers as μ-quadrature, masses as μ-weights — no
379    /// graph, mesh, or neighbor set inside the statistical object). The
380    /// feature columns are ambient coordinates of data concentrated near an
381    /// unknown low-dimensional, possibly stratified set.
382    MeasureJet {
383        feature_cols: Vec<usize>,
384        spec: MeasureJetBasisSpec,
385        #[serde(default)]
386        input_scales: Option<Vec<f64>>,
387    },
388    Duchon {
389        feature_cols: Vec<usize>,
390        spec: DuchonBasisSpec,
391        #[serde(default)]
392        input_scales: Option<Vec<f64>>,
393    },
394    Pca {
395        feature_cols: Vec<usize>,
396        basis_matrix: Array2<f64>,
397        centered: bool,
398        #[serde(default = "default_pca_smooth_penalty")]
399        smooth_penalty: f64,
400        #[serde(default)]
401        center_mean: Option<Array1<f64>>,
402        #[serde(default)]
403        pca_basis_path: Option<PathBuf>,
404        #[serde(default = "default_pca_chunk_size")]
405        chunk_size: usize,
406    },
407    /// Tensor-product smooth built from 1D B-spline marginals.
408    ///
409    /// This is the `te()`-style construction used when axes have different units/scales
410    /// (for example, space x time) and isotropic radial kernels are not appropriate.
411    TensorBSpline {
412        feature_cols: Vec<usize>,
413        spec: TensorBSplineSpec,
414    },
415}
416
417impl SmoothBasisSpec {
418    /// Conservative lower bound on the number of sample rows needed for this
419    /// smooth basis to have a well-posed REML fit.
420    ///
421    /// Each basis kind answers the question for itself, so the workflow does
422    /// not have to know how many columns a B-spline, tensor product, PCA
423    /// projection, or spatial kernel emits. The contract is a *lower bound*:
424    /// returning too small a number is permitted (the inner solver will catch
425    /// any genuine n-vs-rank failure that slips past); returning too large a
426    /// number is a regression because it rejects legitimate fits.
427    ///
428    /// Rationale: B-spline / tensor / PCA bases have a closed-form column
429    /// count, so we use the exact dimension. Radial bases (TPS, Matern,
430    /// Duchon, Sphere) and factor smooths choose their column count from the
431    /// data (`heuristic_centers`, `unique_count`); we fall back to a small
432    /// constant floor because a fit on fewer than five rows cannot stabilise
433    /// any radial smooth regardless of the configured kernel scale.
434    pub fn min_sample_rows(&self) -> usize {
435        // Floor used for data-driven bases whose column count is not known
436        // from the spec alone. Five rows is the minimum at which the inner
437        // pivot/QR + REML smoothing-parameter search has any chance of being
438        // well-posed for a non-parametric smooth.
439        const RADIAL_FLOOR: usize = 5;
440
441        match self {
442            Self::ByVariable { inner, .. } => inner.min_sample_rows(),
443            Self::FactorSumToZero { inner, levels, .. } => {
444                // L-1 independent deviation blocks each carrying the inner
445                // basis dimension. Skip the levels-multiplier if it doesn't
446                // bring more rows; we want the *lower bound* not the rank.
447                let inner_min = inner.min_sample_rows();
448                let lvls = levels.len().saturating_sub(1).max(1);
449                inner_min.saturating_mul(lvls)
450            }
451            Self::BSpline1D { spec, .. } => bspline_basis_min_rows(spec),
452            Self::BySmooth { smooth, .. } => smooth.min_sample_rows(),
453            Self::FactorSmooth { spec } => {
454                // Replicates the marginal once per level; without a known
455                // level count we conservatively require at least the marginal
456                // basis dimension.
457                bspline_basis_min_rows(&spec.marginal)
458            }
459            Self::ThinPlate { .. }
460            | Self::Sphere { .. }
461            | Self::ConstantCurvature { .. }
462            | Self::Matern { .. }
463            | Self::MeasureJet { .. }
464            | Self::Duchon { .. } => RADIAL_FLOOR,
465            Self::Pca { basis_matrix, .. } => basis_matrix.ncols().max(1),
466            Self::TensorBSpline { spec, .. } => {
467                // A `te(...)` smooth is *penalized*: each margin carries a
468                // difference (wiggliness) penalty and the tensor inherits a
469                // Kronecker-sum penalty `S = Σ_i I ⊗ … ⊗ S_i ⊗ … ⊗ I`. The raw
470                // column count is the *product* of the per-marginal column
471                // counts, but that product is the lower bound for an
472                // *unpenalized* tensor regression — it is the number of rows you
473                // would need to identify every interaction column with no
474                // regularization. The penalty regularizes all of those
475                // interaction directions; only the combined penalty *null space*
476                // (the tensor product of the per-margin polynomial trends, a
477                // handful of columns) must be identified by the data, and the
478                // smoothing-parameter search shrinks the rest. The effective
479                // degrees of freedom of the fitted `te()` are therefore a small
480                // fraction of the column product, which is exactly why mgcv
481                // fits a default `te(x, y)` on a couple hundred rows.
482                //
483                // The honest *penalized* lower bound is the **sum** of the
484                // per-marginal column counts, not their product: a row floor of
485                // `Σ_i k_i` still guarantees enough data to identify each
486                // margin's additive main-effect (the largest sub-block the
487                // penalty cannot shrink to zero), while no longer conflating
488                // unpenalized column-count identifiability with penalized
489                // well-posedness. This accepts moderate-`n` penalized tensors
490                // (e.g. a 20×20 default basis on n=200) yet still rejects a
491                // genuinely undersized fit where `n < Σ_i k_i` and even the
492                // additive part is rank-deficient.
493                //
494                // Binary / low-cardinality margins (#724): gam will accept a
495                // `te(x, badh)` whose `badh ∈ {0, 1}` margin nominally requests
496                // more basis columns than `badh` has unique values, where mgcv
497                // refuses the unpenalized term as ill-posed ("badh has
498                // insufficient unique values to support k knots"). This is
499                // correct-by-design, *not* a degenerate fit: the marginal
500                // wiggliness penalty on the `badh` axis has a null space that is
501                // exactly its identifiable trend (the two cell means of a binary
502                // covariate), and the Kronecker-sum penalty shrinks every tensor
503                // column outside that null space toward zero. The resulting fit
504                // is the well-posed "per-level `x` smooth + binary main effect"
505                // that mgcv reaches only after manually collapsing the basis —
506                // gam reaches it automatically because the penalty, not the raw
507                // column count, sets the effective rank. A genuinely
508                // rank-deficient design (penalty null space wider than the data
509                // can support) is still caught downstream by the inner pivoted
510                // factorization, which owns the exact n-vs-rank decision; this
511                // pre-fit gate only refuses the grossly-undersized formula.
512                let mut total: usize = 0;
513                for marginal in &spec.marginalspecs {
514                    let m = bspline_basis_min_rows(marginal);
515                    total = total.saturating_add(m.max(1));
516                }
517                total.max(RADIAL_FLOOR)
518            }
519        }
520    }
521
522    /// Stable structural discriminant for warm-start cache keying (#869).
523    ///
524    /// Two smooths that produce different bases / penalty structures must map
525    /// to different strings here so they cannot collide on the persistent
526    /// warm-start `cache_key` (which is otherwise blind to topology: it hashes
527    /// only the raw input column count, so e.g. `sphere` vs `torus` vs
528    /// `euclidean` candidates fit on the *same* data would otherwise share one
529    /// key and cross-contaminate each other's β/ρ seed). The string is the
530    /// topology identity, not the fitted coefficients, so same-topology refits
531    /// (the screen→full-refit cascade) still hit the same key and reuse work.
532    pub fn structural_kind(&self) -> &'static str {
533        match self {
534            Self::ByVariable { .. } => "by_variable",
535            Self::FactorSumToZero { .. } => "factor_sum_to_zero",
536            Self::BSpline1D { .. } => "bspline_1d",
537            Self::BySmooth { .. } => "by_smooth",
538            Self::FactorSmooth { .. } => "factor_smooth",
539            Self::ThinPlate { .. } => "thin_plate",
540            Self::Sphere { .. } => "sphere",
541            Self::ConstantCurvature { .. } => "constant_curvature",
542            Self::Matern { .. } => "matern",
543            Self::MeasureJet { .. } => "measurejet",
544            Self::Duchon { .. } => "duchon",
545            Self::Pca { .. } => "pca",
546            Self::TensorBSpline { .. } => "tensor_bspline",
547        }
548    }
549
550    /// True for a tensor-product smooth that is only *marginally* centered
551    /// (`ti(...)`, [`TensorBSplineIdentifiability::MarginalSumToZero`]): its
552    /// per-margin sum-to-zero reparameterization `(B_xZ_x)⊗(B_zZ_z)` has ALREADY
553    /// removed each axis's main effect analytically (mgcv-identical), so its
554    /// main-effect removal is complete and it must take NO additional
555    /// owner-residualization block. Residualizing it a second time against the
556    /// realized main-effect designs is a grid-fragile no-op on an exact tensor
557    /// grid but eats genuine pure-interaction curvature off-grid (#1470).
558    pub fn is_marginally_centered_tensor(&self) -> bool {
559        matches!(
560            self,
561            Self::TensorBSpline { spec, .. }
562                if matches!(spec.identifiability, TensorBSplineIdentifiability::MarginalSumToZero)
563        )
564    }
565
566    /// A sum-to-zero factor smooth (`bs="sz"`) has ALREADY removed the
567    /// cross-group main effect analytically, in coefficient space, via its
568    /// `Σ_g d_g(x) ≡ 0` reparameterization (`L-1` deviation blocks with the
569    /// reference level the negative sum of the others) — exactly mgcv's `sz`
570    /// construction, which is self-identifiable against an overlapping `s(x)`
571    /// with no further constraint. Residualizing it a SECOND time against the
572    /// realized B-spline span of the explicit `s(x)` smooth is redundant in
573    /// exact arithmetic (the common-to-all-groups component is zero by
574    /// construction) and actively HARMFUL on finite data: each deviation block
575    /// is a B-spline in `x` whose realized columns share `s(x)`'s span, so the
576    /// joint residualization collapses the full `L·k`-column deviation design to
577    /// `L·k − rank(s(x))` columns and eats the within-group curvature `s(x)`
578    /// cannot represent. REML then rails the deviation smoothing parameter and
579    /// the factor smooth under-recovers (#1605). This is the exact analogue of
580    /// the marginally-centered tensor (`ti`) exemption (#1470), so such a term
581    /// takes NO owner-residualization block.
582    pub fn is_sum_to_zero_factor_smooth(&self) -> bool {
583        matches!(
584            self,
585            Self::FactorSumToZero { .. }
586                | Self::FactorSmooth {
587                    spec: FactorSmoothSpec {
588                        flavour: FactorSmoothFlavour::Sz,
589                        ..
590                    }
591                }
592        )
593    }
594
595    /// Feature columns this basis consumes, used alongside [`structural_kind`]
596    /// to disambiguate two same-kind smooths on different axes. Wrapper
597    /// variants delegate to their inner basis.
598    pub fn structural_feature_cols(&self) -> Vec<usize> {
599        match self {
600            Self::ByVariable { inner, .. } | Self::FactorSumToZero { inner, .. } => {
601                inner.structural_feature_cols()
602            }
603            Self::BySmooth { smooth, .. } => smooth.structural_feature_cols(),
604            Self::FactorSmooth { .. } => Vec::new(),
605            Self::BSpline1D { feature_col, .. } => vec![*feature_col],
606            Self::ThinPlate { feature_cols, .. }
607            | Self::Sphere { feature_cols, .. }
608            | Self::ConstantCurvature { feature_cols, .. }
609            | Self::Matern { feature_cols, .. }
610            | Self::MeasureJet { feature_cols, .. }
611            | Self::Duchon { feature_cols, .. }
612            | Self::Pca { feature_cols, .. }
613            | Self::TensorBSpline { feature_cols, .. } => feature_cols.clone(),
614        }
615    }
616}
617
618/// Lower bound on the number of sample rows a 1D B-spline smooth needs for a
619/// well-posed *penalized* REML fit. Used as the per-smooth row floor in
620/// [`SmoothBasisSpec::min_sample_rows`].
621///
622/// For a *singly*-penalized smooth the floor is the full column count: the
623/// wiggliness penalty leaves the order-`m` polynomial trend unpenalized, and
624/// gam's original gate conservatively required enough rows for the whole basis.
625/// That conservative floor is kept here unchanged.
626///
627/// A *double*-penalized smooth (mgcv `select=TRUE`) is different: it adds a
628/// second penalty on the wiggliness penalty's null space, so even the
629/// polynomial trend is shrinkable toward zero and *nothing* in the basis
630/// requires unpenalized identification by the data — exactly the reasoning the
631/// `TensorBSpline` arm of [`SmoothBasisSpec::min_sample_rows`] already applies
632/// to a penalized tensor. Its honest floor is therefore a small stabilization
633/// constant, not the column count. This is what lets mgcv (and now gam) fit
634/// several `select=TRUE` smooths on a dataset whose row count is below the
635/// summed basis width (e.g. the n≈30 `wine_gamair` fold, 5 `ps` smooths,
636/// p≈51): the penalties, not the data, set the effective rank. The bounded
637/// outer REML loop still terminates, and the genuine n-vs-rank decision is
638/// owned downstream by the inner pivoted factorization. Without this, gam
639/// rejected the fit outright (or, before the gate existed, the outer REML loop
640/// wandered the flat overparameterized surface until the benchmark wall budget
641/// killed it — #1089).
642pub fn bspline_basis_min_rows(spec: &crate::basis::BSplineBasisSpec) -> usize {
643    use crate::basis::BSplineKnotSpec;
644    let columns = match &spec.knotspec {
645        BSplineKnotSpec::Generate {
646            num_internal_knots, ..
647        } => *num_internal_knots + spec.degree + 1,
648        BSplineKnotSpec::Automatic {
649            num_internal_knots: Some(k),
650            ..
651        } => *k + spec.degree + 1,
652        BSplineKnotSpec::Automatic {
653            num_internal_knots: None,
654            ..
655        } => {
656            // Knot count is data-derived (`default_internal_knot_count_for_data`).
657            // A minimal cubic basis is `degree + 2` columns; below that the
658            // basis cannot represent a non-parametric smooth.
659            spec.degree + 2
660        }
661        BSplineKnotSpec::Provided(knots) => knots.len().saturating_sub(spec.degree + 1).max(1),
662        // cr basis dimension equals the knot count (no degree offset).
663        BSplineKnotSpec::NaturalCubicRegression { knots } => knots.len(),
664        BSplineKnotSpec::PeriodicUniform { num_basis, .. } => *num_basis,
665    };
666    let columns = columns.max(spec.degree + 2);
667
668    if spec.double_penalty {
669        // Fully shrinkable basis: only a small stabilization floor must be
670        // identified by the data, capped by the actual column count.
671        const DOUBLE_PENALTY_FLOOR: usize = 2;
672        DOUBLE_PENALTY_FLOOR.min(columns).max(1)
673    } else {
674        columns
675    }
676}
677
678#[derive(Debug, Clone, Serialize, Deserialize)]
679pub enum ByVariableSpec {
680    Numeric,
681    Level { value_bits: u64, label: String },
682}
683
684
685#[derive(Debug, Clone, Serialize, Deserialize)]
686pub enum ByVarKind {
687    Numeric {
688        feature_col: usize,
689    },
690    Factor {
691        feature_col: usize,
692        ordered: bool,
693        frozen_levels: Option<Vec<u64>>,
694    },
695}
696
697#[derive(Debug, Clone, Serialize, Deserialize)]
698pub struct FactorSmoothSpec {
699    pub continuous_cols: Vec<usize>,
700    pub group_col: usize,
701    pub marginal: BSplineBasisSpec,
702    pub flavour: FactorSmoothFlavour,
703    pub group_frozen_levels: Option<Vec<u64>>,
704    /// Fit-time global-orthogonality chart `Z` for this term (`s(x) + fs(x, g)`
705    /// overlap residualization, #978), in the post-joint-null-`Q` coordinates
706    /// (the raw rebuild recomputes any `Q` itself; `fs` penalties are
707    /// typically full-rank so `Q` is absent). Training-row dependent, hence
708    /// persisted; replayed verbatim by `apply_global_smooth_identifiability`.
709    #[serde(default)]
710    pub frozen_global_orthogonality: Option<Array2<f64>>,
711}
712
713#[derive(Debug, Clone, Serialize, Deserialize)]
714pub enum FactorSmoothFlavour {
715    Fs { m_null_penalty_orders: Vec<usize> },
716    Sz,
717    Re,
718}
719
720#[derive(Debug, Default, Clone, Serialize, Deserialize)]
721pub struct TensorBSplineSpec {
722    pub marginalspecs: Vec<BSplineBasisSpec>,
723    #[serde(default)]
724    pub periods: Vec<Option<f64>>,
725    pub double_penalty: bool,
726    #[serde(default)]
727    pub identifiability: TensorBSplineIdentifiability,
728    #[serde(default)]
729    pub penalty_decomposition: TensorBSplinePenaltyDecomposition,
730}
731
732#[derive(Debug, Default, Clone, Serialize, Deserialize)]
733pub enum TensorBSplineIdentifiability {
734    None,
735    #[default]
736    SumToZero,
737    /// mgcv `ti(...)` semantics: a *tensor interaction* smooth that excludes the
738    /// marginal main effects. A sum-to-zero constraint is applied to **each
739    /// marginal basis independently** before forming the tensor product, so the
740    /// resulting column space contains no function of a single variable alone —
741    /// only the pure interaction survives. The realized identifiability
742    /// transform is the Kronecker product `Z = Z₀ ⊗ Z₁ ⊗ … ⊗ Z_{d-1}` of the
743    /// per-margin sum-to-zero null-space bases, which is exactly the
744    /// reparameterization that turns the full-tensor design into the tensor
745    /// product of the centered margins.
746    MarginalSumToZero,
747    FrozenTransform {
748        transform: Array2<f64>,
749    },
750}
751
752#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
753pub enum TensorBSplinePenaltyDecomposition {
754    /// mgcv `te(...)`: one overlapping Kronecker-product penalty per margin,
755    /// `S_j` embedded against identities in the other tensor factors.
756    #[default]
757    MarginalKroneckerSum,
758    /// mgcv `t2(...)`: split every marginal coefficient space into penalized
759    /// range and penalty-null subspaces, then emit one disjoint tensor-subspace
760    /// penalty for every non-empty penalized/null combination.
761    Separable,
762}
763
764#[derive(Debug, Clone, Serialize, Deserialize)]
765pub struct SmoothTermSpec {
766    pub name: String,
767    pub basis: SmoothBasisSpec,
768    pub shape: ShapeConstraint,
769    /// Joint-null absorption rotation captured at fit time. `Some(Q)` means
770    /// the fitted coefficient vector lives in `γ`-coordinates with
771    /// `β_raw = Q · γ`; prediction must rotate the raw-basis design via
772    /// `X_new = X_new_raw · Q` to match. `None` means either the smooth had
773    /// no joint null space (penalty already full-rank) or rotation was
774    /// suppressed (smooth carries shape constraints whose cone geometry
775    /// would not survive an arbitrary orthogonal rotation). Persisted so
776    /// `save → load → predict` is bit-equivalent to in-memory prediction.
777    #[serde(default)]
778    pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
779}
780
781#[derive(Debug, Clone)]
782pub struct SmoothTerm {
783    pub name: String,
784    pub coeff_range: Range<usize>,
785    pub shape: ShapeConstraint,
786    pub penalties_local: Vec<Array2<f64>>,
787    pub nullspace_dims: Vec<usize>,
788    pub penaltyinfo_local: Vec<PenaltyInfo>,
789    pub metadata: BasisMetadata,
790    /// Optional term-local lower bounds for constrained coefficients.
791    /// `-inf` means unconstrained.
792    pub lower_bounds_local: Option<Array1<f64>>,
793    /// Optional term-local inequality constraints in local coefficient coordinates.
794    /// `A_local * beta_local >= b_local`.
795    pub linear_constraints_local: Option<LinearInequalityConstraints>,
796    /// Optional factored tensor-product representation preserved for operator-backed
797    /// assembly in the main design builder.
798    pub kronecker_factored: Option<KroneckerFactoredBasis>,
799    /// Joint-null absorption rotation. `Some(Q)` records the orthonormal
800    /// `(p_local × p_local)` matrix that was applied to this term's design
801    /// and per-block penalties at construction time:
802    /// `term_design ← X_raw · Q`, `penalties_local[k] ← Qᵀ · S_raw · Q`.
803    /// The smooth's coefficient vector therefore lives in the rotated
804    /// (`γ`) coordinate system, with `β_raw = Q · γ` recovering the raw
805    /// pre-rotation parameterization. `None` means either no joint null
806    /// space (penalty already full-rank) or rotation was suppressed —
807    /// suppression fires when the smooth carries shape constraints
808    /// (lower bounds or local linear inequalities) that would lose their
809    /// cone geometry under a general orthogonal rotation.
810    ///
811    /// Prediction-side replay: callers building a new-data design `X_new_raw`
812    /// from the *raw* basis must call [`SmoothTerm::apply_rotation_to_predict`]
813    /// (or equivalent) to obtain `X_new = X_new_raw · Q` matching this
814    /// term's coefficient system.
815    ///
816    /// Persistence replay: `freeze_term_collection_from_design` copies this
817    /// rotation into `SmoothTermSpec`, which is serialized with fitted-model
818    /// payloads and reused by the predict-time basis builder. Saved models
819    /// therefore replay the same `X_new_raw · Q` transform as in-memory
820    /// prediction.
821    pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
822    /// Global-orthogonality transform that `apply_global_smooth_identifiability`
823    /// applied to this term's design but could NOT embed into `metadata`
824    /// (factor-smooth kinds: `sz` metadata is per-marginal, `fs` metadata has
825    /// no transform slot — #978). `freeze_term_collection_from_design` copies
826    /// it onto the term's basis spec (`frozen_global_orthogonality`) so the
827    /// predict-side rebuild replays it instead of emitting the unresidualized
828    /// (wider) design that the fitted coefficients no longer match.
829    /// Chart convention is per kind: post-`Q` `Z` for `sz` (the raw rebuild
830    /// reapplies `Q` itself, #700), full `Q·Z` chart for `fs`.
831    pub unabsorbed_global_orthogonality: Option<Array2<f64>>,
832}
833
834impl SmoothTerm {
835    /// Apply the joint-null absorption rotation to a raw new-data design
836    /// matrix, returning `X_new_raw · Q` when this term was rotated at
837    /// fit time, or `X_new_raw` unchanged when no rotation was applied.
838    ///
839    /// Callers in the prediction path: after building the smooth's basis
840    /// at new data via the *raw* basis builder (the same builder used at
841    /// fit time, applied to `x_new` instead of the training rows), call
842    /// this method on the resulting matrix before forming `X · β`. The
843    /// fitted `β` lives in `γ`-coordinates if Q was applied; multiplying
844    /// the un-rotated `X_new_raw` by `β` would give a wrong η.
845    ///
846    /// Returns an error if the raw design's column count does not match
847    /// the rotation's `p_local`. The width invariant must hold: the raw
848    /// basis builder MUST emit the same `p_local` columns that the
849    /// fit-time builder did, and the rotation is `(p_local × p_local)`.
850    pub fn apply_rotation_to_predict(
851        &self,
852        x_new_raw: Array2<f64>,
853    ) -> Result<Array2<f64>, BasisError> {
854        let Some(rot) = self.joint_null_rotation.as_ref() else {
855            return Ok(x_new_raw);
856        };
857        let p_local = rot.rotation.nrows();
858        if x_new_raw.ncols() != p_local {
859            crate::bail_dim_basis!(
860                "joint-null rotation replay for term '{}': raw design has {} columns, \
861                 rotation expects {} (the raw basis builder must emit the same column \
862                 count as at fit time)",
863                self.name,
864                x_new_raw.ncols(),
865                p_local,
866            );
867        }
868        Ok(gam_linalg::faer_ndarray::fast_ab(
869            &x_new_raw,
870            &rot.rotation,
871        ))
872    }
873
874    /// Dimension of the **joint** null space of this term's active penalties:
875    /// the coefficient directions penalized by *no* penalty. The smooth-component
876    /// Wald test ([`crate::inference::smooth_test::wood_smooth_test`]) treats this
877    /// many leading coefficients as genuine unpenalized fixed effects and tests
878    /// them at full rank; the remainder is the penalized sub-block tested with a
879    /// rank-`≈EDF` truncated pseudo-inverse.
880    ///
881    /// Because every penalty block `S_k` is positive semi-definite,
882    /// `vᵀ(Σ_k S_k)v = Σ_k vᵀ S_k v = 0` iff `S_k v = 0` for *every* `k`; the
883    /// joint null space is therefore exactly `null(Σ_k S_k)`, of dimension
884    /// `p_local − rank(Σ_k S_k)`. This is the **intersection** of the per-penalty
885    /// null spaces, not their sum.
886    ///
887    /// Summing the per-penalty `nullspace_dims` instead (the historical defect
888    /// behind #1360) *unions* the null spaces and badly over-counts: a
889    /// double-penalty smooth carries a bending penalty (null space = its
890    /// polynomial part) plus a complementary null-space ridge (which penalizes
891    /// exactly that polynomial part), so the two null spaces are disjoint and the
892    /// joint null space is empty — yet the per-penalty dims sum to nearly
893    /// `p_local`. Feeding that inflated count to the Wald test makes it test
894    /// almost the whole shrunk block at full rank, manufacturing overwhelming
895    /// "significance" for a term the fit drove to ~0 EDF.
896    pub fn wald_unpenalized_dim(&self) -> usize {
897        joint_unpenalized_dim(
898            self.coeff_range.len(),
899            &self.penalties_local,
900            &self.nullspace_dims,
901        )
902    }
903}
904
905/// Numeric core of [`SmoothTerm::wald_unpenalized_dim`]: the dimension of the
906/// joint null space `∩_k null(S_k) = null(Σ_k S_k)` of a term's local penalty
907/// blocks, with a conservative fallback when a penalty is not materialized as a
908/// full `p_local × p_local` matrix (e.g. a Kronecker tensor factor).
909pub fn joint_unpenalized_dim(
910    p_local: usize,
911    penalties_local: &[Array2<f64>],
912    nullspace_dims: &[usize],
913) -> usize {
914    use gam_linalg::faer_ndarray::FaerEigh;
915    if p_local == 0 {
916        return 0;
917    }
918    if penalties_local.is_empty() {
919        // No penalty ⇒ a wholly unpenalized (fixed-effect) block.
920        return p_local;
921    }
922    // Sum the penalties that are materialized as full `p_local × p_local`
923    // blocks (the common smooth case). The covariance block the Wald test
924    // slices lives in this same coefficient basis (post joint-null rotation),
925    // so the rank is computed in the right metric.
926    let mut s_total = Array2::<f64>::zeros((p_local, p_local));
927    let mut materialized = 0usize;
928    for s in penalties_local {
929        if s.nrows() == p_local && s.ncols() == p_local {
930            s_total += s;
931            materialized += 1;
932        }
933    }
934    if materialized == penalties_local.len() {
935        let symmetric = {
936            let transpose = s_total.t().to_owned();
937            (&s_total + &transpose) * 0.5
938        };
939        if let Ok((evals, _)) = symmetric.eigh(faer::Side::Lower) {
940            let max_abs = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
941            if max_abs == 0.0 {
942                // All penalties identically zero ⇒ unpenalized block.
943                return p_local;
944            }
945            let tol = max_abs * (p_local as f64) * 1e-12;
946            let rank = evals.iter().filter(|&&v| v > tol).count();
947            return p_local.saturating_sub(rank);
948        }
949    }
950    // Conservative fallback when a penalty is not a materialized full block
951    // (e.g. a Kronecker tensor factor): with ≥2 active penalties the joint
952    // null space is almost always empty (the only over-rejecting direction);
953    // with a single penalty it is exactly that penalty's own null space.
954    if penalties_local.len() >= 2 {
955        0
956    } else {
957        nullspace_dims
958            .iter()
959            .copied()
960            .min()
961            .unwrap_or(0)
962            .min(p_local)
963    }
964}
965
966#[derive(Debug, Clone, Serialize, Deserialize)]
967pub struct PenaltyBlockInfo {
968    pub global_index: usize,
969    pub termname: Option<String>,
970    pub penalty: PenaltyInfo,
971}
972
973#[derive(Debug, Clone, Serialize, Deserialize)]
974pub struct DroppedPenaltyBlockInfo {
975    pub termname: Option<String>,
976    pub penalty: PenaltyInfo,
977}
978
979#[derive(Debug, Clone)]
980pub struct SmoothDesign {
981    pub term_designs: Vec<DesignMatrix>,
982    /// Per-term block-local penalties.  Each `col_range` is relative to the
983    /// smooth block (i.e. indexing into the concatenation of `term_designs`).
984    pub penalties: Vec<BlockwisePenalty>,
985    pub nullspace_dims: Vec<usize>,
986    pub penaltyinfo: Vec<PenaltyBlockInfo>,
987    pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
988    pub terms: Vec<SmoothTerm>,
989    /// Optional smooth-block lower bounds in smooth coefficient coordinates.
990    /// Length equals `total_smooth_cols()` when present.
991    pub coefficient_lower_bounds: Option<Array1<f64>>,
992    /// Optional smooth-block inequality constraints:
993    /// `A_smooth * beta_smooth >= b`.
994    pub linear_constraints: Option<LinearInequalityConstraints>,
995}
996
997impl SmoothDesign {
998    pub fn total_smooth_cols(&self) -> usize {
999        self.term_designs.iter().map(DesignMatrix::ncols).sum()
1000    }
1001    pub fn nrows(&self) -> usize {
1002        self.term_designs.first().map_or(0, DesignMatrix::nrows)
1003    }
1004}
1005
1006#[derive(Debug, Clone)]
1007pub struct RawSmoothDesign {
1008    pub term_designs: Vec<DesignMatrix>,
1009    /// Per-term block-local penalties.  Each `col_range` is relative to the
1010    /// smooth block (i.e. indexing into the concatenation of `term_designs`).
1011    pub penalties: Vec<BlockwisePenalty>,
1012    pub nullspace_dims: Vec<usize>,
1013    pub penaltyinfo: Vec<PenaltyBlockInfo>,
1014    pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
1015    pub terms: Vec<SmoothTerm>,
1016    pub coefficient_lower_bounds: Option<Array1<f64>>,
1017    pub linear_constraints: Option<LinearInequalityConstraints>,
1018}
1019
1020impl RawSmoothDesign {
1021    pub fn total_smooth_cols(&self) -> usize {
1022        self.term_designs.iter().map(DesignMatrix::ncols).sum()
1023    }
1024    pub fn nrows(&self) -> usize {
1025        self.term_designs.first().map_or(0, DesignMatrix::nrows)
1026    }
1027}
1028
1029impl From<RawSmoothDesign> for SmoothDesign {
1030    fn from(value: RawSmoothDesign) -> Self {
1031        Self {
1032            term_designs: value.term_designs,
1033            penalties: value.penalties,
1034            nullspace_dims: value.nullspace_dims,
1035            penaltyinfo: value.penaltyinfo,
1036            dropped_penaltyinfo: value.dropped_penaltyinfo,
1037            terms: value.terms,
1038            coefficient_lower_bounds: value.coefficient_lower_bounds,
1039            linear_constraints: value.linear_constraints,
1040        }
1041    }
1042}
1043
1044#[derive(Debug, Default, Clone, Serialize, Deserialize)]
1045pub enum BoundedCoefficientPriorSpec {
1046    #[default]
1047    None,
1048    Uniform,
1049    Beta {
1050        a: f64,
1051        b: f64,
1052    },
1053}
1054
1055#[derive(Debug, Clone, Serialize, Deserialize, Default)]
1056pub enum LinearCoefficientGeometry {
1057    #[default]
1058    Unconstrained,
1059    Bounded {
1060        min: f64,
1061        max: f64,
1062        #[serde(default)]
1063        prior: BoundedCoefficientPriorSpec,
1064    },
1065}
1066
1067#[derive(Debug, Clone, Serialize, Deserialize)]
1068pub struct LinearTermSpec {
1069    pub name: String,
1070    /// Primary feature column index. For Wilkinson-Rogers `:` interaction
1071    /// terms (`a:b[:c...]`) this is the first column in `feature_cols`; the
1072    /// realized design column is the elementwise product across every entry
1073    /// of `feature_cols`. Plain (non-interaction) linear terms set
1074    /// `feature_cols == vec![feature_col]`.
1075    pub feature_col: usize,
1076    /// Full list of columns whose elementwise product yields this term's
1077    /// design column. `len() >= 1`; `len() == 1` is a plain linear effect.
1078    #[serde(default)]
1079    pub feature_cols: Vec<usize>,
1080    /// Categorical-level gates for a factor-aware `:` interaction.
1081    ///
1082    /// Each `(col, level_bits)` multiplies the realized design column by the
1083    /// indicator `1[data[row, col].to_bits() == level_bits]`. This is how a
1084    /// `factor:x` (or `factor:factor`) interaction is expanded: `build_termspec`
1085    /// emits one `LinearTermSpec` per surviving cell of the categorical
1086    /// operand(s) (treatment-coded, first level dropped per factor), each
1087    /// carrying the numeric operands in `feature_cols` and the cell's level
1088    /// gate(s) here. Empty for a plain numeric `:` interaction or main effect,
1089    /// in which case the realized column is exactly the numeric product.
1090    #[serde(default)]
1091    pub categorical_levels: Vec<(usize, u64)>,
1092    /// Optional ridge (`S = I`, REML-selected `λ`) on this linear coefficient.
1093    /// A parametric linear term carries no wiggliness, so it is **unpenalized by
1094    /// default** — gam reports the MLE, matching mgcv/glm/survreg/VGAM (which
1095    /// penalize parametric terms only under an explicit `paraPen`). Set `true`
1096    /// to opt into an explicit shrinkage ridge (a zero-mean Gaussian prior
1097    /// `β ~ N(0, λ⁻¹)`); doing so adds one outer REML smoothing coordinate.
1098    #[serde(default = "default_linear_term_double_penalty")]
1099    pub double_penalty: bool,
1100    #[serde(default)]
1101    pub coefficient_geometry: LinearCoefficientGeometry,
1102    #[serde(default)]
1103    pub coefficient_min: Option<f64>,
1104    #[serde(default)]
1105    pub coefficient_max: Option<f64>,
1106}
1107
1108impl LinearTermSpec {
1109    /// Return the effective list of feature columns. Backfills from
1110    /// `feature_col` for legacy specs that predate the multi-column field.
1111    pub fn effective_feature_cols(&self) -> Vec<usize> {
1112        if self.feature_cols.is_empty() {
1113            vec![self.feature_col]
1114        } else {
1115            self.feature_cols.clone()
1116        }
1117    }
1118
1119    /// True when this term is a Wilkinson-Rogers `:` interaction (multi-col).
1120    pub fn is_interaction(&self) -> bool {
1121        self.feature_cols.len() > 1 || !self.categorical_levels.is_empty()
1122    }
1123
1124    /// Realize this linear term's `(n,)` design column from `data`.
1125    ///
1126    /// The column is the elementwise product of every numeric feature column
1127    /// (`effective_feature_cols`) gated by the categorical-level indicators in
1128    /// `categorical_levels`: each `(col, level_bits)` multiplies the running
1129    /// column by `1[data[row, col].to_bits() == level_bits]`. A plain numeric
1130    /// term (no `categorical_levels`) reduces to the bare product, matching the
1131    /// historical behaviour. A pure categorical interaction (empty
1132    /// `feature_cols`, non-empty `categorical_levels`) reduces to the cell
1133    /// indicator. Bounds are validated here; the returned column has length
1134    /// `data.nrows()`.
1135    pub fn realized_design_column(&self, data: ArrayView2<'_, f64>) -> Result<Array1<f64>, String> {
1136        let n = data.nrows();
1137        let p = data.ncols();
1138        let bounds = |col: usize| -> Result<(), String> {
1139            if col >= p {
1140                Err(format!(
1141                    "linear term '{}' feature column {} out of bounds for {} columns",
1142                    self.name, col, p
1143                ))
1144            } else {
1145                Ok(())
1146            }
1147        };
1148
1149        // Numeric operands. When `categorical_levels` is set we treat
1150        // `feature_cols` as the (possibly empty) numeric operand list and start
1151        // from a column of ones; otherwise we preserve the legacy backfill from
1152        // `feature_col` so a plain term with no `feature_cols` still resolves.
1153        let mut column = if self.categorical_levels.is_empty() {
1154            let cols = self.effective_feature_cols();
1155            for &c in &cols {
1156                bounds(c)?;
1157            }
1158            let mut acc = data.column(cols[0]).to_owned();
1159            for &c in cols.iter().skip(1) {
1160                acc *= &data.column(c);
1161            }
1162            acc
1163        } else {
1164            let mut acc = Array1::<f64>::ones(n);
1165            for &c in &self.feature_cols {
1166                bounds(c)?;
1167                acc *= &data.column(c);
1168            }
1169            acc
1170        };
1171
1172        for &(col, level_bits) in &self.categorical_levels {
1173            bounds(col)?;
1174            let gate = data.column(col);
1175            for (out, &v) in column.iter_mut().zip(gate.iter()) {
1176                if v.to_bits() != level_bits {
1177                    *out = 0.0;
1178                }
1179            }
1180        }
1181
1182        Ok(column)
1183    }
1184}
1185
1186pub const fn default_linear_term_double_penalty() -> bool {
1187    // Parametric/linear terms are unpenalized by default — a single linear
1188    // coefficient has no roughness for a smoothing penalty to control, so the
1189    // historical `S = I`, REML-selected `λ` shrank every linear coefficient off
1190    // the MLE and injected a spurious outer smoothing coordinate (#749). Mature
1191    // tools (mgcv/glm/survreg/VGAM) leave parametric terms unpenalized; gam now
1192    // matches that and reports the MLE. An explicit `double_penalty = true`
1193    // still opts a term into a ridge.
1194    false
1195}
1196
1197pub const fn default_pca_smooth_penalty() -> f64 {
1198    1.0
1199}
1200
1201pub const fn default_pca_chunk_size() -> usize {
1202    4096
1203}
1204
1205/// Random-effects term specification.
1206///
1207/// The selected feature column is interpreted as a categorical grouping variable.
1208/// The term contributes a one-hot dummy block with an identity penalty on group
1209/// coefficients, equivalent to i.i.d. Gaussian random effects.
1210#[derive(Debug, Clone, Serialize, Deserialize)]
1211pub struct RandomEffectTermSpec {
1212    pub name: String,
1213    pub feature_col: usize,
1214    /// If true, drop the lexicographically first group level to use treatment coding.
1215    /// If false, keep all levels (full one-hot block, still identifiable under ridge).
1216    pub drop_first_level: bool,
1217    /// If true, add a ridge penalty and estimate this block as a random effect.
1218    /// If false, leave the one-hot/treatment-coded block unpenalized so it is a
1219    /// fixed categorical main effect.  The default preserves older saved models.
1220    #[serde(default = "default_random_effect_penalized")]
1221    pub penalized: bool,
1222    /// Optional fixed kept-level set (sorted by f64 bit pattern) captured at fit time.
1223    /// When present, prediction uses exactly these columns to avoid design drift.
1224    #[serde(default)]
1225    pub frozen_levels: Option<Vec<u64>>,
1226}
1227
1228pub fn default_random_effect_penalized() -> bool {
1229    true
1230}
1231
1232pub fn validate_measure_jet_positive_vec_len(
1233    label: &str,
1234    term_name: &str,
1235    field: &str,
1236    values: &[f64],
1237    expected: usize,
1238) -> Result<(), String> {
1239    if values.len() != expected {
1240        return Err(SmoothError::invalid_config(format!(
1241            "{label} term '{term_name}' frozen MeasureJet {field} has length {}, expected {expected}",
1242            values.len()
1243        ))
1244        .into());
1245    }
1246    if values
1247        .iter()
1248        .any(|value| !(value.is_finite() && *value > 0.0))
1249    {
1250        return Err(SmoothError::invalid_config(format!(
1251            "{label} term '{term_name}' frozen MeasureJet {field} values must be positive and finite"
1252        ))
1253        .into());
1254    }
1255    Ok(())
1256}
1257
1258#[derive(Debug, Clone, Serialize, Deserialize)]
1259pub struct TermCollectionSpec {
1260    pub linear_terms: Vec<LinearTermSpec>,
1261    pub random_effect_terms: Vec<RandomEffectTermSpec>,
1262    pub smooth_terms: Vec<SmoothTermSpec>,
1263}
1264
1265pub fn validate_smooth_basis_frozen(
1266    basis: &SmoothBasisSpec,
1267    label: &str,
1268    term_name: &str,
1269) -> Result<(), String> {
1270    match basis {
1271        SmoothBasisSpec::ByVariable { inner, .. }
1272        | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
1273            validate_smooth_basis_frozen(inner, label, term_name)
1274        }
1275        SmoothBasisSpec::BSpline1D { spec, .. } => {
1276            if !matches!(
1277                spec.knotspec,
1278                BSplineKnotSpec::Provided(_)
1279                    | BSplineKnotSpec::PeriodicUniform { .. }
1280                    | BSplineKnotSpec::NaturalCubicRegression { .. }
1281            ) {
1282                return Err(format!(
1283                    "{label} term '{term_name}' is not frozen: BSpline knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression"
1284                ));
1285            }
1286            Ok(())
1287        }
1288        SmoothBasisSpec::ThinPlate { spec, .. } => {
1289            if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1290                return Err(format!(
1291                    "{label} term '{term_name}' is not frozen: ThinPlate centers must be UserProvided"
1292                ));
1293            }
1294            if matches!(
1295                spec.identifiability,
1296                SpatialIdentifiability::OrthogonalToParametric
1297            ) {
1298                return Err(format!(
1299                    "{label} term '{term_name}' is not frozen: ThinPlate identifiability must be FrozenTransform or None"
1300                ));
1301            }
1302            Ok(())
1303        }
1304        _ => Ok(()),
1305    }
1306}
1307
1308impl TermCollectionSpec {
1309    /// Write this collection's topology identity into a warm-start cache
1310    /// fingerprint (#869).
1311    ///
1312    /// The persistent warm-start `cache_key` hashes only family + raw input
1313    /// dimensions, so two fits on the same data that differ *only* in their
1314    /// smooth topology (the `s(..., type=AUTO)` candidate enumeration: sphere
1315    /// vs torus vs euclidean vs duchon) collide on one key and seed each other
1316    /// with geometrically incompatible β/ρ. Folding the per-term structural
1317    /// kind + feature columns + linear/random-effect counts into the shape hash
1318    /// gives each candidate its own key, so the screen→full-refit reuse of one
1319    /// candidate is preserved while cross-candidate contamination is removed.
1320    /// Only the structural identity is hashed (not fitted coefficients or
1321    /// frozen knot values), so a refit of the *same* topology still hits.
1322    pub fn write_structural_shape_hash(&self, h: &mut gam_runtime::warm_start::Fingerprinter) {
1323        h.write_str("term-collection");
1324        h.write_usize(self.linear_terms.len());
1325        for linear in &self.linear_terms {
1326            h.write_str(&linear.name);
1327        }
1328        h.write_usize(self.random_effect_terms.len());
1329        h.write_usize(self.smooth_terms.len());
1330        for smooth in &self.smooth_terms {
1331            h.write_str(&smooth.name);
1332            h.write_str(smooth.basis.structural_kind());
1333            for col in smooth.basis.structural_feature_cols() {
1334                h.write_usize(col);
1335            }
1336        }
1337    }
1338
1339    /// Validate that a term collection spec represents a fully frozen model
1340    /// (i.e. all knots/centers are pre-computed, identifiability transforms are
1341    /// baked in, and random-effect levels are fixed).
1342    pub fn validate_frozen(&self, label: &str) -> Result<(), String> {
1343        for linear in &self.linear_terms {
1344            if let (Some(min), Some(max)) = (linear.coefficient_min, linear.coefficient_max)
1345                && (!min.is_finite() || !max.is_finite() || min > max)
1346            {
1347                return Err(SmoothError::invalid_config(format!(
1348                    "{label} linear term '{}' has invalid coefficient constraint [{min}, {max}]",
1349                    linear.name
1350                ))
1351                .into());
1352            }
1353            if let Some(min) = linear.coefficient_min
1354                && !min.is_finite()
1355            {
1356                return Err(SmoothError::invalid_config(format!(
1357                    "{label} linear term '{}' has non-finite coefficient minimum {min}",
1358                    linear.name
1359                ))
1360                .into());
1361            }
1362            if let Some(max) = linear.coefficient_max
1363                && !max.is_finite()
1364            {
1365                return Err(SmoothError::invalid_config(format!(
1366                    "{label} linear term '{}' has non-finite coefficient maximum {max}",
1367                    linear.name
1368                ))
1369                .into());
1370            }
1371            if let LinearCoefficientGeometry::Bounded { min, max, prior } =
1372                &linear.coefficient_geometry
1373            {
1374                if !min.is_finite() || !max.is_finite() || min >= max {
1375                    return Err(SmoothError::invalid_config(format!(
1376                        "{label} bounded term '{}' has invalid bounds [{min}, {max}]",
1377                        linear.name
1378                    ))
1379                    .into());
1380                }
1381                match prior {
1382                    BoundedCoefficientPriorSpec::None | BoundedCoefficientPriorSpec::Uniform => {}
1383                    BoundedCoefficientPriorSpec::Beta { a, b } => {
1384                        if !a.is_finite() || !b.is_finite() || *a < 1.0 || *b < 1.0 {
1385                            return Err(SmoothError::invalid_config(format!(
1386                                "{label} bounded term '{}' has invalid Beta prior ({a}, {b})",
1387                                linear.name
1388                            ))
1389                            .into());
1390                        }
1391                    }
1392                }
1393            }
1394        }
1395        for st in &self.smooth_terms {
1396            match &st.basis {
1397                SmoothBasisSpec::ByVariable { inner, .. } => {
1398                    validate_smooth_basis_frozen(inner, label, &st.name)?;
1399                    let nested = SmoothTermSpec {
1400                        name: st.name.clone(),
1401                        basis: (**inner).clone(),
1402                        shape: st.shape,
1403                        joint_null_rotation: None,
1404                    };
1405                    TermCollectionSpec {
1406                        linear_terms: Vec::new(),
1407                        random_effect_terms: Vec::new(),
1408                        smooth_terms: vec![nested],
1409                    }
1410                    .validate_frozen(label)?;
1411                }
1412                SmoothBasisSpec::FactorSumToZero { inner, levels, .. } => {
1413                    if levels.len() < 2 {
1414                        return Err(format!(
1415                            "{label} term '{}' has invalid frozen sz levels",
1416                            st.name
1417                        ));
1418                    }
1419                    validate_smooth_basis_frozen(inner, label, &st.name)?;
1420                }
1421                SmoothBasisSpec::BSpline1D { spec, .. } => {
1422                    if !matches!(
1423                        spec.knotspec,
1424                        BSplineKnotSpec::Provided(_)
1425                            | BSplineKnotSpec::PeriodicUniform { .. }
1426                            | BSplineKnotSpec::NaturalCubicRegression { .. }
1427                    ) {
1428                        return Err(SmoothError::invalid_config(format!(
1429                            "{label} term '{}' is not frozen: BSpline knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression",
1430                            st.name
1431                        ))
1432                        .into());
1433                    }
1434                }
1435                SmoothBasisSpec::ThinPlate { spec, .. } => {
1436                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1437                        return Err(SmoothError::invalid_config(format!(
1438                            "{label} term '{}' is not frozen: ThinPlate centers must be UserProvided",
1439                            st.name
1440                        ))
1441                        .into());
1442                    }
1443                    if matches!(
1444                        spec.identifiability,
1445                        SpatialIdentifiability::OrthogonalToParametric
1446                    ) {
1447                        return Err(SmoothError::invalid_config(format!(
1448                            "{label} term '{}' is not frozen: ThinPlate identifiability must be FrozenTransform or None",
1449                            st.name
1450                        ))
1451                        .into());
1452                    }
1453                }
1454                SmoothBasisSpec::Sphere { spec, .. } => {
1455                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1456                        return Err(SmoothError::invalid_config(format!(
1457                            "{label} term '{}' is not frozen: Sphere centers must be UserProvided",
1458                            st.name
1459                        ))
1460                        .into());
1461                    }
1462                    if matches!(spec.method, crate::basis::SphereMethod::Harmonic)
1463                        && spec.max_degree.is_none_or(|d| d == 0)
1464                    {
1465                        return Err(format!(
1466                            "{label} term '{}' is not frozen: sphere max_degree must be positive",
1467                            st.name
1468                        ));
1469                    }
1470                }
1471                SmoothBasisSpec::ConstantCurvature { spec, .. } => {
1472                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1473                        return Err(SmoothError::invalid_config(format!(
1474                            "{label} term '{}' is not frozen: ConstantCurvature centers must be UserProvided",
1475                            st.name
1476                        ))
1477                        .into());
1478                    }
1479                    if !(spec.length_scale.is_finite() && spec.length_scale > 0.0) {
1480                        return Err(SmoothError::invalid_config(format!(
1481                            "{label} term '{}' is not frozen: ConstantCurvature length_scale must be the realized positive value",
1482                            st.name
1483                        ))
1484                        .into());
1485                    }
1486                }
1487                SmoothBasisSpec::MeasureJet { spec, .. } => {
1488                    let centers = match &spec.center_strategy {
1489                        CenterStrategy::UserProvided(centers) => centers,
1490                        _ => {
1491                            return Err(SmoothError::invalid_config(format!(
1492                                "{label} term '{}' is not frozen: MeasureJet centers must be UserProvided",
1493                                st.name
1494                            ))
1495                            .into());
1496                        }
1497                    };
1498                    if centers.nrows() == 0 {
1499                        return Err(SmoothError::invalid_config(format!(
1500                            "{label} term '{}' is not frozen: MeasureJet centers are empty",
1501                            st.name
1502                        ))
1503                        .into());
1504                    }
1505                    if !(spec.length_scale.is_finite() && spec.length_scale > 0.0) {
1506                        return Err(SmoothError::invalid_config(format!(
1507                            "{label} term '{}' is not frozen: MeasureJet length_scale must be the realized positive value",
1508                            st.name
1509                        ))
1510                        .into());
1511                    }
1512                    // Exact replay needs the fit-data penalty quadrature and
1513                    // normalization payload (`BasisMetadata::MeasureJet`).
1514                    let frozen = spec.frozen_quadrature.as_ref().ok_or_else(|| {
1515                        SmoothError::invalid_config(format!(
1516                            "{label} term '{}' is not frozen: MeasureJet frozen_quadrature payload is missing",
1517                            st.name
1518                        ))
1519                    })?;
1520                    if frozen.masses.len() != centers.nrows() {
1521                        return Err(SmoothError::invalid_config(format!(
1522                            "{label} term '{}' frozen MeasureJet has {} masses for {} centers",
1523                            st.name,
1524                            frozen.masses.len(),
1525                            centers.nrows()
1526                        ))
1527                        .into());
1528                    }
1529                    let total_mass = frozen.masses.sum();
1530                    if frozen
1531                        .masses
1532                        .iter()
1533                        .any(|mass| !(mass.is_finite() && *mass >= 0.0))
1534                        || !(total_mass.is_finite() && total_mass > 0.0)
1535                    {
1536                        return Err(SmoothError::invalid_config(format!(
1537                            "{label} term '{}' frozen MeasureJet masses must be finite, nonnegative, and have positive total mass",
1538                            st.name
1539                        ))
1540                        .into());
1541                    }
1542                    let n_levels = frozen.eps_band.len();
1543                    if n_levels == 0
1544                        || frozen
1545                            .eps_band
1546                            .iter()
1547                            .any(|eps| !(eps.is_finite() && *eps > 0.0))
1548                    {
1549                        return Err(SmoothError::invalid_config(format!(
1550                            "{label} term '{}' frozen MeasureJet eps_band must be nonempty, finite, and positive",
1551                            st.name
1552                        ))
1553                        .into());
1554                    }
1555                    for (idx, pair) in frozen.eps_band.windows(2).enumerate() {
1556                        if pair[1] <= pair[0] {
1557                            return Err(SmoothError::invalid_config(format!(
1558                                "{label} term '{}' frozen MeasureJet eps_band is not strictly ascending at {idx}: {} then {}",
1559                                st.name,
1560                                pair[0],
1561                                pair[1]
1562                            ))
1563                            .into());
1564                        }
1565                    }
1566                    validate_measure_jet_positive_vec_len(
1567                        label,
1568                        &st.name,
1569                        "support_means",
1570                        &frozen.support_means,
1571                        n_levels,
1572                    )?;
1573                    // Mode predicate MUST match the builder's
1574                    // (`measure_jet_multiscale_mode`): per-level/multiscale is the
1575                    // explicit `spec.multiscale` opt-in (#1116). In single-scale
1576                    // mode the builder emits a single FUSED penalty (empty
1577                    // per-level scales + `fused_penalty_normalization_scale:
1578                    // Some`); only the multiscale opt-in carries `n_levels`
1579                    // per-level scales.
1580                    let per_level = crate::basis::measure_jet_multiscale_mode(spec);
1581                    if per_level {
1582                        validate_measure_jet_positive_vec_len(
1583                            label,
1584                            &st.name,
1585                            "penalty_normalization_scales",
1586                            &frozen.penalty_normalization_scales,
1587                            n_levels,
1588                        )?;
1589                        validate_measure_jet_positive_vec_len(
1590                            label,
1591                            &st.name,
1592                            "raw_penalty_normalization_scales",
1593                            &frozen.raw_penalty_normalization_scales,
1594                            n_levels,
1595                        )?;
1596                        if frozen.fused_penalty_normalization_scale.is_some() {
1597                            return Err(SmoothError::invalid_config(format!(
1598                                "{label} term '{}' per-level MeasureJet must not carry a fused penalty normalization scale",
1599                                st.name
1600                            ))
1601                            .into());
1602                        }
1603                    } else {
1604                        if !frozen.penalty_normalization_scales.is_empty()
1605                            || !frozen.raw_penalty_normalization_scales.is_empty()
1606                        {
1607                            return Err(SmoothError::invalid_config(format!(
1608                                "{label} term '{}' fused MeasureJet must not carry per-level penalty normalization scales",
1609                                st.name
1610                            ))
1611                            .into());
1612                        }
1613                        match frozen.fused_penalty_normalization_scale {
1614                            Some(scale) if scale.is_finite() && scale > 0.0 => {}
1615                            Some(scale) => {
1616                                return Err(SmoothError::invalid_config(format!(
1617                                    "{label} term '{}' fused MeasureJet penalty normalization scale must be positive and finite, got {scale}",
1618                                    st.name
1619                                ))
1620                                .into());
1621                            }
1622                            None => {
1623                                return Err(SmoothError::invalid_config(format!(
1624                                    "{label} term '{}' fused MeasureJet is missing its penalty normalization scale",
1625                                    st.name
1626                                ))
1627                                .into());
1628                            }
1629                        }
1630                    }
1631                }
1632                SmoothBasisSpec::Matern { spec, .. } => {
1633                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1634                        return Err(SmoothError::invalid_config(format!(
1635                            "{label} term '{}' is not frozen: Matern centers must be UserProvided",
1636                            st.name
1637                        ))
1638                        .into());
1639                    }
1640                }
1641                SmoothBasisSpec::Duchon { spec, .. } => {
1642                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1643                        return Err(SmoothError::invalid_config(format!(
1644                            "{label} term '{}' is not frozen: Duchon centers must be UserProvided",
1645                            st.name
1646                        ))
1647                        .into());
1648                    }
1649                    if matches!(
1650                        spec.identifiability,
1651                        SpatialIdentifiability::OrthogonalToParametric
1652                    ) {
1653                        return Err(SmoothError::invalid_config(format!(
1654                            "{label} term '{}' is not frozen: Duchon identifiability must be FrozenTransform or None",
1655                            st.name
1656                        ))
1657                        .into());
1658                    }
1659                }
1660                SmoothBasisSpec::Pca {
1661                    centered,
1662                    center_mean,
1663                    pca_basis_path,
1664                    ..
1665                } => {
1666                    if *centered && center_mean.is_none() && pca_basis_path.is_none() {
1667                        return Err(SmoothError::invalid_config(format!(
1668                            "{label} term '{}' is not frozen: centered Pca missing center_mean",
1669                            st.name
1670                        ))
1671                        .into());
1672                    }
1673                }
1674                SmoothBasisSpec::BySmooth { smooth, by_kind } => {
1675                    if let SmoothBasisSpec::BySmooth { .. } = smooth.as_ref() {
1676                        return Err(format!("{label} term '{}' has nested by-smooths", st.name));
1677                    }
1678                    match by_kind {
1679                        ByVarKind::Numeric { .. } => {}
1680                        ByVarKind::Factor { frozen_levels, .. } if frozen_levels.is_none() => {
1681                            return Err(format!(
1682                                "{label} term '{}' is not frozen: by-factor levels missing",
1683                                st.name
1684                            ));
1685                        }
1686                        ByVarKind::Factor { .. } => {}
1687                    }
1688                    let nested = TermCollectionSpec {
1689                        linear_terms: vec![],
1690                        random_effect_terms: vec![],
1691                        smooth_terms: vec![SmoothTermSpec {
1692                            name: st.name.clone(),
1693                            basis: (**smooth).clone(),
1694                            shape: st.shape,
1695                            joint_null_rotation: None,
1696                        }],
1697                    };
1698                    nested.validate_frozen(label)?;
1699                }
1700                SmoothBasisSpec::FactorSmooth { spec } => {
1701                    if spec.group_frozen_levels.is_none() {
1702                        return Err(format!(
1703                            "{label} term '{}' is not frozen: factor-smooth levels missing",
1704                            st.name
1705                        ));
1706                    }
1707                    if !matches!(
1708                        spec.marginal.knotspec,
1709                        BSplineKnotSpec::Provided(_)
1710                            | BSplineKnotSpec::PeriodicUniform { .. }
1711                            // mgcv's `bs="sz"` default marginal is a cubic
1712                            // regression spline (#1074), and the freeze step
1713                            // restores it as a `NaturalCubicRegression` knotspec
1714                            // carrying its `k` value-knots (spatial_optimization.rs
1715                            // `marginal_is_cr` branch) — the SAME treatment the
1716                            // tensor margin already gets in the arm below. Without
1717                            // this variant a frozen `sz` factor smooth fails its own
1718                            // predict-time freeze check ("factor-smooth marginal
1719                            // knots missing") even though its knots are fully
1720                            // materialized; the validation simply was not updated
1721                            // when the cr marginal landed.
1722                            | BSplineKnotSpec::NaturalCubicRegression { .. }
1723                    ) {
1724                        return Err(format!(
1725                            "{label} term '{}' is not frozen: factor-smooth marginal knots missing",
1726                            st.name
1727                        ));
1728                    }
1729                }
1730                SmoothBasisSpec::TensorBSpline { spec, .. } => {
1731                    for (dim, marginal) in spec.marginalspecs.iter().enumerate() {
1732                        if !matches!(
1733                            marginal.knotspec,
1734                            BSplineKnotSpec::Provided(_)
1735                                | BSplineKnotSpec::PeriodicUniform { .. }
1736                                | BSplineKnotSpec::NaturalCubicRegression { .. }
1737                        ) {
1738                            return Err(SmoothError::invalid_config(format!(
1739                                "{label} term '{}' dim {} is not frozen: tensor marginal knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression",
1740                                st.name, dim
1741                            ))
1742                            .into());
1743                        }
1744                    }
1745                    if matches!(
1746                        spec.identifiability,
1747                        TensorBSplineIdentifiability::SumToZero
1748                            | TensorBSplineIdentifiability::MarginalSumToZero
1749                    ) {
1750                        return Err(SmoothError::invalid_config(format!(
1751                            "{label} term '{}' is not frozen: tensor identifiability must be FrozenTransform or None",
1752                            st.name
1753                        ))
1754                        .into());
1755                    }
1756                }
1757            }
1758        }
1759
1760        for rt in &self.random_effect_terms {
1761            if rt.frozen_levels.is_none() {
1762                return Err(SmoothError::invalid_config(format!(
1763                    "{label} random-effect term '{}' is not frozen: missing frozen_levels",
1764                    rt.name
1765                ))
1766                .into());
1767            }
1768        }
1769
1770        Ok(())
1771    }
1772
1773    /// Re-resolve every stored feature-column index through `remap`, returning a
1774    /// spec that addresses a different column layout.
1775    ///
1776    /// A frozen `TermCollectionSpec` stores feature columns as *absolute indices
1777    /// into the training table*. To replay it on a fresh dataset whose columns
1778    /// sit at different positions — the common case at prediction time, where
1779    /// the response column is unknown and may be absent entirely — every index
1780    /// must be re-resolved against the new layout. `remap` receives each
1781    /// training-table index and returns its position in the runtime table;
1782    /// callers typically implement it as "look the name up in the training
1783    /// headers, then resolve that name against the prediction dataset".
1784    ///
1785    /// This is the single authority on *which* fields carry a column index
1786    /// across every basis variant (linear, random-effect, the `by=` column of
1787    /// `ByVariable`/`FactorSumToZero`/`BySmooth`, the continuous and group
1788    /// columns of a `FactorSmooth`, and the multi-axis `feature_cols` of every
1789    /// spatial/tensor basis), so a predict-time realignment cannot silently miss
1790    /// one and dereference a stale training index.
1791    pub fn remap_feature_columns<E, F>(&self, mut remap: F) -> Result<TermCollectionSpec, E>
1792    where
1793        F: FnMut(usize) -> Result<usize, E>,
1794    {
1795        let mut out = self.clone();
1796        for lt in &mut out.linear_terms {
1797            lt.feature_col = remap(lt.feature_col)?;
1798            // Also remap the full interaction-factor list. The design builder
1799            // (`build_term_collection_design_inner`) materializes the column from
1800            // `effective_feature_cols()` — which returns `feature_cols` whenever
1801            // it is non-empty (i.e. essentially always, including a plain linear
1802            // term where `feature_cols == [feature_col]`). Remapping only the
1803            // singular `feature_col` left these at their saved *training* indices
1804            // at predict time, so a parametric `Surv(...) ~ x` (and any `:`
1805            // interaction) bailed with "feature column N out of bounds" once the
1806            // response/time columns shift the runtime layout (issue #898).
1807            for fc in lt.feature_cols.iter_mut() {
1808                *fc = remap(*fc)?;
1809            }
1810            // A factor-aware `:` interaction also gates on categorical columns;
1811            // those indices live in the same training-time layout and must be
1812            // realigned to the runtime table alongside the numeric operands, or
1813            // the predict-time level indicator would dereference a stale column.
1814            for (col, _bits) in lt.categorical_levels.iter_mut() {
1815                *col = remap(*col)?;
1816            }
1817        }
1818        for rt in &mut out.random_effect_terms {
1819            rt.feature_col = remap(rt.feature_col)?;
1820        }
1821        for st in &mut out.smooth_terms {
1822            remap_smooth_basis_feature_columns(&mut st.basis, &mut remap)?;
1823        }
1824        Ok(out)
1825    }
1826}
1827
1828/// Walk a `SmoothBasisSpec` tree, re-resolving every column index through
1829/// `remap`. Shared by all predict-time column realignment (see
1830/// [`TermCollectionSpec::remap_feature_columns`]); kept exhaustive so a newly
1831/// added index-bearing variant fails to compile until it is handled here.
1832pub fn remap_smooth_basis_feature_columns<E, F>(
1833    basis: &mut SmoothBasisSpec,
1834    remap: &mut F,
1835) -> Result<(), E>
1836where
1837    F: FnMut(usize) -> Result<usize, E>,
1838{
1839    match basis {
1840        SmoothBasisSpec::ByVariable { inner, by_col, .. }
1841        | SmoothBasisSpec::FactorSumToZero { inner, by_col, .. } => {
1842            *by_col = remap(*by_col)?;
1843            remap_smooth_basis_feature_columns(inner, remap)?;
1844        }
1845        SmoothBasisSpec::BSpline1D { feature_col, .. } => {
1846            *feature_col = remap(*feature_col)?;
1847        }
1848        SmoothBasisSpec::BySmooth { smooth, by_kind } => {
1849            let by_feature_col = match by_kind {
1850                ByVarKind::Numeric { feature_col } | ByVarKind::Factor { feature_col, .. } => {
1851                    feature_col
1852                }
1853            };
1854            *by_feature_col = remap(*by_feature_col)?;
1855            remap_smooth_basis_feature_columns(smooth, remap)?;
1856        }
1857        SmoothBasisSpec::FactorSmooth { spec } => {
1858            for fc in spec.continuous_cols.iter_mut() {
1859                *fc = remap(*fc)?;
1860            }
1861            spec.group_col = remap(spec.group_col)?;
1862        }
1863        SmoothBasisSpec::ThinPlate { feature_cols, .. }
1864        | SmoothBasisSpec::Sphere { feature_cols, .. }
1865        | SmoothBasisSpec::ConstantCurvature { feature_cols, .. }
1866        | SmoothBasisSpec::Matern { feature_cols, .. }
1867        | SmoothBasisSpec::MeasureJet { feature_cols, .. }
1868        | SmoothBasisSpec::Duchon { feature_cols, .. }
1869        | SmoothBasisSpec::Pca { feature_cols, .. }
1870        | SmoothBasisSpec::TensorBSpline { feature_cols, .. } => {
1871            for fc in feature_cols.iter_mut() {
1872                *fc = remap(*fc)?;
1873            }
1874        }
1875    }
1876    Ok(())
1877}
1878
1879#[derive(Debug, Clone)]
1880pub enum PenaltyStructureHint {
1881    Ridge(f64),
1882    Kronecker(Vec<Array2<f64>>),
1883}
1884
1885/// A penalty matrix stored at its natural block size together with the
1886/// column range it occupies in the global coefficient vector.
1887///
1888/// Instead of embedding every penalty into a full `p_total × p_total` dense
1889/// matrix filled with zeros, we keep the compact local matrix and reconstruct
1890/// the global view only when a downstream consumer explicitly requires it.
1891#[derive(Clone)]
1892pub struct BlockwisePenalty {
1893    /// Column range in the global coefficient vector that this penalty covers.
1894    pub col_range: Range<usize>,
1895    /// The local penalty matrix — dimensions `block_p × block_p` where
1896    /// `block_p = col_range.len()`.
1897    pub local: Array2<f64>,
1898    /// Optional nonzero centering vector for this coefficient block.
1899    pub prior_mean: gam_problem::CoefficientPriorMean,
1900    /// Optional structural hint so downstream spectral/logdet code can stay
1901    /// block-local or factorized without reverse-engineering the matrix.
1902    pub structure_hint: Option<PenaltyStructureHint>,
1903    /// Optional operator-form handle bit-equivalent to `local`. Populated when
1904    /// the originating closed-form factory emitted an op-form penalty so exact
1905    /// operator algebra can use matvec instead of materializing the dense
1906    /// `block_p × block_p` Gram. `None` for ordinary dense penalties.
1907    pub op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1908}
1909
1910impl std::fmt::Debug for BlockwisePenalty {
1911    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1912        f.debug_struct("BlockwisePenalty")
1913            .field("col_range", &self.col_range)
1914            .field(
1915                "local",
1916                &format_args!("{}×{}", self.local.nrows(), self.local.ncols()),
1917            )
1918            .field("prior_mean", &self.prior_mean)
1919            .field("structure_hint", &self.structure_hint)
1920            .field("op", &self.op.as_ref().map(|o| o.dim()))
1921            .finish()
1922    }
1923}
1924
1925impl BlockwisePenalty {
1926    /// Create a new blockwise penalty.
1927    pub fn new(col_range: Range<usize>, local: Array2<f64>) -> Self {
1928        assert_eq!(col_range.len(), local.nrows());
1929        assert_eq!(col_range.len(), local.ncols());
1930        Self {
1931            col_range,
1932            local,
1933            prior_mean: gam_problem::CoefficientPriorMean::Zero,
1934            structure_hint: None,
1935            op: None,
1936        }
1937    }
1938
1939    pub fn with_prior_mean(
1940        mut self,
1941        prior_mean: gam_problem::CoefficientPriorMean,
1942    ) -> Self {
1943        self.prior_mean = prior_mean;
1944        self
1945    }
1946
1947    /// Attach an op-form penalty handle bit-equivalent to `local`.
1948    pub fn with_op(
1949        mut self,
1950        op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1951    ) -> Self {
1952        self.op = op;
1953        self
1954    }
1955
1956    pub fn ridge(col_range: Range<usize>, scale: f64) -> Self {
1957        let block_size = col_range.len();
1958        let mut local = Array2::<f64>::zeros((block_size, block_size));
1959        for i in 0..block_size {
1960            local[[i, i]] = scale;
1961        }
1962        Self {
1963            col_range,
1964            local,
1965            prior_mean: gam_problem::CoefficientPriorMean::Zero,
1966            structure_hint: Some(PenaltyStructureHint::Ridge(scale)),
1967            op: None,
1968        }
1969    }
1970
1971    pub fn kronecker(
1972        col_range: Range<usize>,
1973        local: Array2<f64>,
1974        factors: Vec<Array2<f64>>,
1975    ) -> Self {
1976        assert_eq!(col_range.len(), local.nrows());
1977        assert_eq!(col_range.len(), local.ncols());
1978        Self {
1979            col_range,
1980            local,
1981            prior_mean: gam_problem::CoefficientPriorMean::Zero,
1982            structure_hint: Some(PenaltyStructureHint::Kronecker(factors)),
1983            op: None,
1984        }
1985    }
1986
1987    /// Expand this blockwise penalty into a full `p_total × p_total` dense
1988    /// matrix (mostly zeros). Use sparingly — the whole point of blockwise
1989    /// storage is to avoid this allocation.
1990    pub fn to_global(&self, p_total: usize) -> Array2<f64> {
1991        let mut g = Array2::<f64>::zeros((p_total, p_total));
1992        let r = &self.col_range;
1993        assert!(
1994            r.end <= p_total && self.local.nrows() == r.len() && self.local.ncols() == r.len(),
1995            "BlockwisePenalty::to_global shape invariant violated: \
1996             col_range={}..{}, local={}x{}, p_total={}",
1997            r.start,
1998            r.end,
1999            self.local.nrows(),
2000            self.local.ncols(),
2001            p_total,
2002        );
2003        g.slice_mut(s![r.start..r.end, r.start..r.end])
2004            .assign(&self.local);
2005        g
2006    }
2007
2008    /// Convert into a blockwise [`gam_problem::PenaltyMatrix`] without
2009    /// expanding to full dimensions.
2010    pub fn to_penalty_matrix(
2011        &self,
2012        total_dim: usize,
2013    ) -> gam_problem::PenaltyMatrix {
2014        gam_problem::PenaltyMatrix::Blockwise {
2015            local: self.local.clone(),
2016            col_range: self.col_range.clone(),
2017            total_dim,
2018        }
2019    }
2020
2021    /// The block size of this penalty.
2022    #[inline]
2023    pub fn block_size(&self) -> usize {
2024        self.col_range.len()
2025    }
2026}
2027
2028/// Compute `Σ_k λ_k S_k` directly from blockwise penalties, accumulating
2029/// into a pre-allocated `p_total × p_total` output without ever materializing
2030/// individual global matrices.
2031pub fn weighted_blockwise_penalty_sum(
2032    penalties: &[BlockwisePenalty],
2033    lambdas: &[f64],
2034    p_total: usize,
2035) -> Array2<f64> {
2036    assert_eq!(penalties.len(), lambdas.len());
2037    // Smoothing parameters λ_k must be non-negative and finite. A negative
2038    // λ would flip the sign of the corresponding block S_k, turning the
2039    // total penalty matrix indefinite and silently corrupting every
2040    // downstream Cholesky / PIRLS / REML / pseudo-logdet computation that
2041    // assumes S_λ ⪰ 0. Catch this at the boundary rather than after it
2042    // has propagated.
2043    for (idx, &lam) in lambdas.iter().enumerate() {
2044        assert!(
2045            lam.is_finite() && lam >= 0.0,
2046            "weighted_blockwise_penalty_sum: lambdas[{idx}] = {lam} is invalid (must be finite and non-negative; negative smoothing parameters violate S_λ ⪰ 0)",
2047        );
2048    }
2049    // Block column ranges must also fit inside the declared total parameter
2050    // dimension; an out-of-bounds slice would otherwise panic from ndarray
2051    // with a far less informative message.
2052    for (idx, bp) in penalties.iter().enumerate() {
2053        let r = &bp.col_range;
2054        assert!(
2055            r.end <= p_total,
2056            "weighted_blockwise_penalty_sum: penalties[{idx}] col_range {:?} exceeds p_total = {p_total}",
2057            r,
2058        );
2059    }
2060    let mut out = Array2::<f64>::zeros((p_total, p_total));
2061    for (bp, &lam) in penalties.iter().zip(lambdas.iter()) {
2062        let r = &bp.col_range;
2063        let mut slice = out.slice_mut(s![r.start..r.end, r.start..r.end]);
2064        slice.scaled_add(lam, &bp.local);
2065    }
2066    out
2067}
2068
2069// ---------------------------------------------------------------------------
2070// KroneckerPenaltySystem — factored tensor-product penalty representation
2071// ---------------------------------------------------------------------------
2072
2073/// Factored representation of tensor-product penalties with precomputed
2074/// marginal eigensystems for O(∏q_j) logdet and penalty operations.
2075#[derive(Debug, Clone)]
2076pub struct KroneckerPenaltySystem {
2077    /// Marginal penalty matrices: `marginal_penalties[k]` is `(q_k, q_k)`.
2078    pub marginal_penalties: Vec<Array2<f64>>,
2079    /// Precomputed eigensystems: `(eigenvalues, eigenvectors)` per marginal.
2080    pub marginal_eigensystems: Vec<(Array1<f64>, Array2<f64>)>,
2081    /// Marginal basis dimensions.
2082    pub marginal_dims: Vec<usize>,
2083    /// Whether a global ridge (double) penalty is present.
2084    pub has_double_penalty: bool,
2085}
2086
2087impl KroneckerPenaltySystem {
2088    pub fn new(
2089        marginal_penalties: Vec<Array2<f64>>,
2090        marginal_dims: Vec<usize>,
2091        has_double_penalty: bool,
2092    ) -> Result<Self, BasisError> {
2093        if marginal_penalties.len() != marginal_dims.len() {
2094            crate::bail_dim_basis!(
2095                "KroneckerPenaltySystem: {} penalties vs {} dims",
2096                marginal_penalties.len(),
2097                marginal_dims.len()
2098            );
2099        }
2100        let eigensystems =
2101            kronecker_marginal_eigensystems(&marginal_penalties, "KroneckerPenaltySystem")
2102                .map_err(|e| BasisError::InvalidInput(e.to_string()))?;
2103        Ok(Self {
2104            marginal_penalties,
2105            marginal_eigensystems: eigensystems,
2106            marginal_dims,
2107            has_double_penalty,
2108        })
2109    }
2110
2111    pub fn p_total(&self) -> usize {
2112        self.marginal_dims.iter().copied().product()
2113    }
2114
2115    pub fn ndim(&self) -> usize {
2116        self.marginal_dims.len()
2117    }
2118
2119    pub fn num_penalties(&self) -> usize {
2120        self.marginal_dims.len() + if self.has_double_penalty { 1 } else { 0 }
2121    }
2122
2123    /// Compute `log|S|₊` and its first/second derivatives w.r.t. `ρ_k = log(λ_k)`.
2124    ///
2125    /// Iterates over the ∏q_j multi-index grid. Cost: O(d · ∏q_j), no O(p²) storage.
2126    pub fn logdet_and_derivatives(
2127        &self,
2128        lambdas: &[f64],
2129        ridge: f64,
2130    ) -> (f64, Array1<f64>, Array2<f64>) {
2131        let n_pen = self.num_penalties();
2132        assert_eq!(lambdas.len(), n_pen, "lambda count mismatch");
2133        let marginal_evals: Vec<_> = self
2134            .marginal_eigensystems
2135            .iter()
2136            .map(|(evals, _)| evals.view())
2137            .collect();
2138        kronecker_logdet_and_derivatives(
2139            &marginal_evals,
2140            &self.marginal_dims,
2141            lambdas,
2142            self.has_double_penalty,
2143            ridge,
2144        )
2145    }
2146
2147    pub fn logdet_rank_and_derivatives(
2148        &self,
2149        lambdas: &[f64],
2150        ridge: f64,
2151    ) -> (f64, usize, Array1<f64>, Array2<f64>) {
2152        let n_pen = self.num_penalties();
2153        assert_eq!(lambdas.len(), n_pen, "lambda count mismatch");
2154        let d = self.marginal_dims.len();
2155        let mut logdet = 0.0;
2156        let mut rank = 0usize;
2157        let mut grad = Array1::<f64>::zeros(n_pen);
2158        let mut hess = Array2::<f64>::zeros((n_pen, n_pen));
2159        // Positivity floor for a penalized eigenvalue `σ`: below this the mode
2160        // is treated as an unpenalized (null-space) direction and excluded from
2161        // both the rank count and the pseudo-log-determinant.
2162        const EIGENVALUE_POSITIVITY_FLOOR: f64 = 1e-12;
2163        // Floor on the *structural* eigenvalue sum (λ-independent) used to decide
2164        // whether a mode lives in the penalty range space and so should receive
2165        // the stabilizing ridge; a structurally-null mode gets no ridge.
2166        const STRUCTURAL_ZERO_FLOOR: f64 = 1e-12;
2167        let mut multi_idx = vec![0usize; d];
2168        loop {
2169            let mut sigma = 0.0;
2170            let mut structural_sigma = 0.0;
2171            for k in 0..d {
2172                let marginal_eigenvalue = self.marginal_eigensystems[k].0[multi_idx[k]];
2173                structural_sigma += marginal_eigenvalue;
2174                sigma += lambdas[k] * marginal_eigenvalue;
2175            }
2176            let joint_null = structural_sigma <= STRUCTURAL_ZERO_FLOOR;
2177            if self.has_double_penalty && joint_null {
2178                sigma += lambdas[d];
2179            }
2180            if structural_sigma > STRUCTURAL_ZERO_FLOOR {
2181                sigma += ridge;
2182            }
2183
2184            if sigma > EIGENVALUE_POSITIVITY_FLOOR {
2185                rank += 1;
2186                logdet += sigma.ln();
2187                let inv_sigma = 1.0 / sigma;
2188                let inv_sigma2 = inv_sigma * inv_sigma;
2189                for k in 0..n_pen {
2190                    let ck = if k < d {
2191                        lambdas[k] * self.marginal_eigensystems[k].0[multi_idx[k]]
2192                    } else if joint_null {
2193                        lambdas[d]
2194                    } else {
2195                        0.0
2196                    };
2197                    grad[k] += ck * inv_sigma;
2198                    hess[[k, k]] += ck * inv_sigma - ck * ck * inv_sigma2;
2199                    for l in (k + 1)..n_pen {
2200                        let cl = if l < d {
2201                            lambdas[l] * self.marginal_eigensystems[l].0[multi_idx[l]]
2202                        } else if joint_null {
2203                            lambdas[d]
2204                        } else {
2205                            0.0
2206                        };
2207                        let off = -ck * cl * inv_sigma2;
2208                        hess[[k, l]] += off;
2209                        hess[[l, k]] += off;
2210                    }
2211                }
2212            }
2213
2214            let mut carry = true;
2215            for dim in (0..d).rev() {
2216                if carry {
2217                    multi_idx[dim] += 1;
2218                    if multi_idx[dim] < self.marginal_dims[dim] {
2219                        carry = false;
2220                    } else {
2221                        multi_idx[dim] = 0;
2222                    }
2223                }
2224            }
2225            if carry {
2226                break;
2227            }
2228        }
2229        (logdet, rank, grad, hess)
2230    }
2231}
2232
2233#[cfg(test)]
2234mod joint_unpenalized_dim_tests {
2235    use super::joint_unpenalized_dim;
2236    use ndarray::{Array2, array};
2237
2238    #[test]
2239    fn no_penalty_is_fully_unpenalized() {
2240        assert_eq!(joint_unpenalized_dim(4, &[], &[]), 4);
2241    }
2242
2243    #[test]
2244    fn single_penalty_returns_its_own_null_space() {
2245        // A 3×3 penalty that penalizes only the last coordinate ⇒ 2-dim null
2246        // space (the first two coordinates are unpenalized).
2247        let s = array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 5.0]];
2248        assert_eq!(joint_unpenalized_dim(3, std::slice::from_ref(&s), &[2]), 2);
2249    }
2250
2251    #[test]
2252    fn complementary_double_penalty_has_empty_joint_null_space() {
2253        // The #1360 case in miniature: a "bending" penalty that leaves the
2254        // first coordinate (its 2-dim... here 1-dim) null, plus a
2255        // complementary "null-space ridge" that penalizes exactly that
2256        // coordinate. Per-penalty null dims are {1, 2} and sum to 3 (≈ p),
2257        // but the INTERSECTION is empty: every coordinate is penalized by
2258        // someone, so the joint unpenalized dim is 0.
2259        let bending = array![[0.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]];
2260        let ridge = array![[2.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]];
2261        assert_eq!(joint_unpenalized_dim(3, &[bending, ridge], &[1, 2]), 0);
2262    }
2263
2264    #[test]
2265    fn partial_overlap_keeps_shared_null_direction() {
2266        // Two penalties that BOTH leave coordinate 0 unpenalized ⇒ the shared
2267        // null direction survives the intersection (joint unpenalized dim 1),
2268        // even though naively summing the per-penalty dims would give 4.
2269        let a = array![[0.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 0.0]];
2270        let b = array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]];
2271        assert_eq!(joint_unpenalized_dim(3, &[a, b], &[2, 2]), 1);
2272    }
2273
2274    #[test]
2275    fn non_materialized_penalty_falls_back_conservatively() {
2276        // A penalty whose stored block is not p_local × p_local (e.g. a
2277        // Kronecker tensor factor). With ≥2 penalties the conservative joint
2278        // dim is 0 (never over-rejecting).
2279        let full: Array2<f64> = array![[0.0, 0.0], [0.0, 1.0]];
2280        let factor: Array2<f64> = array![[1.0]]; // wrong shape for p_local=2
2281        assert_eq!(
2282            joint_unpenalized_dim(2, &[full, factor.clone()], &[1, 0]),
2283            0
2284        );
2285        // With a single non-materialized penalty, fall back to its own null dim.
2286        assert_eq!(joint_unpenalized_dim(4, std::slice::from_ref(&factor), &[2]), 2);
2287    }
2288}
2289
2290#[cfg(test)]
2291mod kronecker_penalty_system_tests {
2292    use super::KroneckerPenaltySystem;
2293    use ndarray::array;
2294
2295    #[test]
2296    fn double_penalty_rank_derivatives_use_only_joint_null_space() {
2297        let penalties = vec![
2298            array![[0.0, 0.0], [0.0, 2.0]],
2299            array![[0.0, 0.0], [0.0, 3.0]],
2300        ];
2301        let system = KroneckerPenaltySystem::new(penalties, vec![2usize, 2usize], true).unwrap();
2302        let lambdas = vec![5.0, 7.0, 11.0];
2303
2304        let (logdet, rank, grad, hess) = system.logdet_rank_and_derivatives(&lambdas, 0.0);
2305
2306        let expected_diag = [11.0_f64, 21.0, 10.0, 31.0];
2307        let expected_logdet: f64 = expected_diag.iter().map(|v| v.ln()).sum();
2308        assert_eq!(rank, 4);
2309        assert!((logdet - expected_logdet).abs() <= 1e-12);
2310        assert!(
2311            (grad[2] - 1.0).abs() <= 1e-12,
2312            "double-penalty rank derivative must count only the joint null mode, got {}",
2313            grad[2]
2314        );
2315        assert!(hess[[2, 2]].abs() <= 1e-12);
2316    }
2317}
2318
2319#[derive(Clone, Debug)]
2320pub struct TermCollectionDesign {
2321    /// The full design matrix.
2322    ///
2323    /// Prefer a true sparse matrix when every block is sparse-compatible.
2324    /// If the collection already contains intrinsically sparse blocks, preserve
2325    /// that storage and let PIRLS decide later whether the penalized system is
2326    /// sparse-native eligible. Purely dense materialized blocks still fall back
2327    /// to the lazy block operator when sparse storage would just re-encode a
2328    /// dense matrix.
2329    pub design: DesignMatrix,
2330    pub penalties: Vec<BlockwisePenalty>,
2331    pub nullspace_dims: Vec<usize>,
2332    pub penaltyinfo: Vec<PenaltyBlockInfo>,
2333    pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
2334    /// Optional global coefficient lower bounds for constrained fitting.
2335    /// Length equals `design.ncols()` when present. Unconstrained entries are `-inf`.
2336    pub coefficient_lower_bounds: Option<Array1<f64>>,
2337    /// Optional global inequality constraints:
2338    /// `A * beta >= b`.
2339    pub linear_constraints: Option<LinearInequalityConstraints>,
2340    pub intercept_range: Range<usize>,
2341    pub linear_ranges: Vec<(String, Range<usize>)>,
2342    pub random_effect_ranges: Vec<(String, Range<usize>)>,
2343    pub random_effect_levels: Vec<(String, Vec<u64>)>,
2344    pub smooth: SmoothDesign,
2345}
2346
2347impl TermCollectionDesign {
2348    /// Number of global penalty blocks that precede the smooth-term penalty
2349    /// blocks in the flat smoothing-parameter / EDF-trace layout.
2350    ///
2351    /// The prefix is not the same as `random_effect_ranges.len()`: unpenalized
2352    /// (or empty) random-effect ranges contribute coefficient columns but no
2353    /// penalty block. The authoritative layout is the recorded `penaltyinfo`
2354    /// sequence assembled alongside `penalties`.
2355    pub fn leading_penalty_blocks_before_smooth(&self) -> usize {
2356        self.penaltyinfo
2357            .iter()
2358            .take_while(|info| {
2359                matches!(
2360                    &info.penalty.source,
2361                    crate::basis::PenaltySource::Other(source)
2362                        if source == "LinearTermRidge"
2363                            || source.starts_with("RandomEffectRidge(")
2364                )
2365            })
2366            .count()
2367    }
2368
2369    /// Convert blockwise penalties to `PenaltyMatrix::Blockwise` without
2370    /// expanding to `p_total × p_total`. This is the preferred path for
2371    /// family modules that accept `Vec<PenaltyMatrix>`.
2372    pub fn penalties_as_penalty_matrix(&self) -> Vec<gam_problem::PenaltyMatrix> {
2373        let p = self.design.ncols();
2374        self.penalties
2375            .iter()
2376            .map(|bp| bp.to_penalty_matrix(p))
2377            .collect()
2378    }
2379
2380    /// Number of penalty blocks.
2381    #[inline]
2382    pub fn num_penalties(&self) -> usize {
2383        self.penalties.len()
2384    }
2385
2386    /// Resolve coefficient groups against this design's global coefficient
2387    /// layout and append their penalties after the existing term penalties.
2388    pub fn realize_coefficient_groups(
2389        &self,
2390        groups: &[CoefficientGroupSpec],
2391        base_prior: &gam_spec::RhoPrior,
2392    ) -> Result<RealizedCoefficientGroups, BasisError> {
2393        realize_coefficient_groups(self, groups, base_prior)
2394    }
2395
2396    /// Extract a `KroneckerPenaltySystem` when the model's *only* smooth term is
2397    /// a single Kronecker-factored tensor.
2398    ///
2399    /// This is a deliberate single-tensor fast path, not a partial feature: any
2400    /// other shape — zero Kronecker terms, several of them, or a tensor mixed
2401    /// with non-tensor smooth terms — is served correctly by the standard
2402    /// block-separable assembly, so this returns `None` and the caller falls
2403    /// back to it. The two former conditions (`len != 1` and "a non-Kronecker
2404    /// smooth term exists") are jointly equivalent to "the sole smooth term is
2405    /// Kronecker", which the slice pattern below expresses directly in one pass.
2406    pub fn kronecker_penalty_system(&self) -> Option<KroneckerPenaltySystem> {
2407        let [only_term] = self.smooth.terms.as_slice() else {
2408            return None;
2409        };
2410        let kron = only_term.kronecker_factored.as_ref()?;
2411        // A genuine tensor product needs at least two margins, and the marginal
2412        // design / penalty / dim collections must agree in length. A degenerate
2413        // (single-margin) or internally inconsistent factored basis cannot feed
2414        // the Kronecker fast path, so fall back to the standard assembly rather
2415        // than construct a malformed `KroneckerPenaltySystem` from it.
2416        if kron.marginal_dims.len() < 2
2417            || kron.marginal_penalties.len() != kron.marginal_dims.len()
2418            || kron.marginal_designs.len() != kron.marginal_dims.len()
2419        {
2420            return None;
2421        }
2422        KroneckerPenaltySystem::new(
2423            kron.marginal_penalties.clone(),
2424            kron.marginal_dims.clone(),
2425            kron.has_double_penalty,
2426        )
2427        .ok()
2428    }
2429}
2430
2431// `FittedTermCollection`, `SpatialLengthScaleOptimizationTiming`, and
2432// `FittedTermCollectionWithSpec` were relocated with the GAM fit-orchestration
2433// drivers to `gam-models` (`crate::fit_orchestration::drivers`) — they hold a
2434// `gam_solve::UnifiedFitResult` and are consumed only by those drivers (#1521).
2435
2436#[derive(Clone)]
2437pub struct StandardLatentCoordConfig {
2438    pub values: std::sync::Arc<crate::latent::LatentCoordValues>,
2439    pub term_index: gam_problem::types::SmoothTermIdx,
2440    pub feature_cols: Vec<usize>,
2441    pub manifold: crate::latent::LatentManifold,
2442    pub manifold_auto: bool,
2443    pub retraction_registry: gam_problem::LatentRetractionRegistry,
2444    pub analytic_penalties: Option<std::sync::Arc<crate::AnalyticPenaltyRegistry>>,
2445}
2446
2447#[derive(Clone, Debug, Serialize, Deserialize)]
2448pub struct AdaptiveSpatialMap {
2449    pub termname: String,
2450    pub feature_cols: Vec<usize>,
2451    pub collocation_points: Array2<f64>,
2452    pub inv_magweight: Array1<f64>,
2453    pub invgradweight: Array1<f64>,
2454    pub inv_lapweight: Array1<f64>,
2455}
2456
2457#[derive(Clone, Debug, Serialize, Deserialize)]
2458pub struct AdaptiveRegularizationDiagnostics {
2459    pub epsilon_0: f64,
2460    pub epsilon_g: f64,
2461    pub epsilon_c: f64,
2462    pub epsilon_outer_iterations: usize,
2463    pub mm_iterations: usize,
2464    pub converged: bool,
2465    pub maps: Vec<AdaptiveSpatialMap>,
2466}
2467
2468#[derive(Debug, Clone)]
2469pub struct LinearColumnConditioning {
2470    col_idx: usize,
2471    mean: f64,
2472    scale: f64,
2473}
2474
2475#[derive(Debug, Clone, Default)]
2476pub struct LinearFitConditioning {
2477    pub intercept_idx: usize,
2478    pub columns: Vec<LinearColumnConditioning>,
2479}
2480
2481#[derive(Clone)]
2482pub struct SpatialPsiDerivative {
2483    // These are derivatives with respect to psi = log(kappa), not log(length_scale).
2484    pub penalty_index: usize,
2485    pub penalty_indices: Vec<usize>,
2486    pub global_range: Range<usize>,
2487    pub total_p: usize,
2488    pub x_psi_local: Array2<f64>,
2489    pub s_psi_components_local: Vec<Array2<f64>>,
2490    pub x_psi_psi_local: Array2<f64>,
2491    pub s_psi_psi_components_local: Vec<Array2<f64>>,
2492    pub aniso_group_id: Option<usize>,
2493    /// Pre-computed cross-derivative design matrices for other axes
2494    /// in the same aniso group: Vec of (axis_offset_in_group, matrix).
2495    pub aniso_cross_designs: Option<Vec<(usize, Array2<f64>)>>,
2496    /// On-demand cross-penalty second derivatives ∂²S_m/∂ψ_a∂ψ_b for axes in
2497    /// the same anisotropy group. The input is the other axis offset in the
2498    /// group, and the output is one local penalty matrix per active penalty.
2499    pub aniso_cross_penalty_provider: Option<
2500        std::sync::Arc<
2501            dyn Fn(usize) -> Result<Vec<Array2<f64>>, EstimationError> + Send + Sync + 'static,
2502        >,
2503    >,
2504    /// Optional implicit design-derivative operator (shared across all axes
2505    /// in the same aniso group). When present, `x_psi_local` and
2506    /// `x_psi_psi_local` may be zero-sized, and design-derivative matvecs
2507    /// should go through this operator using `implicit_axis` as the axis index.
2508    pub implicit_operator: Option<std::sync::Arc<crate::basis::ImplicitDesignPsiDerivative>>,
2509    /// Which axis in the implicit operator this entry corresponds to.
2510    pub implicit_axis: usize,
2511}
2512
2513#[derive(Debug, Clone)]
2514pub struct SpatialLogKappaCoords {
2515    /// Flattened ψ values. For isotropic terms, one entry per term.
2516    /// For anisotropic terms, d entries per term (one ψ_a per axis).
2517    pub values: Array1<f64>,
2518    /// Dimensionality of each term: 1 for isotropic, d for anisotropic.
2519    pub dims_per_term: Vec<usize>,
2520}
2521
2522/// Which end of the ψ bound the shared `aniso_bounds_from_data` helper is
2523/// computing. The lower end uses `-max_length_scale.ln()` as the pure-Duchon
2524/// fallback and the `.0` element of `spatial_term_psi_bounds`; the upper end
2525/// uses `-min_length_scale.ln()` and `.1`. Everything else is identical.
2526#[derive(Clone, Copy)]
2527pub enum AnisoBoundEnd {
2528    Lower,
2529    Upper,
2530}
2531
2532impl SpatialLogKappaCoords {
2533    /// Construct from an explicit dims layout plus values.
2534    pub fn new_with_dims(values: Array1<f64>, dims_per_term: Vec<usize>) -> Self {
2535        assert_eq!(
2536            values.len(),
2537            dims_per_term.iter().sum::<usize>(),
2538            "SpatialLogKappaCoords: values length {} != sum of dims_per_term {}",
2539            values.len(),
2540            dims_per_term.iter().sum::<usize>(),
2541        );
2542        Self {
2543            values,
2544            dims_per_term,
2545        }
2546    }
2547
2548    /// Isotropic initialization (backward-compatible path).
2549    pub fn from_length_scales(
2550        spec: &TermCollectionSpec,
2551        term_indices: &[usize],
2552        options: &SpatialLengthScaleOptimizationOptions,
2553    ) -> Self {
2554        let mut out = Array1::<f64>::zeros(term_indices.len());
2555        for (slot, &term_idx) in term_indices.iter().enumerate() {
2556            // Constant-curvature: the single ψ slot is the raw signed κ, seeded
2557            // from the spec (default κ = 0). The −ln(length_scale) convention is
2558            // log-κ semantics and must not touch the raw-κ coordinate; the κ
2559            // window projection happens later via `clamp_to_bounds`. Mirrors the
2560            // aniso constructor's κ branch.
2561            if let Some(cc) = constant_curvature_term_spec(spec, term_idx) {
2562                out[slot] = cc.kappa;
2563                continue;
2564            }
2565            let length_scale = get_spatial_length_scale(spec, term_idx)
2566                .unwrap_or(options.min_length_scale)
2567                .clamp(options.min_length_scale, options.max_length_scale);
2568            out[slot] = -length_scale.ln();
2569        }
2570        Self {
2571            values: out,
2572            dims_per_term: vec![1; term_indices.len()],
2573        }
2574    }
2575
2576    /// Anisotropic-aware initialization.
2577    ///
2578    /// Initialization strategy (per math team recommendation): standardize the
2579    /// knot cloud axiswise, then run the existing isotropic κ initializer in
2580    /// the standardized space. This reuses the trusted isotropic initializer
2581    /// and gives initial η_a = −ln(σ_a) + mean(ln(σ_a)), which satisfies
2582    /// Ση_a = 0 by construction.
2583    ///
2584    /// For each term, checks whether it has `aniso_log_scales` set on its basis spec.
2585    /// - If isotropic (no aniso_log_scales, or 1-D): 1 entry = −ln(length_scale).
2586    /// - If anisotropic with a scalar length scale: d entries, one ψ_a per axis.
2587    ///   Initialized as ψ_a = −ln(length_scale) + η_a  where η_a are the existing
2588    ///   aniso_log_scales (which sum to zero). Multi-dimensional terms without
2589    ///   explicit anisotropy stay scalar here so the seed dimensionality matches
2590    ///   `spatial_dims_per_term`.
2591    /// - If pure Duchon anisotropic: d - 1 free entries store the leading η_a
2592    ///   values directly; the final axis is reconstructed to keep Ση_a = 0.
2593    pub fn from_length_scales_aniso(
2594        spec: &TermCollectionSpec,
2595        term_indices: &[usize],
2596        options: &SpatialLengthScaleOptimizationOptions,
2597    ) -> Self {
2598        let mut vals = Vec::new();
2599        let mut dims = Vec::new();
2600        for &term_idx in term_indices {
2601            // Measure-jet: dial coordinates seeded directly from the term's
2602            // realized (α, τ[, s]); the −ln(length_scale) convention below is
2603            // κ-semantics and never applies to dials.
2604            if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
2605                let seed = measure_jet_psi_seed(mj);
2606                dims.push(seed.len());
2607                vals.extend(seed);
2608                continue;
2609            }
2610            // Constant-curvature: one signed κ slot seeded from the spec's κ
2611            // (clamped feasible). The −ln(length_scale) convention below is
2612            // log-κ semantics and must not touch the raw-κ coordinate. Bounds
2613            // are unavailable here (no data view), so this is the raw spec κ;
2614            // `reseed_from_data` / `clamp_to_bounds` later project it feasible.
2615            if let Some(cc) = constant_curvature_term_spec(spec, term_idx) {
2616                vals.push(cc.kappa);
2617                dims.push(1);
2618                continue;
2619            }
2620            let length_scale = get_spatial_length_scale(spec, term_idx)
2621                .unwrap_or(options.min_length_scale)
2622                .clamp(options.min_length_scale, options.max_length_scale);
2623            let psi_bar = -length_scale.ln(); // global scale = −ln(length_scale)
2624
2625            if spatial_term_uses_per_axis_psi(spec, term_idx) {
2626                // Per-axis anisotropy is enrolled in the joint outer vector:
2627                // ψ_a = ψ̄ + η_a, one slot per axis. The hyper_dirs builder
2628                // produces matching per-axis derivatives in
2629                // `try_build_spatial_term_log_kappa_aniso_derivativeinfos`.
2630                let d = get_spatial_feature_dim(spec, term_idx).unwrap_or(1);
2631                let eta_raw = get_spatial_aniso_log_scales(spec, term_idx)
2632                    .expect("predicate guarantees aniso_log_scales is Some");
2633                let eta = center_aniso_log_scales(&eta_raw);
2634                for &eta_a in &eta {
2635                    vals.push(psi_bar + eta_a);
2636                }
2637                dims.push(d);
2638            } else {
2639                // Isotropic enrollment — either a 1-D term, a multi-D term
2640                // without explicit anisotropy, or a basis (e.g. Duchon) whose
2641                // η is a fixed geometry parameter rather than a REML hyper
2642                // axis. Exactly one ψ̄ slot, matching the single
2643                // `SpatialPsiDerivative` produced by
2644                // `try_build_spatial_term_log_kappa_derivativeinfo`.
2645                vals.push(psi_bar);
2646                dims.push(1);
2647            }
2648        }
2649        Self {
2650            values: Array1::from_vec(vals),
2651            dims_per_term: dims,
2652        }
2653    }
2654
2655    /// Isotropic lower bounds derived from per-term data geometry.
2656    /// Each entry gets the ψ_lo bound returned by `spatial_term_psi_bounds`
2657    /// for the corresponding term, intersected with the options window.
2658    pub fn lower_bounds_from_data(
2659        data: ArrayView2<'_, f64>,
2660        spec: &TermCollectionSpec,
2661        term_indices: &[usize],
2662        options: &SpatialLengthScaleOptimizationOptions,
2663    ) -> Self {
2664        let mut values = Array1::<f64>::zeros(term_indices.len());
2665        for (slot, &term_idx) in term_indices.iter().enumerate() {
2666            values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).0;
2667        }
2668        Self {
2669            values,
2670            dims_per_term: vec![1; term_indices.len()],
2671        }
2672    }
2673
2674    /// Isotropic upper bounds derived from per-term data geometry.
2675    pub fn upper_bounds_from_data(
2676        data: ArrayView2<'_, f64>,
2677        spec: &TermCollectionSpec,
2678        term_indices: &[usize],
2679        options: &SpatialLengthScaleOptimizationOptions,
2680    ) -> Self {
2681        let mut values = Array1::<f64>::zeros(term_indices.len());
2682        for (slot, &term_idx) in term_indices.iter().enumerate() {
2683            values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).1;
2684        }
2685        Self {
2686            values,
2687            dims_per_term: vec![1; term_indices.len()],
2688        }
2689    }
2690
2691    /// Anisotropic-aware lower bounds derived from per-term data geometry.
2692    /// For hybrid anisotropic terms the scalar ψ_lo bound applies to the
2693    /// mean `ψ̄`, not directly to every raw axis coordinate `ψ_a = ψ̄ + η_a`.
2694    /// Shift each axis by the current centered `η_a` so projecting/clamping
2695    /// the seed moves only the global scale direction and does not silently
2696    /// shrink anisotropy that is already consistent with the current
2697    /// `length_scale`.
2698    ///
2699    /// Pure Duchon anisotropy is structurally different: its stored
2700    /// coordinates are (d-1) free η_a values representing log axis-scale
2701    /// ratios, NOT log-κ. For those terms the κ-range geometry bound is
2702    /// over-restrictive (η_a = ±5 is normal, but that corresponds to 7+
2703    /// orders of magnitude in κ-space and would be rejected by the data
2704    /// window). Fall back to the options window `[-ln(max_ls), -ln(min_ls)]`
2705    /// for those coordinates — that's the same bound the pre-data-geometry
2706    /// code used, which is calibrated to allow legitimate anisotropy.
2707    pub fn lower_bounds_aniso_from_data(
2708        data: ArrayView2<'_, f64>,
2709        spec: &TermCollectionSpec,
2710        term_indices: &[usize],
2711        dims_per_term: &[usize],
2712        options: &SpatialLengthScaleOptimizationOptions,
2713    ) -> Self {
2714        Self::aniso_bounds_from_data(
2715            data,
2716            spec,
2717            term_indices,
2718            dims_per_term,
2719            options,
2720            AnisoBoundEnd::Lower,
2721        )
2722    }
2723
2724    /// Anisotropic-aware upper bounds derived from per-term data geometry.
2725    /// See `lower_bounds_aniso_from_data` for the hybrid-aniso offsetting and
2726    /// pure-Duchon dispatch rationale.
2727    pub fn upper_bounds_aniso_from_data(
2728        data: ArrayView2<'_, f64>,
2729        spec: &TermCollectionSpec,
2730        term_indices: &[usize],
2731        dims_per_term: &[usize],
2732        options: &SpatialLengthScaleOptimizationOptions,
2733    ) -> Self {
2734        Self::aniso_bounds_from_data(
2735            data,
2736            spec,
2737            term_indices,
2738            dims_per_term,
2739            options,
2740            AnisoBoundEnd::Upper,
2741        )
2742    }
2743
2744    /// Shared implementation for the lower/upper aniso bounds. The bound end
2745    /// only changes which options scale (`max_length_scale` vs
2746    /// `min_length_scale`) becomes the pure-Duchon fallback bound and which
2747    /// element of the `(lo, hi)` data-geometry tuple is consumed; the
2748    /// per-term cursor walk and aniso-offset handling are identical.
2749    fn aniso_bounds_from_data(
2750        data: ArrayView2<'_, f64>,
2751        spec: &TermCollectionSpec,
2752        term_indices: &[usize],
2753        dims_per_term: &[usize],
2754        options: &SpatialLengthScaleOptimizationOptions,
2755        end: AnisoBoundEnd,
2756    ) -> Self {
2757        assert_eq!(term_indices.len(), dims_per_term.len());
2758        let total: usize = dims_per_term.iter().sum();
2759        let mut values = Array1::<f64>::zeros(total);
2760        let mut cursor = 0;
2761        for (slot, &term_idx) in term_indices.iter().enumerate() {
2762            let d = dims_per_term[slot];
2763            // Measure-jet: per-coordinate dial boxes, never κ-window geometry
2764            // (which would reject legitimate dial values outright).
2765            if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
2766                let bounds = measure_jet_psi_bound_values(mj, matches!(end, AnisoBoundEnd::Upper));
2767                for (offset, bound) in bounds.into_iter().enumerate() {
2768                    if offset < d {
2769                        values[cursor + offset] = bound;
2770                    }
2771                }
2772                cursor += d;
2773                continue;
2774            }
2775            // Constant-curvature: the single signed-κ box from the data chart
2776            // window (symmetric about κ = 0), never a κ = log-scale window.
2777            if constant_curvature_term_spec(spec, term_idx).is_some() {
2778                let (lo, hi) = constant_curvature_kappa_bounds(data, spec, term_idx);
2779                if d >= 1 {
2780                    values[cursor] = match end {
2781                        AnisoBoundEnd::Lower => lo,
2782                        AnisoBoundEnd::Upper => hi,
2783                    };
2784                }
2785                cursor += d;
2786                continue;
2787            }
2788            let psi_bound = {
2789                let (lo, hi) = spatial_term_psi_bounds(data, spec, term_idx, options);
2790                match end {
2791                    AnisoBoundEnd::Lower => lo,
2792                    AnisoBoundEnd::Upper => hi,
2793                }
2794            };
2795            let axis_offsets = if d <= 1 {
2796                vec![0.0; d]
2797            } else {
2798                get_spatial_aniso_log_scales(spec, term_idx)
2799                    .filter(|eta| eta.len() == d)
2800                    .map(|eta| center_aniso_log_scales(&eta))
2801                    .unwrap_or_else(|| vec![0.0; d])
2802            };
2803            for offset in 0..d {
2804                values[cursor + offset] = psi_bound + axis_offsets[offset];
2805            }
2806            cursor += d;
2807        }
2808        Self {
2809            values,
2810            dims_per_term: dims_per_term.to_vec(),
2811        }
2812    }
2813
2814    /// Rewrite any ψ entries whose originating term lacks an explicit
2815    /// `length_scale` so they sit at the midpoint of the per-term data-derived
2816    /// ψ window. Used so the outer optimizer starts inside the physically
2817    /// meaningful region instead of at an arbitrary `options.max_length_scale`
2818    /// derived seed. For terms with an explicit length_scale, the user's
2819    /// choice is respected. Anisotropy offsets η_a (those stored by
2820    /// `from_length_scales_aniso`) are preserved: we re-center around the new
2821    /// ψ̄, keeping Ση_a = 0.
2822    pub fn reseed_from_data(
2823        mut self,
2824        data: ArrayView2<'_, f64>,
2825        spec: &TermCollectionSpec,
2826        term_indices: &[usize],
2827        options: &SpatialLengthScaleOptimizationOptions,
2828    ) -> Self {
2829        assert_eq!(term_indices.len(), self.dims_per_term.len());
2830        let mut cursor = 0;
2831        for (slot, &term_idx) in term_indices.iter().enumerate() {
2832            let d = self.dims_per_term[slot];
2833            // Measure-jet dials are seeded from the realized spec and must
2834            // not be recentered into a κ data window.
2835            if measure_jet_term_spec(spec, term_idx).is_some() {
2836                cursor += d;
2837                continue;
2838            }
2839            // Constant-curvature κ is seeded from the spec (the user's curvature
2840            // hint, default κ = 0); `clamp_to_bounds` projects it feasible. It
2841            // is not a log-scale, so the log-κ recenter below never applies.
2842            if constant_curvature_term_spec(spec, term_idx).is_some() {
2843                cursor += d;
2844                continue;
2845            }
2846            let Some(psi_bar_new) = spatial_term_psi_seed(data, spec, term_idx, options) else {
2847                cursor += d;
2848                continue;
2849            };
2850            if d == 0 {
2851                continue;
2852            }
2853            let current: Vec<f64> = self.values.slice(s![cursor..cursor + d]).to_vec();
2854            let psi_bar_old = current.iter().sum::<f64>() / d as f64;
2855            for (offset, &old_value) in current.iter().enumerate() {
2856                self.values[cursor + offset] = psi_bar_new + (old_value - psi_bar_old);
2857            }
2858            cursor += d;
2859        }
2860        self
2861    }
2862
2863    /// Project ψ values into `[lower, upper]` element-wise. Used after
2864    /// `from_length_scales*` + `reseed_from_data` when a user-supplied
2865    /// `spec.length_scale` falls outside the data-derived ψ window set by
2866    /// `{lower,upper}_bounds*_from_data`. BFGS requires theta0 ∈ [lower,
2867    /// upper]; projecting is the unique closest feasible seed. The user's
2868    /// length_scale was always a hint for the outer optimizer (the optimizer
2869    /// is authoritative for κ), not a hard constraint — so clipping preserves
2870    /// their intent as far as the geometry allows. Emits `log::info!` when
2871    /// any coordinate moves, so the outside-window case is diagnostically
2872    /// visible (not silent).
2873    pub fn clamp_to_bounds(
2874        mut self,
2875        lower: &SpatialLogKappaCoords,
2876        upper: &SpatialLogKappaCoords,
2877    ) -> Self {
2878        assert_eq!(self.values.len(), lower.values.len());
2879        assert_eq!(self.values.len(), upper.values.len());
2880        let mut n_projected = 0usize;
2881        let mut worst_delta = 0.0_f64;
2882        for idx in 0..self.values.len() {
2883            let lo = lower.values[idx];
2884            let hi = upper.values[idx];
2885            if !(lo.is_finite() && hi.is_finite()) {
2886                continue;
2887            }
2888            let v = self.values[idx];
2889            if v < lo {
2890                worst_delta = worst_delta.max(lo - v);
2891                self.values[idx] = lo;
2892                n_projected += 1;
2893            } else if v > hi {
2894                worst_delta = worst_delta.max(v - hi);
2895                self.values[idx] = hi;
2896                n_projected += 1;
2897            }
2898        }
2899        if n_projected > 0 {
2900            log::info!(
2901                "[spatial-kappa] projected {n_projected}/{} ψ seed coords into data-derived bounds \
2902                 (worst excess={worst_delta:.3} log units); user length_scale falls outside \
2903                 [{KERNEL_RANGE_MIN_DIAMETER_FRACTION}/r_max, {KERNEL_RANGE_MAX_SPACING_MULTIPLE}/r_min] geometry window",
2904                self.values.len()
2905            );
2906        }
2907        self
2908    }
2909
2910    /// Reconstruct from theta tail with known dimensionality layout.
2911    pub fn from_theta_tail_with_dims(
2912        theta: &Array1<f64>,
2913        start: usize,
2914        dims_per_term: Vec<usize>,
2915    ) -> Self {
2916        let total: usize = dims_per_term.iter().sum();
2917        Self {
2918            values: theta.slice(s![start..start + total]).to_owned(),
2919            dims_per_term,
2920        }
2921    }
2922
2923    /// Total number of ψ values in the flat array (= sum of dims_per_term).
2924    pub fn len(&self) -> usize {
2925        self.values.len()
2926    }
2927
2928    /// Dimensionality layout: how many ψ values each term contributes.
2929    pub fn dims_per_term(&self) -> &[usize] {
2930        &self.dims_per_term
2931    }
2932
2933    /// Get the offset into the flat array for logical term i.
2934    fn term_offset(&self, term_idx: usize) -> usize {
2935        self.dims_per_term[..term_idx].iter().sum()
2936    }
2937
2938    /// Get the slice of ψ values for logical term i.
2939    pub fn term_slice(&self, term_idx: usize) -> &[f64] {
2940        let offset = self.term_offset(term_idx);
2941        let d = self.dims_per_term[term_idx];
2942        &self.values.as_slice().unwrap()[offset..offset + d]
2943    }
2944
2945    pub fn as_array(&self) -> &Array1<f64> {
2946        &self.values
2947    }
2948
2949    /// #1464: overwrite the single ψ value of a scalar (1-D) logical term by its
2950    /// position `slot` in this coords vector (the same ordering as the
2951    /// `term_indices` slice the constructors were built from). Used to inject the
2952    /// fixed-κ sign-basin seed into a constant-curvature term's raw-κ slot before
2953    /// the joint solve. No-op (returns `false`) when the slot is not scalar.
2954    pub fn set_scalar_slot(&mut self, slot: usize, value: f64) -> bool {
2955        if slot >= self.dims_per_term.len() || self.dims_per_term[slot] != 1 {
2956            return false;
2957        }
2958        let offset = self.term_offset(slot);
2959        self.values[offset] = value;
2960        true
2961    }
2962
2963    /// Split at a logical-term boundary. `mid` is the number of terms in the
2964    /// first half (not a flat-array index).
2965    pub fn split_at(&self, mid: usize) -> (Self, Self) {
2966        let flat_mid: usize = self.dims_per_term[..mid].iter().sum();
2967        (
2968            Self {
2969                values: self.values.slice(s![0..flat_mid]).to_owned(),
2970                dims_per_term: self.dims_per_term[..mid].to_vec(),
2971            },
2972            Self {
2973                values: self.values.slice(s![flat_mid..]).to_owned(),
2974                dims_per_term: self.dims_per_term[mid..].to_vec(),
2975            },
2976        )
2977    }
2978
2979    /// Apply optimized ψ values back to the spec.
2980    ///
2981    /// For isotropic terms (dims=1): sets scalar length_scale = exp(−ψ).
2982    /// For anisotropic terms (dims=d): hybrid/isotropic families set
2983    /// length_scale = exp(−ψ̄) with centered η_a = ψ_a − ψ̄, while pure Duchon
2984    /// writes only centered η_a and leaves length_scale = None.
2985    pub fn apply_tospec(
2986        &self,
2987        spec: &TermCollectionSpec,
2988        term_indices: &[usize],
2989    ) -> Result<TermCollectionSpec, EstimationError> {
2990        if term_indices.len() != self.dims_per_term.len() {
2991            crate::bail_invalid_estim!(
2992                "SpatialLogKappaCoords::apply_tospec: term count mismatch: \
2993                 term_indices={} dims_per_term={}",
2994                term_indices.len(),
2995                self.dims_per_term.len()
2996            );
2997        }
2998        let mut updated = spec.clone();
2999        for (slot, &term_idx) in term_indices.iter().enumerate() {
3000            let psi = self.term_slice(slot);
3001            let d = self.dims_per_term[slot];
3002            // Measure-jet: write the dial coordinates straight back; the
3003            // κ-translation below would misread them as log-scales.
3004            if measure_jet_term_spec(&updated, term_idx).is_some() {
3005                set_measure_jet_psi_dials(&mut updated, term_idx, psi)?;
3006                continue;
3007            }
3008            // Constant-curvature: write the optimized signed κ straight back;
3009            // the −exp(ψ) length-scale translation below is log-κ semantics and
3010            // would misread the raw curvature.
3011            if constant_curvature_term_spec(&updated, term_idx).is_some() {
3012                set_constant_curvature_kappa(&mut updated, term_idx, psi)?;
3013                continue;
3014            }
3015            let (next_length_scale, next_aniso) = spatial_term_psi_to_length_scale_and_aniso(psi);
3016            if (d == 1 || next_length_scale.is_some())
3017                && let Some(length_scale) = next_length_scale
3018            {
3019                set_spatial_length_scale(&mut updated, term_idx, length_scale)?;
3020            }
3021            if let Some(eta) = next_aniso {
3022                set_spatial_aniso_log_scales(&mut updated, term_idx, eta)?;
3023            }
3024        }
3025        Ok(updated)
3026    }
3027}
3028
3029pub fn center_aniso_log_scales(eta: &[f64]) -> Vec<f64> {
3030    if eta.len() <= 1 {
3031        return eta.to_vec();
3032    }
3033    let mean = eta.iter().sum::<f64>() / eta.len() as f64;
3034    eta.iter()
3035        .map(|&v| {
3036            let centered = v - mean;
3037            if centered.abs() <= 1e-15 {
3038                0.0
3039            } else {
3040                centered
3041            }
3042        })
3043        .collect()
3044}
3045
3046/// Whether a spatial term contributes per-axis ψ entries to the outer joint
3047/// hyperparameter vector.
3048pub fn spatial_term_uses_per_axis_psi(resolvedspec: &TermCollectionSpec, term_idx: usize) -> bool {
3049    if let Some(mj) = measure_jet_term_spec(resolvedspec, term_idx) {
3050        return measure_jet_enrolls_psi(mj);
3051    }
3052    let Some(d) = get_spatial_feature_dim(resolvedspec, term_idx) else {
3053        return false;
3054    };
3055    if d <= 1 {
3056        return false;
3057    }
3058    let Some(eta) = get_spatial_aniso_log_scales(resolvedspec, term_idx) else {
3059        return false;
3060    };
3061    if eta.len() != d {
3062        return false;
3063    }
3064    !matches!(
3065        resolvedspec.smooth_terms.get(term_idx).map(|term| &term.basis),
3066        Some(SmoothBasisSpec::Duchon { .. })
3067    )
3068}
3069
3070pub fn set_spatial_length_scale(
3071    spec: &mut TermCollectionSpec,
3072    term_idx: usize,
3073    length_scale: f64,
3074) -> Result<(), EstimationError> {
3075    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3076        crate::bail_invalid_estim!("spatial length-scale term index {term_idx} out of range");
3077    };
3078    match &mut term.basis {
3079        SmoothBasisSpec::ThinPlate { spec, .. } => {
3080            spec.length_scale = length_scale;
3081            Ok(())
3082        }
3083        SmoothBasisSpec::Matern { spec, .. } => {
3084            spec.length_scale = length_scale;
3085            Ok(())
3086        }
3087        SmoothBasisSpec::Duchon { spec, .. } => {
3088            spec.length_scale = Some(length_scale);
3089            Ok(())
3090        }
3091        _ => Err(EstimationError::InvalidInput(format!(
3092            "term '{}' does not expose a spatial length scale",
3093            term.name
3094        ))),
3095    }
3096}
3097
3098pub fn get_spatial_length_scale(spec: &TermCollectionSpec, term_idx: usize) -> Option<f64> {
3099    spec.smooth_terms
3100        .get(term_idx)
3101        .and_then(|term| match &term.basis {
3102            SmoothBasisSpec::ThinPlate { spec, .. } => Some(spec.length_scale),
3103            SmoothBasisSpec::Matern { spec, .. } => Some(spec.length_scale),
3104            SmoothBasisSpec::Duchon { spec, .. } => spec.length_scale,
3105            _ => None,
3106        })
3107}
3108
3109pub fn spatial_term_supports_hyper_optimization(spec: &TermCollectionSpec, term_idx: usize) -> bool {
3110    // Ordinary penalized thin-plate regression splines do not have an
3111    // identifiable kernel scale once REML is already learning the smoothing
3112    // penalty. Treat the resolved length scale as fixed geometry; enrolling a
3113    // scalar TPS kappa axis creates the flat ρ/κ valleys reported in #718,
3114    // #721, #731, and #732.
3115    if let Some(term) = spec.smooth_terms.get(term_idx)
3116        && let SmoothBasisSpec::ThinPlate { .. } = &term.basis
3117    {
3118        return false;
3119    }
3120
3121    // Duchon anisotropy η is a FIXED, geometry-derived basis parameter, NOT a
3122    // REML hyper axis: the metric is estimated once from the knot-cloud spread
3123    // (`auto_seed_aniso_contrasts`, applied on every Duchon basis build) and the
3124    // Hilbert-scale λ's carry all learned smoothness. So a pure Duchon (no κ)
3125    // contributes no outer optimization axis even when `scale_dims` is on —
3126    // "standardize the geometry, then learn the smoothness." Only an explicit
3127    // kernel length scale κ (the Matérn / hybrid path) is optimized here.
3128    //
3129    // ISOTROPIC Matérn: the *default* `matern(x1, x2)` is isotropic
3130    // (`scale_dims=false` → `aniso_log_scales = None`). It contributes exactly
3131    // ONE κ optimization axis — its scalar log-κ. The shared GAMLSS /
3132    // location-scale exact-joint ψ engine and the spatial-κ joint outer solver
3133    // both require an isotropic Matérn block to expose this single isotropic κ
3134    // axis (#822/#851); without it the per-block ψ-derivative lists are empty
3135    // and the joint-ψ hooks degenerate to `None`. The isotropic κ is the lone
3136    // kernel hyper axis here, mirroring the per-axis ψ ARD that the anisotropic
3137    // path exposes (just collapsed to one dimension).
3138    //
3139    // ANISOTROPIC Matérn (`scale_dims=true` → `aniso_log_scales = Some`) keeps
3140    // its per-axis kernel-η ARD: the d-dimensional ψ search is the *point* of
3141    // the anisotropic request ("Matérn keeps its kernel-η ARD").
3142    //
3143    // Either way a Matérn term always enrolls a κ/ψ axis (1 isotropic, or d
3144    // anisotropic), so `spatial_dims_per_term` reports the correct count.
3145    if let Some(term) = spec.smooth_terms.get(term_idx)
3146        && let SmoothBasisSpec::Matern { .. } = &term.basis
3147    {
3148        return true;
3149    }
3150
3151    // Measure-jet geometry dials are outer ψ coordinates; enrollment is
3152    // owned by `measure_jet_enrolls_psi`.
3153    if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
3154        return measure_jet_enrolls_psi(mj);
3155    }
3156
3157    // Constant-curvature smooths always enroll their single signed curvature κ
3158    // as an outer ψ-coordinate (#944 stage 3): κ̂ is the headline estimand, so
3159    // unlike a fixed-ℓ kernel it is fitted by default, not gated on a
3160    // user-supplied scale. The coordinate is raw κ (interior κ = 0), and its
3161    // exact design/penalty κ-derivatives come from
3162    // `build_constant_curvature_basis_kappa_derivatives`.
3163    if constant_curvature_term_spec(spec, term_idx).is_some() {
3164        return true;
3165    }
3166
3167    get_spatial_length_scale(spec, term_idx).is_some()
3168}
3169
3170/// The measure-jet term's spec, when `term_idx` is a measure-jet smooth.
3171/// Single accessor for every dial-plumbing dispatch below.
3172pub fn measure_jet_term_spec(
3173    spec: &TermCollectionSpec,
3174    term_idx: usize,
3175) -> Option<&crate::basis::MeasureJetBasisSpec> {
3176    spec.smooth_terms
3177        .get(term_idx)
3178        .and_then(|term| match &term.basis {
3179            SmoothBasisSpec::MeasureJet { spec, .. } => Some(spec),
3180            _ => None,
3181        })
3182}
3183
3184/// Single source for measure-jet outer-ψ enrollment: the lnτ dial is
3185/// undefined in the τ = 0 pseudo-inverse oracle mode (see
3186/// `build_measure_jet_basis_psi_derivatives`), so only a positive ridge
3187/// enrolls the dial group. `spatial_term_supports_hyper_optimization` and
3188/// `spatial_term_uses_per_axis_psi` both defer here so the θ-layout
3189/// sources cannot disagree.
3190pub fn measure_jet_enrolls_psi(mj: &crate::basis::MeasureJetBasisSpec) -> bool {
3191    // Two independent enrollment sources (#1116), both explicit:
3192    //   * the design-moving representer length-scale ℓ (`learn_length_scale`),
3193    //     available in every mode when the spec opts in;
3194    //   * the multiscale penalty dials (s, α, lnτ): the per-scale spectral
3195    //     split's (α, lnτ) ride the explicit `multiscale` opt-in, and the lnτ
3196    //     channel additionally needs a positive ridge (τ = 0 is the
3197    //     pseudo-inverse oracle mode where lnτ is undefined).
3198    // A term enrolls if EITHER source is active.
3199    measure_jet_learns_length_scale(mj)
3200        || (mj.tau0 > 0.0 && crate::basis::measure_jet_multiscale_mode(mj))
3201}
3202
3203/// Whether the design-moving ℓ dial is enrolled for this term. ℓ is fixed by
3204/// default and learnable in every mode only when `learn_length_scale = true`.
3205pub fn measure_jet_learns_length_scale(mj: &crate::basis::MeasureJetBasisSpec) -> bool {
3206    mj.learn_length_scale
3207}
3208
3209pub fn freeze_measure_jet_length_scale_learning(spec: &mut TermCollectionSpec) -> usize {
3210    let mut frozen = 0;
3211    for term in spec.smooth_terms.iter_mut() {
3212        if let SmoothBasisSpec::MeasureJet { spec: mj, .. } = &mut term.basis
3213            && mj.learn_length_scale
3214        {
3215            mj.learn_length_scale = false;
3216            frozen += 1;
3217        }
3218    }
3219    frozen
3220}
3221
3222/// Measure-jet ψ dial boxes. The dials are NOT log-kernel-scales, so the
3223/// κ-window machinery never applies: `α` spans density-weighted (0) through
3224/// past-Coifman–Lafon (>1) normalization, and `lnτ` covers the ridge from
3225/// numerically-exact-projection to heavy noise-floor damping. (The energy
3226/// order `s` is the pinned explicit value or absorbed by the REML-learned
3227/// per-scale amplitudes — see `measure_jet_penalty_psi_dim` — so it carries no
3228/// dial box.)
3229pub const MEASURE_JET_PSI_ALPHA_BOUNDS: (f64, f64) = (-1.0, 3.0);
3230
3231pub const MEASURE_JET_PSI_LN_TAU_BOUNDS: (f64, f64) = (-18.420680743952367, 4.605170185988092);
3232
3233/// Log-ℓ box for the design-moving representer length-scale dial (#1116). An
3234/// ABSOLUTE window in the data coordinate scale (ln of ℓ ∈ [1e-3, 1e2]) used
3235/// only when the spec explicitly enrolls the learned representer range. Absolute
3236/// (not seed-relative) so the bound producer needs no data view, matching the
3237/// other dial boxes. `ln(1e-3) = -6.9077…`, `ln(1e2) = 4.6051…`.
3238pub const MEASURE_JET_PSI_LN_LENGTH_SCALE_BOUNDS: (f64, f64) = (-6.907755278982137, 4.605170185988092);
3239
3240/// Number of multiscale PENALTY dials (excluding the design-moving ℓ):
3241/// multiscale (per-scale spectral) mode carries (α, lnτ) = 2 — the order is
3242/// either the pinned explicit `s` or absorbed by the REML-learned per-scale
3243/// amplitudes, so it is NOT a dial; single-scale (the default) carries none.
3244/// MUST agree with the penalty-coordinate layout of
3245/// `build_measure_jet_basis_psi_derivatives` (its `per_level` branch always
3246/// emits exactly the (α, lnτ) coordinate pair).
3247pub fn measure_jet_penalty_psi_dim(mj: &crate::basis::MeasureJetBasisSpec) -> usize {
3248    if crate::basis::measure_jet_multiscale_mode(mj) {
3249        2
3250    } else {
3251        0
3252    }
3253}
3254
3255/// ψ dimension of a measure-jet term. The design-moving ℓ dial (when enrolled)
3256/// is coordinate 0; the multiscale penalty dials follow. MUST agree with the
3257/// coordinate layout of `build_measure_jet_basis_psi_derivatives` (ℓ first).
3258pub fn measure_jet_psi_dim(mj: &crate::basis::MeasureJetBasisSpec) -> usize {
3259    usize::from(measure_jet_learns_length_scale(mj)) + measure_jet_penalty_psi_dim(mj)
3260}
3261
3262/// Seed ψ from the term's realized dials, in producer coordinate order: ℓ first
3263/// (when enrolled), then the multiscale penalty dials. The ℓ seed is the
3264/// realized representer range `ln(length_scale)` (the resolved spec carries the
3265/// concrete auto value after the design build/freeze).
3266pub fn measure_jet_psi_seed(mj: &crate::basis::MeasureJetBasisSpec) -> Vec<f64> {
3267    let mut seed = Vec::with_capacity(measure_jet_psi_dim(mj));
3268    if measure_jet_learns_length_scale(mj) {
3269        // length_scale > 0 after resolution; the 0.0 sentinel (pre-resolution)
3270        // falls back to the centre of the log-ℓ box so the optimizer still
3271        // starts feasible and the first data-aware reseed corrects it.
3272        let ell = if mj.length_scale > 0.0 {
3273            mj.length_scale
3274        } else {
3275            1.0
3276        };
3277        seed.push(ell.ln());
3278    }
3279    if measure_jet_penalty_psi_dim(mj) > 0 {
3280        // Multiscale penalty dials, producer order: (α, lnτ).
3281        let ln_tau = mj.tau0.max(f64::MIN_POSITIVE).ln();
3282        seed.extend_from_slice(&[mj.alpha, ln_tau]);
3283    }
3284    seed
3285}
3286
3287/// One end of the per-coordinate dial boxes, in producer coordinate order
3288/// (ℓ first when enrolled, then the multiscale penalty dials).
3289pub fn measure_jet_psi_bound_values(mj: &crate::basis::MeasureJetBasisSpec, upper: bool) -> Vec<f64> {
3290    let pick = |b: (f64, f64)| if upper { b.1 } else { b.0 };
3291    let mut bounds = Vec::with_capacity(measure_jet_psi_dim(mj));
3292    if measure_jet_learns_length_scale(mj) {
3293        bounds.push(pick(MEASURE_JET_PSI_LN_LENGTH_SCALE_BOUNDS));
3294    }
3295    if measure_jet_penalty_psi_dim(mj) > 0 {
3296        // Multiscale penalty dials, producer order: (α, lnτ).
3297        bounds.push(pick(MEASURE_JET_PSI_ALPHA_BOUNDS));
3298        bounds.push(pick(MEASURE_JET_PSI_LN_TAU_BOUNDS));
3299    }
3300    bounds
3301}
3302
3303/// Write optimized ψ dials back into a measure-jet spec. Returns `true` when
3304/// any dial actually moved. The geometry (centers, masses, band, ℓ, z) is
3305/// ψ-FIXED by contract — only the dials change, so frozen-quadrature
3306/// rebuilds reproduce the identical penalty layout at the new dials.
3307pub fn apply_measure_jet_psi(
3308    mj: &mut crate::basis::MeasureJetBasisSpec,
3309    psi: &[f64],
3310) -> Result<bool, EstimationError> {
3311    if psi.len() != measure_jet_psi_dim(mj) {
3312        crate::bail_invalid_estim!(
3313            "measure-jet ψ write-back dimension mismatch: got {} values for a {}-dial term",
3314            psi.len(),
3315            measure_jet_psi_dim(mj)
3316        );
3317    }
3318    let mut changed = false;
3319    // Coordinate 0 (when enrolled) is the design-moving ln(ℓ); the multiscale
3320    // penalty dials follow. Same order as `measure_jet_psi_seed` and the
3321    // producer (`build_measure_jet_basis_psi_derivatives`).
3322    let mut cursor = 0usize;
3323    if measure_jet_learns_length_scale(mj) {
3324        let next_ell = psi[cursor].exp();
3325        cursor += 1;
3326        if !(next_ell.is_finite() && next_ell > 0.0) {
3327            crate::bail_invalid_estim!(
3328                "measure-jet ψ write-back produced a non-finite/non-positive length_scale (ℓ={next_ell})"
3329            );
3330        }
3331        if next_ell != mj.length_scale {
3332            mj.length_scale = next_ell;
3333            changed = true;
3334        }
3335    }
3336    if measure_jet_penalty_psi_dim(mj) > 0 {
3337        // Multiscale penalty dials, producer order: (α, lnτ). The order `s` is
3338        // not a dial (pinned explicit or absorbed by the per-scale amplitudes).
3339        let next_alpha = psi[cursor];
3340        let next_tau = psi[cursor + 1].exp();
3341        if !(next_alpha.is_finite() && next_tau.is_finite() && next_tau > 0.0) {
3342            crate::bail_invalid_estim!(
3343                "measure-jet ψ write-back produced non-finite dials (alpha={next_alpha}, tau={next_tau})"
3344            );
3345        }
3346        if next_alpha != mj.alpha {
3347            mj.alpha = next_alpha;
3348            changed = true;
3349        }
3350        if next_tau != mj.tau0 {
3351            mj.tau0 = next_tau;
3352            changed = true;
3353        }
3354    }
3355    Ok(changed)
3356}
3357
3358/// Collection-level measure-jet dial write-back (the `apply_tospec` /
3359/// realizer-side entry). Returns whether anything moved.
3360pub fn set_measure_jet_psi_dials(
3361    spec: &mut TermCollectionSpec,
3362    term_idx: usize,
3363    psi: &[f64],
3364) -> Result<bool, EstimationError> {
3365    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3366        crate::bail_invalid_estim!("measure-jet ψ write-back: term index {term_idx} out of range");
3367    };
3368    set_single_term_measure_jet_psi_dials(term, psi)
3369}
3370
3371/// Single-term dial write-back: the shared match+apply core, also used
3372/// directly on the cached per-trial build spec (whose caller has already
3373/// change-checked at the collection level and rebuilds regardless of the
3374/// moved flag).
3375pub fn set_single_term_measure_jet_psi_dials(
3376    term: &mut SmoothTermSpec,
3377    psi: &[f64],
3378) -> Result<bool, EstimationError> {
3379    let SmoothBasisSpec::MeasureJet { spec: mj, .. } = &mut term.basis else {
3380        crate::bail_invalid_estim!("measure-jet ψ write-back targeted a non-measure-jet term");
3381    };
3382    apply_measure_jet_psi(mj, psi)
3383}
3384
3385/// The constant-curvature smooth's spec, when `term_idx` is one. Single
3386/// accessor for every κ-ψ dispatch below, mirroring `measure_jet_term_spec`.
3387pub fn constant_curvature_term_spec(
3388    spec: &TermCollectionSpec,
3389    term_idx: usize,
3390) -> Option<&crate::basis::ConstantCurvatureBasisSpec> {
3391    spec.smooth_terms
3392        .get(term_idx)
3393        .and_then(|term| match &term.basis {
3394            SmoothBasisSpec::ConstantCurvature { spec, .. } => Some(spec),
3395            _ => None,
3396        })
3397}
3398
3399/// Hard positive cap on |κ| relative to the data's inverse squared chart
3400/// radius. The κ-stereographic chart is valid for `1 + κ‖x‖² > 0`; at
3401/// `|κ| = 1/R²` (R² = max squared chart radius) the gauge `1 + κ‖x‖²` reaches
3402/// the chart edge for the farthest data point, so the optimizer is boxed to a
3403/// safe fraction of that scale on both sides. κ = 0 (flat) is the centre of
3404/// the window, an interior point of the `S^d ← ℝ^d → H^d` family — exactly the
3405/// reachability the raw-κ (not log-κ) coordinate exists to preserve.
3406pub const CONSTANT_CURVATURE_KAPPA_CHART_FRACTION: f64 = 0.5;
3407
3408/// Floor on the data's squared chart radius used to scale the κ window, so a
3409/// degenerate (near-origin) point cloud still yields a finite, usable bracket
3410/// rather than an unbounded one.
3411pub const CONSTANT_CURVATURE_MIN_CHART_RADIUS2: f64 = 1e-8;
3412
3413/// `(κ_min, κ_max)` outer-optimization window for a constant-curvature term,
3414/// derived from the data's maximum squared chart radius `R²` so the κ-jets
3415/// never leave the κ-stereographic chart. Symmetric about κ = 0:
3416/// `±CONSTANT_CURVATURE_KAPPA_CHART_FRACTION / R²`.
3417pub fn constant_curvature_kappa_bounds(
3418    data: ArrayView2<'_, f64>,
3419    spec: &TermCollectionSpec,
3420    term_idx: usize,
3421) -> (f64, f64) {
3422    let feature_cols = match spec.smooth_terms.get(term_idx).map(|t| &t.basis) {
3423        Some(SmoothBasisSpec::ConstantCurvature { feature_cols, .. }) => feature_cols,
3424        _ => return (-1.0, 1.0),
3425    };
3426    let mut max_r2 = CONSTANT_CURVATURE_MIN_CHART_RADIUS2;
3427    for row in data.outer_iter() {
3428        let mut r2 = 0.0_f64;
3429        for &c in feature_cols.iter() {
3430            if let Some(&v) = row.get(c)
3431                && v.is_finite()
3432            {
3433                r2 += v * v;
3434            }
3435        }
3436        if r2 > max_r2 {
3437            max_r2 = r2;
3438        }
3439    }
3440    let half = CONSTANT_CURVATURE_KAPPA_CHART_FRACTION / max_r2;
3441    (-half, half)
3442}
3443
3444/// Write the optimized κ back into a constant-curvature term spec. Returns
3445/// `true` when κ moved. Centers, ℓ, and the constraint transform `z` are
3446/// κ-FIXED by the basis κ-contract, so only `kappa` changes.
3447pub fn set_constant_curvature_kappa(
3448    spec: &mut TermCollectionSpec,
3449    term_idx: usize,
3450    psi: &[f64],
3451) -> Result<bool, EstimationError> {
3452    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3453        crate::bail_invalid_estim!(
3454            "constant-curvature κ write-back: term index {term_idx} out of range"
3455        );
3456    };
3457    set_single_term_constant_curvature_kappa(term, psi)
3458}
3459
3460/// Single-term κ write-back: the shared validate+apply core, also used directly
3461/// on the cached per-trial build spec in the incremental realizer (whose caller
3462/// has already change-checked at the collection level and rebuilds regardless
3463/// of the moved flag). Mirrors [`set_single_term_measure_jet_psi_dials`].
3464pub fn set_single_term_constant_curvature_kappa(
3465    term: &mut SmoothTermSpec,
3466    psi: &[f64],
3467) -> Result<bool, EstimationError> {
3468    if psi.len() != 1 {
3469        crate::bail_invalid_estim!(
3470            "constant-curvature κ write-back expects exactly one value, got {}",
3471            psi.len()
3472        );
3473    }
3474    let next_kappa = psi[0];
3475    if !next_kappa.is_finite() {
3476        crate::bail_invalid_estim!(
3477            "constant-curvature κ write-back produced a non-finite κ = {next_kappa}"
3478        );
3479    }
3480    let SmoothBasisSpec::ConstantCurvature { spec: cc, .. } = &mut term.basis else {
3481        crate::bail_invalid_estim!(
3482            "constant-curvature κ write-back targeted a non-constant-curvature term"
3483        );
3484    };
3485    if cc.kappa != next_kappa {
3486        cc.kappa = next_kappa;
3487        Ok(true)
3488    } else {
3489        Ok(false)
3490    }
3491}
3492
3493/// Returns `true` when a spatial term has NO outer optimization axes — i.e.
3494/// the user provided an explicit `length_scale` and the term does not enroll
3495/// REML-side per-axis ψ contrasts, so both the scalar κ and any fixed geometry
3496/// anisotropy are anchored.
3497///
3498/// This is the per-term predicate that distinguishes "fixed kernel scale"
3499/// from "optimize the kernel scale" within the family entry points that
3500/// want to honor an explicit user-supplied scale (e.g. Bernoulli
3501/// marginal-slope, where the joint-spatial outer solver otherwise spends
3502/// ~80 iters stalled on the user's chosen ρ at high gradient).
3503pub fn spatial_term_has_locked_kappa(spec: &TermCollectionSpec, term_idx: usize) -> bool {
3504    get_spatial_length_scale(spec, term_idx).is_some()
3505        && !spatial_term_uses_per_axis_psi(spec, term_idx)
3506}
3507
3508pub fn all_spatial_terms_kappa_fixed(spec: &TermCollectionSpec) -> bool {
3509    spec.smooth_terms.iter().enumerate().all(|(idx, _)| {
3510        !spatial_term_supports_hyper_optimization(spec, idx)
3511            || spatial_term_has_locked_kappa(spec, idx)
3512    })
3513}
3514
3515pub fn spatial_identifiability_policy(termspec: &SmoothTermSpec) -> Option<&SpatialIdentifiability> {
3516    match &termspec.basis {
3517        SmoothBasisSpec::ThinPlate { spec, .. } => Some(&spec.identifiability),
3518        SmoothBasisSpec::Duchon { spec, .. } => Some(&spec.identifiability),
3519        _ => None,
3520    }
3521}
3522
3523/// Standard deviation of the wide, weakly-informative symmetric `Normal` prior
3524/// placed on a relaxable double-penalty smooth's `DoublePenaltyNullspace`
3525/// selection coordinate when the fit is well-determined.
3526pub const NULLSPACE_WELLDET_DEGENERACY_RHO_SD: f64 = 15.0;
3527
3528/// True iff `prior` is the well-determined double-penalty null-space
3529/// degeneracy prior placed on a `DoublePenaltyNullspace` selection coordinate.
3530pub fn is_nullspace_degeneracy_prior(prior: &gam_spec::RhoPrior) -> bool {
3531    matches!(
3532        prior,
3533        gam_spec::RhoPrior::Normal { mean, sd }
3534            if *mean == 0.0 && *sd == NULLSPACE_WELLDET_DEGENERACY_RHO_SD
3535    )
3536}
3537
3538/// Per-term data-derived ψ = log κ bounds.
3539///
3540/// Uses the same safe operating range documented in
3541/// [`crate::basis::build_matern_basis`] / [`crate::basis::build_duchon_basis`]:
3542///   κ ∈ [2 / r_max, 1e2 / r_min]
3543/// where (r_min, r_max) are pairwise-distance extrema of the term's resolved
3544/// centers (post-fit) or the standardized feature data columns (pre-fit).
3545/// Lower edge of the data-derived kernel-range window, as a fraction of the
3546/// maximum pairwise distance `r_max`: length scales below `2/r_max` resolve
3547/// structure finer than the closest center pair, so the kernel range floor is
3548/// set at twice the maximum spacing.
3549pub const KERNEL_RANGE_MIN_DIAMETER_FRACTION: f64 = 2.0;
3550
3551/// Upper edge of the data-derived kernel-range window, as a multiple of the
3552/// minimum pairwise distance `r_min`: beyond `100/r_min` the radial columns go
3553/// nearly collinear with the polynomial nullspace, so the kernel range is
3554/// capped here to keep the basis geometry well-conditioned.
3555pub const KERNEL_RANGE_MAX_SPACING_MULTIPLE: f64 = 1e2;
3556
3557
3558/// Returns ψ-space bounds (ψ_lo = ln(κ_lo), ψ_hi = ln(κ_hi)).
3559///
3560/// When geometry is unavailable (e.g., fewer than 2 distinct points), falls
3561/// back to the scalar `options.min_length_scale` / `options.max_length_scale`
3562/// window so the outer optimizer never sees NaN bounds.
3563///
3564/// The returned window is intersected with the options window so user-set
3565/// `min_length_scale` / `max_length_scale` remain hard limits.
3566pub fn spatial_term_psi_bounds(
3567    data: ArrayView2<'_, f64>,
3568    spec: &TermCollectionSpec,
3569    term_idx: usize,
3570    options: &SpatialLengthScaleOptimizationOptions,
3571) -> (f64, f64) {
3572    let fallback = (
3573        -options.max_length_scale.ln(),
3574        -options.min_length_scale.ln(),
3575    );
3576    // Constant-curvature: the ψ coordinate is the raw signed κ, so its window is
3577    // the chart-feasible κ bracket, NOT a log-ℓ window. Mirrors the aniso bounds
3578    // path's `constant_curvature_kappa_bounds` branch so the isotropic
3579    // (non-aniso) seed clamp projects κ into the right interval.
3580    if constant_curvature_term_spec(spec, term_idx).is_some() {
3581        return constant_curvature_kappa_bounds(data, spec, term_idx);
3582    }
3583    let Some(term) = spec.smooth_terms.get(term_idx) else {
3584        return fallback;
3585    };
3586    // Prefer resolved centers (post-fit) since they live in the same standardized
3587    // space the kernel actually sees. Centers are capped at `default_num_centers`
3588    // (<=2000), so exact pairwise bounds are cheap (<4M ops). If centers are
3589    // not yet UserProvided, fall back to the standardized feature data columns
3590    // with the capped-sample path (O(K²·d), K=1024) — the sample is
3591    // conservative for κ bounds (see `pairwise_distance_bounds_sampled`
3592    // docs): it never excludes a feasible κ the exact method would include.
3593    //
3594    // Under anisotropy the kernel metric is y-space (y_a = exp(η_a) x_a),
3595    // so r_min/r_max must be y-space distances. This matters only when the
3596    // spec already carries calibrated η_a at setup time (e.g., warm-start
3597    // or refit paths); for fresh optimization η_a starts at 0 and y = x.
3598    let aniso = get_spatial_aniso_log_scales(spec, term_idx);
3599    let r_bounds = match spatial_term_center_strategy(term) {
3600        Some(CenterStrategy::UserProvided(centers)) if centers.nrows() >= 2 => {
3601            match aniso.as_deref() {
3602                Some(eta) if eta.len() == centers.ncols() => {
3603                    let y = points_in_aniso_y_space(centers.view(), eta);
3604                    pairwise_distance_bounds(y.view())
3605                }
3606                _ => pairwise_distance_bounds(centers.view()),
3607            }
3608        }
3609        _ => standardized_spatial_term_data(data, term)
3610            .ok()
3611            .and_then(|x| match aniso.as_deref() {
3612                Some(eta) if eta.len() == x.ncols() => {
3613                    let y = points_in_aniso_y_space(x.view(), eta);
3614                    pairwise_distance_bounds_sampled(y.view())
3615                }
3616                _ => pairwise_distance_bounds_sampled(x.view()),
3617            }),
3618    };
3619    let Some((r_min, r_max)) = r_bounds else {
3620        return fallback;
3621    };
3622    // Length scales substantially larger than the data diameter make radial
3623    // TPS/Matern columns nearly collinear with their polynomial nullspace.
3624    // The nullspace already carries constant/linear low-frequency structure,
3625    // so cap the kernel range at the diameter scale instead of letting the
3626    // optimizer enter a numerically degenerate basis geometry.
3627    let psi_lo_data = (KERNEL_RANGE_MIN_DIAMETER_FRACTION / r_max).ln();
3628    let psi_hi_data = (KERNEL_RANGE_MAX_SPACING_MULTIPLE / r_min).ln();
3629    // #1074: the Matérn-specific length-scale ceiling that used to live here was
3630    // deleted. It was masking, not fixing, the real defect: a hard upper bound on
3631    // the kernel range that pinned the κ-optimizer short rather than letting the
3632    // optimizer find the REML optimum. Matérn now shares the same generic geometry
3633    // window as Duchon / TPS (`KERNEL_RANGE_MIN_DIAMETER_FRACTION / r_max` floor,
3634    // `KERNEL_RANGE_MAX_SPACING_MULTIPLE / r_min` ceiling); the #1357 fully-flat
3635    // collapse corner is guarded by the EDF-collapse guard in
3636    // `spatial_optimization.rs`, which acts on the realized fit, not on a clamp.
3637    // Intersect with the options window so min/max_length_scale remain hard caps.
3638    let psi_lo = psi_lo_data.max(fallback.0);
3639    let psi_hi = psi_hi_data.min(fallback.1);
3640    if psi_lo >= psi_hi {
3641        // Degenerate intersection — fall back to the options window to keep the
3642        // outer optimizer from collapsing to a point.
3643        return fallback;
3644    }
3645    (psi_lo, psi_hi)
3646}
3647
3648/// Data-derived ψ seed for a spatial term when the user has not set an
3649/// explicit length_scale on its basis spec. Uses the geometric mean of the
3650/// data-informed kappa range (i.e., the midpoint of the ψ window).
3651pub fn spatial_term_psi_seed(
3652    data: ArrayView2<'_, f64>,
3653    spec: &TermCollectionSpec,
3654    term_idx: usize,
3655    options: &SpatialLengthScaleOptimizationOptions,
3656) -> Option<f64> {
3657    if get_spatial_length_scale(spec, term_idx).is_some() {
3658        return None; // user/spec-provided length_scale wins
3659    }
3660    let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, spec, term_idx, options);
3661    Some(0.5 * (psi_lo + psi_hi))
3662}
3663
3664pub fn spatial_term_psi_to_length_scale_and_aniso(psi: &[f64]) -> (Option<f64>, Option<Vec<f64>>) {
3665    if psi.len() <= 1 {
3666        (Some((-psi.first().copied().unwrap_or(0.0)).exp()), None)
3667    } else {
3668        let psi_bar = psi.iter().sum::<f64>() / psi.len() as f64;
3669        (
3670            Some((-psi_bar).exp()),
3671            Some(psi.iter().map(|&value| value - psi_bar).collect()),
3672        )
3673    }
3674}
3675
3676/// Get the `aniso_log_scales` from a spatial term, if present.
3677pub fn get_spatial_aniso_log_scales(
3678    spec: &TermCollectionSpec,
3679    term_idx: usize,
3680) -> Option<Vec<f64>> {
3681    spec.smooth_terms
3682        .get(term_idx)
3683        .and_then(|term| match &term.basis {
3684            SmoothBasisSpec::Matern { spec, .. } => spec.aniso_log_scales.clone(),
3685            SmoothBasisSpec::Duchon { spec, .. } => spec.aniso_log_scales.clone(),
3686            _ => None,
3687        })
3688}
3689
3690/// Per-axis response-structure score for anisotropy seeding.
3691///
3692/// For each spatial axis `a`, sort the response `y` by the axis coordinate
3693/// `x_a` and measure the total squared successive variation of the sorted
3694/// response, `tv_a = Σ_i (y_{σ(i+1)} − y_{σ(i)})²` where `σ` orders rows by
3695/// `x_a`. An axis that carries real (possibly nonlinear) signal makes `y` vary
3696/// SMOOTHLY when the rows are walked in that axis's order, so `tv_a` is SMALL;
3697/// a pure-nuisance axis leaves `y` looking unordered, so `tv_a` is LARGE.
3698///
3699/// This deliberately does NOT use a linear correlation `corr(x_a, y)`: for an
3700/// odd, symmetric signal such as `sin(2·x1)` over a symmetric domain the linear
3701/// correlation is ~0 on the *signal* axis, which would misdirect the seed. The
3702/// total-variation-of-sorted-response score captures nonlinear association.
3703///
3704/// Returns `score_a = −½·ln(tv_a + ε)` (larger ⇒ more signal on axis `a`),
3705/// centered to sum to zero, or `None` when the data is degenerate (too few
3706/// rows, non-finite, or all axes equally (un)structured). The caller adds a
3707/// BOUNDED multiple of this to the geometry seed — it is a conservative nudge,
3708/// never a hard override.
3709pub fn response_aware_axis_contrasts(
3710    x: ndarray::ArrayView2<'_, f64>,
3711    y: ndarray::ArrayView1<'_, f64>,
3712) -> Option<Vec<f64>> {
3713    let n = x.nrows();
3714    let d = x.ncols();
3715    if d <= 1 || n < 4 || y.len() != n {
3716        return None;
3717    }
3718    if x.iter().any(|v| !v.is_finite()) || y.iter().any(|v| !v.is_finite()) {
3719        return None;
3720    }
3721    let mut scores = Vec::with_capacity(d);
3722    for a in 0..d {
3723        let mut order: Vec<usize> = (0..n).collect();
3724        let col = x.column(a);
3725        order.sort_by(|&i, &j| {
3726            col[i]
3727                .partial_cmp(&col[j])
3728                .unwrap_or(std::cmp::Ordering::Equal)
3729        });
3730        let mut tv = 0.0_f64;
3731        for w in order.windows(2) {
3732            let diff = y[w[1]] - y[w[0]];
3733            tv += diff * diff;
3734        }
3735        // ε guards against ln(0) on a perfectly flat / constant response.
3736        scores.push(-0.5 * (tv + 1e-12).ln());
3737    }
3738    if scores.iter().any(|v| !v.is_finite()) {
3739        return None;
3740    }
3741    let mean = scores.iter().sum::<f64>() / d as f64;
3742    let centered: Vec<f64> = scores.iter().map(|&s| s - mean).collect();
3743    // If every axis is equally structured the centered scores are ~0 and the
3744    // nudge is a no-op — return None so the geometry seed is used unchanged.
3745    if centered.iter().all(|&v| v.abs() < 1e-9) {
3746        return None;
3747    }
3748    Some(centered)
3749}
3750
3751/// Conservative, response-aware anisotropy seed nudge applied before the κ outer
3752/// loop. For each anisotropic spatial term it adds a BOUNDED multiple of the
3753/// per-axis response-structure contrast (`response_aware_axis_contrasts`) on top
3754/// of the existing geometry seed, so the optimizer starts in the correct basin
3755/// instead of at a response-blind near-symmetric point (the #1376 under-recovery
3756/// where a signal axis and a nuisance axis with equal coordinate spread seed to
3757/// ~[0,0]). The nudge is clamped to keep this a perturbation, never a hard
3758/// override, so shared aniso Matérn/Duchon fits cannot be destabilized by it.
3759pub fn apply_response_aware_anisotropy_seed(
3760    data: ArrayView2<'_, f64>,
3761    y: ndarray::ArrayView1<'_, f64>,
3762    spec: &mut TermCollectionSpec,
3763    spatial_terms: &[usize],
3764) {
3765    // Bound on the per-axis contrast nudge (in η units). One LN_2 ≈ 0.69 halves
3766    // the effective per-axis length scale; capping at LN_2 keeps the seed within
3767    // one optimizer log-step of the geometry seed while still breaking the
3768    // symmetric-seed trap.
3769    const MAX_NUDGE: f64 = std::f64::consts::LN_2;
3770    for &term_idx in spatial_terms {
3771        let Some(current_eta) = get_spatial_aniso_log_scales(spec, term_idx) else {
3772            continue;
3773        };
3774        let d = current_eta.len();
3775        if d <= 1 {
3776            continue;
3777        }
3778        let Some(term) = spec.smooth_terms.get(term_idx) else {
3779            continue;
3780        };
3781        let feature_cols = term.basis.structural_feature_cols();
3782        if feature_cols.len() != d {
3783            continue;
3784        }
3785        let Ok(x) = select_columns(data, &feature_cols) else {
3786            continue;
3787        };
3788        let Some(contrast) = response_aware_axis_contrasts(x.view(), y) else {
3789            continue;
3790        };
3791        let nudged: Vec<f64> = current_eta
3792            .iter()
3793            .zip(contrast.iter())
3794            .map(|(&eta_a, &c_a)| eta_a + c_a.clamp(-MAX_NUDGE, MAX_NUDGE))
3795            .collect();
3796        // `set_spatial_aniso_log_scales` re-centers to Σ η = 0. A term that does
3797        // not support aniso scales is silently skipped (the seed is optional).
3798        if let Err(err) = set_spatial_aniso_log_scales(spec, term_idx, nudged) {
3799            log::debug!(
3800                "[spatial-kappa] response-aware anisotropy seed skipped for term {term_idx}: {err}"
3801            );
3802        }
3803    }
3804}
3805
3806/// Get the number of feature columns (spatial dimensionality) for a spatial term.
3807pub fn get_spatial_feature_dim(spec: &TermCollectionSpec, term_idx: usize) -> Option<usize> {
3808    spec.smooth_terms
3809        .get(term_idx)
3810        .and_then(|term| match &term.basis {
3811            SmoothBasisSpec::ThinPlate { feature_cols, .. } => Some(feature_cols.len()),
3812            SmoothBasisSpec::Matern { feature_cols, .. } => Some(feature_cols.len()),
3813            SmoothBasisSpec::Duchon { feature_cols, .. } => Some(feature_cols.len()),
3814            _ => None,
3815        })
3816}
3817
3818/// Log the learned per-axis spatial anisotropy for all spatial terms that
3819/// have `aniso_log_scales` set after optimization.
3820///
3821/// For scalar-scale families this reports eta, effective per-axis length
3822/// scales, and per-axis kappa values. For pure Duchon it reports the centered
3823/// eta contrasts only.
3824pub fn log_spatial_aniso_scales(spec: &TermCollectionSpec) {
3825    for (term_idx, term) in spec.smooth_terms.iter().enumerate() {
3826        let (aniso, length_scale) = match &term.basis {
3827            SmoothBasisSpec::Matern { spec, .. } => {
3828                (spec.aniso_log_scales.as_ref(), Some(spec.length_scale))
3829            }
3830            SmoothBasisSpec::Duchon { spec, .. } => {
3831                (spec.aniso_log_scales.as_ref(), spec.length_scale)
3832            }
3833            _ => (None, None),
3834        };
3835        let Some(eta) = aniso else { continue };
3836        if eta.is_empty() {
3837            continue;
3838        }
3839        let mut lines = match length_scale {
3840            Some(ls) => format!(
3841                "[spatial-kappa] term {} (\"{}\"): anisotropic length scales optimized (global length_scale={:.4})",
3842                term_idx, term.name, ls
3843            ),
3844            None => format!(
3845                "[spatial-kappa] term {} (\"{}\"): pure Duchon shape anisotropy optimized",
3846                term_idx, term.name
3847            ),
3848        };
3849        for (a, &eta_a) in eta.iter().enumerate() {
3850            if let Some(ls) = length_scale {
3851                let length_a = ls * (-eta_a).exp();
3852                let kappa_a = (1.0 / ls) * eta_a.exp();
3853                lines.push_str(&format!(
3854                    "\n  axis {}: eta={:+.4}, length={:.4}, kappa={:.4}",
3855                    a, eta_a, length_a, kappa_a
3856                ));
3857            } else {
3858                lines.push_str(&format!("\n  axis {}: eta={:+.4}", a, eta_a));
3859            }
3860        }
3861        log::info!("{}", lines);
3862    }
3863}
3864
3865/// Set `aniso_log_scales` on a spatial term's basis spec.
3866pub fn set_spatial_aniso_log_scales(
3867    spec: &mut TermCollectionSpec,
3868    term_idx: usize,
3869    eta: Vec<f64>,
3870) -> Result<(), EstimationError> {
3871    let eta = center_aniso_log_scales(&eta);
3872    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3873        crate::bail_invalid_estim!("spatial aniso_log_scales term index {term_idx} out of range");
3874    };
3875    match &mut term.basis {
3876        SmoothBasisSpec::Matern { spec, .. } => {
3877            spec.aniso_log_scales = Some(eta);
3878            Ok(())
3879        }
3880        SmoothBasisSpec::Duchon { spec, .. } => {
3881            spec.aniso_log_scales = Some(eta);
3882            Ok(())
3883        }
3884        _ => Err(EstimationError::InvalidInput(format!(
3885            "term '{}' does not support aniso_log_scales",
3886            term.name
3887        ))),
3888    }
3889}
3890
3891/// Sync knot-cloud-derived anisotropy contrasts from basis metadata back into
3892/// the mutable spec so the optimizer starts from the correct eta values.
3893///
3894/// Call this after building the smooth design but before initializing the
3895/// optimizer's psi coordinates. For each spatial term whose metadata contains
3896/// computed `aniso_log_scales`, this writes them into the spec.
3897pub fn sync_aniso_contrasts_from_metadata(
3898    spec: &mut TermCollectionSpec,
3899    design: &SmoothDesign,
3900) {
3901    for (term_idx, term) in design.terms.iter().enumerate() {
3902        let meta_aniso = match &term.metadata {
3903            BasisMetadata::Matern {
3904                aniso_log_scales, ..
3905            } => aniso_log_scales.clone(),
3906            BasisMetadata::Duchon {
3907                aniso_log_scales, ..
3908            } => aniso_log_scales.clone(),
3909            _ => None,
3910        };
3911        if let Some(eta) = meta_aniso
3912            && eta.len() > 1
3913        {
3914            set_spatial_aniso_log_scales(spec, term_idx, eta).ok();
3915        }
3916    }
3917}
3918
3919#[derive(Debug, Clone)]
3920pub struct SpatialLengthScaleOptimizationOptions {
3921    /// Enable outer-loop optimization over spatial κ (= 1 / length_scale)
3922    /// for supported radial-kernel smooths.
3923    /// This applies to ThinPlate, Matérn, and Duchon terms.
3924    pub enabled: bool,
3925    /// Maximum number of outer iterations in the exact joint [rho, psi] solve.
3926    pub max_outer_iter: usize,
3927    /// Relative improvement threshold for terminating the outer solve.
3928    pub rel_tol: f64,
3929    /// Initial log(length_scale) perturbation used for seed construction.
3930    pub log_step: f64,
3931    /// Minimum allowed length_scale during κ search.
3932    pub min_length_scale: f64,
3933    /// Maximum allowed length_scale during κ search.
3934    pub max_length_scale: f64,
3935    /// Automatic geometry-initializer threshold for large-scale spatial fits.
3936    ///
3937    /// When n exceeds twice this value, the fitter uses a spatially stratified
3938    /// subsample only to seed κ/anisotropy geometry: centers are resolved,
3939    /// axis contrasts are initialized from center/data spread, and one or two
3940    /// cheap ψ reseeding updates are applied. It never runs PIRLS, REML, ARC,
3941    /// BFGS, or any recursive optimizer on the pilot.
3942    ///
3943    /// The final coefficients, smoothing parameters, and spatial geometry are
3944    /// always optimized on the full dataset.
3945    ///
3946    /// Set to 0 to skip the pilot geometry initializer.
3947    pub pilot_subsample_threshold: usize,
3948}
3949
3950impl Default for SpatialLengthScaleOptimizationOptions {
3951    fn default() -> Self {
3952        Self {
3953            enabled: true,
3954            max_outer_iter: 80,
3955            rel_tol: 1e-4,
3956            log_step: std::f64::consts::LN_2,
3957            min_length_scale: 1e-3,
3958            max_length_scale: 1e3,
3959            pilot_subsample_threshold: 10_000,
3960        }
3961    }
3962}
3963
3964impl SpatialLengthScaleOptimizationOptions {
3965    /// Validate the struct's invariants. Callers that construct these options
3966    /// from external input (CLI, config, Python API) should call this before
3967    /// passing the options into the fitter. Returns `Err` with a descriptive
3968    /// message when an invariant is violated; the fitter then panics or
3969    /// returns `EstimationError` at its own boundary.
3970    ///
3971    /// Invariants:
3972    ///   * `min_length_scale > 0`, finite
3973    ///   * `max_length_scale > 0`, finite
3974    ///   * `min_length_scale < max_length_scale`
3975    ///   * `rel_tol > 0`, finite
3976    ///   * `log_step > 0`, finite
3977    ///
3978    /// These invariants are what the downstream κ-bound and ψ-window code
3979    /// assumes (`-log(max_ls)` must be finite, `(min,max)` must not be
3980    /// inverted, etc.). Without validation, invalid options produce silent
3981    /// NaN-propagation inside the outer optimizer.
3982    pub fn validate(&self) -> Result<(), String> {
3983        if !self.min_length_scale.is_finite() || self.min_length_scale <= 0.0 {
3984            return Err(SmoothError::invalid_config(format!(
3985                "SpatialLengthScaleOptimizationOptions::min_length_scale must be > 0 and finite, got {}",
3986                self.min_length_scale
3987            ))
3988            .into());
3989        }
3990        if !self.max_length_scale.is_finite() || self.max_length_scale <= 0.0 {
3991            return Err(SmoothError::invalid_config(format!(
3992                "SpatialLengthScaleOptimizationOptions::max_length_scale must be > 0 and finite, got {}",
3993                self.max_length_scale
3994            ))
3995            .into());
3996        }
3997        if self.min_length_scale >= self.max_length_scale {
3998            return Err(SmoothError::invalid_config(format!(
3999                "SpatialLengthScaleOptimizationOptions requires min_length_scale < max_length_scale, got min={} max={}",
4000                self.min_length_scale, self.max_length_scale
4001            ))
4002            .into());
4003        }
4004        if !self.rel_tol.is_finite() || self.rel_tol <= 0.0 {
4005            return Err(SmoothError::invalid_config(format!(
4006                "SpatialLengthScaleOptimizationOptions::rel_tol must be > 0 and finite, got {}",
4007                self.rel_tol
4008            ))
4009            .into());
4010        }
4011        if !self.log_step.is_finite() || self.log_step <= 0.0 {
4012            return Err(SmoothError::invalid_config(format!(
4013                "SpatialLengthScaleOptimizationOptions::log_step must be > 0 and finite, got {}",
4014                self.log_step
4015            ))
4016            .into());
4017        }
4018        Ok(())
4019    }
4020}
4021
4022#[derive(Debug, Clone)]
4023pub struct RandomEffectBlock {
4024    pub name: String,
4025    /// O(n) group-label vector: group_ids[i] = column index in [0, num_groups).
4026    /// `None` if the observation's level is not in the kept set.
4027    pub group_ids: Vec<Option<usize>>,
4028    pub num_groups: usize,
4029    pub kept_levels: Vec<u64>,
4030}
4031
4032pub const BLOCK_SPARSE_ZERO_EPS: f64 = 1e-12;
4033
4034pub const BLOCK_SPARSE_MAX_DENSITY: f64 = 0.20;
4035
4036pub fn blocks_have_intrinsic_sparse_structure(blocks: &[DesignBlock]) -> bool {
4037    blocks
4038        .iter()
4039        .any(|block| matches!(block, DesignBlock::Sparse(_) | DesignBlock::RandomEffect(_)))
4040}
4041
4042pub fn sparse_compatible_block_nnz(block: &DesignBlock) -> Option<usize> {
4043    match block {
4044        DesignBlock::Intercept(n) => Some(*n),
4045        DesignBlock::RandomEffect(op) => {
4046            Some(op.group_ids.iter().filter(|gid| gid.is_some()).count())
4047        }
4048        DesignBlock::Sparse(sparse) => Some(sparse.val().len()),
4049        DesignBlock::Dense(dense) => dense.as_dense_ref().map(|matrix| {
4050            matrix
4051                .iter()
4052                .filter(|&&value| value.abs() > BLOCK_SPARSE_ZERO_EPS)
4053                .count()
4054        }),
4055    }
4056}
4057
4058pub fn try_build_sparse_design_from_blocks(
4059    blocks: &[DesignBlock],
4060) -> Result<Option<DesignMatrix>, BasisError> {
4061    if blocks.is_empty() {
4062        return Ok(None);
4063    }
4064    let nrows = blocks[0].nrows();
4065    let ncols: usize = blocks.iter().map(DesignBlock::ncols).sum();
4066    if nrows == 0 || ncols == 0 || ncols <= 32 {
4067        return Ok(None);
4068    }
4069
4070    let preserve_sparse_storage = blocks_have_intrinsic_sparse_structure(blocks);
4071    let sparse_nnz_limit = if preserve_sparse_storage {
4072        usize::MAX
4073    } else {
4074        let total_cells = nrows.saturating_mul(ncols);
4075        ((total_cells as f64) * BLOCK_SPARSE_MAX_DENSITY).floor() as usize
4076    };
4077    let mut nnz = 0usize;
4078    for block in blocks {
4079        let block_nnz = if let Some(block_nnz) = sparse_compatible_block_nnz(block) {
4080            block_nnz
4081        } else {
4082            return Ok(None);
4083        };
4084        nnz = nnz.saturating_add(block_nnz);
4085        if nnz > sparse_nnz_limit {
4086            return Ok(None);
4087        }
4088    }
4089
4090    let mut triplets = Vec::<Triplet<usize, usize, f64>>::with_capacity(nnz);
4091    let mut col_offset = 0usize;
4092    for block in blocks {
4093        match block {
4094            DesignBlock::Intercept(n) => {
4095                for row in 0..*n {
4096                    triplets.push(Triplet::new(row, col_offset, 1.0));
4097                }
4098            }
4099            DesignBlock::RandomEffect(op) => {
4100                for (row, group_id) in op.group_ids.iter().enumerate() {
4101                    if let Some(group) = group_id {
4102                        triplets.push(Triplet::new(row, col_offset + group, 1.0));
4103                    }
4104                }
4105            }
4106            DesignBlock::Sparse(sparse) => {
4107                let (symbolic, values) = sparse.parts();
4108                let col_ptr = symbolic.col_ptr();
4109                let row_idx = symbolic.row_idx();
4110                for col in 0..sparse.ncols() {
4111                    for idx in col_ptr[col]..col_ptr[col + 1] {
4112                        let value = values[idx];
4113                        if value.abs() > BLOCK_SPARSE_ZERO_EPS {
4114                            triplets.push(Triplet::new(row_idx[idx], col_offset + col, value));
4115                        }
4116                    }
4117                }
4118            }
4119            DesignBlock::Dense(dense) => {
4120                let matrix = dense.as_dense_ref().ok_or_else(|| {
4121                    BasisError::InvalidInput(
4122                        "sparse-compatible block assembly requires materialized dense blocks"
4123                            .to_string(),
4124                    )
4125                })?;
4126                for row in 0..matrix.nrows() {
4127                    for col in 0..matrix.ncols() {
4128                        let value = matrix[[row, col]];
4129                        if value.abs() > BLOCK_SPARSE_ZERO_EPS {
4130                            triplets.push(Triplet::new(row, col_offset + col, value));
4131                        }
4132                    }
4133                }
4134            }
4135        }
4136        col_offset += block.ncols();
4137    }
4138
4139    let sparse = SparseColMat::try_new_from_triplets(nrows, ncols, &triplets).map_err(|_| {
4140        BasisError::SparseCreation("failed to assemble sparse term-collection design".to_string())
4141    })?;
4142    Ok(Some(DesignMatrix::Sparse(
4143        gam_linalg::matrix::SparseDesignMatrix::new(sparse),
4144    )))
4145}
4146
4147pub fn assemble_term_collection_design_matrix(
4148    blocks: Vec<DesignBlock>,
4149) -> Result<DesignMatrix, BasisError> {
4150    if let Some(sparse) = try_build_sparse_design_from_blocks(&blocks)? {
4151        return Ok(sparse);
4152    }
4153    let block_op = BlockDesignOperator::new(blocks).map_err(|e| {
4154        BasisError::InvalidInput(format!("failed to build block design operator: {e}"))
4155    })?;
4156    Ok(DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
4157        Arc::new(block_op),
4158    )))
4159}
4160
4161pub fn select_columns(data: ArrayView2<'_, f64>, cols: &[usize]) -> Result<Array2<f64>, BasisError> {
4162    let n = data.nrows();
4163    let p = data.ncols();
4164    for &c in cols {
4165        if c >= p {
4166            crate::bail_dim_basis!("feature column {c} is out of bounds for data with {p} columns");
4167        }
4168    }
4169    let mut out = Array2::<f64>::zeros((n, cols.len()));
4170    for (j, &c) in cols.iter().enumerate() {
4171        out.column_mut(j).assign(&data.column(c));
4172    }
4173    Ok(out)
4174}
4175
4176pub fn nonfinite_value_label(value: f64) -> &'static str {
4177    if value.is_nan() {
4178        "NaN"
4179    } else if value.is_sign_positive() {
4180        "+Inf"
4181    } else {
4182        "-Inf"
4183    }
4184}
4185
4186pub fn validate_term_feature_column_finite(
4187    data: ArrayView2<'_, f64>,
4188    term_kind: &str,
4189    term_name: &str,
4190    feature_col: usize,
4191) -> Result<(), BasisError> {
4192    let p = data.ncols();
4193    if feature_col >= p {
4194        crate::bail_dim_basis!(
4195            "{term_kind} term '{term_name}' feature column {feature_col} out of bounds for {p} columns"
4196        );
4197    }
4198    for (row, &value) in data.column(feature_col).iter().enumerate() {
4199        if !value.is_finite() {
4200            crate::bail_invalid_basis!(
4201                "{term_kind} term '{term_name}' feature column {feature_col} row {row} contains non-finite value {}",
4202                nonfinite_value_label(value)
4203            );
4204        }
4205    }
4206    Ok(())
4207}
4208
4209pub fn validate_smooth_terms_finite_inputs(
4210    data: ArrayView2<'_, f64>,
4211    terms: &[SmoothTermSpec],
4212) -> Result<(), BasisError> {
4213    for term in terms {
4214        for feature_col in smooth_term_feature_cols(term) {
4215            validate_term_feature_column_finite(data, "smooth", &term.name, feature_col)?;
4216        }
4217    }
4218    Ok(())
4219}
4220
4221pub fn validate_term_collection_finite_inputs(
4222    data: ArrayView2<'_, f64>,
4223    spec: &TermCollectionSpec,
4224) -> Result<(), BasisError> {
4225    for term in &spec.linear_terms {
4226        validate_term_feature_column_finite(data, "linear", &term.name, term.feature_col)?;
4227    }
4228    for term in &spec.random_effect_terms {
4229        validate_term_feature_column_finite(data, "random-effect", &term.name, term.feature_col)?;
4230    }
4231    validate_smooth_terms_finite_inputs(data, &spec.smooth_terms)
4232}
4233
4234#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
4235pub struct JointSpatialCenterGroupKey {
4236    feature_cols: Vec<usize>,
4237    strategy_kind: CenterStrategyKind,
4238    strategy_aux: usize,
4239    requested_num_centers: usize,
4240    input_scale_bits: Option<Vec<u64>>,
4241}
4242
4243pub fn spatial_term_min_center_count(term: &SmoothTermSpec) -> usize {
4244    match &term.basis {
4245        SmoothBasisSpec::ThinPlate { feature_cols, .. } => feature_cols.len() + 1,
4246        SmoothBasisSpec::Duchon {
4247            feature_cols, spec, ..
4248        } => match spec.nullspace_order {
4249            crate::basis::DuchonNullspaceOrder::Zero => 1,
4250            crate::basis::DuchonNullspaceOrder::Linear => feature_cols.len() + 1,
4251            crate::basis::DuchonNullspaceOrder::Degree(degree) => {
4252                crate::basis::duchon_nullspace_dimension(feature_cols.len(), degree)
4253            }
4254        },
4255        SmoothBasisSpec::Matern { .. } => 1,
4256        _ => 1,
4257    }
4258}
4259
4260pub fn spatial_term_group_key(term: &SmoothTermSpec) -> Option<JointSpatialCenterGroupKey> {
4261    let (feature_cols, strategy, input_scales) = match &term.basis {
4262        SmoothBasisSpec::ThinPlate {
4263            feature_cols,
4264            spec,
4265            input_scales,
4266        } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4267        SmoothBasisSpec::Matern {
4268            feature_cols,
4269            spec,
4270            input_scales,
4271        } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4272        SmoothBasisSpec::Duchon {
4273            feature_cols,
4274            spec,
4275            input_scales,
4276        } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4277        _ => return None,
4278    };
4279    let strategy_kind = center_strategy_kind(strategy);
4280    let strategy_aux = match strategy {
4281        CenterStrategy::Auto(inner) => match inner.as_ref() {
4282            CenterStrategy::KMeans { max_iter, .. } => *max_iter,
4283            CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
4284            _ => 0,
4285        },
4286        CenterStrategy::KMeans { max_iter, .. } => *max_iter,
4287        CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
4288        _ => 0,
4289    };
4290    Some(JointSpatialCenterGroupKey {
4291        feature_cols: feature_cols.clone(),
4292        strategy_kind,
4293        strategy_aux,
4294        requested_num_centers: center_strategy_num_centers(strategy)?,
4295        input_scale_bits: input_scales
4296            .map(|values| values.iter().map(|value| value.to_bits()).collect()),
4297    })
4298}
4299
4300pub fn spatial_term_center_strategy(term: &SmoothTermSpec) -> Option<&CenterStrategy> {
4301    match &term.basis {
4302        SmoothBasisSpec::ThinPlate { spec, .. } => Some(&spec.center_strategy),
4303        SmoothBasisSpec::Matern { spec, .. } => Some(&spec.center_strategy),
4304        SmoothBasisSpec::Duchon { spec, .. } => Some(&spec.center_strategy),
4305        _ => None,
4306    }
4307}
4308
4309pub fn set_spatial_term_centers(
4310    term: &mut SmoothTermSpec,
4311    centers: Array2<f64>,
4312) -> Result<(), BasisError> {
4313    match &mut term.basis {
4314        SmoothBasisSpec::ThinPlate { spec, .. } => {
4315            spec.center_strategy = CenterStrategy::UserProvided(centers);
4316            Ok(())
4317        }
4318        SmoothBasisSpec::Matern { spec, .. } => {
4319            spec.center_strategy = CenterStrategy::UserProvided(centers);
4320            Ok(())
4321        }
4322        SmoothBasisSpec::Duchon { spec, .. } => {
4323            spec.center_strategy = CenterStrategy::UserProvided(centers);
4324            Ok(())
4325        }
4326        _ => Err(BasisError::InvalidInput(format!(
4327            "term '{}' does not support spatial center planning",
4328            term.name
4329        ))),
4330    }
4331}
4332
4333pub fn standardized_spatial_term_data(
4334    data: ArrayView2<'_, f64>,
4335    term: &SmoothTermSpec,
4336) -> Result<Array2<f64>, BasisError> {
4337    let (feature_cols, input_scales) = match &term.basis {
4338        SmoothBasisSpec::ThinPlate {
4339            feature_cols,
4340            input_scales,
4341            ..
4342        }
4343        | SmoothBasisSpec::Matern {
4344            feature_cols,
4345            input_scales,
4346            ..
4347        }
4348        | SmoothBasisSpec::Duchon {
4349            feature_cols,
4350            input_scales,
4351            ..
4352        } => (feature_cols, input_scales.as_ref()),
4353        _ => {
4354            crate::bail_invalid_basis!("term '{}' is not a spatial smooth", term.name);
4355        }
4356    };
4357    let mut x = select_columns(data, feature_cols)?;
4358    if let Some(scales) = input_scales {
4359        apply_input_standardization(&mut x, scales);
4360    } else if let Some(scales) = compute_spatial_input_scales(x.view()) {
4361        apply_input_standardization(&mut x, &scales);
4362    }
4363    Ok(x)
4364}
4365
4366pub fn plan_joint_spatial_centers_for_term_blocks(
4367    data: ArrayView2<'_, f64>,
4368    term_blocks: &[Vec<SmoothTermSpec>],
4369) -> Result<Vec<Vec<SmoothTermSpec>>, BasisError> {
4370    let mut planned_blocks = term_blocks.to_vec();
4371    let n = data.nrows();
4372    let mut groups: BTreeMap<JointSpatialCenterGroupKey, Vec<(usize, usize)>> = BTreeMap::new();
4373
4374    for (block_idx, terms) in planned_blocks.iter().enumerate() {
4375        for (term_idx, term) in terms.iter().enumerate() {
4376            let Some(strategy) = spatial_term_center_strategy(term) else {
4377                continue;
4378            };
4379            if !center_strategy_is_auto(strategy) {
4380                continue;
4381            }
4382            let Some(group_key) = spatial_term_group_key(term) else {
4383                continue;
4384            };
4385            if !matches!(
4386                group_key.strategy_kind,
4387                CenterStrategyKind::EqualMass
4388                    | CenterStrategyKind::EqualMassCovarRepresentative
4389                    | CenterStrategyKind::FarthestPoint
4390                    | CenterStrategyKind::KMeans
4391            ) {
4392                continue;
4393            }
4394            if center_strategy_num_centers(strategy).is_none() {
4395                continue;
4396            }
4397            groups
4398                .entry(group_key)
4399                .or_default()
4400                .push((block_idx, term_idx));
4401        }
4402    }
4403
4404    for (group_key, members) in groups {
4405        if members.len() < 2 {
4406            continue;
4407        }
4408        let min_required = members
4409            .iter()
4410            .map(|&(block_idx, term_idx)| {
4411                spatial_term_min_center_count(&planned_blocks[block_idx][term_idx])
4412            })
4413            .max()
4414            .unwrap_or(1);
4415        let joint_centers = group_key
4416            .requested_num_centers
4417            .max(min_required)
4418            .min(n.max(1));
4419        let (first_block_idx, first_term_idx) = members[0];
4420        let prototype = &planned_blocks[first_block_idx][first_term_idx];
4421        let standardized = standardized_spatial_term_data(data, prototype)?;
4422        let strategy = spatial_term_center_strategy(prototype).ok_or_else(|| {
4423            BasisError::InvalidInput(format!(
4424                "term '{}' lost its spatial center strategy during joint planning",
4425                prototype.name
4426            ))
4427        })?;
4428        let joint_strategy = center_strategy_with_num_centers(strategy, joint_centers)?;
4429        let shared_centers = select_centers_by_strategy(standardized.view(), &joint_strategy)?;
4430        log::info!(
4431            "sharing {} spatial centers across {} smooth terms over columns {:?} (requested {} centers)",
4432            shared_centers.nrows(),
4433            members.len(),
4434            group_key.feature_cols,
4435            group_key.requested_num_centers,
4436        );
4437        for (block_idx, term_idx) in members {
4438            set_spatial_term_centers(
4439                &mut planned_blocks[block_idx][term_idx],
4440                shared_centers.clone(),
4441            )?;
4442        }
4443    }
4444
4445    // Sentinel auto-init: Matern and thin-plate builders write length_scale =
4446    // 0.0 when the user didn't pass `length_scale=...`. Replace those with a
4447    // data-driven initialization here so REML starts in a regime where it can
4448    // escape; the hard-coded 1.0 default was a basin from which ν ≥ 5/2 Matern
4449    // could not recover on high-frequency truths, silently collapsing the fit
4450    // to a near-constant prediction.
4451    for block in planned_blocks.iter_mut() {
4452        for term in block.iter_mut() {
4453            auto_init_length_scale_in_place(data, term);
4454        }
4455    }
4456
4457    Ok(planned_blocks)
4458}
4459
4460/// Tiny positive floor for the auto length scale, guarding against a zero
4461/// kernel range when every feature column is (near-)constant.
4462const AUTO_LENGTH_SCALE_FLOOR: f64 = 1e-6;
4463
4464/// Widest per-axis range of the selected feature columns. Returns `None` when
4465/// every selected column is constant / non-finite (no usable spatial scale).
4466fn feature_columns_max_range(data: ArrayView2<'_, f64>, feature_cols: &[usize]) -> Option<f64> {
4467    let mut max_range = 0.0_f64;
4468    for &c in feature_cols {
4469        if c >= data.ncols() {
4470            continue;
4471        }
4472        let col = data.column(c);
4473        let mut lo = f64::INFINITY;
4474        let mut hi = f64::NEG_INFINITY;
4475        for &v in col.iter() {
4476            if v.is_finite() {
4477                if v < lo {
4478                    lo = v;
4479                }
4480                if v > hi {
4481                    hi = v;
4482                }
4483            }
4484        }
4485        if hi > lo {
4486            let r = hi - lo;
4487            if r > max_range {
4488                max_range = r;
4489            }
4490        }
4491    }
4492    if max_range.is_finite() && max_range > 0.0 {
4493        Some(max_range)
4494    } else {
4495        None
4496    }
4497}
4498
4499/// Compute a data-driven initial length scale from the per-axis range of the
4500/// feature columns. The heuristic `max_range / sqrt(n)` puts the kernel on
4501/// the wiggly side of REML's basin so the optimizer can grow it back if the
4502/// signal is smooth, but is small enough that high-frequency truths remain
4503/// reachable for smoother kernels (ν ≥ 5/2). Clamped to a tiny positive
4504/// floor so degenerate constant-input columns can't produce 0.
4505pub fn auto_initial_length_scale(data: ArrayView2<'_, f64>, feature_cols: &[usize]) -> f64 {
4506    let n = data.nrows();
4507    if n == 0 || feature_cols.is_empty() {
4508        return 1.0;
4509    }
4510    let Some(max_range) = feature_columns_max_range(data, feature_cols) else {
4511        return 1.0;
4512    };
4513    let init = max_range / (n as f64).sqrt();
4514    init.max(AUTO_LENGTH_SCALE_FLOOR).min(max_range)
4515}
4516
4517/// Density-adaptive auto length scale for a kernel basis with `num_centers`
4518/// requested centers (#1731).
4519///
4520/// The plain [`auto_initial_length_scale`] seed `max_range / sqrt(n)` is the
4521/// fill distance of the *n data points*; it is independent of the requested
4522/// center count `k`. For a radial kernel at a FIXED length scale, packing more
4523/// centers into the same cloud makes neighbouring basis functions overlap and
4524/// go numerically collinear, so the realized basis saturates in rank (the
4525/// `matern_rank_reduce_centers` cap) and a richer `k` becomes a no-op — or even
4526/// shrinks the basis. The kernel stays well-conditioned only while the length
4527/// scale tracks the *center* spacing, not the data spacing.
4528///
4529/// We seed the length scale at the fill distance of `max(n, k)` points,
4530/// `max_range / sqrt(max(n, k))`. When `n ≥ k` (the usual case) this is exactly
4531/// the existing `max_range / sqrt(n)` seed, so every current result and small-`k`
4532/// basis size is preserved bit-for-bit (in every covariate dimension). When
4533/// `k > n` (a dense center request on a small cloud, the regime where an
4534/// `n`-sized seed sits above the center spacing and over-smooths the centers
4535/// into collinearity) the seed shrinks with `k` to the center spacing, keeping
4536/// the requested centers numerically independent. This is the Matérn analogue of
4537/// the Duchon-promotion "length_scale from center spacing" rule
4538/// (`hybrid_duchon_promotion_length_scale`).
4539pub fn auto_initial_length_scale_for_centers(
4540    data: ArrayView2<'_, f64>,
4541    feature_cols: &[usize],
4542    num_centers: usize,
4543) -> f64 {
4544    let n = data.nrows();
4545    if n == 0 || feature_cols.is_empty() {
4546        return 1.0;
4547    }
4548    let Some(max_range) = feature_columns_max_range(data, feature_cols) else {
4549        return 1.0;
4550    };
4551    // Resolution density: at least the data points, but no coarser than the
4552    // center spacing once more centers than data are requested. Using the same
4553    // `sqrt` fill-distance law as `auto_initial_length_scale` keeps the seed
4554    // bit-identical whenever `n ≥ num_centers` (every dimension), and only
4555    // shrinks it — never grows it — when `num_centers > n`.
4556    let resolution_points = n.max(num_centers).max(1) as f64;
4557    let spacing = max_range / resolution_points.sqrt();
4558    spacing.max(AUTO_LENGTH_SCALE_FLOOR).min(max_range)
4559}
4560
4561/// Low-rank radial-basis length-scale seed tied to the requested center spacing.
4562///
4563/// Thin-plate regression splines with `k << n` represent the surface through a
4564/// compact set of centers; seeding the kernel at the observation fill distance
4565/// (`max_range / sqrt(n)`) makes the center Gram nearly diagonal and turns the
4566/// bending penalty into an ill-scaled ridge on the radial coefficients. REML then
4567/// sees a weakly identified smoothing surface and can settle on under-recovered
4568/// spatial fits. Seed at the center fill distance instead, so neighbouring
4569/// centers interact at O(1) scale before REML tunes the smoothing parameter.
4570pub fn auto_initial_length_scale_for_low_rank_centers(
4571    data: ArrayView2<'_, f64>,
4572    feature_cols: &[usize],
4573    num_centers: usize,
4574) -> f64 {
4575    if data.nrows() == 0 || feature_cols.is_empty() {
4576        return 1.0;
4577    }
4578    let Some(max_range) = feature_columns_max_range(data, feature_cols) else {
4579        return 1.0;
4580    };
4581    let resolution_points = num_centers.max(1) as f64;
4582    let spacing = max_range / resolution_points.sqrt();
4583    spacing.max(AUTO_LENGTH_SCALE_FLOOR).min(max_range)
4584}
4585
4586/// Requested center count encoded by a [`CenterStrategy`], if it carries an
4587/// explicit count (used to make the Matérn auto length scale density-adaptive).
4588fn center_strategy_requested_count(strategy: &CenterStrategy) -> Option<usize> {
4589    match strategy {
4590        CenterStrategy::Auto(inner) => center_strategy_requested_count(inner),
4591        CenterStrategy::UserProvided(centers) => Some(centers.nrows()),
4592        CenterStrategy::EqualMass { num_centers }
4593        | CenterStrategy::EqualMassCovarRepresentative { num_centers }
4594        | CenterStrategy::FarthestPoint { num_centers }
4595        | CenterStrategy::KMeans { num_centers, .. } => Some(*num_centers),
4596        CenterStrategy::UniformGrid { .. } => None,
4597    }
4598}
4599
4600/// Walk a term and, if it is a Matern or thin-plate smooth whose length_scale
4601/// was left at the auto sentinel (`0.0`), overwrite it with
4602/// [`auto_initial_length_scale`].
4603pub fn auto_init_length_scale_in_place(data: ArrayView2<'_, f64>, term: &mut SmoothTermSpec) {
4604    auto_init_length_scale_in_basis(data, &mut term.basis);
4605}
4606
4607/// Replace the `0.0` auto-init length-scale sentinel with a data-derived value
4608/// for any Matern / thin-plate kernel reachable from this basis — including the
4609/// inner kernel of a `by=`/factor-smooth wrapper.
4610///
4611/// `by=<factor>` and the sum-to-zero factor smooth wrap a spatial kernel inside
4612/// `SmoothBasisSpec::ByVariable` / `SmoothBasisSpec::FactorSumToZero` /
4613/// `SmoothBasisSpec::BySmooth`, so the wrapper variant is what the planner sees.
4614/// Without recursing into the wrapped basis the inner ThinPlate/Matern keeps the
4615/// `0.0` sentinel (the post-`1605b3a6e` builder default), which makes the kernel
4616/// distance divide by `length_scale² = 0`, producing a non-finite design at both
4617/// fit and predict time. Recurse so the inner kernel is initialized identically
4618/// to a top-level one.
4619pub fn auto_init_length_scale_in_basis(data: ArrayView2<'_, f64>, basis: &mut SmoothBasisSpec) {
4620    match basis {
4621        SmoothBasisSpec::Matern {
4622            feature_cols, spec, ..
4623        } => {
4624            if spec.length_scale == 0.0 {
4625                // Density-adaptive seed (#1731): when the requested center count
4626                // is known, scale the auto length scale with the *center*
4627                // spacing so a richer `k` stays numerically full-rank instead of
4628                // saturating against `matern_rank_reduce_centers`. For `n ≥ k`
4629                // (the usual case) this is identical to the plain `max_range /
4630                // sqrt(n)` seed in 2-D, so small-`k` results are unchanged. The
4631                // unconstrained / non-explicit `UniformGrid` strategy falls back
4632                // to the plain seed.
4633                spec.length_scale = match center_strategy_requested_count(&spec.center_strategy) {
4634                    Some(k) => auto_initial_length_scale_for_centers(data, feature_cols, k),
4635                    None => auto_initial_length_scale(data, feature_cols),
4636                };
4637            }
4638        }
4639        SmoothBasisSpec::ThinPlate {
4640            feature_cols, spec, ..
4641        } => {
4642            if spec.length_scale == 0.0 {
4643                spec.length_scale = match center_strategy_requested_count(&spec.center_strategy) {
4644                    Some(k) => {
4645                        auto_initial_length_scale_for_low_rank_centers(data, feature_cols, k)
4646                    }
4647                    None => auto_initial_length_scale(data, feature_cols),
4648                };
4649            }
4650        }
4651        SmoothBasisSpec::ByVariable { inner, .. }
4652        | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
4653            auto_init_length_scale_in_basis(data, inner);
4654        }
4655        SmoothBasisSpec::BySmooth { smooth, .. } => {
4656            auto_init_length_scale_in_basis(data, smooth);
4657        }
4658        _ => {}
4659    }
4660}
4661
4662impl LinearFitConditioning {
4663    pub fn from_columns(design: &TermCollectionDesign, selected_cols: &[usize]) -> Self {
4664        const SCALE_EPS: f64 = 1e-12;
4665        let n = design.design.nrows();
4666        let p = design.design.ncols();
4667        let mut columns = Vec::with_capacity(selected_cols.len());
4668        if n == 0 || selected_cols.is_empty() {
4669            return Self {
4670                intercept_idx: design.intercept_range.start,
4671                columns,
4672            };
4673        }
4674        let chunk_rows = gam_linalg::utils::row_chunk_for_byte_budget(n, p);
4675        // Two-pass mean/variance so operator-backed designs don't need to
4676        // materialize the full dense matrix. Pass 1 accumulates per-column
4677        // sums; pass 2 accumulates the sum of squared deviations from the
4678        // pass-1 mean. This matches the original `Σ (x − mean)² / n` formula
4679        // without the catastrophic cancellation of `E[X²] − E[X]²`.
4680        let mut sums = vec![0.0_f64; selected_cols.len()];
4681        for start in (0..n).step_by(chunk_rows) {
4682            let end = (start + chunk_rows).min(n);
4683            let chunk = design
4684                .design
4685                .try_row_chunk(start..end)
4686                .expect("LinearFitConditioning::from_columns row chunk failed");
4687            for (k, &col_idx) in selected_cols.iter().enumerate() {
4688                let column = chunk.column(col_idx);
4689                for &v in column.iter() {
4690                    sums[k] += v;
4691                }
4692            }
4693        }
4694        let inv_n = 1.0_f64 / n as f64;
4695        let means: Vec<f64> = sums.iter().map(|&s| s * inv_n).collect();
4696        let mut sq_devs = vec![0.0_f64; selected_cols.len()];
4697        for start in (0..n).step_by(chunk_rows) {
4698            let end = (start + chunk_rows).min(n);
4699            let chunk = design
4700                .design
4701                .try_row_chunk(start..end)
4702                .expect("LinearFitConditioning::from_columns row chunk failed");
4703            for (k, &col_idx) in selected_cols.iter().enumerate() {
4704                let mean_k = means[k];
4705                let column = chunk.column(col_idx);
4706                for &v in column.iter() {
4707                    let d = v - mean_k;
4708                    sq_devs[k] += d * d;
4709                }
4710            }
4711        }
4712        for (k, &col_idx) in selected_cols.iter().enumerate() {
4713            let mean = means[k];
4714            let var = sq_devs[k] * inv_n;
4715            let (mean, scale) = if var.is_finite() && var > SCALE_EPS * SCALE_EPS {
4716                (mean, var.sqrt())
4717            } else {
4718                // Leave nearly-constant columns untouched; centering them would collapse
4719                // the design column to ~0 and change the model rather than just condition it.
4720                (0.0, 1.0)
4721            };
4722            columns.push(LinearColumnConditioning {
4723                col_idx,
4724                mean,
4725                scale,
4726            });
4727        }
4728        Self {
4729            intercept_idx: design.intercept_range.start,
4730            columns,
4731        }
4732    }
4733
4734    pub fn apply_to_design(&self, design: &Array2<f64>) -> Array2<f64> {
4735        let mut out = design.clone();
4736        for col in &self.columns {
4737            {
4738                let mut dst = out.column_mut(col.col_idx);
4739                dst -= col.mean;
4740            }
4741            if col.scale != 1.0 {
4742                out.column_mut(col.col_idx).mapv_inplace(|v| v / col.scale);
4743            }
4744        }
4745        out
4746    }
4747
4748    fn transform_matrix_columnswith_a(&self, mat: &Array2<f64>) -> Array2<f64> {
4749        let mut out = mat.clone();
4750        let intercept = self.intercept_idx;
4751        for col in &self.columns {
4752            let intercept_col = out.column(intercept).to_owned();
4753            let mut target = out.column_mut(col.col_idx);
4754            target -= &(intercept_col * col.mean);
4755            if col.scale != 1.0 {
4756                target.mapv_inplace(|v| v / col.scale);
4757            }
4758        }
4759        out
4760    }
4761
4762    fn transform_matrixrowswith_a_transpose(&self, mat: &Array2<f64>) -> Array2<f64> {
4763        let mut out = mat.clone();
4764        let intercept = self.intercept_idx;
4765        for col in &self.columns {
4766            let interceptrow = out.row(intercept).to_owned();
4767            let mut target = out.row_mut(col.col_idx);
4768            target -= &(interceptrow * col.mean);
4769            if col.scale != 1.0 {
4770                target.mapv_inplace(|v| v / col.scale);
4771            }
4772        }
4773        out
4774    }
4775
4776    /// Left-multiply `mat_internal` by `M⁻ᵀ` where `M⁻¹[intercept, j] = mean_j`
4777    /// and `M⁻¹[j, j] = scale_j` for each conditioned column. Used together
4778    /// with [`Self::right_multiply_by_m_inv`] to back-transform an internal
4779    /// penalized Hessian to the original coefficient basis.
4780    fn left_multiply_by_m_inv_transpose(&self, mat_internal: &Array2<f64>) -> Array2<f64> {
4781        let mut out = mat_internal.clone();
4782        let intercept = self.intercept_idx;
4783        let interceptrow_snapshot = mat_internal.row(intercept).to_owned();
4784        for col in &self.columns {
4785            if col.scale != 1.0 {
4786                out.row_mut(col.col_idx).mapv_inplace(|v| v * col.scale);
4787            }
4788            if col.mean != 0.0 {
4789                let mut target = out.row_mut(col.col_idx);
4790                target += &(&interceptrow_snapshot * col.mean);
4791            }
4792        }
4793        out
4794    }
4795
4796    /// Right-multiply `mat_internal` by `M⁻¹`. Mirror of
4797    /// [`Self::left_multiply_by_m_inv_transpose`] on columns.
4798    fn right_multiply_by_m_inv(&self, mat_internal: &Array2<f64>) -> Array2<f64> {
4799        let mut out = mat_internal.clone();
4800        let intercept = self.intercept_idx;
4801        let intercept_col_snapshot = mat_internal.column(intercept).to_owned();
4802        for col in &self.columns {
4803            if col.scale != 1.0 {
4804                out.column_mut(col.col_idx).mapv_inplace(|v| v * col.scale);
4805            }
4806            if col.mean != 0.0 {
4807                let mut target = out.column_mut(col.col_idx);
4808                target += &(&intercept_col_snapshot * col.mean);
4809            }
4810        }
4811        out
4812    }
4813
4814    /// Transform blockwise penalties through the conditioning.
4815    ///
4816    /// For block-local penalties whose `col_range` does not overlap with any
4817    /// conditioning column, the transform is identity (the conditioning only
4818    /// affects unpenalized linear columns). In that common case the penalty
4819    /// passes through unchanged, avoiding O(p²) materialization entirely.
4820    pub fn transform_blockwise_penalties_to_internal(
4821        &self,
4822        penalties: &[BlockwisePenalty],
4823        p: usize,
4824    ) -> Vec<crate::penalty_spec::PenaltySpec> {
4825        let conditioning_cols: std::collections::HashSet<usize> =
4826            self.columns.iter().map(|c| c.col_idx).collect();
4827        penalties
4828            .iter()
4829            .map(|bp| {
4830                let overlaps =
4831                    (bp.col_range.start..bp.col_range.end).any(|j| conditioning_cols.contains(&j));
4832                if overlaps {
4833                    // Rare: penalty block overlaps conditioning columns.
4834                    // Fall back to dense transform.
4835                    let global = bp.to_global(p);
4836                    let right = self.transform_matrix_columnswith_a(&global);
4837                    let transformed = self.transform_matrixrowswith_a_transpose(&right);
4838                    crate::penalty_spec::PenaltySpec::Dense(transformed)
4839                } else {
4840                    // Common: smooth penalty block doesn't touch linear columns.
4841                    // The conditioning is identity on this block.
4842                    crate::penalty_spec::PenaltySpec::from_blockwise(bp.clone())
4843                }
4844            })
4845            .collect()
4846    }
4847
4848    pub fn backtransform_beta(&self, beta_internal: &Array1<f64>) -> Array1<f64> {
4849        let mut beta = beta_internal.clone();
4850        let intercept = self.intercept_idx;
4851        for col in &self.columns {
4852            beta[intercept] -= beta_internal[col.col_idx] * col.mean / col.scale;
4853            beta[col.col_idx] = beta_internal[col.col_idx] / col.scale;
4854        }
4855        beta
4856    }
4857
4858    /// `H_orig = M⁻ᵀ · H_int · M⁻¹`, derived from
4859    /// `L_int(β_int) = L_orig(M · β_int)` via the chain rule.
4860    pub fn transform_penalized_hessian_to_original(&self, h_internal: &Array2<f64>) -> Array2<f64> {
4861        let right = self.right_multiply_by_m_inv(h_internal);
4862        self.left_multiply_by_m_inv_transpose(&right)
4863    }
4864
4865    pub fn internal_bounds_for(&self, col_idx: usize, min: f64, max: f64) -> (f64, f64) {
4866        if let Some(col) = self.columns.iter().find(|c| c.col_idx == col_idx) {
4867            (min * col.scale, max * col.scale)
4868        } else {
4869            (min, max)
4870        }
4871    }
4872}
4873
4874pub fn freeze_raw_spatial_metadata(metadata: BasisMetadata, raw_cols: usize) -> BasisMetadata {
4875    match metadata {
4876        BasisMetadata::ThinPlate {
4877            centers,
4878            length_scale,
4879            periodic,
4880            identifiability_transform: None,
4881            input_scales,
4882            radial_reparam,
4883        } => BasisMetadata::ThinPlate {
4884            centers,
4885            length_scale,
4886            periodic,
4887            identifiability_transform: Some(Array2::eye(raw_cols)),
4888            input_scales,
4889            radial_reparam,
4890        },
4891        BasisMetadata::Duchon {
4892            centers,
4893            length_scale,
4894            periodic,
4895            power,
4896            nullspace_order,
4897            identifiability_transform: None,
4898            input_scales,
4899            aniso_log_scales,
4900            operator_collocation_points,
4901            radial_reparam,
4902        } => BasisMetadata::Duchon {
4903            centers,
4904            length_scale,
4905            periodic,
4906            power,
4907            nullspace_order,
4908            identifiability_transform: Some(Array2::eye(raw_cols)),
4909            input_scales,
4910            aniso_log_scales,
4911            operator_collocation_points,
4912            radial_reparam,
4913        },
4914        other => other,
4915    }
4916}
4917
4918pub fn matern_operator_penalty_triplet_from_metadata(
4919    metadata: &BasisMetadata,
4920) -> Result<(Vec<Array2<f64>>, Vec<usize>, Vec<PenaltyInfo>), BasisError> {
4921    let BasisMetadata::Matern {
4922        centers,
4923        length_scale,
4924        periodic,
4925        nu,
4926        include_intercept,
4927        identifiability_transform,
4928        aniso_log_scales,
4929        input_scales,
4930        ..
4931    } = metadata
4932    else {
4933        crate::bail_invalid_basis!("Matérn operator penalties require Matérn metadata");
4934    };
4935    // The metadata records `length_scale` in *original* (un-standardized) data
4936    // coordinates, while `centers` live in the *standardized* coordinate frame
4937    // (per-axis division by `input_scales`). The realized design built the
4938    // kernel against those standardized centers using the σ_geom-compensated
4939    // effective length scale `length_scale / σ_geom`. The collocation operators
4940    // here are evaluated on the same standardized centers, so they must use the
4941    // SAME effective length scale — otherwise the penalty regularizes a
4942    // different RKHS range than the design lives in, leaving rough coefficient
4943    // directions effectively unpenalized. That mismatch is benign in 1-D
4944    // (no standardization) but produces a catastrophic out-of-sample blow-up in
4945    // d ≥ 2 where σ_geom ≠ 1 (#706).
4946    let penalty_length_scale = match input_scales.as_deref() {
4947        Some(scales) => compensate_length_scale_for_standardization(*length_scale, scales),
4948        None => *length_scale,
4949    };
4950    matern_operator_penalty_triplet_at_length_scale(
4951        centers.view(),
4952        periodic.as_deref(),
4953        identifiability_transform.as_ref(),
4954        *nu,
4955        *include_intercept,
4956        aniso_log_scales.as_deref(),
4957        penalty_length_scale,
4958    )
4959}
4960
4961/// Build the canonical Matérn operator-penalty triplet (mass / tension /
4962/// stiffness) at an explicit **effective** length scale — i.e. the
4963/// σ_geom-compensated, standardized-frame scale the design's kernel was built
4964/// against (NOT the original-coordinate `length_scale` stored in metadata).
4965///
4966/// This is the SINGLE source of truth for the Matérn penalty topology. Two
4967/// callers route through it and must therefore stay byte-for-byte consistent:
4968///   * the cold/slow design rebuild (`matern_operator_penalty_triplet_from_metadata`,
4969///     compensating the frozen metadata `length_scale`), and
4970///   * the n-free κ-optimizer re-key (`FrozenTermCollectionIncrementalRealizer::
4971///     canonical_penalties_at_psi`, compensating the trial `ψ → exp(-ψ)` scale).
4972///
4973/// Sharing the body makes the penalty BLOCK COUNT and the per-block numerics
4974/// one deterministic function of `(geometry, ν, η, ℓ_eff)`. The active-operator
4975/// gate is `m = ν + d/2`, which is independent of ℓ, so the block count is
4976/// **ψ-stable by construction**: the re-key can never produce a different number
4977/// of blocks than the frozen design (the desync that #1270 hard-errored on).
4978pub fn matern_operator_penalty_triplet_at_length_scale(
4979    centers: ArrayView2<'_, f64>,
4980    periodic: Option<&[Option<f64>]>,
4981    identifiability_transform: Option<&Array2<f64>>,
4982    nu: crate::basis::MaternNu,
4983    include_intercept: bool,
4984    aniso_log_scales: Option<&[f64]>,
4985    effective_length_scale: f64,
4986) -> Result<(Vec<Array2<f64>>, Vec<usize>, Vec<PenaltyInfo>), BasisError> {
4987    let penalty_centers = crate::basis::expand_periodic_centers(&centers.to_owned(), periodic)?;
4988    let ops = build_matern_collocation_operator_matrices(
4989        penalty_centers.view(),
4990        None,
4991        effective_length_scale,
4992        nu,
4993        include_intercept,
4994        identifiability_transform.map(|z| z.view()),
4995        aniso_log_scales,
4996    )?;
4997    // Gate the operator dials on the Matérn-ν RKHS Sobolev order m = ν + d/2:
4998    // mass (j=0) is always on, tension (j=1) is on for m > 1, stiffness (j=2)
4999    // is on for m > 2. The threshold is strict so the roughest kernel ν=1/2 in
5000    // d=1 (m=1, the exponential/OU H¹ process) sheds both higher operators —
5001    // its kernel already encodes the H¹ control, so adding an extra tension
5002    // dial over-smooths the oscillation it is meant to track (#707). The
5003    // matching gate lives at `DuchonOperatorPenaltySpec::matern_for_smoothness`.
5004    const ORDER_EPS: f64 = 1e-9;
5005    let d = penalty_centers.ncols();
5006    let m = nu.half_integer_value() + 0.5 * d as f64;
5007    let mut candidates = Vec::with_capacity(3);
5008    for (raw, source, min_order) in [
5009        (ops.d0.t().dot(&ops.d0), PenaltySource::OperatorMass, 0.0),
5010        (ops.d1.t().dot(&ops.d1), PenaltySource::OperatorTension, 1.0),
5011        (
5012            ops.d2.t().dot(&ops.d2),
5013            PenaltySource::OperatorStiffness,
5014            2.0,
5015        ),
5016    ] {
5017        if min_order > 0.0 && m <= min_order + ORDER_EPS {
5018            continue;
5019        }
5020        let sym = (&raw + &raw.t()) * 0.5;
5021        let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&sym);
5022        candidates.push(PenaltyCandidate {
5023            matrix,
5024            nullspace_dim_hint: 0,
5025            source,
5026            normalization_scale,
5027            kronecker_factors: None,
5028            op: None,
5029        });
5030    }
5031    filter_active_penalty_candidates(candidates)
5032}
5033
5034pub fn normalize_penalty_in_constrained_space(matrix: &Array2<f64>) -> (Array2<f64>, f64) {
5035    // Constrained-space normalization:
5036    //   c = ||S_con||_F,  S_tilde = S_con / c.
5037    // This is the only normalization coherent with a REML objective that is
5038    // evaluated entirely in constrained coordinates.
5039    let matrix = (matrix + &matrix.t().to_owned()) * 0.5;
5040    // Clamp noise-floor negative eigenvalues so β'Sβ is non-negative as a contract, not just in exact arithmetic.
5041    let matrix = crate::basis::project_penalty_to_psd_cone(&matrix);
5042    let c = matrix.iter().map(|v| v * v).sum::<f64>().sqrt();
5043    if c.is_finite() && c > 0.0 {
5044        (matrix.mapv(|v| v / c), c)
5045    } else {
5046        (matrix, 1.0)
5047    }
5048}
5049
5050pub fn tensor_product_design_from_sparse_marginals(
5051    marginal_sparse: &[&SparseColMat<usize, f64>],
5052) -> Result<SparseColMat<usize, f64>, BasisError> {
5053    if marginal_sparse.is_empty() {
5054        crate::bail_invalid_basis!("TensorBSpline requires at least one marginal basis");
5055    }
5056    let n = marginal_sparse[0].nrows();
5057    for (i, m) in marginal_sparse.iter().enumerate().skip(1) {
5058        if m.nrows() != n {
5059            crate::bail_dim_basis!(
5060                "tensor sparse marginal row mismatch at dim {i}: expected {n}, got {}",
5061                m.nrows()
5062            );
5063        }
5064    }
5065    let dims: Vec<usize> = marginal_sparse.iter().map(|m| m.ncols()).collect();
5066    let total_cols = dims.iter().try_fold(1usize, |acc, &q| {
5067        acc.checked_mul(q)
5068            .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))
5069    })?;
5070    let mut strides = vec![1usize; dims.len()];
5071    for d in (0..dims.len().saturating_sub(1)).rev() {
5072        strides[d] = strides[d + 1]
5073            .checked_mul(dims[d + 1])
5074            .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))?;
5075    }
5076
5077    use faer::sparse::SparseRowMat;
5078    let csrs: Vec<SparseRowMat<usize, f64>> = marginal_sparse
5079        .iter()
5080        .enumerate()
5081        .map(|(d, m)| {
5082            m.as_ref().to_row_major().map_err(|e| {
5083                BasisError::SparseCreation(format!(
5084                    "tensor sparse marginal {d} CSR conversion failed: {e:?}"
5085                ))
5086            })
5087        })
5088        .collect::<Result<Vec<_>, _>>()?;
5089    let row_ptrs: Vec<&[usize]> = csrs.iter().map(|c| c.symbolic().row_ptr()).collect();
5090    let col_idxs: Vec<&[usize]> = csrs.iter().map(|c| c.symbolic().col_idx()).collect();
5091    let vals: Vec<&[f64]> = csrs.iter().map(|c| c.val()).collect();
5092
5093    use rayon::prelude::*;
5094    const CHUNK: usize = 1024;
5095    let num_chunks = n.div_ceil(CHUNK);
5096    let per_chunk: Vec<Vec<Triplet<usize, usize, f64>>> = (0..num_chunks)
5097        .into_par_iter()
5098        .map(|chunk_idx| {
5099            let row_start = chunk_idx * CHUNK;
5100            let row_end = (row_start + CHUNK).min(n);
5101            let mut chunk_triplets = Vec::<Triplet<usize, usize, f64>>::new();
5102            let mut cur_cols = Vec::<usize>::with_capacity(64);
5103            let mut cur_vals = Vec::<f64>::with_capacity(64);
5104            let mut next_cols = Vec::<usize>::with_capacity(64);
5105            let mut next_vals = Vec::<f64>::with_capacity(64);
5106            for i in row_start..row_end {
5107                cur_cols.clear();
5108                cur_vals.clear();
5109                cur_cols.push(0);
5110                cur_vals.push(1.0);
5111                let mut row_is_zero = false;
5112                for d in 0..dims.len() {
5113                    let row_start_d = row_ptrs[d][i];
5114                    let row_end_d = row_ptrs[d][i + 1];
5115                    if row_start_d == row_end_d {
5116                        row_is_zero = true;
5117                        break;
5118                    }
5119                    let stride = strides[d];
5120                    next_cols.clear();
5121                    next_vals.clear();
5122                    next_cols.reserve(cur_cols.len() * (row_end_d - row_start_d));
5123                    next_vals.reserve(cur_vals.len() * (row_end_d - row_start_d));
5124                    for (&prev_col, &prev_val) in cur_cols.iter().zip(cur_vals.iter()) {
5125                        for ptr in row_start_d..row_end_d {
5126                            let cj = col_idxs[d][ptr];
5127                            let vj = vals[d][ptr];
5128                            next_cols.push(prev_col + cj * stride);
5129                            next_vals.push(prev_val * vj);
5130                        }
5131                    }
5132                    std::mem::swap(&mut cur_cols, &mut next_cols);
5133                    std::mem::swap(&mut cur_vals, &mut next_vals);
5134                }
5135                if row_is_zero {
5136                    continue;
5137                }
5138                for (&col, &val) in cur_cols.iter().zip(cur_vals.iter()) {
5139                    chunk_triplets.push(Triplet::new(i, col, val));
5140                }
5141            }
5142            chunk_triplets
5143        })
5144        .collect();
5145    let total_nnz: usize = per_chunk.iter().map(Vec::len).sum();
5146    let mut triplets = Vec::<Triplet<usize, usize, f64>>::with_capacity(total_nnz);
5147    for chunk in per_chunk {
5148        triplets.extend(chunk);
5149    }
5150    SparseColMat::try_new_from_triplets(n, total_cols, &triplets).map_err(|e| {
5151        BasisError::SparseCreation(format!(
5152            "failed to assemble sparse tensor product design: {e:?}"
5153        ))
5154    })
5155}
5156
5157pub fn dense_local_margin_to_sparse(
5158    dense: &Array2<f64>,
5159) -> Result<SparseColMat<usize, f64>, BasisError> {
5160    let expected_row_nnz = dense.ncols().min(4);
5161    let mut triplets =
5162        Vec::<Triplet<usize, usize, f64>>::with_capacity(dense.nrows() * expected_row_nnz);
5163    for ((row, col), &value) in dense.indexed_iter() {
5164        if value != 0.0 {
5165            triplets.push(Triplet::new(row, col, value));
5166        }
5167    }
5168    SparseColMat::try_new_from_triplets(dense.nrows(), dense.ncols(), &triplets).map_err(|e| {
5169        BasisError::SparseCreation(format!(
5170            "failed to convert tensor marginal design to sparse form: {e:?}"
5171        ))
5172    })
5173}
5174
5175pub struct TensorMarginRangeNullProjectors {
5176    range: Array2<f64>,
5177    null: Array2<f64>,
5178}
5179
5180pub fn projector_from_columns(columns: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
5181    if indices.is_empty() {
5182        return Array2::<f64>::zeros((columns.nrows(), columns.nrows()));
5183    }
5184    let basis = columns.select(Axis(1), indices);
5185    basis.dot(&basis.t())
5186}
5187
5188pub fn tensor_margin_range_null_projectors(
5189    normalized_marginal_penalties: &[(Array2<f64>, f64)],
5190) -> Result<Vec<TensorMarginRangeNullProjectors>, BasisError> {
5191    normalized_marginal_penalties
5192        .iter()
5193        .enumerate()
5194        .map(|(dim, (penalty, _))| {
5195            let analysis = crate::basis::analyze_penalty_block(penalty)?;
5196            if analysis.rank == 0 {
5197                crate::bail_invalid_basis!(
5198                    "t2 separable tensor penalty margin {dim} has rank-zero penalty; \
5199                     cannot split penalized and null subspaces"
5200                );
5201            }
5202            let mut range_idx = Vec::<usize>::new();
5203            let mut null_idx = Vec::<usize>::new();
5204            for (idx, &ev) in analysis.eigenvalues.iter().enumerate() {
5205                if ev > analysis.tol {
5206                    range_idx.push(idx);
5207                } else {
5208                    null_idx.push(idx);
5209                }
5210            }
5211            Ok(TensorMarginRangeNullProjectors {
5212                range: projector_from_columns(&analysis.eigenvectors, &range_idx),
5213                null: projector_from_columns(&analysis.eigenvectors, &null_idx),
5214            })
5215        })
5216        .collect()
5217}
5218
5219pub fn build_tensor_bspline_basis(
5220    data: ArrayView2<'_, f64>,
5221    feature_cols: &[usize],
5222    spec: &TensorBSplineSpec,
5223) -> Result<BasisBuildResult, BasisError> {
5224    if feature_cols.is_empty() {
5225        crate::bail_invalid_basis!("TensorBSpline requires at least one feature column");
5226    }
5227    if feature_cols.len() != spec.marginalspecs.len() {
5228        crate::bail_dim_basis!(
5229            "TensorBSpline feature/spec mismatch: feature_cols={}, marginalspecs={}",
5230            feature_cols.len(),
5231            spec.marginalspecs.len()
5232        );
5233    }
5234    if !spec.periods.is_empty() && spec.periods.len() != feature_cols.len() {
5235        crate::bail_dim_basis!(
5236            "TensorBSpline periods length {} does not match feature count {}",
5237            spec.periods.len(),
5238            feature_cols.len()
5239        );
5240    }
5241    let p = data.ncols();
5242    for &c in feature_cols {
5243        if c >= p {
5244            crate::bail_dim_basis!(
5245                "tensor feature column {c} is out of bounds for data with {p} columns"
5246            );
5247        }
5248    }
5249
5250    let mut marginal_knots = Vec::<Array1<f64>>::with_capacity(feature_cols.len());
5251    // Per-margin cr flag (#1074): `true` when the margin is a natural cubic
5252    // regression spline, so the tensor freeze rebuilds the cr knotspec.
5253    let mut marginal_is_cr_flags = Vec::<bool>::with_capacity(feature_cols.len());
5254    let mut marginal_degrees = Vec::<usize>::with_capacity(feature_cols.len());
5255    let mut marginalnum_basis = Vec::<usize>::with_capacity(feature_cols.len());
5256    let mut marginal_penalties = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
5257    let mut marginal_designs = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
5258    // Per-margin effective period: either user-set via `spec.periods` or
5259    // implied by a `PeriodicUniform` marginal knotspec (which the 1D B-spline
5260    // builder realizes as a cyclic B-spline basis).
5261    // Captured here so freeze→reload round-trips both routes back to a
5262    // `PeriodicUniform` marginal knotspec; otherwise a `PeriodicUniform`
5263    // margin specified without `spec.periods` would freeze as a plain
5264    // `Provided(knots)` open spline and lose its wrap-around at predict time.
5265    let mut marginal_effective_periods = Vec::<Option<f64>>::with_capacity(feature_cols.len());
5266    // Per-marginal sparse representation, populated when the 1D builder returned
5267    // a `DesignMatrix::Sparse`. Used to assemble the Khatri-Rao tensor product
5268    // sparsely (only ∏(degree+1) nonzeros per row) instead of densifying to
5269    // shape (n, ∏ q_j) up front. Periodic B-spline margins are local-support
5270    // bases too; when the 1D builder returns them densely, we convert that
5271    // marginal back to sparse form so cylinder/torus tensor products keep the
5272    // same scale behavior as open tensor products.
5273    let mut marginal_sparse =
5274        Vec::<Option<SparseColMat<usize, f64>>>::with_capacity(feature_cols.len());
5275
5276    // Reuse the robust 1D builder to ensure the same knot validation and
5277    // marginal difference-penalty construction as standalone smooth terms.
5278    for (dim, (&col, marginalspec)) in feature_cols
5279        .iter()
5280        .zip(spec.marginalspecs.iter())
5281        .enumerate()
5282    {
5283        // Tensor basis uses raw marginal knot-product columns. Applying 1D
5284        // identifiability constraints here would change marginal penalty sizes
5285        // without changing the tensor design construction, causing dimension
5286        // mismatch. Keep marginal builders unconstrained at this stage.
5287        let mut marginal_unconstrained = marginalspec.clone();
5288        marginal_unconstrained.identifiability = BSplineIdentifiability::None;
5289        let built = build_bspline_basis_1d(data.column(col), &marginal_unconstrained)?;
5290        // A cr (`NaturalCubicRegression`) margin emits `CubicRegression1D`
5291        // metadata whose `knots` are the k value-knots; a B-spline margin emits
5292        // `BSpline1D` with the clamped knot vector. Capture either so the
5293        // tensor freeze can rebuild the exact same marginal knotspec (#1074).
5294        let (knots, marginal_is_cr) = match built.metadata {
5295            BasisMetadata::BSpline1D { knots, .. } => (knots, false),
5296            BasisMetadata::CubicRegression1D { knots, .. } => (knots, true),
5297            _ => {
5298                crate::bail_invalid_basis!(
5299                    "internal TensorBSpline error at dim {dim}: expected BSpline1D or CubicRegression1D metadata"
5300                );
5301            }
5302        };
5303        let metadata_knots = match marginalspec.knotspec {
5304            BSplineKnotSpec::PeriodicUniform {
5305                data_range,
5306                num_basis,
5307            } => Array1::linspace(data_range.0, data_range.1, num_basis),
5308            _ => knots,
5309        };
5310        marginal_knots.push(metadata_knots);
5311        marginal_is_cr_flags.push(marginal_is_cr);
5312        marginal_degrees.push(marginalspec.degree);
5313        marginalnum_basis.push(built.design.ncols());
5314        // Capture the sparse representation of this marginal (when the
5315        // 1D builder produced one) before densifying for the dense
5316        // marginal cache used by `tensor_product_design_from_marginals`
5317        // and `TensorProductDesignOperator`.
5318        let dense_marginal = built.design.to_dense();
5319        let sparse_view: Option<SparseColMat<usize, f64>> = match built.design.as_sparse() {
5320            Some(sd) => {
5321                let inner: &SparseColMat<usize, f64> = sd;
5322                Some(inner.clone())
5323            }
5324            None => match marginalspec.knotspec {
5325                BSplineKnotSpec::PeriodicUniform { .. } => {
5326                    Some(dense_local_margin_to_sparse(&dense_marginal)?)
5327                }
5328                _ => None,
5329            },
5330        };
5331        marginal_sparse.push(sparse_view);
5332        marginal_designs.push(dense_marginal);
5333        marginal_penalties.push(
5334            built
5335                .penalties
5336                .first()
5337                .ok_or_else(|| {
5338                    BasisError::InvalidInput(format!(
5339                        "internal TensorBSpline error at dim {dim}: missing marginal penalty"
5340                    ))
5341                })?
5342                .clone(),
5343        );
5344        built.nullspace_dims.first().ok_or_else(|| {
5345            BasisError::InvalidInput(format!(
5346                "internal TensorBSpline error at dim {dim}: missing marginal nullspace dim"
5347            ))
5348        })?;
5349        // A `PeriodicUniform` marginal knotspec implies the margin is
5350        // wrap-around: the 1D builder already realized it as a periodic
5351        // basis, so the tensor product inherits that periodicity. Record
5352        // the period derived from the knotspec's data range so freeze
5353        // restores `PeriodicUniform` on the marginal — otherwise the
5354        // round-trip downgrades it to `Provided(knots)` (an open spline)
5355        // and predict-time wraps disappear.
5356        let implied_period = match marginalspec.knotspec {
5357            BSplineKnotSpec::PeriodicUniform { data_range, .. } => {
5358                Some(data_range.1 - data_range.0)
5359            }
5360            _ => spec.periods.get(dim).and_then(|p| *p),
5361        };
5362        marginal_effective_periods.push(implied_period);
5363    }
5364
5365    let total_cols: usize = marginalnum_basis.iter().product();
5366    let mut dense_design = (!matches!(spec.identifiability, TensorBSplineIdentifiability::None))
5367        .then(|| tensor_product_design_from_marginals(&marginal_designs))
5368        .transpose()?;
5369    let mut candidates = Vec::<PenaltyCandidate>::with_capacity(
5370        match spec.penalty_decomposition {
5371            TensorBSplinePenaltyDecomposition::MarginalKroneckerSum => marginal_penalties.len(),
5372            TensorBSplinePenaltyDecomposition::Separable => marginal_penalties.len() * 2,
5373        } + if spec.double_penalty { 1 } else { 0 },
5374    );
5375
5376    // Tensor-product smoothing parameters are one-per-margin.  Therefore the
5377    // physical penalty attached to a margin must be normalized in that margin's
5378    // own working coordinates before it is embedded in the full tensor product.
5379    // Normalizing only the already-Kroneckered matrix would fold arbitrary
5380    // dimension-dependent identity factors into the margin's lambda and would
5381    // make anisotropic REML/LAML smoothing depend on the other margins' basis
5382    // sizes rather than on the marginal roughness operator itself.
5383    let normalized_marginal_penalties: Vec<(Array2<f64>, f64)> = marginal_penalties
5384        .iter()
5385        .map(normalize_penalty_in_constrained_space)
5386        .collect();
5387    let mut kronecker_marginal_penalties =
5388        Vec::<Array2<f64>>::with_capacity(normalized_marginal_penalties.len());
5389
5390    match spec.penalty_decomposition {
5391        TensorBSplinePenaltyDecomposition::MarginalKroneckerSum => {
5392            // Accumulate the Kronecker-sum of the per-margin penalties,
5393            // `Σ_dim S_dim`, whose null space is exactly the *joint* null space
5394            // of all marginal penalties — the tensor of marginal polynomial
5395            // null spaces. The tensor double penalty (below) shrinks only this
5396            // joint null, never the already-penalized interaction range.
5397            let mut marginal_kron_sum = Array2::<f64>::zeros((total_cols, total_cols));
5398
5399            for dim in 0..normalized_marginal_penalties.len() {
5400                let mut s_dim = Array2::<f64>::eye(1);
5401                let mut factors = Vec::<Array2<f64>>::with_capacity(marginalnum_basis.len());
5402                for (j, &qj) in marginalnum_basis.iter().enumerate() {
5403                    let factor = if j == dim {
5404                        normalized_marginal_penalties[j].0.clone()
5405                    } else {
5406                        Array2::<f64>::eye(qj)
5407                    };
5408                    factors.push(factor.clone());
5409                    s_dim = kronecker_product(&s_dim, &factor);
5410                }
5411                if dim == kronecker_marginal_penalties.len() {
5412                    kronecker_marginal_penalties.push(normalized_marginal_penalties[dim].0.clone());
5413                }
5414                marginal_kron_sum += &s_dim;
5415
5416                candidates.push(PenaltyCandidate {
5417                    matrix: s_dim,
5418                    nullspace_dim_hint: 0,
5419                    source: PenaltySource::TensorMarginal { dim },
5420                    normalization_scale: normalized_marginal_penalties[dim].1,
5421                    kronecker_factors: Some(factors),
5422                    op: None,
5423                });
5424            }
5425
5426            if spec.double_penalty
5427                && let Some(shrink) =
5428                    crate::basis::build_nullspace_shrinkage_penalty(&marginal_kron_sum)?
5429            {
5430                let (matrix, normalization_scale) =
5431                    normalize_penalty_in_constrained_space(&shrink.sym_penalty);
5432                candidates.push(PenaltyCandidate {
5433                    matrix,
5434                    nullspace_dim_hint: 0,
5435                    source: PenaltySource::TensorGlobalRidge,
5436                    normalization_scale,
5437                    kronecker_factors: None,
5438                    op: None,
5439                });
5440            }
5441        }
5442        TensorBSplinePenaltyDecomposition::Separable => {
5443            let projectors = tensor_margin_range_null_projectors(&normalized_marginal_penalties)?;
5444            let n_masks = 1usize.checked_shl(projectors.len() as u32).ok_or_else(|| {
5445                BasisError::InvalidInput(format!(
5446                    "t2 separable tensor penalty supports at most {} margins, got {}",
5447                    usize::BITS - 1,
5448                    projectors.len()
5449                ))
5450            })?;
5451            for mask in 1..n_masks {
5452                let mut matrix = Array2::<f64>::eye(1);
5453                let mut factors = Vec::<Array2<f64>>::with_capacity(projectors.len());
5454                let mut penalized_margins = Vec::<usize>::new();
5455                for (dim, projector) in projectors.iter().enumerate() {
5456                    let use_range = ((mask >> dim) & 1) == 1;
5457                    let factor = if use_range {
5458                        penalized_margins.push(dim);
5459                        projector.range.clone()
5460                    } else {
5461                        projector.null.clone()
5462                    };
5463                    matrix = kronecker_product(&matrix, &factor);
5464                    factors.push(factor);
5465                }
5466                let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&matrix);
5467                candidates.push(PenaltyCandidate {
5468                    matrix,
5469                    nullspace_dim_hint: 0,
5470                    source: PenaltySource::TensorSeparable { penalized_margins },
5471                    normalization_scale,
5472                    kronecker_factors: Some(factors),
5473                    op: None,
5474                });
5475            }
5476
5477            if spec.double_penalty {
5478                let mut matrix = Array2::<f64>::eye(1);
5479                let mut factors = Vec::<Array2<f64>>::with_capacity(projectors.len());
5480                for projector in &projectors {
5481                    matrix = kronecker_product(&matrix, &projector.null);
5482                    factors.push(projector.null.clone());
5483                }
5484                let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&matrix);
5485                candidates.push(PenaltyCandidate {
5486                    matrix,
5487                    nullspace_dim_hint: 0,
5488                    source: PenaltySource::TensorGlobalRidge,
5489                    normalization_scale,
5490                    kronecker_factors: Some(factors),
5491                    op: None,
5492                });
5493            }
5494        }
5495    }
5496
5497    let z_opt = match &spec.identifiability {
5498        TensorBSplineIdentifiability::None => None,
5499        TensorBSplineIdentifiability::SumToZero => {
5500            if total_cols < 2 {
5501                crate::bail_invalid_basis!(
5502                    "TensorBSpline requires at least 2 basis coefficients to enforce sum-to-zero identifiability"
5503                );
5504            }
5505            let dense_design_ref = dense_design.as_ref().ok_or_else(|| {
5506                BasisError::InvalidInput(
5507                    "tensor sum-to-zero identifiability requires a realized basis".to_string(),
5508                )
5509            })?;
5510            let (_, z) = apply_sum_to_zero_constraint(dense_design_ref.view(), None)?;
5511            let gauge = gam_problem::Gauge::sum_to_zero(z);
5512            Some(gauge.block_transform(0))
5513        }
5514        TensorBSplineIdentifiability::MarginalSumToZero => {
5515            // `ti(...)`: drop the marginal main effects by centering every
5516            // margin independently, then form the tensor product of the
5517            // centered margins. Concretely, each margin `j` is reparameterized
5518            // by its own sum-to-zero null basis `Z_j` (so the constant — i.e.
5519            // the marginal intercept — is removed from that axis), and the
5520            // combined reparameterization is the Kronecker product
5521            // `Z = Z₀ ⊗ Z₁ ⊗ … ⊗ Z_{d-1}`. Applying `Z` to the full-tensor
5522            // design `B = B₀ ⊗ … ⊗ B_{d-1}` yields `B Z = (B₀ Z₀) ⊗ … ⊗
5523            // (B_{d-1} Z_{d-1})`, the tensor product of the centered margins,
5524            // which by construction contains no pure main effect.
5525            if marginal_designs.len() < 2 {
5526                crate::bail_invalid_basis!(
5527                    "tensor interaction (ti) identifiability requires at least 2 margins"
5528                );
5529            }
5530            let mut z = Array2::<f64>::eye(1);
5531            for (dim, marginal) in marginal_designs.iter().enumerate() {
5532                if marginal.ncols() < 2 {
5533                    crate::bail_invalid_basis!(
5534                        "tensor interaction (ti) margin {dim} has fewer than 2 basis functions; \
5535                         cannot remove its marginal main effect"
5536                    );
5537                }
5538                let (_, z_dim) = apply_sum_to_zero_constraint(marginal.view(), None)?;
5539                let gauge_dim = gam_problem::Gauge::sum_to_zero(z_dim);
5540                let z_dim = gauge_dim.block_transform(0);
5541                z = kronecker_product(&z, &z_dim);
5542            }
5543            Some(z)
5544        }
5545        TensorBSplineIdentifiability::FrozenTransform { transform } => {
5546            if transform.nrows() != total_cols {
5547                crate::bail_dim_basis!(
5548                    "frozen tensor identifiability transform mismatch: design has {} columns but transform has {} rows",
5549                    total_cols,
5550                    transform.nrows()
5551                );
5552            }
5553            Some(transform.clone())
5554        }
5555    };
5556
5557    if let Some(z) = z_opt.as_ref() {
5558        let gauge = gam_problem::Gauge::from_block_transforms(&[z.clone()]);
5559        let dense = dense_design.as_mut().ok_or_else(|| {
5560            BasisError::InvalidInput(
5561                "tensor identifiability transform requires a realized basis".to_string(),
5562            )
5563        })?;
5564        let restricted_design = gauge.restrict_design(dense);
5565        *dense = restricted_design;
5566        candidates = candidates
5567            .into_iter()
5568            .map(|candidate| -> Result<PenaltyCandidate, BasisError> {
5569                let matrix = gauge.restrict_penalty(&candidate.matrix);
5570                // Re-normalize in the *actual* coefficient chart used by the
5571                // fit.  The tensor sum-to-zero transform is not norm-preserving
5572                // for each overlapping marginal penalty, so carrying the raw
5573                // marginal Frobenius scale into the restricted space changes the
5574                // relative amount of smoothing seen by the LAML/REML optimizer.
5575                // Keep the physical scale in metadata and give the optimizer
5576                // unit-scale constrained penalties for every tensor margin.
5577                let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
5578                Ok(PenaltyCandidate {
5579                    nullspace_dim_hint: candidate.nullspace_dim_hint,
5580                    matrix,
5581                    source: candidate.source,
5582                    normalization_scale: candidate.normalization_scale * c_new,
5583                    // Z^T S Z is no longer a Kronecker product of the original
5584                    // marginal factors, so the Kronecker fast path in construction.rs
5585                    // must not be taken. Clearing kronecker_factors forces the generic
5586                    // block-local eigendecomposition path, which operates on the
5587                    // transformed matrix and is correct.
5588                    kronecker_factors: None,
5589                    op: candidate.op.clone(),
5590                })
5591            })
5592            .collect::<Result<Vec<_>, _>>()?;
5593    }
5594
5595    let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
5596        filter_active_penalty_candidates_with_ops(candidates)?;
5597    let identifiability_is_none =
5598        matches!(spec.identifiability, TensorBSplineIdentifiability::None);
5599    // All marginals expose a sparse representation iff each `marginal_sparse`
5600    // slot is `Some(...)`. Currently this is true when every marginal is a
5601    // free-boundary, non-periodic 1D B-spline returned as
5602    // `DesignMatrix::Sparse` from `build_bspline_basis_1d`. Periodic B-splines
5603    // and other dense-only marginals leave a `None` and trigger the fall-back
5604    // path. Identifiability transforms (`SumToZero`, `FrozenTransform`) make
5605    // the tensor design dense in general, so we also gate on that.
5606    let all_marginals_sparse = marginal_sparse.iter().all(Option::is_some);
5607    let design = if let Some(dense_design) = dense_design {
5608        DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense_design))
5609    } else if identifiability_is_none && all_marginals_sparse {
5610        // Sparse Khatri-Rao path: assemble the (n, ∏ q_j) tensor product
5611        // directly as a SparseColMat, preserving the ∏(degree_j+1) nonzero
5612        // structure per row instead of densifying to ∏ q_j columns. This is
5613        // mathematically identical to `tensor_product_design_from_marginals`
5614        // applied to the corresponding dense marginals.
5615        let sparse_marginals: Vec<&SparseColMat<usize, f64>> = marginal_sparse
5616            .iter()
5617            .map(|m| m.as_ref().expect("all_marginals_sparse just verified"))
5618            .collect();
5619        let sparse_design = tensor_product_design_from_sparse_marginals(&sparse_marginals)?;
5620        DesignMatrix::Sparse(gam_linalg::matrix::SparseDesignMatrix::new(sparse_design))
5621    } else {
5622        let marginals: Vec<Arc<Array2<f64>>> = marginal_designs
5623            .iter()
5624            .map(|m| Arc::new(m.clone()))
5625            .collect();
5626        let op = TensorProductDesignOperator::new(marginals).map_err(|e| {
5627            BasisError::InvalidInput(format!("TensorProductDesignOperator build failed: {e}"))
5628        })?;
5629        DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op)))
5630    };
5631
5632    Ok(BasisBuildResult {
5633        design,
5634        penalties,
5635        nullspace_dims,
5636        penaltyinfo,
5637        ops,
5638        null_eigenvectors,
5639        joint_null_rotation: None,
5640        metadata: BasisMetadata::TensorBSpline {
5641            feature_cols: feature_cols.to_vec(),
5642            knots: marginal_knots,
5643            degrees: marginal_degrees,
5644            // Prefer the per-margin effective period derived in the loop —
5645            // it captures both the explicit `spec.periods` route and the
5646            // implied period from a `PeriodicUniform` marginal knotspec.
5647            // Falling back to `spec.periods` when populated keeps any
5648            // user-supplied explicit period authoritative even if the
5649            // marginal knotspec carried no periodicity hint.
5650            periods: marginal_effective_periods,
5651            is_cr: marginal_is_cr_flags,
5652            identifiability_transform: z_opt,
5653        },
5654        kronecker_factored: if matches!(spec.identifiability, TensorBSplineIdentifiability::None)
5655            && matches!(
5656                spec.penalty_decomposition,
5657                TensorBSplinePenaltyDecomposition::MarginalKroneckerSum
5658            ) {
5659            Some(KroneckerFactoredBasis::new(
5660                marginal_designs,
5661                kronecker_marginal_penalties,
5662                marginalnum_basis.clone(),
5663                spec.double_penalty,
5664            ))
5665        } else {
5666            None
5667        },
5668    })
5669}
5670
5671pub fn tensor_product_design_from_marginals(
5672    marginal_designs: &[Array2<f64>],
5673) -> Result<Array2<f64>, BasisError> {
5674    if marginal_designs.is_empty() {
5675        crate::bail_invalid_basis!("TensorBSpline requires at least one marginal basis");
5676    }
5677    let n = marginal_designs[0].nrows();
5678    for (i, b) in marginal_designs.iter().enumerate().skip(1) {
5679        if b.nrows() != n {
5680            crate::bail_dim_basis!(
5681                "tensor marginal row mismatch at dim {i}: expected {n}, got {}",
5682                b.nrows()
5683            );
5684        }
5685    }
5686    let total_cols = marginal_designs.iter().try_fold(1usize, |acc, b| {
5687        acc.checked_mul(b.ncols())
5688            .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))
5689    })?;
5690    // Tensor-product Khatri-Rao: design[i, j] = Π_d marginal_d[i, j_d]
5691    // where j is the multi-index (j_1, ..., j_D) flattened. Independent
5692    // across rows; parallelize row chunks and fill the pre-allocated
5693    // contiguous Array2 in place (no Vec-flatten-collect intermediate,
5694    // which doubled the peak memory at large-scale N).
5695    use ndarray::parallel::prelude::*;
5696    use rayon::iter::{IntoParallelIterator, ParallelIterator};
5697    let mut design = Array2::<f64>::zeros((n, total_cols));
5698    design
5699        .axis_chunks_iter_mut(ndarray::Axis(0), 1024)
5700        .into_par_iter()
5701        .enumerate()
5702        .for_each(|(chunk_idx, mut block)| {
5703            let row_offset = chunk_idx * 1024;
5704            // Scratch buffers reused across rows in this chunk.
5705            let mut cur = Vec::<f64>::with_capacity(total_cols);
5706            let mut next = Vec::<f64>::with_capacity(total_cols);
5707            for (local_i, mut out_row) in block.outer_iter_mut().enumerate() {
5708                let i = row_offset + local_i;
5709                cur.clear();
5710                cur.push(1.0);
5711                for b in marginal_designs {
5712                    let q = b.ncols();
5713                    next.clear();
5714                    next.resize(cur.len() * q, 0.0);
5715                    // Hoist the row view out of the inner `col` loop so the
5716                    // q reads per `a_idx` reuse a single contiguous slice
5717                    // instead of recomputing `b[[i, col]]` strides per cell.
5718                    let b_row = b.row(i);
5719                    let b_slice = b_row
5720                        .as_slice()
5721                        .expect("Array2 row from outer_iter is contiguous");
5722                    for (a_idx, &aval) in cur.iter().enumerate() {
5723                        let off = a_idx * q;
5724                        let dst = &mut next[off..off + q];
5725                        for col in 0..q {
5726                            dst[col] = aval * b_slice[col];
5727                        }
5728                    }
5729                    std::mem::swap(&mut cur, &mut next);
5730                }
5731                // `out_row` is a row of the contiguous C-major `design`
5732                // Array2, so it is backed by a contiguous slice. Use a
5733                // bulk slice copy instead of an element-by-element write
5734                // loop.
5735                let out_slice = out_row
5736                    .as_slice_mut()
5737                    .expect("design row is contiguous in C-major Array2");
5738                out_slice.copy_from_slice(&cur);
5739            }
5740        });
5741    Ok(design)
5742}
5743
5744pub fn build_random_effect_block(
5745    data: ArrayView2<'_, f64>,
5746    spec: &RandomEffectTermSpec,
5747) -> Result<RandomEffectBlock, BasisError> {
5748    let n = data.nrows();
5749    let p = data.ncols();
5750    if spec.feature_col >= p {
5751        crate::bail_dim_basis!(
5752            "random-effect term '{}' feature column {} out of bounds for {} columns",
5753            spec.name,
5754            spec.feature_col,
5755            p
5756        );
5757    }
5758
5759    let col = data.column(spec.feature_col);
5760    if col.iter().any(|v| !v.is_finite()) {
5761        crate::bail_invalid_basis!(
5762            "random-effect term '{}' contains non-finite group values",
5763            spec.name
5764        );
5765    }
5766
5767    let kept_levels: Vec<u64> = if let Some(levels) = spec.frozen_levels.as_ref() {
5768        if levels.is_empty() {
5769            crate::bail_invalid_basis!(
5770                "random-effect term '{}' has empty frozen_levels",
5771                spec.name
5772            );
5773        }
5774        levels.clone()
5775    } else {
5776        let mut seen = BTreeSet::<u64>::new();
5777        let mut levels = Vec::<u64>::new();
5778        for &v in col {
5779            let bits = v.to_bits();
5780            if seen.insert(bits) {
5781                levels.push(bits);
5782            }
5783        }
5784        if levels.is_empty() {
5785            crate::bail_invalid_basis!("random-effect term '{}' has no observed levels", spec.name);
5786        }
5787        let start_idx = if spec.drop_first_level && levels.len() > 1 {
5788            1usize
5789        } else {
5790            0usize
5791        };
5792        levels[start_idx..].to_vec()
5793    };
5794
5795    if kept_levels.is_empty() {
5796        crate::bail_invalid_basis!(
5797            "random-effect term '{}' drops all levels; keep at least one level",
5798            spec.name
5799        );
5800    }
5801
5802    let q = kept_levels.len();
5803    let mut level_to_col = BTreeMap::<u64, usize>::new();
5804    for (idx, &bits) in kept_levels.iter().enumerate() {
5805        if level_to_col.insert(bits, idx).is_some() {
5806            crate::bail_invalid_basis!(
5807                "random-effect term '{}' has duplicate frozen level bits {bits}",
5808                spec.name
5809            );
5810        }
5811    }
5812    let mut group_ids = Vec::with_capacity(n);
5813    for &v in col {
5814        let bits = v.to_bits();
5815        group_ids.push(level_to_col.get(&bits).copied());
5816    }
5817
5818    Ok(RandomEffectBlock {
5819        name: spec.name.clone(),
5820        group_ids,
5821        num_groups: q,
5822        kept_levels,
5823    })
5824}
5825
5826impl SmoothDesign {
5827    /// Map an unconstrained term coefficient vector to its constrained shape space.
5828    /// This is useful for nonlinear fits that optimize unconstrained parameters.
5829    pub fn map_term_coefficients(
5830        unconstrained: &Array1<f64>,
5831        shape: ShapeConstraint,
5832    ) -> Result<Array1<f64>, BasisError> {
5833        if unconstrained.is_empty() {
5834            crate::bail_invalid_basis!("unconstrained coefficient vector cannot be empty");
5835        }
5836        let mapped = match shape {
5837            ShapeConstraint::None => unconstrained.clone(),
5838            ShapeConstraint::MonotoneIncreasing => cumulative_exp(unconstrained, 1.0),
5839            ShapeConstraint::MonotoneDecreasing => cumulative_exp(unconstrained, -1.0),
5840            ShapeConstraint::Convex => second_cumulative_exp(unconstrained, 1.0),
5841            ShapeConstraint::Concave => second_cumulative_exp(unconstrained, -1.0),
5842        };
5843        Ok(mapped)
5844    }
5845}
5846
5847pub struct LocalSmoothTermBuild {
5848    pub dim: usize,
5849    pub design: DesignMatrix,
5850    pub penalties: Vec<Array2<f64>>,
5851    pub ops: Vec<Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>>,
5852    pub nullspaces: Vec<usize>,
5853    /// Per-active-penalty null-space eigenvector matrices, parallel to
5854    /// `penalties` / `ops` / `nullspaces`. `Some(U_null)` when
5855    /// `nullspaces[k] > 0`, with `U_null` orthonormal columns spanning
5856    /// `null(penalties[k])` in this smooth's local coordinate system; `None`
5857    /// when the active block is full-rank. Stage 1 plumbing; Stage 2
5858    /// consumes this to absorb the smooth's null space into the parametric
5859    /// block at `TermCollectionDesign` construction.
5860    pub null_eigenvectors: Vec<Option<Array2<f64>>>,
5861    /// Joint-null absorption rotation for this smooth. `Some(rotation)`
5862    /// records `Q = [U_range | U_null]` spanning `null(Σ_k penalties[k])`,
5863    /// the joint null across all active penalty blocks on this smooth.
5864    /// `None` means the joint penalty is full-rank (joint nullity = 0) or
5865    /// there are no penalties. Stage-2 commit A: plumbing only — populated
5866    /// by commit B, applied by commit D.
5867    pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
5868    pub penaltyinfo: Vec<PenaltyInfo>,
5869    pub pre_dropped_penaltyinfo: Vec<PenaltyInfo>,
5870    pub metadata: BasisMetadata,
5871    pub linear_constraints: Option<LinearInequalityConstraints>,
5872    pub box_reparam: bool,
5873    pub kronecker_factored: Option<KroneckerFactoredBasis>,
5874}
5875
5876#[derive(Clone)]
5877pub struct PcaScoresMemmapDesignOperator {
5878    mmap: Arc<memmap2::Mmap>,
5879    data_offset: usize,
5880    nrows: usize,
5881    ncols: usize,
5882    chunk_size: usize,
5883}
5884
5885impl PcaScoresMemmapDesignOperator {
5886    fn open(path: PathBuf, chunk_size: usize) -> Result<Self, BasisError> {
5887        let file = File::open(&path).map_err(|err| {
5888            BasisError::InvalidInput(format!(
5889                "failed to open lazy Pca .npy scores '{}': {err}",
5890                path.display()
5891            ))
5892        })?;
5893        // The .npy scores file is read-only training-cache data; this
5894        // module never mutates it. The error path below converts mmap
5895        // failure to a typed `BasisError::InvalidInput`.
5896        // SAFETY: `memmap2::Mmap::map` requires no concurrent writers; the
5897        // contract is held by this module's read-only access pattern.
5898        let mmap = unsafe {
5899            memmap2::Mmap::map(&file).map_err(|err| {
5900                BasisError::InvalidInput(format!(
5901                    "failed to memmap lazy Pca .npy scores '{}': {err}",
5902                    path.display()
5903                ))
5904            })?
5905        };
5906        let (data_offset, nrows, ncols) = parse_f64_2d_npy_header(&mmap, &path)?;
5907        let expected = data_offset
5908            .checked_add(nrows.saturating_mul(ncols).saturating_mul(8))
5909            .ok_or_else(|| {
5910                BasisError::InvalidInput(format!(
5911                    "lazy Pca .npy scores '{}' shape is too large",
5912                    path.display()
5913                ))
5914            })?;
5915        if mmap.len() < expected {
5916            crate::bail_invalid_basis!(
5917                "lazy Pca .npy scores '{}' is truncated: header expects {} bytes, file has {}",
5918                path.display(),
5919                expected,
5920                mmap.len()
5921            );
5922        }
5923        Ok(Self {
5924            mmap: Arc::new(mmap),
5925            data_offset,
5926            nrows,
5927            ncols,
5928            chunk_size: chunk_size.max(1),
5929        })
5930    }
5931
5932    fn value(&self, row: usize, col: usize) -> f64 {
5933        let offset = self.data_offset + (row * self.ncols + col) * 8;
5934        let mut bytes = [0_u8; 8];
5935        bytes.copy_from_slice(&self.mmap[offset..offset + 8]);
5936        f64::from_le_bytes(bytes)
5937    }
5938
5939    fn chunk_rows(&self) -> usize {
5940        self.chunk_size.min(self.nrows.max(1))
5941    }
5942}
5943
5944impl LinearOperator for PcaScoresMemmapDesignOperator {
5945    fn nrows(&self) -> usize {
5946        self.nrows
5947    }
5948
5949    fn ncols(&self) -> usize {
5950        self.ncols
5951    }
5952
5953    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
5954        assert_eq!(
5955            vector.len(),
5956            self.ncols,
5957            "lazy Pca apply vector length mismatch"
5958        );
5959        let mut out = Array1::<f64>::zeros(self.nrows);
5960        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5961            let end = (start + self.chunk_rows()).min(self.nrows);
5962            for row in start..end {
5963                let mut acc = 0.0;
5964                for col in 0..self.ncols {
5965                    acc += self.value(row, col) * vector[col];
5966                }
5967                out[row] = acc;
5968            }
5969        }
5970        out
5971    }
5972
5973    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
5974        assert_eq!(
5975            vector.len(),
5976            self.nrows,
5977            "lazy Pca apply_transpose vector length mismatch"
5978        );
5979        let mut out = Array1::<f64>::zeros(self.ncols);
5980        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5981            let end = (start + self.chunk_rows()).min(self.nrows);
5982            for row in start..end {
5983                let scale = vector[row];
5984                if scale == 0.0 {
5985                    continue;
5986                }
5987                for col in 0..self.ncols {
5988                    out[col] += scale * self.value(row, col);
5989                }
5990            }
5991        }
5992        out
5993    }
5994
5995    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
5996        if weights.len() != self.nrows {
5997            return Err(format!(
5998                "lazy Pca diag_xtw_x weight length mismatch: weights={}, nrows={}",
5999                weights.len(),
6000                self.nrows
6001            ));
6002        }
6003        let mut gram = Array2::<f64>::zeros((self.ncols, self.ncols));
6004        for start in (0..self.nrows).step_by(self.chunk_rows()) {
6005            let end = (start + self.chunk_rows()).min(self.nrows);
6006            for row in start..end {
6007                let w = weights[row];
6008                if w == 0.0 {
6009                    continue;
6010                }
6011                for a in 0..self.ncols {
6012                    let xa = self.value(row, a);
6013                    if xa == 0.0 {
6014                        continue;
6015                    }
6016                    for b in a..self.ncols {
6017                        gram[[a, b]] += w * xa * self.value(row, b);
6018                    }
6019                }
6020            }
6021        }
6022        for a in 0..self.ncols {
6023            for b in 0..a {
6024                gram[[a, b]] = gram[[b, a]];
6025            }
6026        }
6027        Ok(gram)
6028    }
6029
6030    fn apply_weighted_normal(
6031        &self,
6032        weights: &Array1<f64>,
6033        vector: &Array1<f64>,
6034        penalty: Option<&Array2<f64>>,
6035        ridge: f64,
6036    ) -> Array1<f64> {
6037        assert_eq!(
6038            weights.len(),
6039            self.nrows,
6040            "lazy Pca weighted-normal weight mismatch"
6041        );
6042        assert_eq!(
6043            vector.len(),
6044            self.ncols,
6045            "lazy Pca weighted-normal vector mismatch"
6046        );
6047        let mut out = Array1::<f64>::zeros(self.ncols);
6048        for start in (0..self.nrows).step_by(self.chunk_rows()) {
6049            let end = (start + self.chunk_rows()).min(self.nrows);
6050            for row in start..end {
6051                let w = weights[row].max(0.0);
6052                if w == 0.0 {
6053                    continue;
6054                }
6055                let mut row_dot = 0.0;
6056                for col in 0..self.ncols {
6057                    row_dot += self.value(row, col) * vector[col];
6058                }
6059                if row_dot == 0.0 {
6060                    continue;
6061                }
6062                let scaled = w * row_dot;
6063                for col in 0..self.ncols {
6064                    out[col] += scaled * self.value(row, col);
6065                }
6066            }
6067        }
6068        if let Some(pen) = penalty {
6069            out += &pen.dot(vector);
6070        }
6071        if ridge > 0.0 {
6072            out += &vector.mapv(|x| ridge * x);
6073        }
6074        out
6075    }
6076}
6077
6078impl DenseDesignOperator for PcaScoresMemmapDesignOperator {
6079    fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
6080        if weights.len() != self.nrows || y.len() != self.nrows {
6081            return Err(format!(
6082                "lazy Pca compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
6083                weights.len(),
6084                y.len(),
6085                self.nrows
6086            ));
6087        }
6088        let mut out = Array1::<f64>::zeros(self.ncols);
6089        for start in (0..self.nrows).step_by(self.chunk_rows()) {
6090            let end = (start + self.chunk_rows()).min(self.nrows);
6091            for row in start..end {
6092                let scale = weights[row] * y[row];
6093                if scale == 0.0 {
6094                    continue;
6095                }
6096                for col in 0..self.ncols {
6097                    out[col] += scale * self.value(row, col);
6098                }
6099            }
6100        }
6101        Ok(out)
6102    }
6103
6104    fn row_chunk_into(
6105        &self,
6106        rows: Range<usize>,
6107        mut out: ArrayViewMut2<'_, f64>,
6108    ) -> Result<(), MatrixMaterializationError> {
6109        if rows.end > self.nrows || rows.start > rows.end {
6110            return Err(MatrixMaterializationError::MissingRowChunk {
6111                context: "lazy Pca row range out of bounds",
6112            });
6113        }
6114        if out.nrows() != rows.end - rows.start || out.ncols() != self.ncols {
6115            return Err(MatrixMaterializationError::MissingRowChunk {
6116                context: "lazy Pca row_chunk_into shape mismatch",
6117            });
6118        }
6119        for (local, row) in (rows.start..rows.end).enumerate() {
6120            for col in 0..self.ncols {
6121                out[[local, col]] = self.value(row, col);
6122            }
6123        }
6124        Ok(())
6125    }
6126
6127    fn to_dense(&self) -> Array2<f64> {
6128        let mut out = Array2::<f64>::zeros((self.nrows, self.ncols));
6129        self.row_chunk_into(0..self.nrows, out.view_mut())
6130            .expect("lazy Pca full materialization failed");
6131        out
6132    }
6133}
6134
6135pub fn parse_f64_2d_npy_header(
6136    bytes: &[u8],
6137    path: &PathBuf,
6138) -> Result<(usize, usize, usize), BasisError> {
6139    if bytes.len() < 10 || &bytes[0..6] != b"\x93NUMPY" {
6140        crate::bail_invalid_basis!("lazy Pca scores '{}' is not a .npy file", path.display());
6141    }
6142    let major = bytes[6];
6143    let header_len = match major {
6144        1 => u16::from_le_bytes([bytes[8], bytes[9]]) as usize,
6145        2 | 3 => {
6146            if bytes.len() < 12 {
6147                crate::bail_invalid_basis!(
6148                    "lazy Pca scores '{}' has a truncated .npy header",
6149                    path.display()
6150                );
6151            }
6152            u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize
6153        }
6154        other => {
6155            crate::bail_invalid_basis!(
6156                "lazy Pca scores '{}' uses unsupported .npy version {}",
6157                path.display(),
6158                other
6159            );
6160        }
6161    };
6162    let header_start = if major == 1 { 10 } else { 12 };
6163    let data_offset = header_start + header_len;
6164    if bytes.len() < data_offset {
6165        crate::bail_invalid_basis!(
6166            "lazy Pca scores '{}' has a truncated .npy header",
6167            path.display()
6168        );
6169    }
6170    let header = std::str::from_utf8(&bytes[header_start..data_offset]).map_err(|err| {
6171        BasisError::InvalidInput(format!(
6172            "lazy Pca scores '{}' has a non-UTF8 .npy header: {err}",
6173            path.display()
6174        ))
6175    })?;
6176    if !(header.contains("'descr': '<f8'")
6177        || header.contains("\"descr\": \"<f8\"")
6178        || header.contains("'descr': '|f8'")
6179        || header.contains("\"descr\": \"|f8\""))
6180    {
6181        crate::bail_invalid_basis!(
6182            "lazy Pca scores '{}' must be float64 little-endian .npy",
6183            path.display()
6184        );
6185    }
6186    if header.contains("True") {
6187        crate::bail_invalid_basis!(
6188            "lazy Pca scores '{}' must be C-contiguous, not Fortran-ordered",
6189            path.display()
6190        );
6191    }
6192    let shape_pos = header.find("shape").ok_or_else(|| {
6193        BasisError::InvalidInput(format!(
6194            "lazy Pca scores '{}' .npy header is missing shape",
6195            path.display()
6196        ))
6197    })?;
6198    let open = header[shape_pos..].find('(').ok_or_else(|| {
6199        BasisError::InvalidInput(format!(
6200            "lazy Pca scores '{}' .npy header has malformed shape",
6201            path.display()
6202        ))
6203    })? + shape_pos;
6204    let close = header[open..].find(')').ok_or_else(|| {
6205        BasisError::InvalidInput(format!(
6206            "lazy Pca scores '{}' .npy header has malformed shape",
6207            path.display()
6208        ))
6209    })? + open;
6210    let dims = header[open + 1..close]
6211        .split(',')
6212        .map(str::trim)
6213        .filter(|part| !part.is_empty())
6214        .map(|part| part.parse::<usize>())
6215        .collect::<Result<Vec<_>, _>>()
6216        .map_err(|err| {
6217            BasisError::InvalidInput(format!(
6218                "lazy Pca scores '{}' .npy shape is not integral: {err}",
6219                path.display()
6220            ))
6221        })?;
6222    if dims.len() != 2 {
6223        crate::bail_invalid_basis!(
6224            "lazy Pca scores '{}' must have shape (N, K), got {:?}",
6225            path.display(),
6226            dims
6227        );
6228    }
6229    Ok((data_offset, dims[0], dims[1]))
6230}
6231
6232pub fn pca_center_mean(x: ArrayView2<'_, f64>) -> Result<Array1<f64>, BasisError> {
6233    if x.nrows() == 0 {
6234        crate::bail_invalid_basis!("Pca basis requires at least one row to compute center mean");
6235    }
6236    let mut mean = Array1::<f64>::zeros(x.ncols());
6237    for row in x.rows() {
6238        mean += &row;
6239    }
6240    mean.mapv_inplace(|v| v / x.nrows() as f64);
6241    Ok(mean)
6242}
6243
6244pub fn build_pca_smooth_basis(
6245    data: ArrayView2<'_, f64>,
6246    feature_cols: &[usize],
6247    basis_matrix: &Array2<f64>,
6248    centered: bool,
6249    smooth_penalty: f64,
6250    center_mean: Option<&Array1<f64>>,
6251    pca_basis_path: Option<&PathBuf>,
6252    chunk_size: usize,
6253) -> Result<BasisBuildResult, BasisError> {
6254    if let Some(path) = pca_basis_path {
6255        let op = PcaScoresMemmapDesignOperator::open(path.clone(), chunk_size)?;
6256        if op.nrows != data.nrows() {
6257            crate::bail_dim_basis!(
6258                "lazy Pca scores row mismatch: .npy has {}, data has {}",
6259                op.nrows,
6260                data.nrows()
6261            );
6262        }
6263        let k = op.ncols;
6264        let mut penalty = Array2::<f64>::eye(k);
6265        penalty.mapv_inplace(|v| v * smooth_penalty);
6266        let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
6267            filter_active_penalty_candidates_with_ops(vec![PenaltyCandidate {
6268                matrix: penalty,
6269                nullspace_dim_hint: 0,
6270                source: PenaltySource::Other("PcaRidge".to_string()),
6271                normalization_scale: 1.0,
6272                kronecker_factors: None,
6273                op: None,
6274            }])?;
6275        return Ok(BasisBuildResult {
6276            design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op))),
6277            penalties,
6278            nullspace_dims,
6279            penaltyinfo,
6280            ops,
6281            null_eigenvectors,
6282            joint_null_rotation: None,
6283            metadata: BasisMetadata::Pca {
6284                feature_cols: feature_cols.to_vec(),
6285                basis_matrix: basis_matrix.clone(),
6286                centered,
6287                smooth_penalty,
6288                center_mean: center_mean.cloned(),
6289                pca_basis_path: Some(path.clone()),
6290                chunk_size: chunk_size.max(1),
6291            },
6292            kronecker_factored: None,
6293        });
6294    }
6295    if basis_matrix.nrows() != feature_cols.len() {
6296        crate::bail_dim_basis!(
6297            "Pca basis row mismatch: basis rows={}, feature columns={}",
6298            basis_matrix.nrows(),
6299            feature_cols.len()
6300        );
6301    }
6302    let mut x = select_columns(data, feature_cols)?;
6303    let mean = if centered {
6304        match center_mean {
6305            Some(mean) => mean.clone(),
6306            None => pca_center_mean(x.view())?,
6307        }
6308    } else {
6309        Array1::<f64>::zeros(feature_cols.len())
6310    };
6311    if centered {
6312        for mut row in x.rows_mut() {
6313            row -= &mean;
6314        }
6315    }
6316    let design = fast_ab(&x, basis_matrix);
6317    let k = basis_matrix.ncols();
6318    let mut penalty = Array2::<f64>::eye(k);
6319    penalty.mapv_inplace(|v| v * smooth_penalty);
6320    let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
6321        filter_active_penalty_candidates_with_ops(vec![PenaltyCandidate {
6322            matrix: penalty,
6323            nullspace_dim_hint: 0,
6324            source: PenaltySource::Other("PcaRidge".to_string()),
6325            normalization_scale: 1.0,
6326            kronecker_factors: None,
6327            op: None,
6328        }])?;
6329    Ok(BasisBuildResult {
6330        design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(design)),
6331        penalties,
6332        nullspace_dims,
6333        penaltyinfo,
6334        ops,
6335        null_eigenvectors,
6336        joint_null_rotation: None,
6337        metadata: BasisMetadata::Pca {
6338            feature_cols: feature_cols.to_vec(),
6339            basis_matrix: basis_matrix.clone(),
6340            centered,
6341            smooth_penalty,
6342            center_mean: centered.then_some(mean),
6343            pca_basis_path: None,
6344            chunk_size: chunk_size.max(1),
6345        },
6346        kronecker_factored: None,
6347    })
6348}
6349
6350/// A factor-level `by=` wrapper owns the model-space centering of its inner
6351/// smooth: it gates the raw/structurally-constrained basis to the level rows
6352/// and then centers that gated block exactly once against the level indicator
6353/// (`build_parametric_constraint_block_for_term` in `design_construction`).
6354/// Leaving the inner B-spline's default pooled weighted-sum-to-zero active here
6355/// would impose two generically-independent constraints — the pooled column
6356/// moment `m = Σ_h m_h` and the per-level moment `m_g` — so a raw `k`-column
6357/// basis collapses to `k-2` columns per level instead of `k-1`, deleting one
6358/// genuine nonconstant spline direction *before REML runs* (#1427). The group
6359/// main effect carries only the constant, so it cannot restore that direction.
6360///
6361/// Only the *default model-space* centering is deferred. Explicit structural or
6362/// frozen transforms (`RemoveLinearTrend`, `OrthogonalToDesignColumns`,
6363/// `FrozenTransform`, `None`) are user/structural choices and are preserved
6364/// verbatim.
6365pub fn defer_inner_model_centering_to_factor_level_wrapper(basis: &mut SmoothBasisSpec) {
6366    if let SmoothBasisSpec::BSpline1D { spec, .. } = basis
6367        && matches!(
6368            spec.identifiability,
6369            BSplineIdentifiability::WeightedSumToZero { .. }
6370        )
6371    {
6372        spec.identifiability = BSplineIdentifiability::None;
6373    }
6374}
6375
6376pub fn apply_by_variable_to_local_build(
6377    mut built: LocalSmoothTermBuild,
6378    data: ArrayView2<'_, f64>,
6379    by_col: usize,
6380    by: &ByVariableSpec,
6381    term_name: &str,
6382) -> Result<LocalSmoothTermBuild, BasisError> {
6383    if by_col >= data.ncols() {
6384        crate::bail_dim_basis!(
6385            "by-variable smooth term '{term_name}' references column {by_col}, but data has {} columns",
6386            data.ncols()
6387        );
6388    }
6389    let weights = match by {
6390        ByVariableSpec::Numeric => data.column(by_col).to_owned(),
6391        ByVariableSpec::Level { value_bits, .. } => data.column(by_col).mapv(|value| {
6392            if value.to_bits() == *value_bits {
6393                1.0
6394            } else {
6395                0.0
6396            }
6397        }),
6398    };
6399    if weights.iter().any(|value| !value.is_finite()) {
6400        crate::bail_invalid_basis!(
6401            "by-variable smooth term '{term_name}' has non-finite by-column values"
6402        );
6403    }
6404
6405    let mut dense = built
6406        .design
6407        .try_to_dense_by_chunks("by-variable smooth row gating")
6408        .map_err(BasisError::InvalidInput)?;
6409    for (mut row, &weight) in dense.rows_mut().into_iter().zip(weights.iter()) {
6410        row.mapv_inplace(|value| value * weight);
6411    }
6412    built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense));
6413    built.kronecker_factored = None;
6414    Ok(built)
6415}
6416
6417/// Build the local smooth term for a `BySmooth` spec, which unifies numeric-by
6418/// and factor-by modulation into a single `SmoothTermSpec`.
6419///
6420/// For a **numeric** by-variable the inner smooth is built once and every row
6421/// is multiplied by the by-column value (identical to `ByVariable::Numeric`).
6422///
6423/// For a **factor** by-variable the inner smooth is built once and gated per
6424/// level into side-by-side column blocks, producing a `n × (L * p)` design
6425/// matrix.  The penalties are block-diagonalised (one copy of the inner penalty
6426/// per level) exactly as `build_factor_smooth` does for `bs="fs"/"sz"`.
6427pub fn build_by_smooth_local(
6428    data: ArrayView2<'_, f64>,
6429    term: &SmoothTermSpec,
6430    smooth: &SmoothBasisSpec,
6431    by_kind: &ByVarKind,
6432    workspace: &mut crate::basis::BasisWorkspace,
6433) -> Result<LocalSmoothTermBuild, BasisError> {
6434    let inner_term = SmoothTermSpec {
6435        name: term.name.clone(),
6436        basis: (*smooth).clone(),
6437        shape: term.shape,
6438        joint_null_rotation: None,
6439    };
6440    let inner = build_single_local_smooth_term(data, &inner_term, workspace)?;
6441
6442    match by_kind {
6443        ByVarKind::Numeric { feature_col } => {
6444            let inner_meta = inner.metadata.clone();
6445            let mut built = apply_by_variable_to_local_build(
6446                inner,
6447                data,
6448                *feature_col,
6449                &ByVariableSpec::Numeric,
6450                &term.name,
6451            )?;
6452            built.metadata = BasisMetadata::BySmooth {
6453                inner: Box::new(inner_meta),
6454                by_col: *feature_col,
6455                levels: None,
6456                ordered: false,
6457            };
6458            Ok(built)
6459        }
6460        ByVarKind::Factor {
6461            feature_col,
6462            frozen_levels,
6463            ordered,
6464        } => {
6465            // Collect factor levels: prefer the frozen set (replay path), else
6466            // scan the data column (first-fit path).
6467            let level_bits: Vec<u64> = if let Some(fl) = frozen_levels {
6468                fl.clone()
6469            } else {
6470                let col = data.column(*feature_col);
6471                let mut seen = BTreeSet::<u64>::new();
6472                for &v in col.iter() {
6473                    if v.is_finite() {
6474                        seen.insert(v.to_bits());
6475                    }
6476                }
6477                seen.into_iter().collect()
6478            };
6479            let n_levels = level_bits.len();
6480            if n_levels == 0 {
6481                crate::bail_invalid_basis!(
6482                    "by-factor smooth term '{}': factor column {} has no observed levels",
6483                    term.name,
6484                    feature_col
6485                );
6486            }
6487            let p = inner.dim;
6488            let q = n_levels * p;
6489            let n = data.nrows();
6490
6491            let inner_dense = inner
6492                .design
6493                .try_to_dense_by_chunks("by-factor smooth design gating")
6494                .map_err(BasisError::InvalidInput)?;
6495
6496            // Gate each level into its own p-wide column block.
6497            let mut combined = Array2::<f64>::zeros((n, q));
6498            for (lvl_idx, &bits) in level_bits.iter().enumerate() {
6499                let col_start = lvl_idx * p;
6500                for row in 0..n {
6501                    if data[[row, *feature_col]].to_bits() == bits {
6502                        combined
6503                            .slice_mut(s![row, col_start..col_start + p])
6504                            .assign(&inner_dense.row(row));
6505                    }
6506                }
6507            }
6508
6509            // Build per-level INDEPENDENT penalties (#1427): one copy of each
6510            // inner penalty per level, but each confined to that single level's
6511            // diagonal block, so every (level, inner-penalty) pair is its OWN
6512            // smoothing-parameter coordinate. `s(x, by=g)` selects the per-group
6513            // curve wiggliness independently — the design is block-diagonal and
6514            // block-separable, so a correct REML must reproduce gamfit's own
6515            // independent per-group fits. Tiling a single inner penalty across
6516            // every level (as the `bs="fs"` shared-λ random-effect construction
6517            // does) collapses all groups onto ONE λ, which cannot match uneven
6518            // per-level smoothness and degrades as data grows (under-recovery up
6519            // to ~16× at n=2000). Emit `n_levels * n_penalties` blocks instead.
6520            let inner_meta = inner.metadata.clone();
6521            let n_penalties = inner.penalties.len();
6522            let n_blocks = n_penalties.saturating_mul(n_levels);
6523            let mut penalties = Vec::<Array2<f64>>::with_capacity(n_blocks);
6524            let mut penaltyinfo = Vec::<PenaltyInfo>::with_capacity(n_blocks);
6525            let mut nullspaces = Vec::<usize>::with_capacity(n_blocks);
6526            for (pen_pos, s_inner) in inner.penalties.iter().enumerate() {
6527                for lvl in 0..n_levels {
6528                    let off = lvl * p;
6529                    let mut s_big = Array2::<f64>::zeros((q, q));
6530                    s_big
6531                        .slice_mut(s![off..off + p, off..off + p])
6532                        .assign(s_inner);
6533                    let (s_big, scale) = normalize_penalty_in_constrained_space(&s_big);
6534                    let mut info = inner.penaltyinfo[pen_pos].clone();
6535                    // Distinct original_index per (penalty, level) so each λ is a
6536                    // separate identifiable coordinate downstream.
6537                    info.original_index = pen_pos * n_levels + lvl;
6538                    info.normalization_scale *= scale;
6539                    // Each block now spans exactly ONE level → per-level nullity,
6540                    // not the tiled (× n_levels) hint of the shared construction.
6541                    info.kronecker_factors = None;
6542                    penalties.push(s_big);
6543                    penaltyinfo.push(info);
6544                    nullspaces.push(inner.nullspaces[pen_pos]);
6545                }
6546            }
6547
6548            let null_eigenvectors = vec![None; penalties.len()];
6549            let ops = vec![None; penalties.len()];
6550
6551            Ok(LocalSmoothTermBuild {
6552                dim: q,
6553                design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(combined)),
6554                penalties,
6555                ops,
6556                nullspaces,
6557                null_eigenvectors,
6558                joint_null_rotation: None,
6559                penaltyinfo,
6560                pre_dropped_penaltyinfo: inner.pre_dropped_penaltyinfo,
6561                metadata: BasisMetadata::BySmooth {
6562                    inner: Box::new(inner_meta),
6563                    by_col: *feature_col,
6564                    levels: Some(level_bits),
6565                    ordered: *ordered,
6566                },
6567                linear_constraints: None,
6568                box_reparam: false,
6569                kronecker_factored: None,
6570            })
6571        }
6572    }
6573}
6574
6575pub fn ensure_by_variable_specs_match(
6576    kind: &BySmoothKind,
6577    by: &ByVariableSpec,
6578    term_name: &str,
6579) -> Result<(), BasisError> {
6580    match (kind, by) {
6581        (BySmoothKind::Numeric, ByVariableSpec::Numeric) => Ok(()),
6582        (BySmoothKind::Level { level_bits }, ByVariableSpec::Level { value_bits, .. })
6583            if level_bits == value_bits =>
6584        {
6585            Ok(())
6586        }
6587        _ => Err(BasisError::InvalidInput(format!(
6588            "by-variable smooth term '{term_name}' has inconsistent by-variable specifications"
6589        ))),
6590    }
6591}
6592
6593/// Build a factor-smooth interaction basis (`bs="fs"`/`"sz"`/`"re"`).
6594///
6595/// A factor smooth replicates a shared marginal smooth in the continuous
6596/// covariate(s) once per level of a grouping factor, coupling all level blocks
6597/// through a *single* set of smoothing parameters (one per marginal penalty).
6598/// This is mgcv's `smooth.construct.fs.smooth.spec` realization and the
6599/// random-effect interpretation of a smooth: the per-level deviations are an
6600/// exchangeable family whose joint wiggliness/shrinkage is governed by the
6601/// shared λ, so the construction scales to many levels with a fixed parameter
6602/// count.
6603///
6604/// Flavours:
6605/// * `Fs` — full random factor-smooth. The marginal carries its wiggliness
6606///   penalty *and* a null-space ridge (double penalty), so the replicated
6607///   design is a proper full-rank random effect: each level's curve is shrunk
6608///   toward zero (intercept + linear trend included), recovering the mgcv
6609///   `bs="fs"` penalty structure `I_L ⊗ S_j` for every marginal penalty `S_j`.
6610/// * `Sz` — sum-to-zero factor smooth. Delegates to the existing
6611///   [`SmoothBasisSpec::FactorSumToZero`] construction (`L-1` deviation blocks,
6612///   coefficient-wise zero sum across levels).
6613/// * `Re` — pure random effect / random slope (`bs="re"`). A degree-1 marginal
6614///   gives the per-level `[1, x]` span; the penalty is the identity over each
6615///   level block (iid Gaussian coefficients), matching mgcv's `bs="re"` ridge.
6616///
6617/// The grouping levels are resolved once at fit time (sorted unique bit
6618/// patterns of the factor column) and frozen into the returned metadata so the
6619/// predict-time rebuild evaluates every row against its own level's block.
6620pub fn build_factor_smooth(
6621    data: ArrayView2<'_, f64>,
6622    spec: &FactorSmoothSpec,
6623    term_name: &str,
6624    workspace: &mut crate::basis::BasisWorkspace,
6625) -> Result<LocalSmoothTermBuild, BasisError> {
6626    if spec.continuous_cols.len() != 1 {
6627        crate::bail_invalid_basis!(
6628            "factor smooth term '{}' currently supports exactly one continuous covariate; found {}",
6629            term_name,
6630            spec.continuous_cols.len()
6631        );
6632    }
6633    let feature_col = spec.continuous_cols[0];
6634    let group_col = spec.group_col;
6635    if feature_col >= data.ncols() || group_col >= data.ncols() {
6636        crate::bail_dim_basis!(
6637            "factor smooth term '{}' references columns ({}, {}) out of bounds for {} columns",
6638            term_name,
6639            feature_col,
6640            group_col,
6641            data.ncols()
6642        );
6643    }
6644
6645    // `Sz` is exactly the existing sum-to-zero factor smooth: reuse it verbatim
6646    // so there is a single source of truth for the zero-sum construction.
6647    if matches!(spec.flavour, FactorSmoothFlavour::Sz) {
6648        let levels = resolve_factor_smooth_levels(data, group_col, spec, term_name)?;
6649        let inner = SmoothBasisSpec::BSpline1D {
6650            feature_col,
6651            spec: factor_smooth_marginal_for_replay(&spec.marginal),
6652        };
6653        let sz_term = SmoothTermSpec {
6654            name: term_name.to_string(),
6655            basis: SmoothBasisSpec::FactorSumToZero {
6656                inner: Box::new(inner),
6657                by_col: group_col,
6658                levels: levels.clone(),
6659                frozen_global_orthogonality: None,
6660            },
6661            shape: ShapeConstraint::None,
6662            joint_null_rotation: None,
6663        };
6664        let mut built = build_single_local_smooth_term(data, &sz_term, workspace)?;
6665        // The delegated `FactorSumToZero` build returns the BARE inner B-spline
6666        // metadata (`BasisMetadata::BSpline1D`), but the term that owns this
6667        // build carries a `SmoothBasisSpec::FactorSmooth { Sz }` spec. Two
6668        // things break if we hand that mismatched pair downstream:
6669        //   1. `freeze_smooth_basis_from_metadata` matches on (spec, metadata)
6670        //      and has no `(FactorSmooth, BSpline1D)` arm, so any refit / spatial
6671        //      re-optimization that freezes the basis aborts with a "smooth
6672        //      metadata/spec type mismatch" error.
6673        //   2. The bare B-spline metadata carries no grouping levels, so a
6674        //      predict-time rebuild cannot replay the SAME replicated design.
6675        // Re-wrap the marginal geometry as `FactorSmooth` metadata exactly as
6676        // the Fs/Re path below does, giving all three factor-smooth flavours a
6677        // single, freeze-consistent metadata shape that also pins the levels.
6678        // Since #1605 the sz marginal is ALWAYS the penalized B-spline the `fs`
6679        // sibling uses (a natural cubic regression marginal hard-enforces f''=0
6680        // at the boundary and cannot represent curved deviations — a consistency
6681        // failure). The `CubicRegression1D` arm below is therefore unreachable on
6682        // a freshly-built sz spec; it is retained only as defense / backward
6683        // compatibility for a frozen spec that still carries a cr marginal, so
6684        // the predict-time freeze restores whatever marginal class it finds.
6685        let (knots, degree, periodic, marginal_is_cr) = match &built.metadata {
6686            BasisMetadata::BSpline1D {
6687                knots,
6688                periodic,
6689                degree,
6690                ..
6691            } => (
6692                knots.clone(),
6693                degree.unwrap_or(spec.marginal.degree),
6694                *periodic,
6695                false,
6696            ),
6697            BasisMetadata::CubicRegression1D { knots, .. } => {
6698                (knots.clone(), spec.marginal.degree, None, true)
6699            }
6700            other => {
6701                crate::bail_invalid_basis!(
6702                    "sz factor smooth term '{}' produced an unexpected marginal metadata variant {:?}",
6703                    term_name,
6704                    other
6705                );
6706            }
6707        };
6708        built.metadata = BasisMetadata::FactorSmooth {
6709            continuous_cols: spec.continuous_cols.clone(),
6710            group_col,
6711            knots,
6712            degree,
6713            periodic,
6714            group_levels: levels,
6715            flavour: "sz".to_string(),
6716            marginal_is_cr,
6717        };
6718        return Ok(built);
6719    }
6720
6721    let levels = resolve_factor_smooth_levels(data, group_col, spec, term_name)?;
6722    let n_levels = levels.len();
6723    if n_levels < 2 {
6724        crate::bail_invalid_basis!(
6725            "factor smooth term '{}' requires at least two grouping levels; found {}",
6726            term_name,
6727            n_levels
6728        );
6729    }
6730
6731    // `Fs` (order ≥ 1, the default) is the random-effect flavour: it penalizes
6732    // each null-space dimension of the marginal wiggliness penalty separately
6733    // below (mgcv's `bs="fs"` construction). That replaces the marginal's single
6734    // *combined* double penalty, so disable the latter here to avoid penalizing
6735    // the null space twice (once combined, once per dimension). The explicit
6736    // `m=0` opt-out keeps the legacy combined double penalty and adds no
6737    // per-dimension penalties.
6738    let use_per_dim_null = matches!(
6739        &spec.flavour,
6740        FactorSmoothFlavour::Fs { m_null_penalty_orders }
6741            if m_null_penalty_orders.iter().copied().max().unwrap_or(0) >= 1
6742    );
6743
6744    // Build the shared marginal design + penalties from the 1-D B-spline.
6745    // `Re` forces a degree-1 marginal (linear span) and replaces the marginal
6746    // wiggliness with an identity ridge below; `Fs` keeps the user's marginal
6747    // (cubic by default) and, under the per-dimension null path, gets its null
6748    // space penalized one dimension at a time after replication.
6749    let mut marginal_spec = factor_smooth_marginal_for_replay(&spec.marginal);
6750    if use_per_dim_null {
6751        marginal_spec.double_penalty = false;
6752    }
6753    let inner_term = SmoothTermSpec {
6754        name: format!("{term_name}::marginal"),
6755        basis: SmoothBasisSpec::BSpline1D {
6756            feature_col,
6757            spec: marginal_spec,
6758        },
6759        shape: ShapeConstraint::None,
6760        joint_null_rotation: None,
6761    };
6762    let inner = build_single_local_smooth_term(data, &inner_term, workspace)?;
6763    let mut base = inner
6764        .design
6765        .try_to_dense_by_chunks("factor smooth marginal")
6766        .map_err(BasisError::InvalidInput)?;
6767    if matches!(spec.flavour, FactorSmoothFlavour::Re) {
6768        // `bs="re"` is a parametric random intercept+slope, not a B-spline
6769        // smooth evaluated through clamped knot support.  A degree-1 B-spline
6770        // with no internal knots spans the training rows, but outside the
6771        // boundary knots its basis is not the model matrix for `(1 + x | g)`;
6772        // held-out extrapolation then loses the random slope contribution.
6773        // Build the random-effect marginal directly as `[1, x - c]`, centered
6774        // at the frozen marginal domain, so fit-time and replay-time rows use
6775        // the same well-conditioned parametric columns on and off the training
6776        // interval.
6777        let center = match &inner.metadata {
6778            BasisMetadata::BSpline1D { knots, .. } if !knots.is_empty() => {
6779                0.5 * (knots[0] + knots[knots.len() - 1])
6780            }
6781            _ => 0.0,
6782        };
6783        let mut linear = Array2::<f64>::ones((data.nrows(), 2));
6784        linear
6785            .column_mut(1)
6786            .assign(&data.column(feature_col).mapv(|x| x - center));
6787        base = linear;
6788    }
6789    let n = base.nrows();
6790    let p = base.ncols();
6791    let q = p * n_levels;
6792
6793    // Block-diagonal replicated design: row i contributes its marginal row to
6794    // the column block owned by its grouping level, zeros elsewhere.
6795    let mut dense = Array2::<f64>::zeros((n, q));
6796    for i in 0..n {
6797        let bits = data[[i, group_col]].to_bits();
6798        let level_idx = levels.iter().position(|b| *b == bits).ok_or_else(|| {
6799            BasisError::InvalidInput(format!(
6800                "factor smooth term '{term_name}' saw an unseen grouping level at row {}",
6801                i + 1
6802            ))
6803        })?;
6804        let start = level_idx * p;
6805        dense
6806            .slice_mut(s![i, start..start + p])
6807            .assign(&base.row(i));
6808    }
6809
6810    // Penalties: replicate each marginal penalty into a block-diagonal
6811    // `I_L ⊗ S_j` so every level shares the same smoothing parameter λ_j (one
6812    // λ per marginal penalty), the defining feature of a factor smooth. For
6813    // `Re` the marginal penalty is replaced by one ridge per parametric
6814    // coordinate so intercept and slope variances can be learned separately.
6815    let marginal_penalties: Vec<Array2<f64>> = if matches!(spec.flavour, FactorSmoothFlavour::Re) {
6816        (0..p)
6817            .map(|j| {
6818                let mut s = Array2::<f64>::zeros((p, p));
6819                s[[j, j]] = 1.0;
6820                s
6821            })
6822            .collect()
6823    } else {
6824        inner.penalties.clone()
6825    };
6826    let marginal_penaltyinfo: Vec<PenaltyInfo> = if matches!(spec.flavour, FactorSmoothFlavour::Re)
6827    {
6828        (0..p)
6829            .map(|j| PenaltyInfo {
6830                source: PenaltySource::Primary,
6831                original_index: j,
6832                active: true,
6833                effective_rank: 1,
6834                dropped_reason: None,
6835                nullspace_dim_hint: p.saturating_sub(1),
6836                normalization_scale: 1.0,
6837                kronecker_factors: None,
6838            })
6839            .collect()
6840    } else {
6841        inner.penaltyinfo.clone()
6842    };
6843    if marginal_penalties.len() != marginal_penaltyinfo.len() {
6844        crate::bail_invalid_basis!(
6845            "internal factor-smooth penalty metadata mismatch for term '{}': penalties={}, infos={}",
6846            term_name,
6847            marginal_penalties.len(),
6848            marginal_penaltyinfo.len()
6849        );
6850    }
6851
6852    let mut penalties = Vec::<Array2<f64>>::with_capacity(marginal_penalties.len());
6853    let mut penaltyinfo = Vec::<PenaltyInfo>::with_capacity(marginal_penalties.len());
6854    for (penalty_pos, s_inner) in marginal_penalties.iter().enumerate() {
6855        let mut s_big = Array2::<f64>::zeros((q, q));
6856        for level in 0..n_levels {
6857            let start = level * p;
6858            s_big
6859                .slice_mut(s![start..start + p, start..start + p])
6860                .assign(s_inner);
6861        }
6862        let (s_big, factor_smooth_scale) = normalize_penalty_in_constrained_space(&s_big);
6863        let mut info = marginal_penaltyinfo[penalty_pos].clone();
6864        info.original_index = penalty_pos;
6865        info.normalization_scale *= factor_smooth_scale;
6866        info.nullspace_dim_hint = info.nullspace_dim_hint.saturating_mul(n_levels);
6867        info.kronecker_factors = None;
6868        penalties.push(s_big);
6869        penaltyinfo.push(info);
6870    }
6871
6872    let mut nullspaces: Vec<usize> = if matches!(spec.flavour, FactorSmoothFlavour::Re) {
6873        vec![q.saturating_sub(n_levels); p]
6874    } else {
6875        inner
6876            .nullspaces
6877            .iter()
6878            .map(|ns| ns.saturating_mul(n_levels))
6879            .collect()
6880    };
6881
6882    // `Fs` is the random-effect flavour of a smooth: the per-group curve is an
6883    // exchangeable Gaussian *function*, so EVERY coefficient — including the
6884    // {const, linear} null space of the marginal wiggliness penalty — must be
6885    // shrinkable toward zero under its own shared variance. The wiggliness
6886    // penalty `S_wiggle` shapes curvature but leaves the per-group intercept and
6887    // slope (its null space) completely UNPENALIZED. With the null space free,
6888    // each group fits its own intercept and slope with NO partial pooling, so
6889    // the held-out per-subject forecast inherits the full no-pooling variance
6890    // and curves away from the true per-group line (gam#712 real arm, gam#713;
6891    // gam#903 sleepstudy forecast ran ~74% over the lme4 BLUP bar).
6892    //
6893    // mgcv's `bs="fs"` fixes this by penalizing each null-space dimension
6894    // SEPARATELY (`smooth.construct.fs.smooth.spec` adds one rank-1 penalty per
6895    // null coordinate), each replicated block-diagonally across levels under a
6896    // single shared smoothing parameter — so REML fits a distinct
6897    // random-intercept variance and random-slope variance, the partial pooling
6898    // that makes the forecast track lme4's correlated random-effect BLUP. A
6899    // single *combined* null penalty (one λ for intercept+slope together) cannot
6900    // express the typically very different intercept and slope variances, which
6901    // is the residual forecast gap. We mirror mgcv exactly: for each orthonormal
6902    // null direction `z_k` of the marginal wiggliness penalty add
6903    // `I_L ⊗ (z_k z_kᵀ)` as its own penalty. The marginal's combined double
6904    // penalty was disabled above, so the null space is penalized once, per
6905    // dimension. With linear data REML drives the curvature λ up and degrades
6906    // `fs` to a linear random slope (edf → ≈2/group); with genuine curvature the
6907    // wiggliness λ stays small and the wiggle survives (data-adaptive, not a
6908    // cap). Gated by `m_null_penalty_orders`: order ≥ 1 (default) enables the
6909    // per-dimension null penalties; `m=0` keeps the legacy combined double
6910    // penalty and adds nothing here.
6911    if use_per_dim_null
6912        && let Some(Some(z)) = inner.null_eigenvectors.first()
6913        && z.nrows() == p
6914    {
6915        for k in 0..z.ncols() {
6916            // Rank-1 marginal penalty `z_k z_kᵀ`, replicated block-diagonally
6917            // across levels into `I_L ⊗ (z_k z_kᵀ)`. Its own λ is one shared
6918            // variance for this null component (intercept or slope) across all
6919            // groups — the random-effect structure of mgcv `fs`.
6920            let zk = z.column(k);
6921            let mut p_k = Array2::<f64>::zeros((p, p));
6922            for a in 0..p {
6923                for b in 0..p {
6924                    p_k[[a, b]] = zk[a] * zk[b];
6925                }
6926            }
6927            let mut s_null = Array2::<f64>::zeros((q, q));
6928            for level in 0..n_levels {
6929                let start = level * p;
6930                s_null
6931                    .slice_mut(s![start..start + p, start..start + p])
6932                    .assign(&p_k);
6933            }
6934            let (s_null, null_scale) = normalize_penalty_in_constrained_space(&s_null);
6935            let null_block = crate::basis::analyze_penalty_block_with_op(&s_null, None)?;
6936            if null_block.rank > 0 {
6937                let original_index = penalties.len();
6938                penalties.push(null_block.sym_penalty);
6939                nullspaces.push(null_block.nullity);
6940                penaltyinfo.push(PenaltyInfo {
6941                    source: PenaltySource::Primary,
6942                    original_index,
6943                    active: true,
6944                    effective_rank: null_block.rank,
6945                    dropped_reason: None,
6946                    nullspace_dim_hint: null_block.nullity,
6947                    normalization_scale: null_scale,
6948                    kronecker_factors: None,
6949                });
6950            }
6951        }
6952    }
6953    let null_eigenvectors = crate::basis::recompute_null_eigenvectors(&penalties)?;
6954    let joint_null_rotation = crate::basis::compute_joint_null_rotation(&penalties)?;
6955
6956    // Metadata: carry the marginal knot geometry + frozen levels so prediction
6957    // reconstructs an identical replicated design.
6958    let (knots, degree, periodic) = match &inner.metadata {
6959        BasisMetadata::BSpline1D {
6960            knots,
6961            periodic,
6962            degree,
6963            ..
6964        } => (
6965            knots.clone(),
6966            degree.unwrap_or(spec.marginal.degree),
6967            *periodic,
6968        ),
6969        other => {
6970            crate::bail_invalid_basis!(
6971                "factor smooth term '{}' produced an unexpected marginal metadata variant {:?}",
6972                term_name,
6973                other
6974            );
6975        }
6976    };
6977    let flavour_tag = match &spec.flavour {
6978        FactorSmoothFlavour::Fs { .. } => "fs",
6979        FactorSmoothFlavour::Sz => "sz",
6980        FactorSmoothFlavour::Re => "re",
6981    }
6982    .to_string();
6983    let metadata = BasisMetadata::FactorSmooth {
6984        continuous_cols: spec.continuous_cols.clone(),
6985        group_col,
6986        knots,
6987        degree,
6988        periodic,
6989        group_levels: levels,
6990        flavour: flavour_tag,
6991        // fs/re marginals are always B-spline; the cr marginal is sz-only and
6992        // handled on the dedicated Sz path above.
6993        marginal_is_cr: false,
6994    };
6995
6996    let ops = vec![None; penalties.len()];
6997    Ok(LocalSmoothTermBuild {
6998        dim: q,
6999        design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense)),
7000        penalties,
7001        ops,
7002        nullspaces,
7003        null_eigenvectors,
7004        joint_null_rotation,
7005        penaltyinfo,
7006        pre_dropped_penaltyinfo: Vec::new(),
7007        metadata,
7008        linear_constraints: None,
7009        box_reparam: false,
7010        kronecker_factored: None,
7011    })
7012}
7013
7014/// Resolve the grouping levels for a factor smooth: replay the frozen level
7015/// list when present (predict path), otherwise discover the sorted unique bit
7016/// patterns of the factor column (fit path).
7017pub fn resolve_factor_smooth_levels(
7018    data: ArrayView2<'_, f64>,
7019    group_col: usize,
7020    spec: &FactorSmoothSpec,
7021    term_name: &str,
7022) -> Result<Vec<u64>, BasisError> {
7023    if let Some(frozen) = &spec.group_frozen_levels {
7024        if frozen.is_empty() {
7025            crate::bail_invalid_basis!(
7026                "factor smooth term '{}' has an empty frozen level list",
7027                term_name
7028            );
7029        }
7030        return Ok(frozen.clone());
7031    }
7032    let mut bits: Vec<u64> = data.column(group_col).iter().map(|v| v.to_bits()).collect();
7033    bits.sort_by(|a, b| {
7034        f64::from_bits(*a)
7035            .partial_cmp(&f64::from_bits(*b))
7036            .unwrap_or(std::cmp::Ordering::Equal)
7037    });
7038    bits.dedup();
7039    Ok(bits)
7040}
7041
7042/// Marginal B-spline spec for a factor-smooth block. The marginal always builds
7043/// without an identifiability constraint (the per-level replication, not a
7044/// sum-to-zero side constraint, provides identifiability against the parametric
7045/// block). At predict time the marginal's knot geometry has already been pinned
7046/// into `marginal.knotspec` by the metadata replay, so the spec is used
7047/// verbatim aside from clearing the identifiability transform.
7048pub fn factor_smooth_marginal_for_replay(marginal: &BSplineBasisSpec) -> BSplineBasisSpec {
7049    let mut m = marginal.clone();
7050    m.identifiability = BSplineIdentifiability::None;
7051    m
7052}
7053
7054pub fn build_single_local_smooth_term(
7055    data: ArrayView2<'_, f64>,
7056    term: &SmoothTermSpec,
7057    workspace: &mut crate::basis::BasisWorkspace,
7058) -> Result<LocalSmoothTermBuild, BasisError> {
7059    if term.shape != ShapeConstraint::None && !shape_supports_basis(term) {
7060        crate::bail_invalid_basis!(
7061            "ShapeConstraint::{:?} is unsupported for term '{}'",
7062            term.shape,
7063            term.name
7064        );
7065    }
7066    if let SmoothBasisSpec::ByVariable {
7067        inner,
7068        by_col,
7069        kind,
7070        by,
7071    } = &term.basis
7072    {
7073        ensure_by_variable_specs_match(kind, by, &term.name)?;
7074        let mut inner_basis = (**inner).clone();
7075        // Factor-level `by=` owns model-space centering (it centers the gated
7076        // block against the level indicator downstream). Defer the inner
7077        // basis's default pooled centering so the level block is not
7078        // double-centered down to `k-2` columns (#1427). Numeric-by smooths are
7079        // untouched: they are not row-gated to a level and keep ordinary
7080        // intercept centering.
7081        if matches!(by, ByVariableSpec::Level { .. }) {
7082            defer_inner_model_centering_to_factor_level_wrapper(&mut inner_basis);
7083        }
7084        let inner_term = SmoothTermSpec {
7085            name: term.name.clone(),
7086            basis: inner_basis,
7087            shape: term.shape,
7088            joint_null_rotation: None,
7089        };
7090        let built = build_single_local_smooth_term(data, &inner_term, workspace)?;
7091        return apply_by_variable_to_local_build(built, data, *by_col, by, &term.name);
7092    }
7093
7094    // BySmooth: a `by=` smooth that unifies numeric or factor modulation into a
7095    // single term.  Lower it here so the downstream match does not need an arm.
7096    if let SmoothBasisSpec::BySmooth { smooth, by_kind } = &term.basis {
7097        return build_by_smooth_local(data, term, smooth, by_kind, workspace);
7098    }
7099
7100    let mut shape_axis_col: Option<usize> = None;
7101    let mut built: BasisBuildResult = match &term.basis {
7102        SmoothBasisSpec::FactorSumToZero {
7103            inner,
7104            by_col,
7105            levels,
7106            ..
7107        } => {
7108            if *by_col >= data.ncols() {
7109                crate::bail_dim_basis!(
7110                    "term '{}' by column {} out of bounds for {} columns",
7111                    term.name,
7112                    by_col,
7113                    data.ncols()
7114                );
7115            }
7116            if levels.len() < 2 {
7117                crate::bail_invalid_basis!(
7118                    "sum-to-zero factor smooth term '{}' requires at least two levels",
7119                    term.name
7120                );
7121            }
7122            if term.shape != ShapeConstraint::None {
7123                crate::bail_invalid_basis!(
7124                    "ShapeConstraint::{:?} is unsupported for sum-to-zero factor smooth term '{}'",
7125                    term.shape,
7126                    term.name
7127                );
7128            }
7129            let inner_term = SmoothTermSpec {
7130                name: format!("{}::inner", term.name),
7131                basis: (**inner).clone(),
7132                shape: ShapeConstraint::None,
7133                joint_null_rotation: None,
7134            };
7135            let mut inner_built = build_single_local_smooth_term(data, &inner_term, workspace)?;
7136            // Capture the marginal penalty's null directions BEFORE the penalty
7137            // vector is rebuilt below; the sum-to-zero null-space ridge replicates
7138            // these `z_k` into the contrast space (mgcv `bs="fs"` double-penalty).
7139            let inner_null_eigenvectors = inner_built.null_eigenvectors.clone();
7140            let base = inner_built
7141                .design
7142                .try_to_dense_by_chunks("sum-to-zero factor smooth")
7143                .map_err(BasisError::InvalidInput)?;
7144            let n = base.nrows();
7145            let p = base.ncols();
7146            let l_minus_one = levels.len() - 1;
7147            let mut dense = Array2::<f64>::zeros((n, p * l_minus_one));
7148            for i in 0..n {
7149                let bits = data[[i, *by_col]].to_bits();
7150                let level_idx = levels.iter().position(|b| *b == bits).ok_or_else(|| {
7151                    BasisError::InvalidInput(format!(
7152                        "sum-to-zero factor smooth term '{}' saw an unseen level at row {}",
7153                        term.name,
7154                        i + 1
7155                    ))
7156                })?;
7157                if level_idx < l_minus_one {
7158                    let start = level_idx * p;
7159                    dense
7160                        .slice_mut(s![i, start..start + p])
7161                        .assign(&base.row(i));
7162                } else {
7163                    for level in 0..l_minus_one {
7164                        let start = level * p;
7165                        dense
7166                            .slice_mut(s![i, start..start + p])
7167                            .assign(&base.row(i).mapv(|v| -v));
7168                    }
7169                }
7170            }
7171            let mut penalties = Vec::<Array2<f64>>::with_capacity(inner_built.penalties.len());
7172            let active_penalty_indices = inner_built
7173                .penaltyinfo
7174                .iter()
7175                .enumerate()
7176                .filter_map(|(idx, info)| info.active.then_some(idx))
7177                .collect::<Vec<_>>();
7178            if active_penalty_indices.len() != inner_built.penalties.len() {
7179                crate::bail_invalid_basis!(
7180                    "internal sz penalty metadata mismatch: activeinfos={}, penalties={}",
7181                    active_penalty_indices.len(),
7182                    inner_built.penalties.len()
7183                );
7184            }
7185            // Replicate each marginal penalty into the sum-to-zero contrast
7186            // space. With `L-1` free deviation blocks and the reference level
7187            // `d_L = -Σ_{k<L} d_k`, the marginal penalty summed over ALL `L`
7188            // levels, `Σ_{k=1}^{L} d_kᵀ S d_k`, expands to the `(I + 11ᵀ) ⊗ S`
7189            // contrast form (factor 2 on the diagonal blocks, 1 off-diagonal).
7190            //
7191            // PER-GROUP SMOOTHING PARAMETERS (#1074). mgcv's `bs="sz"` does NOT
7192            // pool that sum under one λ: `smooth.construct.sz` emits ONE penalty
7193            // matrix per factor level (here 6 separate `S`s, each with its own
7194            // smoothing parameter), so REML can shrink a low-amplitude group's
7195            // deviation curve hard while leaving a high-amplitude group nearly
7196            // unpenalized. A single shared wiggliness λ (the old construction)
7197            // forces every group to the SAME curvature budget, so a group whose
7198            // true curve is flat drags curvature into the noise of the busy
7199            // groups and vice-versa — systematic truth-recovery loss even when
7200            // the pooled total edf matches mgcv's (the observed `sz` 1.23× gap).
7201            //
7202            // We mirror mgcv exactly by splitting the per-marginal penalty
7203            // `Σ_{k=1}^{L} d_kᵀ S d_k` back into its `L` independent
7204            // rank-controlled summands BEFORE mapping to the contrast space, each
7205            // carrying its own λ:
7206            //   * level k < L (free block):  `d_kᵀ S d_k` → block-diagonal
7207            //     `(e_k e_kᵀ) ⊗ S`  (only the (k,k) block is `S`).
7208            //   * level L (reference):       `d_Lᵀ S d_L = (Σ_{j<L} d_j)ᵀ S (·)`
7209            //     → the fully-coupled `(11ᵀ) ⊗ S` block.
7210            // Summed at equal λ these `L` blocks recover the old `(I + 11ᵀ) ⊗ S`
7211            // exactly (`Σ_k e_k e_kᵀ = I`), so this is a strict generalization:
7212            // the pooled fit is still reachable, REML only GAINS the freedom to
7213            // spend curvature per group. The zero-sum reparameterization (hence
7214            // the `sz` vs `fs` identifiability) is untouched.
7215            //
7216            // `which_level ∈ 0..=l_minus_one`: `< l_minus_one` selects the single
7217            // free deviation block; `== l_minus_one` selects the reference-level
7218            // coupling block.
7219            let stz_per_group_penalty = |s_inner: &Array2<f64>, which_level: usize| -> Array2<f64> {
7220                let mut s_big = Array2::<f64>::zeros((p * l_minus_one, p * l_minus_one));
7221                if which_level < l_minus_one {
7222                    // (e_k e_kᵀ) ⊗ S: a single diagonal block.
7223                    let k = which_level;
7224                    let mut block = s_big.slice_mut(s![k * p..(k + 1) * p, k * p..(k + 1) * p]);
7225                    block.assign(s_inner);
7226                } else {
7227                    // (11ᵀ) ⊗ S: every block (diagonal and off-diagonal) is S.
7228                    for a in 0..l_minus_one {
7229                        for b in 0..l_minus_one {
7230                            let mut block =
7231                                s_big.slice_mut(s![a * p..(a + 1) * p, b * p..(b + 1) * p]);
7232                            block.assign(s_inner);
7233                        }
7234                    }
7235                }
7236                s_big
7237            };
7238            // One nullspace-dim entry per emitted penalty (must stay parallel to
7239            // `penalties`). Each per-group wiggliness block carries the marginal's
7240            // OWN nullity (a rank-`p` block touching a single level for the free
7241            // blocks; the coupling block is rank-`p` over the diagonal sum), and
7242            // the null ridges below record their own nullity.
7243            let mut nullspaces = Vec::<usize>::with_capacity(penalties.capacity());
7244            for (penalty_pos, s_inner) in inner_built.penalties.iter().enumerate() {
7245                let info_idx = active_penalty_indices[penalty_pos];
7246                let base_info = inner_built.penaltyinfo[info_idx].clone();
7247                let marginal_nullity = inner_built.nullspaces.get(penalty_pos).copied().unwrap_or(0);
7248                // Emit `L` independent per-level blocks for this marginal penalty.
7249                for which_level in 0..=l_minus_one {
7250                    let raw = stz_per_group_penalty(s_inner, which_level);
7251                    let (s_big, group_scale) = normalize_penalty_in_constrained_space(&raw);
7252                    let block = crate::basis::analyze_penalty_block_with_op(&s_big, None)?;
7253                    if block.rank == 0 {
7254                        continue;
7255                    }
7256                    if which_level == 0 {
7257                        // Reuse the marginal's own info slot for the first block so
7258                        // the existing normalization bookkeeping stays attached.
7259                        inner_built.penaltyinfo[info_idx].normalization_scale *= group_scale;
7260                        inner_built.penaltyinfo[info_idx].original_index = penalties.len();
7261                        inner_built.penaltyinfo[info_idx].effective_rank = block.rank;
7262                        inner_built.penaltyinfo[info_idx].nullspace_dim_hint = block.nullity;
7263                    } else {
7264                        let mut info = base_info.clone();
7265                        info.original_index = penalties.len();
7266                        info.normalization_scale = base_info.normalization_scale * group_scale;
7267                        info.effective_rank = block.rank;
7268                        info.nullspace_dim_hint = block.nullity;
7269                        info.kronecker_factors = None;
7270                        inner_built.penaltyinfo.push(info);
7271                    }
7272                    penalties.push(block.sym_penalty);
7273                    // The coupling block (which_level == l_minus_one) spans the
7274                    // marginal range on the diagonal-sum direction; the free
7275                    // blocks touch one level. Both leave the marginal null space
7276                    // unpenalized, recorded here so the null ridges below complete
7277                    // the double penalty.
7278                    nullspaces.push(marginal_nullity);
7279                }
7280            }
7281
7282            // Null-space ridge, mirroring the `bs="fs"` double-penalty
7283            // construction (#1605, same defect class as #700/#712/#713). The
7284            // marginal wiggliness penalty `S` shapes curvature but leaves the
7285            // {const, linear} null space of each deviation curve COMPLETELY
7286            // unpenalized. With that null space free, the single combined
7287            // wiggliness smoothing parameter cannot separate the per-group
7288            // intercept/slope variance from the curvature variance, so REML
7289            // parks the wiggliness `λ` high — over-smoothing (under-fitting) the
7290            // deviation blocks even when the truth lives in their span (the `sz`
7291            // recovery gap vs the `fs` superset). mgcv's `bs="fs"` fixes the
7292            // analogous gap by penalizing each null-space dimension SEPARATELY
7293            // under its own shared variance; we mirror that here while keeping
7294            // the zero-sum reparameterization, so the constraint (and the
7295            // identifiability of `sz` vs `fs`) is preserved. For each orthonormal
7296            // null direction `z_k` of the marginal penalty, add the rank-1
7297            // marginal penalty `z_k z_kᵀ` mapped into the SAME `(I + 11ᵀ)`
7298            // sum-to-zero contrast space, each carrying its own `λ`.
7299            if let Some(Some(z)) = inner_null_eigenvectors.first()
7300                && z.nrows() == p
7301            {
7302                for k in 0..z.ncols() {
7303                    let zk = z.column(k);
7304                    let mut p_k = Array2::<f64>::zeros((p, p));
7305                    for a in 0..p {
7306                        for b in 0..p {
7307                            p_k[[a, b]] = zk[a] * zk[b];
7308                        }
7309                    }
7310                    // Null ridges stay POOLED (the `(I + 11ᵀ) ⊗ z_k z_kᵀ` form):
7311                    // they govern the per-group intercept/slope shrinkage, which
7312                    // mgcv pools under one variance even for `sz`; only the
7313                    // curvature (wiggliness) penalty is split per group above.
7314                    let stz_pooled_null = {
7315                        let mut s_big = Array2::<f64>::zeros((p * l_minus_one, p * l_minus_one));
7316                        for a in 0..l_minus_one {
7317                            for b in 0..l_minus_one {
7318                                let factor = if a == b { 2.0 } else { 1.0 };
7319                                let mut block =
7320                                    s_big.slice_mut(s![a * p..(a + 1) * p, b * p..(b + 1) * p]);
7321                                block.assign(&p_k.mapv(|v| v * factor));
7322                            }
7323                        }
7324                        s_big
7325                    };
7326                    let (s_null, null_scale) =
7327                        normalize_penalty_in_constrained_space(&stz_pooled_null);
7328                    let null_block = crate::basis::analyze_penalty_block_with_op(&s_null, None)?;
7329                    if null_block.rank > 0 {
7330                        let original_index = penalties.len();
7331                        penalties.push(null_block.sym_penalty);
7332                        nullspaces.push(null_block.nullity);
7333                        inner_built.penaltyinfo.push(PenaltyInfo {
7334                            source: PenaltySource::Primary,
7335                            original_index,
7336                            active: true,
7337                            effective_rank: null_block.rank,
7338                            dropped_reason: None,
7339                            nullspace_dim_hint: null_block.nullity,
7340                            normalization_scale: null_scale,
7341                            kronecker_factors: None,
7342                        });
7343                    }
7344                }
7345            }
7346            inner_built.dim = p * l_minus_one;
7347            inner_built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense));
7348            inner_built.penalties = penalties;
7349            inner_built.ops = vec![None; inner_built.penalties.len()];
7350            inner_built.nullspaces = nullspaces;
7351            // Invariant: `null_eigenvectors[k]` must mirror `penalties[k]`'s
7352            // spectral null space. We just rebuilt `inner_built.penalties` from
7353            // Kronecker-like `S_big` blocks, so the previously-plumbed
7354            // `null_eigenvectors` (still parallel to the OLD per-level penalty)
7355            // is stale. Recompute from the rebuilt penalties to restore the
7356            // invariant; ditto for the joint-null absorption rotation.
7357            inner_built.null_eigenvectors =
7358                crate::basis::recompute_null_eigenvectors(&inner_built.penalties)?;
7359            inner_built.joint_null_rotation =
7360                crate::basis::compute_joint_null_rotation(&inner_built.penalties)?;
7361            inner_built.kronecker_factored = None;
7362            return Ok(inner_built);
7363        }
7364        SmoothBasisSpec::BSpline1D { feature_col, spec } => {
7365            if *feature_col >= data.ncols() {
7366                crate::bail_dim_basis!(
7367                    "term '{}' feature column {} out of bounds for {} columns",
7368                    term.name,
7369                    feature_col,
7370                    data.ncols()
7371                );
7372            }
7373            let mut spec_local = spec.clone();
7374            if term.shape != ShapeConstraint::None {
7375                // Shape-constrained B-splines are anchored by construction.
7376                // Sum-to-zero side constraints conflict with monotonic/convex cones.
7377                spec_local.identifiability = BSplineIdentifiability::None;
7378            }
7379            // Endpoint boundary conditions are structural for B-splines: the
7380            // basis builder bakes their homogeneous nullspace transform into
7381            // the design, penalties, and stored raw-basis transform.
7382            build_bspline_basis_1d(data.column(*feature_col), &spec_local)?
7383        }
7384        SmoothBasisSpec::ThinPlate {
7385            feature_cols,
7386            spec,
7387            input_scales,
7388        } => {
7389            if term.shape != ShapeConstraint::None {
7390                if feature_cols.len() != 1 {
7391                    crate::bail_invalid_basis!(
7392                        "ShapeConstraint::{:?} for term '{}' on ThinPlate basis requires exactly 1 feature axis; found {}",
7393                        term.shape,
7394                        term.name,
7395                        feature_cols.len()
7396                    );
7397                }
7398                shape_axis_col = Some(feature_cols[0]);
7399            }
7400            let mut x = select_columns(data, feature_cols)?;
7401            // Auto-standardize multivariate inputs: use stored scales (prediction)
7402            // or compute fresh ones (training). Same standardization-vs-
7403            // length-scale compensation as Matérn / hybrid Duchon: divide
7404            // the user's L by σ_geom so kernel(‖x_std − c_std‖/L_eff)
7405            // matches the original-coord kernel for uniform σ.
7406            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7407                apply_input_standardization(&mut x, s);
7408                (
7409                    Some(s.clone()),
7410                    compensate_length_scale_for_standardization(spec.length_scale, s),
7411                )
7412            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7413                apply_input_standardization(&mut x, &s);
7414                let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
7415                (Some(s), l_eff)
7416            } else {
7417                (None, spec.length_scale)
7418            };
7419            let mut spec_local = spec.clone();
7420            spec_local.length_scale = length_scale_eff;
7421            if matches!(
7422                spec_local.identifiability,
7423                SpatialIdentifiability::OrthogonalToParametric
7424            ) {
7425                spec_local.identifiability = SpatialIdentifiability::None;
7426            }
7427            let mut result = build_thin_plate_basis(x.view(), &spec_local).map_err(|err| {
7428                rewrite_thin_plate_knots_error(err, &term.name, feature_cols.len(), spec)
7429            })?;
7430            // Inject input scales into metadata; also restore the user's
7431            // original length_scale (not the σ_geom-compensated one) so a
7432            // metadata-driven rebuild that re-applies compensation does not
7433            // double-divide. The build may auto-promote to Duchon when
7434            // canonical TPS is infeasible (k < polynomial-nullspace size);
7435            // in that case patch the Duchon metadata variant so predict-time
7436            // round-trips through the same standardized data path.
7437            match &mut result.metadata {
7438                BasisMetadata::ThinPlate {
7439                    input_scales: ms,
7440                    length_scale,
7441                    ..
7442                } => {
7443                    *ms = scales;
7444                    *length_scale = spec.length_scale;
7445                }
7446                BasisMetadata::Duchon {
7447                    input_scales: ms,
7448                    length_scale,
7449                    ..
7450                } => {
7451                    // Auto-promotion (canonical TPS infeasible at this (d, k)).
7452                    // Since #1091 the promotion does NOT forward the incoming
7453                    // σ_geom-compensated `spec_local.length_scale` to the Duchon
7454                    // builder — it DISCARDS it and substitutes the geometric mean
7455                    // of the center pairwise distances (`promotion_length_scale`,
7456                    // the natural radial-kernel scale where κ·r ≈ O(1)). So the
7457                    // realized kernel bandwidth recorded in this metadata bears no
7458                    // fixed relation to the user's `spec.length_scale`; clobbering
7459                    // it to the user-facing value (the pre-#1091 behavior) makes
7460                    // freeze→replay re-derive `compensate(spec.length_scale, σ) ≠
7461                    // promotion_length_scale`, evaluating the kernel at the wrong
7462                    // bandwidth and corrupting the replayed design (#1091 broke the
7463                    // e7ff5ed83 freeze contract for the auto-promoted path).
7464                    //
7465                    // The freeze→replay round trip rebuilds through the Duchon arm,
7466                    // which re-applies σ_geom compensation:
7467                    //   replay_eff = compensate(metadata.length_scale, σ)
7468                    //              = metadata.length_scale / σ_geom.
7469                    // For replay_eff to reproduce the realized `promotion_length_scale`
7470                    // we must store the UN-compensated value `promotion_length_scale
7471                    // · σ_geom`. `compensate(1.0, σ) = 1/σ_geom`, so divide the
7472                    // realized scale by it to multiply back through σ_geom. With no
7473                    // standardization (`scales == None`) replay does not compensate,
7474                    // so the realized value is kept verbatim.
7475                    if let (Some(s), Some(realized)) = (scales.as_ref(), *length_scale) {
7476                        let inv_sigma_geom =
7477                            compensate_length_scale_for_standardization(1.0, s);
7478                        if inv_sigma_geom.is_finite() && inv_sigma_geom > 0.0 {
7479                            *length_scale = Some(realized / inv_sigma_geom);
7480                        }
7481                    }
7482                    *ms = scales;
7483                }
7484                _ => {}
7485            }
7486            result
7487        }
7488        SmoothBasisSpec::Sphere { feature_cols, spec } => {
7489            if term.shape != ShapeConstraint::None {
7490                crate::bail_invalid_basis!(
7491                    "ShapeConstraint::{:?} for term '{}' is not supported on spherical splines",
7492                    term.shape,
7493                    term.name
7494                );
7495            }
7496            let x = select_columns(data, feature_cols)?;
7497            build_spherical_spline_basis(x.view(), spec)?
7498        }
7499        SmoothBasisSpec::ConstantCurvature { feature_cols, spec } => {
7500            if term.shape != ShapeConstraint::None {
7501                crate::bail_invalid_basis!(
7502                    "ShapeConstraint::{:?} for term '{}' is not supported on constant-curvature smooths",
7503                    term.shape,
7504                    term.name
7505                );
7506            }
7507            // Chart coordinates are consumed verbatim: NO auto-standardization.
7508            // Rescaling axes would change the chart gauge `1 + κ‖x‖²` and
7509            // silently redefine which curvature κ refers to (the same point
7510            // cloud at a different chart scale has a different κ̂); the user's
7511            // coordinates ARE the geometry here, exactly as for the sphere
7512            // smooth's (lat, lon).
7513            let x = select_columns(data, feature_cols)?;
7514            build_constant_curvature_basis(x.view(), spec)?
7515        }
7516        SmoothBasisSpec::MeasureJet {
7517            feature_cols,
7518            spec,
7519            input_scales,
7520        } => {
7521            if term.shape != ShapeConstraint::None {
7522                crate::bail_invalid_basis!(
7523                    "ShapeConstraint::{:?} for term '{}' is not supported on measure-jet smooths",
7524                    term.shape,
7525                    term.name
7526                );
7527            }
7528            let mut x = select_columns(data, feature_cols)?;
7529            // Matern-style per-axis standardization; the realized σ vector is
7530            // persisted into the metadata for predict-time replay.
7531            //
7532            // Length-scale round-trip contract (owning statement; the freeze
7533            // and frozen-validation arms reference it): `input_scales: Some`
7534            // marks the REPLAY path — the frozen length_scale is already the
7535            // realized post-standardization value and passes through
7536            // verbatim. Fresh path: an explicit user length_scale is in
7537            // ORIGINAL coordinates and gets the σ_geom compensation; the 0.0
7538            // auto sentinel passes through (auto-derivation runs inside the
7539            // builder, post-standardization).
7540            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7541                apply_input_standardization(&mut x, s);
7542                (Some(s.clone()), spec.length_scale)
7543            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7544                apply_input_standardization(&mut x, &s);
7545                let l_eff = if spec.length_scale > 0.0 {
7546                    compensate_length_scale_for_standardization(spec.length_scale, &s)
7547                } else {
7548                    spec.length_scale
7549                };
7550                (Some(s), l_eff)
7551            } else {
7552                (None, spec.length_scale)
7553            };
7554            let mut spec_local = spec.clone();
7555            spec_local.length_scale = length_scale_eff;
7556            let mut result = build_measure_jet_basis(x.view(), &spec_local)?;
7557            if let BasisMetadata::MeasureJet {
7558                input_scales: ms, ..
7559            } = &mut result.metadata
7560            {
7561                *ms = scales;
7562            }
7563            result
7564        }
7565        SmoothBasisSpec::Matern {
7566            feature_cols,
7567            spec,
7568            input_scales,
7569        } => {
7570            if term.shape != ShapeConstraint::None {
7571                if feature_cols.len() != 1 {
7572                    crate::bail_invalid_basis!(
7573                        "ShapeConstraint::{:?} for term '{}' on Matern basis requires exactly 1 feature axis; found {}",
7574                        term.shape,
7575                        term.name,
7576                        feature_cols.len()
7577                    );
7578                }
7579                shape_axis_col = Some(feature_cols[0]);
7580            }
7581            let mut x = select_columns(data, feature_cols)?;
7582            // Auto-standardization (per-axis division by σ_a) reinterprets
7583            // the user's `length_scale` from original data coordinates
7584            // into post-standardization coordinates: for uniform σ_a = σ,
7585            // `kernel(‖x_std − c_std‖/L)` equals `kernel(‖x − c‖/(σ·L))`,
7586            // so the effective kernel range shrinks by σ. To keep
7587            // `length_scale` consistently expressed in *original* data
7588            // coordinates regardless of axis variances, we standardize
7589            // and divide L by σ_geom = (∏σ_a)^(1/d). For uniform σ this
7590            // recovers the user's kernel exactly; for anisotropic data
7591            // the resulting per-axis effective scales σ_a / σ_geom are
7592            // the standard Mahalanobis preconditioning and preserve the
7593            // geometric-mean kernel range. Storing the σ vector in
7594            // metadata.input_scales makes the same transformation
7595            // replayable at predict time.
7596            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7597                apply_input_standardization(&mut x, s);
7598                (
7599                    Some(s.clone()),
7600                    compensate_length_scale_for_standardization(spec.length_scale, s),
7601                )
7602            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7603                apply_input_standardization(&mut x, &s);
7604                let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
7605                (Some(s), l_eff)
7606            } else {
7607                (None, spec.length_scale)
7608            };
7609            let mut spec_local = spec.clone();
7610            spec_local.length_scale = length_scale_eff;
7611            let mut result = build_matern_basiswithworkspace(x.view(), &spec_local, workspace)?;
7612            if let BasisMetadata::Matern {
7613                input_scales,
7614                length_scale,
7615                ..
7616            } = &mut result.metadata
7617            {
7618                *input_scales = scales;
7619                *length_scale = spec.length_scale;
7620            }
7621            result
7622        }
7623        SmoothBasisSpec::Duchon {
7624            feature_cols,
7625            spec,
7626            input_scales,
7627        } => {
7628            if term.shape != ShapeConstraint::None {
7629                if feature_cols.len() != 1 {
7630                    crate::bail_invalid_basis!(
7631                        "ShapeConstraint::{:?} for term '{}' on Duchon basis requires exactly 1 feature axis; found {}",
7632                        term.shape,
7633                        term.name,
7634                        feature_cols.len()
7635                    );
7636                }
7637                shape_axis_col = Some(feature_cols[0]);
7638            }
7639            let mut x = select_columns(data, feature_cols)?;
7640            // Hybrid Duchon (length_scale=Some) is governed by the same
7641            // standardization-vs-length-scale equivalence as Matérn: the
7642            // user's `length_scale` is interpreted in original data
7643            // coordinates, but auto-standardization (per-axis division by
7644            // σ_a) reinterprets it as σ_geom · L. Pre-multiply by 1/σ_geom
7645            // so kernel(‖x_std − c_std‖/L_eff) reproduces the user's
7646            // original-coord kernel exactly for uniform σ_a, and reduces
7647            // to standard Mahalanobis preconditioning for anisotropic σ.
7648            // Pure Duchon (length_scale=None) is scale-free and needs no
7649            // compensation.
7650            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7651                apply_input_standardization(&mut x, s);
7652                (
7653                    Some(s.clone()),
7654                    compensate_optional_length_scale_for_standardization(spec.length_scale, s),
7655                )
7656            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7657                apply_input_standardization(&mut x, &s);
7658                let l_eff =
7659                    compensate_optional_length_scale_for_standardization(spec.length_scale, &s);
7660                (Some(s), l_eff)
7661            } else {
7662                (None, spec.length_scale)
7663            };
7664            let mut spec_local = spec.clone();
7665            spec_local.length_scale = length_scale_eff;
7666            // The Duchon input axis is standardized in place above (`x → x/σ`,
7667            // scale-only, no centering). A 1-D cyclic boundary `[start, end)`
7668            // declared in ORIGINAL covariate units must move into that same
7669            // standardized frame, or the modular wrap in
7670            // `build_cyclic_duchon_basis_1dwithworkspace` folds the standardized
7671            // coordinate against an original-unit period: the seam never closes
7672            // and the basis silently degrades to non-periodic (#1074:
7673            // `duchon(x, periodic=true)` predictions diverged across the wrap,
7674            // f(0) ≠ f(2π)). Rescale by the same 1/σ applied to the data so
7675            // training and predict share one periodic geometry.
7676            if let (Some(s), crate::basis::OneDimensionalBoundary::Cyclic { start, end }) =
7677                (scales.as_ref(), spec_local.boundary.clone())
7678                && s.len() == 1
7679                && s[0] > 0.0
7680            {
7681                spec_local.boundary = crate::basis::OneDimensionalBoundary::Cyclic {
7682                    start: start / s[0],
7683                    end: end / s[0],
7684                };
7685            }
7686            if matches!(
7687                spec_local.identifiability,
7688                SpatialIdentifiability::OrthogonalToParametric
7689            ) {
7690                spec_local.identifiability = SpatialIdentifiability::None;
7691            }
7692            let mut result = build_duchon_basiswithworkspace(x.view(), &spec_local, workspace)?;
7693            if let BasisMetadata::Duchon {
7694                input_scales,
7695                length_scale,
7696                ..
7697            } = &mut result.metadata
7698            {
7699                *input_scales = scales;
7700                *length_scale = spec.length_scale;
7701            }
7702            result
7703        }
7704        SmoothBasisSpec::Pca {
7705            feature_cols,
7706            basis_matrix,
7707            centered,
7708            smooth_penalty,
7709            center_mean,
7710            pca_basis_path,
7711            chunk_size,
7712        } => {
7713            if term.shape != ShapeConstraint::None {
7714                crate::bail_invalid_basis!(
7715                    "ShapeConstraint::{:?} for term '{}' is not supported on Pca basis",
7716                    term.shape,
7717                    term.name
7718                );
7719            }
7720            build_pca_smooth_basis(
7721                data,
7722                feature_cols,
7723                basis_matrix,
7724                *centered,
7725                *smooth_penalty,
7726                center_mean.as_ref(),
7727                pca_basis_path.as_ref(),
7728                *chunk_size,
7729            )?
7730        }
7731        SmoothBasisSpec::TensorBSpline { feature_cols, spec } => {
7732            build_tensor_bspline_basis(data, feature_cols, spec)?
7733        }
7734        SmoothBasisSpec::ByVariable { .. } => {
7735            crate::bail_invalid_basis!(
7736                "internal: ByVariable smooths must return before inner basis dispatch"
7737            );
7738        }
7739        SmoothBasisSpec::BySmooth { .. } => {
7740            crate::bail_invalid_basis!("internal: BySmooth smooths must be lowered to ByVariable before inner basis dispatch"
7741                    .to_string(),);
7742        }
7743        SmoothBasisSpec::FactorSmooth { spec } => {
7744            if term.shape != ShapeConstraint::None {
7745                crate::bail_invalid_basis!(
7746                    "ShapeConstraint::{:?} is unsupported for factor smooth term '{}'",
7747                    term.shape,
7748                    term.name
7749                );
7750            }
7751            return build_factor_smooth(data, spec, &term.name, workspace);
7752        }
7753    };
7754
7755    // The Matérn design ALWAYS uses the operator-collocation {mass, tension,
7756    // stiffness} penalty triplet, overriding whatever penalty
7757    // `build_matern_basis_seeded` produced for the `double_penalty` flag.
7758    //
7759    // #1074 investigated swapping this for the genuine RKHS kernel penalty
7760    // `β' K_CC β` (mgcv `bs="gp"` / fields kriging) on the theory that the
7761    // operator triplet under-smooths the rougher half-integer kernels. MSI
7762    // truth-recovery measurement REFUTED that: the kernel penalty did NOT
7763    // improve ν=3/2 recovery (`matern(x,nu=1.5)` RMSE-vs-truth stayed 0.0554)
7764    // and it REGRESSED the high-frequency-init guard — `matern(x,nu≥5/2)` on
7765    // sin(2π·8·x) collapsed (span 0.53, RMSE 0.70) because the single RKHS
7766    // norm over-smooths a high-frequency truth where the Sobolev-order operator
7767    // dials do not. The operator triplet is therefore retained as the Matérn
7768    // penalty, and the κ-optimizer re-key / ψ-derivative paths route through the
7769    // same triplet builder so the block count stays ψ-stable (#1270).
7770    if let SmoothBasisSpec::Matern { .. } = &term.basis {
7771        let (penalties, nullspace_dims, penaltyinfo) =
7772            matern_operator_penalty_triplet_from_metadata(&built.metadata)?;
7773        built.penalties = penalties;
7774        built.nullspace_dims = nullspace_dims;
7775        built.penaltyinfo = penaltyinfo;
7776    }
7777
7778    let p_local = built.design.ncols();
7779    let mut metadata = built.metadata.clone();
7780    // Extract factored Kronecker representation before consuming fields.
7781    // Invalidate it if shape transforms will be applied (they break structure).
7782    let kron_factored = if term.shape == ShapeConstraint::None {
7783        built.kronecker_factored
7784    } else {
7785        None
7786    };
7787    let mut design_t = built.design;
7788    let mut penalties_t: Vec<Array2<f64>> = built.penalties;
7789    // Ops vector parallel to `penalties_t`. Survives unchanged through the
7790    // identity path; nulled element-wise when `T^T S T` reparametrization
7791    // is applied (operator no longer bit-equivalent to the transformed
7792    // matrix); wrapped in `ScaledPenaltyOp` after Frobenius normalization.
7793    let mut ops_t: Vec<Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>> =
7794        built.ops;
7795    if matches!(
7796        spatial_identifiability_policy(term),
7797        Some(SpatialIdentifiability::OrthogonalToParametric)
7798    ) {
7799        metadata = freeze_raw_spatial_metadata(metadata, design_t.ncols());
7800    }
7801
7802    let active_penaltyinfo_t = built
7803        .penaltyinfo
7804        .iter()
7805        .filter(|info| info.active)
7806        .cloned()
7807        .collect::<Vec<_>>();
7808    let pre_dropped_penaltyinfo_t = built
7809        .penaltyinfo
7810        .iter()
7811        .filter(|info| !info.active)
7812        .cloned()
7813        .collect::<Vec<_>>();
7814    let use_box_reparam =
7815        term.shape != ShapeConstraint::None && shape_uses_box_reparameterization(&term.basis);
7816    if let Some((order, sign)) = shape_order_and_sign(term.shape)
7817        && use_box_reparam
7818    {
7819        // Order 1 (monotone): the plain first-difference cone θ_{i+1}−θ_i ≥ 0 is
7820        // the control-polygon monotonicity criterion, which is independent of
7821        // Greville-abscissa spacing (it only fixes the *sign* of consecutive
7822        // control-point gaps), so the integer-difference transform is exact.
7823        //
7824        // Order 2 (convex/concave): the plain second-difference cone is only
7825        // correct for evenly spaced Greville abscissae. gam's B-splines are
7826        // clamped (and may use quantile knots), so the abscissae are not
7827        // uniform and the geometrically-correct cone is the second *divided*
7828        // difference. Build the Greville-scaled transform so γ_{≥2} ≥ 0
7829        // certifies convexity of the function, not of the raw coefficient
7830        // index. Periodic B-splines use uniform interior knots (uniform
7831        // abscissae), where the divided differences coincide with the integer
7832        // differences up to scale, so the plain path stays exact there.
7833        let t = if order == 2 {
7834            let bspline_meta = match &metadata {
7835                BasisMetadata::BSpline1D {
7836                    knots,
7837                    degree,
7838                    periodic,
7839                    ..
7840                } if periodic.is_none() => Some((knots.clone(), degree.unwrap_or(0))),
7841                _ => None,
7842            };
7843            match bspline_meta {
7844                Some((knots, degree)) if degree >= 1 => {
7845                    let greville = crate::basis::compute_greville_abscissae(&knots, degree)?;
7846                    if greville.len() != p_local {
7847                        crate::bail_invalid_basis!(
7848                            "shape-constraint Greville abscissae count {} does not match basis dim {} for term '{}'",
7849                            greville.len(),
7850                            p_local,
7851                            term.name
7852                        );
7853                    }
7854                    convex_divided_difference_transform_matrix(&greville, sign)?
7855                }
7856                _ => cumulative_sum_transform_matrix(p_local, order, sign),
7857            }
7858        } else {
7859            cumulative_sum_transform_matrix(p_local, order, sign)
7860        };
7861        // Coefficient-side transform: wrap the design in an operator that
7862        // applies T on the coefficient side, preserving sparsity/operator
7863        // structure of the inner design.
7864        let inner_dense = match design_t {
7865            DesignMatrix::Dense(d) => d,
7866            DesignMatrix::Sparse(sp) => gam_linalg::matrix::DenseDesignMatrix::from(
7867                sp.try_to_dense_arc("shape-constrained coefficient transform")
7868                    .map_err(BasisError::InvalidInput)?,
7869            ),
7870        };
7871        let coeff_op = gam_linalg::matrix::CoefficientTransformOperator::new(inner_dense, t.clone())
7872            .map_err(|e| BasisError::InvalidInput(format!("CoefficientTransformOperator: {e}")))?;
7873        design_t = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(coeff_op)));
7874        if penalties_t.len() != active_penaltyinfo_t.len() {
7875            crate::bail_invalid_basis!(
7876                "internal box-reparam penalty/info mismatch for term '{}': penalties={}, infos={}",
7877                term.name,
7878                penalties_t.len(),
7879                active_penaltyinfo_t.len()
7880            );
7881        }
7882        // Wiggliness penalties undergo the exact congruence `S → TᵀST` (PSD
7883        // preserving). The double-penalty *nullspace shrinkage* ridge must NOT:
7884        // it is a unit-eigenvalue projector `ZZᵀ` onto null(S_wiggle) in the
7885        // β (B-spline coefficient) coordinates, and the congruence
7886        // `Tᵀ(ZZᵀ)T = (TᵀZ)(TᵀZ)ᵀ` is no longer a projector — its eigenvalues
7887        // blow up by the conditioning of the cumulative-sum `T` (cond(T) grows
7888        // with the basis dim), concentrating an enormous penalty on the leading
7889        // γ₀ "level" coordinate. REML then drives the shared λ to its ceiling
7890        // and the smooth collapses to a flat constant (#509, the over-smoothing
7891        // face). The principled fix keeps mgcv's double-penalty semantics in the
7892        // *reparametrized* space: rebuild the ridge as the unit-eigenvalue
7893        // nullspace projector of the transformed wiggliness penalty `TᵀST`, so
7894        // the double penalty shrinks exactly the unpenalized polynomial
7895        // directions of the γ-space smooth with eigenvalue 1, identical in
7896        // conditioning to the unconstrained fit.
7897        let transformed_wiggliness = penalties_t
7898            .iter()
7899            .zip(active_penaltyinfo_t.iter())
7900            .find(|(_, info)| !matches!(info.source, PenaltySource::DoublePenaltyNullspace))
7901            .map(|(s_local, _)| {
7902                let tt_s = fast_atb(&t, s_local);
7903                fast_ab(&tt_s, &t)
7904            });
7905        let mut rebuilt = Vec::with_capacity(penalties_t.len());
7906        for (s_local, info) in penalties_t.iter().zip(active_penaltyinfo_t.iter()) {
7907            if matches!(info.source, PenaltySource::DoublePenaltyNullspace) {
7908                // #1654: the double-penalty nullspace ridge under the box
7909                // reparameterization.
7910                //
7911                // For the CURVATURE constraints (order == 2, convex/concave) the
7912                // box transform `T` is the Greville-scaled second *divided*
7913                // difference map (`convex_divided_difference_transform_matrix`).
7914                // Rebuilding the ridge from scratch as the orthonormal null-space
7915                // projector of `TᵀST` (the #509-era monotone fix below) yields a
7916                // γ-space ridge `Z_γ Z_γᵀ` whose null subspace is the affine
7917                // (level γ₀ + slope γ₁) face — the SAME subspace targeted in
7918                // β-space, but measured in the γ inner product rather than the
7919                // β one. A reparameterization `β = Tγ` must leave the penalized
7920                // REML fit invariant, which requires every penalty block to
7921                // transform by the SAME congruence `S ↦ TᵀST`; the from-scratch
7922                // projector rebuild is NOT that congruence, so it silently
7923                // re-weights the level/slope shrinkage relative to the wiggliness
7924                // penalty (each block is independently Frobenius-normalized just
7925                // below, decoupling their scales). The distorted REML λ landscape
7926                // then drives the convex/concave smooth into the flat linear
7927                // corner (curvature pinned ≈ 0, EDF ≈ 1.5) for a
7928                // seed/basis-dimension–specific subset of fits, even though an
7929                // unconstrained `s(x)` on the same data recovers the convex truth
7930                // at EDF ≈ 4. Restoring the exact congruence `Tᵀ R_β T` for the
7931                // ridge keeps the box reparameterization a true invertible
7932                // change of coordinates, so the curvature-constrained fit tracks
7933                // the unconstrained smoothing instead of over-smoothing
7934                // (verified: seed-7/k-20 truth-RMSE 0.31 → 0.045).
7935                //
7936                // For MONOTONE (order == 1) `T` is the cumulative-sum transform
7937                // whose conditioning grows fast with the basis dimension; there
7938                // the congruence concentrates an enormous penalty on the leading
7939                // γ₀ level coordinate and over-smooths to a flat constant (#509),
7940                // which the from-scratch unit-eigenvalue projector rebuild was
7941                // introduced to cure. Keep that path for monotone.
7942                if order == 2 {
7943                    let tt_s = fast_atb(&t, s_local);
7944                    rebuilt.push(fast_ab(&tt_s, &t));
7945                } else {
7946                    let s_wiggle_t = transformed_wiggliness.as_ref().ok_or_else(|| {
7947                        BasisError::InvalidInput(format!(
7948                            "box-reparam term '{}' has a double-penalty ridge but no primary wiggliness penalty to derive its nullspace from",
7949                            term.name
7950                        ))
7951                    })?;
7952                    let ridge = crate::basis::build_nullspace_shrinkage_penalty(s_wiggle_t)?
7953                        .map(|shrink| shrink.sym_penalty)
7954                        .unwrap_or_else(|| Array2::<f64>::zeros((p_local, p_local)));
7955                    rebuilt.push(ridge);
7956                }
7957            } else {
7958                let tt_s = fast_atb(&t, s_local);
7959                rebuilt.push(fast_ab(&tt_s, &t));
7960            }
7961        }
7962        penalties_t = rebuilt;
7963        // T^T S T (and the rebuilt γ-space ridge) invalidate op-form
7964        // bit-equivalence; drop ops here.
7965        ops_t = vec![None; penalties_t.len()];
7966    }
7967    if penalties_t.len() != active_penaltyinfo_t.len() {
7968        crate::bail_invalid_basis!(
7969            "internal penalty metadata mismatch for term '{}': active penalties={}, active infos={}",
7970            term.name,
7971            penalties_t.len(),
7972            active_penaltyinfo_t.len()
7973        );
7974    }
7975    if ops_t.len() != penalties_t.len() {
7976        ops_t = vec![None; penalties_t.len()];
7977    }
7978    let penalty_candidates = penalties_t
7979        .into_iter()
7980        .zip(active_penaltyinfo_t.into_iter())
7981        .zip(ops_t.into_iter())
7982        .map(
7983            |((matrix, info), op_in)| -> Result<PenaltyCandidate, BasisError> {
7984                let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
7985                let normalization_scale = info.normalization_scale * c_new;
7986                let op_scale = 1.0 / c_new;
7987                let kronecker_scale = 1.0 / c_new;
7988                // Frobenius rescale: wrap inner op in `ScaledPenaltyOp(1/c_new)`
7989                // so `op.as_dense() == matrix` post-normalization.
7990                let scaled_op = if op_scale > 0.0 && op_scale.is_finite() {
7991                    op_in.map(|op| {
7992                        std::sync::Arc::new(crate::analytic_penalties::ScaledPenaltyOp::new(
7993                            op, op_scale,
7994                        ))
7995                            as std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>
7996                    })
7997                } else {
7998                    None
7999                };
8000                let kronecker_factors = info.kronecker_factors.map(|mut factors| {
8001                    if let Some(first) = factors.first_mut() {
8002                        first.mapv_inplace(|v| v * kronecker_scale);
8003                    }
8004                    factors
8005                });
8006                Ok(PenaltyCandidate {
8007                    nullspace_dim_hint: info.nullspace_dim_hint,
8008                    matrix,
8009                    source: info.source,
8010                    normalization_scale,
8011                    kronecker_factors,
8012                    op: scaled_op,
8013                })
8014            },
8015        )
8016        .collect::<Result<Vec<_>, _>>()?;
8017    let (penalties_t, nullspaces_t, penaltyinfo_t, null_eigenvectors_t, ops_t) =
8018        crate::basis::filter_active_penalty_candidates_with_ops(penalty_candidates)?;
8019    let shape_linear_constraints = if term.shape != ShapeConstraint::None && !use_box_reparam {
8020        let axis = shape_axis_col.ok_or_else(|| {
8021            BasisError::InvalidInput(format!(
8022                "internal shape-constraint axis missing for term '{}'",
8023                term.name
8024            ))
8025        })?;
8026        let (x_shape_eval, design_shape_eval) =
8027            build_shape_constraint_design_1d(data, term, &metadata, axis)?;
8028        build_shape_linear_constraints_1d(
8029            x_shape_eval.view(),
8030            design_shape_eval.view(),
8031            term.shape,
8032        )?
8033    } else {
8034        None
8035    };
8036    let linear_constraints_local = merge_linear_constraints_global(shape_linear_constraints, None);
8037
8038    // Joint-null absorption rotation. Fresh fit specs compute Q from the final
8039    // per-smooth penalty set (after all in-smooth reparameterizations have
8040    // already been applied). Frozen specs already carry the complete realized
8041    // coefficient chart in their `FrozenTransform`; recomputing Q there would
8042    // rotate an already-frozen chart a second time and desynchronize value
8043    // rebuilds from derivative operators.
8044    //
8045    // Kronecker-factored smooths (tensor B-splines under `TensorBSplineIdentifiability::None`)
8046    // carry their joint penalty as `Σ_d S_d` with `S_d = I ⊗ … ⊗ S_d^{1D} ⊗ … ⊗ I`.
8047    // The joint null space is the tensor of marginal nulls and is handled directly
8048    // by the REML runtime's `kronecker_penalty_system` path (see
8049    // `runtime.rs:8334-8344`). Applying a dense (p × p) Q here would densify
8050    // `X_raw = mx ⊗ my` into `X_raw · Q`, destroying the Kronecker product
8051    // structure that the runtime relies on for fast log-det/derivative
8052    // assembly — and the rotation block at the wrapper site also unconditionally
8053    // wipes `kronecker_factored`, leaving the runtime to fall back to the
8054    // dense per-block log-det. Skip the rotation for Kronecker-factored terms
8055    // so the factored representation survives end-to-end.
8056    let joint_null_rotation = match term.joint_null_rotation.clone() {
8057        Some(persisted) => Some(persisted),
8058        None if smooth_has_frozen_identifiability(term) => None,
8059        None if kron_factored.is_some() => None,
8060        None => crate::basis::compute_joint_null_rotation(&penalties_t)?,
8061    };
8062
8063    Ok(LocalSmoothTermBuild {
8064        dim: p_local,
8065        design: design_t,
8066        penalties: penalties_t,
8067        ops: ops_t,
8068        nullspaces: nullspaces_t,
8069        null_eigenvectors: null_eigenvectors_t,
8070        joint_null_rotation,
8071        penaltyinfo: penaltyinfo_t,
8072        pre_dropped_penaltyinfo: pre_dropped_penaltyinfo_t,
8073        metadata,
8074        linear_constraints: linear_constraints_local,
8075        box_reparam: use_box_reparam,
8076        kronecker_factored: kron_factored,
8077    })
8078}
8079
8080pub fn build_smooth_design(
8081    data: ArrayView2<'_, f64>,
8082    terms: &[SmoothTermSpec],
8083) -> Result<RawSmoothDesign, BasisError> {
8084    let mut ws = crate::basis::BasisWorkspace::new();
8085    build_smooth_design_withworkspace(data, terms, &mut ws)
8086}
8087
8088/// Like `build_smooth_design`, but honors the caller workspace policy while
8089/// building each planned smooth term with an independent per-term workspace.
8090///
8091/// Independent workspaces avoid shared mutable distance-cache state during the
8092/// parallel term build; the final design, penalties, and metadata are assembled
8093/// in the original smooth-term order.
8094pub fn build_smooth_design_withworkspace(
8095    data: ArrayView2<'_, f64>,
8096    terms: &[SmoothTermSpec],
8097    workspace: &mut crate::basis::BasisWorkspace,
8098) -> Result<RawSmoothDesign, BasisError> {
8099    validate_smooth_terms_finite_inputs(data, terms)?;
8100    build_smooth_design_withworkspace_unvalidated(data, terms, workspace)
8101}
8102
8103pub fn build_smooth_design_withworkspace_unvalidated(
8104    data: ArrayView2<'_, f64>,
8105    terms: &[SmoothTermSpec],
8106    workspace: &mut crate::basis::BasisWorkspace,
8107) -> Result<RawSmoothDesign, BasisError> {
8108    let mut planned_blocks = plan_joint_spatial_centers_for_term_blocks(data, &[terms.to_vec()])?;
8109    let planned_terms = planned_blocks.pop().ok_or_else(|| {
8110        BasisError::InvalidInput(
8111            "joint spatial center planner returned no smooth blocks".to_string(),
8112        )
8113    })?;
8114    let policy = workspace.policy().clone();
8115    let local_builds: Vec<LocalSmoothTermBuild> = {
8116        use rayon::iter::{IntoParallelIterator, ParallelIterator};
8117        planned_terms
8118            .into_par_iter()
8119            .map(|term| {
8120                let mut term_workspace = crate::basis::BasisWorkspace::with_policy(policy.clone());
8121                build_single_local_smooth_term(data, &term, &mut term_workspace)
8122            })
8123            .collect::<Result<Vec<_>, _>>()?
8124    };
8125
8126    let total_p: usize = local_builds.iter().map(|built| built.dim).sum();
8127
8128    let mut local_designs: Vec<DesignMatrix> = Vec::with_capacity(local_builds.len());
8129    let mut terms_out = Vec::<SmoothTerm>::with_capacity(terms.len());
8130    let mut penalties_global = Vec::<BlockwisePenalty>::new();
8131    let mut nullspace_dims_global = Vec::<usize>::new();
8132    let mut penaltyinfo_global = Vec::<PenaltyBlockInfo>::new();
8133    let mut dropped_penaltyinfo_global = Vec::<DroppedPenaltyBlockInfo>::new();
8134    let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
8135    let mut any_bounds = false;
8136    // Each linear-constraint row only touches the current term's column slice.
8137    // Track `(col_start, col_end, local_row_values)` and assemble the final
8138    // dense `Array2` in one pass, avoiding per-row `Array1::zeros(total_p)`
8139    // allocation plus a row-by-row copy at the end.
8140    let mut linear_constraintsrows: Vec<(usize, usize, Array1<f64>)> = Vec::new();
8141    let mut linear_constraints_b: Vec<f64> = Vec::new();
8142
8143    let mut col_start = 0usize;
8144    for (term, mut built) in terms.iter().zip(local_builds.into_iter()) {
8145        let p_local = built.dim;
8146        let col_end = col_start + p_local;
8147        let lb_local = if built.box_reparam {
8148            shape_lower_bounds_local(term.shape, p_local)
8149        } else {
8150            None
8151        };
8152
8153        // Stage-2 joint-null absorption rotation. Fired *before* the
8154        // penalty / design / global aggregation loops below so that every
8155        // subsequent reference to `built.penalties`, `built.design`, and
8156        // `built.ops` sees the post-rotation values.
8157        //
8158        // The math: when the smooth's joint penalty `Σ_k S_k` has a
8159        // non-trivial null space, eigh selects `Q = [U_range | U_null]`
8160        // with null columns at the tail. Setting `β_raw = Q · γ` and
8161        // applying:
8162        //     design        ← X · Q
8163        //     penalties[k]  ← Qᵀ · S_k · Q   (block-diag, zero null tail)
8164        // yields a model whose fitted γ is invariant to the rotation
8165        // (since likelihood depends only on `X · β_raw = X · Q · γ`), but
8166        // whose penalty is full-rank on the range columns. The large-scale
8167        // failing case (cert refusal in the joint-Newton inner solve)
8168        // resolves because `H_pen = H_loglik + S` becomes full rank on
8169        // the smooth's range columns.
8170        //
8171        // Rotation is suppressed when the smooth carries coordinate-wise
8172        // shape constraints (`lb_local` or `built.linear_constraints`):
8173        // those encode a cone in the original coordinate system and a
8174        // general orthogonal rotation breaks the cone geometry. Smooths
8175        // with shape constraints typically have full-rank joint penalty
8176        // (their structural shape comes from the cone, not from null
8177        // directions in the penalty), so suppression is rarely a loss.
8178        //
8179        // `applied_rotation` carries the Q that was applied (or `None`
8180        // if no rotation fired). It is persisted onto `SmoothTerm` below
8181        // so prediction-side `X_new_raw · Q` replay can reproduce the
8182        // exact rotation. Persistence through the saved-model artifact
8183        // is a follow-up — see the doc on `SmoothTerm.joint_null_rotation`.
8184        let applied_rotation: Option<crate::basis::JointNullRotation> = match (
8185            built.joint_null_rotation.take(),
8186            lb_local.is_some(),
8187            built.linear_constraints.is_some(),
8188        ) {
8189            (Some(rot), false, false) => {
8190                let q = &rot.rotation;
8191                let dense = built
8192                    .design
8193                    .try_to_dense_by_chunks("joint-null absorption rotation")
8194                    .map_err(BasisError::InvalidInput)?;
8195                let rotated = gam_linalg::faer_ndarray::fast_ab(&dense, q);
8196                built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(rotated));
8197                built.penalties = built
8198                    .penalties
8199                    .into_iter()
8200                    .map(|s_local| {
8201                        let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
8202                        gam_linalg::faer_ndarray::fast_ab(&qt_s, q)
8203                    })
8204                    .collect();
8205                built.ops = vec![None; built.penalties.len()];
8206                built.kronecker_factored = None;
8207                Some(rot)
8208            }
8209            (Some(_), _, _) => None,
8210            (None, _, _) => None,
8211        };
8212
8213        let activeinfos = built
8214            .penaltyinfo
8215            .iter()
8216            .filter(|info| info.active)
8217            .collect::<Vec<_>>();
8218        if activeinfos.len() != built.penalties.len() {
8219            crate::bail_invalid_basis!(
8220                "internal penalty info mismatch for term '{}': activeinfos={}, penalties={}",
8221                term.name,
8222                activeinfos.len(),
8223                built.penalties.len()
8224            );
8225        }
8226        for (((s_local, &ns), info), op_local) in built
8227            .penalties
8228            .iter()
8229            .zip(built.nullspaces.iter())
8230            .zip(activeinfos.into_iter())
8231            .zip(built.ops.iter())
8232        {
8233            let global_index = penalties_global.len();
8234            penalties_global.push(
8235                BlockwisePenalty::new(col_start..col_end, s_local.clone())
8236                    .with_op(op_local.clone()),
8237            );
8238            nullspace_dims_global.push(ns);
8239            let mut penalty = info.clone();
8240            penalty.nullspace_dim_hint = ns;
8241            penaltyinfo_global.push(PenaltyBlockInfo {
8242                global_index,
8243                termname: Some(term.name.clone()),
8244                penalty,
8245            });
8246        }
8247        for info in built.penaltyinfo.iter().filter(|info| !info.active) {
8248            dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
8249                termname: Some(term.name.clone()),
8250                penalty: info.clone(),
8251            });
8252        }
8253        for info in &built.pre_dropped_penaltyinfo {
8254            dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
8255                termname: Some(term.name.clone()),
8256                penalty: info.clone(),
8257            });
8258        }
8259
8260        if let Some(lin_local) = &built.linear_constraints {
8261            for r in 0..lin_local.a.nrows() {
8262                linear_constraintsrows.push((col_start, col_end, lin_local.a.row(r).to_owned()));
8263                linear_constraints_b.push(lin_local.b[r]);
8264            }
8265        }
8266        if let Some(lb_local) = &lb_local {
8267            coefficient_lower_bounds
8268                .slice_mut(s![col_start..col_end])
8269                .assign(lb_local);
8270            any_bounds = true;
8271        }
8272
8273        // Move the per-term design out of `built` rather than cloning it.
8274        local_designs.push(built.design);
8275
8276        terms_out.push(SmoothTerm {
8277            name: term.name.clone(),
8278            coeff_range: col_start..col_end,
8279            shape: term.shape,
8280            penalties_local: built.penalties,
8281            nullspace_dims: built.nullspaces,
8282            penaltyinfo_local: built.penaltyinfo,
8283            metadata: built.metadata,
8284            lower_bounds_local: lb_local,
8285            linear_constraints_local: built.linear_constraints,
8286            kronecker_factored: built.kronecker_factored.take(),
8287            joint_null_rotation: applied_rotation,
8288            unabsorbed_global_orthogonality: None,
8289        });
8290
8291        col_start = col_end;
8292    }
8293
8294    assert_eq!(
8295        penalties_global.len(),
8296        nullspace_dims_global.len(),
8297        "global smooth penalty/nullspace bookkeeping diverged"
8298    );
8299    assert_eq!(
8300        penalties_global.len(),
8301        penaltyinfo_global.len(),
8302        "global smooth penalty metadata bookkeeping diverged"
8303    );
8304
8305    Ok(RawSmoothDesign {
8306        term_designs: local_designs,
8307        penalties: penalties_global,
8308        nullspace_dims: nullspace_dims_global,
8309        penaltyinfo: penaltyinfo_global,
8310        dropped_penaltyinfo: dropped_penaltyinfo_global,
8311        terms: terms_out,
8312        coefficient_lower_bounds: if any_bounds {
8313            Some(coefficient_lower_bounds)
8314        } else {
8315            None
8316        },
8317        linear_constraints: if linear_constraintsrows.is_empty() {
8318            None
8319        } else {
8320            let mut a = Array2::<f64>::zeros((linear_constraintsrows.len(), total_p));
8321            for (i, (cs, ce, values)) in linear_constraintsrows.iter().enumerate() {
8322                a.row_mut(i).slice_mut(s![*cs..*ce]).assign(values);
8323            }
8324            Some(LinearInequalityConstraints {
8325                a,
8326                b: Array1::from_vec(linear_constraints_b),
8327            })
8328        },
8329    })
8330}