Skip to main content

gam_terms/smooth/
term_specs.rs

1use coefficient_transforms::{
2    convex_divided_difference_transform_matrix, cumulative_exp, cumulative_sum_transform_matrix,
3    second_cumulative_exp,
4};
5
6pub use error::SmoothError;
7
8use input_standardization::{
9    apply_input_standardization, compensate_length_scale_for_standardization,
10    compensate_optional_length_scale_for_standardization, compute_spatial_input_scales,
11};
12
13use shape_constraints::{
14    build_shape_constraint_design_1d, build_shape_linear_constraints_1d,
15    merge_linear_constraints_global, shape_lower_bounds_local, shape_order_and_sign,
16    shape_supports_basis, shape_uses_box_reparameterization,
17};
18
19pub fn describe_thin_plate_center_request(strategy: &CenterStrategy) -> String {
20    match strategy {
21        CenterStrategy::Auto(inner) => describe_thin_plate_center_request(inner),
22        CenterStrategy::UserProvided(centers) => format!("{} centers", centers.nrows()),
23        CenterStrategy::EqualMass { num_centers }
24        | CenterStrategy::EqualMassCovarRepresentative { num_centers }
25        | CenterStrategy::FarthestPoint { num_centers }
26        | CenterStrategy::KMeans { num_centers, .. } => format!("{num_centers} centers"),
27        CenterStrategy::UniformGrid { points_per_dim } => {
28            format!("uniform grid with {points_per_dim} points per dimension")
29        }
30    }
31}
32
33pub fn rewrite_thin_plate_knots_error(
34    err: BasisError,
35    termname: &str,
36    feature_count: usize,
37    spec: &ThinPlateBasisSpec,
38) -> BasisError {
39    match err {
40        // Polynomial-nullspace shortfall reported directly by the kernel
41        // builder ("thin-plate spline requires at least N centers to span ...").
42        BasisError::InvalidInput(msg)
43            if msg.contains("thin-plate spline requires at least")
44                && (msg.contains("centers to span") || msg.contains("knots to span")) =>
45        {
46            let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
47            let requested = describe_thin_plate_center_request(&spec.center_strategy);
48            BasisError::InvalidInput(format!(
49                "joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
50            ))
51        }
52        // Insufficient-rows shortfall raised by `select_thin_plate_knots` when
53        // the requested center count exceeds the available row count. Rewrite
54        // it in term language so the diagnostic points at the smooth term and
55        // the polynomial-nullspace minimum the user needs to satisfy.
56        BasisError::InvalidInput(msg)
57            if msg.starts_with("requested ") && msg.contains(" knots but only ") =>
58        {
59            let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
60            let requested = describe_thin_plate_center_request(&spec.center_strategy);
61            BasisError::InvalidInput(format!(
62                "joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
63            ))
64        }
65        other => other,
66    }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
70pub enum ShapeConstraint {
71    None,
72    MonotoneIncreasing,
73    MonotoneDecreasing,
74    Convex,
75    Concave,
76}
77
78/// Parse a shape-constraint string into a [`ShapeConstraint`].
79///
80/// This is the single source of truth shared by the formula DSL
81/// (`s(x, shape=...)`) and the `smooths={...}` override path
82/// (`Smooth.shape_constraint`). The accepted spellings cover the canonical
83/// Python `ShapeConstraintLiteral` strings exactly
84/// (`"none"` / `"monotone_increasing"` / `"monotone_decreasing"` /
85/// `"convex"` / `"concave"`) plus a few common aliases. Hyphens and case are
86/// normalized, so `"Monotone-Increasing"` and `"mono_inc"` both resolve to
87/// [`ShapeConstraint::MonotoneIncreasing`].
88pub fn parse_shape_constraint(raw: &str) -> Result<ShapeConstraint, String> {
89    let normalized = raw.trim().to_ascii_lowercase().replace('-', "_");
90    match normalized.as_str() {
91        "" | "none" => Ok(ShapeConstraint::None),
92        "monotone_increasing" | "monotonic_increasing" | "increasing" | "mono_inc" | "mpi" => {
93            Ok(ShapeConstraint::MonotoneIncreasing)
94        }
95        "monotone_decreasing" | "monotonic_decreasing" | "decreasing" | "mono_dec" | "mpd" => {
96            Ok(ShapeConstraint::MonotoneDecreasing)
97        }
98        "convex" | "cvx" => Ok(ShapeConstraint::Convex),
99        "concave" | "ccv" => Ok(ShapeConstraint::Concave),
100        other => Err(format!(
101            "unknown shape constraint {other:?}; expected one of \
102             \"none\", \"monotone_increasing\", \"monotone_decreasing\", \
103             \"convex\", \"concave\""
104        )),
105    }
106}
107
108impl ShapeConstraint {
109    /// Canonical formula-DSL spelling, i.e. the text emitted into
110    /// `s(x, shape=...)`. Round-trips through [`parse_shape_constraint`].
111    pub fn dsl_str(&self) -> &'static str {
112        match self {
113            ShapeConstraint::None => "none",
114            ShapeConstraint::MonotoneIncreasing => "monotone_increasing",
115            ShapeConstraint::MonotoneDecreasing => "monotone_decreasing",
116            ShapeConstraint::Convex => "convex",
117            ShapeConstraint::Concave => "concave",
118        }
119    }
120}
121
122/// Smooth-term head keywords recognised by the formula DSL. A `shape=` option
123/// may be attached to any term whose head is one of these.
124pub const SMOOTH_HEAD_KEYWORDS: [&str; 11] = [
125    "s",
126    "smooth",
127    "te",
128    "tensor",
129    "thinplate",
130    "tps",
131    "duchon",
132    "matern",
133    "sphere",
134    "bs",
135    "bspline",
136];
137
138/// Rewrite smooth-term calls in `formula` so each named smooth carries a
139/// `shape=<kind>` option understood by the formula DSL.
140///
141/// `constraints` pairs the smooth-term text as it appears in the formula
142/// (e.g. `"s(x)"` or `"s(x, type=duchon, centers=8)"`) with a shape-constraint
143/// spelling accepted by [`parse_shape_constraint`]; comparison is exact after
144/// whitespace removal. A `"none"` constraint is a no-op. Referencing a term not
145/// present in the formula is an error.
146///
147/// This is the single source of truth for the `gamfit.fit(..., constraints=…)`
148/// rewrite — the Python wrapper only marshals the mapping across the FFI and
149/// holds no formula-parsing or alias-normalization logic of its own.
150pub fn apply_shape_constraints_to_formula(
151    formula: &str,
152    constraints: &[(String, String)],
153) -> Result<String, String> {
154    use std::collections::{BTreeMap, BTreeSet};
155
156    if constraints.is_empty() {
157        return Ok(formula.to_string());
158    }
159    let strip_ws = |s: &str| -> String { s.chars().filter(|c| !c.is_whitespace()).collect() };
160
161    // Whitespace-stripped term text -> canonical shape spelling.
162    let mut wanted: BTreeMap<String, &'static str> = BTreeMap::new();
163    // Whitespace-stripped term text -> original key (for error labels).
164    let mut originals: BTreeMap<String, String> = BTreeMap::new();
165    for (key, kind_raw) in constraints {
166        let kind = parse_shape_constraint(kind_raw)?;
167        let nk = strip_ws(key);
168        originals.entry(nk.clone()).or_insert_with(|| key.clone());
169        if kind != ShapeConstraint::None {
170            wanted.insert(nk, kind.dsl_str());
171        }
172    }
173    if wanted.is_empty() {
174        return Ok(formula.to_string());
175    }
176
177    let chars: Vec<char> = formula.chars().collect();
178    let n = chars.len();
179    let is_ident = |c: char| c.is_ascii_alphanumeric() || c == '_';
180
181    let mut out = String::with_capacity(formula.len() + 32);
182    let mut matched: BTreeSet<String> = BTreeSet::new();
183    let mut i = 0usize;
184    while i < n {
185        // Locate the next smooth-term head (`<keyword> \s* (`) at or after `i`,
186        // respecting word boundaries so `abs(` never matches the `s(` head.
187        let mut head: Option<(usize, usize)> = None; // (head_start, paren_index)
188        let mut p = i;
189        while p < n {
190            let boundary = p == 0 || !is_ident(chars[p - 1]);
191            if boundary {
192                for kw in SMOOTH_HEAD_KEYWORDS.iter() {
193                    let klen = kw.chars().count();
194                    if p + klen > n || chars[p..p + klen].iter().collect::<String>() != **kw {
195                        continue;
196                    }
197                    let mut q = p + klen;
198                    while q < n && chars[q].is_whitespace() {
199                        q += 1;
200                    }
201                    if q < n && chars[q] == '(' {
202                        head = Some((p, q));
203                        break;
204                    }
205                }
206            }
207            if head.is_some() {
208                break;
209            }
210            p += 1;
211        }
212        let (head_start, paren_open) = match head {
213            Some(h) => h,
214            None => {
215                out.extend(chars[i..].iter());
216                break;
217            }
218        };
219        out.extend(chars[i..head_start].iter());
220
221        // Find the matching close paren, honoring nesting and string literals.
222        let body_start = paren_open + 1;
223        let mut depth = 1i32;
224        let mut j = body_start;
225        let mut in_str: Option<char> = None;
226        let mut closed = false;
227        while j < n {
228            let ch = chars[j];
229            if let Some(quote) = in_str {
230                if ch == quote {
231                    in_str = None;
232                }
233            } else if ch == '\'' || ch == '"' {
234                in_str = Some(ch);
235            } else if ch == '(' {
236                depth += 1;
237            } else if ch == ')' {
238                depth -= 1;
239                if depth == 0 {
240                    closed = true;
241                    break;
242                }
243            }
244            j += 1;
245        }
246
247        if !closed {
248            // Unbalanced — emit the remainder verbatim; the DSL parser will
249            // produce the canonical error.
250            out.extend(chars[head_start..].iter());
251            break;
252        }
253
254        let term_text: String = chars[head_start..=j].iter().collect();
255
256        let key_norm = strip_ws(&term_text);
257
258        match wanted.get(&key_norm) {
259            None => out.extend(chars[head_start..=j].iter()),
260            Some(kind) => {
261                let head_paren: String = chars[head_start..body_start].iter().collect();
262                let inside: String = chars[body_start..j].iter().collect();
263                let inside = inside.trim();
264                if inside.is_empty() {
265                    out.push_str(&format!("{head_paren}shape={kind})"));
266                } else {
267                    out.push_str(&format!("{head_paren}{inside}, shape={kind})"));
268                }
269                matched.insert(key_norm);
270            }
271        }
272
273        i = j + 1;
274    }
275
276    let mut missing: Vec<String> = wanted
277        .keys()
278        .filter(|k| !matched.contains(*k))
279        .map(|k| originals.get(k).cloned().unwrap_or_else(|| k.clone()))
280        .collect();
281
282    if !missing.is_empty() {
283        missing.sort();
284        return Err(format!(
285            "shape constraints referenced smooth term(s) not found in formula: {}",
286            missing.join(", ")
287        ));
288    }
289
290    Ok(out)
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub enum BySmoothKind {
295    Numeric,
296    Level { level_bits: u64 },
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub enum SmoothBasisSpec {
301    /// Row-gated wrapper used for mgcv-style ``by=`` smooths.
302    ///
303    /// ``ByNumeric`` multiplies the inner smooth by a numeric column.
304    /// ``ByLevel`` keeps the inner smooth active only for rows whose encoded
305    /// categorical value has the stored bit pattern.  Unordered factor-by
306    /// smooths are represented as one independent ``ByLevel`` term per level.
307    ///
308    /// `kind` preserves the compact structural discriminator, while `by`
309    /// carries the full row-gating spec used to build the local design.
310    ByVariable {
311        inner: Box<SmoothBasisSpec>,
312        by_col: usize,
313        kind: BySmoothKind,
314        by: ByVariableSpec,
315    },
316    /// Sum-to-zero factor smooth (`bs="sz"`): with L levels, estimate L-1
317    /// deviation coefficient blocks and use the final level as the negative
318    /// sum of the others, enforcing coefficient-wise zero sums across levels.
319    FactorSumToZero {
320        inner: Box<SmoothBasisSpec>,
321        by_col: usize,
322        levels: Vec<u64>,
323        /// Global-orthogonality column map `Z` captured at fit time when this
324        /// term overlapped an owner smooth (`s(x) + s(g, x, bs=sz)`, #978):
325        /// the hierarchical-ownership pass residualized this term's realized
326        /// design as `X ← X·Z`, shrinking its coefficient block. `Z` depends
327        /// on the *training-row* owner designs, so prediction cannot rederive
328        /// it — it must be persisted and replayed
329        /// (`apply_global_smooth_identifiability` consumes it verbatim).
330        /// Chart convention: `Z` lives in the post-restack, post-joint-null-Q
331        /// coordinates — the raw `sz` rebuild reapplies `Q` deterministically
332        /// (#700), then `Z` applies on top. `None` for non-overlapping terms.
333        #[serde(default)]
334        frozen_global_orthogonality: Option<Array2<f64>>,
335    },
336    BSpline1D {
337        feature_col: usize,
338        spec: BSplineBasisSpec,
339    },
340    /// A smooth modulated by a `by=` variable. Numeric `by` scales one inner
341    /// smooth; factor `by` replicates the inner smooth by level.
342    BySmooth {
343        smooth: Box<SmoothBasisSpec>,
344        by_kind: ByVarKind,
345    },
346    /// Factor-smooth interaction families (`bs="fs"`, `bs="sz"`) and
347    /// random slopes (`bs="re"`).
348    FactorSmooth { spec: FactorSmoothSpec },
349    ThinPlate {
350        feature_cols: Vec<usize>,
351        spec: ThinPlateBasisSpec,
352        /// Per-column standard deviations used to standardize input dimensions
353        /// before kernel evaluation when d > 1. `None` means no standardization
354        /// (either d == 1 or explicitly disabled).
355        #[serde(default)]
356        input_scales: Option<Vec<f64>>,
357    },
358    Sphere {
359        feature_cols: Vec<usize>,
360        spec: SphericalSplineBasisSpec,
361    },
362    /// Constant-curvature (`M_κ`) geodesic-kernel smooth over κ-stereographic
363    /// chart coordinates (#944): one construction interpolating
364    /// S^d → ℝ^d → H^d through the spec's fixed κ. The Wahba S² smooth is the
365    /// structural template; the geometry comes from
366    /// `geometry::constant_curvature::ConstantCurvature`.
367    ConstantCurvature {
368        feature_cols: Vec<usize>,
369        spec: ConstantCurvatureBasisSpec,
370    },
371    Matern {
372        feature_cols: Vec<usize>,
373        spec: MaternBasisSpec,
374        #[serde(default)]
375        input_scales: Option<Vec<f64>>,
376    },
377    /// Measure-jet spline smooth: multiscale local-jet-residual energy of the
378    /// empirical measure (centers as μ-quadrature, masses as μ-weights — no
379    /// graph, mesh, or neighbor set inside the statistical object). The
380    /// feature columns are ambient coordinates of data concentrated near an
381    /// unknown low-dimensional, possibly stratified set.
382    MeasureJet {
383        feature_cols: Vec<usize>,
384        spec: MeasureJetBasisSpec,
385        #[serde(default)]
386        input_scales: Option<Vec<f64>>,
387    },
388    Duchon {
389        feature_cols: Vec<usize>,
390        spec: DuchonBasisSpec,
391        #[serde(default)]
392        input_scales: Option<Vec<f64>>,
393    },
394    Pca {
395        feature_cols: Vec<usize>,
396        basis_matrix: Array2<f64>,
397        centered: bool,
398        #[serde(default = "default_pca_smooth_penalty")]
399        smooth_penalty: f64,
400        #[serde(default)]
401        center_mean: Option<Array1<f64>>,
402        #[serde(default)]
403        pca_basis_path: Option<PathBuf>,
404        #[serde(default = "default_pca_chunk_size")]
405        chunk_size: usize,
406    },
407    /// Tensor-product smooth built from 1D B-spline marginals.
408    ///
409    /// This is the `te()`-style construction used when axes have different units/scales
410    /// (for example, space x time) and isotropic radial kernels are not appropriate.
411    TensorBSpline {
412        feature_cols: Vec<usize>,
413        spec: TensorBSplineSpec,
414    },
415}
416
417impl SmoothBasisSpec {
418    /// Conservative lower bound on the number of sample rows needed for this
419    /// smooth basis to have a well-posed REML fit.
420    ///
421    /// Each basis kind answers the question for itself, so the workflow does
422    /// not have to know how many columns a B-spline, tensor product, PCA
423    /// projection, or spatial kernel emits. The contract is a *lower bound*:
424    /// returning too small a number is permitted (the inner solver will catch
425    /// any genuine n-vs-rank failure that slips past); returning too large a
426    /// number is a regression because it rejects legitimate fits.
427    ///
428    /// Rationale: B-spline / tensor / PCA bases have a closed-form column
429    /// count, so we use the exact dimension. Radial bases (TPS, Matern,
430    /// Duchon, Sphere) and factor smooths choose their column count from the
431    /// data (`heuristic_centers`, `unique_count`); we fall back to a small
432    /// constant floor because a fit on fewer than five rows cannot stabilise
433    /// any radial smooth regardless of the configured kernel scale.
434    pub fn min_sample_rows(&self) -> usize {
435        // Floor used for data-driven bases whose column count is not known
436        // from the spec alone. Five rows is the minimum at which the inner
437        // pivot/QR + REML smoothing-parameter search has any chance of being
438        // well-posed for a non-parametric smooth.
439        const RADIAL_FLOOR: usize = 5;
440
441        match self {
442            Self::ByVariable { inner, .. } => inner.min_sample_rows(),
443            Self::FactorSumToZero { inner, levels, .. } => {
444                // L-1 independent deviation blocks each carrying the inner
445                // basis dimension. Skip the levels-multiplier if it doesn't
446                // bring more rows; we want the *lower bound* not the rank.
447                let inner_min = inner.min_sample_rows();
448                let lvls = levels.len().saturating_sub(1).max(1);
449                inner_min.saturating_mul(lvls)
450            }
451            Self::BSpline1D { spec, .. } => bspline_basis_min_rows(spec),
452            Self::BySmooth { smooth, .. } => smooth.min_sample_rows(),
453            Self::FactorSmooth { spec } => {
454                // Replicates the marginal once per level; without a known
455                // level count we conservatively require at least the marginal
456                // basis dimension.
457                bspline_basis_min_rows(&spec.marginal)
458            }
459            Self::ThinPlate { .. }
460            | Self::Sphere { .. }
461            | Self::ConstantCurvature { .. }
462            | Self::Matern { .. }
463            | Self::MeasureJet { .. }
464            | Self::Duchon { .. } => RADIAL_FLOOR,
465            Self::Pca { basis_matrix, .. } => basis_matrix.ncols().max(1),
466            Self::TensorBSpline { spec, .. } => {
467                // A `te(...)` smooth is *penalized*: each margin carries a
468                // difference (wiggliness) penalty and the tensor inherits a
469                // Kronecker-sum penalty `S = Σ_i I ⊗ … ⊗ S_i ⊗ … ⊗ I`. The raw
470                // column count is the *product* of the per-marginal column
471                // counts, but that product is the lower bound for an
472                // *unpenalized* tensor regression — it is the number of rows you
473                // would need to identify every interaction column with no
474                // regularization. The penalty regularizes all of those
475                // interaction directions; only the combined penalty *null space*
476                // (the tensor product of the per-margin polynomial trends, a
477                // handful of columns) must be identified by the data, and the
478                // smoothing-parameter search shrinks the rest. The effective
479                // degrees of freedom of the fitted `te()` are therefore a small
480                // fraction of the column product, which is exactly why mgcv
481                // fits a default `te(x, y)` on a couple hundred rows.
482                //
483                // The honest *penalized* lower bound is the **sum** of the
484                // per-marginal column counts, not their product: a row floor of
485                // `Σ_i k_i` still guarantees enough data to identify each
486                // margin's additive main-effect (the largest sub-block the
487                // penalty cannot shrink to zero), while no longer conflating
488                // unpenalized column-count identifiability with penalized
489                // well-posedness. This accepts moderate-`n` penalized tensors
490                // (e.g. a 20×20 default basis on n=200) yet still rejects a
491                // genuinely undersized fit where `n < Σ_i k_i` and even the
492                // additive part is rank-deficient.
493                //
494                // Binary / low-cardinality margins (#724): gam will accept a
495                // `te(x, badh)` whose `badh ∈ {0, 1}` margin nominally requests
496                // more basis columns than `badh` has unique values, where mgcv
497                // refuses the unpenalized term as ill-posed ("badh has
498                // insufficient unique values to support k knots"). This is
499                // correct-by-design, *not* a degenerate fit: the marginal
500                // wiggliness penalty on the `badh` axis has a null space that is
501                // exactly its identifiable trend (the two cell means of a binary
502                // covariate), and the Kronecker-sum penalty shrinks every tensor
503                // column outside that null space toward zero. The resulting fit
504                // is the well-posed "per-level `x` smooth + binary main effect"
505                // that mgcv reaches only after manually collapsing the basis —
506                // gam reaches it automatically because the penalty, not the raw
507                // column count, sets the effective rank. A genuinely
508                // rank-deficient design (penalty null space wider than the data
509                // can support) is still caught downstream by the inner pivoted
510                // factorization, which owns the exact n-vs-rank decision; this
511                // pre-fit gate only refuses the grossly-undersized formula.
512                let mut total: usize = 0;
513                for marginal in &spec.marginalspecs {
514                    let m = bspline_basis_min_rows(marginal);
515                    total = total.saturating_add(m.max(1));
516                }
517                total.max(RADIAL_FLOOR)
518            }
519        }
520    }
521
522    /// Stable structural discriminant for warm-start cache keying (#869).
523    ///
524    /// Two smooths that produce different bases / penalty structures must map
525    /// to different strings here so they cannot collide on the persistent
526    /// warm-start `cache_key` (which is otherwise blind to topology: it hashes
527    /// only the raw input column count, so e.g. `sphere` vs `torus` vs
528    /// `euclidean` candidates fit on the *same* data would otherwise share one
529    /// key and cross-contaminate each other's β/ρ seed). The string is the
530    /// topology identity, not the fitted coefficients, so same-topology refits
531    /// (the screen→full-refit cascade) still hit the same key and reuse work.
532    pub fn structural_kind(&self) -> &'static str {
533        match self {
534            Self::ByVariable { .. } => "by_variable",
535            Self::FactorSumToZero { .. } => "factor_sum_to_zero",
536            Self::BSpline1D { .. } => "bspline_1d",
537            Self::BySmooth { .. } => "by_smooth",
538            Self::FactorSmooth { .. } => "factor_smooth",
539            Self::ThinPlate { .. } => "thin_plate",
540            Self::Sphere { .. } => "sphere",
541            Self::ConstantCurvature { .. } => "constant_curvature",
542            Self::Matern { .. } => "matern",
543            Self::MeasureJet { .. } => "measurejet",
544            Self::Duchon { .. } => "duchon",
545            Self::Pca { .. } => "pca",
546            Self::TensorBSpline { .. } => "tensor_bspline",
547        }
548    }
549
550    /// True for a tensor-product smooth that is only *marginally* centered
551    /// (`ti(...)`, [`TensorBSplineIdentifiability::MarginalSumToZero`]): its
552    /// per-margin sum-to-zero reparameterization `(B_xZ_x)⊗(B_zZ_z)` has ALREADY
553    /// removed each axis's main effect analytically (mgcv-identical), so its
554    /// main-effect removal is complete and it must take NO additional
555    /// owner-residualization block. Residualizing it a second time against the
556    /// realized main-effect designs is a grid-fragile no-op on an exact tensor
557    /// grid but eats genuine pure-interaction curvature off-grid (#1470).
558    pub fn is_marginally_centered_tensor(&self) -> bool {
559        matches!(
560            self,
561            Self::TensorBSpline { spec, .. }
562                if matches!(spec.identifiability, TensorBSplineIdentifiability::MarginalSumToZero)
563        )
564    }
565
566    /// A sum-to-zero factor smooth (`bs="sz"`) has ALREADY removed the
567    /// cross-group main effect analytically, in coefficient space, via its
568    /// `Σ_g d_g(x) ≡ 0` reparameterization (`L-1` deviation blocks with the
569    /// reference level the negative sum of the others) — exactly mgcv's `sz`
570    /// construction, which is self-identifiable against an overlapping `s(x)`
571    /// with no further constraint. Residualizing it a SECOND time against the
572    /// realized B-spline span of the explicit `s(x)` smooth is redundant in
573    /// exact arithmetic (the common-to-all-groups component is zero by
574    /// construction) and actively HARMFUL on finite data: each deviation block
575    /// is a B-spline in `x` whose realized columns share `s(x)`'s span, so the
576    /// joint residualization collapses the full `L·k`-column deviation design to
577    /// `L·k − rank(s(x))` columns and eats the within-group curvature `s(x)`
578    /// cannot represent. REML then rails the deviation smoothing parameter and
579    /// the factor smooth under-recovers (#1605). This is the exact analogue of
580    /// the marginally-centered tensor (`ti`) exemption (#1470), so such a term
581    /// takes NO owner-residualization block.
582    pub fn is_sum_to_zero_factor_smooth(&self) -> bool {
583        matches!(
584            self,
585            Self::FactorSumToZero { .. }
586                | Self::FactorSmooth {
587                    spec: FactorSmoothSpec {
588                        flavour: FactorSmoothFlavour::Sz,
589                        ..
590                    }
591                }
592        )
593    }
594
595    /// Feature columns this basis consumes, used alongside [`structural_kind`]
596    /// to disambiguate two same-kind smooths on different axes. Wrapper
597    /// variants delegate to their inner basis.
598    pub fn structural_feature_cols(&self) -> Vec<usize> {
599        match self {
600            Self::ByVariable { inner, .. } | Self::FactorSumToZero { inner, .. } => {
601                inner.structural_feature_cols()
602            }
603            Self::BySmooth { smooth, .. } => smooth.structural_feature_cols(),
604            Self::FactorSmooth { .. } => Vec::new(),
605            Self::BSpline1D { feature_col, .. } => vec![*feature_col],
606            Self::ThinPlate { feature_cols, .. }
607            | Self::Sphere { feature_cols, .. }
608            | Self::ConstantCurvature { feature_cols, .. }
609            | Self::Matern { feature_cols, .. }
610            | Self::MeasureJet { feature_cols, .. }
611            | Self::Duchon { feature_cols, .. }
612            | Self::Pca { feature_cols, .. }
613            | Self::TensorBSpline { feature_cols, .. } => feature_cols.clone(),
614        }
615    }
616}
617
618/// Lower bound on the number of sample rows a 1D B-spline smooth needs for a
619/// well-posed *penalized* REML fit. Used as the per-smooth row floor in
620/// [`SmoothBasisSpec::min_sample_rows`].
621///
622/// For a *singly*-penalized smooth the floor is the full column count: the
623/// wiggliness penalty leaves the order-`m` polynomial trend unpenalized, and
624/// gam's original gate conservatively required enough rows for the whole basis.
625/// That conservative floor is kept here unchanged.
626///
627/// A *double*-penalized smooth (mgcv `select=TRUE`) is different: it adds a
628/// second penalty on the wiggliness penalty's null space, so even the
629/// polynomial trend is shrinkable toward zero and *nothing* in the basis
630/// requires unpenalized identification by the data — exactly the reasoning the
631/// `TensorBSpline` arm of [`SmoothBasisSpec::min_sample_rows`] already applies
632/// to a penalized tensor. Its honest floor is therefore a small stabilization
633/// constant, not the column count. This is what lets mgcv (and now gam) fit
634/// several `select=TRUE` smooths on a dataset whose row count is below the
635/// summed basis width (e.g. the n≈30 `wine_gamair` fold, 5 `ps` smooths,
636/// p≈51): the penalties, not the data, set the effective rank. The bounded
637/// outer REML loop still terminates, and the genuine n-vs-rank decision is
638/// owned downstream by the inner pivoted factorization. Without this, gam
639/// rejected the fit outright (or, before the gate existed, the outer REML loop
640/// wandered the flat overparameterized surface until the benchmark wall budget
641/// killed it — #1089).
642pub fn bspline_basis_min_rows(spec: &crate::basis::BSplineBasisSpec) -> usize {
643    use crate::basis::BSplineKnotSpec;
644    let columns = match &spec.knotspec {
645        BSplineKnotSpec::Generate {
646            num_internal_knots, ..
647        } => *num_internal_knots + spec.degree + 1,
648        BSplineKnotSpec::Automatic {
649            num_internal_knots: Some(k),
650            ..
651        } => *k + spec.degree + 1,
652        BSplineKnotSpec::Automatic {
653            num_internal_knots: None,
654            ..
655        } => {
656            // Knot count is data-derived (`default_internal_knot_count_for_data`).
657            // A minimal cubic basis is `degree + 2` columns; below that the
658            // basis cannot represent a non-parametric smooth.
659            spec.degree + 2
660        }
661        BSplineKnotSpec::Provided(knots) => knots.len().saturating_sub(spec.degree + 1).max(1),
662        // cr basis dimension equals the knot count (no degree offset).
663        BSplineKnotSpec::NaturalCubicRegression { knots } => knots.len(),
664        BSplineKnotSpec::PeriodicUniform { num_basis, .. } => *num_basis,
665    };
666    let columns = columns.max(spec.degree + 2);
667
668    if spec.double_penalty {
669        // Fully shrinkable basis: only a small stabilization floor must be
670        // identified by the data, capped by the actual column count.
671        const DOUBLE_PENALTY_FLOOR: usize = 2;
672        DOUBLE_PENALTY_FLOOR.min(columns).max(1)
673    } else {
674        columns
675    }
676}
677
678#[derive(Debug, Clone, Serialize, Deserialize)]
679pub enum ByVariableSpec {
680    Numeric,
681    Level { value_bits: u64, label: String },
682}
683
684
685#[derive(Debug, Clone, Serialize, Deserialize)]
686pub enum ByVarKind {
687    Numeric {
688        feature_col: usize,
689    },
690    Factor {
691        feature_col: usize,
692        ordered: bool,
693        frozen_levels: Option<Vec<u64>>,
694    },
695}
696
697#[derive(Debug, Clone, Serialize, Deserialize)]
698pub struct FactorSmoothSpec {
699    pub continuous_cols: Vec<usize>,
700    pub group_col: usize,
701    pub marginal: BSplineBasisSpec,
702    pub flavour: FactorSmoothFlavour,
703    pub group_frozen_levels: Option<Vec<u64>>,
704    /// Fit-time global-orthogonality chart `Z` for this term (`s(x) + fs(x, g)`
705    /// overlap residualization, #978), in the post-joint-null-`Q` coordinates
706    /// (the raw rebuild recomputes any `Q` itself; `fs` penalties are
707    /// typically full-rank so `Q` is absent). Training-row dependent, hence
708    /// persisted; replayed verbatim by `apply_global_smooth_identifiability`.
709    #[serde(default)]
710    pub frozen_global_orthogonality: Option<Array2<f64>>,
711}
712
713#[derive(Debug, Clone, Serialize, Deserialize)]
714pub enum FactorSmoothFlavour {
715    Fs { m_null_penalty_orders: Vec<usize> },
716    Sz,
717    Re,
718}
719
720#[derive(Debug, Default, Clone, Serialize, Deserialize)]
721pub struct TensorBSplineSpec {
722    pub marginalspecs: Vec<BSplineBasisSpec>,
723    #[serde(default)]
724    pub periods: Vec<Option<f64>>,
725    pub double_penalty: bool,
726    #[serde(default)]
727    pub identifiability: TensorBSplineIdentifiability,
728    #[serde(default)]
729    pub penalty_decomposition: TensorBSplinePenaltyDecomposition,
730}
731
732#[derive(Debug, Default, Clone, Serialize, Deserialize)]
733pub enum TensorBSplineIdentifiability {
734    None,
735    #[default]
736    SumToZero,
737    /// mgcv `ti(...)` semantics: a *tensor interaction* smooth that excludes the
738    /// marginal main effects. A sum-to-zero constraint is applied to **each
739    /// marginal basis independently** before forming the tensor product, so the
740    /// resulting column space contains no function of a single variable alone —
741    /// only the pure interaction survives. The realized identifiability
742    /// transform is the Kronecker product `Z = Z₀ ⊗ Z₁ ⊗ … ⊗ Z_{d-1}` of the
743    /// per-margin sum-to-zero null-space bases, which is exactly the
744    /// reparameterization that turns the full-tensor design into the tensor
745    /// product of the centered margins.
746    MarginalSumToZero,
747    FrozenTransform {
748        transform: Array2<f64>,
749    },
750}
751
752#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
753pub enum TensorBSplinePenaltyDecomposition {
754    /// mgcv `te(...)`: one overlapping Kronecker-product penalty per margin,
755    /// `S_j` embedded against identities in the other tensor factors.
756    #[default]
757    MarginalKroneckerSum,
758    /// mgcv `t2(...)`: split every marginal coefficient space into penalized
759    /// range and penalty-null subspaces, then emit one disjoint tensor-subspace
760    /// penalty for every non-empty penalized/null combination.
761    Separable,
762}
763
764#[derive(Debug, Clone, Serialize, Deserialize)]
765pub struct SmoothTermSpec {
766    pub name: String,
767    pub basis: SmoothBasisSpec,
768    pub shape: ShapeConstraint,
769    /// Joint-null absorption rotation captured at fit time. `Some(Q)` means
770    /// the fitted coefficient vector lives in `γ`-coordinates with
771    /// `β_raw = Q · γ`; prediction must rotate the raw-basis design via
772    /// `X_new = X_new_raw · Q` to match. `None` means either the smooth had
773    /// no joint null space (penalty already full-rank) or rotation was
774    /// suppressed (smooth carries shape constraints whose cone geometry
775    /// would not survive an arbitrary orthogonal rotation). Persisted so
776    /// `save → load → predict` is bit-equivalent to in-memory prediction.
777    #[serde(default)]
778    pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
779}
780
781#[derive(Debug, Clone)]
782pub struct SmoothTerm {
783    pub name: String,
784    pub coeff_range: Range<usize>,
785    pub shape: ShapeConstraint,
786    pub penalties_local: Vec<Array2<f64>>,
787    pub nullspace_dims: Vec<usize>,
788    pub penaltyinfo_local: Vec<PenaltyInfo>,
789    pub metadata: BasisMetadata,
790    /// Optional term-local lower bounds for constrained coefficients.
791    /// `-inf` means unconstrained.
792    pub lower_bounds_local: Option<Array1<f64>>,
793    /// Optional term-local inequality constraints in local coefficient coordinates.
794    /// `A_local * beta_local >= b_local`.
795    pub linear_constraints_local: Option<LinearInequalityConstraints>,
796    /// Optional factored tensor-product representation preserved for operator-backed
797    /// assembly in the main design builder.
798    pub kronecker_factored: Option<KroneckerFactoredBasis>,
799    /// Joint-null absorption rotation. `Some(Q)` records the orthonormal
800    /// `(p_local × p_local)` matrix that was applied to this term's design
801    /// and per-block penalties at construction time:
802    /// `term_design ← X_raw · Q`, `penalties_local[k] ← Qᵀ · S_raw · Q`.
803    /// The smooth's coefficient vector therefore lives in the rotated
804    /// (`γ`) coordinate system, with `β_raw = Q · γ` recovering the raw
805    /// pre-rotation parameterization. `None` means either no joint null
806    /// space (penalty already full-rank) or rotation was suppressed —
807    /// suppression fires when the smooth carries shape constraints
808    /// (lower bounds or local linear inequalities) that would lose their
809    /// cone geometry under a general orthogonal rotation.
810    ///
811    /// Prediction-side replay: callers building a new-data design `X_new_raw`
812    /// from the *raw* basis must call [`SmoothTerm::apply_rotation_to_predict`]
813    /// (or equivalent) to obtain `X_new = X_new_raw · Q` matching this
814    /// term's coefficient system.
815    ///
816    /// Persistence replay: `freeze_term_collection_from_design` copies this
817    /// rotation into `SmoothTermSpec`, which is serialized with fitted-model
818    /// payloads and reused by the predict-time basis builder. Saved models
819    /// therefore replay the same `X_new_raw · Q` transform as in-memory
820    /// prediction.
821    pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
822    /// Global-orthogonality transform that `apply_global_smooth_identifiability`
823    /// applied to this term's design but could NOT embed into `metadata`
824    /// (factor-smooth kinds: `sz` metadata is per-marginal, `fs` metadata has
825    /// no transform slot — #978). `freeze_term_collection_from_design` copies
826    /// it onto the term's basis spec (`frozen_global_orthogonality`) so the
827    /// predict-side rebuild replays it instead of emitting the unresidualized
828    /// (wider) design that the fitted coefficients no longer match.
829    /// Chart convention is per kind: post-`Q` `Z` for `sz` (the raw rebuild
830    /// reapplies `Q` itself, #700), full `Q·Z` chart for `fs`.
831    pub unabsorbed_global_orthogonality: Option<Array2<f64>>,
832}
833
834impl SmoothTerm {
835    /// Apply the joint-null absorption rotation to a raw new-data design
836    /// matrix, returning `X_new_raw · Q` when this term was rotated at
837    /// fit time, or `X_new_raw` unchanged when no rotation was applied.
838    ///
839    /// Callers in the prediction path: after building the smooth's basis
840    /// at new data via the *raw* basis builder (the same builder used at
841    /// fit time, applied to `x_new` instead of the training rows), call
842    /// this method on the resulting matrix before forming `X · β`. The
843    /// fitted `β` lives in `γ`-coordinates if Q was applied; multiplying
844    /// the un-rotated `X_new_raw` by `β` would give a wrong η.
845    ///
846    /// Returns an error if the raw design's column count does not match
847    /// the rotation's `p_local`. The width invariant must hold: the raw
848    /// basis builder MUST emit the same `p_local` columns that the
849    /// fit-time builder did, and the rotation is `(p_local × p_local)`.
850    pub fn apply_rotation_to_predict(
851        &self,
852        x_new_raw: Array2<f64>,
853    ) -> Result<Array2<f64>, BasisError> {
854        let Some(rot) = self.joint_null_rotation.as_ref() else {
855            return Ok(x_new_raw);
856        };
857        let p_local = rot.rotation.nrows();
858        if x_new_raw.ncols() != p_local {
859            crate::bail_dim_basis!(
860                "joint-null rotation replay for term '{}': raw design has {} columns, \
861                 rotation expects {} (the raw basis builder must emit the same column \
862                 count as at fit time)",
863                self.name,
864                x_new_raw.ncols(),
865                p_local,
866            );
867        }
868        Ok(gam_linalg::faer_ndarray::fast_ab(
869            &x_new_raw,
870            &rot.rotation,
871        ))
872    }
873
874    /// Dimension of the **joint** null space of this term's active penalties:
875    /// the coefficient directions penalized by *no* penalty. The smooth-component
876    /// Wald test ([`crate::inference::smooth_test::wood_smooth_test`]) treats this
877    /// many leading coefficients as genuine unpenalized fixed effects and tests
878    /// them at full rank; the remainder is the penalized sub-block tested with a
879    /// rank-`≈EDF` truncated pseudo-inverse.
880    ///
881    /// Because every penalty block `S_k` is positive semi-definite,
882    /// `vᵀ(Σ_k S_k)v = Σ_k vᵀ S_k v = 0` iff `S_k v = 0` for *every* `k`; the
883    /// joint null space is therefore exactly `null(Σ_k S_k)`, of dimension
884    /// `p_local − rank(Σ_k S_k)`. This is the **intersection** of the per-penalty
885    /// null spaces, not their sum.
886    ///
887    /// Summing the per-penalty `nullspace_dims` instead (the historical defect
888    /// behind #1360) *unions* the null spaces and badly over-counts: a
889    /// double-penalty smooth carries a bending penalty (null space = its
890    /// polynomial part) plus a complementary null-space ridge (which penalizes
891    /// exactly that polynomial part), so the two null spaces are disjoint and the
892    /// joint null space is empty — yet the per-penalty dims sum to nearly
893    /// `p_local`. Feeding that inflated count to the Wald test makes it test
894    /// almost the whole shrunk block at full rank, manufacturing overwhelming
895    /// "significance" for a term the fit drove to ~0 EDF.
896    pub fn wald_unpenalized_dim(&self) -> usize {
897        joint_unpenalized_dim(
898            self.coeff_range.len(),
899            &self.penalties_local,
900            &self.nullspace_dims,
901        )
902    }
903}
904
905/// Numeric core of [`SmoothTerm::wald_unpenalized_dim`]: the dimension of the
906/// joint null space `∩_k null(S_k) = null(Σ_k S_k)` of a term's local penalty
907/// blocks, with a conservative fallback when a penalty is not materialized as a
908/// full `p_local × p_local` matrix (e.g. a Kronecker tensor factor).
909pub fn joint_unpenalized_dim(
910    p_local: usize,
911    penalties_local: &[Array2<f64>],
912    nullspace_dims: &[usize],
913) -> usize {
914    use gam_linalg::faer_ndarray::FaerEigh;
915    if p_local == 0 {
916        return 0;
917    }
918    if penalties_local.is_empty() {
919        // No penalty ⇒ a wholly unpenalized (fixed-effect) block.
920        return p_local;
921    }
922    // Sum the penalties that are materialized as full `p_local × p_local`
923    // blocks (the common smooth case). The covariance block the Wald test
924    // slices lives in this same coefficient basis (post joint-null rotation),
925    // so the rank is computed in the right metric.
926    let mut s_total = Array2::<f64>::zeros((p_local, p_local));
927    let mut materialized = 0usize;
928    for s in penalties_local {
929        if s.nrows() == p_local && s.ncols() == p_local {
930            s_total += s;
931            materialized += 1;
932        }
933    }
934    if materialized == penalties_local.len() {
935        let symmetric = {
936            let transpose = s_total.t().to_owned();
937            (&s_total + &transpose) * 0.5
938        };
939        if let Ok((evals, _)) = symmetric.eigh(faer::Side::Lower) {
940            let max_abs = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
941            if max_abs == 0.0 {
942                // All penalties identically zero ⇒ unpenalized block.
943                return p_local;
944            }
945            let tol = max_abs * (p_local as f64) * 1e-12;
946            let rank = evals.iter().filter(|&&v| v > tol).count();
947            return p_local.saturating_sub(rank);
948        }
949    }
950    // Conservative fallback when a penalty is not a materialized full block
951    // (e.g. a Kronecker tensor factor): with ≥2 active penalties the joint
952    // null space is almost always empty (the only over-rejecting direction);
953    // with a single penalty it is exactly that penalty's own null space.
954    if penalties_local.len() >= 2 {
955        0
956    } else {
957        nullspace_dims
958            .iter()
959            .copied()
960            .min()
961            .unwrap_or(0)
962            .min(p_local)
963    }
964}
965
966#[derive(Debug, Clone, Serialize, Deserialize)]
967pub struct PenaltyBlockInfo {
968    pub global_index: usize,
969    pub termname: Option<String>,
970    pub penalty: PenaltyInfo,
971}
972
973#[derive(Debug, Clone, Serialize, Deserialize)]
974pub struct DroppedPenaltyBlockInfo {
975    pub termname: Option<String>,
976    pub penalty: PenaltyInfo,
977}
978
979#[derive(Debug, Clone)]
980pub struct SmoothDesign {
981    pub term_designs: Vec<DesignMatrix>,
982    /// Per-term block-local penalties.  Each `col_range` is relative to the
983    /// smooth block (i.e. indexing into the concatenation of `term_designs`).
984    pub penalties: Vec<BlockwisePenalty>,
985    pub nullspace_dims: Vec<usize>,
986    pub penaltyinfo: Vec<PenaltyBlockInfo>,
987    pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
988    pub terms: Vec<SmoothTerm>,
989    /// Optional smooth-block lower bounds in smooth coefficient coordinates.
990    /// Length equals `total_smooth_cols()` when present.
991    pub coefficient_lower_bounds: Option<Array1<f64>>,
992    /// Optional smooth-block inequality constraints:
993    /// `A_smooth * beta_smooth >= b`.
994    pub linear_constraints: Option<LinearInequalityConstraints>,
995}
996
997impl SmoothDesign {
998    pub fn total_smooth_cols(&self) -> usize {
999        self.term_designs.iter().map(DesignMatrix::ncols).sum()
1000    }
1001    pub fn nrows(&self) -> usize {
1002        self.term_designs.first().map_or(0, DesignMatrix::nrows)
1003    }
1004}
1005
1006#[derive(Debug, Clone)]
1007pub struct RawSmoothDesign {
1008    pub term_designs: Vec<DesignMatrix>,
1009    /// Per-term block-local penalties.  Each `col_range` is relative to the
1010    /// smooth block (i.e. indexing into the concatenation of `term_designs`).
1011    pub penalties: Vec<BlockwisePenalty>,
1012    pub nullspace_dims: Vec<usize>,
1013    pub penaltyinfo: Vec<PenaltyBlockInfo>,
1014    pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
1015    pub terms: Vec<SmoothTerm>,
1016    pub coefficient_lower_bounds: Option<Array1<f64>>,
1017    pub linear_constraints: Option<LinearInequalityConstraints>,
1018}
1019
1020impl RawSmoothDesign {
1021    pub fn total_smooth_cols(&self) -> usize {
1022        self.term_designs.iter().map(DesignMatrix::ncols).sum()
1023    }
1024    pub fn nrows(&self) -> usize {
1025        self.term_designs.first().map_or(0, DesignMatrix::nrows)
1026    }
1027}
1028
1029impl From<RawSmoothDesign> for SmoothDesign {
1030    fn from(value: RawSmoothDesign) -> Self {
1031        Self {
1032            term_designs: value.term_designs,
1033            penalties: value.penalties,
1034            nullspace_dims: value.nullspace_dims,
1035            penaltyinfo: value.penaltyinfo,
1036            dropped_penaltyinfo: value.dropped_penaltyinfo,
1037            terms: value.terms,
1038            coefficient_lower_bounds: value.coefficient_lower_bounds,
1039            linear_constraints: value.linear_constraints,
1040        }
1041    }
1042}
1043
1044#[derive(Debug, Default, Clone, Serialize, Deserialize)]
1045pub enum BoundedCoefficientPriorSpec {
1046    #[default]
1047    None,
1048    Uniform,
1049    Beta {
1050        a: f64,
1051        b: f64,
1052    },
1053}
1054
1055#[derive(Debug, Clone, Serialize, Deserialize, Default)]
1056pub enum LinearCoefficientGeometry {
1057    #[default]
1058    Unconstrained,
1059    Bounded {
1060        min: f64,
1061        max: f64,
1062        #[serde(default)]
1063        prior: BoundedCoefficientPriorSpec,
1064    },
1065}
1066
1067#[derive(Debug, Clone, Serialize, Deserialize)]
1068pub struct LinearTermSpec {
1069    pub name: String,
1070    /// Primary feature column index. For Wilkinson-Rogers `:` interaction
1071    /// terms (`a:b[:c...]`) this is the first column in `feature_cols`; the
1072    /// realized design column is the elementwise product across every entry
1073    /// of `feature_cols`. Plain (non-interaction) linear terms set
1074    /// `feature_cols == vec![feature_col]`.
1075    pub feature_col: usize,
1076    /// Full list of columns whose elementwise product yields this term's
1077    /// design column. `len() >= 1`; `len() == 1` is a plain linear effect.
1078    #[serde(default)]
1079    pub feature_cols: Vec<usize>,
1080    /// Categorical-level gates for a factor-aware `:` interaction.
1081    ///
1082    /// Each `(col, level_bits)` multiplies the realized design column by the
1083    /// indicator `1[data[row, col].to_bits() == level_bits]`. This is how a
1084    /// `factor:x` (or `factor:factor`) interaction is expanded: `build_termspec`
1085    /// emits one `LinearTermSpec` per surviving cell of the categorical
1086    /// operand(s) (treatment-coded, first level dropped per factor), each
1087    /// carrying the numeric operands in `feature_cols` and the cell's level
1088    /// gate(s) here. Empty for a plain numeric `:` interaction or main effect,
1089    /// in which case the realized column is exactly the numeric product.
1090    #[serde(default)]
1091    pub categorical_levels: Vec<(usize, u64)>,
1092    /// Optional ridge (`S = I`, REML-selected `λ`) on this linear coefficient.
1093    /// A parametric linear term carries no wiggliness, so it is **unpenalized by
1094    /// default** — gam reports the MLE, matching mgcv/glm/survreg/VGAM (which
1095    /// penalize parametric terms only under an explicit `paraPen`). Set `true`
1096    /// to opt into an explicit shrinkage ridge (a zero-mean Gaussian prior
1097    /// `β ~ N(0, λ⁻¹)`); doing so adds one outer REML smoothing coordinate.
1098    #[serde(default = "default_linear_term_double_penalty")]
1099    pub double_penalty: bool,
1100    #[serde(default)]
1101    pub coefficient_geometry: LinearCoefficientGeometry,
1102    #[serde(default)]
1103    pub coefficient_min: Option<f64>,
1104    #[serde(default)]
1105    pub coefficient_max: Option<f64>,
1106}
1107
1108impl LinearTermSpec {
1109    /// Return the effective list of feature columns. Backfills from
1110    /// `feature_col` for legacy specs that predate the multi-column field.
1111    pub fn effective_feature_cols(&self) -> Vec<usize> {
1112        if self.feature_cols.is_empty() {
1113            vec![self.feature_col]
1114        } else {
1115            self.feature_cols.clone()
1116        }
1117    }
1118
1119    /// True when this term is a Wilkinson-Rogers `:` interaction (multi-col).
1120    pub fn is_interaction(&self) -> bool {
1121        self.feature_cols.len() > 1 || !self.categorical_levels.is_empty()
1122    }
1123
1124    /// Realize this linear term's `(n,)` design column from `data`.
1125    ///
1126    /// The column is the elementwise product of every numeric feature column
1127    /// (`effective_feature_cols`) gated by the categorical-level indicators in
1128    /// `categorical_levels`: each `(col, level_bits)` multiplies the running
1129    /// column by `1[data[row, col].to_bits() == level_bits]`. A plain numeric
1130    /// term (no `categorical_levels`) reduces to the bare product, matching the
1131    /// historical behaviour. A pure categorical interaction (empty
1132    /// `feature_cols`, non-empty `categorical_levels`) reduces to the cell
1133    /// indicator. Bounds are validated here; the returned column has length
1134    /// `data.nrows()`.
1135    pub fn realized_design_column(&self, data: ArrayView2<'_, f64>) -> Result<Array1<f64>, String> {
1136        let n = data.nrows();
1137        let p = data.ncols();
1138        let bounds = |col: usize| -> Result<(), String> {
1139            if col >= p {
1140                Err(format!(
1141                    "linear term '{}' feature column {} out of bounds for {} columns",
1142                    self.name, col, p
1143                ))
1144            } else {
1145                Ok(())
1146            }
1147        };
1148
1149        // Numeric operands. When `categorical_levels` is set we treat
1150        // `feature_cols` as the (possibly empty) numeric operand list and start
1151        // from a column of ones; otherwise we preserve the legacy backfill from
1152        // `feature_col` so a plain term with no `feature_cols` still resolves.
1153        let mut column = if self.categorical_levels.is_empty() {
1154            let cols = self.effective_feature_cols();
1155            for &c in &cols {
1156                bounds(c)?;
1157            }
1158            let mut acc = data.column(cols[0]).to_owned();
1159            for &c in cols.iter().skip(1) {
1160                acc *= &data.column(c);
1161            }
1162            acc
1163        } else {
1164            let mut acc = Array1::<f64>::ones(n);
1165            for &c in &self.feature_cols {
1166                bounds(c)?;
1167                acc *= &data.column(c);
1168            }
1169            acc
1170        };
1171
1172        for &(col, level_bits) in &self.categorical_levels {
1173            bounds(col)?;
1174            let gate = data.column(col);
1175            for (out, &v) in column.iter_mut().zip(gate.iter()) {
1176                if v.to_bits() != level_bits {
1177                    *out = 0.0;
1178                }
1179            }
1180        }
1181
1182        Ok(column)
1183    }
1184}
1185
1186pub const fn default_linear_term_double_penalty() -> bool {
1187    // Parametric/linear terms are unpenalized by default — a single linear
1188    // coefficient has no roughness for a smoothing penalty to control, so the
1189    // historical `S = I`, REML-selected `λ` shrank every linear coefficient off
1190    // the MLE and injected a spurious outer smoothing coordinate (#749). Mature
1191    // tools (mgcv/glm/survreg/VGAM) leave parametric terms unpenalized; gam now
1192    // matches that and reports the MLE. An explicit `double_penalty = true`
1193    // still opts a term into a ridge.
1194    false
1195}
1196
1197pub const fn default_pca_smooth_penalty() -> f64 {
1198    1.0
1199}
1200
1201pub const fn default_pca_chunk_size() -> usize {
1202    4096
1203}
1204
1205/// Random-effects term specification.
1206///
1207/// The selected feature column is interpreted as a categorical grouping variable.
1208/// The term contributes a one-hot dummy block with an identity penalty on group
1209/// coefficients, equivalent to i.i.d. Gaussian random effects.
1210#[derive(Debug, Clone, Serialize, Deserialize)]
1211pub struct RandomEffectTermSpec {
1212    pub name: String,
1213    pub feature_col: usize,
1214    /// If true, drop the lexicographically first group level to use treatment coding.
1215    /// If false, keep all levels (full one-hot block, still identifiable under ridge).
1216    pub drop_first_level: bool,
1217    /// If true, add a ridge penalty and estimate this block as a random effect.
1218    /// If false, leave the one-hot/treatment-coded block unpenalized so it is a
1219    /// fixed categorical main effect.  The default preserves older saved models.
1220    #[serde(default = "default_random_effect_penalized")]
1221    pub penalized: bool,
1222    /// Optional fixed kept-level set (sorted by f64 bit pattern) captured at fit time.
1223    /// When present, prediction uses exactly these columns to avoid design drift.
1224    #[serde(default)]
1225    pub frozen_levels: Option<Vec<u64>>,
1226}
1227
1228pub fn default_random_effect_penalized() -> bool {
1229    true
1230}
1231
1232pub fn validate_measure_jet_positive_vec_len(
1233    label: &str,
1234    term_name: &str,
1235    field: &str,
1236    values: &[f64],
1237    expected: usize,
1238) -> Result<(), String> {
1239    if values.len() != expected {
1240        return Err(SmoothError::invalid_config(format!(
1241            "{label} term '{term_name}' frozen MeasureJet {field} has length {}, expected {expected}",
1242            values.len()
1243        ))
1244        .into());
1245    }
1246    if values
1247        .iter()
1248        .any(|value| !(value.is_finite() && *value > 0.0))
1249    {
1250        return Err(SmoothError::invalid_config(format!(
1251            "{label} term '{term_name}' frozen MeasureJet {field} values must be positive and finite"
1252        ))
1253        .into());
1254    }
1255    Ok(())
1256}
1257
1258#[derive(Debug, Clone, Serialize, Deserialize)]
1259pub struct TermCollectionSpec {
1260    pub linear_terms: Vec<LinearTermSpec>,
1261    pub random_effect_terms: Vec<RandomEffectTermSpec>,
1262    pub smooth_terms: Vec<SmoothTermSpec>,
1263}
1264
1265pub fn validate_smooth_basis_frozen(
1266    basis: &SmoothBasisSpec,
1267    label: &str,
1268    term_name: &str,
1269) -> Result<(), String> {
1270    match basis {
1271        SmoothBasisSpec::ByVariable { inner, .. }
1272        | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
1273            validate_smooth_basis_frozen(inner, label, term_name)
1274        }
1275        SmoothBasisSpec::BSpline1D { spec, .. } => {
1276            if !matches!(
1277                spec.knotspec,
1278                BSplineKnotSpec::Provided(_)
1279                    | BSplineKnotSpec::PeriodicUniform { .. }
1280                    | BSplineKnotSpec::NaturalCubicRegression { .. }
1281            ) {
1282                return Err(format!(
1283                    "{label} term '{term_name}' is not frozen: BSpline knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression"
1284                ));
1285            }
1286            Ok(())
1287        }
1288        SmoothBasisSpec::ThinPlate { spec, .. } => {
1289            if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1290                return Err(format!(
1291                    "{label} term '{term_name}' is not frozen: ThinPlate centers must be UserProvided"
1292                ));
1293            }
1294            if matches!(
1295                spec.identifiability,
1296                SpatialIdentifiability::OrthogonalToParametric
1297            ) {
1298                return Err(format!(
1299                    "{label} term '{term_name}' is not frozen: ThinPlate identifiability must be FrozenTransform or None"
1300                ));
1301            }
1302            Ok(())
1303        }
1304        _ => Ok(()),
1305    }
1306}
1307
1308impl TermCollectionSpec {
1309    /// Write this collection's topology identity into a warm-start cache
1310    /// fingerprint (#869).
1311    ///
1312    /// The persistent warm-start `cache_key` hashes only family + raw input
1313    /// dimensions, so two fits on the same data that differ *only* in their
1314    /// smooth topology (the `s(..., type=AUTO)` candidate enumeration: sphere
1315    /// vs torus vs euclidean vs duchon) collide on one key and seed each other
1316    /// with geometrically incompatible β/ρ. Folding the per-term structural
1317    /// kind + feature columns + linear/random-effect counts into the shape hash
1318    /// gives each candidate its own key, so the screen→full-refit reuse of one
1319    /// candidate is preserved while cross-candidate contamination is removed.
1320    /// Only the structural identity is hashed (not fitted coefficients or
1321    /// frozen knot values), so a refit of the *same* topology still hits.
1322    pub fn write_structural_shape_hash(&self, h: &mut gam_runtime::warm_start::Fingerprinter) {
1323        h.write_str("term-collection");
1324        h.write_usize(self.linear_terms.len());
1325        for linear in &self.linear_terms {
1326            h.write_str(&linear.name);
1327        }
1328        h.write_usize(self.random_effect_terms.len());
1329        h.write_usize(self.smooth_terms.len());
1330        for smooth in &self.smooth_terms {
1331            h.write_str(&smooth.name);
1332            h.write_str(smooth.basis.structural_kind());
1333            for col in smooth.basis.structural_feature_cols() {
1334                h.write_usize(col);
1335            }
1336        }
1337    }
1338
1339    /// Validate that a term collection spec represents a fully frozen model
1340    /// (i.e. all knots/centers are pre-computed, identifiability transforms are
1341    /// baked in, and random-effect levels are fixed).
1342    pub fn validate_frozen(&self, label: &str) -> Result<(), String> {
1343        for linear in &self.linear_terms {
1344            if let (Some(min), Some(max)) = (linear.coefficient_min, linear.coefficient_max)
1345                && (!min.is_finite() || !max.is_finite() || min > max)
1346            {
1347                return Err(SmoothError::invalid_config(format!(
1348                    "{label} linear term '{}' has invalid coefficient constraint [{min}, {max}]",
1349                    linear.name
1350                ))
1351                .into());
1352            }
1353            if let Some(min) = linear.coefficient_min
1354                && !min.is_finite()
1355            {
1356                return Err(SmoothError::invalid_config(format!(
1357                    "{label} linear term '{}' has non-finite coefficient minimum {min}",
1358                    linear.name
1359                ))
1360                .into());
1361            }
1362            if let Some(max) = linear.coefficient_max
1363                && !max.is_finite()
1364            {
1365                return Err(SmoothError::invalid_config(format!(
1366                    "{label} linear term '{}' has non-finite coefficient maximum {max}",
1367                    linear.name
1368                ))
1369                .into());
1370            }
1371            if let LinearCoefficientGeometry::Bounded { min, max, prior } =
1372                &linear.coefficient_geometry
1373            {
1374                if !min.is_finite() || !max.is_finite() || min >= max {
1375                    return Err(SmoothError::invalid_config(format!(
1376                        "{label} bounded term '{}' has invalid bounds [{min}, {max}]",
1377                        linear.name
1378                    ))
1379                    .into());
1380                }
1381                match prior {
1382                    BoundedCoefficientPriorSpec::None | BoundedCoefficientPriorSpec::Uniform => {}
1383                    BoundedCoefficientPriorSpec::Beta { a, b } => {
1384                        if !a.is_finite() || !b.is_finite() || *a < 1.0 || *b < 1.0 {
1385                            return Err(SmoothError::invalid_config(format!(
1386                                "{label} bounded term '{}' has invalid Beta prior ({a}, {b})",
1387                                linear.name
1388                            ))
1389                            .into());
1390                        }
1391                    }
1392                }
1393            }
1394        }
1395        for st in &self.smooth_terms {
1396            match &st.basis {
1397                SmoothBasisSpec::ByVariable { inner, .. } => {
1398                    validate_smooth_basis_frozen(inner, label, &st.name)?;
1399                    let nested = SmoothTermSpec {
1400                        name: st.name.clone(),
1401                        basis: (**inner).clone(),
1402                        shape: st.shape,
1403                        joint_null_rotation: None,
1404                    };
1405                    TermCollectionSpec {
1406                        linear_terms: Vec::new(),
1407                        random_effect_terms: Vec::new(),
1408                        smooth_terms: vec![nested],
1409                    }
1410                    .validate_frozen(label)?;
1411                }
1412                SmoothBasisSpec::FactorSumToZero { inner, levels, .. } => {
1413                    if levels.len() < 2 {
1414                        return Err(format!(
1415                            "{label} term '{}' has invalid frozen sz levels",
1416                            st.name
1417                        ));
1418                    }
1419                    validate_smooth_basis_frozen(inner, label, &st.name)?;
1420                }
1421                SmoothBasisSpec::BSpline1D { spec, .. } => {
1422                    if !matches!(
1423                        spec.knotspec,
1424                        BSplineKnotSpec::Provided(_)
1425                            | BSplineKnotSpec::PeriodicUniform { .. }
1426                            | BSplineKnotSpec::NaturalCubicRegression { .. }
1427                    ) {
1428                        return Err(SmoothError::invalid_config(format!(
1429                            "{label} term '{}' is not frozen: BSpline knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression",
1430                            st.name
1431                        ))
1432                        .into());
1433                    }
1434                }
1435                SmoothBasisSpec::ThinPlate { spec, .. } => {
1436                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1437                        return Err(SmoothError::invalid_config(format!(
1438                            "{label} term '{}' is not frozen: ThinPlate centers must be UserProvided",
1439                            st.name
1440                        ))
1441                        .into());
1442                    }
1443                    if matches!(
1444                        spec.identifiability,
1445                        SpatialIdentifiability::OrthogonalToParametric
1446                    ) {
1447                        return Err(SmoothError::invalid_config(format!(
1448                            "{label} term '{}' is not frozen: ThinPlate identifiability must be FrozenTransform or None",
1449                            st.name
1450                        ))
1451                        .into());
1452                    }
1453                }
1454                SmoothBasisSpec::Sphere { spec, .. } => {
1455                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1456                        return Err(SmoothError::invalid_config(format!(
1457                            "{label} term '{}' is not frozen: Sphere centers must be UserProvided",
1458                            st.name
1459                        ))
1460                        .into());
1461                    }
1462                    if matches!(spec.method, crate::basis::SphereMethod::Harmonic)
1463                        && spec.max_degree.is_none_or(|d| d == 0)
1464                    {
1465                        return Err(format!(
1466                            "{label} term '{}' is not frozen: sphere max_degree must be positive",
1467                            st.name
1468                        ));
1469                    }
1470                }
1471                SmoothBasisSpec::ConstantCurvature { spec, .. } => {
1472                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1473                        return Err(SmoothError::invalid_config(format!(
1474                            "{label} term '{}' is not frozen: ConstantCurvature centers must be UserProvided",
1475                            st.name
1476                        ))
1477                        .into());
1478                    }
1479                    if !(spec.length_scale.is_finite() && spec.length_scale > 0.0) {
1480                        return Err(SmoothError::invalid_config(format!(
1481                            "{label} term '{}' is not frozen: ConstantCurvature length_scale must be the realized positive value",
1482                            st.name
1483                        ))
1484                        .into());
1485                    }
1486                }
1487                SmoothBasisSpec::MeasureJet { spec, .. } => {
1488                    let centers = match &spec.center_strategy {
1489                        CenterStrategy::UserProvided(centers) => centers,
1490                        _ => {
1491                            return Err(SmoothError::invalid_config(format!(
1492                                "{label} term '{}' is not frozen: MeasureJet centers must be UserProvided",
1493                                st.name
1494                            ))
1495                            .into());
1496                        }
1497                    };
1498                    if centers.nrows() == 0 {
1499                        return Err(SmoothError::invalid_config(format!(
1500                            "{label} term '{}' is not frozen: MeasureJet centers are empty",
1501                            st.name
1502                        ))
1503                        .into());
1504                    }
1505                    if !(spec.length_scale.is_finite() && spec.length_scale > 0.0) {
1506                        return Err(SmoothError::invalid_config(format!(
1507                            "{label} term '{}' is not frozen: MeasureJet length_scale must be the realized positive value",
1508                            st.name
1509                        ))
1510                        .into());
1511                    }
1512                    // Exact replay needs the fit-data penalty quadrature and
1513                    // normalization payload (`BasisMetadata::MeasureJet`).
1514                    let frozen = spec.frozen_quadrature.as_ref().ok_or_else(|| {
1515                        SmoothError::invalid_config(format!(
1516                            "{label} term '{}' is not frozen: MeasureJet frozen_quadrature payload is missing",
1517                            st.name
1518                        ))
1519                    })?;
1520                    if frozen.masses.len() != centers.nrows() {
1521                        return Err(SmoothError::invalid_config(format!(
1522                            "{label} term '{}' frozen MeasureJet has {} masses for {} centers",
1523                            st.name,
1524                            frozen.masses.len(),
1525                            centers.nrows()
1526                        ))
1527                        .into());
1528                    }
1529                    let total_mass = frozen.masses.sum();
1530                    if frozen
1531                        .masses
1532                        .iter()
1533                        .any(|mass| !(mass.is_finite() && *mass >= 0.0))
1534                        || !(total_mass.is_finite() && total_mass > 0.0)
1535                    {
1536                        return Err(SmoothError::invalid_config(format!(
1537                            "{label} term '{}' frozen MeasureJet masses must be finite, nonnegative, and have positive total mass",
1538                            st.name
1539                        ))
1540                        .into());
1541                    }
1542                    let n_levels = frozen.eps_band.len();
1543                    if n_levels == 0
1544                        || frozen
1545                            .eps_band
1546                            .iter()
1547                            .any(|eps| !(eps.is_finite() && *eps > 0.0))
1548                    {
1549                        return Err(SmoothError::invalid_config(format!(
1550                            "{label} term '{}' frozen MeasureJet eps_band must be nonempty, finite, and positive",
1551                            st.name
1552                        ))
1553                        .into());
1554                    }
1555                    for (idx, pair) in frozen.eps_band.windows(2).enumerate() {
1556                        if pair[1] <= pair[0] {
1557                            return Err(SmoothError::invalid_config(format!(
1558                                "{label} term '{}' frozen MeasureJet eps_band is not strictly ascending at {idx}: {} then {}",
1559                                st.name,
1560                                pair[0],
1561                                pair[1]
1562                            ))
1563                            .into());
1564                        }
1565                    }
1566                    validate_measure_jet_positive_vec_len(
1567                        label,
1568                        &st.name,
1569                        "support_means",
1570                        &frozen.support_means,
1571                        n_levels,
1572                    )?;
1573                    // Mode predicate MUST match the builder's
1574                    // (`measure_jet_multiscale_mode`): per-level/multiscale is the
1575                    // explicit `spec.multiscale` opt-in (#1116). In single-scale
1576                    // mode the builder emits a single FUSED penalty (empty
1577                    // per-level scales + `fused_penalty_normalization_scale:
1578                    // Some`); only the multiscale opt-in carries `n_levels`
1579                    // per-level scales.
1580                    let per_level = crate::basis::measure_jet_multiscale_mode(spec);
1581                    if per_level {
1582                        validate_measure_jet_positive_vec_len(
1583                            label,
1584                            &st.name,
1585                            "penalty_normalization_scales",
1586                            &frozen.penalty_normalization_scales,
1587                            n_levels,
1588                        )?;
1589                        validate_measure_jet_positive_vec_len(
1590                            label,
1591                            &st.name,
1592                            "raw_penalty_normalization_scales",
1593                            &frozen.raw_penalty_normalization_scales,
1594                            n_levels,
1595                        )?;
1596                        if frozen.fused_penalty_normalization_scale.is_some() {
1597                            return Err(SmoothError::invalid_config(format!(
1598                                "{label} term '{}' per-level MeasureJet must not carry a fused penalty normalization scale",
1599                                st.name
1600                            ))
1601                            .into());
1602                        }
1603                    } else {
1604                        if !frozen.penalty_normalization_scales.is_empty()
1605                            || !frozen.raw_penalty_normalization_scales.is_empty()
1606                        {
1607                            return Err(SmoothError::invalid_config(format!(
1608                                "{label} term '{}' fused MeasureJet must not carry per-level penalty normalization scales",
1609                                st.name
1610                            ))
1611                            .into());
1612                        }
1613                        match frozen.fused_penalty_normalization_scale {
1614                            Some(scale) if scale.is_finite() && scale > 0.0 => {}
1615                            Some(scale) => {
1616                                return Err(SmoothError::invalid_config(format!(
1617                                    "{label} term '{}' fused MeasureJet penalty normalization scale must be positive and finite, got {scale}",
1618                                    st.name
1619                                ))
1620                                .into());
1621                            }
1622                            None => {
1623                                return Err(SmoothError::invalid_config(format!(
1624                                    "{label} term '{}' fused MeasureJet is missing its penalty normalization scale",
1625                                    st.name
1626                                ))
1627                                .into());
1628                            }
1629                        }
1630                    }
1631                }
1632                SmoothBasisSpec::Matern { spec, .. } => {
1633                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1634                        return Err(SmoothError::invalid_config(format!(
1635                            "{label} term '{}' is not frozen: Matern centers must be UserProvided",
1636                            st.name
1637                        ))
1638                        .into());
1639                    }
1640                }
1641                SmoothBasisSpec::Duchon { spec, .. } => {
1642                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1643                        return Err(SmoothError::invalid_config(format!(
1644                            "{label} term '{}' is not frozen: Duchon centers must be UserProvided",
1645                            st.name
1646                        ))
1647                        .into());
1648                    }
1649                    if matches!(
1650                        spec.identifiability,
1651                        SpatialIdentifiability::OrthogonalToParametric
1652                    ) {
1653                        return Err(SmoothError::invalid_config(format!(
1654                            "{label} term '{}' is not frozen: Duchon identifiability must be FrozenTransform or None",
1655                            st.name
1656                        ))
1657                        .into());
1658                    }
1659                }
1660                SmoothBasisSpec::Pca {
1661                    centered,
1662                    center_mean,
1663                    pca_basis_path,
1664                    ..
1665                } => {
1666                    if *centered && center_mean.is_none() && pca_basis_path.is_none() {
1667                        return Err(SmoothError::invalid_config(format!(
1668                            "{label} term '{}' is not frozen: centered Pca missing center_mean",
1669                            st.name
1670                        ))
1671                        .into());
1672                    }
1673                }
1674                SmoothBasisSpec::BySmooth { smooth, by_kind } => {
1675                    if let SmoothBasisSpec::BySmooth { .. } = smooth.as_ref() {
1676                        return Err(format!("{label} term '{}' has nested by-smooths", st.name));
1677                    }
1678                    match by_kind {
1679                        ByVarKind::Numeric { .. } => {}
1680                        ByVarKind::Factor { frozen_levels, .. } if frozen_levels.is_none() => {
1681                            return Err(format!(
1682                                "{label} term '{}' is not frozen: by-factor levels missing",
1683                                st.name
1684                            ));
1685                        }
1686                        ByVarKind::Factor { .. } => {}
1687                    }
1688                    let nested = TermCollectionSpec {
1689                        linear_terms: vec![],
1690                        random_effect_terms: vec![],
1691                        smooth_terms: vec![SmoothTermSpec {
1692                            name: st.name.clone(),
1693                            basis: (**smooth).clone(),
1694                            shape: st.shape,
1695                            joint_null_rotation: None,
1696                        }],
1697                    };
1698                    nested.validate_frozen(label)?;
1699                }
1700                SmoothBasisSpec::FactorSmooth { spec } => {
1701                    if spec.group_frozen_levels.is_none() {
1702                        return Err(format!(
1703                            "{label} term '{}' is not frozen: factor-smooth levels missing",
1704                            st.name
1705                        ));
1706                    }
1707                    if !matches!(
1708                        spec.marginal.knotspec,
1709                        BSplineKnotSpec::Provided(_)
1710                            | BSplineKnotSpec::PeriodicUniform { .. }
1711                            // mgcv's `bs="sz"` default marginal is a cubic
1712                            // regression spline (#1074), and the freeze step
1713                            // restores it as a `NaturalCubicRegression` knotspec
1714                            // carrying its `k` value-knots (spatial_optimization.rs
1715                            // `marginal_is_cr` branch) — the SAME treatment the
1716                            // tensor margin already gets in the arm below. Without
1717                            // this variant a frozen `sz` factor smooth fails its own
1718                            // predict-time freeze check ("factor-smooth marginal
1719                            // knots missing") even though its knots are fully
1720                            // materialized; the validation simply was not updated
1721                            // when the cr marginal landed.
1722                            | BSplineKnotSpec::NaturalCubicRegression { .. }
1723                    ) {
1724                        return Err(format!(
1725                            "{label} term '{}' is not frozen: factor-smooth marginal knots missing",
1726                            st.name
1727                        ));
1728                    }
1729                }
1730                SmoothBasisSpec::TensorBSpline { spec, .. } => {
1731                    for (dim, marginal) in spec.marginalspecs.iter().enumerate() {
1732                        if !matches!(
1733                            marginal.knotspec,
1734                            BSplineKnotSpec::Provided(_)
1735                                | BSplineKnotSpec::PeriodicUniform { .. }
1736                                | BSplineKnotSpec::NaturalCubicRegression { .. }
1737                        ) {
1738                            return Err(SmoothError::invalid_config(format!(
1739                                "{label} term '{}' dim {} is not frozen: tensor marginal knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression",
1740                                st.name, dim
1741                            ))
1742                            .into());
1743                        }
1744                    }
1745                    if matches!(
1746                        spec.identifiability,
1747                        TensorBSplineIdentifiability::SumToZero
1748                            | TensorBSplineIdentifiability::MarginalSumToZero
1749                    ) {
1750                        return Err(SmoothError::invalid_config(format!(
1751                            "{label} term '{}' is not frozen: tensor identifiability must be FrozenTransform or None",
1752                            st.name
1753                        ))
1754                        .into());
1755                    }
1756                }
1757            }
1758        }
1759
1760        for rt in &self.random_effect_terms {
1761            if rt.frozen_levels.is_none() {
1762                return Err(SmoothError::invalid_config(format!(
1763                    "{label} random-effect term '{}' is not frozen: missing frozen_levels",
1764                    rt.name
1765                ))
1766                .into());
1767            }
1768        }
1769
1770        Ok(())
1771    }
1772
1773    /// Re-resolve every stored feature-column index through `remap`, returning a
1774    /// spec that addresses a different column layout.
1775    ///
1776    /// A frozen `TermCollectionSpec` stores feature columns as *absolute indices
1777    /// into the training table*. To replay it on a fresh dataset whose columns
1778    /// sit at different positions — the common case at prediction time, where
1779    /// the response column is unknown and may be absent entirely — every index
1780    /// must be re-resolved against the new layout. `remap` receives each
1781    /// training-table index and returns its position in the runtime table;
1782    /// callers typically implement it as "look the name up in the training
1783    /// headers, then resolve that name against the prediction dataset".
1784    ///
1785    /// This is the single authority on *which* fields carry a column index
1786    /// across every basis variant (linear, random-effect, the `by=` column of
1787    /// `ByVariable`/`FactorSumToZero`/`BySmooth`, the continuous and group
1788    /// columns of a `FactorSmooth`, and the multi-axis `feature_cols` of every
1789    /// spatial/tensor basis), so a predict-time realignment cannot silently miss
1790    /// one and dereference a stale training index.
1791    pub fn remap_feature_columns<E, F>(&self, mut remap: F) -> Result<TermCollectionSpec, E>
1792    where
1793        F: FnMut(usize) -> Result<usize, E>,
1794    {
1795        let mut out = self.clone();
1796        for lt in &mut out.linear_terms {
1797            lt.feature_col = remap(lt.feature_col)?;
1798            // Also remap the full interaction-factor list. The design builder
1799            // (`build_term_collection_design_inner`) materializes the column from
1800            // `effective_feature_cols()` — which returns `feature_cols` whenever
1801            // it is non-empty (i.e. essentially always, including a plain linear
1802            // term where `feature_cols == [feature_col]`). Remapping only the
1803            // singular `feature_col` left these at their saved *training* indices
1804            // at predict time, so a parametric `Surv(...) ~ x` (and any `:`
1805            // interaction) bailed with "feature column N out of bounds" once the
1806            // response/time columns shift the runtime layout (issue #898).
1807            for fc in lt.feature_cols.iter_mut() {
1808                *fc = remap(*fc)?;
1809            }
1810            // A factor-aware `:` interaction also gates on categorical columns;
1811            // those indices live in the same training-time layout and must be
1812            // realigned to the runtime table alongside the numeric operands, or
1813            // the predict-time level indicator would dereference a stale column.
1814            for (col, _bits) in lt.categorical_levels.iter_mut() {
1815                *col = remap(*col)?;
1816            }
1817        }
1818        for rt in &mut out.random_effect_terms {
1819            rt.feature_col = remap(rt.feature_col)?;
1820        }
1821        for st in &mut out.smooth_terms {
1822            remap_smooth_basis_feature_columns(&mut st.basis, &mut remap)?;
1823        }
1824        Ok(out)
1825    }
1826}
1827
1828/// Walk a `SmoothBasisSpec` tree, re-resolving every column index through
1829/// `remap`. Shared by all predict-time column realignment (see
1830/// [`TermCollectionSpec::remap_feature_columns`]); kept exhaustive so a newly
1831/// added index-bearing variant fails to compile until it is handled here.
1832pub fn remap_smooth_basis_feature_columns<E, F>(
1833    basis: &mut SmoothBasisSpec,
1834    remap: &mut F,
1835) -> Result<(), E>
1836where
1837    F: FnMut(usize) -> Result<usize, E>,
1838{
1839    match basis {
1840        SmoothBasisSpec::ByVariable { inner, by_col, .. }
1841        | SmoothBasisSpec::FactorSumToZero { inner, by_col, .. } => {
1842            *by_col = remap(*by_col)?;
1843            remap_smooth_basis_feature_columns(inner, remap)?;
1844        }
1845        SmoothBasisSpec::BSpline1D { feature_col, .. } => {
1846            *feature_col = remap(*feature_col)?;
1847        }
1848        SmoothBasisSpec::BySmooth { smooth, by_kind } => {
1849            let by_feature_col = match by_kind {
1850                ByVarKind::Numeric { feature_col } | ByVarKind::Factor { feature_col, .. } => {
1851                    feature_col
1852                }
1853            };
1854            *by_feature_col = remap(*by_feature_col)?;
1855            remap_smooth_basis_feature_columns(smooth, remap)?;
1856        }
1857        SmoothBasisSpec::FactorSmooth { spec } => {
1858            for fc in spec.continuous_cols.iter_mut() {
1859                *fc = remap(*fc)?;
1860            }
1861            spec.group_col = remap(spec.group_col)?;
1862        }
1863        SmoothBasisSpec::ThinPlate { feature_cols, .. }
1864        | SmoothBasisSpec::Sphere { feature_cols, .. }
1865        | SmoothBasisSpec::ConstantCurvature { feature_cols, .. }
1866        | SmoothBasisSpec::Matern { feature_cols, .. }
1867        | SmoothBasisSpec::MeasureJet { feature_cols, .. }
1868        | SmoothBasisSpec::Duchon { feature_cols, .. }
1869        | SmoothBasisSpec::Pca { feature_cols, .. }
1870        | SmoothBasisSpec::TensorBSpline { feature_cols, .. } => {
1871            for fc in feature_cols.iter_mut() {
1872                *fc = remap(*fc)?;
1873            }
1874        }
1875    }
1876    Ok(())
1877}
1878
1879#[derive(Debug, Clone)]
1880pub enum PenaltyStructureHint {
1881    Ridge(f64),
1882    Kronecker(Vec<Array2<f64>>),
1883}
1884
1885/// A penalty matrix stored at its natural block size together with the
1886/// column range it occupies in the global coefficient vector.
1887///
1888/// Instead of embedding every penalty into a full `p_total × p_total` dense
1889/// matrix filled with zeros, we keep the compact local matrix and reconstruct
1890/// the global view only when a downstream consumer explicitly requires it.
1891#[derive(Clone)]
1892pub struct BlockwisePenalty {
1893    /// Column range in the global coefficient vector that this penalty covers.
1894    pub col_range: Range<usize>,
1895    /// The local penalty matrix — dimensions `block_p × block_p` where
1896    /// `block_p = col_range.len()`.
1897    pub local: Array2<f64>,
1898    /// Optional nonzero centering vector for this coefficient block.
1899    pub prior_mean: gam_problem::CoefficientPriorMean,
1900    /// Optional structural hint so downstream spectral/logdet code can stay
1901    /// block-local or factorized without reverse-engineering the matrix.
1902    pub structure_hint: Option<PenaltyStructureHint>,
1903    /// Optional operator-form handle bit-equivalent to `local`. Populated when
1904    /// the originating closed-form factory emitted an op-form penalty so exact
1905    /// operator algebra can use matvec instead of materializing the dense
1906    /// `block_p × block_p` Gram. `None` for ordinary dense penalties.
1907    pub op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1908}
1909
1910impl std::fmt::Debug for BlockwisePenalty {
1911    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1912        f.debug_struct("BlockwisePenalty")
1913            .field("col_range", &self.col_range)
1914            .field(
1915                "local",
1916                &format_args!("{}×{}", self.local.nrows(), self.local.ncols()),
1917            )
1918            .field("prior_mean", &self.prior_mean)
1919            .field("structure_hint", &self.structure_hint)
1920            .field("op", &self.op.as_ref().map(|o| o.dim()))
1921            .finish()
1922    }
1923}
1924
1925impl BlockwisePenalty {
1926    /// Create a new blockwise penalty.
1927    pub fn new(col_range: Range<usize>, local: Array2<f64>) -> Self {
1928        assert_eq!(col_range.len(), local.nrows());
1929        assert_eq!(col_range.len(), local.ncols());
1930        Self {
1931            col_range,
1932            local,
1933            prior_mean: gam_problem::CoefficientPriorMean::Zero,
1934            structure_hint: None,
1935            op: None,
1936        }
1937    }
1938
1939    pub fn with_prior_mean(
1940        mut self,
1941        prior_mean: gam_problem::CoefficientPriorMean,
1942    ) -> Self {
1943        self.prior_mean = prior_mean;
1944        self
1945    }
1946
1947    /// Attach an op-form penalty handle bit-equivalent to `local`.
1948    pub fn with_op(
1949        mut self,
1950        op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1951    ) -> Self {
1952        self.op = op;
1953        self
1954    }
1955
1956    pub fn ridge(col_range: Range<usize>, scale: f64) -> Self {
1957        let block_size = col_range.len();
1958        let mut local = Array2::<f64>::zeros((block_size, block_size));
1959        for i in 0..block_size {
1960            local[[i, i]] = scale;
1961        }
1962        Self {
1963            col_range,
1964            local,
1965            prior_mean: gam_problem::CoefficientPriorMean::Zero,
1966            structure_hint: Some(PenaltyStructureHint::Ridge(scale)),
1967            op: None,
1968        }
1969    }
1970
1971    pub fn kronecker(
1972        col_range: Range<usize>,
1973        local: Array2<f64>,
1974        factors: Vec<Array2<f64>>,
1975    ) -> Self {
1976        assert_eq!(col_range.len(), local.nrows());
1977        assert_eq!(col_range.len(), local.ncols());
1978        Self {
1979            col_range,
1980            local,
1981            prior_mean: gam_problem::CoefficientPriorMean::Zero,
1982            structure_hint: Some(PenaltyStructureHint::Kronecker(factors)),
1983            op: None,
1984        }
1985    }
1986
1987    /// Expand this blockwise penalty into a full `p_total × p_total` dense
1988    /// matrix (mostly zeros). Use sparingly — the whole point of blockwise
1989    /// storage is to avoid this allocation.
1990    pub fn to_global(&self, p_total: usize) -> Array2<f64> {
1991        let mut g = Array2::<f64>::zeros((p_total, p_total));
1992        let r = &self.col_range;
1993        assert!(
1994            r.end <= p_total && self.local.nrows() == r.len() && self.local.ncols() == r.len(),
1995            "BlockwisePenalty::to_global shape invariant violated: \
1996             col_range={}..{}, local={}x{}, p_total={}",
1997            r.start,
1998            r.end,
1999            self.local.nrows(),
2000            self.local.ncols(),
2001            p_total,
2002        );
2003        g.slice_mut(s![r.start..r.end, r.start..r.end])
2004            .assign(&self.local);
2005        g
2006    }
2007
2008    /// Convert into a blockwise [`gam_problem::PenaltyMatrix`] without
2009    /// expanding to full dimensions.
2010    pub fn to_penalty_matrix(
2011        &self,
2012        total_dim: usize,
2013    ) -> gam_problem::PenaltyMatrix {
2014        gam_problem::PenaltyMatrix::Blockwise {
2015            local: self.local.clone(),
2016            col_range: self.col_range.clone(),
2017            total_dim,
2018        }
2019    }
2020
2021    /// The block size of this penalty.
2022    #[inline]
2023    pub fn block_size(&self) -> usize {
2024        self.col_range.len()
2025    }
2026}
2027
2028/// Compute `Σ_k λ_k S_k` directly from blockwise penalties, accumulating
2029/// into a pre-allocated `p_total × p_total` output without ever materializing
2030/// individual global matrices.
2031pub fn weighted_blockwise_penalty_sum(
2032    penalties: &[BlockwisePenalty],
2033    lambdas: &[f64],
2034    p_total: usize,
2035) -> Array2<f64> {
2036    assert_eq!(penalties.len(), lambdas.len());
2037    // Smoothing parameters λ_k must be non-negative and finite. A negative
2038    // λ would flip the sign of the corresponding block S_k, turning the
2039    // total penalty matrix indefinite and silently corrupting every
2040    // downstream Cholesky / PIRLS / REML / pseudo-logdet computation that
2041    // assumes S_λ ⪰ 0. Catch this at the boundary rather than after it
2042    // has propagated.
2043    for (idx, &lam) in lambdas.iter().enumerate() {
2044        assert!(
2045            lam.is_finite() && lam >= 0.0,
2046            "weighted_blockwise_penalty_sum: lambdas[{idx}] = {lam} is invalid (must be finite and non-negative; negative smoothing parameters violate S_λ ⪰ 0)",
2047        );
2048    }
2049    // Block column ranges must also fit inside the declared total parameter
2050    // dimension; an out-of-bounds slice would otherwise panic from ndarray
2051    // with a far less informative message.
2052    for (idx, bp) in penalties.iter().enumerate() {
2053        let r = &bp.col_range;
2054        assert!(
2055            r.end <= p_total,
2056            "weighted_blockwise_penalty_sum: penalties[{idx}] col_range {:?} exceeds p_total = {p_total}",
2057            r,
2058        );
2059    }
2060    let mut out = Array2::<f64>::zeros((p_total, p_total));
2061    for (bp, &lam) in penalties.iter().zip(lambdas.iter()) {
2062        let r = &bp.col_range;
2063        let mut slice = out.slice_mut(s![r.start..r.end, r.start..r.end]);
2064        slice.scaled_add(lam, &bp.local);
2065    }
2066    out
2067}
2068
2069// ---------------------------------------------------------------------------
2070// KroneckerPenaltySystem — factored tensor-product penalty representation
2071// ---------------------------------------------------------------------------
2072
2073/// Factored representation of tensor-product penalties with precomputed
2074/// marginal eigensystems for O(∏q_j) logdet and penalty operations.
2075#[derive(Debug, Clone)]
2076pub struct KroneckerPenaltySystem {
2077    /// Marginal penalty matrices: `marginal_penalties[k]` is `(q_k, q_k)`.
2078    pub marginal_penalties: Vec<Array2<f64>>,
2079    /// Precomputed eigensystems: `(eigenvalues, eigenvectors)` per marginal.
2080    pub marginal_eigensystems: Vec<(Array1<f64>, Array2<f64>)>,
2081    /// Marginal basis dimensions.
2082    pub marginal_dims: Vec<usize>,
2083    /// Whether a global ridge (double) penalty is present.
2084    pub has_double_penalty: bool,
2085}
2086
2087impl KroneckerPenaltySystem {
2088    pub fn new(
2089        marginal_penalties: Vec<Array2<f64>>,
2090        marginal_dims: Vec<usize>,
2091        has_double_penalty: bool,
2092    ) -> Result<Self, BasisError> {
2093        if marginal_penalties.len() != marginal_dims.len() {
2094            crate::bail_dim_basis!(
2095                "KroneckerPenaltySystem: {} penalties vs {} dims",
2096                marginal_penalties.len(),
2097                marginal_dims.len()
2098            );
2099        }
2100        let eigensystems =
2101            kronecker_marginal_eigensystems(&marginal_penalties, "KroneckerPenaltySystem")
2102                .map_err(|e| BasisError::InvalidInput(e.to_string()))?;
2103        Ok(Self {
2104            marginal_penalties,
2105            marginal_eigensystems: eigensystems,
2106            marginal_dims,
2107            has_double_penalty,
2108        })
2109    }
2110
2111    pub fn p_total(&self) -> usize {
2112        self.marginal_dims.iter().copied().product()
2113    }
2114
2115    pub fn ndim(&self) -> usize {
2116        self.marginal_dims.len()
2117    }
2118
2119    pub fn num_penalties(&self) -> usize {
2120        self.marginal_dims.len() + if self.has_double_penalty { 1 } else { 0 }
2121    }
2122
2123    /// Compute `log|S|₊` and its first/second derivatives w.r.t. `ρ_k = log(λ_k)`.
2124    ///
2125    /// Iterates over the ∏q_j multi-index grid. Cost: O(d · ∏q_j), no O(p²) storage.
2126    pub fn logdet_and_derivatives(
2127        &self,
2128        lambdas: &[f64],
2129        ridge: f64,
2130    ) -> (f64, Array1<f64>, Array2<f64>) {
2131        let n_pen = self.num_penalties();
2132        assert_eq!(lambdas.len(), n_pen, "lambda count mismatch");
2133        let marginal_evals: Vec<_> = self
2134            .marginal_eigensystems
2135            .iter()
2136            .map(|(evals, _)| evals.view())
2137            .collect();
2138        kronecker_logdet_and_derivatives(
2139            &marginal_evals,
2140            &self.marginal_dims,
2141            lambdas,
2142            self.has_double_penalty,
2143            ridge,
2144        )
2145    }
2146
2147    pub fn logdet_rank_and_derivatives(
2148        &self,
2149        lambdas: &[f64],
2150        ridge: f64,
2151    ) -> (f64, usize, Array1<f64>, Array2<f64>) {
2152        let n_pen = self.num_penalties();
2153        assert_eq!(lambdas.len(), n_pen, "lambda count mismatch");
2154        let d = self.marginal_dims.len();
2155        let mut logdet = 0.0;
2156        let mut rank = 0usize;
2157        let mut grad = Array1::<f64>::zeros(n_pen);
2158        let mut hess = Array2::<f64>::zeros((n_pen, n_pen));
2159        // Positivity floor for a penalized eigenvalue `σ`: below this the mode
2160        // is treated as an unpenalized (null-space) direction and excluded from
2161        // both the rank count and the pseudo-log-determinant.
2162        const EIGENVALUE_POSITIVITY_FLOOR: f64 = 1e-12;
2163        // Floor on the *structural* eigenvalue sum (λ-independent) used to decide
2164        // whether a mode lives in the penalty range space and so should receive
2165        // the stabilizing ridge; a structurally-null mode gets no ridge.
2166        const STRUCTURAL_ZERO_FLOOR: f64 = 1e-12;
2167        let mut multi_idx = vec![0usize; d];
2168        loop {
2169            let mut sigma = 0.0;
2170            let mut structural_sigma = 0.0;
2171            for k in 0..d {
2172                let marginal_eigenvalue = self.marginal_eigensystems[k].0[multi_idx[k]];
2173                structural_sigma += marginal_eigenvalue;
2174                sigma += lambdas[k] * marginal_eigenvalue;
2175            }
2176            let joint_null = structural_sigma <= STRUCTURAL_ZERO_FLOOR;
2177            if self.has_double_penalty && joint_null {
2178                sigma += lambdas[d];
2179            }
2180            if structural_sigma > STRUCTURAL_ZERO_FLOOR {
2181                sigma += ridge;
2182            }
2183
2184            if sigma > EIGENVALUE_POSITIVITY_FLOOR {
2185                rank += 1;
2186                logdet += sigma.ln();
2187                let inv_sigma = 1.0 / sigma;
2188                let inv_sigma2 = inv_sigma * inv_sigma;
2189                for k in 0..n_pen {
2190                    let ck = if k < d {
2191                        lambdas[k] * self.marginal_eigensystems[k].0[multi_idx[k]]
2192                    } else if joint_null {
2193                        lambdas[d]
2194                    } else {
2195                        0.0
2196                    };
2197                    grad[k] += ck * inv_sigma;
2198                    hess[[k, k]] += ck * inv_sigma - ck * ck * inv_sigma2;
2199                    for l in (k + 1)..n_pen {
2200                        let cl = if l < d {
2201                            lambdas[l] * self.marginal_eigensystems[l].0[multi_idx[l]]
2202                        } else if joint_null {
2203                            lambdas[d]
2204                        } else {
2205                            0.0
2206                        };
2207                        let off = -ck * cl * inv_sigma2;
2208                        hess[[k, l]] += off;
2209                        hess[[l, k]] += off;
2210                    }
2211                }
2212            }
2213
2214            let mut carry = true;
2215            for dim in (0..d).rev() {
2216                if carry {
2217                    multi_idx[dim] += 1;
2218                    if multi_idx[dim] < self.marginal_dims[dim] {
2219                        carry = false;
2220                    } else {
2221                        multi_idx[dim] = 0;
2222                    }
2223                }
2224            }
2225            if carry {
2226                break;
2227            }
2228        }
2229        (logdet, rank, grad, hess)
2230    }
2231}
2232
2233#[cfg(test)]
2234mod joint_unpenalized_dim_tests {
2235    use super::joint_unpenalized_dim;
2236    use ndarray::{Array2, array};
2237
2238    #[test]
2239    fn no_penalty_is_fully_unpenalized() {
2240        assert_eq!(joint_unpenalized_dim(4, &[], &[]), 4);
2241    }
2242
2243    #[test]
2244    fn single_penalty_returns_its_own_null_space() {
2245        // A 3×3 penalty that penalizes only the last coordinate ⇒ 2-dim null
2246        // space (the first two coordinates are unpenalized).
2247        let s = array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 5.0]];
2248        assert_eq!(joint_unpenalized_dim(3, std::slice::from_ref(&s), &[2]), 2);
2249    }
2250
2251    #[test]
2252    fn complementary_double_penalty_has_empty_joint_null_space() {
2253        // The #1360 case in miniature: a "bending" penalty that leaves the
2254        // first coordinate (its 2-dim... here 1-dim) null, plus a
2255        // complementary "null-space ridge" that penalizes exactly that
2256        // coordinate. Per-penalty null dims are {1, 2} and sum to 3 (≈ p),
2257        // but the INTERSECTION is empty: every coordinate is penalized by
2258        // someone, so the joint unpenalized dim is 0.
2259        let bending = array![[0.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]];
2260        let ridge = array![[2.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]];
2261        assert_eq!(joint_unpenalized_dim(3, &[bending, ridge], &[1, 2]), 0);
2262    }
2263
2264    #[test]
2265    fn partial_overlap_keeps_shared_null_direction() {
2266        // Two penalties that BOTH leave coordinate 0 unpenalized ⇒ the shared
2267        // null direction survives the intersection (joint unpenalized dim 1),
2268        // even though naively summing the per-penalty dims would give 4.
2269        let a = array![[0.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 0.0]];
2270        let b = array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]];
2271        assert_eq!(joint_unpenalized_dim(3, &[a, b], &[2, 2]), 1);
2272    }
2273
2274    #[test]
2275    fn non_materialized_penalty_falls_back_conservatively() {
2276        // A penalty whose stored block is not p_local × p_local (e.g. a
2277        // Kronecker tensor factor). With ≥2 penalties the conservative joint
2278        // dim is 0 (never over-rejecting).
2279        let full: Array2<f64> = array![[0.0, 0.0], [0.0, 1.0]];
2280        let factor: Array2<f64> = array![[1.0]]; // wrong shape for p_local=2
2281        assert_eq!(
2282            joint_unpenalized_dim(2, &[full, factor.clone()], &[1, 0]),
2283            0
2284        );
2285        // With a single non-materialized penalty, fall back to its own null dim.
2286        assert_eq!(joint_unpenalized_dim(4, std::slice::from_ref(&factor), &[2]), 2);
2287    }
2288}
2289
2290#[cfg(test)]
2291mod kronecker_penalty_system_tests {
2292    use super::KroneckerPenaltySystem;
2293    use ndarray::array;
2294
2295    #[test]
2296    fn double_penalty_rank_derivatives_use_only_joint_null_space() {
2297        let penalties = vec![
2298            array![[0.0, 0.0], [0.0, 2.0]],
2299            array![[0.0, 0.0], [0.0, 3.0]],
2300        ];
2301        let system = KroneckerPenaltySystem::new(penalties, vec![2usize, 2usize], true).unwrap();
2302        let lambdas = vec![5.0, 7.0, 11.0];
2303
2304        let (logdet, rank, grad, hess) = system.logdet_rank_and_derivatives(&lambdas, 0.0);
2305
2306        let expected_diag = [11.0_f64, 21.0, 10.0, 31.0];
2307        let expected_logdet: f64 = expected_diag.iter().map(|v| v.ln()).sum();
2308        assert_eq!(rank, 4);
2309        assert!((logdet - expected_logdet).abs() <= 1e-12);
2310        assert!(
2311            (grad[2] - 1.0).abs() <= 1e-12,
2312            "double-penalty rank derivative must count only the joint null mode, got {}",
2313            grad[2]
2314        );
2315        assert!(hess[[2, 2]].abs() <= 1e-12);
2316    }
2317}
2318
2319#[derive(Clone, Debug)]
2320pub struct TermCollectionDesign {
2321    /// The full design matrix.
2322    ///
2323    /// Prefer a true sparse matrix when every block is sparse-compatible.
2324    /// If the collection already contains intrinsically sparse blocks, preserve
2325    /// that storage and let PIRLS decide later whether the penalized system is
2326    /// sparse-native eligible. Purely dense materialized blocks still fall back
2327    /// to the lazy block operator when sparse storage would just re-encode a
2328    /// dense matrix.
2329    pub design: DesignMatrix,
2330    pub penalties: Vec<BlockwisePenalty>,
2331    pub nullspace_dims: Vec<usize>,
2332    pub penaltyinfo: Vec<PenaltyBlockInfo>,
2333    pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
2334    /// Optional global coefficient lower bounds for constrained fitting.
2335    /// Length equals `design.ncols()` when present. Unconstrained entries are `-inf`.
2336    pub coefficient_lower_bounds: Option<Array1<f64>>,
2337    /// Optional global inequality constraints:
2338    /// `A * beta >= b`.
2339    pub linear_constraints: Option<LinearInequalityConstraints>,
2340    pub intercept_range: Range<usize>,
2341    pub linear_ranges: Vec<(String, Range<usize>)>,
2342    pub random_effect_ranges: Vec<(String, Range<usize>)>,
2343    pub random_effect_levels: Vec<(String, Vec<u64>)>,
2344    pub smooth: SmoothDesign,
2345}
2346
2347impl TermCollectionDesign {
2348    /// Convert blockwise penalties to `PenaltyMatrix::Blockwise` without
2349    /// expanding to `p_total × p_total`. This is the preferred path for
2350    /// family modules that accept `Vec<PenaltyMatrix>`.
2351    pub fn penalties_as_penalty_matrix(&self) -> Vec<gam_problem::PenaltyMatrix> {
2352        let p = self.design.ncols();
2353        self.penalties
2354            .iter()
2355            .map(|bp| bp.to_penalty_matrix(p))
2356            .collect()
2357    }
2358
2359    /// Number of penalty blocks.
2360    #[inline]
2361    pub fn num_penalties(&self) -> usize {
2362        self.penalties.len()
2363    }
2364
2365    /// Resolve coefficient groups against this design's global coefficient
2366    /// layout and append their penalties after the existing term penalties.
2367    pub fn realize_coefficient_groups(
2368        &self,
2369        groups: &[CoefficientGroupSpec],
2370        base_prior: &gam_spec::RhoPrior,
2371    ) -> Result<RealizedCoefficientGroups, BasisError> {
2372        realize_coefficient_groups(self, groups, base_prior)
2373    }
2374
2375    /// Extract a `KroneckerPenaltySystem` when the model's *only* smooth term is
2376    /// a single Kronecker-factored tensor.
2377    ///
2378    /// This is a deliberate single-tensor fast path, not a partial feature: any
2379    /// other shape — zero Kronecker terms, several of them, or a tensor mixed
2380    /// with non-tensor smooth terms — is served correctly by the standard
2381    /// block-separable assembly, so this returns `None` and the caller falls
2382    /// back to it. The two former conditions (`len != 1` and "a non-Kronecker
2383    /// smooth term exists") are jointly equivalent to "the sole smooth term is
2384    /// Kronecker", which the slice pattern below expresses directly in one pass.
2385    pub fn kronecker_penalty_system(&self) -> Option<KroneckerPenaltySystem> {
2386        let [only_term] = self.smooth.terms.as_slice() else {
2387            return None;
2388        };
2389        let kron = only_term.kronecker_factored.as_ref()?;
2390        // A genuine tensor product needs at least two margins, and the marginal
2391        // design / penalty / dim collections must agree in length. A degenerate
2392        // (single-margin) or internally inconsistent factored basis cannot feed
2393        // the Kronecker fast path, so fall back to the standard assembly rather
2394        // than construct a malformed `KroneckerPenaltySystem` from it.
2395        if kron.marginal_dims.len() < 2
2396            || kron.marginal_penalties.len() != kron.marginal_dims.len()
2397            || kron.marginal_designs.len() != kron.marginal_dims.len()
2398        {
2399            return None;
2400        }
2401        KroneckerPenaltySystem::new(
2402            kron.marginal_penalties.clone(),
2403            kron.marginal_dims.clone(),
2404            kron.has_double_penalty,
2405        )
2406        .ok()
2407    }
2408}
2409
2410// `FittedTermCollection`, `SpatialLengthScaleOptimizationTiming`, and
2411// `FittedTermCollectionWithSpec` were relocated with the GAM fit-orchestration
2412// drivers to `gam-models` (`crate::fit_orchestration::drivers`) — they hold a
2413// `gam_solve::UnifiedFitResult` and are consumed only by those drivers (#1521).
2414
2415#[derive(Clone)]
2416pub struct StandardLatentCoordConfig {
2417    pub values: std::sync::Arc<crate::latent::LatentCoordValues>,
2418    pub term_index: gam_problem::types::SmoothTermIdx,
2419    pub feature_cols: Vec<usize>,
2420    pub manifold: crate::latent::LatentManifold,
2421    pub manifold_auto: bool,
2422    pub retraction_registry: gam_problem::LatentRetractionRegistry,
2423    pub analytic_penalties: Option<std::sync::Arc<crate::AnalyticPenaltyRegistry>>,
2424}
2425
2426#[derive(Clone, Debug, Serialize, Deserialize)]
2427pub struct AdaptiveSpatialMap {
2428    pub termname: String,
2429    pub feature_cols: Vec<usize>,
2430    pub collocation_points: Array2<f64>,
2431    pub inv_magweight: Array1<f64>,
2432    pub invgradweight: Array1<f64>,
2433    pub inv_lapweight: Array1<f64>,
2434}
2435
2436#[derive(Clone, Debug, Serialize, Deserialize)]
2437pub struct AdaptiveRegularizationDiagnostics {
2438    pub epsilon_0: f64,
2439    pub epsilon_g: f64,
2440    pub epsilon_c: f64,
2441    pub epsilon_outer_iterations: usize,
2442    pub mm_iterations: usize,
2443    pub converged: bool,
2444    pub maps: Vec<AdaptiveSpatialMap>,
2445}
2446
2447#[derive(Debug, Clone)]
2448pub struct LinearColumnConditioning {
2449    col_idx: usize,
2450    mean: f64,
2451    scale: f64,
2452}
2453
2454#[derive(Debug, Clone, Default)]
2455pub struct LinearFitConditioning {
2456    pub intercept_idx: usize,
2457    pub columns: Vec<LinearColumnConditioning>,
2458}
2459
2460#[derive(Clone)]
2461pub struct SpatialPsiDerivative {
2462    // These are derivatives with respect to psi = log(kappa), not log(length_scale).
2463    pub penalty_index: usize,
2464    pub penalty_indices: Vec<usize>,
2465    pub global_range: Range<usize>,
2466    pub total_p: usize,
2467    pub x_psi_local: Array2<f64>,
2468    pub s_psi_components_local: Vec<Array2<f64>>,
2469    pub x_psi_psi_local: Array2<f64>,
2470    pub s_psi_psi_components_local: Vec<Array2<f64>>,
2471    pub aniso_group_id: Option<usize>,
2472    /// Pre-computed cross-derivative design matrices for other axes
2473    /// in the same aniso group: Vec of (axis_offset_in_group, matrix).
2474    pub aniso_cross_designs: Option<Vec<(usize, Array2<f64>)>>,
2475    /// On-demand cross-penalty second derivatives ∂²S_m/∂ψ_a∂ψ_b for axes in
2476    /// the same anisotropy group. The input is the other axis offset in the
2477    /// group, and the output is one local penalty matrix per active penalty.
2478    pub aniso_cross_penalty_provider: Option<
2479        std::sync::Arc<
2480            dyn Fn(usize) -> Result<Vec<Array2<f64>>, EstimationError> + Send + Sync + 'static,
2481        >,
2482    >,
2483    /// Optional implicit design-derivative operator (shared across all axes
2484    /// in the same aniso group). When present, `x_psi_local` and
2485    /// `x_psi_psi_local` may be zero-sized, and design-derivative matvecs
2486    /// should go through this operator using `implicit_axis` as the axis index.
2487    pub implicit_operator: Option<std::sync::Arc<crate::basis::ImplicitDesignPsiDerivative>>,
2488    /// Which axis in the implicit operator this entry corresponds to.
2489    pub implicit_axis: usize,
2490}
2491
2492#[derive(Debug, Clone)]
2493pub struct SpatialLogKappaCoords {
2494    /// Flattened ψ values. For isotropic terms, one entry per term.
2495    /// For anisotropic terms, d entries per term (one ψ_a per axis).
2496    pub values: Array1<f64>,
2497    /// Dimensionality of each term: 1 for isotropic, d for anisotropic.
2498    pub dims_per_term: Vec<usize>,
2499}
2500
2501/// Which end of the ψ bound the shared `aniso_bounds_from_data` helper is
2502/// computing. The lower end uses `-max_length_scale.ln()` as the pure-Duchon
2503/// fallback and the `.0` element of `spatial_term_psi_bounds`; the upper end
2504/// uses `-min_length_scale.ln()` and `.1`. Everything else is identical.
2505#[derive(Clone, Copy)]
2506pub enum AnisoBoundEnd {
2507    Lower,
2508    Upper,
2509}
2510
2511impl SpatialLogKappaCoords {
2512    /// Construct from an explicit dims layout plus values.
2513    pub fn new_with_dims(values: Array1<f64>, dims_per_term: Vec<usize>) -> Self {
2514        assert_eq!(
2515            values.len(),
2516            dims_per_term.iter().sum::<usize>(),
2517            "SpatialLogKappaCoords: values length {} != sum of dims_per_term {}",
2518            values.len(),
2519            dims_per_term.iter().sum::<usize>(),
2520        );
2521        Self {
2522            values,
2523            dims_per_term,
2524        }
2525    }
2526
2527    /// Isotropic initialization (backward-compatible path).
2528    pub fn from_length_scales(
2529        spec: &TermCollectionSpec,
2530        term_indices: &[usize],
2531        options: &SpatialLengthScaleOptimizationOptions,
2532    ) -> Self {
2533        let mut out = Array1::<f64>::zeros(term_indices.len());
2534        for (slot, &term_idx) in term_indices.iter().enumerate() {
2535            // Constant-curvature: the single ψ slot is the raw signed κ, seeded
2536            // from the spec (default κ = 0). The −ln(length_scale) convention is
2537            // log-κ semantics and must not touch the raw-κ coordinate; the κ
2538            // window projection happens later via `clamp_to_bounds`. Mirrors the
2539            // aniso constructor's κ branch.
2540            if let Some(cc) = constant_curvature_term_spec(spec, term_idx) {
2541                out[slot] = cc.kappa;
2542                continue;
2543            }
2544            let length_scale = get_spatial_length_scale(spec, term_idx)
2545                .unwrap_or(options.min_length_scale)
2546                .clamp(options.min_length_scale, options.max_length_scale);
2547            out[slot] = -length_scale.ln();
2548        }
2549        Self {
2550            values: out,
2551            dims_per_term: vec![1; term_indices.len()],
2552        }
2553    }
2554
2555    /// Anisotropic-aware initialization.
2556    ///
2557    /// Initialization strategy (per math team recommendation): standardize the
2558    /// knot cloud axiswise, then run the existing isotropic κ initializer in
2559    /// the standardized space. This reuses the trusted isotropic initializer
2560    /// and gives initial η_a = −ln(σ_a) + mean(ln(σ_a)), which satisfies
2561    /// Ση_a = 0 by construction.
2562    ///
2563    /// For each term, checks whether it has `aniso_log_scales` set on its basis spec.
2564    /// - If isotropic (no aniso_log_scales, or 1-D): 1 entry = −ln(length_scale).
2565    /// - If anisotropic with a scalar length scale: d entries, one ψ_a per axis.
2566    ///   Initialized as ψ_a = −ln(length_scale) + η_a  where η_a are the existing
2567    ///   aniso_log_scales (which sum to zero). Multi-dimensional terms without
2568    ///   explicit anisotropy stay scalar here so the seed dimensionality matches
2569    ///   `spatial_dims_per_term`.
2570    /// - If pure Duchon anisotropic: d - 1 free entries store the leading η_a
2571    ///   values directly; the final axis is reconstructed to keep Ση_a = 0.
2572    pub fn from_length_scales_aniso(
2573        spec: &TermCollectionSpec,
2574        term_indices: &[usize],
2575        options: &SpatialLengthScaleOptimizationOptions,
2576    ) -> Self {
2577        let mut vals = Vec::new();
2578        let mut dims = Vec::new();
2579        for &term_idx in term_indices {
2580            // Measure-jet: dial coordinates seeded directly from the term's
2581            // realized (α, τ[, s]); the −ln(length_scale) convention below is
2582            // κ-semantics and never applies to dials.
2583            if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
2584                let seed = measure_jet_psi_seed(mj);
2585                dims.push(seed.len());
2586                vals.extend(seed);
2587                continue;
2588            }
2589            // Constant-curvature: one signed κ slot seeded from the spec's κ
2590            // (clamped feasible). The −ln(length_scale) convention below is
2591            // log-κ semantics and must not touch the raw-κ coordinate. Bounds
2592            // are unavailable here (no data view), so this is the raw spec κ;
2593            // `reseed_from_data` / `clamp_to_bounds` later project it feasible.
2594            if let Some(cc) = constant_curvature_term_spec(spec, term_idx) {
2595                vals.push(cc.kappa);
2596                dims.push(1);
2597                continue;
2598            }
2599            let length_scale = get_spatial_length_scale(spec, term_idx)
2600                .unwrap_or(options.min_length_scale)
2601                .clamp(options.min_length_scale, options.max_length_scale);
2602            let psi_bar = -length_scale.ln(); // global scale = −ln(length_scale)
2603
2604            if spatial_term_uses_per_axis_psi(spec, term_idx) {
2605                // Per-axis anisotropy is enrolled in the joint outer vector:
2606                // ψ_a = ψ̄ + η_a, one slot per axis. The hyper_dirs builder
2607                // produces matching per-axis derivatives in
2608                // `try_build_spatial_term_log_kappa_aniso_derivativeinfos`.
2609                let d = get_spatial_feature_dim(spec, term_idx).unwrap_or(1);
2610                let eta_raw = get_spatial_aniso_log_scales(spec, term_idx)
2611                    .expect("predicate guarantees aniso_log_scales is Some");
2612                let eta = center_aniso_log_scales(&eta_raw);
2613                for &eta_a in &eta {
2614                    vals.push(psi_bar + eta_a);
2615                }
2616                dims.push(d);
2617            } else {
2618                // Isotropic enrollment — either a 1-D term, a multi-D term
2619                // without explicit anisotropy, or a basis (e.g. Duchon) whose
2620                // η is a fixed geometry parameter rather than a REML hyper
2621                // axis. Exactly one ψ̄ slot, matching the single
2622                // `SpatialPsiDerivative` produced by
2623                // `try_build_spatial_term_log_kappa_derivativeinfo`.
2624                vals.push(psi_bar);
2625                dims.push(1);
2626            }
2627        }
2628        Self {
2629            values: Array1::from_vec(vals),
2630            dims_per_term: dims,
2631        }
2632    }
2633
2634    /// Isotropic lower bounds derived from per-term data geometry.
2635    /// Each entry gets the ψ_lo bound returned by `spatial_term_psi_bounds`
2636    /// for the corresponding term, intersected with the options window.
2637    pub fn lower_bounds_from_data(
2638        data: ArrayView2<'_, f64>,
2639        spec: &TermCollectionSpec,
2640        term_indices: &[usize],
2641        options: &SpatialLengthScaleOptimizationOptions,
2642    ) -> Self {
2643        let mut values = Array1::<f64>::zeros(term_indices.len());
2644        for (slot, &term_idx) in term_indices.iter().enumerate() {
2645            values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).0;
2646        }
2647        Self {
2648            values,
2649            dims_per_term: vec![1; term_indices.len()],
2650        }
2651    }
2652
2653    /// Isotropic upper bounds derived from per-term data geometry.
2654    pub fn upper_bounds_from_data(
2655        data: ArrayView2<'_, f64>,
2656        spec: &TermCollectionSpec,
2657        term_indices: &[usize],
2658        options: &SpatialLengthScaleOptimizationOptions,
2659    ) -> Self {
2660        let mut values = Array1::<f64>::zeros(term_indices.len());
2661        for (slot, &term_idx) in term_indices.iter().enumerate() {
2662            values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).1;
2663        }
2664        Self {
2665            values,
2666            dims_per_term: vec![1; term_indices.len()],
2667        }
2668    }
2669
2670    /// Anisotropic-aware lower bounds derived from per-term data geometry.
2671    /// For hybrid anisotropic terms the scalar ψ_lo bound applies to the
2672    /// mean `ψ̄`, not directly to every raw axis coordinate `ψ_a = ψ̄ + η_a`.
2673    /// Shift each axis by the current centered `η_a` so projecting/clamping
2674    /// the seed moves only the global scale direction and does not silently
2675    /// shrink anisotropy that is already consistent with the current
2676    /// `length_scale`.
2677    ///
2678    /// Pure Duchon anisotropy is structurally different: its stored
2679    /// coordinates are (d-1) free η_a values representing log axis-scale
2680    /// ratios, NOT log-κ. For those terms the κ-range geometry bound is
2681    /// over-restrictive (η_a = ±5 is normal, but that corresponds to 7+
2682    /// orders of magnitude in κ-space and would be rejected by the data
2683    /// window). Fall back to the options window `[-ln(max_ls), -ln(min_ls)]`
2684    /// for those coordinates — that's the same bound the pre-data-geometry
2685    /// code used, which is calibrated to allow legitimate anisotropy.
2686    pub fn lower_bounds_aniso_from_data(
2687        data: ArrayView2<'_, f64>,
2688        spec: &TermCollectionSpec,
2689        term_indices: &[usize],
2690        dims_per_term: &[usize],
2691        options: &SpatialLengthScaleOptimizationOptions,
2692    ) -> Self {
2693        Self::aniso_bounds_from_data(
2694            data,
2695            spec,
2696            term_indices,
2697            dims_per_term,
2698            options,
2699            AnisoBoundEnd::Lower,
2700        )
2701    }
2702
2703    /// Anisotropic-aware upper bounds derived from per-term data geometry.
2704    /// See `lower_bounds_aniso_from_data` for the hybrid-aniso offsetting and
2705    /// pure-Duchon dispatch rationale.
2706    pub fn upper_bounds_aniso_from_data(
2707        data: ArrayView2<'_, f64>,
2708        spec: &TermCollectionSpec,
2709        term_indices: &[usize],
2710        dims_per_term: &[usize],
2711        options: &SpatialLengthScaleOptimizationOptions,
2712    ) -> Self {
2713        Self::aniso_bounds_from_data(
2714            data,
2715            spec,
2716            term_indices,
2717            dims_per_term,
2718            options,
2719            AnisoBoundEnd::Upper,
2720        )
2721    }
2722
2723    /// Shared implementation for the lower/upper aniso bounds. The bound end
2724    /// only changes which options scale (`max_length_scale` vs
2725    /// `min_length_scale`) becomes the pure-Duchon fallback bound and which
2726    /// element of the `(lo, hi)` data-geometry tuple is consumed; the
2727    /// per-term cursor walk and aniso-offset handling are identical.
2728    fn aniso_bounds_from_data(
2729        data: ArrayView2<'_, f64>,
2730        spec: &TermCollectionSpec,
2731        term_indices: &[usize],
2732        dims_per_term: &[usize],
2733        options: &SpatialLengthScaleOptimizationOptions,
2734        end: AnisoBoundEnd,
2735    ) -> Self {
2736        assert_eq!(term_indices.len(), dims_per_term.len());
2737        let total: usize = dims_per_term.iter().sum();
2738        let mut values = Array1::<f64>::zeros(total);
2739        let mut cursor = 0;
2740        for (slot, &term_idx) in term_indices.iter().enumerate() {
2741            let d = dims_per_term[slot];
2742            // Measure-jet: per-coordinate dial boxes, never κ-window geometry
2743            // (which would reject legitimate dial values outright).
2744            if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
2745                let bounds = measure_jet_psi_bound_values(mj, matches!(end, AnisoBoundEnd::Upper));
2746                for (offset, bound) in bounds.into_iter().enumerate() {
2747                    if offset < d {
2748                        values[cursor + offset] = bound;
2749                    }
2750                }
2751                cursor += d;
2752                continue;
2753            }
2754            // Constant-curvature: the single signed-κ box from the data chart
2755            // window (symmetric about κ = 0), never a κ = log-scale window.
2756            if constant_curvature_term_spec(spec, term_idx).is_some() {
2757                let (lo, hi) = constant_curvature_kappa_bounds(data, spec, term_idx);
2758                if d >= 1 {
2759                    values[cursor] = match end {
2760                        AnisoBoundEnd::Lower => lo,
2761                        AnisoBoundEnd::Upper => hi,
2762                    };
2763                }
2764                cursor += d;
2765                continue;
2766            }
2767            let psi_bound = {
2768                let (lo, hi) = spatial_term_psi_bounds(data, spec, term_idx, options);
2769                match end {
2770                    AnisoBoundEnd::Lower => lo,
2771                    AnisoBoundEnd::Upper => hi,
2772                }
2773            };
2774            let axis_offsets = if d <= 1 {
2775                vec![0.0; d]
2776            } else {
2777                get_spatial_aniso_log_scales(spec, term_idx)
2778                    .filter(|eta| eta.len() == d)
2779                    .map(|eta| center_aniso_log_scales(&eta))
2780                    .unwrap_or_else(|| vec![0.0; d])
2781            };
2782            for offset in 0..d {
2783                values[cursor + offset] = psi_bound + axis_offsets[offset];
2784            }
2785            cursor += d;
2786        }
2787        Self {
2788            values,
2789            dims_per_term: dims_per_term.to_vec(),
2790        }
2791    }
2792
2793    /// Rewrite any ψ entries whose originating term lacks an explicit
2794    /// `length_scale` so they sit at the midpoint of the per-term data-derived
2795    /// ψ window. Used so the outer optimizer starts inside the physically
2796    /// meaningful region instead of at an arbitrary `options.max_length_scale`
2797    /// derived seed. For terms with an explicit length_scale, the user's
2798    /// choice is respected. Anisotropy offsets η_a (those stored by
2799    /// `from_length_scales_aniso`) are preserved: we re-center around the new
2800    /// ψ̄, keeping Ση_a = 0.
2801    pub fn reseed_from_data(
2802        mut self,
2803        data: ArrayView2<'_, f64>,
2804        spec: &TermCollectionSpec,
2805        term_indices: &[usize],
2806        options: &SpatialLengthScaleOptimizationOptions,
2807    ) -> Self {
2808        assert_eq!(term_indices.len(), self.dims_per_term.len());
2809        let mut cursor = 0;
2810        for (slot, &term_idx) in term_indices.iter().enumerate() {
2811            let d = self.dims_per_term[slot];
2812            // Measure-jet dials are seeded from the realized spec and must
2813            // not be recentered into a κ data window.
2814            if measure_jet_term_spec(spec, term_idx).is_some() {
2815                cursor += d;
2816                continue;
2817            }
2818            // Constant-curvature κ is seeded from the spec (the user's curvature
2819            // hint, default κ = 0); `clamp_to_bounds` projects it feasible. It
2820            // is not a log-scale, so the log-κ recenter below never applies.
2821            if constant_curvature_term_spec(spec, term_idx).is_some() {
2822                cursor += d;
2823                continue;
2824            }
2825            let Some(psi_bar_new) = spatial_term_psi_seed(data, spec, term_idx, options) else {
2826                cursor += d;
2827                continue;
2828            };
2829            if d == 0 {
2830                continue;
2831            }
2832            let current: Vec<f64> = self.values.slice(s![cursor..cursor + d]).to_vec();
2833            let psi_bar_old = current.iter().sum::<f64>() / d as f64;
2834            for (offset, &old_value) in current.iter().enumerate() {
2835                self.values[cursor + offset] = psi_bar_new + (old_value - psi_bar_old);
2836            }
2837            cursor += d;
2838        }
2839        self
2840    }
2841
2842    /// Project ψ values into `[lower, upper]` element-wise. Used after
2843    /// `from_length_scales*` + `reseed_from_data` when a user-supplied
2844    /// `spec.length_scale` falls outside the data-derived ψ window set by
2845    /// `{lower,upper}_bounds*_from_data`. BFGS requires theta0 ∈ [lower,
2846    /// upper]; projecting is the unique closest feasible seed. The user's
2847    /// length_scale was always a hint for the outer optimizer (the optimizer
2848    /// is authoritative for κ), not a hard constraint — so clipping preserves
2849    /// their intent as far as the geometry allows. Emits `log::info!` when
2850    /// any coordinate moves, so the outside-window case is diagnostically
2851    /// visible (not silent).
2852    pub fn clamp_to_bounds(
2853        mut self,
2854        lower: &SpatialLogKappaCoords,
2855        upper: &SpatialLogKappaCoords,
2856    ) -> Self {
2857        assert_eq!(self.values.len(), lower.values.len());
2858        assert_eq!(self.values.len(), upper.values.len());
2859        let mut n_projected = 0usize;
2860        let mut worst_delta = 0.0_f64;
2861        for idx in 0..self.values.len() {
2862            let lo = lower.values[idx];
2863            let hi = upper.values[idx];
2864            if !(lo.is_finite() && hi.is_finite()) {
2865                continue;
2866            }
2867            let v = self.values[idx];
2868            if v < lo {
2869                worst_delta = worst_delta.max(lo - v);
2870                self.values[idx] = lo;
2871                n_projected += 1;
2872            } else if v > hi {
2873                worst_delta = worst_delta.max(v - hi);
2874                self.values[idx] = hi;
2875                n_projected += 1;
2876            }
2877        }
2878        if n_projected > 0 {
2879            log::info!(
2880                "[spatial-kappa] projected {n_projected}/{} ψ seed coords into data-derived bounds \
2881                 (worst excess={worst_delta:.3} log units); user length_scale falls outside \
2882                 [{KERNEL_RANGE_MIN_DIAMETER_FRACTION}/r_max, {KERNEL_RANGE_MAX_SPACING_MULTIPLE}/r_min] geometry window",
2883                self.values.len()
2884            );
2885        }
2886        self
2887    }
2888
2889    /// Reconstruct from theta tail with known dimensionality layout.
2890    pub fn from_theta_tail_with_dims(
2891        theta: &Array1<f64>,
2892        start: usize,
2893        dims_per_term: Vec<usize>,
2894    ) -> Self {
2895        let total: usize = dims_per_term.iter().sum();
2896        Self {
2897            values: theta.slice(s![start..start + total]).to_owned(),
2898            dims_per_term,
2899        }
2900    }
2901
2902    /// Total number of ψ values in the flat array (= sum of dims_per_term).
2903    pub fn len(&self) -> usize {
2904        self.values.len()
2905    }
2906
2907    /// Dimensionality layout: how many ψ values each term contributes.
2908    pub fn dims_per_term(&self) -> &[usize] {
2909        &self.dims_per_term
2910    }
2911
2912    /// Get the offset into the flat array for logical term i.
2913    fn term_offset(&self, term_idx: usize) -> usize {
2914        self.dims_per_term[..term_idx].iter().sum()
2915    }
2916
2917    /// Get the slice of ψ values for logical term i.
2918    pub fn term_slice(&self, term_idx: usize) -> &[f64] {
2919        let offset = self.term_offset(term_idx);
2920        let d = self.dims_per_term[term_idx];
2921        &self.values.as_slice().unwrap()[offset..offset + d]
2922    }
2923
2924    pub fn as_array(&self) -> &Array1<f64> {
2925        &self.values
2926    }
2927
2928    /// #1464: overwrite the single ψ value of a scalar (1-D) logical term by its
2929    /// position `slot` in this coords vector (the same ordering as the
2930    /// `term_indices` slice the constructors were built from). Used to inject the
2931    /// fixed-κ sign-basin seed into a constant-curvature term's raw-κ slot before
2932    /// the joint solve. No-op (returns `false`) when the slot is not scalar.
2933    pub fn set_scalar_slot(&mut self, slot: usize, value: f64) -> bool {
2934        if slot >= self.dims_per_term.len() || self.dims_per_term[slot] != 1 {
2935            return false;
2936        }
2937        let offset = self.term_offset(slot);
2938        self.values[offset] = value;
2939        true
2940    }
2941
2942    /// Split at a logical-term boundary. `mid` is the number of terms in the
2943    /// first half (not a flat-array index).
2944    pub fn split_at(&self, mid: usize) -> (Self, Self) {
2945        let flat_mid: usize = self.dims_per_term[..mid].iter().sum();
2946        (
2947            Self {
2948                values: self.values.slice(s![0..flat_mid]).to_owned(),
2949                dims_per_term: self.dims_per_term[..mid].to_vec(),
2950            },
2951            Self {
2952                values: self.values.slice(s![flat_mid..]).to_owned(),
2953                dims_per_term: self.dims_per_term[mid..].to_vec(),
2954            },
2955        )
2956    }
2957
2958    /// Apply optimized ψ values back to the spec.
2959    ///
2960    /// For isotropic terms (dims=1): sets scalar length_scale = exp(−ψ).
2961    /// For anisotropic terms (dims=d): hybrid/isotropic families set
2962    /// length_scale = exp(−ψ̄) with centered η_a = ψ_a − ψ̄, while pure Duchon
2963    /// writes only centered η_a and leaves length_scale = None.
2964    pub fn apply_tospec(
2965        &self,
2966        spec: &TermCollectionSpec,
2967        term_indices: &[usize],
2968    ) -> Result<TermCollectionSpec, EstimationError> {
2969        if term_indices.len() != self.dims_per_term.len() {
2970            crate::bail_invalid_estim!(
2971                "SpatialLogKappaCoords::apply_tospec: term count mismatch: \
2972                 term_indices={} dims_per_term={}",
2973                term_indices.len(),
2974                self.dims_per_term.len()
2975            );
2976        }
2977        let mut updated = spec.clone();
2978        for (slot, &term_idx) in term_indices.iter().enumerate() {
2979            let psi = self.term_slice(slot);
2980            let d = self.dims_per_term[slot];
2981            // Measure-jet: write the dial coordinates straight back; the
2982            // κ-translation below would misread them as log-scales.
2983            if measure_jet_term_spec(&updated, term_idx).is_some() {
2984                set_measure_jet_psi_dials(&mut updated, term_idx, psi)?;
2985                continue;
2986            }
2987            // Constant-curvature: write the optimized signed κ straight back;
2988            // the −exp(ψ) length-scale translation below is log-κ semantics and
2989            // would misread the raw curvature.
2990            if constant_curvature_term_spec(&updated, term_idx).is_some() {
2991                set_constant_curvature_kappa(&mut updated, term_idx, psi)?;
2992                continue;
2993            }
2994            let (next_length_scale, next_aniso) = spatial_term_psi_to_length_scale_and_aniso(psi);
2995            if (d == 1 || next_length_scale.is_some())
2996                && let Some(length_scale) = next_length_scale
2997            {
2998                set_spatial_length_scale(&mut updated, term_idx, length_scale)?;
2999            }
3000            if let Some(eta) = next_aniso {
3001                set_spatial_aniso_log_scales(&mut updated, term_idx, eta)?;
3002            }
3003        }
3004        Ok(updated)
3005    }
3006}
3007
3008pub fn center_aniso_log_scales(eta: &[f64]) -> Vec<f64> {
3009    if eta.len() <= 1 {
3010        return eta.to_vec();
3011    }
3012    let mean = eta.iter().sum::<f64>() / eta.len() as f64;
3013    eta.iter()
3014        .map(|&v| {
3015            let centered = v - mean;
3016            if centered.abs() <= 1e-15 {
3017                0.0
3018            } else {
3019                centered
3020            }
3021        })
3022        .collect()
3023}
3024
3025/// Whether a spatial term contributes per-axis ψ entries to the outer joint
3026/// hyperparameter vector.
3027pub fn spatial_term_uses_per_axis_psi(resolvedspec: &TermCollectionSpec, term_idx: usize) -> bool {
3028    if let Some(mj) = measure_jet_term_spec(resolvedspec, term_idx) {
3029        return measure_jet_enrolls_psi(mj);
3030    }
3031    let Some(d) = get_spatial_feature_dim(resolvedspec, term_idx) else {
3032        return false;
3033    };
3034    if d <= 1 {
3035        return false;
3036    }
3037    let Some(eta) = get_spatial_aniso_log_scales(resolvedspec, term_idx) else {
3038        return false;
3039    };
3040    if eta.len() != d {
3041        return false;
3042    }
3043    !matches!(
3044        resolvedspec.smooth_terms.get(term_idx).map(|term| &term.basis),
3045        Some(SmoothBasisSpec::Duchon { .. })
3046    )
3047}
3048
3049pub fn set_spatial_length_scale(
3050    spec: &mut TermCollectionSpec,
3051    term_idx: usize,
3052    length_scale: f64,
3053) -> Result<(), EstimationError> {
3054    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3055        crate::bail_invalid_estim!("spatial length-scale term index {term_idx} out of range");
3056    };
3057    match &mut term.basis {
3058        SmoothBasisSpec::ThinPlate { spec, .. } => {
3059            spec.length_scale = length_scale;
3060            Ok(())
3061        }
3062        SmoothBasisSpec::Matern { spec, .. } => {
3063            spec.length_scale = length_scale;
3064            Ok(())
3065        }
3066        SmoothBasisSpec::Duchon { spec, .. } => {
3067            spec.length_scale = Some(length_scale);
3068            Ok(())
3069        }
3070        _ => Err(EstimationError::InvalidInput(format!(
3071            "term '{}' does not expose a spatial length scale",
3072            term.name
3073        ))),
3074    }
3075}
3076
3077pub fn get_spatial_length_scale(spec: &TermCollectionSpec, term_idx: usize) -> Option<f64> {
3078    spec.smooth_terms
3079        .get(term_idx)
3080        .and_then(|term| match &term.basis {
3081            SmoothBasisSpec::ThinPlate { spec, .. } => Some(spec.length_scale),
3082            SmoothBasisSpec::Matern { spec, .. } => Some(spec.length_scale),
3083            SmoothBasisSpec::Duchon { spec, .. } => spec.length_scale,
3084            _ => None,
3085        })
3086}
3087
3088pub fn spatial_term_supports_hyper_optimization(spec: &TermCollectionSpec, term_idx: usize) -> bool {
3089    // Ordinary penalized thin-plate regression splines do not have an
3090    // identifiable kernel scale once REML is already learning the smoothing
3091    // penalty. Treat the resolved length scale as fixed geometry; enrolling a
3092    // scalar TPS kappa axis creates the flat ρ/κ valleys reported in #718,
3093    // #721, #731, and #732.
3094    if let Some(term) = spec.smooth_terms.get(term_idx)
3095        && let SmoothBasisSpec::ThinPlate { .. } = &term.basis
3096    {
3097        return false;
3098    }
3099
3100    // Duchon anisotropy η is a FIXED, geometry-derived basis parameter, NOT a
3101    // REML hyper axis: the metric is estimated once from the knot-cloud spread
3102    // (`auto_seed_aniso_contrasts`, applied on every Duchon basis build) and the
3103    // Hilbert-scale λ's carry all learned smoothness. So a pure Duchon (no κ)
3104    // contributes no outer optimization axis even when `scale_dims` is on —
3105    // "standardize the geometry, then learn the smoothness." Only an explicit
3106    // kernel length scale κ (the Matérn / hybrid path) is optimized here.
3107    //
3108    // ISOTROPIC Matérn: the *default* `matern(x1, x2)` is isotropic
3109    // (`scale_dims=false` → `aniso_log_scales = None`). It contributes exactly
3110    // ONE κ optimization axis — its scalar log-κ. The shared GAMLSS /
3111    // location-scale exact-joint ψ engine and the spatial-κ joint outer solver
3112    // both require an isotropic Matérn block to expose this single isotropic κ
3113    // axis (#822/#851); without it the per-block ψ-derivative lists are empty
3114    // and the joint-ψ hooks degenerate to `None`. The isotropic κ is the lone
3115    // kernel hyper axis here, mirroring the per-axis ψ ARD that the anisotropic
3116    // path exposes (just collapsed to one dimension).
3117    //
3118    // ANISOTROPIC Matérn (`scale_dims=true` → `aniso_log_scales = Some`) keeps
3119    // its per-axis kernel-η ARD: the d-dimensional ψ search is the *point* of
3120    // the anisotropic request ("Matérn keeps its kernel-η ARD").
3121    //
3122    // Either way a Matérn term always enrolls a κ/ψ axis (1 isotropic, or d
3123    // anisotropic), so `spatial_dims_per_term` reports the correct count.
3124    if let Some(term) = spec.smooth_terms.get(term_idx)
3125        && let SmoothBasisSpec::Matern { .. } = &term.basis
3126    {
3127        return true;
3128    }
3129
3130    // Measure-jet geometry dials are outer ψ coordinates; enrollment is
3131    // owned by `measure_jet_enrolls_psi`.
3132    if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
3133        return measure_jet_enrolls_psi(mj);
3134    }
3135
3136    // Constant-curvature smooths always enroll their single signed curvature κ
3137    // as an outer ψ-coordinate (#944 stage 3): κ̂ is the headline estimand, so
3138    // unlike a fixed-ℓ kernel it is fitted by default, not gated on a
3139    // user-supplied scale. The coordinate is raw κ (interior κ = 0), and its
3140    // exact design/penalty κ-derivatives come from
3141    // `build_constant_curvature_basis_kappa_derivatives`.
3142    if constant_curvature_term_spec(spec, term_idx).is_some() {
3143        return true;
3144    }
3145
3146    get_spatial_length_scale(spec, term_idx).is_some()
3147}
3148
3149/// The measure-jet term's spec, when `term_idx` is a measure-jet smooth.
3150/// Single accessor for every dial-plumbing dispatch below.
3151pub fn measure_jet_term_spec(
3152    spec: &TermCollectionSpec,
3153    term_idx: usize,
3154) -> Option<&crate::basis::MeasureJetBasisSpec> {
3155    spec.smooth_terms
3156        .get(term_idx)
3157        .and_then(|term| match &term.basis {
3158            SmoothBasisSpec::MeasureJet { spec, .. } => Some(spec),
3159            _ => None,
3160        })
3161}
3162
3163/// Single source for measure-jet outer-ψ enrollment: the lnτ dial is
3164/// undefined in the τ = 0 pseudo-inverse oracle mode (see
3165/// `build_measure_jet_basis_psi_derivatives`), so only a positive ridge
3166/// enrolls the dial group. `spatial_term_supports_hyper_optimization` and
3167/// `spatial_term_uses_per_axis_psi` both defer here so the θ-layout
3168/// sources cannot disagree.
3169pub fn measure_jet_enrolls_psi(mj: &crate::basis::MeasureJetBasisSpec) -> bool {
3170    // Two independent enrollment sources (#1116), both explicit:
3171    //   * the design-moving representer length-scale ℓ (`learn_length_scale`),
3172    //     available in every mode when the spec opts in;
3173    //   * the multiscale penalty dials (s, α, lnτ): the per-scale spectral
3174    //     split's (α, lnτ) ride the explicit `multiscale` opt-in, and the lnτ
3175    //     channel additionally needs a positive ridge (τ = 0 is the
3176    //     pseudo-inverse oracle mode where lnτ is undefined).
3177    // A term enrolls if EITHER source is active.
3178    measure_jet_learns_length_scale(mj)
3179        || (mj.tau0 > 0.0 && crate::basis::measure_jet_multiscale_mode(mj))
3180}
3181
3182/// Whether the design-moving ℓ dial is enrolled for this term. ℓ is fixed by
3183/// default and learnable in every mode only when `learn_length_scale = true`.
3184pub fn measure_jet_learns_length_scale(mj: &crate::basis::MeasureJetBasisSpec) -> bool {
3185    mj.learn_length_scale
3186}
3187
3188pub fn freeze_measure_jet_length_scale_learning(spec: &mut TermCollectionSpec) -> usize {
3189    let mut frozen = 0;
3190    for term in spec.smooth_terms.iter_mut() {
3191        if let SmoothBasisSpec::MeasureJet { spec: mj, .. } = &mut term.basis
3192            && mj.learn_length_scale
3193        {
3194            mj.learn_length_scale = false;
3195            frozen += 1;
3196        }
3197    }
3198    frozen
3199}
3200
3201/// Measure-jet ψ dial boxes. The dials are NOT log-kernel-scales, so the
3202/// κ-window machinery never applies: `α` spans density-weighted (0) through
3203/// past-Coifman–Lafon (>1) normalization, and `lnτ` covers the ridge from
3204/// numerically-exact-projection to heavy noise-floor damping. (The energy
3205/// order `s` is the pinned explicit value or absorbed by the REML-learned
3206/// per-scale amplitudes — see `measure_jet_penalty_psi_dim` — so it carries no
3207/// dial box.)
3208pub const MEASURE_JET_PSI_ALPHA_BOUNDS: (f64, f64) = (-1.0, 3.0);
3209
3210pub const MEASURE_JET_PSI_LN_TAU_BOUNDS: (f64, f64) = (-18.420680743952367, 4.605170185988092);
3211
3212/// Log-ℓ box for the design-moving representer length-scale dial (#1116). An
3213/// ABSOLUTE window in the data coordinate scale (ln of ℓ ∈ [1e-3, 1e2]) used
3214/// only when the spec explicitly enrolls the learned representer range. Absolute
3215/// (not seed-relative) so the bound producer needs no data view, matching the
3216/// other dial boxes. `ln(1e-3) = -6.9077…`, `ln(1e2) = 4.6051…`.
3217pub const MEASURE_JET_PSI_LN_LENGTH_SCALE_BOUNDS: (f64, f64) = (-6.907755278982137, 4.605170185988092);
3218
3219/// Number of multiscale PENALTY dials (excluding the design-moving ℓ):
3220/// multiscale (per-scale spectral) mode carries (α, lnτ) = 2 — the order is
3221/// either the pinned explicit `s` or absorbed by the REML-learned per-scale
3222/// amplitudes, so it is NOT a dial; single-scale (the default) carries none.
3223/// MUST agree with the penalty-coordinate layout of
3224/// `build_measure_jet_basis_psi_derivatives` (its `per_level` branch always
3225/// emits exactly the (α, lnτ) coordinate pair).
3226pub fn measure_jet_penalty_psi_dim(mj: &crate::basis::MeasureJetBasisSpec) -> usize {
3227    if crate::basis::measure_jet_multiscale_mode(mj) {
3228        2
3229    } else {
3230        0
3231    }
3232}
3233
3234/// ψ dimension of a measure-jet term. The design-moving ℓ dial (when enrolled)
3235/// is coordinate 0; the multiscale penalty dials follow. MUST agree with the
3236/// coordinate layout of `build_measure_jet_basis_psi_derivatives` (ℓ first).
3237pub fn measure_jet_psi_dim(mj: &crate::basis::MeasureJetBasisSpec) -> usize {
3238    usize::from(measure_jet_learns_length_scale(mj)) + measure_jet_penalty_psi_dim(mj)
3239}
3240
3241/// Seed ψ from the term's realized dials, in producer coordinate order: ℓ first
3242/// (when enrolled), then the multiscale penalty dials. The ℓ seed is the
3243/// realized representer range `ln(length_scale)` (the resolved spec carries the
3244/// concrete auto value after the design build/freeze).
3245pub fn measure_jet_psi_seed(mj: &crate::basis::MeasureJetBasisSpec) -> Vec<f64> {
3246    let mut seed = Vec::with_capacity(measure_jet_psi_dim(mj));
3247    if measure_jet_learns_length_scale(mj) {
3248        // length_scale > 0 after resolution; the 0.0 sentinel (pre-resolution)
3249        // falls back to the centre of the log-ℓ box so the optimizer still
3250        // starts feasible and the first data-aware reseed corrects it.
3251        let ell = if mj.length_scale > 0.0 {
3252            mj.length_scale
3253        } else {
3254            1.0
3255        };
3256        seed.push(ell.ln());
3257    }
3258    if measure_jet_penalty_psi_dim(mj) > 0 {
3259        // Multiscale penalty dials, producer order: (α, lnτ).
3260        let ln_tau = mj.tau0.max(f64::MIN_POSITIVE).ln();
3261        seed.extend_from_slice(&[mj.alpha, ln_tau]);
3262    }
3263    seed
3264}
3265
3266/// One end of the per-coordinate dial boxes, in producer coordinate order
3267/// (ℓ first when enrolled, then the multiscale penalty dials).
3268pub fn measure_jet_psi_bound_values(mj: &crate::basis::MeasureJetBasisSpec, upper: bool) -> Vec<f64> {
3269    let pick = |b: (f64, f64)| if upper { b.1 } else { b.0 };
3270    let mut bounds = Vec::with_capacity(measure_jet_psi_dim(mj));
3271    if measure_jet_learns_length_scale(mj) {
3272        bounds.push(pick(MEASURE_JET_PSI_LN_LENGTH_SCALE_BOUNDS));
3273    }
3274    if measure_jet_penalty_psi_dim(mj) > 0 {
3275        // Multiscale penalty dials, producer order: (α, lnτ).
3276        bounds.push(pick(MEASURE_JET_PSI_ALPHA_BOUNDS));
3277        bounds.push(pick(MEASURE_JET_PSI_LN_TAU_BOUNDS));
3278    }
3279    bounds
3280}
3281
3282/// Write optimized ψ dials back into a measure-jet spec. Returns `true` when
3283/// any dial actually moved. The geometry (centers, masses, band, ℓ, z) is
3284/// ψ-FIXED by contract — only the dials change, so frozen-quadrature
3285/// rebuilds reproduce the identical penalty layout at the new dials.
3286pub fn apply_measure_jet_psi(
3287    mj: &mut crate::basis::MeasureJetBasisSpec,
3288    psi: &[f64],
3289) -> Result<bool, EstimationError> {
3290    if psi.len() != measure_jet_psi_dim(mj) {
3291        crate::bail_invalid_estim!(
3292            "measure-jet ψ write-back dimension mismatch: got {} values for a {}-dial term",
3293            psi.len(),
3294            measure_jet_psi_dim(mj)
3295        );
3296    }
3297    let mut changed = false;
3298    // Coordinate 0 (when enrolled) is the design-moving ln(ℓ); the multiscale
3299    // penalty dials follow. Same order as `measure_jet_psi_seed` and the
3300    // producer (`build_measure_jet_basis_psi_derivatives`).
3301    let mut cursor = 0usize;
3302    if measure_jet_learns_length_scale(mj) {
3303        let next_ell = psi[cursor].exp();
3304        cursor += 1;
3305        if !(next_ell.is_finite() && next_ell > 0.0) {
3306            crate::bail_invalid_estim!(
3307                "measure-jet ψ write-back produced a non-finite/non-positive length_scale (ℓ={next_ell})"
3308            );
3309        }
3310        if next_ell != mj.length_scale {
3311            mj.length_scale = next_ell;
3312            changed = true;
3313        }
3314    }
3315    if measure_jet_penalty_psi_dim(mj) > 0 {
3316        // Multiscale penalty dials, producer order: (α, lnτ). The order `s` is
3317        // not a dial (pinned explicit or absorbed by the per-scale amplitudes).
3318        let next_alpha = psi[cursor];
3319        let next_tau = psi[cursor + 1].exp();
3320        if !(next_alpha.is_finite() && next_tau.is_finite() && next_tau > 0.0) {
3321            crate::bail_invalid_estim!(
3322                "measure-jet ψ write-back produced non-finite dials (alpha={next_alpha}, tau={next_tau})"
3323            );
3324        }
3325        if next_alpha != mj.alpha {
3326            mj.alpha = next_alpha;
3327            changed = true;
3328        }
3329        if next_tau != mj.tau0 {
3330            mj.tau0 = next_tau;
3331            changed = true;
3332        }
3333    }
3334    Ok(changed)
3335}
3336
3337/// Collection-level measure-jet dial write-back (the `apply_tospec` /
3338/// realizer-side entry). Returns whether anything moved.
3339pub fn set_measure_jet_psi_dials(
3340    spec: &mut TermCollectionSpec,
3341    term_idx: usize,
3342    psi: &[f64],
3343) -> Result<bool, EstimationError> {
3344    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3345        crate::bail_invalid_estim!("measure-jet ψ write-back: term index {term_idx} out of range");
3346    };
3347    set_single_term_measure_jet_psi_dials(term, psi)
3348}
3349
3350/// Single-term dial write-back: the shared match+apply core, also used
3351/// directly on the cached per-trial build spec (whose caller has already
3352/// change-checked at the collection level and rebuilds regardless of the
3353/// moved flag).
3354pub fn set_single_term_measure_jet_psi_dials(
3355    term: &mut SmoothTermSpec,
3356    psi: &[f64],
3357) -> Result<bool, EstimationError> {
3358    let SmoothBasisSpec::MeasureJet { spec: mj, .. } = &mut term.basis else {
3359        crate::bail_invalid_estim!("measure-jet ψ write-back targeted a non-measure-jet term");
3360    };
3361    apply_measure_jet_psi(mj, psi)
3362}
3363
3364/// The constant-curvature smooth's spec, when `term_idx` is one. Single
3365/// accessor for every κ-ψ dispatch below, mirroring `measure_jet_term_spec`.
3366pub fn constant_curvature_term_spec(
3367    spec: &TermCollectionSpec,
3368    term_idx: usize,
3369) -> Option<&crate::basis::ConstantCurvatureBasisSpec> {
3370    spec.smooth_terms
3371        .get(term_idx)
3372        .and_then(|term| match &term.basis {
3373            SmoothBasisSpec::ConstantCurvature { spec, .. } => Some(spec),
3374            _ => None,
3375        })
3376}
3377
3378/// Hard positive cap on |κ| relative to the data's inverse squared chart
3379/// radius. The κ-stereographic chart is valid for `1 + κ‖x‖² > 0`; at
3380/// `|κ| = 1/R²` (R² = max squared chart radius) the gauge `1 + κ‖x‖²` reaches
3381/// the chart edge for the farthest data point, so the optimizer is boxed to a
3382/// safe fraction of that scale on both sides. κ = 0 (flat) is the centre of
3383/// the window, an interior point of the `S^d ← ℝ^d → H^d` family — exactly the
3384/// reachability the raw-κ (not log-κ) coordinate exists to preserve.
3385pub const CONSTANT_CURVATURE_KAPPA_CHART_FRACTION: f64 = 0.5;
3386
3387/// Floor on the data's squared chart radius used to scale the κ window, so a
3388/// degenerate (near-origin) point cloud still yields a finite, usable bracket
3389/// rather than an unbounded one.
3390pub const CONSTANT_CURVATURE_MIN_CHART_RADIUS2: f64 = 1e-8;
3391
3392/// `(κ_min, κ_max)` outer-optimization window for a constant-curvature term,
3393/// derived from the data's maximum squared chart radius `R²` so the κ-jets
3394/// never leave the κ-stereographic chart. Symmetric about κ = 0:
3395/// `±CONSTANT_CURVATURE_KAPPA_CHART_FRACTION / R²`.
3396pub fn constant_curvature_kappa_bounds(
3397    data: ArrayView2<'_, f64>,
3398    spec: &TermCollectionSpec,
3399    term_idx: usize,
3400) -> (f64, f64) {
3401    let feature_cols = match spec.smooth_terms.get(term_idx).map(|t| &t.basis) {
3402        Some(SmoothBasisSpec::ConstantCurvature { feature_cols, .. }) => feature_cols,
3403        _ => return (-1.0, 1.0),
3404    };
3405    let mut max_r2 = CONSTANT_CURVATURE_MIN_CHART_RADIUS2;
3406    for row in data.outer_iter() {
3407        let mut r2 = 0.0_f64;
3408        for &c in feature_cols.iter() {
3409            if let Some(&v) = row.get(c)
3410                && v.is_finite()
3411            {
3412                r2 += v * v;
3413            }
3414        }
3415        if r2 > max_r2 {
3416            max_r2 = r2;
3417        }
3418    }
3419    let half = CONSTANT_CURVATURE_KAPPA_CHART_FRACTION / max_r2;
3420    (-half, half)
3421}
3422
3423/// Write the optimized κ back into a constant-curvature term spec. Returns
3424/// `true` when κ moved. Centers, ℓ, and the constraint transform `z` are
3425/// κ-FIXED by the basis κ-contract, so only `kappa` changes.
3426pub fn set_constant_curvature_kappa(
3427    spec: &mut TermCollectionSpec,
3428    term_idx: usize,
3429    psi: &[f64],
3430) -> Result<bool, EstimationError> {
3431    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3432        crate::bail_invalid_estim!(
3433            "constant-curvature κ write-back: term index {term_idx} out of range"
3434        );
3435    };
3436    set_single_term_constant_curvature_kappa(term, psi)
3437}
3438
3439/// Single-term κ write-back: the shared validate+apply core, also used directly
3440/// on the cached per-trial build spec in the incremental realizer (whose caller
3441/// has already change-checked at the collection level and rebuilds regardless
3442/// of the moved flag). Mirrors [`set_single_term_measure_jet_psi_dials`].
3443pub fn set_single_term_constant_curvature_kappa(
3444    term: &mut SmoothTermSpec,
3445    psi: &[f64],
3446) -> Result<bool, EstimationError> {
3447    if psi.len() != 1 {
3448        crate::bail_invalid_estim!(
3449            "constant-curvature κ write-back expects exactly one value, got {}",
3450            psi.len()
3451        );
3452    }
3453    let next_kappa = psi[0];
3454    if !next_kappa.is_finite() {
3455        crate::bail_invalid_estim!(
3456            "constant-curvature κ write-back produced a non-finite κ = {next_kappa}"
3457        );
3458    }
3459    let SmoothBasisSpec::ConstantCurvature { spec: cc, .. } = &mut term.basis else {
3460        crate::bail_invalid_estim!(
3461            "constant-curvature κ write-back targeted a non-constant-curvature term"
3462        );
3463    };
3464    if cc.kappa != next_kappa {
3465        cc.kappa = next_kappa;
3466        Ok(true)
3467    } else {
3468        Ok(false)
3469    }
3470}
3471
3472/// Returns `true` when a spatial term has NO outer optimization axes — i.e.
3473/// the user provided an explicit `length_scale` and the term does not enroll
3474/// REML-side per-axis ψ contrasts, so both the scalar κ and any fixed geometry
3475/// anisotropy are anchored.
3476///
3477/// This is the per-term predicate that distinguishes "fixed kernel scale"
3478/// from "optimize the kernel scale" within the family entry points that
3479/// want to honor an explicit user-supplied scale (e.g. Bernoulli
3480/// marginal-slope, where the joint-spatial outer solver otherwise spends
3481/// ~80 iters stalled on the user's chosen ρ at high gradient).
3482pub fn spatial_term_has_locked_kappa(spec: &TermCollectionSpec, term_idx: usize) -> bool {
3483    get_spatial_length_scale(spec, term_idx).is_some()
3484        && !spatial_term_uses_per_axis_psi(spec, term_idx)
3485}
3486
3487pub fn all_spatial_terms_kappa_fixed(spec: &TermCollectionSpec) -> bool {
3488    spec.smooth_terms.iter().enumerate().all(|(idx, _)| {
3489        !spatial_term_supports_hyper_optimization(spec, idx)
3490            || spatial_term_has_locked_kappa(spec, idx)
3491    })
3492}
3493
3494pub fn spatial_identifiability_policy(termspec: &SmoothTermSpec) -> Option<&SpatialIdentifiability> {
3495    match &termspec.basis {
3496        SmoothBasisSpec::ThinPlate { spec, .. } => Some(&spec.identifiability),
3497        SmoothBasisSpec::Duchon { spec, .. } => Some(&spec.identifiability),
3498        _ => None,
3499    }
3500}
3501
3502/// Standard deviation of the wide, weakly-informative symmetric `Normal` prior
3503/// placed on a relaxable double-penalty smooth's `DoublePenaltyNullspace`
3504/// selection coordinate when the fit is well-determined.
3505pub const NULLSPACE_WELLDET_DEGENERACY_RHO_SD: f64 = 15.0;
3506
3507/// True iff `prior` is the well-determined double-penalty null-space
3508/// degeneracy prior placed on a `DoublePenaltyNullspace` selection coordinate.
3509pub fn is_nullspace_degeneracy_prior(prior: &gam_spec::RhoPrior) -> bool {
3510    matches!(
3511        prior,
3512        gam_spec::RhoPrior::Normal { mean, sd }
3513            if *mean == 0.0 && *sd == NULLSPACE_WELLDET_DEGENERACY_RHO_SD
3514    )
3515}
3516
3517/// Per-term data-derived ψ = log κ bounds.
3518///
3519/// Uses the same safe operating range documented in
3520/// [`crate::basis::build_matern_basis`] / [`crate::basis::build_duchon_basis`]:
3521///   κ ∈ [2 / r_max, 1e2 / r_min]
3522/// where (r_min, r_max) are pairwise-distance extrema of the term's resolved
3523/// centers (post-fit) or the standardized feature data columns (pre-fit).
3524/// Lower edge of the data-derived kernel-range window, as a fraction of the
3525/// maximum pairwise distance `r_max`: length scales below `2/r_max` resolve
3526/// structure finer than the closest center pair, so the kernel range floor is
3527/// set at twice the maximum spacing.
3528pub const KERNEL_RANGE_MIN_DIAMETER_FRACTION: f64 = 2.0;
3529
3530/// Upper edge of the data-derived kernel-range window, as a multiple of the
3531/// minimum pairwise distance `r_min`: beyond `100/r_min` the radial columns go
3532/// nearly collinear with the polynomial nullspace, so the kernel range is
3533/// capped here to keep the basis geometry well-conditioned.
3534pub const KERNEL_RANGE_MAX_SPACING_MULTIPLE: f64 = 1e2;
3535
3536
3537/// Returns ψ-space bounds (ψ_lo = ln(κ_lo), ψ_hi = ln(κ_hi)).
3538///
3539/// When geometry is unavailable (e.g., fewer than 2 distinct points), falls
3540/// back to the scalar `options.min_length_scale` / `options.max_length_scale`
3541/// window so the outer optimizer never sees NaN bounds.
3542///
3543/// The returned window is intersected with the options window so user-set
3544/// `min_length_scale` / `max_length_scale` remain hard limits.
3545pub fn spatial_term_psi_bounds(
3546    data: ArrayView2<'_, f64>,
3547    spec: &TermCollectionSpec,
3548    term_idx: usize,
3549    options: &SpatialLengthScaleOptimizationOptions,
3550) -> (f64, f64) {
3551    let fallback = (
3552        -options.max_length_scale.ln(),
3553        -options.min_length_scale.ln(),
3554    );
3555    // Constant-curvature: the ψ coordinate is the raw signed κ, so its window is
3556    // the chart-feasible κ bracket, NOT a log-ℓ window. Mirrors the aniso bounds
3557    // path's `constant_curvature_kappa_bounds` branch so the isotropic
3558    // (non-aniso) seed clamp projects κ into the right interval.
3559    if constant_curvature_term_spec(spec, term_idx).is_some() {
3560        return constant_curvature_kappa_bounds(data, spec, term_idx);
3561    }
3562    let Some(term) = spec.smooth_terms.get(term_idx) else {
3563        return fallback;
3564    };
3565    // Prefer resolved centers (post-fit) since they live in the same standardized
3566    // space the kernel actually sees. Centers are capped at `default_num_centers`
3567    // (<=2000), so exact pairwise bounds are cheap (<4M ops). If centers are
3568    // not yet UserProvided, fall back to the standardized feature data columns
3569    // with the capped-sample path (O(K²·d), K=1024) — the sample is
3570    // conservative for κ bounds (see `pairwise_distance_bounds_sampled`
3571    // docs): it never excludes a feasible κ the exact method would include.
3572    //
3573    // Under anisotropy the kernel metric is y-space (y_a = exp(η_a) x_a),
3574    // so r_min/r_max must be y-space distances. This matters only when the
3575    // spec already carries calibrated η_a at setup time (e.g., warm-start
3576    // or refit paths); for fresh optimization η_a starts at 0 and y = x.
3577    let aniso = get_spatial_aniso_log_scales(spec, term_idx);
3578    let r_bounds = match spatial_term_center_strategy(term) {
3579        Some(CenterStrategy::UserProvided(centers)) if centers.nrows() >= 2 => {
3580            match aniso.as_deref() {
3581                Some(eta) if eta.len() == centers.ncols() => {
3582                    let y = points_in_aniso_y_space(centers.view(), eta);
3583                    pairwise_distance_bounds(y.view())
3584                }
3585                _ => pairwise_distance_bounds(centers.view()),
3586            }
3587        }
3588        _ => standardized_spatial_term_data(data, term)
3589            .ok()
3590            .and_then(|x| match aniso.as_deref() {
3591                Some(eta) if eta.len() == x.ncols() => {
3592                    let y = points_in_aniso_y_space(x.view(), eta);
3593                    pairwise_distance_bounds_sampled(y.view())
3594                }
3595                _ => pairwise_distance_bounds_sampled(x.view()),
3596            }),
3597    };
3598    let Some((r_min, r_max)) = r_bounds else {
3599        return fallback;
3600    };
3601    // Length scales substantially larger than the data diameter make radial
3602    // TPS/Matern columns nearly collinear with their polynomial nullspace.
3603    // The nullspace already carries constant/linear low-frequency structure,
3604    // so cap the kernel range at the diameter scale instead of letting the
3605    // optimizer enter a numerically degenerate basis geometry.
3606    let psi_lo_data = (KERNEL_RANGE_MIN_DIAMETER_FRACTION / r_max).ln();
3607    let psi_hi_data = (KERNEL_RANGE_MAX_SPACING_MULTIPLE / r_min).ln();
3608    // #1074: the Matérn-specific length-scale ceiling that used to live here was
3609    // deleted. It was masking, not fixing, the real defect: a hard upper bound on
3610    // the kernel range that pinned the κ-optimizer short rather than letting the
3611    // optimizer find the REML optimum. Matérn now shares the same generic geometry
3612    // window as Duchon / TPS (`KERNEL_RANGE_MIN_DIAMETER_FRACTION / r_max` floor,
3613    // `KERNEL_RANGE_MAX_SPACING_MULTIPLE / r_min` ceiling); the #1357 fully-flat
3614    // collapse corner is guarded by the EDF-collapse guard in
3615    // `spatial_optimization.rs`, which acts on the realized fit, not on a clamp.
3616    // Intersect with the options window so min/max_length_scale remain hard caps.
3617    let psi_lo = psi_lo_data.max(fallback.0);
3618    let psi_hi = psi_hi_data.min(fallback.1);
3619    if psi_lo >= psi_hi {
3620        // Degenerate intersection — fall back to the options window to keep the
3621        // outer optimizer from collapsing to a point.
3622        return fallback;
3623    }
3624    (psi_lo, psi_hi)
3625}
3626
3627/// Data-derived ψ seed for a spatial term when the user has not set an
3628/// explicit length_scale on its basis spec. Uses the geometric mean of the
3629/// data-informed kappa range (i.e., the midpoint of the ψ window).
3630pub fn spatial_term_psi_seed(
3631    data: ArrayView2<'_, f64>,
3632    spec: &TermCollectionSpec,
3633    term_idx: usize,
3634    options: &SpatialLengthScaleOptimizationOptions,
3635) -> Option<f64> {
3636    if get_spatial_length_scale(spec, term_idx).is_some() {
3637        return None; // user/spec-provided length_scale wins
3638    }
3639    let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, spec, term_idx, options);
3640    Some(0.5 * (psi_lo + psi_hi))
3641}
3642
3643pub fn spatial_term_psi_to_length_scale_and_aniso(psi: &[f64]) -> (Option<f64>, Option<Vec<f64>>) {
3644    if psi.len() <= 1 {
3645        (Some((-psi.first().copied().unwrap_or(0.0)).exp()), None)
3646    } else {
3647        let psi_bar = psi.iter().sum::<f64>() / psi.len() as f64;
3648        (
3649            Some((-psi_bar).exp()),
3650            Some(psi.iter().map(|&value| value - psi_bar).collect()),
3651        )
3652    }
3653}
3654
3655/// Get the `aniso_log_scales` from a spatial term, if present.
3656pub fn get_spatial_aniso_log_scales(
3657    spec: &TermCollectionSpec,
3658    term_idx: usize,
3659) -> Option<Vec<f64>> {
3660    spec.smooth_terms
3661        .get(term_idx)
3662        .and_then(|term| match &term.basis {
3663            SmoothBasisSpec::Matern { spec, .. } => spec.aniso_log_scales.clone(),
3664            SmoothBasisSpec::Duchon { spec, .. } => spec.aniso_log_scales.clone(),
3665            _ => None,
3666        })
3667}
3668
3669/// Per-axis response-structure score for anisotropy seeding.
3670///
3671/// For each spatial axis `a`, sort the response `y` by the axis coordinate
3672/// `x_a` and measure the total squared successive variation of the sorted
3673/// response, `tv_a = Σ_i (y_{σ(i+1)} − y_{σ(i)})²` where `σ` orders rows by
3674/// `x_a`. An axis that carries real (possibly nonlinear) signal makes `y` vary
3675/// SMOOTHLY when the rows are walked in that axis's order, so `tv_a` is SMALL;
3676/// a pure-nuisance axis leaves `y` looking unordered, so `tv_a` is LARGE.
3677///
3678/// This deliberately does NOT use a linear correlation `corr(x_a, y)`: for an
3679/// odd, symmetric signal such as `sin(2·x1)` over a symmetric domain the linear
3680/// correlation is ~0 on the *signal* axis, which would misdirect the seed. The
3681/// total-variation-of-sorted-response score captures nonlinear association.
3682///
3683/// Returns `score_a = −½·ln(tv_a + ε)` (larger ⇒ more signal on axis `a`),
3684/// centered to sum to zero, or `None` when the data is degenerate (too few
3685/// rows, non-finite, or all axes equally (un)structured). The caller adds a
3686/// BOUNDED multiple of this to the geometry seed — it is a conservative nudge,
3687/// never a hard override.
3688pub fn response_aware_axis_contrasts(
3689    x: ndarray::ArrayView2<'_, f64>,
3690    y: ndarray::ArrayView1<'_, f64>,
3691) -> Option<Vec<f64>> {
3692    let n = x.nrows();
3693    let d = x.ncols();
3694    if d <= 1 || n < 4 || y.len() != n {
3695        return None;
3696    }
3697    if x.iter().any(|v| !v.is_finite()) || y.iter().any(|v| !v.is_finite()) {
3698        return None;
3699    }
3700    let mut scores = Vec::with_capacity(d);
3701    for a in 0..d {
3702        let mut order: Vec<usize> = (0..n).collect();
3703        let col = x.column(a);
3704        order.sort_by(|&i, &j| {
3705            col[i]
3706                .partial_cmp(&col[j])
3707                .unwrap_or(std::cmp::Ordering::Equal)
3708        });
3709        let mut tv = 0.0_f64;
3710        for w in order.windows(2) {
3711            let diff = y[w[1]] - y[w[0]];
3712            tv += diff * diff;
3713        }
3714        // ε guards against ln(0) on a perfectly flat / constant response.
3715        scores.push(-0.5 * (tv + 1e-12).ln());
3716    }
3717    if scores.iter().any(|v| !v.is_finite()) {
3718        return None;
3719    }
3720    let mean = scores.iter().sum::<f64>() / d as f64;
3721    let centered: Vec<f64> = scores.iter().map(|&s| s - mean).collect();
3722    // If every axis is equally structured the centered scores are ~0 and the
3723    // nudge is a no-op — return None so the geometry seed is used unchanged.
3724    if centered.iter().all(|&v| v.abs() < 1e-9) {
3725        return None;
3726    }
3727    Some(centered)
3728}
3729
3730/// Conservative, response-aware anisotropy seed nudge applied before the κ outer
3731/// loop. For each anisotropic spatial term it adds a BOUNDED multiple of the
3732/// per-axis response-structure contrast (`response_aware_axis_contrasts`) on top
3733/// of the existing geometry seed, so the optimizer starts in the correct basin
3734/// instead of at a response-blind near-symmetric point (the #1376 under-recovery
3735/// where a signal axis and a nuisance axis with equal coordinate spread seed to
3736/// ~[0,0]). The nudge is clamped to keep this a perturbation, never a hard
3737/// override, so shared aniso Matérn/Duchon fits cannot be destabilized by it.
3738pub fn apply_response_aware_anisotropy_seed(
3739    data: ArrayView2<'_, f64>,
3740    y: ndarray::ArrayView1<'_, f64>,
3741    spec: &mut TermCollectionSpec,
3742    spatial_terms: &[usize],
3743) {
3744    // Bound on the per-axis contrast nudge (in η units). One LN_2 ≈ 0.69 halves
3745    // the effective per-axis length scale; capping at LN_2 keeps the seed within
3746    // one optimizer log-step of the geometry seed while still breaking the
3747    // symmetric-seed trap.
3748    const MAX_NUDGE: f64 = std::f64::consts::LN_2;
3749    for &term_idx in spatial_terms {
3750        let Some(current_eta) = get_spatial_aniso_log_scales(spec, term_idx) else {
3751            continue;
3752        };
3753        let d = current_eta.len();
3754        if d <= 1 {
3755            continue;
3756        }
3757        let Some(term) = spec.smooth_terms.get(term_idx) else {
3758            continue;
3759        };
3760        let feature_cols = term.basis.structural_feature_cols();
3761        if feature_cols.len() != d {
3762            continue;
3763        }
3764        let Ok(x) = select_columns(data, &feature_cols) else {
3765            continue;
3766        };
3767        let Some(contrast) = response_aware_axis_contrasts(x.view(), y) else {
3768            continue;
3769        };
3770        let nudged: Vec<f64> = current_eta
3771            .iter()
3772            .zip(contrast.iter())
3773            .map(|(&eta_a, &c_a)| eta_a + c_a.clamp(-MAX_NUDGE, MAX_NUDGE))
3774            .collect();
3775        // `set_spatial_aniso_log_scales` re-centers to Σ η = 0. A term that does
3776        // not support aniso scales is silently skipped (the seed is optional).
3777        if let Err(err) = set_spatial_aniso_log_scales(spec, term_idx, nudged) {
3778            log::debug!(
3779                "[spatial-kappa] response-aware anisotropy seed skipped for term {term_idx}: {err}"
3780            );
3781        }
3782    }
3783}
3784
3785/// Get the number of feature columns (spatial dimensionality) for a spatial term.
3786pub fn get_spatial_feature_dim(spec: &TermCollectionSpec, term_idx: usize) -> Option<usize> {
3787    spec.smooth_terms
3788        .get(term_idx)
3789        .and_then(|term| match &term.basis {
3790            SmoothBasisSpec::ThinPlate { feature_cols, .. } => Some(feature_cols.len()),
3791            SmoothBasisSpec::Matern { feature_cols, .. } => Some(feature_cols.len()),
3792            SmoothBasisSpec::Duchon { feature_cols, .. } => Some(feature_cols.len()),
3793            _ => None,
3794        })
3795}
3796
3797/// Log the learned per-axis spatial anisotropy for all spatial terms that
3798/// have `aniso_log_scales` set after optimization.
3799///
3800/// For scalar-scale families this reports eta, effective per-axis length
3801/// scales, and per-axis kappa values. For pure Duchon it reports the centered
3802/// eta contrasts only.
3803pub fn log_spatial_aniso_scales(spec: &TermCollectionSpec) {
3804    for (term_idx, term) in spec.smooth_terms.iter().enumerate() {
3805        let (aniso, length_scale) = match &term.basis {
3806            SmoothBasisSpec::Matern { spec, .. } => {
3807                (spec.aniso_log_scales.as_ref(), Some(spec.length_scale))
3808            }
3809            SmoothBasisSpec::Duchon { spec, .. } => {
3810                (spec.aniso_log_scales.as_ref(), spec.length_scale)
3811            }
3812            _ => (None, None),
3813        };
3814        let Some(eta) = aniso else { continue };
3815        if eta.is_empty() {
3816            continue;
3817        }
3818        let mut lines = match length_scale {
3819            Some(ls) => format!(
3820                "[spatial-kappa] term {} (\"{}\"): anisotropic length scales optimized (global length_scale={:.4})",
3821                term_idx, term.name, ls
3822            ),
3823            None => format!(
3824                "[spatial-kappa] term {} (\"{}\"): pure Duchon shape anisotropy optimized",
3825                term_idx, term.name
3826            ),
3827        };
3828        for (a, &eta_a) in eta.iter().enumerate() {
3829            if let Some(ls) = length_scale {
3830                let length_a = ls * (-eta_a).exp();
3831                let kappa_a = (1.0 / ls) * eta_a.exp();
3832                lines.push_str(&format!(
3833                    "\n  axis {}: eta={:+.4}, length={:.4}, kappa={:.4}",
3834                    a, eta_a, length_a, kappa_a
3835                ));
3836            } else {
3837                lines.push_str(&format!("\n  axis {}: eta={:+.4}", a, eta_a));
3838            }
3839        }
3840        log::info!("{}", lines);
3841    }
3842}
3843
3844/// Set `aniso_log_scales` on a spatial term's basis spec.
3845pub fn set_spatial_aniso_log_scales(
3846    spec: &mut TermCollectionSpec,
3847    term_idx: usize,
3848    eta: Vec<f64>,
3849) -> Result<(), EstimationError> {
3850    let eta = center_aniso_log_scales(&eta);
3851    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3852        crate::bail_invalid_estim!("spatial aniso_log_scales term index {term_idx} out of range");
3853    };
3854    match &mut term.basis {
3855        SmoothBasisSpec::Matern { spec, .. } => {
3856            spec.aniso_log_scales = Some(eta);
3857            Ok(())
3858        }
3859        SmoothBasisSpec::Duchon { spec, .. } => {
3860            spec.aniso_log_scales = Some(eta);
3861            Ok(())
3862        }
3863        _ => Err(EstimationError::InvalidInput(format!(
3864            "term '{}' does not support aniso_log_scales",
3865            term.name
3866        ))),
3867    }
3868}
3869
3870/// Sync knot-cloud-derived anisotropy contrasts from basis metadata back into
3871/// the mutable spec so the optimizer starts from the correct eta values.
3872///
3873/// Call this after building the smooth design but before initializing the
3874/// optimizer's psi coordinates. For each spatial term whose metadata contains
3875/// computed `aniso_log_scales`, this writes them into the spec.
3876pub fn sync_aniso_contrasts_from_metadata(
3877    spec: &mut TermCollectionSpec,
3878    design: &SmoothDesign,
3879) {
3880    for (term_idx, term) in design.terms.iter().enumerate() {
3881        let meta_aniso = match &term.metadata {
3882            BasisMetadata::Matern {
3883                aniso_log_scales, ..
3884            } => aniso_log_scales.clone(),
3885            BasisMetadata::Duchon {
3886                aniso_log_scales, ..
3887            } => aniso_log_scales.clone(),
3888            _ => None,
3889        };
3890        if let Some(eta) = meta_aniso
3891            && eta.len() > 1
3892        {
3893            set_spatial_aniso_log_scales(spec, term_idx, eta).ok();
3894        }
3895    }
3896}
3897
3898#[derive(Debug, Clone)]
3899pub struct SpatialLengthScaleOptimizationOptions {
3900    /// Enable outer-loop optimization over spatial κ (= 1 / length_scale)
3901    /// for supported radial-kernel smooths.
3902    /// This applies to ThinPlate, Matérn, and Duchon terms.
3903    pub enabled: bool,
3904    /// Maximum number of outer iterations in the exact joint [rho, psi] solve.
3905    pub max_outer_iter: usize,
3906    /// Relative improvement threshold for terminating the outer solve.
3907    pub rel_tol: f64,
3908    /// Initial log(length_scale) perturbation used for seed construction.
3909    pub log_step: f64,
3910    /// Minimum allowed length_scale during κ search.
3911    pub min_length_scale: f64,
3912    /// Maximum allowed length_scale during κ search.
3913    pub max_length_scale: f64,
3914    /// Automatic geometry-initializer threshold for large-scale spatial fits.
3915    ///
3916    /// When n exceeds twice this value, the fitter uses a spatially stratified
3917    /// subsample only to seed κ/anisotropy geometry: centers are resolved,
3918    /// axis contrasts are initialized from center/data spread, and one or two
3919    /// cheap ψ reseeding updates are applied. It never runs PIRLS, REML, ARC,
3920    /// BFGS, or any recursive optimizer on the pilot.
3921    ///
3922    /// The final coefficients, smoothing parameters, and spatial geometry are
3923    /// always optimized on the full dataset.
3924    ///
3925    /// Set to 0 to skip the pilot geometry initializer.
3926    pub pilot_subsample_threshold: usize,
3927    /// Optional wall-clock budget (seconds) for the whole outer smoothing search
3928    /// (gam#979). When a family arms the global deadline from this, an outer
3929    /// search that cannot certify convergence (survival marginal-slope's
3930    /// monotonicity-pinned constrained joint-Newton) returns its best-so-far
3931    /// iterate (or a catchable error) within the budget instead of hanging.
3932    /// `None` keeps the legacy unbounded behavior; the survival marginal-slope
3933    /// path applies a generous default when this is `None`.
3934    pub outer_wall_clock_budget_secs: Option<f64>,
3935}
3936
3937impl Default for SpatialLengthScaleOptimizationOptions {
3938    fn default() -> Self {
3939        Self {
3940            enabled: true,
3941            max_outer_iter: 80,
3942            rel_tol: 1e-4,
3943            log_step: std::f64::consts::LN_2,
3944            min_length_scale: 1e-3,
3945            max_length_scale: 1e3,
3946            pilot_subsample_threshold: 10_000,
3947            outer_wall_clock_budget_secs: None,
3948        }
3949    }
3950}
3951
3952impl SpatialLengthScaleOptimizationOptions {
3953    /// Validate the struct's invariants. Callers that construct these options
3954    /// from external input (CLI, config, Python API) should call this before
3955    /// passing the options into the fitter. Returns `Err` with a descriptive
3956    /// message when an invariant is violated; the fitter then panics or
3957    /// returns `EstimationError` at its own boundary.
3958    ///
3959    /// Invariants:
3960    ///   * `min_length_scale > 0`, finite
3961    ///   * `max_length_scale > 0`, finite
3962    ///   * `min_length_scale < max_length_scale`
3963    ///   * `rel_tol > 0`, finite
3964    ///   * `log_step > 0`, finite
3965    ///
3966    /// These invariants are what the downstream κ-bound and ψ-window code
3967    /// assumes (`-log(max_ls)` must be finite, `(min,max)` must not be
3968    /// inverted, etc.). Without validation, invalid options produce silent
3969    /// NaN-propagation inside the outer optimizer.
3970    pub fn validate(&self) -> Result<(), String> {
3971        if !self.min_length_scale.is_finite() || self.min_length_scale <= 0.0 {
3972            return Err(SmoothError::invalid_config(format!(
3973                "SpatialLengthScaleOptimizationOptions::min_length_scale must be > 0 and finite, got {}",
3974                self.min_length_scale
3975            ))
3976            .into());
3977        }
3978        if !self.max_length_scale.is_finite() || self.max_length_scale <= 0.0 {
3979            return Err(SmoothError::invalid_config(format!(
3980                "SpatialLengthScaleOptimizationOptions::max_length_scale must be > 0 and finite, got {}",
3981                self.max_length_scale
3982            ))
3983            .into());
3984        }
3985        if self.min_length_scale >= self.max_length_scale {
3986            return Err(SmoothError::invalid_config(format!(
3987                "SpatialLengthScaleOptimizationOptions requires min_length_scale < max_length_scale, got min={} max={}",
3988                self.min_length_scale, self.max_length_scale
3989            ))
3990            .into());
3991        }
3992        if !self.rel_tol.is_finite() || self.rel_tol <= 0.0 {
3993            return Err(SmoothError::invalid_config(format!(
3994                "SpatialLengthScaleOptimizationOptions::rel_tol must be > 0 and finite, got {}",
3995                self.rel_tol
3996            ))
3997            .into());
3998        }
3999        if !self.log_step.is_finite() || self.log_step <= 0.0 {
4000            return Err(SmoothError::invalid_config(format!(
4001                "SpatialLengthScaleOptimizationOptions::log_step must be > 0 and finite, got {}",
4002                self.log_step
4003            ))
4004            .into());
4005        }
4006        Ok(())
4007    }
4008}
4009
4010#[derive(Debug, Clone)]
4011pub struct RandomEffectBlock {
4012    pub name: String,
4013    /// O(n) group-label vector: group_ids[i] = column index in [0, num_groups).
4014    /// `None` if the observation's level is not in the kept set.
4015    pub group_ids: Vec<Option<usize>>,
4016    pub num_groups: usize,
4017    pub kept_levels: Vec<u64>,
4018}
4019
4020pub const BLOCK_SPARSE_ZERO_EPS: f64 = 1e-12;
4021
4022pub const BLOCK_SPARSE_MAX_DENSITY: f64 = 0.20;
4023
4024pub fn blocks_have_intrinsic_sparse_structure(blocks: &[DesignBlock]) -> bool {
4025    blocks
4026        .iter()
4027        .any(|block| matches!(block, DesignBlock::Sparse(_) | DesignBlock::RandomEffect(_)))
4028}
4029
4030pub fn sparse_compatible_block_nnz(block: &DesignBlock) -> Option<usize> {
4031    match block {
4032        DesignBlock::Intercept(n) => Some(*n),
4033        DesignBlock::RandomEffect(op) => {
4034            Some(op.group_ids.iter().filter(|gid| gid.is_some()).count())
4035        }
4036        DesignBlock::Sparse(sparse) => Some(sparse.val().len()),
4037        DesignBlock::Dense(dense) => dense.as_dense_ref().map(|matrix| {
4038            matrix
4039                .iter()
4040                .filter(|&&value| value.abs() > BLOCK_SPARSE_ZERO_EPS)
4041                .count()
4042        }),
4043    }
4044}
4045
4046pub fn try_build_sparse_design_from_blocks(
4047    blocks: &[DesignBlock],
4048) -> Result<Option<DesignMatrix>, BasisError> {
4049    if blocks.is_empty() {
4050        return Ok(None);
4051    }
4052    let nrows = blocks[0].nrows();
4053    let ncols: usize = blocks.iter().map(DesignBlock::ncols).sum();
4054    if nrows == 0 || ncols == 0 || ncols <= 32 {
4055        return Ok(None);
4056    }
4057
4058    let preserve_sparse_storage = blocks_have_intrinsic_sparse_structure(blocks);
4059    let sparse_nnz_limit = if preserve_sparse_storage {
4060        usize::MAX
4061    } else {
4062        let total_cells = nrows.saturating_mul(ncols);
4063        ((total_cells as f64) * BLOCK_SPARSE_MAX_DENSITY).floor() as usize
4064    };
4065    let mut nnz = 0usize;
4066    for block in blocks {
4067        let block_nnz = if let Some(block_nnz) = sparse_compatible_block_nnz(block) {
4068            block_nnz
4069        } else {
4070            return Ok(None);
4071        };
4072        nnz = nnz.saturating_add(block_nnz);
4073        if nnz > sparse_nnz_limit {
4074            return Ok(None);
4075        }
4076    }
4077
4078    let mut triplets = Vec::<Triplet<usize, usize, f64>>::with_capacity(nnz);
4079    let mut col_offset = 0usize;
4080    for block in blocks {
4081        match block {
4082            DesignBlock::Intercept(n) => {
4083                for row in 0..*n {
4084                    triplets.push(Triplet::new(row, col_offset, 1.0));
4085                }
4086            }
4087            DesignBlock::RandomEffect(op) => {
4088                for (row, group_id) in op.group_ids.iter().enumerate() {
4089                    if let Some(group) = group_id {
4090                        triplets.push(Triplet::new(row, col_offset + group, 1.0));
4091                    }
4092                }
4093            }
4094            DesignBlock::Sparse(sparse) => {
4095                let (symbolic, values) = sparse.parts();
4096                let col_ptr = symbolic.col_ptr();
4097                let row_idx = symbolic.row_idx();
4098                for col in 0..sparse.ncols() {
4099                    for idx in col_ptr[col]..col_ptr[col + 1] {
4100                        let value = values[idx];
4101                        if value.abs() > BLOCK_SPARSE_ZERO_EPS {
4102                            triplets.push(Triplet::new(row_idx[idx], col_offset + col, value));
4103                        }
4104                    }
4105                }
4106            }
4107            DesignBlock::Dense(dense) => {
4108                let matrix = dense.as_dense_ref().ok_or_else(|| {
4109                    BasisError::InvalidInput(
4110                        "sparse-compatible block assembly requires materialized dense blocks"
4111                            .to_string(),
4112                    )
4113                })?;
4114                for row in 0..matrix.nrows() {
4115                    for col in 0..matrix.ncols() {
4116                        let value = matrix[[row, col]];
4117                        if value.abs() > BLOCK_SPARSE_ZERO_EPS {
4118                            triplets.push(Triplet::new(row, col_offset + col, value));
4119                        }
4120                    }
4121                }
4122            }
4123        }
4124        col_offset += block.ncols();
4125    }
4126
4127    let sparse = SparseColMat::try_new_from_triplets(nrows, ncols, &triplets).map_err(|_| {
4128        BasisError::SparseCreation("failed to assemble sparse term-collection design".to_string())
4129    })?;
4130    Ok(Some(DesignMatrix::Sparse(
4131        gam_linalg::matrix::SparseDesignMatrix::new(sparse),
4132    )))
4133}
4134
4135pub fn assemble_term_collection_design_matrix(
4136    blocks: Vec<DesignBlock>,
4137) -> Result<DesignMatrix, BasisError> {
4138    if let Some(sparse) = try_build_sparse_design_from_blocks(&blocks)? {
4139        return Ok(sparse);
4140    }
4141    let block_op = BlockDesignOperator::new(blocks).map_err(|e| {
4142        BasisError::InvalidInput(format!("failed to build block design operator: {e}"))
4143    })?;
4144    Ok(DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
4145        Arc::new(block_op),
4146    )))
4147}
4148
4149pub fn select_columns(data: ArrayView2<'_, f64>, cols: &[usize]) -> Result<Array2<f64>, BasisError> {
4150    let n = data.nrows();
4151    let p = data.ncols();
4152    for &c in cols {
4153        if c >= p {
4154            crate::bail_dim_basis!("feature column {c} is out of bounds for data with {p} columns");
4155        }
4156    }
4157    let mut out = Array2::<f64>::zeros((n, cols.len()));
4158    for (j, &c) in cols.iter().enumerate() {
4159        out.column_mut(j).assign(&data.column(c));
4160    }
4161    Ok(out)
4162}
4163
4164pub fn nonfinite_value_label(value: f64) -> &'static str {
4165    if value.is_nan() {
4166        "NaN"
4167    } else if value.is_sign_positive() {
4168        "+Inf"
4169    } else {
4170        "-Inf"
4171    }
4172}
4173
4174pub fn validate_term_feature_column_finite(
4175    data: ArrayView2<'_, f64>,
4176    term_kind: &str,
4177    term_name: &str,
4178    feature_col: usize,
4179) -> Result<(), BasisError> {
4180    let p = data.ncols();
4181    if feature_col >= p {
4182        crate::bail_dim_basis!(
4183            "{term_kind} term '{term_name}' feature column {feature_col} out of bounds for {p} columns"
4184        );
4185    }
4186    for (row, &value) in data.column(feature_col).iter().enumerate() {
4187        if !value.is_finite() {
4188            crate::bail_invalid_basis!(
4189                "{term_kind} term '{term_name}' feature column {feature_col} row {row} contains non-finite value {}",
4190                nonfinite_value_label(value)
4191            );
4192        }
4193    }
4194    Ok(())
4195}
4196
4197pub fn validate_smooth_terms_finite_inputs(
4198    data: ArrayView2<'_, f64>,
4199    terms: &[SmoothTermSpec],
4200) -> Result<(), BasisError> {
4201    for term in terms {
4202        for feature_col in smooth_term_feature_cols(term) {
4203            validate_term_feature_column_finite(data, "smooth", &term.name, feature_col)?;
4204        }
4205    }
4206    Ok(())
4207}
4208
4209pub fn validate_term_collection_finite_inputs(
4210    data: ArrayView2<'_, f64>,
4211    spec: &TermCollectionSpec,
4212) -> Result<(), BasisError> {
4213    for term in &spec.linear_terms {
4214        validate_term_feature_column_finite(data, "linear", &term.name, term.feature_col)?;
4215    }
4216    for term in &spec.random_effect_terms {
4217        validate_term_feature_column_finite(data, "random-effect", &term.name, term.feature_col)?;
4218    }
4219    validate_smooth_terms_finite_inputs(data, &spec.smooth_terms)
4220}
4221
4222#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
4223pub struct JointSpatialCenterGroupKey {
4224    feature_cols: Vec<usize>,
4225    strategy_kind: CenterStrategyKind,
4226    strategy_aux: usize,
4227    requested_num_centers: usize,
4228    input_scale_bits: Option<Vec<u64>>,
4229}
4230
4231pub fn spatial_term_min_center_count(term: &SmoothTermSpec) -> usize {
4232    match &term.basis {
4233        SmoothBasisSpec::ThinPlate { feature_cols, .. } => feature_cols.len() + 1,
4234        SmoothBasisSpec::Duchon {
4235            feature_cols, spec, ..
4236        } => match spec.nullspace_order {
4237            crate::basis::DuchonNullspaceOrder::Zero => 1,
4238            crate::basis::DuchonNullspaceOrder::Linear => feature_cols.len() + 1,
4239            crate::basis::DuchonNullspaceOrder::Degree(degree) => {
4240                crate::basis::duchon_nullspace_dimension(feature_cols.len(), degree)
4241            }
4242        },
4243        SmoothBasisSpec::Matern { .. } => 1,
4244        _ => 1,
4245    }
4246}
4247
4248pub fn spatial_term_group_key(term: &SmoothTermSpec) -> Option<JointSpatialCenterGroupKey> {
4249    let (feature_cols, strategy, input_scales) = match &term.basis {
4250        SmoothBasisSpec::ThinPlate {
4251            feature_cols,
4252            spec,
4253            input_scales,
4254        } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4255        SmoothBasisSpec::Matern {
4256            feature_cols,
4257            spec,
4258            input_scales,
4259        } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4260        SmoothBasisSpec::Duchon {
4261            feature_cols,
4262            spec,
4263            input_scales,
4264        } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4265        _ => return None,
4266    };
4267    let strategy_kind = center_strategy_kind(strategy);
4268    let strategy_aux = match strategy {
4269        CenterStrategy::Auto(inner) => match inner.as_ref() {
4270            CenterStrategy::KMeans { max_iter, .. } => *max_iter,
4271            CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
4272            _ => 0,
4273        },
4274        CenterStrategy::KMeans { max_iter, .. } => *max_iter,
4275        CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
4276        _ => 0,
4277    };
4278    Some(JointSpatialCenterGroupKey {
4279        feature_cols: feature_cols.clone(),
4280        strategy_kind,
4281        strategy_aux,
4282        requested_num_centers: center_strategy_num_centers(strategy)?,
4283        input_scale_bits: input_scales
4284            .map(|values| values.iter().map(|value| value.to_bits()).collect()),
4285    })
4286}
4287
4288pub fn spatial_term_center_strategy(term: &SmoothTermSpec) -> Option<&CenterStrategy> {
4289    match &term.basis {
4290        SmoothBasisSpec::ThinPlate { spec, .. } => Some(&spec.center_strategy),
4291        SmoothBasisSpec::Matern { spec, .. } => Some(&spec.center_strategy),
4292        SmoothBasisSpec::Duchon { spec, .. } => Some(&spec.center_strategy),
4293        _ => None,
4294    }
4295}
4296
4297pub fn set_spatial_term_centers(
4298    term: &mut SmoothTermSpec,
4299    centers: Array2<f64>,
4300) -> Result<(), BasisError> {
4301    match &mut term.basis {
4302        SmoothBasisSpec::ThinPlate { spec, .. } => {
4303            spec.center_strategy = CenterStrategy::UserProvided(centers);
4304            Ok(())
4305        }
4306        SmoothBasisSpec::Matern { spec, .. } => {
4307            spec.center_strategy = CenterStrategy::UserProvided(centers);
4308            Ok(())
4309        }
4310        SmoothBasisSpec::Duchon { spec, .. } => {
4311            spec.center_strategy = CenterStrategy::UserProvided(centers);
4312            Ok(())
4313        }
4314        _ => Err(BasisError::InvalidInput(format!(
4315            "term '{}' does not support spatial center planning",
4316            term.name
4317        ))),
4318    }
4319}
4320
4321pub fn standardized_spatial_term_data(
4322    data: ArrayView2<'_, f64>,
4323    term: &SmoothTermSpec,
4324) -> Result<Array2<f64>, BasisError> {
4325    let (feature_cols, input_scales) = match &term.basis {
4326        SmoothBasisSpec::ThinPlate {
4327            feature_cols,
4328            input_scales,
4329            ..
4330        }
4331        | SmoothBasisSpec::Matern {
4332            feature_cols,
4333            input_scales,
4334            ..
4335        }
4336        | SmoothBasisSpec::Duchon {
4337            feature_cols,
4338            input_scales,
4339            ..
4340        } => (feature_cols, input_scales.as_ref()),
4341        _ => {
4342            crate::bail_invalid_basis!("term '{}' is not a spatial smooth", term.name);
4343        }
4344    };
4345    let mut x = select_columns(data, feature_cols)?;
4346    if let Some(scales) = input_scales {
4347        apply_input_standardization(&mut x, scales);
4348    } else if let Some(scales) = compute_spatial_input_scales(x.view()) {
4349        apply_input_standardization(&mut x, &scales);
4350    }
4351    Ok(x)
4352}
4353
4354pub fn plan_joint_spatial_centers_for_term_blocks(
4355    data: ArrayView2<'_, f64>,
4356    term_blocks: &[Vec<SmoothTermSpec>],
4357) -> Result<Vec<Vec<SmoothTermSpec>>, BasisError> {
4358    let mut planned_blocks = term_blocks.to_vec();
4359    let n = data.nrows();
4360    let mut groups: BTreeMap<JointSpatialCenterGroupKey, Vec<(usize, usize)>> = BTreeMap::new();
4361
4362    for (block_idx, terms) in planned_blocks.iter().enumerate() {
4363        for (term_idx, term) in terms.iter().enumerate() {
4364            let Some(strategy) = spatial_term_center_strategy(term) else {
4365                continue;
4366            };
4367            if !center_strategy_is_auto(strategy) {
4368                continue;
4369            }
4370            let Some(group_key) = spatial_term_group_key(term) else {
4371                continue;
4372            };
4373            if !matches!(
4374                group_key.strategy_kind,
4375                CenterStrategyKind::EqualMass
4376                    | CenterStrategyKind::EqualMassCovarRepresentative
4377                    | CenterStrategyKind::FarthestPoint
4378                    | CenterStrategyKind::KMeans
4379            ) {
4380                continue;
4381            }
4382            if center_strategy_num_centers(strategy).is_none() {
4383                continue;
4384            }
4385            groups
4386                .entry(group_key)
4387                .or_default()
4388                .push((block_idx, term_idx));
4389        }
4390    }
4391
4392    for (group_key, members) in groups {
4393        if members.len() < 2 {
4394            continue;
4395        }
4396        let min_required = members
4397            .iter()
4398            .map(|&(block_idx, term_idx)| {
4399                spatial_term_min_center_count(&planned_blocks[block_idx][term_idx])
4400            })
4401            .max()
4402            .unwrap_or(1);
4403        let joint_centers = group_key
4404            .requested_num_centers
4405            .max(min_required)
4406            .min(n.max(1));
4407        let (first_block_idx, first_term_idx) = members[0];
4408        let prototype = &planned_blocks[first_block_idx][first_term_idx];
4409        let standardized = standardized_spatial_term_data(data, prototype)?;
4410        let strategy = spatial_term_center_strategy(prototype).ok_or_else(|| {
4411            BasisError::InvalidInput(format!(
4412                "term '{}' lost its spatial center strategy during joint planning",
4413                prototype.name
4414            ))
4415        })?;
4416        let joint_strategy = center_strategy_with_num_centers(strategy, joint_centers)?;
4417        let shared_centers = select_centers_by_strategy(standardized.view(), &joint_strategy)?;
4418        log::info!(
4419            "sharing {} spatial centers across {} smooth terms over columns {:?} (requested {} centers)",
4420            shared_centers.nrows(),
4421            members.len(),
4422            group_key.feature_cols,
4423            group_key.requested_num_centers,
4424        );
4425        for (block_idx, term_idx) in members {
4426            set_spatial_term_centers(
4427                &mut planned_blocks[block_idx][term_idx],
4428                shared_centers.clone(),
4429            )?;
4430        }
4431    }
4432
4433    // Sentinel auto-init: Matern and thin-plate builders write length_scale =
4434    // 0.0 when the user didn't pass `length_scale=...`. Replace those with a
4435    // data-driven initialization here so REML starts in a regime where it can
4436    // escape; the hard-coded 1.0 default was a basin from which ν ≥ 5/2 Matern
4437    // could not recover on high-frequency truths, silently collapsing the fit
4438    // to a near-constant prediction.
4439    for block in planned_blocks.iter_mut() {
4440        for term in block.iter_mut() {
4441            auto_init_length_scale_in_place(data, term);
4442        }
4443    }
4444
4445    Ok(planned_blocks)
4446}
4447
4448/// Tiny positive floor for the auto length scale, guarding against a zero
4449/// kernel range when every feature column is (near-)constant.
4450const AUTO_LENGTH_SCALE_FLOOR: f64 = 1e-6;
4451
4452/// Widest per-axis range of the selected feature columns. Returns `None` when
4453/// every selected column is constant / non-finite (no usable spatial scale).
4454fn feature_columns_max_range(data: ArrayView2<'_, f64>, feature_cols: &[usize]) -> Option<f64> {
4455    let mut max_range = 0.0_f64;
4456    for &c in feature_cols {
4457        if c >= data.ncols() {
4458            continue;
4459        }
4460        let col = data.column(c);
4461        let mut lo = f64::INFINITY;
4462        let mut hi = f64::NEG_INFINITY;
4463        for &v in col.iter() {
4464            if v.is_finite() {
4465                if v < lo {
4466                    lo = v;
4467                }
4468                if v > hi {
4469                    hi = v;
4470                }
4471            }
4472        }
4473        if hi > lo {
4474            let r = hi - lo;
4475            if r > max_range {
4476                max_range = r;
4477            }
4478        }
4479    }
4480    if max_range.is_finite() && max_range > 0.0 {
4481        Some(max_range)
4482    } else {
4483        None
4484    }
4485}
4486
4487/// Compute a data-driven initial length scale from the per-axis range of the
4488/// feature columns. The heuristic `max_range / sqrt(n)` puts the kernel on
4489/// the wiggly side of REML's basin so the optimizer can grow it back if the
4490/// signal is smooth, but is small enough that high-frequency truths remain
4491/// reachable for smoother kernels (ν ≥ 5/2). Clamped to a tiny positive
4492/// floor so degenerate constant-input columns can't produce 0.
4493pub fn auto_initial_length_scale(data: ArrayView2<'_, f64>, feature_cols: &[usize]) -> f64 {
4494    let n = data.nrows();
4495    if n == 0 || feature_cols.is_empty() {
4496        return 1.0;
4497    }
4498    let Some(max_range) = feature_columns_max_range(data, feature_cols) else {
4499        return 1.0;
4500    };
4501    let init = max_range / (n as f64).sqrt();
4502    init.max(AUTO_LENGTH_SCALE_FLOOR).min(max_range)
4503}
4504
4505/// Density-adaptive auto length scale for a kernel basis with `num_centers`
4506/// requested centers (#1731).
4507///
4508/// The plain [`auto_initial_length_scale`] seed `max_range / sqrt(n)` is the
4509/// fill distance of the *n data points*; it is independent of the requested
4510/// center count `k`. For a radial kernel at a FIXED length scale, packing more
4511/// centers into the same cloud makes neighbouring basis functions overlap and
4512/// go numerically collinear, so the realized basis saturates in rank (the
4513/// `matern_rank_reduce_centers` cap) and a richer `k` becomes a no-op — or even
4514/// shrinks the basis. The kernel stays well-conditioned only while the length
4515/// scale tracks the *center* spacing, not the data spacing.
4516///
4517/// We seed the length scale at the fill distance of `max(n, k)` points,
4518/// `max_range / sqrt(max(n, k))`. When `n ≥ k` (the usual case) this is exactly
4519/// the existing `max_range / sqrt(n)` seed, so every current result and small-`k`
4520/// basis size is preserved bit-for-bit (in every covariate dimension). When
4521/// `k > n` (a dense center request on a small cloud, the regime where an
4522/// `n`-sized seed sits above the center spacing and over-smooths the centers
4523/// into collinearity) the seed shrinks with `k` to the center spacing, keeping
4524/// the requested centers numerically independent. This is the Matérn analogue of
4525/// the Duchon-promotion "length_scale from center spacing" rule
4526/// (`hybrid_duchon_promotion_length_scale`).
4527pub fn auto_initial_length_scale_for_centers(
4528    data: ArrayView2<'_, f64>,
4529    feature_cols: &[usize],
4530    num_centers: usize,
4531) -> f64 {
4532    let n = data.nrows();
4533    if n == 0 || feature_cols.is_empty() {
4534        return 1.0;
4535    }
4536    let Some(max_range) = feature_columns_max_range(data, feature_cols) else {
4537        return 1.0;
4538    };
4539    // Resolution density: at least the data points, but no coarser than the
4540    // center spacing once more centers than data are requested. Using the same
4541    // `sqrt` fill-distance law as `auto_initial_length_scale` keeps the seed
4542    // bit-identical whenever `n ≥ num_centers` (every dimension), and only
4543    // shrinks it — never grows it — when `num_centers > n`.
4544    let resolution_points = n.max(num_centers).max(1) as f64;
4545    let spacing = max_range / resolution_points.sqrt();
4546    spacing.max(AUTO_LENGTH_SCALE_FLOOR).min(max_range)
4547}
4548
4549/// Requested center count encoded by a [`CenterStrategy`], if it carries an
4550/// explicit count (used to make the Matérn auto length scale density-adaptive).
4551fn center_strategy_requested_count(strategy: &CenterStrategy) -> Option<usize> {
4552    match strategy {
4553        CenterStrategy::Auto(inner) => center_strategy_requested_count(inner),
4554        CenterStrategy::UserProvided(centers) => Some(centers.nrows()),
4555        CenterStrategy::EqualMass { num_centers }
4556        | CenterStrategy::EqualMassCovarRepresentative { num_centers }
4557        | CenterStrategy::FarthestPoint { num_centers }
4558        | CenterStrategy::KMeans { num_centers, .. } => Some(*num_centers),
4559        CenterStrategy::UniformGrid { .. } => None,
4560    }
4561}
4562
4563/// Walk a term and, if it is a Matern or thin-plate smooth whose length_scale
4564/// was left at the auto sentinel (`0.0`), overwrite it with
4565/// [`auto_initial_length_scale`].
4566pub fn auto_init_length_scale_in_place(data: ArrayView2<'_, f64>, term: &mut SmoothTermSpec) {
4567    auto_init_length_scale_in_basis(data, &mut term.basis);
4568}
4569
4570/// Replace the `0.0` auto-init length-scale sentinel with a data-derived value
4571/// for any Matern / thin-plate kernel reachable from this basis — including the
4572/// inner kernel of a `by=`/factor-smooth wrapper.
4573///
4574/// `by=<factor>` and the sum-to-zero factor smooth wrap a spatial kernel inside
4575/// `SmoothBasisSpec::ByVariable` / `SmoothBasisSpec::FactorSumToZero` /
4576/// `SmoothBasisSpec::BySmooth`, so the wrapper variant is what the planner sees.
4577/// Without recursing into the wrapped basis the inner ThinPlate/Matern keeps the
4578/// `0.0` sentinel (the post-`1605b3a6e` builder default), which makes the kernel
4579/// distance divide by `length_scale² = 0`, producing a non-finite design at both
4580/// fit and predict time. Recurse so the inner kernel is initialized identically
4581/// to a top-level one.
4582pub fn auto_init_length_scale_in_basis(data: ArrayView2<'_, f64>, basis: &mut SmoothBasisSpec) {
4583    match basis {
4584        SmoothBasisSpec::Matern {
4585            feature_cols, spec, ..
4586        } => {
4587            if spec.length_scale == 0.0 {
4588                // Density-adaptive seed (#1731): when the requested center count
4589                // is known, scale the auto length scale with the *center*
4590                // spacing so a richer `k` stays numerically full-rank instead of
4591                // saturating against `matern_rank_reduce_centers`. For `n ≥ k`
4592                // (the usual case) this is identical to the plain `max_range /
4593                // sqrt(n)` seed in 2-D, so small-`k` results are unchanged. The
4594                // unconstrained / non-explicit `UniformGrid` strategy falls back
4595                // to the plain seed.
4596                spec.length_scale = match center_strategy_requested_count(&spec.center_strategy) {
4597                    Some(k) => auto_initial_length_scale_for_centers(data, feature_cols, k),
4598                    None => auto_initial_length_scale(data, feature_cols),
4599                };
4600            }
4601        }
4602        SmoothBasisSpec::ThinPlate {
4603            feature_cols, spec, ..
4604        } => {
4605            if spec.length_scale == 0.0 {
4606                spec.length_scale = auto_initial_length_scale(data, feature_cols);
4607            }
4608        }
4609        SmoothBasisSpec::ByVariable { inner, .. }
4610        | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
4611            auto_init_length_scale_in_basis(data, inner);
4612        }
4613        SmoothBasisSpec::BySmooth { smooth, .. } => {
4614            auto_init_length_scale_in_basis(data, smooth);
4615        }
4616        _ => {}
4617    }
4618}
4619
4620impl LinearFitConditioning {
4621    pub fn from_columns(design: &TermCollectionDesign, selected_cols: &[usize]) -> Self {
4622        const SCALE_EPS: f64 = 1e-12;
4623        let n = design.design.nrows();
4624        let p = design.design.ncols();
4625        let mut columns = Vec::with_capacity(selected_cols.len());
4626        if n == 0 || selected_cols.is_empty() {
4627            return Self {
4628                intercept_idx: design.intercept_range.start,
4629                columns,
4630            };
4631        }
4632        let chunk_rows = gam_linalg::utils::row_chunk_for_byte_budget(n, p);
4633        // Two-pass mean/variance so operator-backed designs don't need to
4634        // materialize the full dense matrix. Pass 1 accumulates per-column
4635        // sums; pass 2 accumulates the sum of squared deviations from the
4636        // pass-1 mean. This matches the original `Σ (x − mean)² / n` formula
4637        // without the catastrophic cancellation of `E[X²] − E[X]²`.
4638        let mut sums = vec![0.0_f64; selected_cols.len()];
4639        for start in (0..n).step_by(chunk_rows) {
4640            let end = (start + chunk_rows).min(n);
4641            let chunk = design
4642                .design
4643                .try_row_chunk(start..end)
4644                .expect("LinearFitConditioning::from_columns row chunk failed");
4645            for (k, &col_idx) in selected_cols.iter().enumerate() {
4646                let column = chunk.column(col_idx);
4647                for &v in column.iter() {
4648                    sums[k] += v;
4649                }
4650            }
4651        }
4652        let inv_n = 1.0_f64 / n as f64;
4653        let means: Vec<f64> = sums.iter().map(|&s| s * inv_n).collect();
4654        let mut sq_devs = vec![0.0_f64; selected_cols.len()];
4655        for start in (0..n).step_by(chunk_rows) {
4656            let end = (start + chunk_rows).min(n);
4657            let chunk = design
4658                .design
4659                .try_row_chunk(start..end)
4660                .expect("LinearFitConditioning::from_columns row chunk failed");
4661            for (k, &col_idx) in selected_cols.iter().enumerate() {
4662                let mean_k = means[k];
4663                let column = chunk.column(col_idx);
4664                for &v in column.iter() {
4665                    let d = v - mean_k;
4666                    sq_devs[k] += d * d;
4667                }
4668            }
4669        }
4670        for (k, &col_idx) in selected_cols.iter().enumerate() {
4671            let mean = means[k];
4672            let var = sq_devs[k] * inv_n;
4673            let (mean, scale) = if var.is_finite() && var > SCALE_EPS * SCALE_EPS {
4674                (mean, var.sqrt())
4675            } else {
4676                // Leave nearly-constant columns untouched; centering them would collapse
4677                // the design column to ~0 and change the model rather than just condition it.
4678                (0.0, 1.0)
4679            };
4680            columns.push(LinearColumnConditioning {
4681                col_idx,
4682                mean,
4683                scale,
4684            });
4685        }
4686        Self {
4687            intercept_idx: design.intercept_range.start,
4688            columns,
4689        }
4690    }
4691
4692    pub fn apply_to_design(&self, design: &Array2<f64>) -> Array2<f64> {
4693        let mut out = design.clone();
4694        for col in &self.columns {
4695            {
4696                let mut dst = out.column_mut(col.col_idx);
4697                dst -= col.mean;
4698            }
4699            if col.scale != 1.0 {
4700                out.column_mut(col.col_idx).mapv_inplace(|v| v / col.scale);
4701            }
4702        }
4703        out
4704    }
4705
4706    fn transform_matrix_columnswith_a(&self, mat: &Array2<f64>) -> Array2<f64> {
4707        let mut out = mat.clone();
4708        let intercept = self.intercept_idx;
4709        for col in &self.columns {
4710            let intercept_col = out.column(intercept).to_owned();
4711            let mut target = out.column_mut(col.col_idx);
4712            target -= &(intercept_col * col.mean);
4713            if col.scale != 1.0 {
4714                target.mapv_inplace(|v| v / col.scale);
4715            }
4716        }
4717        out
4718    }
4719
4720    fn transform_matrixrowswith_a_transpose(&self, mat: &Array2<f64>) -> Array2<f64> {
4721        let mut out = mat.clone();
4722        let intercept = self.intercept_idx;
4723        for col in &self.columns {
4724            let interceptrow = out.row(intercept).to_owned();
4725            let mut target = out.row_mut(col.col_idx);
4726            target -= &(interceptrow * col.mean);
4727            if col.scale != 1.0 {
4728                target.mapv_inplace(|v| v / col.scale);
4729            }
4730        }
4731        out
4732    }
4733
4734    /// Left-multiply `mat_internal` by `M⁻ᵀ` where `M⁻¹[intercept, j] = mean_j`
4735    /// and `M⁻¹[j, j] = scale_j` for each conditioned column. Used together
4736    /// with [`Self::right_multiply_by_m_inv`] to back-transform an internal
4737    /// penalized Hessian to the original coefficient basis.
4738    fn left_multiply_by_m_inv_transpose(&self, mat_internal: &Array2<f64>) -> Array2<f64> {
4739        let mut out = mat_internal.clone();
4740        let intercept = self.intercept_idx;
4741        let interceptrow_snapshot = mat_internal.row(intercept).to_owned();
4742        for col in &self.columns {
4743            if col.scale != 1.0 {
4744                out.row_mut(col.col_idx).mapv_inplace(|v| v * col.scale);
4745            }
4746            if col.mean != 0.0 {
4747                let mut target = out.row_mut(col.col_idx);
4748                target += &(&interceptrow_snapshot * col.mean);
4749            }
4750        }
4751        out
4752    }
4753
4754    /// Right-multiply `mat_internal` by `M⁻¹`. Mirror of
4755    /// [`Self::left_multiply_by_m_inv_transpose`] on columns.
4756    fn right_multiply_by_m_inv(&self, mat_internal: &Array2<f64>) -> Array2<f64> {
4757        let mut out = mat_internal.clone();
4758        let intercept = self.intercept_idx;
4759        let intercept_col_snapshot = mat_internal.column(intercept).to_owned();
4760        for col in &self.columns {
4761            if col.scale != 1.0 {
4762                out.column_mut(col.col_idx).mapv_inplace(|v| v * col.scale);
4763            }
4764            if col.mean != 0.0 {
4765                let mut target = out.column_mut(col.col_idx);
4766                target += &(&intercept_col_snapshot * col.mean);
4767            }
4768        }
4769        out
4770    }
4771
4772    /// Transform blockwise penalties through the conditioning.
4773    ///
4774    /// For block-local penalties whose `col_range` does not overlap with any
4775    /// conditioning column, the transform is identity (the conditioning only
4776    /// affects unpenalized linear columns). In that common case the penalty
4777    /// passes through unchanged, avoiding O(p²) materialization entirely.
4778    pub fn transform_blockwise_penalties_to_internal(
4779        &self,
4780        penalties: &[BlockwisePenalty],
4781        p: usize,
4782    ) -> Vec<crate::penalty_spec::PenaltySpec> {
4783        let conditioning_cols: std::collections::HashSet<usize> =
4784            self.columns.iter().map(|c| c.col_idx).collect();
4785        penalties
4786            .iter()
4787            .map(|bp| {
4788                let overlaps =
4789                    (bp.col_range.start..bp.col_range.end).any(|j| conditioning_cols.contains(&j));
4790                if overlaps {
4791                    // Rare: penalty block overlaps conditioning columns.
4792                    // Fall back to dense transform.
4793                    let global = bp.to_global(p);
4794                    let right = self.transform_matrix_columnswith_a(&global);
4795                    let transformed = self.transform_matrixrowswith_a_transpose(&right);
4796                    crate::penalty_spec::PenaltySpec::Dense(transformed)
4797                } else {
4798                    // Common: smooth penalty block doesn't touch linear columns.
4799                    // The conditioning is identity on this block.
4800                    crate::penalty_spec::PenaltySpec::from_blockwise(bp.clone())
4801                }
4802            })
4803            .collect()
4804    }
4805
4806    pub fn backtransform_beta(&self, beta_internal: &Array1<f64>) -> Array1<f64> {
4807        let mut beta = beta_internal.clone();
4808        let intercept = self.intercept_idx;
4809        for col in &self.columns {
4810            beta[intercept] -= beta_internal[col.col_idx] * col.mean / col.scale;
4811            beta[col.col_idx] = beta_internal[col.col_idx] / col.scale;
4812        }
4813        beta
4814    }
4815
4816    /// `H_orig = M⁻ᵀ · H_int · M⁻¹`, derived from
4817    /// `L_int(β_int) = L_orig(M · β_int)` via the chain rule.
4818    pub fn transform_penalized_hessian_to_original(&self, h_internal: &Array2<f64>) -> Array2<f64> {
4819        let right = self.right_multiply_by_m_inv(h_internal);
4820        self.left_multiply_by_m_inv_transpose(&right)
4821    }
4822
4823    pub fn internal_bounds_for(&self, col_idx: usize, min: f64, max: f64) -> (f64, f64) {
4824        if let Some(col) = self.columns.iter().find(|c| c.col_idx == col_idx) {
4825            (min * col.scale, max * col.scale)
4826        } else {
4827            (min, max)
4828        }
4829    }
4830}
4831
4832pub fn freeze_raw_spatial_metadata(metadata: BasisMetadata, raw_cols: usize) -> BasisMetadata {
4833    match metadata {
4834        BasisMetadata::ThinPlate {
4835            centers,
4836            length_scale,
4837            periodic,
4838            identifiability_transform: None,
4839            input_scales,
4840            radial_reparam,
4841        } => BasisMetadata::ThinPlate {
4842            centers,
4843            length_scale,
4844            periodic,
4845            identifiability_transform: Some(Array2::eye(raw_cols)),
4846            input_scales,
4847            radial_reparam,
4848        },
4849        BasisMetadata::Duchon {
4850            centers,
4851            length_scale,
4852            periodic,
4853            power,
4854            nullspace_order,
4855            identifiability_transform: None,
4856            input_scales,
4857            aniso_log_scales,
4858            operator_collocation_points,
4859            radial_reparam,
4860        } => BasisMetadata::Duchon {
4861            centers,
4862            length_scale,
4863            periodic,
4864            power,
4865            nullspace_order,
4866            identifiability_transform: Some(Array2::eye(raw_cols)),
4867            input_scales,
4868            aniso_log_scales,
4869            operator_collocation_points,
4870            radial_reparam,
4871        },
4872        other => other,
4873    }
4874}
4875
4876pub fn matern_operator_penalty_triplet_from_metadata(
4877    metadata: &BasisMetadata,
4878) -> Result<(Vec<Array2<f64>>, Vec<usize>, Vec<PenaltyInfo>), BasisError> {
4879    let BasisMetadata::Matern {
4880        centers,
4881        length_scale,
4882        periodic,
4883        nu,
4884        include_intercept,
4885        identifiability_transform,
4886        aniso_log_scales,
4887        input_scales,
4888        ..
4889    } = metadata
4890    else {
4891        crate::bail_invalid_basis!("Matérn operator penalties require Matérn metadata");
4892    };
4893    // The metadata records `length_scale` in *original* (un-standardized) data
4894    // coordinates, while `centers` live in the *standardized* coordinate frame
4895    // (per-axis division by `input_scales`). The realized design built the
4896    // kernel against those standardized centers using the σ_geom-compensated
4897    // effective length scale `length_scale / σ_geom`. The collocation operators
4898    // here are evaluated on the same standardized centers, so they must use the
4899    // SAME effective length scale — otherwise the penalty regularizes a
4900    // different RKHS range than the design lives in, leaving rough coefficient
4901    // directions effectively unpenalized. That mismatch is benign in 1-D
4902    // (no standardization) but produces a catastrophic out-of-sample blow-up in
4903    // d ≥ 2 where σ_geom ≠ 1 (#706).
4904    let penalty_length_scale = match input_scales.as_deref() {
4905        Some(scales) => compensate_length_scale_for_standardization(*length_scale, scales),
4906        None => *length_scale,
4907    };
4908    matern_operator_penalty_triplet_at_length_scale(
4909        centers.view(),
4910        periodic.as_deref(),
4911        identifiability_transform.as_ref(),
4912        *nu,
4913        *include_intercept,
4914        aniso_log_scales.as_deref(),
4915        penalty_length_scale,
4916    )
4917}
4918
4919/// Build the canonical Matérn operator-penalty triplet (mass / tension /
4920/// stiffness) at an explicit **effective** length scale — i.e. the
4921/// σ_geom-compensated, standardized-frame scale the design's kernel was built
4922/// against (NOT the original-coordinate `length_scale` stored in metadata).
4923///
4924/// This is the SINGLE source of truth for the Matérn penalty topology. Two
4925/// callers route through it and must therefore stay byte-for-byte consistent:
4926///   * the cold/slow design rebuild (`matern_operator_penalty_triplet_from_metadata`,
4927///     compensating the frozen metadata `length_scale`), and
4928///   * the n-free κ-optimizer re-key (`FrozenTermCollectionIncrementalRealizer::
4929///     canonical_penalties_at_psi`, compensating the trial `ψ → exp(-ψ)` scale).
4930///
4931/// Sharing the body makes the penalty BLOCK COUNT and the per-block numerics
4932/// one deterministic function of `(geometry, ν, η, ℓ_eff)`. The active-operator
4933/// gate is `m = ν + d/2`, which is independent of ℓ, so the block count is
4934/// **ψ-stable by construction**: the re-key can never produce a different number
4935/// of blocks than the frozen design (the desync that #1270 hard-errored on).
4936pub fn matern_operator_penalty_triplet_at_length_scale(
4937    centers: ArrayView2<'_, f64>,
4938    periodic: Option<&[Option<f64>]>,
4939    identifiability_transform: Option<&Array2<f64>>,
4940    nu: crate::basis::MaternNu,
4941    include_intercept: bool,
4942    aniso_log_scales: Option<&[f64]>,
4943    effective_length_scale: f64,
4944) -> Result<(Vec<Array2<f64>>, Vec<usize>, Vec<PenaltyInfo>), BasisError> {
4945    let penalty_centers = crate::basis::expand_periodic_centers(&centers.to_owned(), periodic)?;
4946    let ops = build_matern_collocation_operator_matrices(
4947        penalty_centers.view(),
4948        None,
4949        effective_length_scale,
4950        nu,
4951        include_intercept,
4952        identifiability_transform.map(|z| z.view()),
4953        aniso_log_scales,
4954    )?;
4955    // Gate the operator dials on the Matérn-ν RKHS Sobolev order m = ν + d/2:
4956    // mass (j=0) is always on, tension (j=1) is on for m > 1, stiffness (j=2)
4957    // is on for m > 2. The threshold is strict so the roughest kernel ν=1/2 in
4958    // d=1 (m=1, the exponential/OU H¹ process) sheds both higher operators —
4959    // its kernel already encodes the H¹ control, so adding an extra tension
4960    // dial over-smooths the oscillation it is meant to track (#707). The
4961    // matching gate lives at `DuchonOperatorPenaltySpec::matern_for_smoothness`.
4962    const ORDER_EPS: f64 = 1e-9;
4963    let d = penalty_centers.ncols();
4964    let m = nu.half_integer_value() + 0.5 * d as f64;
4965    let mut candidates = Vec::with_capacity(3);
4966    for (raw, source, min_order) in [
4967        (ops.d0.t().dot(&ops.d0), PenaltySource::OperatorMass, 0.0),
4968        (ops.d1.t().dot(&ops.d1), PenaltySource::OperatorTension, 1.0),
4969        (
4970            ops.d2.t().dot(&ops.d2),
4971            PenaltySource::OperatorStiffness,
4972            2.0,
4973        ),
4974    ] {
4975        if min_order > 0.0 && m <= min_order + ORDER_EPS {
4976            continue;
4977        }
4978        let sym = (&raw + &raw.t()) * 0.5;
4979        let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&sym);
4980        candidates.push(PenaltyCandidate {
4981            matrix,
4982            nullspace_dim_hint: 0,
4983            source,
4984            normalization_scale,
4985            kronecker_factors: None,
4986            op: None,
4987        });
4988    }
4989    filter_active_penalty_candidates(candidates)
4990}
4991
4992pub fn normalize_penalty_in_constrained_space(matrix: &Array2<f64>) -> (Array2<f64>, f64) {
4993    // Constrained-space normalization:
4994    //   c = ||S_con||_F,  S_tilde = S_con / c.
4995    // This is the only normalization coherent with a REML objective that is
4996    // evaluated entirely in constrained coordinates.
4997    let matrix = (matrix + &matrix.t().to_owned()) * 0.5;
4998    // Clamp noise-floor negative eigenvalues so β'Sβ is non-negative as a contract, not just in exact arithmetic.
4999    let matrix = crate::basis::project_penalty_to_psd_cone(&matrix);
5000    let c = matrix.iter().map(|v| v * v).sum::<f64>().sqrt();
5001    if c.is_finite() && c > 0.0 {
5002        (matrix.mapv(|v| v / c), c)
5003    } else {
5004        (matrix, 1.0)
5005    }
5006}
5007
5008pub fn tensor_product_design_from_sparse_marginals(
5009    marginal_sparse: &[&SparseColMat<usize, f64>],
5010) -> Result<SparseColMat<usize, f64>, BasisError> {
5011    if marginal_sparse.is_empty() {
5012        crate::bail_invalid_basis!("TensorBSpline requires at least one marginal basis");
5013    }
5014    let n = marginal_sparse[0].nrows();
5015    for (i, m) in marginal_sparse.iter().enumerate().skip(1) {
5016        if m.nrows() != n {
5017            crate::bail_dim_basis!(
5018                "tensor sparse marginal row mismatch at dim {i}: expected {n}, got {}",
5019                m.nrows()
5020            );
5021        }
5022    }
5023    let dims: Vec<usize> = marginal_sparse.iter().map(|m| m.ncols()).collect();
5024    let total_cols = dims.iter().try_fold(1usize, |acc, &q| {
5025        acc.checked_mul(q)
5026            .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))
5027    })?;
5028    let mut strides = vec![1usize; dims.len()];
5029    for d in (0..dims.len().saturating_sub(1)).rev() {
5030        strides[d] = strides[d + 1]
5031            .checked_mul(dims[d + 1])
5032            .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))?;
5033    }
5034
5035    use faer::sparse::SparseRowMat;
5036    let csrs: Vec<SparseRowMat<usize, f64>> = marginal_sparse
5037        .iter()
5038        .enumerate()
5039        .map(|(d, m)| {
5040            m.as_ref().to_row_major().map_err(|e| {
5041                BasisError::SparseCreation(format!(
5042                    "tensor sparse marginal {d} CSR conversion failed: {e:?}"
5043                ))
5044            })
5045        })
5046        .collect::<Result<Vec<_>, _>>()?;
5047    let row_ptrs: Vec<&[usize]> = csrs.iter().map(|c| c.symbolic().row_ptr()).collect();
5048    let col_idxs: Vec<&[usize]> = csrs.iter().map(|c| c.symbolic().col_idx()).collect();
5049    let vals: Vec<&[f64]> = csrs.iter().map(|c| c.val()).collect();
5050
5051    use rayon::prelude::*;
5052    const CHUNK: usize = 1024;
5053    let num_chunks = n.div_ceil(CHUNK);
5054    let per_chunk: Vec<Vec<Triplet<usize, usize, f64>>> = (0..num_chunks)
5055        .into_par_iter()
5056        .map(|chunk_idx| {
5057            let row_start = chunk_idx * CHUNK;
5058            let row_end = (row_start + CHUNK).min(n);
5059            let mut chunk_triplets = Vec::<Triplet<usize, usize, f64>>::new();
5060            let mut cur_cols = Vec::<usize>::with_capacity(64);
5061            let mut cur_vals = Vec::<f64>::with_capacity(64);
5062            let mut next_cols = Vec::<usize>::with_capacity(64);
5063            let mut next_vals = Vec::<f64>::with_capacity(64);
5064            for i in row_start..row_end {
5065                cur_cols.clear();
5066                cur_vals.clear();
5067                cur_cols.push(0);
5068                cur_vals.push(1.0);
5069                let mut row_is_zero = false;
5070                for d in 0..dims.len() {
5071                    let row_start_d = row_ptrs[d][i];
5072                    let row_end_d = row_ptrs[d][i + 1];
5073                    if row_start_d == row_end_d {
5074                        row_is_zero = true;
5075                        break;
5076                    }
5077                    let stride = strides[d];
5078                    next_cols.clear();
5079                    next_vals.clear();
5080                    next_cols.reserve(cur_cols.len() * (row_end_d - row_start_d));
5081                    next_vals.reserve(cur_vals.len() * (row_end_d - row_start_d));
5082                    for (&prev_col, &prev_val) in cur_cols.iter().zip(cur_vals.iter()) {
5083                        for ptr in row_start_d..row_end_d {
5084                            let cj = col_idxs[d][ptr];
5085                            let vj = vals[d][ptr];
5086                            next_cols.push(prev_col + cj * stride);
5087                            next_vals.push(prev_val * vj);
5088                        }
5089                    }
5090                    std::mem::swap(&mut cur_cols, &mut next_cols);
5091                    std::mem::swap(&mut cur_vals, &mut next_vals);
5092                }
5093                if row_is_zero {
5094                    continue;
5095                }
5096                for (&col, &val) in cur_cols.iter().zip(cur_vals.iter()) {
5097                    chunk_triplets.push(Triplet::new(i, col, val));
5098                }
5099            }
5100            chunk_triplets
5101        })
5102        .collect();
5103    let total_nnz: usize = per_chunk.iter().map(Vec::len).sum();
5104    let mut triplets = Vec::<Triplet<usize, usize, f64>>::with_capacity(total_nnz);
5105    for chunk in per_chunk {
5106        triplets.extend(chunk);
5107    }
5108    SparseColMat::try_new_from_triplets(n, total_cols, &triplets).map_err(|e| {
5109        BasisError::SparseCreation(format!(
5110            "failed to assemble sparse tensor product design: {e:?}"
5111        ))
5112    })
5113}
5114
5115pub fn dense_local_margin_to_sparse(
5116    dense: &Array2<f64>,
5117) -> Result<SparseColMat<usize, f64>, BasisError> {
5118    let expected_row_nnz = dense.ncols().min(4);
5119    let mut triplets =
5120        Vec::<Triplet<usize, usize, f64>>::with_capacity(dense.nrows() * expected_row_nnz);
5121    for ((row, col), &value) in dense.indexed_iter() {
5122        if value != 0.0 {
5123            triplets.push(Triplet::new(row, col, value));
5124        }
5125    }
5126    SparseColMat::try_new_from_triplets(dense.nrows(), dense.ncols(), &triplets).map_err(|e| {
5127        BasisError::SparseCreation(format!(
5128            "failed to convert tensor marginal design to sparse form: {e:?}"
5129        ))
5130    })
5131}
5132
5133pub struct TensorMarginRangeNullProjectors {
5134    range: Array2<f64>,
5135    null: Array2<f64>,
5136}
5137
5138pub fn projector_from_columns(columns: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
5139    if indices.is_empty() {
5140        return Array2::<f64>::zeros((columns.nrows(), columns.nrows()));
5141    }
5142    let basis = columns.select(Axis(1), indices);
5143    basis.dot(&basis.t())
5144}
5145
5146pub fn tensor_margin_range_null_projectors(
5147    normalized_marginal_penalties: &[(Array2<f64>, f64)],
5148) -> Result<Vec<TensorMarginRangeNullProjectors>, BasisError> {
5149    normalized_marginal_penalties
5150        .iter()
5151        .enumerate()
5152        .map(|(dim, (penalty, _))| {
5153            let analysis = crate::basis::analyze_penalty_block(penalty)?;
5154            if analysis.rank == 0 {
5155                crate::bail_invalid_basis!(
5156                    "t2 separable tensor penalty margin {dim} has rank-zero penalty; \
5157                     cannot split penalized and null subspaces"
5158                );
5159            }
5160            let mut range_idx = Vec::<usize>::new();
5161            let mut null_idx = Vec::<usize>::new();
5162            for (idx, &ev) in analysis.eigenvalues.iter().enumerate() {
5163                if ev > analysis.tol {
5164                    range_idx.push(idx);
5165                } else {
5166                    null_idx.push(idx);
5167                }
5168            }
5169            Ok(TensorMarginRangeNullProjectors {
5170                range: projector_from_columns(&analysis.eigenvectors, &range_idx),
5171                null: projector_from_columns(&analysis.eigenvectors, &null_idx),
5172            })
5173        })
5174        .collect()
5175}
5176
5177pub fn build_tensor_bspline_basis(
5178    data: ArrayView2<'_, f64>,
5179    feature_cols: &[usize],
5180    spec: &TensorBSplineSpec,
5181) -> Result<BasisBuildResult, BasisError> {
5182    if feature_cols.is_empty() {
5183        crate::bail_invalid_basis!("TensorBSpline requires at least one feature column");
5184    }
5185    if feature_cols.len() != spec.marginalspecs.len() {
5186        crate::bail_dim_basis!(
5187            "TensorBSpline feature/spec mismatch: feature_cols={}, marginalspecs={}",
5188            feature_cols.len(),
5189            spec.marginalspecs.len()
5190        );
5191    }
5192    if !spec.periods.is_empty() && spec.periods.len() != feature_cols.len() {
5193        crate::bail_dim_basis!(
5194            "TensorBSpline periods length {} does not match feature count {}",
5195            spec.periods.len(),
5196            feature_cols.len()
5197        );
5198    }
5199    let p = data.ncols();
5200    for &c in feature_cols {
5201        if c >= p {
5202            crate::bail_dim_basis!(
5203                "tensor feature column {c} is out of bounds for data with {p} columns"
5204            );
5205        }
5206    }
5207
5208    let mut marginal_knots = Vec::<Array1<f64>>::with_capacity(feature_cols.len());
5209    // Per-margin cr flag (#1074): `true` when the margin is a natural cubic
5210    // regression spline, so the tensor freeze rebuilds the cr knotspec.
5211    let mut marginal_is_cr_flags = Vec::<bool>::with_capacity(feature_cols.len());
5212    let mut marginal_degrees = Vec::<usize>::with_capacity(feature_cols.len());
5213    let mut marginalnum_basis = Vec::<usize>::with_capacity(feature_cols.len());
5214    let mut marginal_penalties = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
5215    let mut marginal_designs = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
5216    // Per-margin effective period: either user-set via `spec.periods` or
5217    // implied by a `PeriodicUniform` marginal knotspec (which the 1D B-spline
5218    // builder realizes as a cyclic B-spline basis).
5219    // Captured here so freeze→reload round-trips both routes back to a
5220    // `PeriodicUniform` marginal knotspec; otherwise a `PeriodicUniform`
5221    // margin specified without `spec.periods` would freeze as a plain
5222    // `Provided(knots)` open spline and lose its wrap-around at predict time.
5223    let mut marginal_effective_periods = Vec::<Option<f64>>::with_capacity(feature_cols.len());
5224    // Per-marginal sparse representation, populated when the 1D builder returned
5225    // a `DesignMatrix::Sparse`. Used to assemble the Khatri-Rao tensor product
5226    // sparsely (only ∏(degree+1) nonzeros per row) instead of densifying to
5227    // shape (n, ∏ q_j) up front. Periodic B-spline margins are local-support
5228    // bases too; when the 1D builder returns them densely, we convert that
5229    // marginal back to sparse form so cylinder/torus tensor products keep the
5230    // same scale behavior as open tensor products.
5231    let mut marginal_sparse =
5232        Vec::<Option<SparseColMat<usize, f64>>>::with_capacity(feature_cols.len());
5233
5234    // Reuse the robust 1D builder to ensure the same knot validation and
5235    // marginal difference-penalty construction as standalone smooth terms.
5236    for (dim, (&col, marginalspec)) in feature_cols
5237        .iter()
5238        .zip(spec.marginalspecs.iter())
5239        .enumerate()
5240    {
5241        // Tensor basis uses raw marginal knot-product columns. Applying 1D
5242        // identifiability constraints here would change marginal penalty sizes
5243        // without changing the tensor design construction, causing dimension
5244        // mismatch. Keep marginal builders unconstrained at this stage.
5245        let mut marginal_unconstrained = marginalspec.clone();
5246        marginal_unconstrained.identifiability = BSplineIdentifiability::None;
5247        let built = build_bspline_basis_1d(data.column(col), &marginal_unconstrained)?;
5248        // A cr (`NaturalCubicRegression`) margin emits `CubicRegression1D`
5249        // metadata whose `knots` are the k value-knots; a B-spline margin emits
5250        // `BSpline1D` with the clamped knot vector. Capture either so the
5251        // tensor freeze can rebuild the exact same marginal knotspec (#1074).
5252        let (knots, marginal_is_cr) = match built.metadata {
5253            BasisMetadata::BSpline1D { knots, .. } => (knots, false),
5254            BasisMetadata::CubicRegression1D { knots, .. } => (knots, true),
5255            _ => {
5256                crate::bail_invalid_basis!(
5257                    "internal TensorBSpline error at dim {dim}: expected BSpline1D or CubicRegression1D metadata"
5258                );
5259            }
5260        };
5261        let metadata_knots = match marginalspec.knotspec {
5262            BSplineKnotSpec::PeriodicUniform {
5263                data_range,
5264                num_basis,
5265            } => Array1::linspace(data_range.0, data_range.1, num_basis),
5266            _ => knots,
5267        };
5268        marginal_knots.push(metadata_knots);
5269        marginal_is_cr_flags.push(marginal_is_cr);
5270        marginal_degrees.push(marginalspec.degree);
5271        marginalnum_basis.push(built.design.ncols());
5272        // Capture the sparse representation of this marginal (when the
5273        // 1D builder produced one) before densifying for the dense
5274        // marginal cache used by `tensor_product_design_from_marginals`
5275        // and `TensorProductDesignOperator`.
5276        let dense_marginal = built.design.to_dense();
5277        let sparse_view: Option<SparseColMat<usize, f64>> = match built.design.as_sparse() {
5278            Some(sd) => {
5279                let inner: &SparseColMat<usize, f64> = sd;
5280                Some(inner.clone())
5281            }
5282            None => match marginalspec.knotspec {
5283                BSplineKnotSpec::PeriodicUniform { .. } => {
5284                    Some(dense_local_margin_to_sparse(&dense_marginal)?)
5285                }
5286                _ => None,
5287            },
5288        };
5289        marginal_sparse.push(sparse_view);
5290        marginal_designs.push(dense_marginal);
5291        marginal_penalties.push(
5292            built
5293                .penalties
5294                .first()
5295                .ok_or_else(|| {
5296                    BasisError::InvalidInput(format!(
5297                        "internal TensorBSpline error at dim {dim}: missing marginal penalty"
5298                    ))
5299                })?
5300                .clone(),
5301        );
5302        built.nullspace_dims.first().ok_or_else(|| {
5303            BasisError::InvalidInput(format!(
5304                "internal TensorBSpline error at dim {dim}: missing marginal nullspace dim"
5305            ))
5306        })?;
5307        // A `PeriodicUniform` marginal knotspec implies the margin is
5308        // wrap-around: the 1D builder already realized it as a periodic
5309        // basis, so the tensor product inherits that periodicity. Record
5310        // the period derived from the knotspec's data range so freeze
5311        // restores `PeriodicUniform` on the marginal — otherwise the
5312        // round-trip downgrades it to `Provided(knots)` (an open spline)
5313        // and predict-time wraps disappear.
5314        let implied_period = match marginalspec.knotspec {
5315            BSplineKnotSpec::PeriodicUniform { data_range, .. } => {
5316                Some(data_range.1 - data_range.0)
5317            }
5318            _ => spec.periods.get(dim).and_then(|p| *p),
5319        };
5320        marginal_effective_periods.push(implied_period);
5321    }
5322
5323    let total_cols: usize = marginalnum_basis.iter().product();
5324    let mut dense_design = (!matches!(spec.identifiability, TensorBSplineIdentifiability::None))
5325        .then(|| tensor_product_design_from_marginals(&marginal_designs))
5326        .transpose()?;
5327    let mut candidates = Vec::<PenaltyCandidate>::with_capacity(
5328        match spec.penalty_decomposition {
5329            TensorBSplinePenaltyDecomposition::MarginalKroneckerSum => marginal_penalties.len(),
5330            TensorBSplinePenaltyDecomposition::Separable => marginal_penalties.len() * 2,
5331        } + if spec.double_penalty { 1 } else { 0 },
5332    );
5333
5334    // Tensor-product smoothing parameters are one-per-margin.  Therefore the
5335    // physical penalty attached to a margin must be normalized in that margin's
5336    // own working coordinates before it is embedded in the full tensor product.
5337    // Normalizing only the already-Kroneckered matrix would fold arbitrary
5338    // dimension-dependent identity factors into the margin's lambda and would
5339    // make anisotropic REML/LAML smoothing depend on the other margins' basis
5340    // sizes rather than on the marginal roughness operator itself.
5341    let normalized_marginal_penalties: Vec<(Array2<f64>, f64)> = marginal_penalties
5342        .iter()
5343        .map(normalize_penalty_in_constrained_space)
5344        .collect();
5345    let mut kronecker_marginal_penalties =
5346        Vec::<Array2<f64>>::with_capacity(normalized_marginal_penalties.len());
5347
5348    match spec.penalty_decomposition {
5349        TensorBSplinePenaltyDecomposition::MarginalKroneckerSum => {
5350            // Accumulate the Kronecker-sum of the per-margin penalties,
5351            // `Σ_dim S_dim`, whose null space is exactly the *joint* null space
5352            // of all marginal penalties — the tensor of marginal polynomial
5353            // null spaces. The tensor double penalty (below) shrinks only this
5354            // joint null, never the already-penalized interaction range.
5355            let mut marginal_kron_sum = Array2::<f64>::zeros((total_cols, total_cols));
5356
5357            for dim in 0..normalized_marginal_penalties.len() {
5358                let mut s_dim = Array2::<f64>::eye(1);
5359                let mut factors = Vec::<Array2<f64>>::with_capacity(marginalnum_basis.len());
5360                for (j, &qj) in marginalnum_basis.iter().enumerate() {
5361                    let factor = if j == dim {
5362                        normalized_marginal_penalties[j].0.clone()
5363                    } else {
5364                        Array2::<f64>::eye(qj)
5365                    };
5366                    factors.push(factor.clone());
5367                    s_dim = kronecker_product(&s_dim, &factor);
5368                }
5369                if dim == kronecker_marginal_penalties.len() {
5370                    kronecker_marginal_penalties.push(normalized_marginal_penalties[dim].0.clone());
5371                }
5372                marginal_kron_sum += &s_dim;
5373
5374                candidates.push(PenaltyCandidate {
5375                    matrix: s_dim,
5376                    nullspace_dim_hint: 0,
5377                    source: PenaltySource::TensorMarginal { dim },
5378                    normalization_scale: normalized_marginal_penalties[dim].1,
5379                    kronecker_factors: Some(factors),
5380                    op: None,
5381                });
5382            }
5383
5384            if spec.double_penalty
5385                && let Some(shrink) =
5386                    crate::basis::build_nullspace_shrinkage_penalty(&marginal_kron_sum)?
5387            {
5388                let (matrix, normalization_scale) =
5389                    normalize_penalty_in_constrained_space(&shrink.sym_penalty);
5390                candidates.push(PenaltyCandidate {
5391                    matrix,
5392                    nullspace_dim_hint: 0,
5393                    source: PenaltySource::TensorGlobalRidge,
5394                    normalization_scale,
5395                    kronecker_factors: None,
5396                    op: None,
5397                });
5398            }
5399        }
5400        TensorBSplinePenaltyDecomposition::Separable => {
5401            let projectors = tensor_margin_range_null_projectors(&normalized_marginal_penalties)?;
5402            let n_masks = 1usize.checked_shl(projectors.len() as u32).ok_or_else(|| {
5403                BasisError::InvalidInput(format!(
5404                    "t2 separable tensor penalty supports at most {} margins, got {}",
5405                    usize::BITS - 1,
5406                    projectors.len()
5407                ))
5408            })?;
5409            for mask in 1..n_masks {
5410                let mut matrix = Array2::<f64>::eye(1);
5411                let mut factors = Vec::<Array2<f64>>::with_capacity(projectors.len());
5412                let mut penalized_margins = Vec::<usize>::new();
5413                for (dim, projector) in projectors.iter().enumerate() {
5414                    let use_range = ((mask >> dim) & 1) == 1;
5415                    let factor = if use_range {
5416                        penalized_margins.push(dim);
5417                        projector.range.clone()
5418                    } else {
5419                        projector.null.clone()
5420                    };
5421                    matrix = kronecker_product(&matrix, &factor);
5422                    factors.push(factor);
5423                }
5424                let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&matrix);
5425                candidates.push(PenaltyCandidate {
5426                    matrix,
5427                    nullspace_dim_hint: 0,
5428                    source: PenaltySource::TensorSeparable { penalized_margins },
5429                    normalization_scale,
5430                    kronecker_factors: Some(factors),
5431                    op: None,
5432                });
5433            }
5434
5435            if spec.double_penalty {
5436                let mut matrix = Array2::<f64>::eye(1);
5437                let mut factors = Vec::<Array2<f64>>::with_capacity(projectors.len());
5438                for projector in &projectors {
5439                    matrix = kronecker_product(&matrix, &projector.null);
5440                    factors.push(projector.null.clone());
5441                }
5442                let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&matrix);
5443                candidates.push(PenaltyCandidate {
5444                    matrix,
5445                    nullspace_dim_hint: 0,
5446                    source: PenaltySource::TensorGlobalRidge,
5447                    normalization_scale,
5448                    kronecker_factors: Some(factors),
5449                    op: None,
5450                });
5451            }
5452        }
5453    }
5454
5455    let z_opt = match &spec.identifiability {
5456        TensorBSplineIdentifiability::None => None,
5457        TensorBSplineIdentifiability::SumToZero => {
5458            if total_cols < 2 {
5459                crate::bail_invalid_basis!(
5460                    "TensorBSpline requires at least 2 basis coefficients to enforce sum-to-zero identifiability"
5461                );
5462            }
5463            let dense_design_ref = dense_design.as_ref().ok_or_else(|| {
5464                BasisError::InvalidInput(
5465                    "tensor sum-to-zero identifiability requires a realized basis".to_string(),
5466                )
5467            })?;
5468            let (_, z) = apply_sum_to_zero_constraint(dense_design_ref.view(), None)?;
5469            let gauge = gam_problem::Gauge::sum_to_zero(z);
5470            Some(gauge.block_transform(0))
5471        }
5472        TensorBSplineIdentifiability::MarginalSumToZero => {
5473            // `ti(...)`: drop the marginal main effects by centering every
5474            // margin independently, then form the tensor product of the
5475            // centered margins. Concretely, each margin `j` is reparameterized
5476            // by its own sum-to-zero null basis `Z_j` (so the constant — i.e.
5477            // the marginal intercept — is removed from that axis), and the
5478            // combined reparameterization is the Kronecker product
5479            // `Z = Z₀ ⊗ Z₁ ⊗ … ⊗ Z_{d-1}`. Applying `Z` to the full-tensor
5480            // design `B = B₀ ⊗ … ⊗ B_{d-1}` yields `B Z = (B₀ Z₀) ⊗ … ⊗
5481            // (B_{d-1} Z_{d-1})`, the tensor product of the centered margins,
5482            // which by construction contains no pure main effect.
5483            if marginal_designs.len() < 2 {
5484                crate::bail_invalid_basis!(
5485                    "tensor interaction (ti) identifiability requires at least 2 margins"
5486                );
5487            }
5488            let mut z = Array2::<f64>::eye(1);
5489            for (dim, marginal) in marginal_designs.iter().enumerate() {
5490                if marginal.ncols() < 2 {
5491                    crate::bail_invalid_basis!(
5492                        "tensor interaction (ti) margin {dim} has fewer than 2 basis functions; \
5493                         cannot remove its marginal main effect"
5494                    );
5495                }
5496                let (_, z_dim) = apply_sum_to_zero_constraint(marginal.view(), None)?;
5497                let gauge_dim = gam_problem::Gauge::sum_to_zero(z_dim);
5498                let z_dim = gauge_dim.block_transform(0);
5499                z = kronecker_product(&z, &z_dim);
5500            }
5501            Some(z)
5502        }
5503        TensorBSplineIdentifiability::FrozenTransform { transform } => {
5504            if transform.nrows() != total_cols {
5505                crate::bail_dim_basis!(
5506                    "frozen tensor identifiability transform mismatch: design has {} columns but transform has {} rows",
5507                    total_cols,
5508                    transform.nrows()
5509                );
5510            }
5511            Some(transform.clone())
5512        }
5513    };
5514
5515    if let Some(z) = z_opt.as_ref() {
5516        let gauge = gam_problem::Gauge::from_block_transforms(&[z.clone()]);
5517        let dense = dense_design.as_mut().ok_or_else(|| {
5518            BasisError::InvalidInput(
5519                "tensor identifiability transform requires a realized basis".to_string(),
5520            )
5521        })?;
5522        let restricted_design = gauge.restrict_design(dense);
5523        *dense = restricted_design;
5524        candidates = candidates
5525            .into_iter()
5526            .map(|candidate| -> Result<PenaltyCandidate, BasisError> {
5527                let matrix = gauge.restrict_penalty(&candidate.matrix);
5528                // Re-normalize in the *actual* coefficient chart used by the
5529                // fit.  The tensor sum-to-zero transform is not norm-preserving
5530                // for each overlapping marginal penalty, so carrying the raw
5531                // marginal Frobenius scale into the restricted space changes the
5532                // relative amount of smoothing seen by the LAML/REML optimizer.
5533                // Keep the physical scale in metadata and give the optimizer
5534                // unit-scale constrained penalties for every tensor margin.
5535                let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
5536                Ok(PenaltyCandidate {
5537                    nullspace_dim_hint: candidate.nullspace_dim_hint,
5538                    matrix,
5539                    source: candidate.source,
5540                    normalization_scale: candidate.normalization_scale * c_new,
5541                    // Z^T S Z is no longer a Kronecker product of the original
5542                    // marginal factors, so the Kronecker fast path in construction.rs
5543                    // must not be taken. Clearing kronecker_factors forces the generic
5544                    // block-local eigendecomposition path, which operates on the
5545                    // transformed matrix and is correct.
5546                    kronecker_factors: None,
5547                    op: candidate.op.clone(),
5548                })
5549            })
5550            .collect::<Result<Vec<_>, _>>()?;
5551    }
5552
5553    let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
5554        filter_active_penalty_candidates_with_ops(candidates)?;
5555    let identifiability_is_none =
5556        matches!(spec.identifiability, TensorBSplineIdentifiability::None);
5557    // All marginals expose a sparse representation iff each `marginal_sparse`
5558    // slot is `Some(...)`. Currently this is true when every marginal is a
5559    // free-boundary, non-periodic 1D B-spline returned as
5560    // `DesignMatrix::Sparse` from `build_bspline_basis_1d`. Periodic B-splines
5561    // and other dense-only marginals leave a `None` and trigger the fall-back
5562    // path. Identifiability transforms (`SumToZero`, `FrozenTransform`) make
5563    // the tensor design dense in general, so we also gate on that.
5564    let all_marginals_sparse = marginal_sparse.iter().all(Option::is_some);
5565    let design = if let Some(dense_design) = dense_design {
5566        DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense_design))
5567    } else if identifiability_is_none && all_marginals_sparse {
5568        // Sparse Khatri-Rao path: assemble the (n, ∏ q_j) tensor product
5569        // directly as a SparseColMat, preserving the ∏(degree_j+1) nonzero
5570        // structure per row instead of densifying to ∏ q_j columns. This is
5571        // mathematically identical to `tensor_product_design_from_marginals`
5572        // applied to the corresponding dense marginals.
5573        let sparse_marginals: Vec<&SparseColMat<usize, f64>> = marginal_sparse
5574            .iter()
5575            .map(|m| m.as_ref().expect("all_marginals_sparse just verified"))
5576            .collect();
5577        let sparse_design = tensor_product_design_from_sparse_marginals(&sparse_marginals)?;
5578        DesignMatrix::Sparse(gam_linalg::matrix::SparseDesignMatrix::new(sparse_design))
5579    } else {
5580        let marginals: Vec<Arc<Array2<f64>>> = marginal_designs
5581            .iter()
5582            .map(|m| Arc::new(m.clone()))
5583            .collect();
5584        let op = TensorProductDesignOperator::new(marginals).map_err(|e| {
5585            BasisError::InvalidInput(format!("TensorProductDesignOperator build failed: {e}"))
5586        })?;
5587        DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op)))
5588    };
5589
5590    Ok(BasisBuildResult {
5591        design,
5592        penalties,
5593        nullspace_dims,
5594        penaltyinfo,
5595        ops,
5596        null_eigenvectors,
5597        joint_null_rotation: None,
5598        metadata: BasisMetadata::TensorBSpline {
5599            feature_cols: feature_cols.to_vec(),
5600            knots: marginal_knots,
5601            degrees: marginal_degrees,
5602            // Prefer the per-margin effective period derived in the loop —
5603            // it captures both the explicit `spec.periods` route and the
5604            // implied period from a `PeriodicUniform` marginal knotspec.
5605            // Falling back to `spec.periods` when populated keeps any
5606            // user-supplied explicit period authoritative even if the
5607            // marginal knotspec carried no periodicity hint.
5608            periods: marginal_effective_periods,
5609            is_cr: marginal_is_cr_flags,
5610            identifiability_transform: z_opt,
5611        },
5612        kronecker_factored: if matches!(spec.identifiability, TensorBSplineIdentifiability::None)
5613            && matches!(
5614                spec.penalty_decomposition,
5615                TensorBSplinePenaltyDecomposition::MarginalKroneckerSum
5616            ) {
5617            Some(KroneckerFactoredBasis::new(
5618                marginal_designs,
5619                kronecker_marginal_penalties,
5620                marginalnum_basis.clone(),
5621                spec.double_penalty,
5622            ))
5623        } else {
5624            None
5625        },
5626    })
5627}
5628
5629pub fn tensor_product_design_from_marginals(
5630    marginal_designs: &[Array2<f64>],
5631) -> Result<Array2<f64>, BasisError> {
5632    if marginal_designs.is_empty() {
5633        crate::bail_invalid_basis!("TensorBSpline requires at least one marginal basis");
5634    }
5635    let n = marginal_designs[0].nrows();
5636    for (i, b) in marginal_designs.iter().enumerate().skip(1) {
5637        if b.nrows() != n {
5638            crate::bail_dim_basis!(
5639                "tensor marginal row mismatch at dim {i}: expected {n}, got {}",
5640                b.nrows()
5641            );
5642        }
5643    }
5644    let total_cols = marginal_designs.iter().try_fold(1usize, |acc, b| {
5645        acc.checked_mul(b.ncols())
5646            .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))
5647    })?;
5648    // Tensor-product Khatri-Rao: design[i, j] = Π_d marginal_d[i, j_d]
5649    // where j is the multi-index (j_1, ..., j_D) flattened. Independent
5650    // across rows; parallelize row chunks and fill the pre-allocated
5651    // contiguous Array2 in place (no Vec-flatten-collect intermediate,
5652    // which doubled the peak memory at large-scale N).
5653    use ndarray::parallel::prelude::*;
5654    use rayon::iter::{IntoParallelIterator, ParallelIterator};
5655    let mut design = Array2::<f64>::zeros((n, total_cols));
5656    design
5657        .axis_chunks_iter_mut(ndarray::Axis(0), 1024)
5658        .into_par_iter()
5659        .enumerate()
5660        .for_each(|(chunk_idx, mut block)| {
5661            let row_offset = chunk_idx * 1024;
5662            // Scratch buffers reused across rows in this chunk.
5663            let mut cur = Vec::<f64>::with_capacity(total_cols);
5664            let mut next = Vec::<f64>::with_capacity(total_cols);
5665            for (local_i, mut out_row) in block.outer_iter_mut().enumerate() {
5666                let i = row_offset + local_i;
5667                cur.clear();
5668                cur.push(1.0);
5669                for b in marginal_designs {
5670                    let q = b.ncols();
5671                    next.clear();
5672                    next.resize(cur.len() * q, 0.0);
5673                    // Hoist the row view out of the inner `col` loop so the
5674                    // q reads per `a_idx` reuse a single contiguous slice
5675                    // instead of recomputing `b[[i, col]]` strides per cell.
5676                    let b_row = b.row(i);
5677                    let b_slice = b_row
5678                        .as_slice()
5679                        .expect("Array2 row from outer_iter is contiguous");
5680                    for (a_idx, &aval) in cur.iter().enumerate() {
5681                        let off = a_idx * q;
5682                        let dst = &mut next[off..off + q];
5683                        for col in 0..q {
5684                            dst[col] = aval * b_slice[col];
5685                        }
5686                    }
5687                    std::mem::swap(&mut cur, &mut next);
5688                }
5689                // `out_row` is a row of the contiguous C-major `design`
5690                // Array2, so it is backed by a contiguous slice. Use a
5691                // bulk slice copy instead of an element-by-element write
5692                // loop.
5693                let out_slice = out_row
5694                    .as_slice_mut()
5695                    .expect("design row is contiguous in C-major Array2");
5696                out_slice.copy_from_slice(&cur);
5697            }
5698        });
5699    Ok(design)
5700}
5701
5702pub fn build_random_effect_block(
5703    data: ArrayView2<'_, f64>,
5704    spec: &RandomEffectTermSpec,
5705) -> Result<RandomEffectBlock, BasisError> {
5706    let n = data.nrows();
5707    let p = data.ncols();
5708    if spec.feature_col >= p {
5709        crate::bail_dim_basis!(
5710            "random-effect term '{}' feature column {} out of bounds for {} columns",
5711            spec.name,
5712            spec.feature_col,
5713            p
5714        );
5715    }
5716
5717    let col = data.column(spec.feature_col);
5718    if col.iter().any(|v| !v.is_finite()) {
5719        crate::bail_invalid_basis!(
5720            "random-effect term '{}' contains non-finite group values",
5721            spec.name
5722        );
5723    }
5724
5725    let kept_levels: Vec<u64> = if let Some(levels) = spec.frozen_levels.as_ref() {
5726        if levels.is_empty() {
5727            crate::bail_invalid_basis!(
5728                "random-effect term '{}' has empty frozen_levels",
5729                spec.name
5730            );
5731        }
5732        levels.clone()
5733    } else {
5734        let mut levels_set = BTreeSet::<u64>::new();
5735        for &v in col {
5736            levels_set.insert(v.to_bits());
5737        }
5738        if levels_set.is_empty() {
5739            crate::bail_invalid_basis!("random-effect term '{}' has no observed levels", spec.name);
5740        }
5741        let levels: Vec<u64> = levels_set.into_iter().collect();
5742        let start_idx = if spec.drop_first_level && levels.len() > 1 {
5743            1usize
5744        } else {
5745            0usize
5746        };
5747        levels[start_idx..].to_vec()
5748    };
5749
5750    if kept_levels.is_empty() {
5751        crate::bail_invalid_basis!(
5752            "random-effect term '{}' drops all levels; keep at least one level",
5753            spec.name
5754        );
5755    }
5756
5757    let q = kept_levels.len();
5758    let mut level_to_col = BTreeMap::<u64, usize>::new();
5759    for (idx, &bits) in kept_levels.iter().enumerate() {
5760        if level_to_col.insert(bits, idx).is_some() {
5761            crate::bail_invalid_basis!(
5762                "random-effect term '{}' has duplicate frozen level bits {bits}",
5763                spec.name
5764            );
5765        }
5766    }
5767    let mut group_ids = Vec::with_capacity(n);
5768    for &v in col {
5769        let bits = v.to_bits();
5770        group_ids.push(level_to_col.get(&bits).copied());
5771    }
5772
5773    Ok(RandomEffectBlock {
5774        name: spec.name.clone(),
5775        group_ids,
5776        num_groups: q,
5777        kept_levels,
5778    })
5779}
5780
5781impl SmoothDesign {
5782    /// Map an unconstrained term coefficient vector to its constrained shape space.
5783    /// This is useful for nonlinear fits that optimize unconstrained parameters.
5784    pub fn map_term_coefficients(
5785        unconstrained: &Array1<f64>,
5786        shape: ShapeConstraint,
5787    ) -> Result<Array1<f64>, BasisError> {
5788        if unconstrained.is_empty() {
5789            crate::bail_invalid_basis!("unconstrained coefficient vector cannot be empty");
5790        }
5791        let mapped = match shape {
5792            ShapeConstraint::None => unconstrained.clone(),
5793            ShapeConstraint::MonotoneIncreasing => cumulative_exp(unconstrained, 1.0),
5794            ShapeConstraint::MonotoneDecreasing => cumulative_exp(unconstrained, -1.0),
5795            ShapeConstraint::Convex => second_cumulative_exp(unconstrained, 1.0),
5796            ShapeConstraint::Concave => second_cumulative_exp(unconstrained, -1.0),
5797        };
5798        Ok(mapped)
5799    }
5800}
5801
5802pub struct LocalSmoothTermBuild {
5803    pub dim: usize,
5804    pub design: DesignMatrix,
5805    pub penalties: Vec<Array2<f64>>,
5806    pub ops: Vec<Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>>,
5807    pub nullspaces: Vec<usize>,
5808    /// Per-active-penalty null-space eigenvector matrices, parallel to
5809    /// `penalties` / `ops` / `nullspaces`. `Some(U_null)` when
5810    /// `nullspaces[k] > 0`, with `U_null` orthonormal columns spanning
5811    /// `null(penalties[k])` in this smooth's local coordinate system; `None`
5812    /// when the active block is full-rank. Stage 1 plumbing; Stage 2
5813    /// consumes this to absorb the smooth's null space into the parametric
5814    /// block at `TermCollectionDesign` construction.
5815    pub null_eigenvectors: Vec<Option<Array2<f64>>>,
5816    /// Joint-null absorption rotation for this smooth. `Some(rotation)`
5817    /// records `Q = [U_range | U_null]` spanning `null(Σ_k penalties[k])`,
5818    /// the joint null across all active penalty blocks on this smooth.
5819    /// `None` means the joint penalty is full-rank (joint nullity = 0) or
5820    /// there are no penalties. Stage-2 commit A: plumbing only — populated
5821    /// by commit B, applied by commit D.
5822    pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
5823    pub penaltyinfo: Vec<PenaltyInfo>,
5824    pub pre_dropped_penaltyinfo: Vec<PenaltyInfo>,
5825    pub metadata: BasisMetadata,
5826    pub linear_constraints: Option<LinearInequalityConstraints>,
5827    pub box_reparam: bool,
5828    pub kronecker_factored: Option<KroneckerFactoredBasis>,
5829}
5830
5831#[derive(Clone)]
5832pub struct PcaScoresMemmapDesignOperator {
5833    mmap: Arc<memmap2::Mmap>,
5834    data_offset: usize,
5835    nrows: usize,
5836    ncols: usize,
5837    chunk_size: usize,
5838}
5839
5840impl PcaScoresMemmapDesignOperator {
5841    fn open(path: PathBuf, chunk_size: usize) -> Result<Self, BasisError> {
5842        let file = File::open(&path).map_err(|err| {
5843            BasisError::InvalidInput(format!(
5844                "failed to open lazy Pca .npy scores '{}': {err}",
5845                path.display()
5846            ))
5847        })?;
5848        // The .npy scores file is read-only training-cache data; this
5849        // module never mutates it. The error path below converts mmap
5850        // failure to a typed `BasisError::InvalidInput`.
5851        // SAFETY: `memmap2::Mmap::map` requires no concurrent writers; the
5852        // contract is held by this module's read-only access pattern.
5853        let mmap = unsafe {
5854            memmap2::Mmap::map(&file).map_err(|err| {
5855                BasisError::InvalidInput(format!(
5856                    "failed to memmap lazy Pca .npy scores '{}': {err}",
5857                    path.display()
5858                ))
5859            })?
5860        };
5861        let (data_offset, nrows, ncols) = parse_f64_2d_npy_header(&mmap, &path)?;
5862        let expected = data_offset
5863            .checked_add(nrows.saturating_mul(ncols).saturating_mul(8))
5864            .ok_or_else(|| {
5865                BasisError::InvalidInput(format!(
5866                    "lazy Pca .npy scores '{}' shape is too large",
5867                    path.display()
5868                ))
5869            })?;
5870        if mmap.len() < expected {
5871            crate::bail_invalid_basis!(
5872                "lazy Pca .npy scores '{}' is truncated: header expects {} bytes, file has {}",
5873                path.display(),
5874                expected,
5875                mmap.len()
5876            );
5877        }
5878        Ok(Self {
5879            mmap: Arc::new(mmap),
5880            data_offset,
5881            nrows,
5882            ncols,
5883            chunk_size: chunk_size.max(1),
5884        })
5885    }
5886
5887    fn value(&self, row: usize, col: usize) -> f64 {
5888        let offset = self.data_offset + (row * self.ncols + col) * 8;
5889        let mut bytes = [0_u8; 8];
5890        bytes.copy_from_slice(&self.mmap[offset..offset + 8]);
5891        f64::from_le_bytes(bytes)
5892    }
5893
5894    fn chunk_rows(&self) -> usize {
5895        self.chunk_size.min(self.nrows.max(1))
5896    }
5897}
5898
5899impl LinearOperator for PcaScoresMemmapDesignOperator {
5900    fn nrows(&self) -> usize {
5901        self.nrows
5902    }
5903
5904    fn ncols(&self) -> usize {
5905        self.ncols
5906    }
5907
5908    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
5909        assert_eq!(
5910            vector.len(),
5911            self.ncols,
5912            "lazy Pca apply vector length mismatch"
5913        );
5914        let mut out = Array1::<f64>::zeros(self.nrows);
5915        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5916            let end = (start + self.chunk_rows()).min(self.nrows);
5917            for row in start..end {
5918                let mut acc = 0.0;
5919                for col in 0..self.ncols {
5920                    acc += self.value(row, col) * vector[col];
5921                }
5922                out[row] = acc;
5923            }
5924        }
5925        out
5926    }
5927
5928    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
5929        assert_eq!(
5930            vector.len(),
5931            self.nrows,
5932            "lazy Pca apply_transpose vector length mismatch"
5933        );
5934        let mut out = Array1::<f64>::zeros(self.ncols);
5935        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5936            let end = (start + self.chunk_rows()).min(self.nrows);
5937            for row in start..end {
5938                let scale = vector[row];
5939                if scale == 0.0 {
5940                    continue;
5941                }
5942                for col in 0..self.ncols {
5943                    out[col] += scale * self.value(row, col);
5944                }
5945            }
5946        }
5947        out
5948    }
5949
5950    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
5951        if weights.len() != self.nrows {
5952            return Err(format!(
5953                "lazy Pca diag_xtw_x weight length mismatch: weights={}, nrows={}",
5954                weights.len(),
5955                self.nrows
5956            ));
5957        }
5958        let mut gram = Array2::<f64>::zeros((self.ncols, self.ncols));
5959        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5960            let end = (start + self.chunk_rows()).min(self.nrows);
5961            for row in start..end {
5962                let w = weights[row];
5963                if w == 0.0 {
5964                    continue;
5965                }
5966                for a in 0..self.ncols {
5967                    let xa = self.value(row, a);
5968                    if xa == 0.0 {
5969                        continue;
5970                    }
5971                    for b in a..self.ncols {
5972                        gram[[a, b]] += w * xa * self.value(row, b);
5973                    }
5974                }
5975            }
5976        }
5977        for a in 0..self.ncols {
5978            for b in 0..a {
5979                gram[[a, b]] = gram[[b, a]];
5980            }
5981        }
5982        Ok(gram)
5983    }
5984
5985    fn apply_weighted_normal(
5986        &self,
5987        weights: &Array1<f64>,
5988        vector: &Array1<f64>,
5989        penalty: Option<&Array2<f64>>,
5990        ridge: f64,
5991    ) -> Array1<f64> {
5992        assert_eq!(
5993            weights.len(),
5994            self.nrows,
5995            "lazy Pca weighted-normal weight mismatch"
5996        );
5997        assert_eq!(
5998            vector.len(),
5999            self.ncols,
6000            "lazy Pca weighted-normal vector mismatch"
6001        );
6002        let mut out = Array1::<f64>::zeros(self.ncols);
6003        for start in (0..self.nrows).step_by(self.chunk_rows()) {
6004            let end = (start + self.chunk_rows()).min(self.nrows);
6005            for row in start..end {
6006                let w = weights[row].max(0.0);
6007                if w == 0.0 {
6008                    continue;
6009                }
6010                let mut row_dot = 0.0;
6011                for col in 0..self.ncols {
6012                    row_dot += self.value(row, col) * vector[col];
6013                }
6014                if row_dot == 0.0 {
6015                    continue;
6016                }
6017                let scaled = w * row_dot;
6018                for col in 0..self.ncols {
6019                    out[col] += scaled * self.value(row, col);
6020                }
6021            }
6022        }
6023        if let Some(pen) = penalty {
6024            out += &pen.dot(vector);
6025        }
6026        if ridge > 0.0 {
6027            out += &vector.mapv(|x| ridge * x);
6028        }
6029        out
6030    }
6031}
6032
6033impl DenseDesignOperator for PcaScoresMemmapDesignOperator {
6034    fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
6035        if weights.len() != self.nrows || y.len() != self.nrows {
6036            return Err(format!(
6037                "lazy Pca compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
6038                weights.len(),
6039                y.len(),
6040                self.nrows
6041            ));
6042        }
6043        let mut out = Array1::<f64>::zeros(self.ncols);
6044        for start in (0..self.nrows).step_by(self.chunk_rows()) {
6045            let end = (start + self.chunk_rows()).min(self.nrows);
6046            for row in start..end {
6047                let scale = weights[row] * y[row];
6048                if scale == 0.0 {
6049                    continue;
6050                }
6051                for col in 0..self.ncols {
6052                    out[col] += scale * self.value(row, col);
6053                }
6054            }
6055        }
6056        Ok(out)
6057    }
6058
6059    fn row_chunk_into(
6060        &self,
6061        rows: Range<usize>,
6062        mut out: ArrayViewMut2<'_, f64>,
6063    ) -> Result<(), MatrixMaterializationError> {
6064        if rows.end > self.nrows || rows.start > rows.end {
6065            return Err(MatrixMaterializationError::MissingRowChunk {
6066                context: "lazy Pca row range out of bounds",
6067            });
6068        }
6069        if out.nrows() != rows.end - rows.start || out.ncols() != self.ncols {
6070            return Err(MatrixMaterializationError::MissingRowChunk {
6071                context: "lazy Pca row_chunk_into shape mismatch",
6072            });
6073        }
6074        for (local, row) in (rows.start..rows.end).enumerate() {
6075            for col in 0..self.ncols {
6076                out[[local, col]] = self.value(row, col);
6077            }
6078        }
6079        Ok(())
6080    }
6081
6082    fn to_dense(&self) -> Array2<f64> {
6083        let mut out = Array2::<f64>::zeros((self.nrows, self.ncols));
6084        self.row_chunk_into(0..self.nrows, out.view_mut())
6085            .expect("lazy Pca full materialization failed");
6086        out
6087    }
6088}
6089
6090pub fn parse_f64_2d_npy_header(
6091    bytes: &[u8],
6092    path: &PathBuf,
6093) -> Result<(usize, usize, usize), BasisError> {
6094    if bytes.len() < 10 || &bytes[0..6] != b"\x93NUMPY" {
6095        crate::bail_invalid_basis!("lazy Pca scores '{}' is not a .npy file", path.display());
6096    }
6097    let major = bytes[6];
6098    let header_len = match major {
6099        1 => u16::from_le_bytes([bytes[8], bytes[9]]) as usize,
6100        2 | 3 => {
6101            if bytes.len() < 12 {
6102                crate::bail_invalid_basis!(
6103                    "lazy Pca scores '{}' has a truncated .npy header",
6104                    path.display()
6105                );
6106            }
6107            u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize
6108        }
6109        other => {
6110            crate::bail_invalid_basis!(
6111                "lazy Pca scores '{}' uses unsupported .npy version {}",
6112                path.display(),
6113                other
6114            );
6115        }
6116    };
6117    let header_start = if major == 1 { 10 } else { 12 };
6118    let data_offset = header_start + header_len;
6119    if bytes.len() < data_offset {
6120        crate::bail_invalid_basis!(
6121            "lazy Pca scores '{}' has a truncated .npy header",
6122            path.display()
6123        );
6124    }
6125    let header = std::str::from_utf8(&bytes[header_start..data_offset]).map_err(|err| {
6126        BasisError::InvalidInput(format!(
6127            "lazy Pca scores '{}' has a non-UTF8 .npy header: {err}",
6128            path.display()
6129        ))
6130    })?;
6131    if !(header.contains("'descr': '<f8'")
6132        || header.contains("\"descr\": \"<f8\"")
6133        || header.contains("'descr': '|f8'")
6134        || header.contains("\"descr\": \"|f8\""))
6135    {
6136        crate::bail_invalid_basis!(
6137            "lazy Pca scores '{}' must be float64 little-endian .npy",
6138            path.display()
6139        );
6140    }
6141    if header.contains("True") {
6142        crate::bail_invalid_basis!(
6143            "lazy Pca scores '{}' must be C-contiguous, not Fortran-ordered",
6144            path.display()
6145        );
6146    }
6147    let shape_pos = header.find("shape").ok_or_else(|| {
6148        BasisError::InvalidInput(format!(
6149            "lazy Pca scores '{}' .npy header is missing shape",
6150            path.display()
6151        ))
6152    })?;
6153    let open = header[shape_pos..].find('(').ok_or_else(|| {
6154        BasisError::InvalidInput(format!(
6155            "lazy Pca scores '{}' .npy header has malformed shape",
6156            path.display()
6157        ))
6158    })? + shape_pos;
6159    let close = header[open..].find(')').ok_or_else(|| {
6160        BasisError::InvalidInput(format!(
6161            "lazy Pca scores '{}' .npy header has malformed shape",
6162            path.display()
6163        ))
6164    })? + open;
6165    let dims = header[open + 1..close]
6166        .split(',')
6167        .map(str::trim)
6168        .filter(|part| !part.is_empty())
6169        .map(|part| part.parse::<usize>())
6170        .collect::<Result<Vec<_>, _>>()
6171        .map_err(|err| {
6172            BasisError::InvalidInput(format!(
6173                "lazy Pca scores '{}' .npy shape is not integral: {err}",
6174                path.display()
6175            ))
6176        })?;
6177    if dims.len() != 2 {
6178        crate::bail_invalid_basis!(
6179            "lazy Pca scores '{}' must have shape (N, K), got {:?}",
6180            path.display(),
6181            dims
6182        );
6183    }
6184    Ok((data_offset, dims[0], dims[1]))
6185}
6186
6187pub fn pca_center_mean(x: ArrayView2<'_, f64>) -> Result<Array1<f64>, BasisError> {
6188    if x.nrows() == 0 {
6189        crate::bail_invalid_basis!("Pca basis requires at least one row to compute center mean");
6190    }
6191    let mut mean = Array1::<f64>::zeros(x.ncols());
6192    for row in x.rows() {
6193        mean += &row;
6194    }
6195    mean.mapv_inplace(|v| v / x.nrows() as f64);
6196    Ok(mean)
6197}
6198
6199pub fn build_pca_smooth_basis(
6200    data: ArrayView2<'_, f64>,
6201    feature_cols: &[usize],
6202    basis_matrix: &Array2<f64>,
6203    centered: bool,
6204    smooth_penalty: f64,
6205    center_mean: Option<&Array1<f64>>,
6206    pca_basis_path: Option<&PathBuf>,
6207    chunk_size: usize,
6208) -> Result<BasisBuildResult, BasisError> {
6209    if let Some(path) = pca_basis_path {
6210        let op = PcaScoresMemmapDesignOperator::open(path.clone(), chunk_size)?;
6211        if op.nrows != data.nrows() {
6212            crate::bail_dim_basis!(
6213                "lazy Pca scores row mismatch: .npy has {}, data has {}",
6214                op.nrows,
6215                data.nrows()
6216            );
6217        }
6218        let k = op.ncols;
6219        let mut penalty = Array2::<f64>::eye(k);
6220        penalty.mapv_inplace(|v| v * smooth_penalty);
6221        let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
6222            filter_active_penalty_candidates_with_ops(vec![PenaltyCandidate {
6223                matrix: penalty,
6224                nullspace_dim_hint: 0,
6225                source: PenaltySource::Other("PcaRidge".to_string()),
6226                normalization_scale: 1.0,
6227                kronecker_factors: None,
6228                op: None,
6229            }])?;
6230        return Ok(BasisBuildResult {
6231            design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op))),
6232            penalties,
6233            nullspace_dims,
6234            penaltyinfo,
6235            ops,
6236            null_eigenvectors,
6237            joint_null_rotation: None,
6238            metadata: BasisMetadata::Pca {
6239                feature_cols: feature_cols.to_vec(),
6240                basis_matrix: basis_matrix.clone(),
6241                centered,
6242                smooth_penalty,
6243                center_mean: center_mean.cloned(),
6244                pca_basis_path: Some(path.clone()),
6245                chunk_size: chunk_size.max(1),
6246            },
6247            kronecker_factored: None,
6248        });
6249    }
6250    if basis_matrix.nrows() != feature_cols.len() {
6251        crate::bail_dim_basis!(
6252            "Pca basis row mismatch: basis rows={}, feature columns={}",
6253            basis_matrix.nrows(),
6254            feature_cols.len()
6255        );
6256    }
6257    let mut x = select_columns(data, feature_cols)?;
6258    let mean = if centered {
6259        match center_mean {
6260            Some(mean) => mean.clone(),
6261            None => pca_center_mean(x.view())?,
6262        }
6263    } else {
6264        Array1::<f64>::zeros(feature_cols.len())
6265    };
6266    if centered {
6267        for mut row in x.rows_mut() {
6268            row -= &mean;
6269        }
6270    }
6271    let design = fast_ab(&x, basis_matrix);
6272    let k = basis_matrix.ncols();
6273    let mut penalty = Array2::<f64>::eye(k);
6274    penalty.mapv_inplace(|v| v * smooth_penalty);
6275    let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
6276        filter_active_penalty_candidates_with_ops(vec![PenaltyCandidate {
6277            matrix: penalty,
6278            nullspace_dim_hint: 0,
6279            source: PenaltySource::Other("PcaRidge".to_string()),
6280            normalization_scale: 1.0,
6281            kronecker_factors: None,
6282            op: None,
6283        }])?;
6284    Ok(BasisBuildResult {
6285        design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(design)),
6286        penalties,
6287        nullspace_dims,
6288        penaltyinfo,
6289        ops,
6290        null_eigenvectors,
6291        joint_null_rotation: None,
6292        metadata: BasisMetadata::Pca {
6293            feature_cols: feature_cols.to_vec(),
6294            basis_matrix: basis_matrix.clone(),
6295            centered,
6296            smooth_penalty,
6297            center_mean: centered.then_some(mean),
6298            pca_basis_path: None,
6299            chunk_size: chunk_size.max(1),
6300        },
6301        kronecker_factored: None,
6302    })
6303}
6304
6305/// A factor-level `by=` wrapper owns the model-space centering of its inner
6306/// smooth: it gates the raw/structurally-constrained basis to the level rows
6307/// and then centers that gated block exactly once against the level indicator
6308/// (`build_parametric_constraint_block_for_term` in `design_construction`).
6309/// Leaving the inner B-spline's default pooled weighted-sum-to-zero active here
6310/// would impose two generically-independent constraints — the pooled column
6311/// moment `m = Σ_h m_h` and the per-level moment `m_g` — so a raw `k`-column
6312/// basis collapses to `k-2` columns per level instead of `k-1`, deleting one
6313/// genuine nonconstant spline direction *before REML runs* (#1427). The group
6314/// main effect carries only the constant, so it cannot restore that direction.
6315///
6316/// Only the *default model-space* centering is deferred. Explicit structural or
6317/// frozen transforms (`RemoveLinearTrend`, `OrthogonalToDesignColumns`,
6318/// `FrozenTransform`, `None`) are user/structural choices and are preserved
6319/// verbatim.
6320pub fn defer_inner_model_centering_to_factor_level_wrapper(basis: &mut SmoothBasisSpec) {
6321    if let SmoothBasisSpec::BSpline1D { spec, .. } = basis
6322        && matches!(
6323            spec.identifiability,
6324            BSplineIdentifiability::WeightedSumToZero { .. }
6325        )
6326    {
6327        spec.identifiability = BSplineIdentifiability::None;
6328    }
6329}
6330
6331pub fn apply_by_variable_to_local_build(
6332    mut built: LocalSmoothTermBuild,
6333    data: ArrayView2<'_, f64>,
6334    by_col: usize,
6335    by: &ByVariableSpec,
6336    term_name: &str,
6337) -> Result<LocalSmoothTermBuild, BasisError> {
6338    if by_col >= data.ncols() {
6339        crate::bail_dim_basis!(
6340            "by-variable smooth term '{term_name}' references column {by_col}, but data has {} columns",
6341            data.ncols()
6342        );
6343    }
6344    let weights = match by {
6345        ByVariableSpec::Numeric => data.column(by_col).to_owned(),
6346        ByVariableSpec::Level { value_bits, .. } => data.column(by_col).mapv(|value| {
6347            if value.to_bits() == *value_bits {
6348                1.0
6349            } else {
6350                0.0
6351            }
6352        }),
6353    };
6354    if weights.iter().any(|value| !value.is_finite()) {
6355        crate::bail_invalid_basis!(
6356            "by-variable smooth term '{term_name}' has non-finite by-column values"
6357        );
6358    }
6359
6360    let mut dense = built
6361        .design
6362        .try_to_dense_by_chunks("by-variable smooth row gating")
6363        .map_err(BasisError::InvalidInput)?;
6364    for (mut row, &weight) in dense.rows_mut().into_iter().zip(weights.iter()) {
6365        row.mapv_inplace(|value| value * weight);
6366    }
6367    built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense));
6368    built.kronecker_factored = None;
6369    Ok(built)
6370}
6371
6372/// Build the local smooth term for a `BySmooth` spec, which unifies numeric-by
6373/// and factor-by modulation into a single `SmoothTermSpec`.
6374///
6375/// For a **numeric** by-variable the inner smooth is built once and every row
6376/// is multiplied by the by-column value (identical to `ByVariable::Numeric`).
6377///
6378/// For a **factor** by-variable the inner smooth is built once and gated per
6379/// level into side-by-side column blocks, producing a `n × (L * p)` design
6380/// matrix.  The penalties are block-diagonalised (one copy of the inner penalty
6381/// per level) exactly as `build_factor_smooth` does for `bs="fs"/"sz"`.
6382pub fn build_by_smooth_local(
6383    data: ArrayView2<'_, f64>,
6384    term: &SmoothTermSpec,
6385    smooth: &SmoothBasisSpec,
6386    by_kind: &ByVarKind,
6387    workspace: &mut crate::basis::BasisWorkspace,
6388) -> Result<LocalSmoothTermBuild, BasisError> {
6389    let inner_term = SmoothTermSpec {
6390        name: term.name.clone(),
6391        basis: (*smooth).clone(),
6392        shape: term.shape,
6393        joint_null_rotation: None,
6394    };
6395    let inner = build_single_local_smooth_term(data, &inner_term, workspace)?;
6396
6397    match by_kind {
6398        ByVarKind::Numeric { feature_col } => {
6399            let inner_meta = inner.metadata.clone();
6400            let mut built = apply_by_variable_to_local_build(
6401                inner,
6402                data,
6403                *feature_col,
6404                &ByVariableSpec::Numeric,
6405                &term.name,
6406            )?;
6407            built.metadata = BasisMetadata::BySmooth {
6408                inner: Box::new(inner_meta),
6409                by_col: *feature_col,
6410                levels: None,
6411                ordered: false,
6412            };
6413            Ok(built)
6414        }
6415        ByVarKind::Factor {
6416            feature_col,
6417            frozen_levels,
6418            ordered,
6419        } => {
6420            // Collect factor levels: prefer the frozen set (replay path), else
6421            // scan the data column (first-fit path).
6422            let level_bits: Vec<u64> = if let Some(fl) = frozen_levels {
6423                fl.clone()
6424            } else {
6425                let col = data.column(*feature_col);
6426                let mut seen = BTreeSet::<u64>::new();
6427                for &v in col.iter() {
6428                    if v.is_finite() {
6429                        seen.insert(v.to_bits());
6430                    }
6431                }
6432                seen.into_iter().collect()
6433            };
6434            let n_levels = level_bits.len();
6435            if n_levels == 0 {
6436                crate::bail_invalid_basis!(
6437                    "by-factor smooth term '{}': factor column {} has no observed levels",
6438                    term.name,
6439                    feature_col
6440                );
6441            }
6442            let p = inner.dim;
6443            let q = n_levels * p;
6444            let n = data.nrows();
6445
6446            let inner_dense = inner
6447                .design
6448                .try_to_dense_by_chunks("by-factor smooth design gating")
6449                .map_err(BasisError::InvalidInput)?;
6450
6451            // Gate each level into its own p-wide column block.
6452            let mut combined = Array2::<f64>::zeros((n, q));
6453            for (lvl_idx, &bits) in level_bits.iter().enumerate() {
6454                let col_start = lvl_idx * p;
6455                for row in 0..n {
6456                    if data[[row, *feature_col]].to_bits() == bits {
6457                        combined
6458                            .slice_mut(s![row, col_start..col_start + p])
6459                            .assign(&inner_dense.row(row));
6460                    }
6461                }
6462            }
6463
6464            // Build per-level INDEPENDENT penalties (#1427): one copy of each
6465            // inner penalty per level, but each confined to that single level's
6466            // diagonal block, so every (level, inner-penalty) pair is its OWN
6467            // smoothing-parameter coordinate. `s(x, by=g)` selects the per-group
6468            // curve wiggliness independently — the design is block-diagonal and
6469            // block-separable, so a correct REML must reproduce gamfit's own
6470            // independent per-group fits. Tiling a single inner penalty across
6471            // every level (as the `bs="fs"` shared-λ random-effect construction
6472            // does) collapses all groups onto ONE λ, which cannot match uneven
6473            // per-level smoothness and degrades as data grows (under-recovery up
6474            // to ~16× at n=2000). Emit `n_levels * n_penalties` blocks instead.
6475            let inner_meta = inner.metadata.clone();
6476            let n_penalties = inner.penalties.len();
6477            let n_blocks = n_penalties.saturating_mul(n_levels);
6478            let mut penalties = Vec::<Array2<f64>>::with_capacity(n_blocks);
6479            let mut penaltyinfo = Vec::<PenaltyInfo>::with_capacity(n_blocks);
6480            let mut nullspaces = Vec::<usize>::with_capacity(n_blocks);
6481            for (pen_pos, s_inner) in inner.penalties.iter().enumerate() {
6482                for lvl in 0..n_levels {
6483                    let off = lvl * p;
6484                    let mut s_big = Array2::<f64>::zeros((q, q));
6485                    s_big
6486                        .slice_mut(s![off..off + p, off..off + p])
6487                        .assign(s_inner);
6488                    let (s_big, scale) = normalize_penalty_in_constrained_space(&s_big);
6489                    let mut info = inner.penaltyinfo[pen_pos].clone();
6490                    // Distinct original_index per (penalty, level) so each λ is a
6491                    // separate identifiable coordinate downstream.
6492                    info.original_index = pen_pos * n_levels + lvl;
6493                    info.normalization_scale *= scale;
6494                    // Each block now spans exactly ONE level → per-level nullity,
6495                    // not the tiled (× n_levels) hint of the shared construction.
6496                    info.kronecker_factors = None;
6497                    penalties.push(s_big);
6498                    penaltyinfo.push(info);
6499                    nullspaces.push(inner.nullspaces[pen_pos]);
6500                }
6501            }
6502
6503            let null_eigenvectors = vec![None; penalties.len()];
6504            let ops = vec![None; penalties.len()];
6505
6506            Ok(LocalSmoothTermBuild {
6507                dim: q,
6508                design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(combined)),
6509                penalties,
6510                ops,
6511                nullspaces,
6512                null_eigenvectors,
6513                joint_null_rotation: None,
6514                penaltyinfo,
6515                pre_dropped_penaltyinfo: inner.pre_dropped_penaltyinfo,
6516                metadata: BasisMetadata::BySmooth {
6517                    inner: Box::new(inner_meta),
6518                    by_col: *feature_col,
6519                    levels: Some(level_bits),
6520                    ordered: *ordered,
6521                },
6522                linear_constraints: None,
6523                box_reparam: false,
6524                kronecker_factored: None,
6525            })
6526        }
6527    }
6528}
6529
6530pub fn ensure_by_variable_specs_match(
6531    kind: &BySmoothKind,
6532    by: &ByVariableSpec,
6533    term_name: &str,
6534) -> Result<(), BasisError> {
6535    match (kind, by) {
6536        (BySmoothKind::Numeric, ByVariableSpec::Numeric) => Ok(()),
6537        (BySmoothKind::Level { level_bits }, ByVariableSpec::Level { value_bits, .. })
6538            if level_bits == value_bits =>
6539        {
6540            Ok(())
6541        }
6542        _ => Err(BasisError::InvalidInput(format!(
6543            "by-variable smooth term '{term_name}' has inconsistent by-variable specifications"
6544        ))),
6545    }
6546}
6547
6548/// Build a factor-smooth interaction basis (`bs="fs"`/`"sz"`/`"re"`).
6549///
6550/// A factor smooth replicates a shared marginal smooth in the continuous
6551/// covariate(s) once per level of a grouping factor, coupling all level blocks
6552/// through a *single* set of smoothing parameters (one per marginal penalty).
6553/// This is mgcv's `smooth.construct.fs.smooth.spec` realization and the
6554/// random-effect interpretation of a smooth: the per-level deviations are an
6555/// exchangeable family whose joint wiggliness/shrinkage is governed by the
6556/// shared λ, so the construction scales to many levels with a fixed parameter
6557/// count.
6558///
6559/// Flavours:
6560/// * `Fs` — full random factor-smooth. The marginal carries its wiggliness
6561///   penalty *and* a null-space ridge (double penalty), so the replicated
6562///   design is a proper full-rank random effect: each level's curve is shrunk
6563///   toward zero (intercept + linear trend included), recovering the mgcv
6564///   `bs="fs"` penalty structure `I_L ⊗ S_j` for every marginal penalty `S_j`.
6565/// * `Sz` — sum-to-zero factor smooth. Delegates to the existing
6566///   [`SmoothBasisSpec::FactorSumToZero`] construction (`L-1` deviation blocks,
6567///   coefficient-wise zero sum across levels).
6568/// * `Re` — pure random effect / random slope (`bs="re"`). A degree-1 marginal
6569///   gives the per-level `[1, x]` span; the penalty is the identity over each
6570///   level block (iid Gaussian coefficients), matching mgcv's `bs="re"` ridge.
6571///
6572/// The grouping levels are resolved once at fit time (sorted unique bit
6573/// patterns of the factor column) and frozen into the returned metadata so the
6574/// predict-time rebuild evaluates every row against its own level's block.
6575pub fn build_factor_smooth(
6576    data: ArrayView2<'_, f64>,
6577    spec: &FactorSmoothSpec,
6578    term_name: &str,
6579    workspace: &mut crate::basis::BasisWorkspace,
6580) -> Result<LocalSmoothTermBuild, BasisError> {
6581    if spec.continuous_cols.len() != 1 {
6582        crate::bail_invalid_basis!(
6583            "factor smooth term '{}' currently supports exactly one continuous covariate; found {}",
6584            term_name,
6585            spec.continuous_cols.len()
6586        );
6587    }
6588    let feature_col = spec.continuous_cols[0];
6589    let group_col = spec.group_col;
6590    if feature_col >= data.ncols() || group_col >= data.ncols() {
6591        crate::bail_dim_basis!(
6592            "factor smooth term '{}' references columns ({}, {}) out of bounds for {} columns",
6593            term_name,
6594            feature_col,
6595            group_col,
6596            data.ncols()
6597        );
6598    }
6599
6600    // `Sz` is exactly the existing sum-to-zero factor smooth: reuse it verbatim
6601    // so there is a single source of truth for the zero-sum construction.
6602    if matches!(spec.flavour, FactorSmoothFlavour::Sz) {
6603        let levels = resolve_factor_smooth_levels(data, group_col, spec, term_name)?;
6604        let inner = SmoothBasisSpec::BSpline1D {
6605            feature_col,
6606            spec: factor_smooth_marginal_for_replay(&spec.marginal),
6607        };
6608        let sz_term = SmoothTermSpec {
6609            name: term_name.to_string(),
6610            basis: SmoothBasisSpec::FactorSumToZero {
6611                inner: Box::new(inner),
6612                by_col: group_col,
6613                levels: levels.clone(),
6614                frozen_global_orthogonality: None,
6615            },
6616            shape: ShapeConstraint::None,
6617            joint_null_rotation: None,
6618        };
6619        let mut built = build_single_local_smooth_term(data, &sz_term, workspace)?;
6620        // The delegated `FactorSumToZero` build returns the BARE inner B-spline
6621        // metadata (`BasisMetadata::BSpline1D`), but the term that owns this
6622        // build carries a `SmoothBasisSpec::FactorSmooth { Sz }` spec. Two
6623        // things break if we hand that mismatched pair downstream:
6624        //   1. `freeze_smooth_basis_from_metadata` matches on (spec, metadata)
6625        //      and has no `(FactorSmooth, BSpline1D)` arm, so any refit / spatial
6626        //      re-optimization that freezes the basis aborts with a "smooth
6627        //      metadata/spec type mismatch" error.
6628        //   2. The bare B-spline metadata carries no grouping levels, so a
6629        //      predict-time rebuild cannot replay the SAME replicated design.
6630        // Re-wrap the marginal geometry as `FactorSmooth` metadata exactly as
6631        // the Fs/Re path below does, giving all three factor-smooth flavours a
6632        // single, freeze-consistent metadata shape that also pins the levels.
6633        // Since #1605 the sz marginal is ALWAYS the penalized B-spline the `fs`
6634        // sibling uses (a natural cubic regression marginal hard-enforces f''=0
6635        // at the boundary and cannot represent curved deviations — a consistency
6636        // failure). The `CubicRegression1D` arm below is therefore unreachable on
6637        // a freshly-built sz spec; it is retained only as defense / backward
6638        // compatibility for a frozen spec that still carries a cr marginal, so
6639        // the predict-time freeze restores whatever marginal class it finds.
6640        let (knots, degree, periodic, marginal_is_cr) = match &built.metadata {
6641            BasisMetadata::BSpline1D {
6642                knots,
6643                periodic,
6644                degree,
6645                ..
6646            } => (
6647                knots.clone(),
6648                degree.unwrap_or(spec.marginal.degree),
6649                *periodic,
6650                false,
6651            ),
6652            BasisMetadata::CubicRegression1D { knots, .. } => {
6653                (knots.clone(), spec.marginal.degree, None, true)
6654            }
6655            other => {
6656                crate::bail_invalid_basis!(
6657                    "sz factor smooth term '{}' produced an unexpected marginal metadata variant {:?}",
6658                    term_name,
6659                    other
6660                );
6661            }
6662        };
6663        built.metadata = BasisMetadata::FactorSmooth {
6664            continuous_cols: spec.continuous_cols.clone(),
6665            group_col,
6666            knots,
6667            degree,
6668            periodic,
6669            group_levels: levels,
6670            flavour: "sz".to_string(),
6671            marginal_is_cr,
6672        };
6673        return Ok(built);
6674    }
6675
6676    let levels = resolve_factor_smooth_levels(data, group_col, spec, term_name)?;
6677    let n_levels = levels.len();
6678    if n_levels < 2 {
6679        crate::bail_invalid_basis!(
6680            "factor smooth term '{}' requires at least two grouping levels; found {}",
6681            term_name,
6682            n_levels
6683        );
6684    }
6685
6686    // `Fs` (order ≥ 1, the default) is the random-effect flavour: it penalizes
6687    // each null-space dimension of the marginal wiggliness penalty separately
6688    // below (mgcv's `bs="fs"` construction). That replaces the marginal's single
6689    // *combined* double penalty, so disable the latter here to avoid penalizing
6690    // the null space twice (once combined, once per dimension). The explicit
6691    // `m=0` opt-out keeps the legacy combined double penalty and adds no
6692    // per-dimension penalties.
6693    let use_per_dim_null = matches!(
6694        &spec.flavour,
6695        FactorSmoothFlavour::Fs { m_null_penalty_orders }
6696            if m_null_penalty_orders.iter().copied().max().unwrap_or(0) >= 1
6697    );
6698
6699    // Build the shared marginal design + penalties from the 1-D B-spline.
6700    // `Re` forces a degree-1 marginal (linear span) and replaces the marginal
6701    // wiggliness with an identity ridge below; `Fs` keeps the user's marginal
6702    // (cubic by default) and, under the per-dimension null path, gets its null
6703    // space penalized one dimension at a time after replication.
6704    let mut marginal_spec = factor_smooth_marginal_for_replay(&spec.marginal);
6705    if use_per_dim_null {
6706        marginal_spec.double_penalty = false;
6707    }
6708    let inner_term = SmoothTermSpec {
6709        name: format!("{term_name}::marginal"),
6710        basis: SmoothBasisSpec::BSpline1D {
6711            feature_col,
6712            spec: marginal_spec,
6713        },
6714        shape: ShapeConstraint::None,
6715        joint_null_rotation: None,
6716    };
6717    let inner = build_single_local_smooth_term(data, &inner_term, workspace)?;
6718    let base = inner
6719        .design
6720        .try_to_dense_by_chunks("factor smooth marginal")
6721        .map_err(BasisError::InvalidInput)?;
6722    let n = base.nrows();
6723    let p = base.ncols();
6724    let q = p * n_levels;
6725
6726    // Block-diagonal replicated design: row i contributes its marginal row to
6727    // the column block owned by its grouping level, zeros elsewhere.
6728    let mut dense = Array2::<f64>::zeros((n, q));
6729    for i in 0..n {
6730        let bits = data[[i, group_col]].to_bits();
6731        let level_idx = levels.iter().position(|b| *b == bits).ok_or_else(|| {
6732            BasisError::InvalidInput(format!(
6733                "factor smooth term '{term_name}' saw an unseen grouping level at row {}",
6734                i + 1
6735            ))
6736        })?;
6737        let start = level_idx * p;
6738        dense
6739            .slice_mut(s![i, start..start + p])
6740            .assign(&base.row(i));
6741    }
6742
6743    // Penalties: replicate each marginal penalty into a block-diagonal
6744    // `I_L ⊗ S_j` so every level shares the same smoothing parameter λ_j (one
6745    // λ per marginal penalty), the defining feature of a factor smooth. For
6746    // `Re` the marginal penalty is replaced by an identity ridge so each
6747    // per-level coefficient is an iid Gaussian random effect.
6748    let marginal_penalties: Vec<Array2<f64>> = if matches!(spec.flavour, FactorSmoothFlavour::Re) {
6749        vec![Array2::<f64>::eye(p)]
6750    } else {
6751        inner.penalties.clone()
6752    };
6753    let marginal_penaltyinfo: Vec<PenaltyInfo> = if matches!(spec.flavour, FactorSmoothFlavour::Re)
6754    {
6755        vec![PenaltyInfo {
6756            source: PenaltySource::Primary,
6757            original_index: 0,
6758            active: true,
6759            effective_rank: p,
6760            dropped_reason: None,
6761            nullspace_dim_hint: 0,
6762            normalization_scale: 1.0,
6763            kronecker_factors: None,
6764        }]
6765    } else {
6766        inner.penaltyinfo.clone()
6767    };
6768    if marginal_penalties.len() != marginal_penaltyinfo.len() {
6769        crate::bail_invalid_basis!(
6770            "internal factor-smooth penalty metadata mismatch for term '{}': penalties={}, infos={}",
6771            term_name,
6772            marginal_penalties.len(),
6773            marginal_penaltyinfo.len()
6774        );
6775    }
6776
6777    let mut penalties = Vec::<Array2<f64>>::with_capacity(marginal_penalties.len());
6778    let mut penaltyinfo = Vec::<PenaltyInfo>::with_capacity(marginal_penalties.len());
6779    for (penalty_pos, s_inner) in marginal_penalties.iter().enumerate() {
6780        let mut s_big = Array2::<f64>::zeros((q, q));
6781        for level in 0..n_levels {
6782            let start = level * p;
6783            s_big
6784                .slice_mut(s![start..start + p, start..start + p])
6785                .assign(s_inner);
6786        }
6787        let (s_big, factor_smooth_scale) = normalize_penalty_in_constrained_space(&s_big);
6788        let mut info = marginal_penaltyinfo[penalty_pos].clone();
6789        info.original_index = penalty_pos;
6790        info.normalization_scale *= factor_smooth_scale;
6791        info.nullspace_dim_hint = info.nullspace_dim_hint.saturating_mul(n_levels);
6792        info.kronecker_factors = None;
6793        penalties.push(s_big);
6794        penaltyinfo.push(info);
6795    }
6796
6797    let mut nullspaces: Vec<usize> = if matches!(spec.flavour, FactorSmoothFlavour::Re) {
6798        vec![0]
6799    } else {
6800        inner
6801            .nullspaces
6802            .iter()
6803            .map(|ns| ns.saturating_mul(n_levels))
6804            .collect()
6805    };
6806
6807    // `Fs` is the random-effect flavour of a smooth: the per-group curve is an
6808    // exchangeable Gaussian *function*, so EVERY coefficient — including the
6809    // {const, linear} null space of the marginal wiggliness penalty — must be
6810    // shrinkable toward zero under its own shared variance. The wiggliness
6811    // penalty `S_wiggle` shapes curvature but leaves the per-group intercept and
6812    // slope (its null space) completely UNPENALIZED. With the null space free,
6813    // each group fits its own intercept and slope with NO partial pooling, so
6814    // the held-out per-subject forecast inherits the full no-pooling variance
6815    // and curves away from the true per-group line (gam#712 real arm, gam#713;
6816    // gam#903 sleepstudy forecast ran ~74% over the lme4 BLUP bar).
6817    //
6818    // mgcv's `bs="fs"` fixes this by penalizing each null-space dimension
6819    // SEPARATELY (`smooth.construct.fs.smooth.spec` adds one rank-1 penalty per
6820    // null coordinate), each replicated block-diagonally across levels under a
6821    // single shared smoothing parameter — so REML fits a distinct
6822    // random-intercept variance and random-slope variance, the partial pooling
6823    // that makes the forecast track lme4's correlated random-effect BLUP. A
6824    // single *combined* null penalty (one λ for intercept+slope together) cannot
6825    // express the typically very different intercept and slope variances, which
6826    // is the residual forecast gap. We mirror mgcv exactly: for each orthonormal
6827    // null direction `z_k` of the marginal wiggliness penalty add
6828    // `I_L ⊗ (z_k z_kᵀ)` as its own penalty. The marginal's combined double
6829    // penalty was disabled above, so the null space is penalized once, per
6830    // dimension. With linear data REML drives the curvature λ up and degrades
6831    // `fs` to a linear random slope (edf → ≈2/group); with genuine curvature the
6832    // wiggliness λ stays small and the wiggle survives (data-adaptive, not a
6833    // cap). Gated by `m_null_penalty_orders`: order ≥ 1 (default) enables the
6834    // per-dimension null penalties; `m=0` keeps the legacy combined double
6835    // penalty and adds nothing here.
6836    if use_per_dim_null
6837        && let Some(Some(z)) = inner.null_eigenvectors.first()
6838        && z.nrows() == p
6839    {
6840        for k in 0..z.ncols() {
6841            // Rank-1 marginal penalty `z_k z_kᵀ`, replicated block-diagonally
6842            // across levels into `I_L ⊗ (z_k z_kᵀ)`. Its own λ is one shared
6843            // variance for this null component (intercept or slope) across all
6844            // groups — the random-effect structure of mgcv `fs`.
6845            let zk = z.column(k);
6846            let mut p_k = Array2::<f64>::zeros((p, p));
6847            for a in 0..p {
6848                for b in 0..p {
6849                    p_k[[a, b]] = zk[a] * zk[b];
6850                }
6851            }
6852            let mut s_null = Array2::<f64>::zeros((q, q));
6853            for level in 0..n_levels {
6854                let start = level * p;
6855                s_null
6856                    .slice_mut(s![start..start + p, start..start + p])
6857                    .assign(&p_k);
6858            }
6859            let (s_null, null_scale) = normalize_penalty_in_constrained_space(&s_null);
6860            let null_block = crate::basis::analyze_penalty_block_with_op(&s_null, None)?;
6861            if null_block.rank > 0 {
6862                let original_index = penalties.len();
6863                penalties.push(null_block.sym_penalty);
6864                nullspaces.push(null_block.nullity);
6865                penaltyinfo.push(PenaltyInfo {
6866                    source: PenaltySource::Primary,
6867                    original_index,
6868                    active: true,
6869                    effective_rank: null_block.rank,
6870                    dropped_reason: None,
6871                    nullspace_dim_hint: null_block.nullity,
6872                    normalization_scale: null_scale,
6873                    kronecker_factors: None,
6874                });
6875            }
6876        }
6877    }
6878    let null_eigenvectors = crate::basis::recompute_null_eigenvectors(&penalties)?;
6879    let joint_null_rotation = crate::basis::compute_joint_null_rotation(&penalties)?;
6880
6881    // Metadata: carry the marginal knot geometry + frozen levels so prediction
6882    // reconstructs an identical replicated design.
6883    let (knots, degree, periodic) = match &inner.metadata {
6884        BasisMetadata::BSpline1D {
6885            knots,
6886            periodic,
6887            degree,
6888            ..
6889        } => (
6890            knots.clone(),
6891            degree.unwrap_or(spec.marginal.degree),
6892            *periodic,
6893        ),
6894        other => {
6895            crate::bail_invalid_basis!(
6896                "factor smooth term '{}' produced an unexpected marginal metadata variant {:?}",
6897                term_name,
6898                other
6899            );
6900        }
6901    };
6902    let flavour_tag = match &spec.flavour {
6903        FactorSmoothFlavour::Fs { .. } => "fs",
6904        FactorSmoothFlavour::Sz => "sz",
6905        FactorSmoothFlavour::Re => "re",
6906    }
6907    .to_string();
6908    let metadata = BasisMetadata::FactorSmooth {
6909        continuous_cols: spec.continuous_cols.clone(),
6910        group_col,
6911        knots,
6912        degree,
6913        periodic,
6914        group_levels: levels,
6915        flavour: flavour_tag,
6916        // fs/re marginals are always B-spline; the cr marginal is sz-only and
6917        // handled on the dedicated Sz path above.
6918        marginal_is_cr: false,
6919    };
6920
6921    let ops = vec![None; penalties.len()];
6922    Ok(LocalSmoothTermBuild {
6923        dim: q,
6924        design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense)),
6925        penalties,
6926        ops,
6927        nullspaces,
6928        null_eigenvectors,
6929        joint_null_rotation,
6930        penaltyinfo,
6931        pre_dropped_penaltyinfo: Vec::new(),
6932        metadata,
6933        linear_constraints: None,
6934        box_reparam: false,
6935        kronecker_factored: None,
6936    })
6937}
6938
6939/// Resolve the grouping levels for a factor smooth: replay the frozen level
6940/// list when present (predict path), otherwise discover the sorted unique bit
6941/// patterns of the factor column (fit path).
6942pub fn resolve_factor_smooth_levels(
6943    data: ArrayView2<'_, f64>,
6944    group_col: usize,
6945    spec: &FactorSmoothSpec,
6946    term_name: &str,
6947) -> Result<Vec<u64>, BasisError> {
6948    if let Some(frozen) = &spec.group_frozen_levels {
6949        if frozen.is_empty() {
6950            crate::bail_invalid_basis!(
6951                "factor smooth term '{}' has an empty frozen level list",
6952                term_name
6953            );
6954        }
6955        return Ok(frozen.clone());
6956    }
6957    let mut bits: Vec<u64> = data.column(group_col).iter().map(|v| v.to_bits()).collect();
6958    bits.sort_by(|a, b| {
6959        f64::from_bits(*a)
6960            .partial_cmp(&f64::from_bits(*b))
6961            .unwrap_or(std::cmp::Ordering::Equal)
6962    });
6963    bits.dedup();
6964    Ok(bits)
6965}
6966
6967/// Marginal B-spline spec for a factor-smooth block. The marginal always builds
6968/// without an identifiability constraint (the per-level replication, not a
6969/// sum-to-zero side constraint, provides identifiability against the parametric
6970/// block). At predict time the marginal's knot geometry has already been pinned
6971/// into `marginal.knotspec` by the metadata replay, so the spec is used
6972/// verbatim aside from clearing the identifiability transform.
6973pub fn factor_smooth_marginal_for_replay(marginal: &BSplineBasisSpec) -> BSplineBasisSpec {
6974    let mut m = marginal.clone();
6975    m.identifiability = BSplineIdentifiability::None;
6976    m
6977}
6978
6979pub fn build_single_local_smooth_term(
6980    data: ArrayView2<'_, f64>,
6981    term: &SmoothTermSpec,
6982    workspace: &mut crate::basis::BasisWorkspace,
6983) -> Result<LocalSmoothTermBuild, BasisError> {
6984    if term.shape != ShapeConstraint::None && !shape_supports_basis(term) {
6985        crate::bail_invalid_basis!(
6986            "ShapeConstraint::{:?} is unsupported for term '{}'",
6987            term.shape,
6988            term.name
6989        );
6990    }
6991    if let SmoothBasisSpec::ByVariable {
6992        inner,
6993        by_col,
6994        kind,
6995        by,
6996    } = &term.basis
6997    {
6998        ensure_by_variable_specs_match(kind, by, &term.name)?;
6999        let mut inner_basis = (**inner).clone();
7000        // Factor-level `by=` owns model-space centering (it centers the gated
7001        // block against the level indicator downstream). Defer the inner
7002        // basis's default pooled centering so the level block is not
7003        // double-centered down to `k-2` columns (#1427). Numeric-by smooths are
7004        // untouched: they are not row-gated to a level and keep ordinary
7005        // intercept centering.
7006        if matches!(by, ByVariableSpec::Level { .. }) {
7007            defer_inner_model_centering_to_factor_level_wrapper(&mut inner_basis);
7008        }
7009        let inner_term = SmoothTermSpec {
7010            name: term.name.clone(),
7011            basis: inner_basis,
7012            shape: term.shape,
7013            joint_null_rotation: None,
7014        };
7015        let built = build_single_local_smooth_term(data, &inner_term, workspace)?;
7016        return apply_by_variable_to_local_build(built, data, *by_col, by, &term.name);
7017    }
7018
7019    // BySmooth: a `by=` smooth that unifies numeric or factor modulation into a
7020    // single term.  Lower it here so the downstream match does not need an arm.
7021    if let SmoothBasisSpec::BySmooth { smooth, by_kind } = &term.basis {
7022        return build_by_smooth_local(data, term, smooth, by_kind, workspace);
7023    }
7024
7025    let mut shape_axis_col: Option<usize> = None;
7026    let mut built: BasisBuildResult = match &term.basis {
7027        SmoothBasisSpec::FactorSumToZero {
7028            inner,
7029            by_col,
7030            levels,
7031            ..
7032        } => {
7033            if *by_col >= data.ncols() {
7034                crate::bail_dim_basis!(
7035                    "term '{}' by column {} out of bounds for {} columns",
7036                    term.name,
7037                    by_col,
7038                    data.ncols()
7039                );
7040            }
7041            if levels.len() < 2 {
7042                crate::bail_invalid_basis!(
7043                    "sum-to-zero factor smooth term '{}' requires at least two levels",
7044                    term.name
7045                );
7046            }
7047            if term.shape != ShapeConstraint::None {
7048                crate::bail_invalid_basis!(
7049                    "ShapeConstraint::{:?} is unsupported for sum-to-zero factor smooth term '{}'",
7050                    term.shape,
7051                    term.name
7052                );
7053            }
7054            let inner_term = SmoothTermSpec {
7055                name: format!("{}::inner", term.name),
7056                basis: (**inner).clone(),
7057                shape: ShapeConstraint::None,
7058                joint_null_rotation: None,
7059            };
7060            let mut inner_built = build_single_local_smooth_term(data, &inner_term, workspace)?;
7061            // Capture the marginal penalty's null directions BEFORE the penalty
7062            // vector is rebuilt below; the sum-to-zero null-space ridge replicates
7063            // these `z_k` into the contrast space (mgcv `bs="fs"` double-penalty).
7064            let inner_null_eigenvectors = inner_built.null_eigenvectors.clone();
7065            let base = inner_built
7066                .design
7067                .try_to_dense_by_chunks("sum-to-zero factor smooth")
7068                .map_err(BasisError::InvalidInput)?;
7069            let n = base.nrows();
7070            let p = base.ncols();
7071            let l_minus_one = levels.len() - 1;
7072            let mut dense = Array2::<f64>::zeros((n, p * l_minus_one));
7073            for i in 0..n {
7074                let bits = data[[i, *by_col]].to_bits();
7075                let level_idx = levels.iter().position(|b| *b == bits).ok_or_else(|| {
7076                    BasisError::InvalidInput(format!(
7077                        "sum-to-zero factor smooth term '{}' saw an unseen level at row {}",
7078                        term.name,
7079                        i + 1
7080                    ))
7081                })?;
7082                if level_idx < l_minus_one {
7083                    let start = level_idx * p;
7084                    dense
7085                        .slice_mut(s![i, start..start + p])
7086                        .assign(&base.row(i));
7087                } else {
7088                    for level in 0..l_minus_one {
7089                        let start = level * p;
7090                        dense
7091                            .slice_mut(s![i, start..start + p])
7092                            .assign(&base.row(i).mapv(|v| -v));
7093                    }
7094                }
7095            }
7096            let mut penalties = Vec::<Array2<f64>>::with_capacity(inner_built.penalties.len());
7097            let active_penalty_indices = inner_built
7098                .penaltyinfo
7099                .iter()
7100                .enumerate()
7101                .filter_map(|(idx, info)| info.active.then_some(idx))
7102                .collect::<Vec<_>>();
7103            if active_penalty_indices.len() != inner_built.penalties.len() {
7104                crate::bail_invalid_basis!(
7105                    "internal sz penalty metadata mismatch: activeinfos={}, penalties={}",
7106                    active_penalty_indices.len(),
7107                    inner_built.penalties.len()
7108                );
7109            }
7110            // Replicate each marginal penalty into the sum-to-zero contrast
7111            // space. With `L-1` free deviation blocks and the reference level
7112            // `d_L = -Σ_{k<L} d_k`, the marginal penalty summed over ALL `L`
7113            // levels, `Σ_{k=1}^{L} d_kᵀ S d_k`, expands to the `(I + 11ᵀ) ⊗ S`
7114            // contrast form (factor 2 on the diagonal blocks, 1 off-diagonal).
7115            //
7116            // PER-GROUP SMOOTHING PARAMETERS (#1074). mgcv's `bs="sz"` does NOT
7117            // pool that sum under one λ: `smooth.construct.sz` emits ONE penalty
7118            // matrix per factor level (here 6 separate `S`s, each with its own
7119            // smoothing parameter), so REML can shrink a low-amplitude group's
7120            // deviation curve hard while leaving a high-amplitude group nearly
7121            // unpenalized. A single shared wiggliness λ (the old construction)
7122            // forces every group to the SAME curvature budget, so a group whose
7123            // true curve is flat drags curvature into the noise of the busy
7124            // groups and vice-versa — systematic truth-recovery loss even when
7125            // the pooled total edf matches mgcv's (the observed `sz` 1.23× gap).
7126            //
7127            // We mirror mgcv exactly by splitting the per-marginal penalty
7128            // `Σ_{k=1}^{L} d_kᵀ S d_k` back into its `L` independent
7129            // rank-controlled summands BEFORE mapping to the contrast space, each
7130            // carrying its own λ:
7131            //   * level k < L (free block):  `d_kᵀ S d_k` → block-diagonal
7132            //     `(e_k e_kᵀ) ⊗ S`  (only the (k,k) block is `S`).
7133            //   * level L (reference):       `d_Lᵀ S d_L = (Σ_{j<L} d_j)ᵀ S (·)`
7134            //     → the fully-coupled `(11ᵀ) ⊗ S` block.
7135            // Summed at equal λ these `L` blocks recover the old `(I + 11ᵀ) ⊗ S`
7136            // exactly (`Σ_k e_k e_kᵀ = I`), so this is a strict generalization:
7137            // the pooled fit is still reachable, REML only GAINS the freedom to
7138            // spend curvature per group. The zero-sum reparameterization (hence
7139            // the `sz` vs `fs` identifiability) is untouched.
7140            //
7141            // `which_level ∈ 0..=l_minus_one`: `< l_minus_one` selects the single
7142            // free deviation block; `== l_minus_one` selects the reference-level
7143            // coupling block.
7144            let stz_per_group_penalty = |s_inner: &Array2<f64>, which_level: usize| -> Array2<f64> {
7145                let mut s_big = Array2::<f64>::zeros((p * l_minus_one, p * l_minus_one));
7146                if which_level < l_minus_one {
7147                    // (e_k e_kᵀ) ⊗ S: a single diagonal block.
7148                    let k = which_level;
7149                    let mut block = s_big.slice_mut(s![k * p..(k + 1) * p, k * p..(k + 1) * p]);
7150                    block.assign(s_inner);
7151                } else {
7152                    // (11ᵀ) ⊗ S: every block (diagonal and off-diagonal) is S.
7153                    for a in 0..l_minus_one {
7154                        for b in 0..l_minus_one {
7155                            let mut block =
7156                                s_big.slice_mut(s![a * p..(a + 1) * p, b * p..(b + 1) * p]);
7157                            block.assign(s_inner);
7158                        }
7159                    }
7160                }
7161                s_big
7162            };
7163            // One nullspace-dim entry per emitted penalty (must stay parallel to
7164            // `penalties`). Each per-group wiggliness block carries the marginal's
7165            // OWN nullity (a rank-`p` block touching a single level for the free
7166            // blocks; the coupling block is rank-`p` over the diagonal sum), and
7167            // the null ridges below record their own nullity.
7168            let mut nullspaces = Vec::<usize>::with_capacity(penalties.capacity());
7169            for (penalty_pos, s_inner) in inner_built.penalties.iter().enumerate() {
7170                let info_idx = active_penalty_indices[penalty_pos];
7171                let base_info = inner_built.penaltyinfo[info_idx].clone();
7172                let marginal_nullity = inner_built.nullspaces.get(penalty_pos).copied().unwrap_or(0);
7173                // Emit `L` independent per-level blocks for this marginal penalty.
7174                for which_level in 0..=l_minus_one {
7175                    let raw = stz_per_group_penalty(s_inner, which_level);
7176                    let (s_big, group_scale) = normalize_penalty_in_constrained_space(&raw);
7177                    let block = crate::basis::analyze_penalty_block_with_op(&s_big, None)?;
7178                    if block.rank == 0 {
7179                        continue;
7180                    }
7181                    if which_level == 0 {
7182                        // Reuse the marginal's own info slot for the first block so
7183                        // the existing normalization bookkeeping stays attached.
7184                        inner_built.penaltyinfo[info_idx].normalization_scale *= group_scale;
7185                        inner_built.penaltyinfo[info_idx].original_index = penalties.len();
7186                        inner_built.penaltyinfo[info_idx].effective_rank = block.rank;
7187                        inner_built.penaltyinfo[info_idx].nullspace_dim_hint = block.nullity;
7188                    } else {
7189                        let mut info = base_info.clone();
7190                        info.original_index = penalties.len();
7191                        info.normalization_scale = base_info.normalization_scale * group_scale;
7192                        info.effective_rank = block.rank;
7193                        info.nullspace_dim_hint = block.nullity;
7194                        info.kronecker_factors = None;
7195                        inner_built.penaltyinfo.push(info);
7196                    }
7197                    penalties.push(block.sym_penalty);
7198                    // The coupling block (which_level == l_minus_one) spans the
7199                    // marginal range on the diagonal-sum direction; the free
7200                    // blocks touch one level. Both leave the marginal null space
7201                    // unpenalized, recorded here so the null ridges below complete
7202                    // the double penalty.
7203                    nullspaces.push(marginal_nullity);
7204                }
7205            }
7206
7207            // Null-space ridge, mirroring the `bs="fs"` double-penalty
7208            // construction (#1605, same defect class as #700/#712/#713). The
7209            // marginal wiggliness penalty `S` shapes curvature but leaves the
7210            // {const, linear} null space of each deviation curve COMPLETELY
7211            // unpenalized. With that null space free, the single combined
7212            // wiggliness smoothing parameter cannot separate the per-group
7213            // intercept/slope variance from the curvature variance, so REML
7214            // parks the wiggliness `λ` high — over-smoothing (under-fitting) the
7215            // deviation blocks even when the truth lives in their span (the `sz`
7216            // recovery gap vs the `fs` superset). mgcv's `bs="fs"` fixes the
7217            // analogous gap by penalizing each null-space dimension SEPARATELY
7218            // under its own shared variance; we mirror that here while keeping
7219            // the zero-sum reparameterization, so the constraint (and the
7220            // identifiability of `sz` vs `fs`) is preserved. For each orthonormal
7221            // null direction `z_k` of the marginal penalty, add the rank-1
7222            // marginal penalty `z_k z_kᵀ` mapped into the SAME `(I + 11ᵀ)`
7223            // sum-to-zero contrast space, each carrying its own `λ`.
7224            if let Some(Some(z)) = inner_null_eigenvectors.first()
7225                && z.nrows() == p
7226            {
7227                for k in 0..z.ncols() {
7228                    let zk = z.column(k);
7229                    let mut p_k = Array2::<f64>::zeros((p, p));
7230                    for a in 0..p {
7231                        for b in 0..p {
7232                            p_k[[a, b]] = zk[a] * zk[b];
7233                        }
7234                    }
7235                    // Null ridges stay POOLED (the `(I + 11ᵀ) ⊗ z_k z_kᵀ` form):
7236                    // they govern the per-group intercept/slope shrinkage, which
7237                    // mgcv pools under one variance even for `sz`; only the
7238                    // curvature (wiggliness) penalty is split per group above.
7239                    let stz_pooled_null = {
7240                        let mut s_big = Array2::<f64>::zeros((p * l_minus_one, p * l_minus_one));
7241                        for a in 0..l_minus_one {
7242                            for b in 0..l_minus_one {
7243                                let factor = if a == b { 2.0 } else { 1.0 };
7244                                let mut block =
7245                                    s_big.slice_mut(s![a * p..(a + 1) * p, b * p..(b + 1) * p]);
7246                                block.assign(&p_k.mapv(|v| v * factor));
7247                            }
7248                        }
7249                        s_big
7250                    };
7251                    let (s_null, null_scale) =
7252                        normalize_penalty_in_constrained_space(&stz_pooled_null);
7253                    let null_block = crate::basis::analyze_penalty_block_with_op(&s_null, None)?;
7254                    if null_block.rank > 0 {
7255                        let original_index = penalties.len();
7256                        penalties.push(null_block.sym_penalty);
7257                        nullspaces.push(null_block.nullity);
7258                        inner_built.penaltyinfo.push(PenaltyInfo {
7259                            source: PenaltySource::Primary,
7260                            original_index,
7261                            active: true,
7262                            effective_rank: null_block.rank,
7263                            dropped_reason: None,
7264                            nullspace_dim_hint: null_block.nullity,
7265                            normalization_scale: null_scale,
7266                            kronecker_factors: None,
7267                        });
7268                    }
7269                }
7270            }
7271            inner_built.dim = p * l_minus_one;
7272            inner_built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense));
7273            inner_built.penalties = penalties;
7274            inner_built.ops = vec![None; inner_built.penalties.len()];
7275            inner_built.nullspaces = nullspaces;
7276            // Invariant: `null_eigenvectors[k]` must mirror `penalties[k]`'s
7277            // spectral null space. We just rebuilt `inner_built.penalties` from
7278            // Kronecker-like `S_big` blocks, so the previously-plumbed
7279            // `null_eigenvectors` (still parallel to the OLD per-level penalty)
7280            // is stale. Recompute from the rebuilt penalties to restore the
7281            // invariant; ditto for the joint-null absorption rotation.
7282            inner_built.null_eigenvectors =
7283                crate::basis::recompute_null_eigenvectors(&inner_built.penalties)?;
7284            inner_built.joint_null_rotation =
7285                crate::basis::compute_joint_null_rotation(&inner_built.penalties)?;
7286            inner_built.kronecker_factored = None;
7287            return Ok(inner_built);
7288        }
7289        SmoothBasisSpec::BSpline1D { feature_col, spec } => {
7290            if *feature_col >= data.ncols() {
7291                crate::bail_dim_basis!(
7292                    "term '{}' feature column {} out of bounds for {} columns",
7293                    term.name,
7294                    feature_col,
7295                    data.ncols()
7296                );
7297            }
7298            let mut spec_local = spec.clone();
7299            if term.shape != ShapeConstraint::None {
7300                // Shape-constrained B-splines are anchored by construction.
7301                // Sum-to-zero side constraints conflict with monotonic/convex cones.
7302                spec_local.identifiability = BSplineIdentifiability::None;
7303            }
7304            // Endpoint boundary conditions are structural for B-splines: the
7305            // basis builder bakes their homogeneous nullspace transform into
7306            // the design, penalties, and stored raw-basis transform.
7307            build_bspline_basis_1d(data.column(*feature_col), &spec_local)?
7308        }
7309        SmoothBasisSpec::ThinPlate {
7310            feature_cols,
7311            spec,
7312            input_scales,
7313        } => {
7314            if term.shape != ShapeConstraint::None {
7315                if feature_cols.len() != 1 {
7316                    crate::bail_invalid_basis!(
7317                        "ShapeConstraint::{:?} for term '{}' on ThinPlate basis requires exactly 1 feature axis; found {}",
7318                        term.shape,
7319                        term.name,
7320                        feature_cols.len()
7321                    );
7322                }
7323                shape_axis_col = Some(feature_cols[0]);
7324            }
7325            let mut x = select_columns(data, feature_cols)?;
7326            // Auto-standardize multivariate inputs: use stored scales (prediction)
7327            // or compute fresh ones (training). Same standardization-vs-
7328            // length-scale compensation as Matérn / hybrid Duchon: divide
7329            // the user's L by σ_geom so kernel(‖x_std − c_std‖/L_eff)
7330            // matches the original-coord kernel for uniform σ.
7331            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7332                apply_input_standardization(&mut x, s);
7333                (
7334                    Some(s.clone()),
7335                    compensate_length_scale_for_standardization(spec.length_scale, s),
7336                )
7337            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7338                apply_input_standardization(&mut x, &s);
7339                let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
7340                (Some(s), l_eff)
7341            } else {
7342                (None, spec.length_scale)
7343            };
7344            let mut spec_local = spec.clone();
7345            spec_local.length_scale = length_scale_eff;
7346            if matches!(
7347                spec_local.identifiability,
7348                SpatialIdentifiability::OrthogonalToParametric
7349            ) {
7350                spec_local.identifiability = SpatialIdentifiability::None;
7351            }
7352            let mut result = build_thin_plate_basis(x.view(), &spec_local).map_err(|err| {
7353                rewrite_thin_plate_knots_error(err, &term.name, feature_cols.len(), spec)
7354            })?;
7355            // Inject input scales into metadata; also restore the user's
7356            // original length_scale (not the σ_geom-compensated one) so a
7357            // metadata-driven rebuild that re-applies compensation does not
7358            // double-divide. The build may auto-promote to Duchon when
7359            // canonical TPS is infeasible (k < polynomial-nullspace size);
7360            // in that case patch the Duchon metadata variant so predict-time
7361            // round-trips through the same standardized data path.
7362            match &mut result.metadata {
7363                BasisMetadata::ThinPlate {
7364                    input_scales: ms,
7365                    length_scale,
7366                    ..
7367                } => {
7368                    *ms = scales;
7369                    *length_scale = spec.length_scale;
7370                }
7371                BasisMetadata::Duchon {
7372                    input_scales: ms,
7373                    length_scale,
7374                    ..
7375                } => {
7376                    // Auto-promotion (canonical TPS infeasible at this (d, k)).
7377                    // Since #1091 the promotion does NOT forward the incoming
7378                    // σ_geom-compensated `spec_local.length_scale` to the Duchon
7379                    // builder — it DISCARDS it and substitutes the geometric mean
7380                    // of the center pairwise distances (`promotion_length_scale`,
7381                    // the natural radial-kernel scale where κ·r ≈ O(1)). So the
7382                    // realized kernel bandwidth recorded in this metadata bears no
7383                    // fixed relation to the user's `spec.length_scale`; clobbering
7384                    // it to the user-facing value (the pre-#1091 behavior) makes
7385                    // freeze→replay re-derive `compensate(spec.length_scale, σ) ≠
7386                    // promotion_length_scale`, evaluating the kernel at the wrong
7387                    // bandwidth and corrupting the replayed design (#1091 broke the
7388                    // e7ff5ed83 freeze contract for the auto-promoted path).
7389                    //
7390                    // The freeze→replay round trip rebuilds through the Duchon arm,
7391                    // which re-applies σ_geom compensation:
7392                    //   replay_eff = compensate(metadata.length_scale, σ)
7393                    //              = metadata.length_scale / σ_geom.
7394                    // For replay_eff to reproduce the realized `promotion_length_scale`
7395                    // we must store the UN-compensated value `promotion_length_scale
7396                    // · σ_geom`. `compensate(1.0, σ) = 1/σ_geom`, so divide the
7397                    // realized scale by it to multiply back through σ_geom. With no
7398                    // standardization (`scales == None`) replay does not compensate,
7399                    // so the realized value is kept verbatim.
7400                    if let (Some(s), Some(realized)) = (scales.as_ref(), *length_scale) {
7401                        let inv_sigma_geom =
7402                            compensate_length_scale_for_standardization(1.0, s);
7403                        if inv_sigma_geom.is_finite() && inv_sigma_geom > 0.0 {
7404                            *length_scale = Some(realized / inv_sigma_geom);
7405                        }
7406                    }
7407                    *ms = scales;
7408                }
7409                _ => {}
7410            }
7411            result
7412        }
7413        SmoothBasisSpec::Sphere { feature_cols, spec } => {
7414            if term.shape != ShapeConstraint::None {
7415                crate::bail_invalid_basis!(
7416                    "ShapeConstraint::{:?} for term '{}' is not supported on spherical splines",
7417                    term.shape,
7418                    term.name
7419                );
7420            }
7421            let x = select_columns(data, feature_cols)?;
7422            build_spherical_spline_basis(x.view(), spec)?
7423        }
7424        SmoothBasisSpec::ConstantCurvature { feature_cols, spec } => {
7425            if term.shape != ShapeConstraint::None {
7426                crate::bail_invalid_basis!(
7427                    "ShapeConstraint::{:?} for term '{}' is not supported on constant-curvature smooths",
7428                    term.shape,
7429                    term.name
7430                );
7431            }
7432            // Chart coordinates are consumed verbatim: NO auto-standardization.
7433            // Rescaling axes would change the chart gauge `1 + κ‖x‖²` and
7434            // silently redefine which curvature κ refers to (the same point
7435            // cloud at a different chart scale has a different κ̂); the user's
7436            // coordinates ARE the geometry here, exactly as for the sphere
7437            // smooth's (lat, lon).
7438            let x = select_columns(data, feature_cols)?;
7439            build_constant_curvature_basis(x.view(), spec)?
7440        }
7441        SmoothBasisSpec::MeasureJet {
7442            feature_cols,
7443            spec,
7444            input_scales,
7445        } => {
7446            if term.shape != ShapeConstraint::None {
7447                crate::bail_invalid_basis!(
7448                    "ShapeConstraint::{:?} for term '{}' is not supported on measure-jet smooths",
7449                    term.shape,
7450                    term.name
7451                );
7452            }
7453            let mut x = select_columns(data, feature_cols)?;
7454            // Matern-style per-axis standardization; the realized σ vector is
7455            // persisted into the metadata for predict-time replay.
7456            //
7457            // Length-scale round-trip contract (owning statement; the freeze
7458            // and frozen-validation arms reference it): `input_scales: Some`
7459            // marks the REPLAY path — the frozen length_scale is already the
7460            // realized post-standardization value and passes through
7461            // verbatim. Fresh path: an explicit user length_scale is in
7462            // ORIGINAL coordinates and gets the σ_geom compensation; the 0.0
7463            // auto sentinel passes through (auto-derivation runs inside the
7464            // builder, post-standardization).
7465            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7466                apply_input_standardization(&mut x, s);
7467                (Some(s.clone()), spec.length_scale)
7468            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7469                apply_input_standardization(&mut x, &s);
7470                let l_eff = if spec.length_scale > 0.0 {
7471                    compensate_length_scale_for_standardization(spec.length_scale, &s)
7472                } else {
7473                    spec.length_scale
7474                };
7475                (Some(s), l_eff)
7476            } else {
7477                (None, spec.length_scale)
7478            };
7479            let mut spec_local = spec.clone();
7480            spec_local.length_scale = length_scale_eff;
7481            let mut result = build_measure_jet_basis(x.view(), &spec_local)?;
7482            if let BasisMetadata::MeasureJet {
7483                input_scales: ms, ..
7484            } = &mut result.metadata
7485            {
7486                *ms = scales;
7487            }
7488            result
7489        }
7490        SmoothBasisSpec::Matern {
7491            feature_cols,
7492            spec,
7493            input_scales,
7494        } => {
7495            if term.shape != ShapeConstraint::None {
7496                if feature_cols.len() != 1 {
7497                    crate::bail_invalid_basis!(
7498                        "ShapeConstraint::{:?} for term '{}' on Matern basis requires exactly 1 feature axis; found {}",
7499                        term.shape,
7500                        term.name,
7501                        feature_cols.len()
7502                    );
7503                }
7504                shape_axis_col = Some(feature_cols[0]);
7505            }
7506            let mut x = select_columns(data, feature_cols)?;
7507            // Auto-standardization (per-axis division by σ_a) reinterprets
7508            // the user's `length_scale` from original data coordinates
7509            // into post-standardization coordinates: for uniform σ_a = σ,
7510            // `kernel(‖x_std − c_std‖/L)` equals `kernel(‖x − c‖/(σ·L))`,
7511            // so the effective kernel range shrinks by σ. To keep
7512            // `length_scale` consistently expressed in *original* data
7513            // coordinates regardless of axis variances, we standardize
7514            // and divide L by σ_geom = (∏σ_a)^(1/d). For uniform σ this
7515            // recovers the user's kernel exactly; for anisotropic data
7516            // the resulting per-axis effective scales σ_a / σ_geom are
7517            // the standard Mahalanobis preconditioning and preserve the
7518            // geometric-mean kernel range. Storing the σ vector in
7519            // metadata.input_scales makes the same transformation
7520            // replayable at predict time.
7521            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7522                apply_input_standardization(&mut x, s);
7523                (
7524                    Some(s.clone()),
7525                    compensate_length_scale_for_standardization(spec.length_scale, s),
7526                )
7527            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7528                apply_input_standardization(&mut x, &s);
7529                let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
7530                (Some(s), l_eff)
7531            } else {
7532                (None, spec.length_scale)
7533            };
7534            let mut spec_local = spec.clone();
7535            spec_local.length_scale = length_scale_eff;
7536            let mut result = build_matern_basiswithworkspace(x.view(), &spec_local, workspace)?;
7537            if let BasisMetadata::Matern {
7538                input_scales,
7539                length_scale,
7540                ..
7541            } = &mut result.metadata
7542            {
7543                *input_scales = scales;
7544                *length_scale = spec.length_scale;
7545            }
7546            result
7547        }
7548        SmoothBasisSpec::Duchon {
7549            feature_cols,
7550            spec,
7551            input_scales,
7552        } => {
7553            if term.shape != ShapeConstraint::None {
7554                if feature_cols.len() != 1 {
7555                    crate::bail_invalid_basis!(
7556                        "ShapeConstraint::{:?} for term '{}' on Duchon basis requires exactly 1 feature axis; found {}",
7557                        term.shape,
7558                        term.name,
7559                        feature_cols.len()
7560                    );
7561                }
7562                shape_axis_col = Some(feature_cols[0]);
7563            }
7564            let mut x = select_columns(data, feature_cols)?;
7565            // Hybrid Duchon (length_scale=Some) is governed by the same
7566            // standardization-vs-length-scale equivalence as Matérn: the
7567            // user's `length_scale` is interpreted in original data
7568            // coordinates, but auto-standardization (per-axis division by
7569            // σ_a) reinterprets it as σ_geom · L. Pre-multiply by 1/σ_geom
7570            // so kernel(‖x_std − c_std‖/L_eff) reproduces the user's
7571            // original-coord kernel exactly for uniform σ_a, and reduces
7572            // to standard Mahalanobis preconditioning for anisotropic σ.
7573            // Pure Duchon (length_scale=None) is scale-free and needs no
7574            // compensation.
7575            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7576                apply_input_standardization(&mut x, s);
7577                (
7578                    Some(s.clone()),
7579                    compensate_optional_length_scale_for_standardization(spec.length_scale, s),
7580                )
7581            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7582                apply_input_standardization(&mut x, &s);
7583                let l_eff =
7584                    compensate_optional_length_scale_for_standardization(spec.length_scale, &s);
7585                (Some(s), l_eff)
7586            } else {
7587                (None, spec.length_scale)
7588            };
7589            let mut spec_local = spec.clone();
7590            spec_local.length_scale = length_scale_eff;
7591            // The Duchon input axis is standardized in place above (`x → x/σ`,
7592            // scale-only, no centering). A 1-D cyclic boundary `[start, end)`
7593            // declared in ORIGINAL covariate units must move into that same
7594            // standardized frame, or the modular wrap in
7595            // `build_cyclic_duchon_basis_1dwithworkspace` folds the standardized
7596            // coordinate against an original-unit period: the seam never closes
7597            // and the basis silently degrades to non-periodic (#1074:
7598            // `duchon(x, periodic=true)` predictions diverged across the wrap,
7599            // f(0) ≠ f(2π)). Rescale by the same 1/σ applied to the data so
7600            // training and predict share one periodic geometry.
7601            if let (Some(s), crate::basis::OneDimensionalBoundary::Cyclic { start, end }) =
7602                (scales.as_ref(), spec_local.boundary.clone())
7603                && s.len() == 1
7604                && s[0] > 0.0
7605            {
7606                spec_local.boundary = crate::basis::OneDimensionalBoundary::Cyclic {
7607                    start: start / s[0],
7608                    end: end / s[0],
7609                };
7610            }
7611            if matches!(
7612                spec_local.identifiability,
7613                SpatialIdentifiability::OrthogonalToParametric
7614            ) {
7615                spec_local.identifiability = SpatialIdentifiability::None;
7616            }
7617            let mut result = build_duchon_basiswithworkspace(x.view(), &spec_local, workspace)?;
7618            if let BasisMetadata::Duchon {
7619                input_scales,
7620                length_scale,
7621                ..
7622            } = &mut result.metadata
7623            {
7624                *input_scales = scales;
7625                *length_scale = spec.length_scale;
7626            }
7627            result
7628        }
7629        SmoothBasisSpec::Pca {
7630            feature_cols,
7631            basis_matrix,
7632            centered,
7633            smooth_penalty,
7634            center_mean,
7635            pca_basis_path,
7636            chunk_size,
7637        } => {
7638            if term.shape != ShapeConstraint::None {
7639                crate::bail_invalid_basis!(
7640                    "ShapeConstraint::{:?} for term '{}' is not supported on Pca basis",
7641                    term.shape,
7642                    term.name
7643                );
7644            }
7645            build_pca_smooth_basis(
7646                data,
7647                feature_cols,
7648                basis_matrix,
7649                *centered,
7650                *smooth_penalty,
7651                center_mean.as_ref(),
7652                pca_basis_path.as_ref(),
7653                *chunk_size,
7654            )?
7655        }
7656        SmoothBasisSpec::TensorBSpline { feature_cols, spec } => {
7657            build_tensor_bspline_basis(data, feature_cols, spec)?
7658        }
7659        SmoothBasisSpec::ByVariable { .. } => {
7660            crate::bail_invalid_basis!(
7661                "internal: ByVariable smooths must return before inner basis dispatch"
7662            );
7663        }
7664        SmoothBasisSpec::BySmooth { .. } => {
7665            crate::bail_invalid_basis!("internal: BySmooth smooths must be lowered to ByVariable before inner basis dispatch"
7666                    .to_string(),);
7667        }
7668        SmoothBasisSpec::FactorSmooth { spec } => {
7669            if term.shape != ShapeConstraint::None {
7670                crate::bail_invalid_basis!(
7671                    "ShapeConstraint::{:?} is unsupported for factor smooth term '{}'",
7672                    term.shape,
7673                    term.name
7674                );
7675            }
7676            return build_factor_smooth(data, spec, &term.name, workspace);
7677        }
7678    };
7679
7680    // The Matérn design ALWAYS uses the operator-collocation {mass, tension,
7681    // stiffness} penalty triplet, overriding whatever penalty
7682    // `build_matern_basis_seeded` produced for the `double_penalty` flag.
7683    //
7684    // #1074 investigated swapping this for the genuine RKHS kernel penalty
7685    // `β' K_CC β` (mgcv `bs="gp"` / fields kriging) on the theory that the
7686    // operator triplet under-smooths the rougher half-integer kernels. MSI
7687    // truth-recovery measurement REFUTED that: the kernel penalty did NOT
7688    // improve ν=3/2 recovery (`matern(x,nu=1.5)` RMSE-vs-truth stayed 0.0554)
7689    // and it REGRESSED the high-frequency-init guard — `matern(x,nu≥5/2)` on
7690    // sin(2π·8·x) collapsed (span 0.53, RMSE 0.70) because the single RKHS
7691    // norm over-smooths a high-frequency truth where the Sobolev-order operator
7692    // dials do not. The operator triplet is therefore retained as the Matérn
7693    // penalty, and the κ-optimizer re-key / ψ-derivative paths route through the
7694    // same triplet builder so the block count stays ψ-stable (#1270).
7695    if let SmoothBasisSpec::Matern { .. } = &term.basis {
7696        let (penalties, nullspace_dims, penaltyinfo) =
7697            matern_operator_penalty_triplet_from_metadata(&built.metadata)?;
7698        built.penalties = penalties;
7699        built.nullspace_dims = nullspace_dims;
7700        built.penaltyinfo = penaltyinfo;
7701    }
7702
7703    let p_local = built.design.ncols();
7704    let mut metadata = built.metadata.clone();
7705    // Extract factored Kronecker representation before consuming fields.
7706    // Invalidate it if shape transforms will be applied (they break structure).
7707    let kron_factored = if term.shape == ShapeConstraint::None {
7708        built.kronecker_factored
7709    } else {
7710        None
7711    };
7712    let mut design_t = built.design;
7713    let mut penalties_t: Vec<Array2<f64>> = built.penalties;
7714    // Ops vector parallel to `penalties_t`. Survives unchanged through the
7715    // identity path; nulled element-wise when `T^T S T` reparametrization
7716    // is applied (operator no longer bit-equivalent to the transformed
7717    // matrix); wrapped in `ScaledPenaltyOp` after Frobenius normalization.
7718    let mut ops_t: Vec<Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>> =
7719        built.ops;
7720    if matches!(
7721        spatial_identifiability_policy(term),
7722        Some(SpatialIdentifiability::OrthogonalToParametric)
7723    ) {
7724        metadata = freeze_raw_spatial_metadata(metadata, design_t.ncols());
7725    }
7726
7727    let active_penaltyinfo_t = built
7728        .penaltyinfo
7729        .iter()
7730        .filter(|info| info.active)
7731        .cloned()
7732        .collect::<Vec<_>>();
7733    let pre_dropped_penaltyinfo_t = built
7734        .penaltyinfo
7735        .iter()
7736        .filter(|info| !info.active)
7737        .cloned()
7738        .collect::<Vec<_>>();
7739    let use_box_reparam =
7740        term.shape != ShapeConstraint::None && shape_uses_box_reparameterization(&term.basis);
7741    if let Some((order, sign)) = shape_order_and_sign(term.shape)
7742        && use_box_reparam
7743    {
7744        // Order 1 (monotone): the plain first-difference cone θ_{i+1}−θ_i ≥ 0 is
7745        // the control-polygon monotonicity criterion, which is independent of
7746        // Greville-abscissa spacing (it only fixes the *sign* of consecutive
7747        // control-point gaps), so the integer-difference transform is exact.
7748        //
7749        // Order 2 (convex/concave): the plain second-difference cone is only
7750        // correct for evenly spaced Greville abscissae. gam's B-splines are
7751        // clamped (and may use quantile knots), so the abscissae are not
7752        // uniform and the geometrically-correct cone is the second *divided*
7753        // difference. Build the Greville-scaled transform so γ_{≥2} ≥ 0
7754        // certifies convexity of the function, not of the raw coefficient
7755        // index. Periodic B-splines use uniform interior knots (uniform
7756        // abscissae), where the divided differences coincide with the integer
7757        // differences up to scale, so the plain path stays exact there.
7758        let t = if order == 2 {
7759            let bspline_meta = match &metadata {
7760                BasisMetadata::BSpline1D {
7761                    knots,
7762                    degree,
7763                    periodic,
7764                    ..
7765                } if periodic.is_none() => Some((knots.clone(), degree.unwrap_or(0))),
7766                _ => None,
7767            };
7768            match bspline_meta {
7769                Some((knots, degree)) if degree >= 1 => {
7770                    let greville = crate::basis::compute_greville_abscissae(&knots, degree)?;
7771                    if greville.len() != p_local {
7772                        crate::bail_invalid_basis!(
7773                            "shape-constraint Greville abscissae count {} does not match basis dim {} for term '{}'",
7774                            greville.len(),
7775                            p_local,
7776                            term.name
7777                        );
7778                    }
7779                    convex_divided_difference_transform_matrix(&greville, sign)?
7780                }
7781                _ => cumulative_sum_transform_matrix(p_local, order, sign),
7782            }
7783        } else {
7784            cumulative_sum_transform_matrix(p_local, order, sign)
7785        };
7786        // Coefficient-side transform: wrap the design in an operator that
7787        // applies T on the coefficient side, preserving sparsity/operator
7788        // structure of the inner design.
7789        let inner_dense = match design_t {
7790            DesignMatrix::Dense(d) => d,
7791            DesignMatrix::Sparse(sp) => gam_linalg::matrix::DenseDesignMatrix::from(
7792                sp.try_to_dense_arc("shape-constrained coefficient transform")
7793                    .map_err(BasisError::InvalidInput)?,
7794            ),
7795        };
7796        let coeff_op = gam_linalg::matrix::CoefficientTransformOperator::new(inner_dense, t.clone())
7797            .map_err(|e| BasisError::InvalidInput(format!("CoefficientTransformOperator: {e}")))?;
7798        design_t = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(coeff_op)));
7799        if penalties_t.len() != active_penaltyinfo_t.len() {
7800            crate::bail_invalid_basis!(
7801                "internal box-reparam penalty/info mismatch for term '{}': penalties={}, infos={}",
7802                term.name,
7803                penalties_t.len(),
7804                active_penaltyinfo_t.len()
7805            );
7806        }
7807        // Wiggliness penalties undergo the exact congruence `S → TᵀST` (PSD
7808        // preserving). The double-penalty *nullspace shrinkage* ridge must NOT:
7809        // it is a unit-eigenvalue projector `ZZᵀ` onto null(S_wiggle) in the
7810        // β (B-spline coefficient) coordinates, and the congruence
7811        // `Tᵀ(ZZᵀ)T = (TᵀZ)(TᵀZ)ᵀ` is no longer a projector — its eigenvalues
7812        // blow up by the conditioning of the cumulative-sum `T` (cond(T) grows
7813        // with the basis dim), concentrating an enormous penalty on the leading
7814        // γ₀ "level" coordinate. REML then drives the shared λ to its ceiling
7815        // and the smooth collapses to a flat constant (#509, the over-smoothing
7816        // face). The principled fix keeps mgcv's double-penalty semantics in the
7817        // *reparametrized* space: rebuild the ridge as the unit-eigenvalue
7818        // nullspace projector of the transformed wiggliness penalty `TᵀST`, so
7819        // the double penalty shrinks exactly the unpenalized polynomial
7820        // directions of the γ-space smooth with eigenvalue 1, identical in
7821        // conditioning to the unconstrained fit.
7822        let transformed_wiggliness = penalties_t
7823            .iter()
7824            .zip(active_penaltyinfo_t.iter())
7825            .find(|(_, info)| !matches!(info.source, PenaltySource::DoublePenaltyNullspace))
7826            .map(|(s_local, _)| {
7827                let tt_s = fast_atb(&t, s_local);
7828                fast_ab(&tt_s, &t)
7829            });
7830        let mut rebuilt = Vec::with_capacity(penalties_t.len());
7831        for (s_local, info) in penalties_t.iter().zip(active_penaltyinfo_t.iter()) {
7832            if matches!(info.source, PenaltySource::DoublePenaltyNullspace) {
7833                // #1654: the double-penalty nullspace ridge under the box
7834                // reparameterization.
7835                //
7836                // For the CURVATURE constraints (order == 2, convex/concave) the
7837                // box transform `T` is the Greville-scaled second *divided*
7838                // difference map (`convex_divided_difference_transform_matrix`).
7839                // Rebuilding the ridge from scratch as the orthonormal null-space
7840                // projector of `TᵀST` (the #509-era monotone fix below) yields a
7841                // γ-space ridge `Z_γ Z_γᵀ` whose null subspace is the affine
7842                // (level γ₀ + slope γ₁) face — the SAME subspace targeted in
7843                // β-space, but measured in the γ inner product rather than the
7844                // β one. A reparameterization `β = Tγ` must leave the penalized
7845                // REML fit invariant, which requires every penalty block to
7846                // transform by the SAME congruence `S ↦ TᵀST`; the from-scratch
7847                // projector rebuild is NOT that congruence, so it silently
7848                // re-weights the level/slope shrinkage relative to the wiggliness
7849                // penalty (each block is independently Frobenius-normalized just
7850                // below, decoupling their scales). The distorted REML λ landscape
7851                // then drives the convex/concave smooth into the flat linear
7852                // corner (curvature pinned ≈ 0, EDF ≈ 1.5) for a
7853                // seed/basis-dimension–specific subset of fits, even though an
7854                // unconstrained `s(x)` on the same data recovers the convex truth
7855                // at EDF ≈ 4. Restoring the exact congruence `Tᵀ R_β T` for the
7856                // ridge keeps the box reparameterization a true invertible
7857                // change of coordinates, so the curvature-constrained fit tracks
7858                // the unconstrained smoothing instead of over-smoothing
7859                // (verified: seed-7/k-20 truth-RMSE 0.31 → 0.045).
7860                //
7861                // For MONOTONE (order == 1) `T` is the cumulative-sum transform
7862                // whose conditioning grows fast with the basis dimension; there
7863                // the congruence concentrates an enormous penalty on the leading
7864                // γ₀ level coordinate and over-smooths to a flat constant (#509),
7865                // which the from-scratch unit-eigenvalue projector rebuild was
7866                // introduced to cure. Keep that path for monotone.
7867                if order == 2 {
7868                    let tt_s = fast_atb(&t, s_local);
7869                    rebuilt.push(fast_ab(&tt_s, &t));
7870                } else {
7871                    let s_wiggle_t = transformed_wiggliness.as_ref().ok_or_else(|| {
7872                        BasisError::InvalidInput(format!(
7873                            "box-reparam term '{}' has a double-penalty ridge but no primary wiggliness penalty to derive its nullspace from",
7874                            term.name
7875                        ))
7876                    })?;
7877                    let ridge = crate::basis::build_nullspace_shrinkage_penalty(s_wiggle_t)?
7878                        .map(|shrink| shrink.sym_penalty)
7879                        .unwrap_or_else(|| Array2::<f64>::zeros((p_local, p_local)));
7880                    rebuilt.push(ridge);
7881                }
7882            } else {
7883                let tt_s = fast_atb(&t, s_local);
7884                rebuilt.push(fast_ab(&tt_s, &t));
7885            }
7886        }
7887        penalties_t = rebuilt;
7888        // T^T S T (and the rebuilt γ-space ridge) invalidate op-form
7889        // bit-equivalence; drop ops here.
7890        ops_t = vec![None; penalties_t.len()];
7891    }
7892    if penalties_t.len() != active_penaltyinfo_t.len() {
7893        crate::bail_invalid_basis!(
7894            "internal penalty metadata mismatch for term '{}': active penalties={}, active infos={}",
7895            term.name,
7896            penalties_t.len(),
7897            active_penaltyinfo_t.len()
7898        );
7899    }
7900    if ops_t.len() != penalties_t.len() {
7901        ops_t = vec![None; penalties_t.len()];
7902    }
7903    let penalty_candidates = penalties_t
7904        .into_iter()
7905        .zip(active_penaltyinfo_t.into_iter())
7906        .zip(ops_t.into_iter())
7907        .map(
7908            |((matrix, info), op_in)| -> Result<PenaltyCandidate, BasisError> {
7909                let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
7910                let normalization_scale = info.normalization_scale * c_new;
7911                let op_scale = 1.0 / c_new;
7912                let kronecker_scale = 1.0 / c_new;
7913                // Frobenius rescale: wrap inner op in `ScaledPenaltyOp(1/c_new)`
7914                // so `op.as_dense() == matrix` post-normalization.
7915                let scaled_op = if op_scale > 0.0 && op_scale.is_finite() {
7916                    op_in.map(|op| {
7917                        std::sync::Arc::new(crate::analytic_penalties::ScaledPenaltyOp::new(
7918                            op, op_scale,
7919                        ))
7920                            as std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>
7921                    })
7922                } else {
7923                    None
7924                };
7925                let kronecker_factors = info.kronecker_factors.map(|mut factors| {
7926                    if let Some(first) = factors.first_mut() {
7927                        first.mapv_inplace(|v| v * kronecker_scale);
7928                    }
7929                    factors
7930                });
7931                Ok(PenaltyCandidate {
7932                    nullspace_dim_hint: info.nullspace_dim_hint,
7933                    matrix,
7934                    source: info.source,
7935                    normalization_scale,
7936                    kronecker_factors,
7937                    op: scaled_op,
7938                })
7939            },
7940        )
7941        .collect::<Result<Vec<_>, _>>()?;
7942    let (penalties_t, nullspaces_t, penaltyinfo_t, null_eigenvectors_t, ops_t) =
7943        crate::basis::filter_active_penalty_candidates_with_ops(penalty_candidates)?;
7944    let shape_linear_constraints = if term.shape != ShapeConstraint::None && !use_box_reparam {
7945        let axis = shape_axis_col.ok_or_else(|| {
7946            BasisError::InvalidInput(format!(
7947                "internal shape-constraint axis missing for term '{}'",
7948                term.name
7949            ))
7950        })?;
7951        let (x_shape_eval, design_shape_eval) =
7952            build_shape_constraint_design_1d(data, term, &metadata, axis)?;
7953        build_shape_linear_constraints_1d(
7954            x_shape_eval.view(),
7955            design_shape_eval.view(),
7956            term.shape,
7957        )?
7958    } else {
7959        None
7960    };
7961    let linear_constraints_local = merge_linear_constraints_global(shape_linear_constraints, None);
7962
7963    // Joint-null absorption rotation. Fresh fit specs compute Q from the final
7964    // per-smooth penalty set (after all in-smooth reparameterizations have
7965    // already been applied). Frozen specs already carry the complete realized
7966    // coefficient chart in their `FrozenTransform`; recomputing Q there would
7967    // rotate an already-frozen chart a second time and desynchronize value
7968    // rebuilds from derivative operators.
7969    //
7970    // Kronecker-factored smooths (tensor B-splines under `TensorBSplineIdentifiability::None`)
7971    // carry their joint penalty as `Σ_d S_d` with `S_d = I ⊗ … ⊗ S_d^{1D} ⊗ … ⊗ I`.
7972    // The joint null space is the tensor of marginal nulls and is handled directly
7973    // by the REML runtime's `kronecker_penalty_system` path (see
7974    // `runtime.rs:8334-8344`). Applying a dense (p × p) Q here would densify
7975    // `X_raw = mx ⊗ my` into `X_raw · Q`, destroying the Kronecker product
7976    // structure that the runtime relies on for fast log-det/derivative
7977    // assembly — and the rotation block at the wrapper site also unconditionally
7978    // wipes `kronecker_factored`, leaving the runtime to fall back to the
7979    // dense per-block log-det. Skip the rotation for Kronecker-factored terms
7980    // so the factored representation survives end-to-end.
7981    let joint_null_rotation = match term.joint_null_rotation.clone() {
7982        Some(persisted) => Some(persisted),
7983        None if smooth_has_frozen_identifiability(term) => None,
7984        None if kron_factored.is_some() => None,
7985        None => crate::basis::compute_joint_null_rotation(&penalties_t)?,
7986    };
7987
7988    Ok(LocalSmoothTermBuild {
7989        dim: p_local,
7990        design: design_t,
7991        penalties: penalties_t,
7992        ops: ops_t,
7993        nullspaces: nullspaces_t,
7994        null_eigenvectors: null_eigenvectors_t,
7995        joint_null_rotation,
7996        penaltyinfo: penaltyinfo_t,
7997        pre_dropped_penaltyinfo: pre_dropped_penaltyinfo_t,
7998        metadata,
7999        linear_constraints: linear_constraints_local,
8000        box_reparam: use_box_reparam,
8001        kronecker_factored: kron_factored,
8002    })
8003}
8004
8005pub fn build_smooth_design(
8006    data: ArrayView2<'_, f64>,
8007    terms: &[SmoothTermSpec],
8008) -> Result<RawSmoothDesign, BasisError> {
8009    let mut ws = crate::basis::BasisWorkspace::new();
8010    build_smooth_design_withworkspace(data, terms, &mut ws)
8011}
8012
8013/// Like `build_smooth_design`, but honors the caller workspace policy while
8014/// building each planned smooth term with an independent per-term workspace.
8015///
8016/// Independent workspaces avoid shared mutable distance-cache state during the
8017/// parallel term build; the final design, penalties, and metadata are assembled
8018/// in the original smooth-term order.
8019pub fn build_smooth_design_withworkspace(
8020    data: ArrayView2<'_, f64>,
8021    terms: &[SmoothTermSpec],
8022    workspace: &mut crate::basis::BasisWorkspace,
8023) -> Result<RawSmoothDesign, BasisError> {
8024    validate_smooth_terms_finite_inputs(data, terms)?;
8025    build_smooth_design_withworkspace_unvalidated(data, terms, workspace)
8026}
8027
8028pub fn build_smooth_design_withworkspace_unvalidated(
8029    data: ArrayView2<'_, f64>,
8030    terms: &[SmoothTermSpec],
8031    workspace: &mut crate::basis::BasisWorkspace,
8032) -> Result<RawSmoothDesign, BasisError> {
8033    let mut planned_blocks = plan_joint_spatial_centers_for_term_blocks(data, &[terms.to_vec()])?;
8034    let planned_terms = planned_blocks.pop().ok_or_else(|| {
8035        BasisError::InvalidInput(
8036            "joint spatial center planner returned no smooth blocks".to_string(),
8037        )
8038    })?;
8039    let policy = workspace.policy().clone();
8040    let local_builds: Vec<LocalSmoothTermBuild> = {
8041        use rayon::iter::{IntoParallelIterator, ParallelIterator};
8042        planned_terms
8043            .into_par_iter()
8044            .map(|term| {
8045                let mut term_workspace = crate::basis::BasisWorkspace::with_policy(policy.clone());
8046                build_single_local_smooth_term(data, &term, &mut term_workspace)
8047            })
8048            .collect::<Result<Vec<_>, _>>()?
8049    };
8050
8051    let total_p: usize = local_builds.iter().map(|built| built.dim).sum();
8052
8053    let mut local_designs: Vec<DesignMatrix> = Vec::with_capacity(local_builds.len());
8054    let mut terms_out = Vec::<SmoothTerm>::with_capacity(terms.len());
8055    let mut penalties_global = Vec::<BlockwisePenalty>::new();
8056    let mut nullspace_dims_global = Vec::<usize>::new();
8057    let mut penaltyinfo_global = Vec::<PenaltyBlockInfo>::new();
8058    let mut dropped_penaltyinfo_global = Vec::<DroppedPenaltyBlockInfo>::new();
8059    let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
8060    let mut any_bounds = false;
8061    // Each linear-constraint row only touches the current term's column slice.
8062    // Track `(col_start, col_end, local_row_values)` and assemble the final
8063    // dense `Array2` in one pass, avoiding per-row `Array1::zeros(total_p)`
8064    // allocation plus a row-by-row copy at the end.
8065    let mut linear_constraintsrows: Vec<(usize, usize, Array1<f64>)> = Vec::new();
8066    let mut linear_constraints_b: Vec<f64> = Vec::new();
8067
8068    let mut col_start = 0usize;
8069    for (term, mut built) in terms.iter().zip(local_builds.into_iter()) {
8070        let p_local = built.dim;
8071        let col_end = col_start + p_local;
8072        let lb_local = if built.box_reparam {
8073            shape_lower_bounds_local(term.shape, p_local)
8074        } else {
8075            None
8076        };
8077
8078        // Stage-2 joint-null absorption rotation. Fired *before* the
8079        // penalty / design / global aggregation loops below so that every
8080        // subsequent reference to `built.penalties`, `built.design`, and
8081        // `built.ops` sees the post-rotation values.
8082        //
8083        // The math: when the smooth's joint penalty `Σ_k S_k` has a
8084        // non-trivial null space, eigh selects `Q = [U_range | U_null]`
8085        // with null columns at the tail. Setting `β_raw = Q · γ` and
8086        // applying:
8087        //     design        ← X · Q
8088        //     penalties[k]  ← Qᵀ · S_k · Q   (block-diag, zero null tail)
8089        // yields a model whose fitted γ is invariant to the rotation
8090        // (since likelihood depends only on `X · β_raw = X · Q · γ`), but
8091        // whose penalty is full-rank on the range columns. The large-scale
8092        // failing case (cert refusal in the joint-Newton inner solve)
8093        // resolves because `H_pen = H_loglik + S` becomes full rank on
8094        // the smooth's range columns.
8095        //
8096        // Rotation is suppressed when the smooth carries coordinate-wise
8097        // shape constraints (`lb_local` or `built.linear_constraints`):
8098        // those encode a cone in the original coordinate system and a
8099        // general orthogonal rotation breaks the cone geometry. Smooths
8100        // with shape constraints typically have full-rank joint penalty
8101        // (their structural shape comes from the cone, not from null
8102        // directions in the penalty), so suppression is rarely a loss.
8103        //
8104        // `applied_rotation` carries the Q that was applied (or `None`
8105        // if no rotation fired). It is persisted onto `SmoothTerm` below
8106        // so prediction-side `X_new_raw · Q` replay can reproduce the
8107        // exact rotation. Persistence through the saved-model artifact
8108        // is a follow-up — see the doc on `SmoothTerm.joint_null_rotation`.
8109        let applied_rotation: Option<crate::basis::JointNullRotation> = match (
8110            built.joint_null_rotation.take(),
8111            lb_local.is_some(),
8112            built.linear_constraints.is_some(),
8113        ) {
8114            (Some(rot), false, false) => {
8115                let q = &rot.rotation;
8116                let dense = built
8117                    .design
8118                    .try_to_dense_by_chunks("joint-null absorption rotation")
8119                    .map_err(BasisError::InvalidInput)?;
8120                let rotated = gam_linalg::faer_ndarray::fast_ab(&dense, q);
8121                built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(rotated));
8122                built.penalties = built
8123                    .penalties
8124                    .into_iter()
8125                    .map(|s_local| {
8126                        let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
8127                        gam_linalg::faer_ndarray::fast_ab(&qt_s, q)
8128                    })
8129                    .collect();
8130                built.ops = vec![None; built.penalties.len()];
8131                built.kronecker_factored = None;
8132                Some(rot)
8133            }
8134            (Some(_), _, _) => None,
8135            (None, _, _) => None,
8136        };
8137
8138        let activeinfos = built
8139            .penaltyinfo
8140            .iter()
8141            .filter(|info| info.active)
8142            .collect::<Vec<_>>();
8143        if activeinfos.len() != built.penalties.len() {
8144            crate::bail_invalid_basis!(
8145                "internal penalty info mismatch for term '{}': activeinfos={}, penalties={}",
8146                term.name,
8147                activeinfos.len(),
8148                built.penalties.len()
8149            );
8150        }
8151        for (((s_local, &ns), info), op_local) in built
8152            .penalties
8153            .iter()
8154            .zip(built.nullspaces.iter())
8155            .zip(activeinfos.into_iter())
8156            .zip(built.ops.iter())
8157        {
8158            let global_index = penalties_global.len();
8159            penalties_global.push(
8160                BlockwisePenalty::new(col_start..col_end, s_local.clone())
8161                    .with_op(op_local.clone()),
8162            );
8163            nullspace_dims_global.push(ns);
8164            let mut penalty = info.clone();
8165            penalty.nullspace_dim_hint = ns;
8166            penaltyinfo_global.push(PenaltyBlockInfo {
8167                global_index,
8168                termname: Some(term.name.clone()),
8169                penalty,
8170            });
8171        }
8172        for info in built.penaltyinfo.iter().filter(|info| !info.active) {
8173            dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
8174                termname: Some(term.name.clone()),
8175                penalty: info.clone(),
8176            });
8177        }
8178        for info in &built.pre_dropped_penaltyinfo {
8179            dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
8180                termname: Some(term.name.clone()),
8181                penalty: info.clone(),
8182            });
8183        }
8184
8185        if let Some(lin_local) = &built.linear_constraints {
8186            for r in 0..lin_local.a.nrows() {
8187                linear_constraintsrows.push((col_start, col_end, lin_local.a.row(r).to_owned()));
8188                linear_constraints_b.push(lin_local.b[r]);
8189            }
8190        }
8191        if let Some(lb_local) = &lb_local {
8192            coefficient_lower_bounds
8193                .slice_mut(s![col_start..col_end])
8194                .assign(lb_local);
8195            any_bounds = true;
8196        }
8197
8198        // Move the per-term design out of `built` rather than cloning it.
8199        local_designs.push(built.design);
8200
8201        terms_out.push(SmoothTerm {
8202            name: term.name.clone(),
8203            coeff_range: col_start..col_end,
8204            shape: term.shape,
8205            penalties_local: built.penalties,
8206            nullspace_dims: built.nullspaces,
8207            penaltyinfo_local: built.penaltyinfo,
8208            metadata: built.metadata,
8209            lower_bounds_local: lb_local,
8210            linear_constraints_local: built.linear_constraints,
8211            kronecker_factored: built.kronecker_factored.take(),
8212            joint_null_rotation: applied_rotation,
8213            unabsorbed_global_orthogonality: None,
8214        });
8215
8216        col_start = col_end;
8217    }
8218
8219    assert_eq!(
8220        penalties_global.len(),
8221        nullspace_dims_global.len(),
8222        "global smooth penalty/nullspace bookkeeping diverged"
8223    );
8224    assert_eq!(
8225        penalties_global.len(),
8226        penaltyinfo_global.len(),
8227        "global smooth penalty metadata bookkeeping diverged"
8228    );
8229
8230    Ok(RawSmoothDesign {
8231        term_designs: local_designs,
8232        penalties: penalties_global,
8233        nullspace_dims: nullspace_dims_global,
8234        penaltyinfo: penaltyinfo_global,
8235        dropped_penaltyinfo: dropped_penaltyinfo_global,
8236        terms: terms_out,
8237        coefficient_lower_bounds: if any_bounds {
8238            Some(coefficient_lower_bounds)
8239        } else {
8240            None
8241        },
8242        linear_constraints: if linear_constraintsrows.is_empty() {
8243            None
8244        } else {
8245            let mut a = Array2::<f64>::zeros((linear_constraintsrows.len(), total_p));
8246            for (i, (cs, ce, values)) in linear_constraintsrows.iter().enumerate() {
8247                a.row_mut(i).slice_mut(s![*cs..*ce]).assign(values);
8248            }
8249            Some(LinearInequalityConstraints {
8250                a,
8251                b: Array1::from_vec(linear_constraints_b),
8252            })
8253        },
8254    })
8255}