Skip to main content

gam_terms/smooth/
term_specs.rs

1use coefficient_transforms::{
2    convex_divided_difference_transform_matrix, cumulative_exp, cumulative_sum_transform_matrix,
3    second_cumulative_exp,
4};
5
6pub use error::SmoothError;
7
8use input_standardization::{
9    apply_input_standardization, compensate_length_scale_for_standardization,
10    compensate_optional_length_scale_for_standardization, compute_spatial_input_scales,
11};
12
13use shape_constraints::{
14    build_shape_constraint_design_1d, build_shape_linear_constraints_1d,
15    merge_linear_constraints_global, shape_lower_bounds_local, shape_order_and_sign,
16    shape_supports_basis, shape_uses_box_reparameterization,
17};
18
19pub fn describe_thin_plate_center_request(strategy: &CenterStrategy) -> String {
20    match strategy {
21        CenterStrategy::Auto(inner) => describe_thin_plate_center_request(inner),
22        CenterStrategy::UserProvided(centers) => format!("{} centers", centers.nrows()),
23        CenterStrategy::EqualMass { num_centers }
24        | CenterStrategy::EqualMassCovarRepresentative { num_centers }
25        | CenterStrategy::FarthestPoint { num_centers }
26        | CenterStrategy::KMeans { num_centers, .. } => format!("{num_centers} centers"),
27        CenterStrategy::UniformGrid { points_per_dim } => {
28            format!("uniform grid with {points_per_dim} points per dimension")
29        }
30    }
31}
32
33pub fn rewrite_thin_plate_knots_error(
34    err: BasisError,
35    termname: &str,
36    feature_count: usize,
37    spec: &ThinPlateBasisSpec,
38) -> BasisError {
39    match err {
40        // Polynomial-nullspace shortfall reported directly by the kernel
41        // builder ("thin-plate spline requires at least N centers to span ...").
42        BasisError::InvalidInput(msg)
43            if msg.contains("thin-plate spline requires at least")
44                && (msg.contains("centers to span") || msg.contains("knots to span")) =>
45        {
46            let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
47            let requested = describe_thin_plate_center_request(&spec.center_strategy);
48            BasisError::InvalidInput(format!(
49                "joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
50            ))
51        }
52        // Insufficient-rows shortfall raised by `select_thin_plate_knots` when
53        // the requested center count exceeds the available row count. Rewrite
54        // it in term language so the diagnostic points at the smooth term and
55        // the polynomial-nullspace minimum the user needs to satisfy.
56        BasisError::InvalidInput(msg)
57            if msg.starts_with("requested ") && msg.contains(" knots but only ") =>
58        {
59            let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
60            let requested = describe_thin_plate_center_request(&spec.center_strategy);
61            BasisError::InvalidInput(format!(
62                "joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
63            ))
64        }
65        other => other,
66    }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
70pub enum ShapeConstraint {
71    None,
72    MonotoneIncreasing,
73    MonotoneDecreasing,
74    Convex,
75    Concave,
76}
77
78/// Parse a shape-constraint string into a [`ShapeConstraint`].
79///
80/// This is the single source of truth shared by the formula DSL
81/// (`s(x, shape=...)`) and the `smooths={...}` override path
82/// (`Smooth.shape_constraint`). The accepted spellings cover the canonical
83/// Python `ShapeConstraintLiteral` strings exactly
84/// (`"none"` / `"monotone_increasing"` / `"monotone_decreasing"` /
85/// `"convex"` / `"concave"`) plus a few common aliases. Hyphens and case are
86/// normalized, so `"Monotone-Increasing"` and `"mono_inc"` both resolve to
87/// [`ShapeConstraint::MonotoneIncreasing`].
88pub fn parse_shape_constraint(raw: &str) -> Result<ShapeConstraint, String> {
89    let normalized = raw.trim().to_ascii_lowercase().replace('-', "_");
90    match normalized.as_str() {
91        "" | "none" => Ok(ShapeConstraint::None),
92        "monotone_increasing" | "monotonic_increasing" | "increasing" | "mono_inc" | "mpi" => {
93            Ok(ShapeConstraint::MonotoneIncreasing)
94        }
95        "monotone_decreasing" | "monotonic_decreasing" | "decreasing" | "mono_dec" | "mpd" => {
96            Ok(ShapeConstraint::MonotoneDecreasing)
97        }
98        "convex" | "cvx" => Ok(ShapeConstraint::Convex),
99        "concave" | "ccv" => Ok(ShapeConstraint::Concave),
100        other => Err(format!(
101            "unknown shape constraint {other:?}; expected one of \
102             \"none\", \"monotone_increasing\", \"monotone_decreasing\", \
103             \"convex\", \"concave\""
104        )),
105    }
106}
107
108impl ShapeConstraint {
109    /// Canonical formula-DSL spelling, i.e. the text emitted into
110    /// `s(x, shape=...)`. Round-trips through [`parse_shape_constraint`].
111    pub fn dsl_str(&self) -> &'static str {
112        match self {
113            ShapeConstraint::None => "none",
114            ShapeConstraint::MonotoneIncreasing => "monotone_increasing",
115            ShapeConstraint::MonotoneDecreasing => "monotone_decreasing",
116            ShapeConstraint::Convex => "convex",
117            ShapeConstraint::Concave => "concave",
118        }
119    }
120}
121
122/// Smooth-term head keywords recognised by the formula DSL. A `shape=` option
123/// may be attached to any term whose head is one of these.
124pub const SMOOTH_HEAD_KEYWORDS: [&str; 11] = [
125    "s",
126    "smooth",
127    "te",
128    "tensor",
129    "thinplate",
130    "tps",
131    "duchon",
132    "matern",
133    "sphere",
134    "bs",
135    "bspline",
136];
137
138/// Rewrite smooth-term calls in `formula` so each named smooth carries a
139/// `shape=<kind>` option understood by the formula DSL.
140///
141/// `constraints` pairs the smooth-term text as it appears in the formula
142/// (e.g. `"s(x)"` or `"s(x, type=duchon, centers=8)"`) with a shape-constraint
143/// spelling accepted by [`parse_shape_constraint`]; comparison is exact after
144/// whitespace removal. A `"none"` constraint is a no-op. Referencing a term not
145/// present in the formula is an error.
146///
147/// This is the single source of truth for the `gamfit.fit(..., constraints=…)`
148/// rewrite — the Python wrapper only marshals the mapping across the FFI and
149/// holds no formula-parsing or alias-normalization logic of its own.
150pub fn apply_shape_constraints_to_formula(
151    formula: &str,
152    constraints: &[(String, String)],
153) -> Result<String, String> {
154    use std::collections::{BTreeMap, BTreeSet};
155
156    if constraints.is_empty() {
157        return Ok(formula.to_string());
158    }
159    let strip_ws = |s: &str| -> String { s.chars().filter(|c| !c.is_whitespace()).collect() };
160
161    // Whitespace-stripped term text -> canonical shape spelling.
162    let mut wanted: BTreeMap<String, &'static str> = BTreeMap::new();
163    // Whitespace-stripped term text -> original key (for error labels).
164    let mut originals: BTreeMap<String, String> = BTreeMap::new();
165    for (key, kind_raw) in constraints {
166        let kind = parse_shape_constraint(kind_raw)?;
167        let nk = strip_ws(key);
168        originals.entry(nk.clone()).or_insert_with(|| key.clone());
169        if kind != ShapeConstraint::None {
170            wanted.insert(nk, kind.dsl_str());
171        }
172    }
173    if wanted.is_empty() {
174        return Ok(formula.to_string());
175    }
176
177    let chars: Vec<char> = formula.chars().collect();
178    let n = chars.len();
179    let is_ident = |c: char| c.is_ascii_alphanumeric() || c == '_';
180
181    let mut out = String::with_capacity(formula.len() + 32);
182    let mut matched: BTreeSet<String> = BTreeSet::new();
183    let mut i = 0usize;
184    while i < n {
185        // Locate the next smooth-term head (`<keyword> \s* (`) at or after `i`,
186        // respecting word boundaries so `abs(` never matches the `s(` head.
187        let mut head: Option<(usize, usize)> = None; // (head_start, paren_index)
188        let mut p = i;
189        while p < n {
190            let boundary = p == 0 || !is_ident(chars[p - 1]);
191            if boundary {
192                for kw in SMOOTH_HEAD_KEYWORDS.iter() {
193                    let klen = kw.chars().count();
194                    if p + klen > n || chars[p..p + klen].iter().collect::<String>() != **kw {
195                        continue;
196                    }
197                    let mut q = p + klen;
198                    while q < n && chars[q].is_whitespace() {
199                        q += 1;
200                    }
201                    if q < n && chars[q] == '(' {
202                        head = Some((p, q));
203                        break;
204                    }
205                }
206            }
207            if head.is_some() {
208                break;
209            }
210            p += 1;
211        }
212        let (head_start, paren_open) = match head {
213            Some(h) => h,
214            None => {
215                out.extend(chars[i..].iter());
216                break;
217            }
218        };
219        out.extend(chars[i..head_start].iter());
220
221        // Find the matching close paren, honoring nesting and string literals.
222        let body_start = paren_open + 1;
223        let mut depth = 1i32;
224        let mut j = body_start;
225        let mut in_str: Option<char> = None;
226        let mut closed = false;
227        while j < n {
228            let ch = chars[j];
229            if let Some(quote) = in_str {
230                if ch == quote {
231                    in_str = None;
232                }
233            } else if ch == '\'' || ch == '"' {
234                in_str = Some(ch);
235            } else if ch == '(' {
236                depth += 1;
237            } else if ch == ')' {
238                depth -= 1;
239                if depth == 0 {
240                    closed = true;
241                    break;
242                }
243            }
244            j += 1;
245        }
246
247        if !closed {
248            // Unbalanced — emit the remainder verbatim; the DSL parser will
249            // produce the canonical error.
250            out.extend(chars[head_start..].iter());
251            break;
252        }
253
254        let term_text: String = chars[head_start..=j].iter().collect();
255
256        let key_norm = strip_ws(&term_text);
257
258        match wanted.get(&key_norm) {
259            None => out.extend(chars[head_start..=j].iter()),
260            Some(kind) => {
261                let head_paren: String = chars[head_start..body_start].iter().collect();
262                let inside: String = chars[body_start..j].iter().collect();
263                let inside = inside.trim();
264                if inside.is_empty() {
265                    out.push_str(&format!("{head_paren}shape={kind})"));
266                } else {
267                    out.push_str(&format!("{head_paren}{inside}, shape={kind})"));
268                }
269                matched.insert(key_norm);
270            }
271        }
272
273        i = j + 1;
274    }
275
276    let mut missing: Vec<String> = wanted
277        .keys()
278        .filter(|k| !matched.contains(*k))
279        .map(|k| originals.get(k).cloned().unwrap_or_else(|| k.clone()))
280        .collect();
281
282    if !missing.is_empty() {
283        missing.sort();
284        return Err(format!(
285            "shape constraints referenced smooth term(s) not found in formula: {}",
286            missing.join(", ")
287        ));
288    }
289
290    Ok(out)
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub enum BySmoothKind {
295    Numeric,
296    Level { level_bits: u64 },
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub enum SmoothBasisSpec {
301    /// Row-gated wrapper used for mgcv-style ``by=`` smooths.
302    ///
303    /// ``ByNumeric`` multiplies the inner smooth by a numeric column.
304    /// ``ByLevel`` keeps the inner smooth active only for rows whose encoded
305    /// categorical value has the stored bit pattern.  Unordered factor-by
306    /// smooths are represented as one independent ``ByLevel`` term per level.
307    ///
308    /// `kind` preserves the compact structural discriminator, while `by`
309    /// carries the full row-gating spec used to build the local design.
310    ByVariable {
311        inner: Box<SmoothBasisSpec>,
312        by_col: usize,
313        kind: BySmoothKind,
314        by: ByVariableSpec,
315    },
316    /// Sum-to-zero factor smooth (`bs="sz"`): with L levels, estimate L-1
317    /// deviation coefficient blocks and use the final level as the negative
318    /// sum of the others, enforcing coefficient-wise zero sums across levels.
319    FactorSumToZero {
320        inner: Box<SmoothBasisSpec>,
321        by_col: usize,
322        levels: Vec<u64>,
323        /// Global-orthogonality column map `Z` captured at fit time when this
324        /// term overlapped an owner smooth (`s(x) + s(g, x, bs=sz)`, #978):
325        /// the hierarchical-ownership pass residualized this term's realized
326        /// design as `X ← X·Z`, shrinking its coefficient block. `Z` depends
327        /// on the *training-row* owner designs, so prediction cannot rederive
328        /// it — it must be persisted and replayed
329        /// (`apply_global_smooth_identifiability` consumes it verbatim).
330        /// Chart convention: `Z` lives in the post-restack, post-joint-null-Q
331        /// coordinates — the raw `sz` rebuild reapplies `Q` deterministically
332        /// (#700), then `Z` applies on top. `None` for non-overlapping terms.
333        #[serde(default)]
334        frozen_global_orthogonality: Option<Array2<f64>>,
335    },
336    BSpline1D {
337        feature_col: usize,
338        spec: BSplineBasisSpec,
339    },
340    /// A smooth modulated by a `by=` variable. Numeric `by` scales one inner
341    /// smooth; factor `by` replicates the inner smooth by level.
342    BySmooth {
343        smooth: Box<SmoothBasisSpec>,
344        by_kind: ByVarKind,
345    },
346    /// Factor-smooth interaction families (`bs="fs"`, `bs="sz"`) and
347    /// random slopes (`bs="re"`).
348    FactorSmooth { spec: FactorSmoothSpec },
349    ThinPlate {
350        feature_cols: Vec<usize>,
351        spec: ThinPlateBasisSpec,
352        /// Per-column standard deviations used to standardize input dimensions
353        /// before kernel evaluation when d > 1. `None` means no standardization
354        /// (either d == 1 or explicitly disabled).
355        #[serde(default)]
356        input_scales: Option<Vec<f64>>,
357    },
358    Sphere {
359        feature_cols: Vec<usize>,
360        spec: SphericalSplineBasisSpec,
361    },
362    /// Constant-curvature (`M_κ`) geodesic-kernel smooth over κ-stereographic
363    /// chart coordinates (#944): one construction interpolating
364    /// S^d → ℝ^d → H^d through the spec's fixed κ. The Wahba S² smooth is the
365    /// structural template; the geometry comes from
366    /// `geometry::constant_curvature::ConstantCurvature`.
367    ConstantCurvature {
368        feature_cols: Vec<usize>,
369        spec: ConstantCurvatureBasisSpec,
370    },
371    Matern {
372        feature_cols: Vec<usize>,
373        spec: MaternBasisSpec,
374        #[serde(default)]
375        input_scales: Option<Vec<f64>>,
376    },
377    /// Measure-jet spline smooth: multiscale local-jet-residual energy of the
378    /// empirical measure (centers as μ-quadrature, masses as μ-weights — no
379    /// graph, mesh, or neighbor set inside the statistical object). The
380    /// feature columns are ambient coordinates of data concentrated near an
381    /// unknown low-dimensional, possibly stratified set.
382    MeasureJet {
383        feature_cols: Vec<usize>,
384        spec: MeasureJetBasisSpec,
385        #[serde(default)]
386        input_scales: Option<Vec<f64>>,
387    },
388    Duchon {
389        feature_cols: Vec<usize>,
390        spec: DuchonBasisSpec,
391        #[serde(default)]
392        input_scales: Option<Vec<f64>>,
393    },
394    Pca {
395        feature_cols: Vec<usize>,
396        basis_matrix: Array2<f64>,
397        centered: bool,
398        #[serde(default = "default_pca_smooth_penalty")]
399        smooth_penalty: f64,
400        #[serde(default)]
401        center_mean: Option<Array1<f64>>,
402        #[serde(default)]
403        pca_basis_path: Option<PathBuf>,
404        #[serde(default = "default_pca_chunk_size")]
405        chunk_size: usize,
406    },
407    /// Tensor-product smooth built from 1D B-spline marginals.
408    ///
409    /// This is the `te()`-style construction used when axes have different units/scales
410    /// (for example, space x time) and isotropic radial kernels are not appropriate.
411    TensorBSpline {
412        feature_cols: Vec<usize>,
413        spec: TensorBSplineSpec,
414    },
415}
416
417impl SmoothBasisSpec {
418    /// Conservative lower bound on the number of sample rows needed for this
419    /// smooth basis to have a well-posed REML fit.
420    ///
421    /// Each basis kind answers the question for itself, so the workflow does
422    /// not have to know how many columns a B-spline, tensor product, PCA
423    /// projection, or spatial kernel emits. The contract is a *lower bound*:
424    /// returning too small a number is permitted (the inner solver will catch
425    /// any genuine n-vs-rank failure that slips past); returning too large a
426    /// number is a regression because it rejects legitimate fits.
427    ///
428    /// Rationale: B-spline / tensor / PCA bases have a closed-form column
429    /// count, so we use the exact dimension. Radial bases (TPS, Matern,
430    /// Duchon, Sphere) and factor smooths choose their column count from the
431    /// data (`heuristic_centers`, `unique_count`); we fall back to a small
432    /// constant floor because a fit on fewer than five rows cannot stabilise
433    /// any radial smooth regardless of the configured kernel scale.
434    pub fn min_sample_rows(&self) -> usize {
435        // Floor used for data-driven bases whose column count is not known
436        // from the spec alone. Five rows is the minimum at which the inner
437        // pivot/QR + REML smoothing-parameter search has any chance of being
438        // well-posed for a non-parametric smooth.
439        const RADIAL_FLOOR: usize = 5;
440
441        match self {
442            Self::ByVariable { inner, .. } => inner.min_sample_rows(),
443            Self::FactorSumToZero { inner, levels, .. } => {
444                // L-1 independent deviation blocks each carrying the inner
445                // basis dimension. Skip the levels-multiplier if it doesn't
446                // bring more rows; we want the *lower bound* not the rank.
447                let inner_min = inner.min_sample_rows();
448                let lvls = levels.len().saturating_sub(1).max(1);
449                inner_min.saturating_mul(lvls)
450            }
451            Self::BSpline1D { spec, .. } => bspline_basis_min_rows(spec),
452            Self::BySmooth { smooth, .. } => smooth.min_sample_rows(),
453            Self::FactorSmooth { spec } => {
454                // Replicates the marginal once per level; without a known
455                // level count we conservatively require at least the marginal
456                // basis dimension.
457                bspline_basis_min_rows(&spec.marginal)
458            }
459            Self::ThinPlate { .. }
460            | Self::Sphere { .. }
461            | Self::ConstantCurvature { .. }
462            | Self::Matern { .. }
463            | Self::MeasureJet { .. }
464            | Self::Duchon { .. } => RADIAL_FLOOR,
465            Self::Pca { basis_matrix, .. } => basis_matrix.ncols().max(1),
466            Self::TensorBSpline { spec, .. } => {
467                // A `te(...)` smooth is *penalized*: each margin carries a
468                // difference (wiggliness) penalty and the tensor inherits a
469                // Kronecker-sum penalty `S = Σ_i I ⊗ … ⊗ S_i ⊗ … ⊗ I`. The raw
470                // column count is the *product* of the per-marginal column
471                // counts, but that product is the lower bound for an
472                // *unpenalized* tensor regression — it is the number of rows you
473                // would need to identify every interaction column with no
474                // regularization. The penalty regularizes all of those
475                // interaction directions; only the combined penalty *null space*
476                // (the tensor product of the per-margin polynomial trends, a
477                // handful of columns) must be identified by the data, and the
478                // smoothing-parameter search shrinks the rest. The effective
479                // degrees of freedom of the fitted `te()` are therefore a small
480                // fraction of the column product, which is exactly why mgcv
481                // fits a default `te(x, y)` on a couple hundred rows.
482                //
483                // The honest *penalized* lower bound is the **sum** of the
484                // per-marginal column counts, not their product: a row floor of
485                // `Σ_i k_i` still guarantees enough data to identify each
486                // margin's additive main-effect (the largest sub-block the
487                // penalty cannot shrink to zero), while no longer conflating
488                // unpenalized column-count identifiability with penalized
489                // well-posedness. This accepts moderate-`n` penalized tensors
490                // (e.g. a 20×20 default basis on n=200) yet still rejects a
491                // genuinely undersized fit where `n < Σ_i k_i` and even the
492                // additive part is rank-deficient.
493                //
494                // Binary / low-cardinality margins (#724): gam will accept a
495                // `te(x, badh)` whose `badh ∈ {0, 1}` margin nominally requests
496                // more basis columns than `badh` has unique values, where mgcv
497                // refuses the unpenalized term as ill-posed ("badh has
498                // insufficient unique values to support k knots"). This is
499                // correct-by-design, *not* a degenerate fit: the marginal
500                // wiggliness penalty on the `badh` axis has a null space that is
501                // exactly its identifiable trend (the two cell means of a binary
502                // covariate), and the Kronecker-sum penalty shrinks every tensor
503                // column outside that null space toward zero. The resulting fit
504                // is the well-posed "per-level `x` smooth + binary main effect"
505                // that mgcv reaches only after manually collapsing the basis —
506                // gam reaches it automatically because the penalty, not the raw
507                // column count, sets the effective rank. A genuinely
508                // rank-deficient design (penalty null space wider than the data
509                // can support) is still caught downstream by the inner pivoted
510                // factorization, which owns the exact n-vs-rank decision; this
511                // pre-fit gate only refuses the grossly-undersized formula.
512                let mut total: usize = 0;
513                for marginal in &spec.marginalspecs {
514                    let m = bspline_basis_min_rows(marginal);
515                    total = total.saturating_add(m.max(1));
516                }
517                total.max(RADIAL_FLOOR)
518            }
519        }
520    }
521
522    /// Stable structural discriminant for warm-start cache keying (#869).
523    ///
524    /// Two smooths that produce different bases / penalty structures must map
525    /// to different strings here so they cannot collide on the persistent
526    /// warm-start `cache_key` (which is otherwise blind to topology: it hashes
527    /// only the raw input column count, so e.g. `sphere` vs `torus` vs
528    /// `euclidean` candidates fit on the *same* data would otherwise share one
529    /// key and cross-contaminate each other's β/ρ seed). The string is the
530    /// topology identity, not the fitted coefficients, so same-topology refits
531    /// (the screen→full-refit cascade) still hit the same key and reuse work.
532    pub fn structural_kind(&self) -> &'static str {
533        match self {
534            Self::ByVariable { .. } => "by_variable",
535            Self::FactorSumToZero { .. } => "factor_sum_to_zero",
536            Self::BSpline1D { .. } => "bspline_1d",
537            Self::BySmooth { .. } => "by_smooth",
538            Self::FactorSmooth { .. } => "factor_smooth",
539            Self::ThinPlate { .. } => "thin_plate",
540            Self::Sphere { .. } => "sphere",
541            Self::ConstantCurvature { .. } => "constant_curvature",
542            Self::Matern { .. } => "matern",
543            Self::MeasureJet { .. } => "measurejet",
544            Self::Duchon { .. } => "duchon",
545            Self::Pca { .. } => "pca",
546            Self::TensorBSpline { .. } => "tensor_bspline",
547        }
548    }
549
550    /// True for a tensor-product smooth that is only *marginally* centered
551    /// (`ti(...)`, [`TensorBSplineIdentifiability::MarginalSumToZero`]): its
552    /// per-margin sum-to-zero reparameterization `(B_xZ_x)⊗(B_zZ_z)` has ALREADY
553    /// removed each axis's main effect analytically (mgcv-identical), so its
554    /// main-effect removal is complete and it must take NO additional
555    /// owner-residualization block. Residualizing it a second time against the
556    /// realized main-effect designs is a grid-fragile no-op on an exact tensor
557    /// grid but eats genuine pure-interaction curvature off-grid (#1470).
558    pub fn is_marginally_centered_tensor(&self) -> bool {
559        matches!(
560            self,
561            Self::TensorBSpline { spec, .. }
562                if matches!(spec.identifiability, TensorBSplineIdentifiability::MarginalSumToZero)
563        )
564    }
565
566    /// A sum-to-zero factor smooth (`bs="sz"`) has ALREADY removed the
567    /// cross-group main effect analytically, in coefficient space, via its
568    /// `Σ_g d_g(x) ≡ 0` reparameterization (`L-1` deviation blocks with the
569    /// reference level the negative sum of the others) — exactly mgcv's `sz`
570    /// construction, which is self-identifiable against an overlapping `s(x)`
571    /// with no further constraint. Residualizing it a SECOND time against the
572    /// realized B-spline span of the explicit `s(x)` smooth is redundant in
573    /// exact arithmetic (the common-to-all-groups component is zero by
574    /// construction) and actively HARMFUL on finite data: each deviation block
575    /// is a B-spline in `x` whose realized columns share `s(x)`'s span, so the
576    /// joint residualization collapses the full `L·k`-column deviation design to
577    /// `L·k − rank(s(x))` columns and eats the within-group curvature `s(x)`
578    /// cannot represent. REML then rails the deviation smoothing parameter and
579    /// the factor smooth under-recovers (#1605). This is the exact analogue of
580    /// the marginally-centered tensor (`ti`) exemption (#1470), so such a term
581    /// takes NO owner-residualization block.
582    pub fn is_sum_to_zero_factor_smooth(&self) -> bool {
583        matches!(
584            self,
585            Self::FactorSumToZero { .. }
586                | Self::FactorSmooth {
587                    spec: FactorSmoothSpec {
588                        flavour: FactorSmoothFlavour::Sz,
589                        ..
590                    }
591                }
592        )
593    }
594
595    /// Feature columns this basis consumes, used alongside [`structural_kind`]
596    /// to disambiguate two same-kind smooths on different axes. Wrapper
597    /// variants delegate to their inner basis.
598    pub fn structural_feature_cols(&self) -> Vec<usize> {
599        match self {
600            Self::ByVariable { inner, .. } | Self::FactorSumToZero { inner, .. } => {
601                inner.structural_feature_cols()
602            }
603            Self::BySmooth { smooth, .. } => smooth.structural_feature_cols(),
604            Self::FactorSmooth { .. } => Vec::new(),
605            Self::BSpline1D { feature_col, .. } => vec![*feature_col],
606            Self::ThinPlate { feature_cols, .. }
607            | Self::Sphere { feature_cols, .. }
608            | Self::ConstantCurvature { feature_cols, .. }
609            | Self::Matern { feature_cols, .. }
610            | Self::MeasureJet { feature_cols, .. }
611            | Self::Duchon { feature_cols, .. }
612            | Self::Pca { feature_cols, .. }
613            | Self::TensorBSpline { feature_cols, .. } => feature_cols.clone(),
614        }
615    }
616}
617
618/// Lower bound on the number of sample rows a 1D B-spline smooth needs for a
619/// well-posed *penalized* REML fit. Used as the per-smooth row floor in
620/// [`SmoothBasisSpec::min_sample_rows`].
621///
622/// For a *singly*-penalized smooth the floor is the full column count: the
623/// wiggliness penalty leaves the order-`m` polynomial trend unpenalized, and
624/// gam's original gate conservatively required enough rows for the whole basis.
625/// That conservative floor is kept here unchanged.
626///
627/// A *double*-penalized smooth (mgcv `select=TRUE`) is different: it adds a
628/// second penalty on the wiggliness penalty's null space, so even the
629/// polynomial trend is shrinkable toward zero and *nothing* in the basis
630/// requires unpenalized identification by the data — exactly the reasoning the
631/// `TensorBSpline` arm of [`SmoothBasisSpec::min_sample_rows`] already applies
632/// to a penalized tensor. Its honest floor is therefore a small stabilization
633/// constant, not the column count. This is what lets mgcv (and now gam) fit
634/// several `select=TRUE` smooths on a dataset whose row count is below the
635/// summed basis width (e.g. the n≈30 `wine_gamair` fold, 5 `ps` smooths,
636/// p≈51): the penalties, not the data, set the effective rank. The bounded
637/// outer REML loop still terminates, and the genuine n-vs-rank decision is
638/// owned downstream by the inner pivoted factorization. Without this, gam
639/// rejected the fit outright (or, before the gate existed, the outer REML loop
640/// wandered the flat overparameterized surface until the benchmark wall budget
641/// killed it — #1089).
642pub fn bspline_basis_min_rows(spec: &crate::basis::BSplineBasisSpec) -> usize {
643    use crate::basis::BSplineKnotSpec;
644    let columns = match &spec.knotspec {
645        BSplineKnotSpec::Generate {
646            num_internal_knots, ..
647        } => *num_internal_knots + spec.degree + 1,
648        BSplineKnotSpec::Automatic {
649            num_internal_knots: Some(k),
650            ..
651        } => *k + spec.degree + 1,
652        BSplineKnotSpec::Automatic {
653            num_internal_knots: None,
654            ..
655        } => {
656            // Knot count is data-derived (`default_internal_knot_count_for_data`).
657            // A minimal cubic basis is `degree + 2` columns; below that the
658            // basis cannot represent a non-parametric smooth.
659            spec.degree + 2
660        }
661        BSplineKnotSpec::Provided(knots) => knots.len().saturating_sub(spec.degree + 1).max(1),
662        // cr basis dimension equals the knot count (no degree offset).
663        BSplineKnotSpec::NaturalCubicRegression { knots } => knots.len(),
664        BSplineKnotSpec::PeriodicUniform { num_basis, .. } => *num_basis,
665    };
666    let columns = columns.max(spec.degree + 2);
667
668    if spec.double_penalty {
669        // Fully shrinkable basis: only a small stabilization floor must be
670        // identified by the data, capped by the actual column count.
671        const DOUBLE_PENALTY_FLOOR: usize = 2;
672        DOUBLE_PENALTY_FLOOR.min(columns).max(1)
673    } else {
674        columns
675    }
676}
677
678#[derive(Debug, Clone, Serialize, Deserialize)]
679pub enum ByVariableSpec {
680    Numeric,
681    Level { value_bits: u64, label: String },
682}
683
684
685#[derive(Debug, Clone, Serialize, Deserialize)]
686pub enum ByVarKind {
687    Numeric {
688        feature_col: usize,
689    },
690    Factor {
691        feature_col: usize,
692        ordered: bool,
693        frozen_levels: Option<Vec<u64>>,
694    },
695}
696
697#[derive(Debug, Clone, Serialize, Deserialize)]
698pub struct FactorSmoothSpec {
699    pub continuous_cols: Vec<usize>,
700    pub group_col: usize,
701    pub marginal: BSplineBasisSpec,
702    pub flavour: FactorSmoothFlavour,
703    pub group_frozen_levels: Option<Vec<u64>>,
704    /// Fit-time global-orthogonality chart `Z` for this term (`s(x) + fs(x, g)`
705    /// overlap residualization, #978), in the post-joint-null-`Q` coordinates
706    /// (the raw rebuild recomputes any `Q` itself; `fs` penalties are
707    /// typically full-rank so `Q` is absent). Training-row dependent, hence
708    /// persisted; replayed verbatim by `apply_global_smooth_identifiability`.
709    #[serde(default)]
710    pub frozen_global_orthogonality: Option<Array2<f64>>,
711}
712
713#[derive(Debug, Clone, Serialize, Deserialize)]
714pub enum FactorSmoothFlavour {
715    Fs { m_null_penalty_orders: Vec<usize> },
716    Sz,
717    Re,
718}
719
720#[derive(Debug, Default, Clone, Serialize, Deserialize)]
721pub struct TensorBSplineSpec {
722    pub marginalspecs: Vec<BSplineBasisSpec>,
723    #[serde(default)]
724    pub periods: Vec<Option<f64>>,
725    pub double_penalty: bool,
726    #[serde(default)]
727    pub identifiability: TensorBSplineIdentifiability,
728    #[serde(default)]
729    pub penalty_decomposition: TensorBSplinePenaltyDecomposition,
730}
731
732#[derive(Debug, Default, Clone, Serialize, Deserialize)]
733pub enum TensorBSplineIdentifiability {
734    None,
735    #[default]
736    SumToZero,
737    /// mgcv `ti(...)` semantics: a *tensor interaction* smooth that excludes the
738    /// marginal main effects. A sum-to-zero constraint is applied to **each
739    /// marginal basis independently** before forming the tensor product, so the
740    /// resulting column space contains no function of a single variable alone —
741    /// only the pure interaction survives. The realized identifiability
742    /// transform is the Kronecker product `Z = Z₀ ⊗ Z₁ ⊗ … ⊗ Z_{d-1}` of the
743    /// per-margin sum-to-zero null-space bases, which is exactly the
744    /// reparameterization that turns the full-tensor design into the tensor
745    /// product of the centered margins.
746    MarginalSumToZero,
747    FrozenTransform {
748        transform: Array2<f64>,
749    },
750}
751
752#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
753pub enum TensorBSplinePenaltyDecomposition {
754    /// mgcv `te(...)`: one overlapping Kronecker-product penalty per margin,
755    /// `S_j` embedded against identities in the other tensor factors.
756    #[default]
757    MarginalKroneckerSum,
758    /// mgcv `t2(...)`: split every marginal coefficient space into penalized
759    /// range and penalty-null subspaces, then emit one disjoint tensor-subspace
760    /// penalty for every non-empty penalized/null combination.
761    Separable,
762}
763
764#[derive(Debug, Clone, Serialize, Deserialize)]
765pub struct SmoothTermSpec {
766    pub name: String,
767    pub basis: SmoothBasisSpec,
768    pub shape: ShapeConstraint,
769    /// Joint-null absorption rotation captured at fit time. `Some(Q)` means
770    /// the fitted coefficient vector lives in `γ`-coordinates with
771    /// `β_raw = Q · γ`; prediction must rotate the raw-basis design via
772    /// `X_new = X_new_raw · Q` to match. `None` means either the smooth had
773    /// no joint null space (penalty already full-rank) or rotation was
774    /// suppressed (smooth carries shape constraints whose cone geometry
775    /// would not survive an arbitrary orthogonal rotation). Persisted so
776    /// `save → load → predict` is bit-equivalent to in-memory prediction.
777    #[serde(default)]
778    pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
779}
780
781#[derive(Debug, Clone)]
782pub struct SmoothTerm {
783    pub name: String,
784    pub coeff_range: Range<usize>,
785    pub shape: ShapeConstraint,
786    pub penalties_local: Vec<Array2<f64>>,
787    pub nullspace_dims: Vec<usize>,
788    pub penaltyinfo_local: Vec<PenaltyInfo>,
789    pub metadata: BasisMetadata,
790    /// Optional term-local lower bounds for constrained coefficients.
791    /// `-inf` means unconstrained.
792    pub lower_bounds_local: Option<Array1<f64>>,
793    /// Optional term-local inequality constraints in local coefficient coordinates.
794    /// `A_local * beta_local >= b_local`.
795    pub linear_constraints_local: Option<LinearInequalityConstraints>,
796    /// Optional factored tensor-product representation preserved for operator-backed
797    /// assembly in the main design builder.
798    pub kronecker_factored: Option<KroneckerFactoredBasis>,
799    /// Joint-null absorption rotation. `Some(Q)` records the orthonormal
800    /// `(p_local × p_local)` matrix that was applied to this term's design
801    /// and per-block penalties at construction time:
802    /// `term_design ← X_raw · Q`, `penalties_local[k] ← Qᵀ · S_raw · Q`.
803    /// The smooth's coefficient vector therefore lives in the rotated
804    /// (`γ`) coordinate system, with `β_raw = Q · γ` recovering the raw
805    /// pre-rotation parameterization. `None` means either no joint null
806    /// space (penalty already full-rank) or rotation was suppressed —
807    /// suppression fires when the smooth carries shape constraints
808    /// (lower bounds or local linear inequalities) that would lose their
809    /// cone geometry under a general orthogonal rotation.
810    ///
811    /// Prediction-side replay: callers building a new-data design `X_new_raw`
812    /// from the *raw* basis must call [`SmoothTerm::apply_rotation_to_predict`]
813    /// (or equivalent) to obtain `X_new = X_new_raw · Q` matching this
814    /// term's coefficient system.
815    ///
816    /// Persistence replay: `freeze_term_collection_from_design` copies this
817    /// rotation into `SmoothTermSpec`, which is serialized with fitted-model
818    /// payloads and reused by the predict-time basis builder. Saved models
819    /// therefore replay the same `X_new_raw · Q` transform as in-memory
820    /// prediction.
821    pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
822    /// Global-orthogonality transform that `apply_global_smooth_identifiability`
823    /// applied to this term's design but could NOT embed into `metadata`
824    /// (factor-smooth kinds: `sz` metadata is per-marginal, `fs` metadata has
825    /// no transform slot — #978). `freeze_term_collection_from_design` copies
826    /// it onto the term's basis spec (`frozen_global_orthogonality`) so the
827    /// predict-side rebuild replays it instead of emitting the unresidualized
828    /// (wider) design that the fitted coefficients no longer match.
829    /// Chart convention is per kind: post-`Q` `Z` for `sz` (the raw rebuild
830    /// reapplies `Q` itself, #700), full `Q·Z` chart for `fs`.
831    pub unabsorbed_global_orthogonality: Option<Array2<f64>>,
832}
833
834impl SmoothTerm {
835    /// Apply the joint-null absorption rotation to a raw new-data design
836    /// matrix, returning `X_new_raw · Q` when this term was rotated at
837    /// fit time, or `X_new_raw` unchanged when no rotation was applied.
838    ///
839    /// Callers in the prediction path: after building the smooth's basis
840    /// at new data via the *raw* basis builder (the same builder used at
841    /// fit time, applied to `x_new` instead of the training rows), call
842    /// this method on the resulting matrix before forming `X · β`. The
843    /// fitted `β` lives in `γ`-coordinates if Q was applied; multiplying
844    /// the un-rotated `X_new_raw` by `β` would give a wrong η.
845    ///
846    /// Returns an error if the raw design's column count does not match
847    /// the rotation's `p_local`. The width invariant must hold: the raw
848    /// basis builder MUST emit the same `p_local` columns that the
849    /// fit-time builder did, and the rotation is `(p_local × p_local)`.
850    pub fn apply_rotation_to_predict(
851        &self,
852        x_new_raw: Array2<f64>,
853    ) -> Result<Array2<f64>, BasisError> {
854        let Some(rot) = self.joint_null_rotation.as_ref() else {
855            return Ok(x_new_raw);
856        };
857        let p_local = rot.rotation.nrows();
858        if x_new_raw.ncols() != p_local {
859            crate::bail_dim_basis!(
860                "joint-null rotation replay for term '{}': raw design has {} columns, \
861                 rotation expects {} (the raw basis builder must emit the same column \
862                 count as at fit time)",
863                self.name,
864                x_new_raw.ncols(),
865                p_local,
866            );
867        }
868        Ok(gam_linalg::faer_ndarray::fast_ab(
869            &x_new_raw,
870            &rot.rotation,
871        ))
872    }
873
874    /// Dimension of the **joint** null space of this term's active penalties:
875    /// the coefficient directions penalized by *no* penalty. The smooth-component
876    /// Wald test ([`crate::inference::smooth_test::wood_smooth_test`]) treats this
877    /// many leading coefficients as genuine unpenalized fixed effects and tests
878    /// them at full rank; the remainder is the penalized sub-block tested with a
879    /// rank-`≈EDF` truncated pseudo-inverse.
880    ///
881    /// Because every penalty block `S_k` is positive semi-definite,
882    /// `vᵀ(Σ_k S_k)v = Σ_k vᵀ S_k v = 0` iff `S_k v = 0` for *every* `k`; the
883    /// joint null space is therefore exactly `null(Σ_k S_k)`, of dimension
884    /// `p_local − rank(Σ_k S_k)`. This is the **intersection** of the per-penalty
885    /// null spaces, not their sum.
886    ///
887    /// Summing the per-penalty `nullspace_dims` instead (the historical defect
888    /// behind #1360) *unions* the null spaces and badly over-counts: a
889    /// double-penalty smooth carries a bending penalty (null space = its
890    /// polynomial part) plus a complementary null-space ridge (which penalizes
891    /// exactly that polynomial part), so the two null spaces are disjoint and the
892    /// joint null space is empty — yet the per-penalty dims sum to nearly
893    /// `p_local`. Feeding that inflated count to the Wald test makes it test
894    /// almost the whole shrunk block at full rank, manufacturing overwhelming
895    /// "significance" for a term the fit drove to ~0 EDF.
896    pub fn wald_unpenalized_dim(&self) -> usize {
897        joint_unpenalized_dim(
898            self.coeff_range.len(),
899            &self.penalties_local,
900            &self.nullspace_dims,
901        )
902    }
903}
904
905/// Numeric core of [`SmoothTerm::wald_unpenalized_dim`]: the dimension of the
906/// joint null space `∩_k null(S_k) = null(Σ_k S_k)` of a term's local penalty
907/// blocks, with a conservative fallback when a penalty is not materialized as a
908/// full `p_local × p_local` matrix (e.g. a Kronecker tensor factor).
909pub fn joint_unpenalized_dim(
910    p_local: usize,
911    penalties_local: &[Array2<f64>],
912    nullspace_dims: &[usize],
913) -> usize {
914    use gam_linalg::faer_ndarray::FaerEigh;
915    if p_local == 0 {
916        return 0;
917    }
918    if penalties_local.is_empty() {
919        // No penalty ⇒ a wholly unpenalized (fixed-effect) block.
920        return p_local;
921    }
922    // Sum the penalties that are materialized as full `p_local × p_local`
923    // blocks (the common smooth case). The covariance block the Wald test
924    // slices lives in this same coefficient basis (post joint-null rotation),
925    // so the rank is computed in the right metric.
926    let mut s_total = Array2::<f64>::zeros((p_local, p_local));
927    let mut materialized = 0usize;
928    for s in penalties_local {
929        if s.nrows() == p_local && s.ncols() == p_local {
930            s_total += s;
931            materialized += 1;
932        }
933    }
934    if materialized == penalties_local.len() {
935        let symmetric = {
936            let transpose = s_total.t().to_owned();
937            (&s_total + &transpose) * 0.5
938        };
939        if let Ok((evals, _)) = symmetric.eigh(faer::Side::Lower) {
940            let max_abs = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
941            if max_abs == 0.0 {
942                // All penalties identically zero ⇒ unpenalized block.
943                return p_local;
944            }
945            let tol = max_abs * (p_local as f64) * 1e-12;
946            let rank = evals.iter().filter(|&&v| v > tol).count();
947            return p_local.saturating_sub(rank);
948        }
949    }
950    // Conservative fallback when a penalty is not a materialized full block
951    // (e.g. a Kronecker tensor factor): with ≥2 active penalties the joint
952    // null space is almost always empty (the only over-rejecting direction);
953    // with a single penalty it is exactly that penalty's own null space.
954    if penalties_local.len() >= 2 {
955        0
956    } else {
957        nullspace_dims
958            .iter()
959            .copied()
960            .min()
961            .unwrap_or(0)
962            .min(p_local)
963    }
964}
965
966#[derive(Debug, Clone, Serialize, Deserialize)]
967pub struct PenaltyBlockInfo {
968    pub global_index: usize,
969    pub termname: Option<String>,
970    pub penalty: PenaltyInfo,
971}
972
973#[derive(Debug, Clone, Serialize, Deserialize)]
974pub struct DroppedPenaltyBlockInfo {
975    pub termname: Option<String>,
976    pub penalty: PenaltyInfo,
977}
978
979#[derive(Debug, Clone)]
980pub struct SmoothDesign {
981    pub term_designs: Vec<DesignMatrix>,
982    /// Per-term block-local penalties.  Each `col_range` is relative to the
983    /// smooth block (i.e. indexing into the concatenation of `term_designs`).
984    pub penalties: Vec<BlockwisePenalty>,
985    pub nullspace_dims: Vec<usize>,
986    pub penaltyinfo: Vec<PenaltyBlockInfo>,
987    pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
988    pub terms: Vec<SmoothTerm>,
989    /// Optional smooth-block lower bounds in smooth coefficient coordinates.
990    /// Length equals `total_smooth_cols()` when present.
991    pub coefficient_lower_bounds: Option<Array1<f64>>,
992    /// Optional smooth-block inequality constraints:
993    /// `A_smooth * beta_smooth >= b`.
994    pub linear_constraints: Option<LinearInequalityConstraints>,
995}
996
997impl SmoothDesign {
998    pub fn total_smooth_cols(&self) -> usize {
999        self.term_designs.iter().map(DesignMatrix::ncols).sum()
1000    }
1001    pub fn nrows(&self) -> usize {
1002        self.term_designs.first().map_or(0, DesignMatrix::nrows)
1003    }
1004}
1005
1006#[derive(Debug, Clone)]
1007pub struct RawSmoothDesign {
1008    pub term_designs: Vec<DesignMatrix>,
1009    /// Per-term block-local penalties.  Each `col_range` is relative to the
1010    /// smooth block (i.e. indexing into the concatenation of `term_designs`).
1011    pub penalties: Vec<BlockwisePenalty>,
1012    pub nullspace_dims: Vec<usize>,
1013    pub penaltyinfo: Vec<PenaltyBlockInfo>,
1014    pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
1015    pub terms: Vec<SmoothTerm>,
1016    pub coefficient_lower_bounds: Option<Array1<f64>>,
1017    pub linear_constraints: Option<LinearInequalityConstraints>,
1018}
1019
1020impl RawSmoothDesign {
1021    pub fn total_smooth_cols(&self) -> usize {
1022        self.term_designs.iter().map(DesignMatrix::ncols).sum()
1023    }
1024    pub fn nrows(&self) -> usize {
1025        self.term_designs.first().map_or(0, DesignMatrix::nrows)
1026    }
1027}
1028
1029impl From<RawSmoothDesign> for SmoothDesign {
1030    fn from(value: RawSmoothDesign) -> Self {
1031        Self {
1032            term_designs: value.term_designs,
1033            penalties: value.penalties,
1034            nullspace_dims: value.nullspace_dims,
1035            penaltyinfo: value.penaltyinfo,
1036            dropped_penaltyinfo: value.dropped_penaltyinfo,
1037            terms: value.terms,
1038            coefficient_lower_bounds: value.coefficient_lower_bounds,
1039            linear_constraints: value.linear_constraints,
1040        }
1041    }
1042}
1043
1044#[derive(Debug, Default, Clone, Serialize, Deserialize)]
1045pub enum BoundedCoefficientPriorSpec {
1046    #[default]
1047    None,
1048    Uniform,
1049    Beta {
1050        a: f64,
1051        b: f64,
1052    },
1053}
1054
1055#[derive(Debug, Clone, Serialize, Deserialize, Default)]
1056pub enum LinearCoefficientGeometry {
1057    #[default]
1058    Unconstrained,
1059    Bounded {
1060        min: f64,
1061        max: f64,
1062        #[serde(default)]
1063        prior: BoundedCoefficientPriorSpec,
1064    },
1065}
1066
1067#[derive(Debug, Clone, Serialize, Deserialize)]
1068pub struct LinearTermSpec {
1069    pub name: String,
1070    /// Primary feature column index. For Wilkinson-Rogers `:` interaction
1071    /// terms (`a:b[:c...]`) this is the first column in `feature_cols`; the
1072    /// realized design column is the elementwise product across every entry
1073    /// of `feature_cols`. Plain (non-interaction) linear terms set
1074    /// `feature_cols == vec![feature_col]`.
1075    pub feature_col: usize,
1076    /// Full list of columns whose elementwise product yields this term's
1077    /// design column. `len() >= 1`; `len() == 1` is a plain linear effect.
1078    #[serde(default)]
1079    pub feature_cols: Vec<usize>,
1080    /// Categorical-level gates for a factor-aware `:` interaction.
1081    ///
1082    /// Each `(col, level_bits)` multiplies the realized design column by the
1083    /// indicator `1[data[row, col].to_bits() == level_bits]`. This is how a
1084    /// `factor:x` (or `factor:factor`) interaction is expanded: `build_termspec`
1085    /// emits one `LinearTermSpec` per surviving cell of the categorical
1086    /// operand(s) (treatment-coded, first level dropped per factor), each
1087    /// carrying the numeric operands in `feature_cols` and the cell's level
1088    /// gate(s) here. Empty for a plain numeric `:` interaction or main effect,
1089    /// in which case the realized column is exactly the numeric product.
1090    #[serde(default)]
1091    pub categorical_levels: Vec<(usize, u64)>,
1092    /// Optional ridge (`S = I`, REML-selected `λ`) on this linear coefficient.
1093    /// A parametric linear term carries no wiggliness, so it is **unpenalized by
1094    /// default** — gam reports the MLE, matching mgcv/glm/survreg/VGAM (which
1095    /// penalize parametric terms only under an explicit `paraPen`). Set `true`
1096    /// to opt into an explicit shrinkage ridge (a zero-mean Gaussian prior
1097    /// `β ~ N(0, λ⁻¹)`); doing so adds one outer REML smoothing coordinate.
1098    #[serde(default = "default_linear_term_double_penalty")]
1099    pub double_penalty: bool,
1100    #[serde(default)]
1101    pub coefficient_geometry: LinearCoefficientGeometry,
1102    #[serde(default)]
1103    pub coefficient_min: Option<f64>,
1104    #[serde(default)]
1105    pub coefficient_max: Option<f64>,
1106}
1107
1108impl LinearTermSpec {
1109    /// Return the effective list of feature columns. Backfills from
1110    /// `feature_col` for legacy specs that predate the multi-column field.
1111    pub fn effective_feature_cols(&self) -> Vec<usize> {
1112        if self.feature_cols.is_empty() {
1113            vec![self.feature_col]
1114        } else {
1115            self.feature_cols.clone()
1116        }
1117    }
1118
1119    /// True when this term is a Wilkinson-Rogers `:` interaction (multi-col).
1120    pub fn is_interaction(&self) -> bool {
1121        self.feature_cols.len() > 1 || !self.categorical_levels.is_empty()
1122    }
1123
1124    /// Realize this linear term's `(n,)` design column from `data`.
1125    ///
1126    /// The column is the elementwise product of every numeric feature column
1127    /// (`effective_feature_cols`) gated by the categorical-level indicators in
1128    /// `categorical_levels`: each `(col, level_bits)` multiplies the running
1129    /// column by `1[data[row, col].to_bits() == level_bits]`. A plain numeric
1130    /// term (no `categorical_levels`) reduces to the bare product, matching the
1131    /// historical behaviour. A pure categorical interaction (empty
1132    /// `feature_cols`, non-empty `categorical_levels`) reduces to the cell
1133    /// indicator. Bounds are validated here; the returned column has length
1134    /// `data.nrows()`.
1135    pub fn realized_design_column(&self, data: ArrayView2<'_, f64>) -> Result<Array1<f64>, String> {
1136        let n = data.nrows();
1137        let p = data.ncols();
1138        let bounds = |col: usize| -> Result<(), String> {
1139            if col >= p {
1140                Err(format!(
1141                    "linear term '{}' feature column {} out of bounds for {} columns",
1142                    self.name, col, p
1143                ))
1144            } else {
1145                Ok(())
1146            }
1147        };
1148
1149        // Numeric operands. When `categorical_levels` is set we treat
1150        // `feature_cols` as the (possibly empty) numeric operand list and start
1151        // from a column of ones; otherwise we preserve the legacy backfill from
1152        // `feature_col` so a plain term with no `feature_cols` still resolves.
1153        let mut column = if self.categorical_levels.is_empty() {
1154            let cols = self.effective_feature_cols();
1155            for &c in &cols {
1156                bounds(c)?;
1157            }
1158            let mut acc = data.column(cols[0]).to_owned();
1159            for &c in cols.iter().skip(1) {
1160                acc *= &data.column(c);
1161            }
1162            acc
1163        } else {
1164            let mut acc = Array1::<f64>::ones(n);
1165            for &c in &self.feature_cols {
1166                bounds(c)?;
1167                acc *= &data.column(c);
1168            }
1169            acc
1170        };
1171
1172        for &(col, level_bits) in &self.categorical_levels {
1173            bounds(col)?;
1174            let gate = data.column(col);
1175            for (out, &v) in column.iter_mut().zip(gate.iter()) {
1176                if v.to_bits() != level_bits {
1177                    *out = 0.0;
1178                }
1179            }
1180        }
1181
1182        Ok(column)
1183    }
1184}
1185
1186pub const fn default_linear_term_double_penalty() -> bool {
1187    // Parametric/linear terms are unpenalized by default — a single linear
1188    // coefficient has no roughness for a smoothing penalty to control, so the
1189    // historical `S = I`, REML-selected `λ` shrank every linear coefficient off
1190    // the MLE and injected a spurious outer smoothing coordinate (#749). Mature
1191    // tools (mgcv/glm/survreg/VGAM) leave parametric terms unpenalized; gam now
1192    // matches that and reports the MLE. An explicit `double_penalty = true`
1193    // still opts a term into a ridge.
1194    false
1195}
1196
1197pub const fn default_pca_smooth_penalty() -> f64 {
1198    1.0
1199}
1200
1201pub const fn default_pca_chunk_size() -> usize {
1202    4096
1203}
1204
1205/// Random-effects term specification.
1206///
1207/// The selected feature column is interpreted as a categorical grouping variable.
1208/// The term contributes a one-hot dummy block with an identity penalty on group
1209/// coefficients, equivalent to i.i.d. Gaussian random effects.
1210#[derive(Debug, Clone, Serialize, Deserialize)]
1211pub struct RandomEffectTermSpec {
1212    pub name: String,
1213    pub feature_col: usize,
1214    /// If true, drop the lexicographically first group level to use treatment coding.
1215    /// If false, keep all levels (full one-hot block, still identifiable under ridge).
1216    pub drop_first_level: bool,
1217    /// If true, add a ridge penalty and estimate this block as a random effect.
1218    /// If false, leave the one-hot/treatment-coded block unpenalized so it is a
1219    /// fixed categorical main effect.  The default preserves older saved models.
1220    #[serde(default = "default_random_effect_penalized")]
1221    pub penalized: bool,
1222    /// Optional fixed kept-level set (sorted by f64 bit pattern) captured at fit time.
1223    /// When present, prediction uses exactly these columns to avoid design drift.
1224    #[serde(default)]
1225    pub frozen_levels: Option<Vec<u64>>,
1226}
1227
1228pub fn default_random_effect_penalized() -> bool {
1229    true
1230}
1231
1232pub fn validate_measure_jet_positive_vec_len(
1233    label: &str,
1234    term_name: &str,
1235    field: &str,
1236    values: &[f64],
1237    expected: usize,
1238) -> Result<(), String> {
1239    if values.len() != expected {
1240        return Err(SmoothError::invalid_config(format!(
1241            "{label} term '{term_name}' frozen MeasureJet {field} has length {}, expected {expected}",
1242            values.len()
1243        ))
1244        .into());
1245    }
1246    if values
1247        .iter()
1248        .any(|value| !(value.is_finite() && *value > 0.0))
1249    {
1250        return Err(SmoothError::invalid_config(format!(
1251            "{label} term '{term_name}' frozen MeasureJet {field} values must be positive and finite"
1252        ))
1253        .into());
1254    }
1255    Ok(())
1256}
1257
1258#[derive(Debug, Clone, Serialize, Deserialize)]
1259pub struct TermCollectionSpec {
1260    pub linear_terms: Vec<LinearTermSpec>,
1261    pub random_effect_terms: Vec<RandomEffectTermSpec>,
1262    pub smooth_terms: Vec<SmoothTermSpec>,
1263}
1264
1265pub fn validate_smooth_basis_frozen(
1266    basis: &SmoothBasisSpec,
1267    label: &str,
1268    term_name: &str,
1269) -> Result<(), String> {
1270    match basis {
1271        SmoothBasisSpec::ByVariable { inner, .. }
1272        | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
1273            validate_smooth_basis_frozen(inner, label, term_name)
1274        }
1275        SmoothBasisSpec::BSpline1D { spec, .. } => {
1276            if !matches!(
1277                spec.knotspec,
1278                BSplineKnotSpec::Provided(_)
1279                    | BSplineKnotSpec::PeriodicUniform { .. }
1280                    | BSplineKnotSpec::NaturalCubicRegression { .. }
1281            ) {
1282                return Err(format!(
1283                    "{label} term '{term_name}' is not frozen: BSpline knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression"
1284                ));
1285            }
1286            Ok(())
1287        }
1288        SmoothBasisSpec::ThinPlate { spec, .. } => {
1289            if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1290                return Err(format!(
1291                    "{label} term '{term_name}' is not frozen: ThinPlate centers must be UserProvided"
1292                ));
1293            }
1294            if matches!(
1295                spec.identifiability,
1296                SpatialIdentifiability::OrthogonalToParametric
1297            ) {
1298                return Err(format!(
1299                    "{label} term '{term_name}' is not frozen: ThinPlate identifiability must be FrozenTransform or None"
1300                ));
1301            }
1302            Ok(())
1303        }
1304        _ => Ok(()),
1305    }
1306}
1307
1308impl TermCollectionSpec {
1309    /// Write this collection's topology identity into a warm-start cache
1310    /// fingerprint (#869).
1311    ///
1312    /// The persistent warm-start `cache_key` hashes only family + raw input
1313    /// dimensions, so two fits on the same data that differ *only* in their
1314    /// smooth topology (the `s(..., type=AUTO)` candidate enumeration: sphere
1315    /// vs torus vs euclidean vs duchon) collide on one key and seed each other
1316    /// with geometrically incompatible β/ρ. Folding the per-term structural
1317    /// kind + feature columns + linear/random-effect counts into the shape hash
1318    /// gives each candidate its own key, so the screen→full-refit reuse of one
1319    /// candidate is preserved while cross-candidate contamination is removed.
1320    /// Only the structural identity is hashed (not fitted coefficients or
1321    /// frozen knot values), so a refit of the *same* topology still hits.
1322    pub fn write_structural_shape_hash(&self, h: &mut gam_runtime::warm_start::Fingerprinter) {
1323        h.write_str("term-collection");
1324        h.write_usize(self.linear_terms.len());
1325        for linear in &self.linear_terms {
1326            h.write_str(&linear.name);
1327        }
1328        h.write_usize(self.random_effect_terms.len());
1329        h.write_usize(self.smooth_terms.len());
1330        for smooth in &self.smooth_terms {
1331            h.write_str(&smooth.name);
1332            h.write_str(smooth.basis.structural_kind());
1333            for col in smooth.basis.structural_feature_cols() {
1334                h.write_usize(col);
1335            }
1336        }
1337    }
1338
1339    /// Validate that a term collection spec represents a fully frozen model
1340    /// (i.e. all knots/centers are pre-computed, identifiability transforms are
1341    /// baked in, and random-effect levels are fixed).
1342    pub fn validate_frozen(&self, label: &str) -> Result<(), String> {
1343        for linear in &self.linear_terms {
1344            if let (Some(min), Some(max)) = (linear.coefficient_min, linear.coefficient_max)
1345                && (!min.is_finite() || !max.is_finite() || min > max)
1346            {
1347                return Err(SmoothError::invalid_config(format!(
1348                    "{label} linear term '{}' has invalid coefficient constraint [{min}, {max}]",
1349                    linear.name
1350                ))
1351                .into());
1352            }
1353            if let Some(min) = linear.coefficient_min
1354                && !min.is_finite()
1355            {
1356                return Err(SmoothError::invalid_config(format!(
1357                    "{label} linear term '{}' has non-finite coefficient minimum {min}",
1358                    linear.name
1359                ))
1360                .into());
1361            }
1362            if let Some(max) = linear.coefficient_max
1363                && !max.is_finite()
1364            {
1365                return Err(SmoothError::invalid_config(format!(
1366                    "{label} linear term '{}' has non-finite coefficient maximum {max}",
1367                    linear.name
1368                ))
1369                .into());
1370            }
1371            if let LinearCoefficientGeometry::Bounded { min, max, prior } =
1372                &linear.coefficient_geometry
1373            {
1374                if !min.is_finite() || !max.is_finite() || min >= max {
1375                    return Err(SmoothError::invalid_config(format!(
1376                        "{label} bounded term '{}' has invalid bounds [{min}, {max}]",
1377                        linear.name
1378                    ))
1379                    .into());
1380                }
1381                match prior {
1382                    BoundedCoefficientPriorSpec::None | BoundedCoefficientPriorSpec::Uniform => {}
1383                    BoundedCoefficientPriorSpec::Beta { a, b } => {
1384                        if !a.is_finite() || !b.is_finite() || *a < 1.0 || *b < 1.0 {
1385                            return Err(SmoothError::invalid_config(format!(
1386                                "{label} bounded term '{}' has invalid Beta prior ({a}, {b})",
1387                                linear.name
1388                            ))
1389                            .into());
1390                        }
1391                    }
1392                }
1393            }
1394        }
1395        for st in &self.smooth_terms {
1396            match &st.basis {
1397                SmoothBasisSpec::ByVariable { inner, .. } => {
1398                    validate_smooth_basis_frozen(inner, label, &st.name)?;
1399                    let nested = SmoothTermSpec {
1400                        name: st.name.clone(),
1401                        basis: (**inner).clone(),
1402                        shape: st.shape,
1403                        joint_null_rotation: None,
1404                    };
1405                    TermCollectionSpec {
1406                        linear_terms: Vec::new(),
1407                        random_effect_terms: Vec::new(),
1408                        smooth_terms: vec![nested],
1409                    }
1410                    .validate_frozen(label)?;
1411                }
1412                SmoothBasisSpec::FactorSumToZero { inner, levels, .. } => {
1413                    if levels.len() < 2 {
1414                        return Err(format!(
1415                            "{label} term '{}' has invalid frozen sz levels",
1416                            st.name
1417                        ));
1418                    }
1419                    validate_smooth_basis_frozen(inner, label, &st.name)?;
1420                }
1421                SmoothBasisSpec::BSpline1D { spec, .. } => {
1422                    if !matches!(
1423                        spec.knotspec,
1424                        BSplineKnotSpec::Provided(_)
1425                            | BSplineKnotSpec::PeriodicUniform { .. }
1426                            | BSplineKnotSpec::NaturalCubicRegression { .. }
1427                    ) {
1428                        return Err(SmoothError::invalid_config(format!(
1429                            "{label} term '{}' is not frozen: BSpline knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression",
1430                            st.name
1431                        ))
1432                        .into());
1433                    }
1434                }
1435                SmoothBasisSpec::ThinPlate { spec, .. } => {
1436                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1437                        return Err(SmoothError::invalid_config(format!(
1438                            "{label} term '{}' is not frozen: ThinPlate centers must be UserProvided",
1439                            st.name
1440                        ))
1441                        .into());
1442                    }
1443                    if matches!(
1444                        spec.identifiability,
1445                        SpatialIdentifiability::OrthogonalToParametric
1446                    ) {
1447                        return Err(SmoothError::invalid_config(format!(
1448                            "{label} term '{}' is not frozen: ThinPlate identifiability must be FrozenTransform or None",
1449                            st.name
1450                        ))
1451                        .into());
1452                    }
1453                }
1454                SmoothBasisSpec::Sphere { spec, .. } => {
1455                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1456                        return Err(SmoothError::invalid_config(format!(
1457                            "{label} term '{}' is not frozen: Sphere centers must be UserProvided",
1458                            st.name
1459                        ))
1460                        .into());
1461                    }
1462                    if matches!(spec.method, crate::basis::SphereMethod::Harmonic)
1463                        && spec.max_degree.is_none_or(|d| d == 0)
1464                    {
1465                        return Err(format!(
1466                            "{label} term '{}' is not frozen: sphere max_degree must be positive",
1467                            st.name
1468                        ));
1469                    }
1470                }
1471                SmoothBasisSpec::ConstantCurvature { spec, .. } => {
1472                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1473                        return Err(SmoothError::invalid_config(format!(
1474                            "{label} term '{}' is not frozen: ConstantCurvature centers must be UserProvided",
1475                            st.name
1476                        ))
1477                        .into());
1478                    }
1479                    if !(spec.length_scale.is_finite() && spec.length_scale > 0.0) {
1480                        return Err(SmoothError::invalid_config(format!(
1481                            "{label} term '{}' is not frozen: ConstantCurvature length_scale must be the realized positive value",
1482                            st.name
1483                        ))
1484                        .into());
1485                    }
1486                }
1487                SmoothBasisSpec::MeasureJet { spec, .. } => {
1488                    let centers = match &spec.center_strategy {
1489                        CenterStrategy::UserProvided(centers) => centers,
1490                        _ => {
1491                            return Err(SmoothError::invalid_config(format!(
1492                                "{label} term '{}' is not frozen: MeasureJet centers must be UserProvided",
1493                                st.name
1494                            ))
1495                            .into());
1496                        }
1497                    };
1498                    if centers.nrows() == 0 {
1499                        return Err(SmoothError::invalid_config(format!(
1500                            "{label} term '{}' is not frozen: MeasureJet centers are empty",
1501                            st.name
1502                        ))
1503                        .into());
1504                    }
1505                    if !(spec.length_scale.is_finite() && spec.length_scale > 0.0) {
1506                        return Err(SmoothError::invalid_config(format!(
1507                            "{label} term '{}' is not frozen: MeasureJet length_scale must be the realized positive value",
1508                            st.name
1509                        ))
1510                        .into());
1511                    }
1512                    // Exact replay needs the fit-data penalty quadrature and
1513                    // normalization payload (`BasisMetadata::MeasureJet`).
1514                    let frozen = spec.frozen_quadrature.as_ref().ok_or_else(|| {
1515                        SmoothError::invalid_config(format!(
1516                            "{label} term '{}' is not frozen: MeasureJet frozen_quadrature payload is missing",
1517                            st.name
1518                        ))
1519                    })?;
1520                    if frozen.masses.len() != centers.nrows() {
1521                        return Err(SmoothError::invalid_config(format!(
1522                            "{label} term '{}' frozen MeasureJet has {} masses for {} centers",
1523                            st.name,
1524                            frozen.masses.len(),
1525                            centers.nrows()
1526                        ))
1527                        .into());
1528                    }
1529                    let total_mass = frozen.masses.sum();
1530                    if frozen
1531                        .masses
1532                        .iter()
1533                        .any(|mass| !(mass.is_finite() && *mass >= 0.0))
1534                        || !(total_mass.is_finite() && total_mass > 0.0)
1535                    {
1536                        return Err(SmoothError::invalid_config(format!(
1537                            "{label} term '{}' frozen MeasureJet masses must be finite, nonnegative, and have positive total mass",
1538                            st.name
1539                        ))
1540                        .into());
1541                    }
1542                    let n_levels = frozen.eps_band.len();
1543                    if n_levels == 0
1544                        || frozen
1545                            .eps_band
1546                            .iter()
1547                            .any(|eps| !(eps.is_finite() && *eps > 0.0))
1548                    {
1549                        return Err(SmoothError::invalid_config(format!(
1550                            "{label} term '{}' frozen MeasureJet eps_band must be nonempty, finite, and positive",
1551                            st.name
1552                        ))
1553                        .into());
1554                    }
1555                    for (idx, pair) in frozen.eps_band.windows(2).enumerate() {
1556                        if pair[1] <= pair[0] {
1557                            return Err(SmoothError::invalid_config(format!(
1558                                "{label} term '{}' frozen MeasureJet eps_band is not strictly ascending at {idx}: {} then {}",
1559                                st.name,
1560                                pair[0],
1561                                pair[1]
1562                            ))
1563                            .into());
1564                        }
1565                    }
1566                    validate_measure_jet_positive_vec_len(
1567                        label,
1568                        &st.name,
1569                        "support_means",
1570                        &frozen.support_means,
1571                        n_levels,
1572                    )?;
1573                    // Mode predicate MUST match the builder's
1574                    // (`measure_jet_multiscale_mode`): per-level/multiscale is the
1575                    // explicit `spec.multiscale` opt-in (#1116). In single-scale
1576                    // mode the builder emits a single FUSED penalty (empty
1577                    // per-level scales + `fused_penalty_normalization_scale:
1578                    // Some`); only the multiscale opt-in carries `n_levels`
1579                    // per-level scales.
1580                    let per_level = crate::basis::measure_jet_multiscale_mode(spec);
1581                    if per_level {
1582                        validate_measure_jet_positive_vec_len(
1583                            label,
1584                            &st.name,
1585                            "penalty_normalization_scales",
1586                            &frozen.penalty_normalization_scales,
1587                            n_levels,
1588                        )?;
1589                        validate_measure_jet_positive_vec_len(
1590                            label,
1591                            &st.name,
1592                            "raw_penalty_normalization_scales",
1593                            &frozen.raw_penalty_normalization_scales,
1594                            n_levels,
1595                        )?;
1596                        if frozen.fused_penalty_normalization_scale.is_some() {
1597                            return Err(SmoothError::invalid_config(format!(
1598                                "{label} term '{}' per-level MeasureJet must not carry a fused penalty normalization scale",
1599                                st.name
1600                            ))
1601                            .into());
1602                        }
1603                    } else {
1604                        if !frozen.penalty_normalization_scales.is_empty()
1605                            || !frozen.raw_penalty_normalization_scales.is_empty()
1606                        {
1607                            return Err(SmoothError::invalid_config(format!(
1608                                "{label} term '{}' fused MeasureJet must not carry per-level penalty normalization scales",
1609                                st.name
1610                            ))
1611                            .into());
1612                        }
1613                        match frozen.fused_penalty_normalization_scale {
1614                            Some(scale) if scale.is_finite() && scale > 0.0 => {}
1615                            Some(scale) => {
1616                                return Err(SmoothError::invalid_config(format!(
1617                                    "{label} term '{}' fused MeasureJet penalty normalization scale must be positive and finite, got {scale}",
1618                                    st.name
1619                                ))
1620                                .into());
1621                            }
1622                            None => {
1623                                return Err(SmoothError::invalid_config(format!(
1624                                    "{label} term '{}' fused MeasureJet is missing its penalty normalization scale",
1625                                    st.name
1626                                ))
1627                                .into());
1628                            }
1629                        }
1630                    }
1631                }
1632                SmoothBasisSpec::Matern { spec, .. } => {
1633                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1634                        return Err(SmoothError::invalid_config(format!(
1635                            "{label} term '{}' is not frozen: Matern centers must be UserProvided",
1636                            st.name
1637                        ))
1638                        .into());
1639                    }
1640                }
1641                SmoothBasisSpec::Duchon { spec, .. } => {
1642                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1643                        return Err(SmoothError::invalid_config(format!(
1644                            "{label} term '{}' is not frozen: Duchon centers must be UserProvided",
1645                            st.name
1646                        ))
1647                        .into());
1648                    }
1649                    if matches!(
1650                        spec.identifiability,
1651                        SpatialIdentifiability::OrthogonalToParametric
1652                    ) {
1653                        return Err(SmoothError::invalid_config(format!(
1654                            "{label} term '{}' is not frozen: Duchon identifiability must be FrozenTransform or None",
1655                            st.name
1656                        ))
1657                        .into());
1658                    }
1659                }
1660                SmoothBasisSpec::Pca {
1661                    centered,
1662                    center_mean,
1663                    pca_basis_path,
1664                    ..
1665                } => {
1666                    if *centered && center_mean.is_none() && pca_basis_path.is_none() {
1667                        return Err(SmoothError::invalid_config(format!(
1668                            "{label} term '{}' is not frozen: centered Pca missing center_mean",
1669                            st.name
1670                        ))
1671                        .into());
1672                    }
1673                }
1674                SmoothBasisSpec::BySmooth { smooth, by_kind } => {
1675                    if let SmoothBasisSpec::BySmooth { .. } = smooth.as_ref() {
1676                        return Err(format!("{label} term '{}' has nested by-smooths", st.name));
1677                    }
1678                    match by_kind {
1679                        ByVarKind::Numeric { .. } => {}
1680                        ByVarKind::Factor { frozen_levels, .. } if frozen_levels.is_none() => {
1681                            return Err(format!(
1682                                "{label} term '{}' is not frozen: by-factor levels missing",
1683                                st.name
1684                            ));
1685                        }
1686                        ByVarKind::Factor { .. } => {}
1687                    }
1688                    let nested = TermCollectionSpec {
1689                        linear_terms: vec![],
1690                        random_effect_terms: vec![],
1691                        smooth_terms: vec![SmoothTermSpec {
1692                            name: st.name.clone(),
1693                            basis: (**smooth).clone(),
1694                            shape: st.shape,
1695                            joint_null_rotation: None,
1696                        }],
1697                    };
1698                    nested.validate_frozen(label)?;
1699                }
1700                SmoothBasisSpec::FactorSmooth { spec } => {
1701                    if spec.group_frozen_levels.is_none() {
1702                        return Err(format!(
1703                            "{label} term '{}' is not frozen: factor-smooth levels missing",
1704                            st.name
1705                        ));
1706                    }
1707                    if !matches!(
1708                        spec.marginal.knotspec,
1709                        BSplineKnotSpec::Provided(_)
1710                            | BSplineKnotSpec::PeriodicUniform { .. }
1711                            // mgcv's `bs="sz"` default marginal is a cubic
1712                            // regression spline (#1074), and the freeze step
1713                            // restores it as a `NaturalCubicRegression` knotspec
1714                            // carrying its `k` value-knots (spatial_optimization.rs
1715                            // `marginal_is_cr` branch) — the SAME treatment the
1716                            // tensor margin already gets in the arm below. Without
1717                            // this variant a frozen `sz` factor smooth fails its own
1718                            // predict-time freeze check ("factor-smooth marginal
1719                            // knots missing") even though its knots are fully
1720                            // materialized; the validation simply was not updated
1721                            // when the cr marginal landed.
1722                            | BSplineKnotSpec::NaturalCubicRegression { .. }
1723                    ) {
1724                        return Err(format!(
1725                            "{label} term '{}' is not frozen: factor-smooth marginal knots missing",
1726                            st.name
1727                        ));
1728                    }
1729                }
1730                SmoothBasisSpec::TensorBSpline { spec, .. } => {
1731                    for (dim, marginal) in spec.marginalspecs.iter().enumerate() {
1732                        if !matches!(
1733                            marginal.knotspec,
1734                            BSplineKnotSpec::Provided(_)
1735                                | BSplineKnotSpec::PeriodicUniform { .. }
1736                                | BSplineKnotSpec::NaturalCubicRegression { .. }
1737                        ) {
1738                            return Err(SmoothError::invalid_config(format!(
1739                                "{label} term '{}' dim {} is not frozen: tensor marginal knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression",
1740                                st.name, dim
1741                            ))
1742                            .into());
1743                        }
1744                    }
1745                    if matches!(
1746                        spec.identifiability,
1747                        TensorBSplineIdentifiability::SumToZero
1748                            | TensorBSplineIdentifiability::MarginalSumToZero
1749                    ) {
1750                        return Err(SmoothError::invalid_config(format!(
1751                            "{label} term '{}' is not frozen: tensor identifiability must be FrozenTransform or None",
1752                            st.name
1753                        ))
1754                        .into());
1755                    }
1756                }
1757            }
1758        }
1759
1760        for rt in &self.random_effect_terms {
1761            if rt.frozen_levels.is_none() {
1762                return Err(SmoothError::invalid_config(format!(
1763                    "{label} random-effect term '{}' is not frozen: missing frozen_levels",
1764                    rt.name
1765                ))
1766                .into());
1767            }
1768        }
1769
1770        Ok(())
1771    }
1772
1773    /// Re-resolve every stored feature-column index through `remap`, returning a
1774    /// spec that addresses a different column layout.
1775    ///
1776    /// A frozen `TermCollectionSpec` stores feature columns as *absolute indices
1777    /// into the training table*. To replay it on a fresh dataset whose columns
1778    /// sit at different positions — the common case at prediction time, where
1779    /// the response column is unknown and may be absent entirely — every index
1780    /// must be re-resolved against the new layout. `remap` receives each
1781    /// training-table index and returns its position in the runtime table;
1782    /// callers typically implement it as "look the name up in the training
1783    /// headers, then resolve that name against the prediction dataset".
1784    ///
1785    /// This is the single authority on *which* fields carry a column index
1786    /// across every basis variant (linear, random-effect, the `by=` column of
1787    /// `ByVariable`/`FactorSumToZero`/`BySmooth`, the continuous and group
1788    /// columns of a `FactorSmooth`, and the multi-axis `feature_cols` of every
1789    /// spatial/tensor basis), so a predict-time realignment cannot silently miss
1790    /// one and dereference a stale training index.
1791    pub fn remap_feature_columns<E, F>(&self, mut remap: F) -> Result<TermCollectionSpec, E>
1792    where
1793        F: FnMut(usize) -> Result<usize, E>,
1794    {
1795        let mut out = self.clone();
1796        for lt in &mut out.linear_terms {
1797            lt.feature_col = remap(lt.feature_col)?;
1798            // Also remap the full interaction-factor list. The design builder
1799            // (`build_term_collection_design_inner`) materializes the column from
1800            // `effective_feature_cols()` — which returns `feature_cols` whenever
1801            // it is non-empty (i.e. essentially always, including a plain linear
1802            // term where `feature_cols == [feature_col]`). Remapping only the
1803            // singular `feature_col` left these at their saved *training* indices
1804            // at predict time, so a parametric `Surv(...) ~ x` (and any `:`
1805            // interaction) bailed with "feature column N out of bounds" once the
1806            // response/time columns shift the runtime layout (issue #898).
1807            for fc in lt.feature_cols.iter_mut() {
1808                *fc = remap(*fc)?;
1809            }
1810            // A factor-aware `:` interaction also gates on categorical columns;
1811            // those indices live in the same training-time layout and must be
1812            // realigned to the runtime table alongside the numeric operands, or
1813            // the predict-time level indicator would dereference a stale column.
1814            for (col, _bits) in lt.categorical_levels.iter_mut() {
1815                *col = remap(*col)?;
1816            }
1817        }
1818        for rt in &mut out.random_effect_terms {
1819            rt.feature_col = remap(rt.feature_col)?;
1820        }
1821        for st in &mut out.smooth_terms {
1822            remap_smooth_basis_feature_columns(&mut st.basis, &mut remap)?;
1823        }
1824        Ok(out)
1825    }
1826}
1827
1828/// Walk a `SmoothBasisSpec` tree, re-resolving every column index through
1829/// `remap`. Shared by all predict-time column realignment (see
1830/// [`TermCollectionSpec::remap_feature_columns`]); kept exhaustive so a newly
1831/// added index-bearing variant fails to compile until it is handled here.
1832pub fn remap_smooth_basis_feature_columns<E, F>(
1833    basis: &mut SmoothBasisSpec,
1834    remap: &mut F,
1835) -> Result<(), E>
1836where
1837    F: FnMut(usize) -> Result<usize, E>,
1838{
1839    match basis {
1840        SmoothBasisSpec::ByVariable { inner, by_col, .. }
1841        | SmoothBasisSpec::FactorSumToZero { inner, by_col, .. } => {
1842            *by_col = remap(*by_col)?;
1843            remap_smooth_basis_feature_columns(inner, remap)?;
1844        }
1845        SmoothBasisSpec::BSpline1D { feature_col, .. } => {
1846            *feature_col = remap(*feature_col)?;
1847        }
1848        SmoothBasisSpec::BySmooth { smooth, by_kind } => {
1849            let by_feature_col = match by_kind {
1850                ByVarKind::Numeric { feature_col } | ByVarKind::Factor { feature_col, .. } => {
1851                    feature_col
1852                }
1853            };
1854            *by_feature_col = remap(*by_feature_col)?;
1855            remap_smooth_basis_feature_columns(smooth, remap)?;
1856        }
1857        SmoothBasisSpec::FactorSmooth { spec } => {
1858            for fc in spec.continuous_cols.iter_mut() {
1859                *fc = remap(*fc)?;
1860            }
1861            spec.group_col = remap(spec.group_col)?;
1862        }
1863        SmoothBasisSpec::ThinPlate { feature_cols, .. }
1864        | SmoothBasisSpec::Sphere { feature_cols, .. }
1865        | SmoothBasisSpec::ConstantCurvature { feature_cols, .. }
1866        | SmoothBasisSpec::Matern { feature_cols, .. }
1867        | SmoothBasisSpec::MeasureJet { feature_cols, .. }
1868        | SmoothBasisSpec::Duchon { feature_cols, .. }
1869        | SmoothBasisSpec::Pca { feature_cols, .. }
1870        | SmoothBasisSpec::TensorBSpline { feature_cols, .. } => {
1871            for fc in feature_cols.iter_mut() {
1872                *fc = remap(*fc)?;
1873            }
1874        }
1875    }
1876    Ok(())
1877}
1878
1879#[derive(Debug, Clone)]
1880pub enum PenaltyStructureHint {
1881    Ridge(f64),
1882    Kronecker(Vec<Array2<f64>>),
1883}
1884
1885/// A penalty matrix stored at its natural block size together with the
1886/// column range it occupies in the global coefficient vector.
1887///
1888/// Instead of embedding every penalty into a full `p_total × p_total` dense
1889/// matrix filled with zeros, we keep the compact local matrix and reconstruct
1890/// the global view only when a downstream consumer explicitly requires it.
1891#[derive(Clone)]
1892pub struct BlockwisePenalty {
1893    /// Column range in the global coefficient vector that this penalty covers.
1894    pub col_range: Range<usize>,
1895    /// The local penalty matrix — dimensions `block_p × block_p` where
1896    /// `block_p = col_range.len()`.
1897    pub local: Array2<f64>,
1898    /// Optional nonzero centering vector for this coefficient block.
1899    pub prior_mean: gam_problem::CoefficientPriorMean,
1900    /// Optional structural hint so downstream spectral/logdet code can stay
1901    /// block-local or factorized without reverse-engineering the matrix.
1902    pub structure_hint: Option<PenaltyStructureHint>,
1903    /// Optional operator-form handle bit-equivalent to `local`. Populated when
1904    /// the originating closed-form factory emitted an op-form penalty so exact
1905    /// operator algebra can use matvec instead of materializing the dense
1906    /// `block_p × block_p` Gram. `None` for ordinary dense penalties.
1907    pub op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1908}
1909
1910impl std::fmt::Debug for BlockwisePenalty {
1911    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1912        f.debug_struct("BlockwisePenalty")
1913            .field("col_range", &self.col_range)
1914            .field(
1915                "local",
1916                &format_args!("{}×{}", self.local.nrows(), self.local.ncols()),
1917            )
1918            .field("prior_mean", &self.prior_mean)
1919            .field("structure_hint", &self.structure_hint)
1920            .field("op", &self.op.as_ref().map(|o| o.dim()))
1921            .finish()
1922    }
1923}
1924
1925impl BlockwisePenalty {
1926    /// Create a new blockwise penalty.
1927    pub fn new(col_range: Range<usize>, local: Array2<f64>) -> Self {
1928        assert_eq!(col_range.len(), local.nrows());
1929        assert_eq!(col_range.len(), local.ncols());
1930        Self {
1931            col_range,
1932            local,
1933            prior_mean: gam_problem::CoefficientPriorMean::Zero,
1934            structure_hint: None,
1935            op: None,
1936        }
1937    }
1938
1939    pub fn with_prior_mean(
1940        mut self,
1941        prior_mean: gam_problem::CoefficientPriorMean,
1942    ) -> Self {
1943        self.prior_mean = prior_mean;
1944        self
1945    }
1946
1947    /// Attach an op-form penalty handle bit-equivalent to `local`.
1948    pub fn with_op(
1949        mut self,
1950        op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1951    ) -> Self {
1952        self.op = op;
1953        self
1954    }
1955
1956    pub fn ridge(col_range: Range<usize>, scale: f64) -> Self {
1957        let block_size = col_range.len();
1958        let mut local = Array2::<f64>::zeros((block_size, block_size));
1959        for i in 0..block_size {
1960            local[[i, i]] = scale;
1961        }
1962        Self {
1963            col_range,
1964            local,
1965            prior_mean: gam_problem::CoefficientPriorMean::Zero,
1966            structure_hint: Some(PenaltyStructureHint::Ridge(scale)),
1967            op: None,
1968        }
1969    }
1970
1971    pub fn kronecker(
1972        col_range: Range<usize>,
1973        local: Array2<f64>,
1974        factors: Vec<Array2<f64>>,
1975    ) -> Self {
1976        assert_eq!(col_range.len(), local.nrows());
1977        assert_eq!(col_range.len(), local.ncols());
1978        Self {
1979            col_range,
1980            local,
1981            prior_mean: gam_problem::CoefficientPriorMean::Zero,
1982            structure_hint: Some(PenaltyStructureHint::Kronecker(factors)),
1983            op: None,
1984        }
1985    }
1986
1987    /// Expand this blockwise penalty into a full `p_total × p_total` dense
1988    /// matrix (mostly zeros). Use sparingly — the whole point of blockwise
1989    /// storage is to avoid this allocation.
1990    pub fn to_global(&self, p_total: usize) -> Array2<f64> {
1991        let mut g = Array2::<f64>::zeros((p_total, p_total));
1992        let r = &self.col_range;
1993        assert!(
1994            r.end <= p_total && self.local.nrows() == r.len() && self.local.ncols() == r.len(),
1995            "BlockwisePenalty::to_global shape invariant violated: \
1996             col_range={}..{}, local={}x{}, p_total={}",
1997            r.start,
1998            r.end,
1999            self.local.nrows(),
2000            self.local.ncols(),
2001            p_total,
2002        );
2003        g.slice_mut(s![r.start..r.end, r.start..r.end])
2004            .assign(&self.local);
2005        g
2006    }
2007
2008    /// Convert into a blockwise [`gam_problem::PenaltyMatrix`] without
2009    /// expanding to full dimensions.
2010    pub fn to_penalty_matrix(
2011        &self,
2012        total_dim: usize,
2013    ) -> gam_problem::PenaltyMatrix {
2014        gam_problem::PenaltyMatrix::Blockwise {
2015            local: self.local.clone(),
2016            col_range: self.col_range.clone(),
2017            total_dim,
2018        }
2019    }
2020
2021    /// The block size of this penalty.
2022    #[inline]
2023    pub fn block_size(&self) -> usize {
2024        self.col_range.len()
2025    }
2026}
2027
2028/// Compute `Σ_k λ_k S_k` directly from blockwise penalties, accumulating
2029/// into a pre-allocated `p_total × p_total` output without ever materializing
2030/// individual global matrices.
2031pub fn weighted_blockwise_penalty_sum(
2032    penalties: &[BlockwisePenalty],
2033    lambdas: &[f64],
2034    p_total: usize,
2035) -> Array2<f64> {
2036    assert_eq!(penalties.len(), lambdas.len());
2037    // Smoothing parameters λ_k must be non-negative and finite. A negative
2038    // λ would flip the sign of the corresponding block S_k, turning the
2039    // total penalty matrix indefinite and silently corrupting every
2040    // downstream Cholesky / PIRLS / REML / pseudo-logdet computation that
2041    // assumes S_λ ⪰ 0. Catch this at the boundary rather than after it
2042    // has propagated.
2043    for (idx, &lam) in lambdas.iter().enumerate() {
2044        assert!(
2045            lam.is_finite() && lam >= 0.0,
2046            "weighted_blockwise_penalty_sum: lambdas[{idx}] = {lam} is invalid (must be finite and non-negative; negative smoothing parameters violate S_λ ⪰ 0)",
2047        );
2048    }
2049    // Block column ranges must also fit inside the declared total parameter
2050    // dimension; an out-of-bounds slice would otherwise panic from ndarray
2051    // with a far less informative message.
2052    for (idx, bp) in penalties.iter().enumerate() {
2053        let r = &bp.col_range;
2054        assert!(
2055            r.end <= p_total,
2056            "weighted_blockwise_penalty_sum: penalties[{idx}] col_range {:?} exceeds p_total = {p_total}",
2057            r,
2058        );
2059    }
2060    let mut out = Array2::<f64>::zeros((p_total, p_total));
2061    for (bp, &lam) in penalties.iter().zip(lambdas.iter()) {
2062        let r = &bp.col_range;
2063        let mut slice = out.slice_mut(s![r.start..r.end, r.start..r.end]);
2064        slice.scaled_add(lam, &bp.local);
2065    }
2066    out
2067}
2068
2069// ---------------------------------------------------------------------------
2070// KroneckerPenaltySystem — factored tensor-product penalty representation
2071// ---------------------------------------------------------------------------
2072
2073/// Factored representation of tensor-product penalties with precomputed
2074/// marginal eigensystems for O(∏q_j) logdet and penalty operations.
2075#[derive(Debug, Clone)]
2076pub struct KroneckerPenaltySystem {
2077    /// Marginal penalty matrices: `marginal_penalties[k]` is `(q_k, q_k)`.
2078    pub marginal_penalties: Vec<Array2<f64>>,
2079    /// Precomputed eigensystems: `(eigenvalues, eigenvectors)` per marginal.
2080    pub marginal_eigensystems: Vec<(Array1<f64>, Array2<f64>)>,
2081    /// Marginal basis dimensions.
2082    pub marginal_dims: Vec<usize>,
2083    /// Whether a global ridge (double) penalty is present.
2084    pub has_double_penalty: bool,
2085}
2086
2087impl KroneckerPenaltySystem {
2088    pub fn new(
2089        marginal_penalties: Vec<Array2<f64>>,
2090        marginal_dims: Vec<usize>,
2091        has_double_penalty: bool,
2092    ) -> Result<Self, BasisError> {
2093        if marginal_penalties.len() != marginal_dims.len() {
2094            crate::bail_dim_basis!(
2095                "KroneckerPenaltySystem: {} penalties vs {} dims",
2096                marginal_penalties.len(),
2097                marginal_dims.len()
2098            );
2099        }
2100        let eigensystems =
2101            kronecker_marginal_eigensystems(&marginal_penalties, "KroneckerPenaltySystem")
2102                .map_err(|e| BasisError::InvalidInput(e.to_string()))?;
2103        Ok(Self {
2104            marginal_penalties,
2105            marginal_eigensystems: eigensystems,
2106            marginal_dims,
2107            has_double_penalty,
2108        })
2109    }
2110
2111    pub fn p_total(&self) -> usize {
2112        self.marginal_dims.iter().copied().product()
2113    }
2114
2115    pub fn ndim(&self) -> usize {
2116        self.marginal_dims.len()
2117    }
2118
2119    pub fn num_penalties(&self) -> usize {
2120        self.marginal_dims.len() + if self.has_double_penalty { 1 } else { 0 }
2121    }
2122
2123    /// Compute `log|S|₊` and its first/second derivatives w.r.t. `ρ_k = log(λ_k)`.
2124    ///
2125    /// Iterates over the ∏q_j multi-index grid. Cost: O(d · ∏q_j), no O(p²) storage.
2126    pub fn logdet_and_derivatives(
2127        &self,
2128        lambdas: &[f64],
2129        ridge: f64,
2130    ) -> (f64, Array1<f64>, Array2<f64>) {
2131        let n_pen = self.num_penalties();
2132        assert_eq!(lambdas.len(), n_pen, "lambda count mismatch");
2133        let marginal_evals: Vec<_> = self
2134            .marginal_eigensystems
2135            .iter()
2136            .map(|(evals, _)| evals.view())
2137            .collect();
2138        kronecker_logdet_and_derivatives(
2139            &marginal_evals,
2140            &self.marginal_dims,
2141            lambdas,
2142            self.has_double_penalty,
2143            ridge,
2144        )
2145    }
2146
2147    pub fn logdet_rank_and_derivatives(
2148        &self,
2149        lambdas: &[f64],
2150        ridge: f64,
2151    ) -> (f64, usize, Array1<f64>, Array2<f64>) {
2152        let n_pen = self.num_penalties();
2153        assert_eq!(lambdas.len(), n_pen, "lambda count mismatch");
2154        let d = self.marginal_dims.len();
2155        let mut logdet = 0.0;
2156        let mut rank = 0usize;
2157        let mut grad = Array1::<f64>::zeros(n_pen);
2158        let mut hess = Array2::<f64>::zeros((n_pen, n_pen));
2159        // Positivity floor for a penalized eigenvalue `σ`: below this the mode
2160        // is treated as an unpenalized (null-space) direction and excluded from
2161        // both the rank count and the pseudo-log-determinant.
2162        const EIGENVALUE_POSITIVITY_FLOOR: f64 = 1e-12;
2163        // Floor on the *structural* eigenvalue sum (λ-independent) used to decide
2164        // whether a mode lives in the penalty range space and so should receive
2165        // the stabilizing ridge; a structurally-null mode gets no ridge.
2166        const STRUCTURAL_ZERO_FLOOR: f64 = 1e-12;
2167        let mut multi_idx = vec![0usize; d];
2168        loop {
2169            let mut sigma = 0.0;
2170            let mut structural_sigma = 0.0;
2171            for k in 0..d {
2172                let marginal_eigenvalue = self.marginal_eigensystems[k].0[multi_idx[k]];
2173                structural_sigma += marginal_eigenvalue;
2174                sigma += lambdas[k] * marginal_eigenvalue;
2175            }
2176            let joint_null = structural_sigma <= STRUCTURAL_ZERO_FLOOR;
2177            if self.has_double_penalty && joint_null {
2178                sigma += lambdas[d];
2179            }
2180            if structural_sigma > STRUCTURAL_ZERO_FLOOR {
2181                sigma += ridge;
2182            }
2183
2184            if sigma > EIGENVALUE_POSITIVITY_FLOOR {
2185                rank += 1;
2186                logdet += sigma.ln();
2187                let inv_sigma = 1.0 / sigma;
2188                let inv_sigma2 = inv_sigma * inv_sigma;
2189                for k in 0..n_pen {
2190                    let ck = if k < d {
2191                        lambdas[k] * self.marginal_eigensystems[k].0[multi_idx[k]]
2192                    } else if joint_null {
2193                        lambdas[d]
2194                    } else {
2195                        0.0
2196                    };
2197                    grad[k] += ck * inv_sigma;
2198                    hess[[k, k]] += ck * inv_sigma - ck * ck * inv_sigma2;
2199                    for l in (k + 1)..n_pen {
2200                        let cl = if l < d {
2201                            lambdas[l] * self.marginal_eigensystems[l].0[multi_idx[l]]
2202                        } else if joint_null {
2203                            lambdas[d]
2204                        } else {
2205                            0.0
2206                        };
2207                        let off = -ck * cl * inv_sigma2;
2208                        hess[[k, l]] += off;
2209                        hess[[l, k]] += off;
2210                    }
2211                }
2212            }
2213
2214            let mut carry = true;
2215            for dim in (0..d).rev() {
2216                if carry {
2217                    multi_idx[dim] += 1;
2218                    if multi_idx[dim] < self.marginal_dims[dim] {
2219                        carry = false;
2220                    } else {
2221                        multi_idx[dim] = 0;
2222                    }
2223                }
2224            }
2225            if carry {
2226                break;
2227            }
2228        }
2229        (logdet, rank, grad, hess)
2230    }
2231}
2232
2233#[cfg(test)]
2234mod joint_unpenalized_dim_tests {
2235    use super::joint_unpenalized_dim;
2236    use ndarray::{Array2, array};
2237
2238    #[test]
2239    fn no_penalty_is_fully_unpenalized() {
2240        assert_eq!(joint_unpenalized_dim(4, &[], &[]), 4);
2241    }
2242
2243    #[test]
2244    fn single_penalty_returns_its_own_null_space() {
2245        // A 3×3 penalty that penalizes only the last coordinate ⇒ 2-dim null
2246        // space (the first two coordinates are unpenalized).
2247        let s = array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 5.0]];
2248        assert_eq!(joint_unpenalized_dim(3, std::slice::from_ref(&s), &[2]), 2);
2249    }
2250
2251    #[test]
2252    fn complementary_double_penalty_has_empty_joint_null_space() {
2253        // The #1360 case in miniature: a "bending" penalty that leaves the
2254        // first coordinate (its 2-dim... here 1-dim) null, plus a
2255        // complementary "null-space ridge" that penalizes exactly that
2256        // coordinate. Per-penalty null dims are {1, 2} and sum to 3 (≈ p),
2257        // but the INTERSECTION is empty: every coordinate is penalized by
2258        // someone, so the joint unpenalized dim is 0.
2259        let bending = array![[0.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]];
2260        let ridge = array![[2.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]];
2261        assert_eq!(joint_unpenalized_dim(3, &[bending, ridge], &[1, 2]), 0);
2262    }
2263
2264    #[test]
2265    fn partial_overlap_keeps_shared_null_direction() {
2266        // Two penalties that BOTH leave coordinate 0 unpenalized ⇒ the shared
2267        // null direction survives the intersection (joint unpenalized dim 1),
2268        // even though naively summing the per-penalty dims would give 4.
2269        let a = array![[0.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 0.0]];
2270        let b = array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]];
2271        assert_eq!(joint_unpenalized_dim(3, &[a, b], &[2, 2]), 1);
2272    }
2273
2274    #[test]
2275    fn non_materialized_penalty_falls_back_conservatively() {
2276        // A penalty whose stored block is not p_local × p_local (e.g. a
2277        // Kronecker tensor factor). With ≥2 penalties the conservative joint
2278        // dim is 0 (never over-rejecting).
2279        let full: Array2<f64> = array![[0.0, 0.0], [0.0, 1.0]];
2280        let factor: Array2<f64> = array![[1.0]]; // wrong shape for p_local=2
2281        assert_eq!(
2282            joint_unpenalized_dim(2, &[full, factor.clone()], &[1, 0]),
2283            0
2284        );
2285        // With a single non-materialized penalty, fall back to its own null dim.
2286        assert_eq!(joint_unpenalized_dim(4, std::slice::from_ref(&factor), &[2]), 2);
2287    }
2288}
2289
2290#[cfg(test)]
2291mod kronecker_penalty_system_tests {
2292    use super::KroneckerPenaltySystem;
2293    use ndarray::array;
2294
2295    #[test]
2296    fn double_penalty_rank_derivatives_use_only_joint_null_space() {
2297        let penalties = vec![
2298            array![[0.0, 0.0], [0.0, 2.0]],
2299            array![[0.0, 0.0], [0.0, 3.0]],
2300        ];
2301        let system = KroneckerPenaltySystem::new(penalties, vec![2usize, 2usize], true).unwrap();
2302        let lambdas = vec![5.0, 7.0, 11.0];
2303
2304        let (logdet, rank, grad, hess) = system.logdet_rank_and_derivatives(&lambdas, 0.0);
2305
2306        let expected_diag = [11.0_f64, 21.0, 10.0, 31.0];
2307        let expected_logdet: f64 = expected_diag.iter().map(|v| v.ln()).sum();
2308        assert_eq!(rank, 4);
2309        assert!((logdet - expected_logdet).abs() <= 1e-12);
2310        assert!(
2311            (grad[2] - 1.0).abs() <= 1e-12,
2312            "double-penalty rank derivative must count only the joint null mode, got {}",
2313            grad[2]
2314        );
2315        assert!(hess[[2, 2]].abs() <= 1e-12);
2316    }
2317}
2318
2319#[derive(Clone, Debug)]
2320pub struct TermCollectionDesign {
2321    /// The full design matrix.
2322    ///
2323    /// Prefer a true sparse matrix when every block is sparse-compatible.
2324    /// If the collection already contains intrinsically sparse blocks, preserve
2325    /// that storage and let PIRLS decide later whether the penalized system is
2326    /// sparse-native eligible. Purely dense materialized blocks still fall back
2327    /// to the lazy block operator when sparse storage would just re-encode a
2328    /// dense matrix.
2329    pub design: DesignMatrix,
2330    pub penalties: Vec<BlockwisePenalty>,
2331    pub nullspace_dims: Vec<usize>,
2332    pub penaltyinfo: Vec<PenaltyBlockInfo>,
2333    pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
2334    /// Optional global coefficient lower bounds for constrained fitting.
2335    /// Length equals `design.ncols()` when present. Unconstrained entries are `-inf`.
2336    pub coefficient_lower_bounds: Option<Array1<f64>>,
2337    /// Optional global inequality constraints:
2338    /// `A * beta >= b`.
2339    pub linear_constraints: Option<LinearInequalityConstraints>,
2340    pub intercept_range: Range<usize>,
2341    pub linear_ranges: Vec<(String, Range<usize>)>,
2342    pub random_effect_ranges: Vec<(String, Range<usize>)>,
2343    pub random_effect_levels: Vec<(String, Vec<u64>)>,
2344    pub smooth: SmoothDesign,
2345}
2346
2347impl TermCollectionDesign {
2348    /// Number of global penalty blocks that precede the smooth-term penalty
2349    /// blocks in the flat smoothing-parameter / EDF-trace layout.
2350    ///
2351    /// The prefix is not the same as `random_effect_ranges.len()`: unpenalized
2352    /// (or empty) random-effect ranges contribute coefficient columns but no
2353    /// penalty block. The authoritative layout is the recorded `penaltyinfo`
2354    /// sequence assembled alongside `penalties`.
2355    pub fn leading_penalty_blocks_before_smooth(&self) -> usize {
2356        self.penaltyinfo
2357            .iter()
2358            .take_while(|info| {
2359                matches!(
2360                    &info.penalty.source,
2361                    crate::basis::PenaltySource::Other(source)
2362                        if source == "LinearTermRidge"
2363                            || source.starts_with("RandomEffectRidge(")
2364                )
2365            })
2366            .count()
2367    }
2368
2369    /// Convert blockwise penalties to `PenaltyMatrix::Blockwise` without
2370    /// expanding to `p_total × p_total`. This is the preferred path for
2371    /// family modules that accept `Vec<PenaltyMatrix>`.
2372    pub fn penalties_as_penalty_matrix(&self) -> Vec<gam_problem::PenaltyMatrix> {
2373        let p = self.design.ncols();
2374        self.penalties
2375            .iter()
2376            .map(|bp| bp.to_penalty_matrix(p))
2377            .collect()
2378    }
2379
2380    /// Number of penalty blocks.
2381    #[inline]
2382    pub fn num_penalties(&self) -> usize {
2383        self.penalties.len()
2384    }
2385
2386    /// Resolve coefficient groups against this design's global coefficient
2387    /// layout and append their penalties after the existing term penalties.
2388    pub fn realize_coefficient_groups(
2389        &self,
2390        groups: &[CoefficientGroupSpec],
2391        base_prior: &gam_spec::RhoPrior,
2392    ) -> Result<RealizedCoefficientGroups, BasisError> {
2393        realize_coefficient_groups(self, groups, base_prior)
2394    }
2395
2396    /// Extract a `KroneckerPenaltySystem` when the model's *only* smooth term is
2397    /// a single Kronecker-factored tensor.
2398    ///
2399    /// This is a deliberate single-tensor fast path, not a partial feature: any
2400    /// other shape — zero Kronecker terms, several of them, or a tensor mixed
2401    /// with non-tensor smooth terms — is served correctly by the standard
2402    /// block-separable assembly, so this returns `None` and the caller falls
2403    /// back to it. The two former conditions (`len != 1` and "a non-Kronecker
2404    /// smooth term exists") are jointly equivalent to "the sole smooth term is
2405    /// Kronecker", which the slice pattern below expresses directly in one pass.
2406    pub fn kronecker_penalty_system(&self) -> Option<KroneckerPenaltySystem> {
2407        let [only_term] = self.smooth.terms.as_slice() else {
2408            return None;
2409        };
2410        let kron = only_term.kronecker_factored.as_ref()?;
2411        // A genuine tensor product needs at least two margins, and the marginal
2412        // design / penalty / dim collections must agree in length. A degenerate
2413        // (single-margin) or internally inconsistent factored basis cannot feed
2414        // the Kronecker fast path, so fall back to the standard assembly rather
2415        // than construct a malformed `KroneckerPenaltySystem` from it.
2416        if kron.marginal_dims.len() < 2
2417            || kron.marginal_penalties.len() != kron.marginal_dims.len()
2418            || kron.marginal_designs.len() != kron.marginal_dims.len()
2419        {
2420            return None;
2421        }
2422        KroneckerPenaltySystem::new(
2423            kron.marginal_penalties.clone(),
2424            kron.marginal_dims.clone(),
2425            kron.has_double_penalty,
2426        )
2427        .ok()
2428    }
2429}
2430
2431// `FittedTermCollection`, `SpatialLengthScaleOptimizationTiming`, and
2432// `FittedTermCollectionWithSpec` were relocated with the GAM fit-orchestration
2433// drivers to `gam-models` (`crate::fit_orchestration::drivers`) — they hold a
2434// `gam_solve::UnifiedFitResult` and are consumed only by those drivers (#1521).
2435
2436#[derive(Clone)]
2437pub struct StandardLatentCoordConfig {
2438    pub values: std::sync::Arc<crate::latent::LatentCoordValues>,
2439    pub term_index: gam_problem::types::SmoothTermIdx,
2440    pub feature_cols: Vec<usize>,
2441    pub manifold: crate::latent::LatentManifold,
2442    pub manifold_auto: bool,
2443    pub retraction_registry: gam_problem::LatentRetractionRegistry,
2444    pub analytic_penalties: Option<std::sync::Arc<crate::AnalyticPenaltyRegistry>>,
2445}
2446
2447#[derive(Clone, Debug, Serialize, Deserialize)]
2448pub struct AdaptiveSpatialMap {
2449    pub termname: String,
2450    pub feature_cols: Vec<usize>,
2451    pub collocation_points: Array2<f64>,
2452    pub inv_magweight: Array1<f64>,
2453    pub invgradweight: Array1<f64>,
2454    pub inv_lapweight: Array1<f64>,
2455}
2456
2457#[derive(Clone, Debug, Serialize, Deserialize)]
2458pub struct AdaptiveRegularizationDiagnostics {
2459    pub epsilon_0: f64,
2460    pub epsilon_g: f64,
2461    pub epsilon_c: f64,
2462    pub epsilon_outer_iterations: usize,
2463    pub mm_iterations: usize,
2464    pub converged: bool,
2465    pub maps: Vec<AdaptiveSpatialMap>,
2466}
2467
2468#[derive(Debug, Clone)]
2469pub struct LinearColumnConditioning {
2470    col_idx: usize,
2471    mean: f64,
2472    scale: f64,
2473}
2474
2475#[derive(Debug, Clone, Default)]
2476pub struct LinearFitConditioning {
2477    pub intercept_idx: usize,
2478    pub columns: Vec<LinearColumnConditioning>,
2479}
2480
2481#[derive(Clone)]
2482pub struct SpatialPsiDerivative {
2483    // These are derivatives with respect to psi = log(kappa), not log(length_scale).
2484    pub penalty_index: usize,
2485    pub penalty_indices: Vec<usize>,
2486    pub global_range: Range<usize>,
2487    pub total_p: usize,
2488    pub x_psi_local: Array2<f64>,
2489    pub s_psi_components_local: Vec<Array2<f64>>,
2490    pub x_psi_psi_local: Array2<f64>,
2491    pub s_psi_psi_components_local: Vec<Array2<f64>>,
2492    pub aniso_group_id: Option<usize>,
2493    /// Pre-computed cross-derivative design matrices for other axes
2494    /// in the same aniso group: Vec of (axis_offset_in_group, matrix).
2495    pub aniso_cross_designs: Option<Vec<(usize, Array2<f64>)>>,
2496    /// On-demand cross-penalty second derivatives ∂²S_m/∂ψ_a∂ψ_b for axes in
2497    /// the same anisotropy group. The input is the other axis offset in the
2498    /// group, and the output is one local penalty matrix per active penalty.
2499    pub aniso_cross_penalty_provider: Option<
2500        std::sync::Arc<
2501            dyn Fn(usize) -> Result<Vec<Array2<f64>>, EstimationError> + Send + Sync + 'static,
2502        >,
2503    >,
2504    /// Optional implicit design-derivative operator (shared across all axes
2505    /// in the same aniso group). When present, `x_psi_local` and
2506    /// `x_psi_psi_local` may be zero-sized, and design-derivative matvecs
2507    /// should go through this operator using `implicit_axis` as the axis index.
2508    pub implicit_operator: Option<std::sync::Arc<crate::basis::ImplicitDesignPsiDerivative>>,
2509    /// Which axis in the implicit operator this entry corresponds to.
2510    pub implicit_axis: usize,
2511}
2512
2513#[derive(Debug, Clone)]
2514pub struct SpatialLogKappaCoords {
2515    /// Flattened ψ values. For isotropic terms, one entry per term.
2516    /// For anisotropic terms, d entries per term (one ψ_a per axis).
2517    pub values: Array1<f64>,
2518    /// Dimensionality of each term: 1 for isotropic, d for anisotropic.
2519    pub dims_per_term: Vec<usize>,
2520}
2521
2522/// Which end of the ψ bound the shared `aniso_bounds_from_data` helper is
2523/// computing. The lower end uses `-max_length_scale.ln()` as the pure-Duchon
2524/// fallback and the `.0` element of `spatial_term_psi_bounds`; the upper end
2525/// uses `-min_length_scale.ln()` and `.1`. Everything else is identical.
2526#[derive(Clone, Copy)]
2527pub enum AnisoBoundEnd {
2528    Lower,
2529    Upper,
2530}
2531
2532impl SpatialLogKappaCoords {
2533    /// Construct from an explicit dims layout plus values.
2534    pub fn new_with_dims(values: Array1<f64>, dims_per_term: Vec<usize>) -> Self {
2535        assert_eq!(
2536            values.len(),
2537            dims_per_term.iter().sum::<usize>(),
2538            "SpatialLogKappaCoords: values length {} != sum of dims_per_term {}",
2539            values.len(),
2540            dims_per_term.iter().sum::<usize>(),
2541        );
2542        Self {
2543            values,
2544            dims_per_term,
2545        }
2546    }
2547
2548    /// Isotropic initialization (backward-compatible path).
2549    pub fn from_length_scales(
2550        spec: &TermCollectionSpec,
2551        term_indices: &[usize],
2552        options: &SpatialLengthScaleOptimizationOptions,
2553    ) -> Self {
2554        let mut out = Array1::<f64>::zeros(term_indices.len());
2555        for (slot, &term_idx) in term_indices.iter().enumerate() {
2556            // Constant-curvature: the single ψ slot is the raw signed κ, seeded
2557            // from the spec (default κ = 0). The −ln(length_scale) convention is
2558            // log-κ semantics and must not touch the raw-κ coordinate; the κ
2559            // window projection happens later via `clamp_to_bounds`. Mirrors the
2560            // aniso constructor's κ branch.
2561            if let Some(cc) = constant_curvature_term_spec(spec, term_idx) {
2562                out[slot] = cc.kappa;
2563                continue;
2564            }
2565            let length_scale = get_spatial_length_scale(spec, term_idx)
2566                .unwrap_or(options.min_length_scale)
2567                .clamp(options.min_length_scale, options.max_length_scale);
2568            out[slot] = -length_scale.ln();
2569        }
2570        Self {
2571            values: out,
2572            dims_per_term: vec![1; term_indices.len()],
2573        }
2574    }
2575
2576    /// Anisotropic-aware initialization.
2577    ///
2578    /// Initialization strategy (per math team recommendation): standardize the
2579    /// knot cloud axiswise, then run the existing isotropic κ initializer in
2580    /// the standardized space. This reuses the trusted isotropic initializer
2581    /// and gives initial η_a = −ln(σ_a) + mean(ln(σ_a)), which satisfies
2582    /// Ση_a = 0 by construction.
2583    ///
2584    /// For each term, checks whether it has `aniso_log_scales` set on its basis spec.
2585    /// - If isotropic (no aniso_log_scales, or 1-D): 1 entry = −ln(length_scale).
2586    /// - If anisotropic with a scalar length scale: d entries, one ψ_a per axis.
2587    ///   Initialized as ψ_a = −ln(length_scale) + η_a  where η_a are the existing
2588    ///   aniso_log_scales (which sum to zero). Multi-dimensional terms without
2589    ///   explicit anisotropy stay scalar here so the seed dimensionality matches
2590    ///   `spatial_dims_per_term`.
2591    /// - If pure Duchon anisotropic: d - 1 free entries store the leading η_a
2592    ///   values directly; the final axis is reconstructed to keep Ση_a = 0.
2593    pub fn from_length_scales_aniso(
2594        spec: &TermCollectionSpec,
2595        term_indices: &[usize],
2596        options: &SpatialLengthScaleOptimizationOptions,
2597    ) -> Self {
2598        let mut vals = Vec::new();
2599        let mut dims = Vec::new();
2600        for &term_idx in term_indices {
2601            // Measure-jet: dial coordinates seeded directly from the term's
2602            // realized (α, τ[, s]); the −ln(length_scale) convention below is
2603            // κ-semantics and never applies to dials.
2604            if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
2605                let seed = measure_jet_psi_seed(mj);
2606                dims.push(seed.len());
2607                vals.extend(seed);
2608                continue;
2609            }
2610            // Constant-curvature: one signed κ slot seeded from the spec's κ
2611            // (clamped feasible). The −ln(length_scale) convention below is
2612            // log-κ semantics and must not touch the raw-κ coordinate. Bounds
2613            // are unavailable here (no data view), so this is the raw spec κ;
2614            // `reseed_from_data` / `clamp_to_bounds` later project it feasible.
2615            if let Some(cc) = constant_curvature_term_spec(spec, term_idx) {
2616                vals.push(cc.kappa);
2617                dims.push(1);
2618                continue;
2619            }
2620            let length_scale = get_spatial_length_scale(spec, term_idx)
2621                .unwrap_or(options.min_length_scale)
2622                .clamp(options.min_length_scale, options.max_length_scale);
2623            let psi_bar = -length_scale.ln(); // global scale = −ln(length_scale)
2624
2625            if spatial_term_uses_per_axis_psi(spec, term_idx) {
2626                // Per-axis anisotropy is enrolled in the joint outer vector:
2627                // ψ_a = ψ̄ + η_a, one slot per axis. The hyper_dirs builder
2628                // produces matching per-axis derivatives in
2629                // `try_build_spatial_term_log_kappa_aniso_derivativeinfos`.
2630                let d = get_spatial_feature_dim(spec, term_idx).unwrap_or(1);
2631                let eta_raw = get_spatial_aniso_log_scales(spec, term_idx)
2632                    .expect("predicate guarantees aniso_log_scales is Some");
2633                let eta = center_aniso_log_scales(&eta_raw);
2634                for &eta_a in &eta {
2635                    vals.push(psi_bar + eta_a);
2636                }
2637                dims.push(d);
2638            } else {
2639                // Isotropic enrollment — either a 1-D term, a multi-D term
2640                // without explicit anisotropy, or a basis (e.g. Duchon) whose
2641                // η is a fixed geometry parameter rather than a REML hyper
2642                // axis. Exactly one ψ̄ slot, matching the single
2643                // `SpatialPsiDerivative` produced by
2644                // `try_build_spatial_term_log_kappa_derivativeinfo`.
2645                vals.push(psi_bar);
2646                dims.push(1);
2647            }
2648        }
2649        Self {
2650            values: Array1::from_vec(vals),
2651            dims_per_term: dims,
2652        }
2653    }
2654
2655    /// Isotropic lower bounds derived from per-term data geometry.
2656    /// Each entry gets the ψ_lo bound returned by `spatial_term_psi_bounds`
2657    /// for the corresponding term, intersected with the options window.
2658    pub fn lower_bounds_from_data(
2659        data: ArrayView2<'_, f64>,
2660        spec: &TermCollectionSpec,
2661        term_indices: &[usize],
2662        options: &SpatialLengthScaleOptimizationOptions,
2663    ) -> Self {
2664        let mut values = Array1::<f64>::zeros(term_indices.len());
2665        for (slot, &term_idx) in term_indices.iter().enumerate() {
2666            values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).0;
2667        }
2668        Self {
2669            values,
2670            dims_per_term: vec![1; term_indices.len()],
2671        }
2672    }
2673
2674    /// Isotropic upper bounds derived from per-term data geometry.
2675    pub fn upper_bounds_from_data(
2676        data: ArrayView2<'_, f64>,
2677        spec: &TermCollectionSpec,
2678        term_indices: &[usize],
2679        options: &SpatialLengthScaleOptimizationOptions,
2680    ) -> Self {
2681        let mut values = Array1::<f64>::zeros(term_indices.len());
2682        for (slot, &term_idx) in term_indices.iter().enumerate() {
2683            values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).1;
2684        }
2685        Self {
2686            values,
2687            dims_per_term: vec![1; term_indices.len()],
2688        }
2689    }
2690
2691    /// Anisotropic-aware lower bounds derived from per-term data geometry.
2692    /// For hybrid anisotropic terms the scalar ψ_lo bound applies to the
2693    /// mean `ψ̄`, not directly to every raw axis coordinate `ψ_a = ψ̄ + η_a`.
2694    /// Shift each axis by the current centered `η_a` so projecting/clamping
2695    /// the seed moves only the global scale direction and does not silently
2696    /// shrink anisotropy that is already consistent with the current
2697    /// `length_scale`.
2698    ///
2699    /// Pure Duchon anisotropy is structurally different: its stored
2700    /// coordinates are (d-1) free η_a values representing log axis-scale
2701    /// ratios, NOT log-κ. For those terms the κ-range geometry bound is
2702    /// over-restrictive (η_a = ±5 is normal, but that corresponds to 7+
2703    /// orders of magnitude in κ-space and would be rejected by the data
2704    /// window). Fall back to the options window `[-ln(max_ls), -ln(min_ls)]`
2705    /// for those coordinates — that's the same bound the pre-data-geometry
2706    /// code used, which is calibrated to allow legitimate anisotropy.
2707    pub fn lower_bounds_aniso_from_data(
2708        data: ArrayView2<'_, f64>,
2709        spec: &TermCollectionSpec,
2710        term_indices: &[usize],
2711        dims_per_term: &[usize],
2712        options: &SpatialLengthScaleOptimizationOptions,
2713    ) -> Self {
2714        Self::aniso_bounds_from_data(
2715            data,
2716            spec,
2717            term_indices,
2718            dims_per_term,
2719            options,
2720            AnisoBoundEnd::Lower,
2721        )
2722    }
2723
2724    /// Anisotropic-aware upper bounds derived from per-term data geometry.
2725    /// See `lower_bounds_aniso_from_data` for the hybrid-aniso offsetting and
2726    /// pure-Duchon dispatch rationale.
2727    pub fn upper_bounds_aniso_from_data(
2728        data: ArrayView2<'_, f64>,
2729        spec: &TermCollectionSpec,
2730        term_indices: &[usize],
2731        dims_per_term: &[usize],
2732        options: &SpatialLengthScaleOptimizationOptions,
2733    ) -> Self {
2734        Self::aniso_bounds_from_data(
2735            data,
2736            spec,
2737            term_indices,
2738            dims_per_term,
2739            options,
2740            AnisoBoundEnd::Upper,
2741        )
2742    }
2743
2744    /// Shared implementation for the lower/upper aniso bounds. The bound end
2745    /// only changes which options scale (`max_length_scale` vs
2746    /// `min_length_scale`) becomes the pure-Duchon fallback bound and which
2747    /// element of the `(lo, hi)` data-geometry tuple is consumed; the
2748    /// per-term cursor walk and aniso-offset handling are identical.
2749    fn aniso_bounds_from_data(
2750        data: ArrayView2<'_, f64>,
2751        spec: &TermCollectionSpec,
2752        term_indices: &[usize],
2753        dims_per_term: &[usize],
2754        options: &SpatialLengthScaleOptimizationOptions,
2755        end: AnisoBoundEnd,
2756    ) -> Self {
2757        assert_eq!(term_indices.len(), dims_per_term.len());
2758        let total: usize = dims_per_term.iter().sum();
2759        let mut values = Array1::<f64>::zeros(total);
2760        let mut cursor = 0;
2761        for (slot, &term_idx) in term_indices.iter().enumerate() {
2762            let d = dims_per_term[slot];
2763            // Measure-jet: per-coordinate dial boxes, never κ-window geometry
2764            // (which would reject legitimate dial values outright).
2765            if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
2766                let bounds = measure_jet_psi_bound_values(mj, matches!(end, AnisoBoundEnd::Upper));
2767                for (offset, bound) in bounds.into_iter().enumerate() {
2768                    if offset < d {
2769                        values[cursor + offset] = bound;
2770                    }
2771                }
2772                cursor += d;
2773                continue;
2774            }
2775            // Constant-curvature: the single signed-κ box from the data chart
2776            // window (symmetric about κ = 0), never a κ = log-scale window.
2777            if constant_curvature_term_spec(spec, term_idx).is_some() {
2778                let (lo, hi) = constant_curvature_kappa_bounds(data, spec, term_idx);
2779                if d >= 1 {
2780                    values[cursor] = match end {
2781                        AnisoBoundEnd::Lower => lo,
2782                        AnisoBoundEnd::Upper => hi,
2783                    };
2784                }
2785                cursor += d;
2786                continue;
2787            }
2788            let psi_bound = {
2789                let (lo, hi) = spatial_term_psi_bounds(data, spec, term_idx, options);
2790                match end {
2791                    AnisoBoundEnd::Lower => lo,
2792                    AnisoBoundEnd::Upper => hi,
2793                }
2794            };
2795            let axis_offsets = if d <= 1 {
2796                vec![0.0; d]
2797            } else {
2798                get_spatial_aniso_log_scales(spec, term_idx)
2799                    .filter(|eta| eta.len() == d)
2800                    .map(|eta| center_aniso_log_scales(&eta))
2801                    .unwrap_or_else(|| vec![0.0; d])
2802            };
2803            for offset in 0..d {
2804                values[cursor + offset] = psi_bound + axis_offsets[offset];
2805            }
2806            cursor += d;
2807        }
2808        Self {
2809            values,
2810            dims_per_term: dims_per_term.to_vec(),
2811        }
2812    }
2813
2814    /// Rewrite any ψ entries whose originating term lacks an explicit
2815    /// `length_scale` so they sit at the midpoint of the per-term data-derived
2816    /// ψ window. Used so the outer optimizer starts inside the physically
2817    /// meaningful region instead of at an arbitrary `options.max_length_scale`
2818    /// derived seed. For terms with an explicit length_scale, the user's
2819    /// choice is respected. Anisotropy offsets η_a (those stored by
2820    /// `from_length_scales_aniso`) are preserved: we re-center around the new
2821    /// ψ̄, keeping Ση_a = 0.
2822    pub fn reseed_from_data(
2823        mut self,
2824        data: ArrayView2<'_, f64>,
2825        spec: &TermCollectionSpec,
2826        term_indices: &[usize],
2827        options: &SpatialLengthScaleOptimizationOptions,
2828    ) -> Self {
2829        assert_eq!(term_indices.len(), self.dims_per_term.len());
2830        let mut cursor = 0;
2831        for (slot, &term_idx) in term_indices.iter().enumerate() {
2832            let d = self.dims_per_term[slot];
2833            // Measure-jet dials are seeded from the realized spec and must
2834            // not be recentered into a κ data window.
2835            if measure_jet_term_spec(spec, term_idx).is_some() {
2836                cursor += d;
2837                continue;
2838            }
2839            // Constant-curvature κ is seeded from the spec (the user's curvature
2840            // hint, default κ = 0); `clamp_to_bounds` projects it feasible. It
2841            // is not a log-scale, so the log-κ recenter below never applies.
2842            if constant_curvature_term_spec(spec, term_idx).is_some() {
2843                cursor += d;
2844                continue;
2845            }
2846            let Some(psi_bar_new) = spatial_term_psi_seed(data, spec, term_idx, options) else {
2847                cursor += d;
2848                continue;
2849            };
2850            if d == 0 {
2851                continue;
2852            }
2853            let current: Vec<f64> = self.values.slice(s![cursor..cursor + d]).to_vec();
2854            let psi_bar_old = current.iter().sum::<f64>() / d as f64;
2855            for (offset, &old_value) in current.iter().enumerate() {
2856                self.values[cursor + offset] = psi_bar_new + (old_value - psi_bar_old);
2857            }
2858            cursor += d;
2859        }
2860        self
2861    }
2862
2863    /// Project ψ values into `[lower, upper]` element-wise. Used after
2864    /// `from_length_scales*` + `reseed_from_data` when a user-supplied
2865    /// `spec.length_scale` falls outside the data-derived ψ window set by
2866    /// `{lower,upper}_bounds*_from_data`. BFGS requires theta0 ∈ [lower,
2867    /// upper]; projecting is the unique closest feasible seed. The user's
2868    /// length_scale was always a hint for the outer optimizer (the optimizer
2869    /// is authoritative for κ), not a hard constraint — so clipping preserves
2870    /// their intent as far as the geometry allows. Emits `log::info!` when
2871    /// any coordinate moves, so the outside-window case is diagnostically
2872    /// visible (not silent).
2873    pub fn clamp_to_bounds(
2874        mut self,
2875        lower: &SpatialLogKappaCoords,
2876        upper: &SpatialLogKappaCoords,
2877    ) -> Self {
2878        assert_eq!(self.values.len(), lower.values.len());
2879        assert_eq!(self.values.len(), upper.values.len());
2880        let mut n_projected = 0usize;
2881        let mut worst_delta = 0.0_f64;
2882        for idx in 0..self.values.len() {
2883            let lo = lower.values[idx];
2884            let hi = upper.values[idx];
2885            if !(lo.is_finite() && hi.is_finite()) {
2886                continue;
2887            }
2888            let v = self.values[idx];
2889            if v < lo {
2890                worst_delta = worst_delta.max(lo - v);
2891                self.values[idx] = lo;
2892                n_projected += 1;
2893            } else if v > hi {
2894                worst_delta = worst_delta.max(v - hi);
2895                self.values[idx] = hi;
2896                n_projected += 1;
2897            }
2898        }
2899        if n_projected > 0 {
2900            log::info!(
2901                "[spatial-kappa] projected {n_projected}/{} ψ seed coords into data-derived bounds \
2902                 (worst excess={worst_delta:.3} log units); user length_scale falls outside \
2903                 [{KERNEL_RANGE_MIN_DIAMETER_FRACTION}/r_max, {KERNEL_RANGE_MAX_SPACING_MULTIPLE}/r_min] geometry window",
2904                self.values.len()
2905            );
2906        }
2907        self
2908    }
2909
2910    /// Reconstruct from theta tail with known dimensionality layout.
2911    pub fn from_theta_tail_with_dims(
2912        theta: &Array1<f64>,
2913        start: usize,
2914        dims_per_term: Vec<usize>,
2915    ) -> Self {
2916        let total: usize = dims_per_term.iter().sum();
2917        Self {
2918            values: theta.slice(s![start..start + total]).to_owned(),
2919            dims_per_term,
2920        }
2921    }
2922
2923    /// Total number of ψ values in the flat array (= sum of dims_per_term).
2924    pub fn len(&self) -> usize {
2925        self.values.len()
2926    }
2927
2928    /// Dimensionality layout: how many ψ values each term contributes.
2929    pub fn dims_per_term(&self) -> &[usize] {
2930        &self.dims_per_term
2931    }
2932
2933    /// Get the offset into the flat array for logical term i.
2934    fn term_offset(&self, term_idx: usize) -> usize {
2935        self.dims_per_term[..term_idx].iter().sum()
2936    }
2937
2938    /// Get the slice of ψ values for logical term i.
2939    pub fn term_slice(&self, term_idx: usize) -> &[f64] {
2940        let offset = self.term_offset(term_idx);
2941        let d = self.dims_per_term[term_idx];
2942        &self.values.as_slice().unwrap()[offset..offset + d]
2943    }
2944
2945    pub fn as_array(&self) -> &Array1<f64> {
2946        &self.values
2947    }
2948
2949    /// #1464: overwrite the single ψ value of a scalar (1-D) logical term by its
2950    /// position `slot` in this coords vector (the same ordering as the
2951    /// `term_indices` slice the constructors were built from). Used to inject the
2952    /// fixed-κ sign-basin seed into a constant-curvature term's raw-κ slot before
2953    /// the joint solve. No-op (returns `false`) when the slot is not scalar.
2954    pub fn set_scalar_slot(&mut self, slot: usize, value: f64) -> bool {
2955        if slot >= self.dims_per_term.len() || self.dims_per_term[slot] != 1 {
2956            return false;
2957        }
2958        let offset = self.term_offset(slot);
2959        self.values[offset] = value;
2960        true
2961    }
2962
2963    /// Split at a logical-term boundary. `mid` is the number of terms in the
2964    /// first half (not a flat-array index).
2965    pub fn split_at(&self, mid: usize) -> (Self, Self) {
2966        let flat_mid: usize = self.dims_per_term[..mid].iter().sum();
2967        (
2968            Self {
2969                values: self.values.slice(s![0..flat_mid]).to_owned(),
2970                dims_per_term: self.dims_per_term[..mid].to_vec(),
2971            },
2972            Self {
2973                values: self.values.slice(s![flat_mid..]).to_owned(),
2974                dims_per_term: self.dims_per_term[mid..].to_vec(),
2975            },
2976        )
2977    }
2978
2979    /// Apply optimized ψ values back to the spec.
2980    ///
2981    /// For isotropic terms (dims=1): sets scalar length_scale = exp(−ψ).
2982    /// For anisotropic terms (dims=d): hybrid/isotropic families set
2983    /// length_scale = exp(−ψ̄) with centered η_a = ψ_a − ψ̄, while pure Duchon
2984    /// writes only centered η_a and leaves length_scale = None.
2985    pub fn apply_tospec(
2986        &self,
2987        spec: &TermCollectionSpec,
2988        term_indices: &[usize],
2989    ) -> Result<TermCollectionSpec, EstimationError> {
2990        if term_indices.len() != self.dims_per_term.len() {
2991            crate::bail_invalid_estim!(
2992                "SpatialLogKappaCoords::apply_tospec: term count mismatch: \
2993                 term_indices={} dims_per_term={}",
2994                term_indices.len(),
2995                self.dims_per_term.len()
2996            );
2997        }
2998        let mut updated = spec.clone();
2999        for (slot, &term_idx) in term_indices.iter().enumerate() {
3000            let psi = self.term_slice(slot);
3001            let d = self.dims_per_term[slot];
3002            // Measure-jet: write the dial coordinates straight back; the
3003            // κ-translation below would misread them as log-scales.
3004            if measure_jet_term_spec(&updated, term_idx).is_some() {
3005                set_measure_jet_psi_dials(&mut updated, term_idx, psi)?;
3006                continue;
3007            }
3008            // Constant-curvature: write the optimized signed κ straight back;
3009            // the −exp(ψ) length-scale translation below is log-κ semantics and
3010            // would misread the raw curvature.
3011            if constant_curvature_term_spec(&updated, term_idx).is_some() {
3012                set_constant_curvature_kappa(&mut updated, term_idx, psi)?;
3013                continue;
3014            }
3015            let (next_length_scale, next_aniso) = spatial_term_psi_to_length_scale_and_aniso(psi);
3016            if (d == 1 || next_length_scale.is_some())
3017                && let Some(length_scale) = next_length_scale
3018            {
3019                set_spatial_length_scale(&mut updated, term_idx, length_scale)?;
3020            }
3021            if let Some(eta) = next_aniso {
3022                set_spatial_aniso_log_scales(&mut updated, term_idx, eta)?;
3023            }
3024        }
3025        Ok(updated)
3026    }
3027}
3028
3029pub fn center_aniso_log_scales(eta: &[f64]) -> Vec<f64> {
3030    if eta.len() <= 1 {
3031        return eta.to_vec();
3032    }
3033    let mean = eta.iter().sum::<f64>() / eta.len() as f64;
3034    eta.iter()
3035        .map(|&v| {
3036            let centered = v - mean;
3037            if centered.abs() <= 1e-15 {
3038                0.0
3039            } else {
3040                centered
3041            }
3042        })
3043        .collect()
3044}
3045
3046/// Whether a spatial term contributes per-axis ψ entries to the outer joint
3047/// hyperparameter vector.
3048pub fn spatial_term_uses_per_axis_psi(resolvedspec: &TermCollectionSpec, term_idx: usize) -> bool {
3049    if let Some(mj) = measure_jet_term_spec(resolvedspec, term_idx) {
3050        return measure_jet_enrolls_psi(mj);
3051    }
3052    let Some(d) = get_spatial_feature_dim(resolvedspec, term_idx) else {
3053        return false;
3054    };
3055    if d <= 1 {
3056        return false;
3057    }
3058    let Some(eta) = get_spatial_aniso_log_scales(resolvedspec, term_idx) else {
3059        return false;
3060    };
3061    if eta.len() != d {
3062        return false;
3063    }
3064    !matches!(
3065        resolvedspec.smooth_terms.get(term_idx).map(|term| &term.basis),
3066        Some(SmoothBasisSpec::Duchon { .. })
3067    )
3068}
3069
3070pub fn set_spatial_length_scale(
3071    spec: &mut TermCollectionSpec,
3072    term_idx: usize,
3073    length_scale: f64,
3074) -> Result<(), EstimationError> {
3075    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3076        crate::bail_invalid_estim!("spatial length-scale term index {term_idx} out of range");
3077    };
3078    match &mut term.basis {
3079        SmoothBasisSpec::ThinPlate { spec, .. } => {
3080            spec.length_scale = length_scale;
3081            Ok(())
3082        }
3083        SmoothBasisSpec::Matern { spec, .. } => {
3084            spec.length_scale = length_scale;
3085            Ok(())
3086        }
3087        SmoothBasisSpec::Duchon { spec, .. } => {
3088            spec.length_scale = Some(length_scale);
3089            Ok(())
3090        }
3091        _ => Err(EstimationError::InvalidInput(format!(
3092            "term '{}' does not expose a spatial length scale",
3093            term.name
3094        ))),
3095    }
3096}
3097
3098pub fn get_spatial_length_scale(spec: &TermCollectionSpec, term_idx: usize) -> Option<f64> {
3099    spec.smooth_terms
3100        .get(term_idx)
3101        .and_then(|term| match &term.basis {
3102            SmoothBasisSpec::ThinPlate { spec, .. } => Some(spec.length_scale),
3103            SmoothBasisSpec::Matern { spec, .. } => Some(spec.length_scale),
3104            SmoothBasisSpec::Duchon { spec, .. } => spec.length_scale,
3105            _ => None,
3106        })
3107}
3108
3109pub fn spatial_term_supports_hyper_optimization(spec: &TermCollectionSpec, term_idx: usize) -> bool {
3110    // Ordinary penalized thin-plate regression splines do not have an
3111    // identifiable kernel scale once REML is already learning the smoothing
3112    // penalty. Treat the resolved length scale as fixed geometry; enrolling a
3113    // scalar TPS kappa axis creates the flat ρ/κ valleys reported in #718,
3114    // #721, #731, and #732.
3115    if let Some(term) = spec.smooth_terms.get(term_idx)
3116        && let SmoothBasisSpec::ThinPlate { .. } = &term.basis
3117    {
3118        return false;
3119    }
3120
3121    // Duchon anisotropy η is a FIXED, geometry-derived basis parameter, NOT a
3122    // REML hyper axis: the metric is estimated once from the knot-cloud spread
3123    // (`auto_seed_aniso_contrasts`, applied on every Duchon basis build) and the
3124    // Hilbert-scale λ's carry all learned smoothness. So a pure Duchon (no κ)
3125    // contributes no outer optimization axis even when `scale_dims` is on —
3126    // "standardize the geometry, then learn the smoothness." Only an explicit
3127    // kernel length scale κ (the Matérn / hybrid path) is optimized here.
3128    //
3129    // ISOTROPIC Matérn: the *default* `matern(x1, x2)` is isotropic
3130    // (`scale_dims=false` → `aniso_log_scales = None`). It contributes exactly
3131    // ONE κ optimization axis — its scalar log-κ. The shared GAMLSS /
3132    // location-scale exact-joint ψ engine and the spatial-κ joint outer solver
3133    // both require an isotropic Matérn block to expose this single isotropic κ
3134    // axis (#822/#851); without it the per-block ψ-derivative lists are empty
3135    // and the joint-ψ hooks degenerate to `None`. The isotropic κ is the lone
3136    // kernel hyper axis here, mirroring the per-axis ψ ARD that the anisotropic
3137    // path exposes (just collapsed to one dimension).
3138    //
3139    // ANISOTROPIC Matérn (`scale_dims=true` → `aniso_log_scales = Some`) keeps
3140    // its per-axis kernel-η ARD: the d-dimensional ψ search is the *point* of
3141    // the anisotropic request ("Matérn keeps its kernel-η ARD").
3142    //
3143    // Either way a Matérn term always enrolls a κ/ψ axis (1 isotropic, or d
3144    // anisotropic), so `spatial_dims_per_term` reports the correct count.
3145    if let Some(term) = spec.smooth_terms.get(term_idx)
3146        && let SmoothBasisSpec::Matern { .. } = &term.basis
3147    {
3148        return true;
3149    }
3150
3151    // Measure-jet geometry dials are outer ψ coordinates; enrollment is
3152    // owned by `measure_jet_enrolls_psi`.
3153    if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
3154        return measure_jet_enrolls_psi(mj);
3155    }
3156
3157    // Constant-curvature smooths always enroll their single signed curvature κ
3158    // as an outer ψ-coordinate (#944 stage 3): κ̂ is the headline estimand, so
3159    // unlike a fixed-ℓ kernel it is fitted by default, not gated on a
3160    // user-supplied scale. The coordinate is raw κ (interior κ = 0), and its
3161    // exact design/penalty κ-derivatives come from
3162    // `build_constant_curvature_basis_kappa_derivatives`.
3163    if constant_curvature_term_spec(spec, term_idx).is_some() {
3164        return true;
3165    }
3166
3167    get_spatial_length_scale(spec, term_idx).is_some()
3168}
3169
3170/// The measure-jet term's spec, when `term_idx` is a measure-jet smooth.
3171/// Single accessor for every dial-plumbing dispatch below.
3172pub fn measure_jet_term_spec(
3173    spec: &TermCollectionSpec,
3174    term_idx: usize,
3175) -> Option<&crate::basis::MeasureJetBasisSpec> {
3176    spec.smooth_terms
3177        .get(term_idx)
3178        .and_then(|term| match &term.basis {
3179            SmoothBasisSpec::MeasureJet { spec, .. } => Some(spec),
3180            _ => None,
3181        })
3182}
3183
3184/// Single source for measure-jet outer-ψ enrollment: the lnτ dial is
3185/// undefined in the τ = 0 pseudo-inverse oracle mode (see
3186/// `build_measure_jet_basis_psi_derivatives`), so only a positive ridge
3187/// enrolls the dial group. `spatial_term_supports_hyper_optimization` and
3188/// `spatial_term_uses_per_axis_psi` both defer here so the θ-layout
3189/// sources cannot disagree.
3190pub fn measure_jet_enrolls_psi(mj: &crate::basis::MeasureJetBasisSpec) -> bool {
3191    // Two independent enrollment sources (#1116), both explicit:
3192    //   * the design-moving representer length-scale ℓ (`learn_length_scale`),
3193    //     available in every mode when the spec opts in;
3194    //   * the multiscale penalty dials (s, α, lnτ): the per-scale spectral
3195    //     split's (α, lnτ) ride the explicit `multiscale` opt-in, and the lnτ
3196    //     channel additionally needs a positive ridge (τ = 0 is the
3197    //     pseudo-inverse oracle mode where lnτ is undefined).
3198    // A term enrolls if EITHER source is active.
3199    measure_jet_learns_length_scale(mj)
3200        || (mj.tau0 > 0.0 && crate::basis::measure_jet_multiscale_mode(mj))
3201}
3202
3203/// Whether the design-moving ℓ dial is enrolled for this term. ℓ is fixed by
3204/// default and learnable in every mode only when `learn_length_scale = true`.
3205pub fn measure_jet_learns_length_scale(mj: &crate::basis::MeasureJetBasisSpec) -> bool {
3206    mj.learn_length_scale
3207}
3208
3209pub fn freeze_measure_jet_length_scale_learning(spec: &mut TermCollectionSpec) -> usize {
3210    let mut frozen = 0;
3211    for term in spec.smooth_terms.iter_mut() {
3212        if let SmoothBasisSpec::MeasureJet { spec: mj, .. } = &mut term.basis
3213            && mj.learn_length_scale
3214        {
3215            mj.learn_length_scale = false;
3216            frozen += 1;
3217        }
3218    }
3219    frozen
3220}
3221
3222/// Measure-jet ψ dial boxes. The dials are NOT log-kernel-scales, so the
3223/// κ-window machinery never applies: `α` spans density-weighted (0) through
3224/// past-Coifman–Lafon (>1) normalization, and `lnτ` covers the ridge from
3225/// numerically-exact-projection to heavy noise-floor damping. (The energy
3226/// order `s` is the pinned explicit value or absorbed by the REML-learned
3227/// per-scale amplitudes — see `measure_jet_penalty_psi_dim` — so it carries no
3228/// dial box.)
3229pub const MEASURE_JET_PSI_ALPHA_BOUNDS: (f64, f64) = (-1.0, 3.0);
3230
3231pub const MEASURE_JET_PSI_LN_TAU_BOUNDS: (f64, f64) = (-18.420680743952367, 4.605170185988092);
3232
3233/// Log-ℓ box for the design-moving representer length-scale dial (#1116). An
3234/// ABSOLUTE window in the data coordinate scale (ln of ℓ ∈ [1e-3, 1e2]) used
3235/// only when the spec explicitly enrolls the learned representer range. Absolute
3236/// (not seed-relative) so the bound producer needs no data view, matching the
3237/// other dial boxes. `ln(1e-3) = -6.9077…`, `ln(1e2) = 4.6051…`.
3238pub const MEASURE_JET_PSI_LN_LENGTH_SCALE_BOUNDS: (f64, f64) = (-6.907755278982137, 4.605170185988092);
3239
3240/// Number of multiscale PENALTY dials (excluding the design-moving ℓ):
3241/// multiscale (per-scale spectral) mode carries (α, lnτ) = 2 — the order is
3242/// either the pinned explicit `s` or absorbed by the REML-learned per-scale
3243/// amplitudes, so it is NOT a dial; single-scale (the default) carries none.
3244/// MUST agree with the penalty-coordinate layout of
3245/// `build_measure_jet_basis_psi_derivatives` (its `per_level` branch always
3246/// emits exactly the (α, lnτ) coordinate pair).
3247pub fn measure_jet_penalty_psi_dim(mj: &crate::basis::MeasureJetBasisSpec) -> usize {
3248    if crate::basis::measure_jet_multiscale_mode(mj) {
3249        2
3250    } else {
3251        0
3252    }
3253}
3254
3255/// ψ dimension of a measure-jet term. The design-moving ℓ dial (when enrolled)
3256/// is coordinate 0; the multiscale penalty dials follow. MUST agree with the
3257/// coordinate layout of `build_measure_jet_basis_psi_derivatives` (ℓ first).
3258pub fn measure_jet_psi_dim(mj: &crate::basis::MeasureJetBasisSpec) -> usize {
3259    usize::from(measure_jet_learns_length_scale(mj)) + measure_jet_penalty_psi_dim(mj)
3260}
3261
3262/// Seed ψ from the term's realized dials, in producer coordinate order: ℓ first
3263/// (when enrolled), then the multiscale penalty dials. The ℓ seed is the
3264/// realized representer range `ln(length_scale)` (the resolved spec carries the
3265/// concrete auto value after the design build/freeze).
3266pub fn measure_jet_psi_seed(mj: &crate::basis::MeasureJetBasisSpec) -> Vec<f64> {
3267    let mut seed = Vec::with_capacity(measure_jet_psi_dim(mj));
3268    if measure_jet_learns_length_scale(mj) {
3269        // length_scale > 0 after resolution; the 0.0 sentinel (pre-resolution)
3270        // falls back to the centre of the log-ℓ box so the optimizer still
3271        // starts feasible and the first data-aware reseed corrects it.
3272        let ell = if mj.length_scale > 0.0 {
3273            mj.length_scale
3274        } else {
3275            1.0
3276        };
3277        seed.push(ell.ln());
3278    }
3279    if measure_jet_penalty_psi_dim(mj) > 0 {
3280        // Multiscale penalty dials, producer order: (α, lnτ).
3281        let ln_tau = mj.tau0.max(f64::MIN_POSITIVE).ln();
3282        seed.extend_from_slice(&[mj.alpha, ln_tau]);
3283    }
3284    seed
3285}
3286
3287/// One end of the per-coordinate dial boxes, in producer coordinate order
3288/// (ℓ first when enrolled, then the multiscale penalty dials).
3289pub fn measure_jet_psi_bound_values(mj: &crate::basis::MeasureJetBasisSpec, upper: bool) -> Vec<f64> {
3290    let pick = |b: (f64, f64)| if upper { b.1 } else { b.0 };
3291    let mut bounds = Vec::with_capacity(measure_jet_psi_dim(mj));
3292    if measure_jet_learns_length_scale(mj) {
3293        bounds.push(pick(MEASURE_JET_PSI_LN_LENGTH_SCALE_BOUNDS));
3294    }
3295    if measure_jet_penalty_psi_dim(mj) > 0 {
3296        // Multiscale penalty dials, producer order: (α, lnτ).
3297        bounds.push(pick(MEASURE_JET_PSI_ALPHA_BOUNDS));
3298        bounds.push(pick(MEASURE_JET_PSI_LN_TAU_BOUNDS));
3299    }
3300    bounds
3301}
3302
3303/// Write optimized ψ dials back into a measure-jet spec. Returns `true` when
3304/// any dial actually moved. The geometry (centers, masses, band, ℓ, z) is
3305/// ψ-FIXED by contract — only the dials change, so frozen-quadrature
3306/// rebuilds reproduce the identical penalty layout at the new dials.
3307pub fn apply_measure_jet_psi(
3308    mj: &mut crate::basis::MeasureJetBasisSpec,
3309    psi: &[f64],
3310) -> Result<bool, EstimationError> {
3311    if psi.len() != measure_jet_psi_dim(mj) {
3312        crate::bail_invalid_estim!(
3313            "measure-jet ψ write-back dimension mismatch: got {} values for a {}-dial term",
3314            psi.len(),
3315            measure_jet_psi_dim(mj)
3316        );
3317    }
3318    let mut changed = false;
3319    // Coordinate 0 (when enrolled) is the design-moving ln(ℓ); the multiscale
3320    // penalty dials follow. Same order as `measure_jet_psi_seed` and the
3321    // producer (`build_measure_jet_basis_psi_derivatives`).
3322    let mut cursor = 0usize;
3323    if measure_jet_learns_length_scale(mj) {
3324        let next_ell = psi[cursor].exp();
3325        cursor += 1;
3326        if !(next_ell.is_finite() && next_ell > 0.0) {
3327            crate::bail_invalid_estim!(
3328                "measure-jet ψ write-back produced a non-finite/non-positive length_scale (ℓ={next_ell})"
3329            );
3330        }
3331        if next_ell != mj.length_scale {
3332            mj.length_scale = next_ell;
3333            changed = true;
3334        }
3335    }
3336    if measure_jet_penalty_psi_dim(mj) > 0 {
3337        // Multiscale penalty dials, producer order: (α, lnτ). The order `s` is
3338        // not a dial (pinned explicit or absorbed by the per-scale amplitudes).
3339        let next_alpha = psi[cursor];
3340        let next_tau = psi[cursor + 1].exp();
3341        if !(next_alpha.is_finite() && next_tau.is_finite() && next_tau > 0.0) {
3342            crate::bail_invalid_estim!(
3343                "measure-jet ψ write-back produced non-finite dials (alpha={next_alpha}, tau={next_tau})"
3344            );
3345        }
3346        if next_alpha != mj.alpha {
3347            mj.alpha = next_alpha;
3348            changed = true;
3349        }
3350        if next_tau != mj.tau0 {
3351            mj.tau0 = next_tau;
3352            changed = true;
3353        }
3354    }
3355    Ok(changed)
3356}
3357
3358/// Collection-level measure-jet dial write-back (the `apply_tospec` /
3359/// realizer-side entry). Returns whether anything moved.
3360pub fn set_measure_jet_psi_dials(
3361    spec: &mut TermCollectionSpec,
3362    term_idx: usize,
3363    psi: &[f64],
3364) -> Result<bool, EstimationError> {
3365    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3366        crate::bail_invalid_estim!("measure-jet ψ write-back: term index {term_idx} out of range");
3367    };
3368    set_single_term_measure_jet_psi_dials(term, psi)
3369}
3370
3371/// Single-term dial write-back: the shared match+apply core, also used
3372/// directly on the cached per-trial build spec (whose caller has already
3373/// change-checked at the collection level and rebuilds regardless of the
3374/// moved flag).
3375pub fn set_single_term_measure_jet_psi_dials(
3376    term: &mut SmoothTermSpec,
3377    psi: &[f64],
3378) -> Result<bool, EstimationError> {
3379    let SmoothBasisSpec::MeasureJet { spec: mj, .. } = &mut term.basis else {
3380        crate::bail_invalid_estim!("measure-jet ψ write-back targeted a non-measure-jet term");
3381    };
3382    apply_measure_jet_psi(mj, psi)
3383}
3384
3385/// The constant-curvature smooth's spec, when `term_idx` is one. Single
3386/// accessor for every κ-ψ dispatch below, mirroring `measure_jet_term_spec`.
3387pub fn constant_curvature_term_spec(
3388    spec: &TermCollectionSpec,
3389    term_idx: usize,
3390) -> Option<&crate::basis::ConstantCurvatureBasisSpec> {
3391    spec.smooth_terms
3392        .get(term_idx)
3393        .and_then(|term| match &term.basis {
3394            SmoothBasisSpec::ConstantCurvature { spec, .. } => Some(spec),
3395            _ => None,
3396        })
3397}
3398
3399/// Hard positive cap on |κ| relative to the data's inverse squared chart
3400/// radius. The κ-stereographic chart is valid for `1 + κ‖x‖² > 0`; at
3401/// `|κ| = 1/R²` (R² = max squared chart radius) the gauge `1 + κ‖x‖²` reaches
3402/// the chart edge for the farthest data point, so the optimizer is boxed to a
3403/// safe fraction of that scale on both sides. κ = 0 (flat) is the centre of
3404/// the window, an interior point of the `S^d ← ℝ^d → H^d` family — exactly the
3405/// reachability the raw-κ (not log-κ) coordinate exists to preserve.
3406pub const CONSTANT_CURVATURE_KAPPA_CHART_FRACTION: f64 = 0.5;
3407
3408/// Floor on the data's squared chart radius used to scale the κ window, so a
3409/// degenerate (near-origin) point cloud still yields a finite, usable bracket
3410/// rather than an unbounded one.
3411pub const CONSTANT_CURVATURE_MIN_CHART_RADIUS2: f64 = 1e-8;
3412
3413/// `(κ_min, κ_max)` outer-optimization window for a constant-curvature term,
3414/// derived from the data's maximum squared chart radius `R²` so the κ-jets
3415/// never leave the κ-stereographic chart. Symmetric about κ = 0:
3416/// `±CONSTANT_CURVATURE_KAPPA_CHART_FRACTION / R²`.
3417pub fn constant_curvature_kappa_bounds(
3418    data: ArrayView2<'_, f64>,
3419    spec: &TermCollectionSpec,
3420    term_idx: usize,
3421) -> (f64, f64) {
3422    let feature_cols = match spec.smooth_terms.get(term_idx).map(|t| &t.basis) {
3423        Some(SmoothBasisSpec::ConstantCurvature { feature_cols, .. }) => feature_cols,
3424        _ => return (-1.0, 1.0),
3425    };
3426    let mut max_r2 = CONSTANT_CURVATURE_MIN_CHART_RADIUS2;
3427    for row in data.outer_iter() {
3428        let mut r2 = 0.0_f64;
3429        for &c in feature_cols.iter() {
3430            if let Some(&v) = row.get(c)
3431                && v.is_finite()
3432            {
3433                r2 += v * v;
3434            }
3435        }
3436        if r2 > max_r2 {
3437            max_r2 = r2;
3438        }
3439    }
3440    let half = CONSTANT_CURVATURE_KAPPA_CHART_FRACTION / max_r2;
3441    (-half, half)
3442}
3443
3444/// Write the optimized κ back into a constant-curvature term spec. Returns
3445/// `true` when κ moved. Centers, ℓ, and the constraint transform `z` are
3446/// κ-FIXED by the basis κ-contract, so only `kappa` changes.
3447pub fn set_constant_curvature_kappa(
3448    spec: &mut TermCollectionSpec,
3449    term_idx: usize,
3450    psi: &[f64],
3451) -> Result<bool, EstimationError> {
3452    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3453        crate::bail_invalid_estim!(
3454            "constant-curvature κ write-back: term index {term_idx} out of range"
3455        );
3456    };
3457    set_single_term_constant_curvature_kappa(term, psi)
3458}
3459
3460/// Single-term κ write-back: the shared validate+apply core, also used directly
3461/// on the cached per-trial build spec in the incremental realizer (whose caller
3462/// has already change-checked at the collection level and rebuilds regardless
3463/// of the moved flag). Mirrors [`set_single_term_measure_jet_psi_dials`].
3464pub fn set_single_term_constant_curvature_kappa(
3465    term: &mut SmoothTermSpec,
3466    psi: &[f64],
3467) -> Result<bool, EstimationError> {
3468    if psi.len() != 1 {
3469        crate::bail_invalid_estim!(
3470            "constant-curvature κ write-back expects exactly one value, got {}",
3471            psi.len()
3472        );
3473    }
3474    let next_kappa = psi[0];
3475    if !next_kappa.is_finite() {
3476        crate::bail_invalid_estim!(
3477            "constant-curvature κ write-back produced a non-finite κ = {next_kappa}"
3478        );
3479    }
3480    let SmoothBasisSpec::ConstantCurvature { spec: cc, .. } = &mut term.basis else {
3481        crate::bail_invalid_estim!(
3482            "constant-curvature κ write-back targeted a non-constant-curvature term"
3483        );
3484    };
3485    if cc.kappa != next_kappa {
3486        cc.kappa = next_kappa;
3487        Ok(true)
3488    } else {
3489        Ok(false)
3490    }
3491}
3492
3493/// Returns `true` when a spatial term has NO outer optimization axes — i.e.
3494/// the user provided an explicit `length_scale` and the term does not enroll
3495/// REML-side per-axis ψ contrasts, so both the scalar κ and any fixed geometry
3496/// anisotropy are anchored.
3497///
3498/// This is the per-term predicate that distinguishes "fixed kernel scale"
3499/// from "optimize the kernel scale" within the family entry points that
3500/// want to honor an explicit user-supplied scale (e.g. Bernoulli
3501/// marginal-slope, where the joint-spatial outer solver otherwise spends
3502/// ~80 iters stalled on the user's chosen ρ at high gradient).
3503pub fn spatial_term_has_locked_kappa(spec: &TermCollectionSpec, term_idx: usize) -> bool {
3504    get_spatial_length_scale(spec, term_idx).is_some()
3505        && !spatial_term_uses_per_axis_psi(spec, term_idx)
3506}
3507
3508pub fn all_spatial_terms_kappa_fixed(spec: &TermCollectionSpec) -> bool {
3509    spec.smooth_terms.iter().enumerate().all(|(idx, _)| {
3510        !spatial_term_supports_hyper_optimization(spec, idx)
3511            || spatial_term_has_locked_kappa(spec, idx)
3512    })
3513}
3514
3515pub fn spatial_identifiability_policy(termspec: &SmoothTermSpec) -> Option<&SpatialIdentifiability> {
3516    match &termspec.basis {
3517        SmoothBasisSpec::ThinPlate { spec, .. } => Some(&spec.identifiability),
3518        SmoothBasisSpec::Duchon { spec, .. } => Some(&spec.identifiability),
3519        _ => None,
3520    }
3521}
3522
3523/// Standard deviation of the wide, weakly-informative symmetric `Normal` prior
3524/// placed on a relaxable double-penalty smooth's `DoublePenaltyNullspace`
3525/// selection coordinate when the fit is well-determined.
3526pub const NULLSPACE_WELLDET_DEGENERACY_RHO_SD: f64 = 15.0;
3527
3528/// True iff `prior` is the well-determined double-penalty null-space
3529/// degeneracy prior placed on a `DoublePenaltyNullspace` selection coordinate.
3530pub fn is_nullspace_degeneracy_prior(prior: &gam_spec::RhoPrior) -> bool {
3531    matches!(
3532        prior,
3533        gam_spec::RhoPrior::Normal { mean, sd }
3534            if *mean == 0.0 && *sd == NULLSPACE_WELLDET_DEGENERACY_RHO_SD
3535    )
3536}
3537
3538/// Per-term data-derived ψ = log κ bounds.
3539///
3540/// Uses the same safe operating range documented in
3541/// [`crate::basis::build_matern_basis`] / [`crate::basis::build_duchon_basis`]:
3542///   κ ∈ [2 / r_max, 1e2 / r_min]
3543/// where (r_min, r_max) are pairwise-distance extrema of the term's resolved
3544/// centers (post-fit) or the standardized feature data columns (pre-fit).
3545/// Lower edge of the data-derived kernel-range window, as a fraction of the
3546/// maximum pairwise distance `r_max`: length scales below `2/r_max` resolve
3547/// structure finer than the closest center pair, so the kernel range floor is
3548/// set at twice the maximum spacing.
3549pub const KERNEL_RANGE_MIN_DIAMETER_FRACTION: f64 = 2.0;
3550
3551/// Upper edge of the data-derived kernel-range window, as a multiple of the
3552/// minimum pairwise distance `r_min`: beyond `100/r_min` the radial columns go
3553/// nearly collinear with the polynomial nullspace, so the kernel range is
3554/// capped here to keep the basis geometry well-conditioned.
3555pub const KERNEL_RANGE_MAX_SPACING_MULTIPLE: f64 = 1e2;
3556
3557
3558/// Returns ψ-space bounds (ψ_lo = ln(κ_lo), ψ_hi = ln(κ_hi)).
3559///
3560/// When geometry is unavailable (e.g., fewer than 2 distinct points), falls
3561/// back to the scalar `options.min_length_scale` / `options.max_length_scale`
3562/// window so the outer optimizer never sees NaN bounds.
3563///
3564/// The returned window is intersected with the options window so user-set
3565/// `min_length_scale` / `max_length_scale` remain hard limits.
3566pub fn spatial_term_psi_bounds(
3567    data: ArrayView2<'_, f64>,
3568    spec: &TermCollectionSpec,
3569    term_idx: usize,
3570    options: &SpatialLengthScaleOptimizationOptions,
3571) -> (f64, f64) {
3572    let fallback = (
3573        -options.max_length_scale.ln(),
3574        -options.min_length_scale.ln(),
3575    );
3576    // Constant-curvature: the ψ coordinate is the raw signed κ, so its window is
3577    // the chart-feasible κ bracket, NOT a log-ℓ window. Mirrors the aniso bounds
3578    // path's `constant_curvature_kappa_bounds` branch so the isotropic
3579    // (non-aniso) seed clamp projects κ into the right interval.
3580    if constant_curvature_term_spec(spec, term_idx).is_some() {
3581        return constant_curvature_kappa_bounds(data, spec, term_idx);
3582    }
3583    let Some(term) = spec.smooth_terms.get(term_idx) else {
3584        return fallback;
3585    };
3586    // Prefer resolved centers (post-fit) since they live in the same standardized
3587    // space the kernel actually sees. Centers are capped at `default_num_centers`
3588    // (<=2000), so exact pairwise bounds are cheap (<4M ops). If centers are
3589    // not yet UserProvided, fall back to the standardized feature data columns
3590    // with the capped-sample path (O(K²·d), K=1024) — the sample is
3591    // conservative for κ bounds (see `pairwise_distance_bounds_sampled`
3592    // docs): it never excludes a feasible κ the exact method would include.
3593    //
3594    // Under anisotropy the kernel metric is y-space (y_a = exp(η_a) x_a),
3595    // so r_min/r_max must be y-space distances. This matters only when the
3596    // spec already carries calibrated η_a at setup time (e.g., warm-start
3597    // or refit paths); for fresh optimization η_a starts at 0 and y = x.
3598    let aniso = get_spatial_aniso_log_scales(spec, term_idx);
3599    let r_bounds = match spatial_term_center_strategy(term) {
3600        Some(CenterStrategy::UserProvided(centers)) if centers.nrows() >= 2 => {
3601            match aniso.as_deref() {
3602                Some(eta) if eta.len() == centers.ncols() => {
3603                    let y = points_in_aniso_y_space(centers.view(), eta);
3604                    pairwise_distance_bounds(y.view())
3605                }
3606                _ => pairwise_distance_bounds(centers.view()),
3607            }
3608        }
3609        _ => standardized_spatial_term_data(data, term)
3610            .ok()
3611            .and_then(|x| match aniso.as_deref() {
3612                Some(eta) if eta.len() == x.ncols() => {
3613                    let y = points_in_aniso_y_space(x.view(), eta);
3614                    pairwise_distance_bounds_sampled(y.view())
3615                }
3616                _ => pairwise_distance_bounds_sampled(x.view()),
3617            }),
3618    };
3619    let Some((r_min, r_max)) = r_bounds else {
3620        return fallback;
3621    };
3622    // Length scales substantially larger than the data diameter make radial
3623    // TPS/Matern columns nearly collinear with their polynomial nullspace.
3624    // The nullspace already carries constant/linear low-frequency structure,
3625    // so cap the kernel range at the diameter scale instead of letting the
3626    // optimizer enter a numerically degenerate basis geometry.
3627    let psi_lo_data = (KERNEL_RANGE_MIN_DIAMETER_FRACTION / r_max).ln();
3628    let psi_hi_data = (KERNEL_RANGE_MAX_SPACING_MULTIPLE / r_min).ln();
3629    // #1074: the Matérn-specific length-scale ceiling that used to live here was
3630    // deleted. It was masking, not fixing, the real defect: a hard upper bound on
3631    // the kernel range that pinned the κ-optimizer short rather than letting the
3632    // optimizer find the REML optimum. Matérn now shares the same generic geometry
3633    // window as Duchon / TPS (`KERNEL_RANGE_MIN_DIAMETER_FRACTION / r_max` floor,
3634    // `KERNEL_RANGE_MAX_SPACING_MULTIPLE / r_min` ceiling); the #1357 fully-flat
3635    // collapse corner is guarded by the EDF-collapse guard in
3636    // `spatial_optimization.rs`, which acts on the realized fit, not on a clamp.
3637    // Intersect with the options window so min/max_length_scale remain hard caps.
3638    let psi_lo = psi_lo_data.max(fallback.0);
3639    let psi_hi = psi_hi_data.min(fallback.1);
3640    if psi_lo >= psi_hi {
3641        // Degenerate intersection — fall back to the options window to keep the
3642        // outer optimizer from collapsing to a point.
3643        return fallback;
3644    }
3645    (psi_lo, psi_hi)
3646}
3647
3648/// Data-derived ψ seed for a spatial term when the user has not set an
3649/// explicit length_scale on its basis spec. Uses the geometric mean of the
3650/// data-informed kappa range (i.e., the midpoint of the ψ window).
3651pub fn spatial_term_psi_seed(
3652    data: ArrayView2<'_, f64>,
3653    spec: &TermCollectionSpec,
3654    term_idx: usize,
3655    options: &SpatialLengthScaleOptimizationOptions,
3656) -> Option<f64> {
3657    if get_spatial_length_scale(spec, term_idx).is_some() {
3658        return None; // user/spec-provided length_scale wins
3659    }
3660    let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, spec, term_idx, options);
3661    Some(0.5 * (psi_lo + psi_hi))
3662}
3663
3664pub fn spatial_term_psi_to_length_scale_and_aniso(psi: &[f64]) -> (Option<f64>, Option<Vec<f64>>) {
3665    if psi.len() <= 1 {
3666        (Some((-psi.first().copied().unwrap_or(0.0)).exp()), None)
3667    } else {
3668        let psi_bar = psi.iter().sum::<f64>() / psi.len() as f64;
3669        (
3670            Some((-psi_bar).exp()),
3671            Some(psi.iter().map(|&value| value - psi_bar).collect()),
3672        )
3673    }
3674}
3675
3676/// Get the `aniso_log_scales` from a spatial term, if present.
3677pub fn get_spatial_aniso_log_scales(
3678    spec: &TermCollectionSpec,
3679    term_idx: usize,
3680) -> Option<Vec<f64>> {
3681    spec.smooth_terms
3682        .get(term_idx)
3683        .and_then(|term| match &term.basis {
3684            SmoothBasisSpec::Matern { spec, .. } => spec.aniso_log_scales.clone(),
3685            SmoothBasisSpec::Duchon { spec, .. } => spec.aniso_log_scales.clone(),
3686            _ => None,
3687        })
3688}
3689
3690/// Per-axis response-structure score for anisotropy seeding.
3691///
3692/// For each spatial axis `a`, sort the response `y` by the axis coordinate
3693/// `x_a` and measure the total squared successive variation of the sorted
3694/// response, `tv_a = Σ_i (y_{σ(i+1)} − y_{σ(i)})²` where `σ` orders rows by
3695/// `x_a`. An axis that carries real (possibly nonlinear) signal makes `y` vary
3696/// SMOOTHLY when the rows are walked in that axis's order, so `tv_a` is SMALL;
3697/// a pure-nuisance axis leaves `y` looking unordered, so `tv_a` is LARGE.
3698///
3699/// This deliberately does NOT use a linear correlation `corr(x_a, y)`: for an
3700/// odd, symmetric signal such as `sin(2·x1)` over a symmetric domain the linear
3701/// correlation is ~0 on the *signal* axis, which would misdirect the seed. The
3702/// total-variation-of-sorted-response score captures nonlinear association.
3703///
3704/// Returns `score_a = −½·ln(tv_a + ε)` (larger ⇒ more signal on axis `a`),
3705/// centered to sum to zero, or `None` when the data is degenerate (too few
3706/// rows, non-finite, or all axes equally (un)structured). The caller adds a
3707/// BOUNDED multiple of this to the geometry seed — it is a conservative nudge,
3708/// never a hard override.
3709pub fn response_aware_axis_contrasts(
3710    x: ndarray::ArrayView2<'_, f64>,
3711    y: ndarray::ArrayView1<'_, f64>,
3712) -> Option<Vec<f64>> {
3713    let n = x.nrows();
3714    let d = x.ncols();
3715    if d <= 1 || n < 4 || y.len() != n {
3716        return None;
3717    }
3718    if x.iter().any(|v| !v.is_finite()) || y.iter().any(|v| !v.is_finite()) {
3719        return None;
3720    }
3721    let mut scores = Vec::with_capacity(d);
3722    for a in 0..d {
3723        let mut order: Vec<usize> = (0..n).collect();
3724        let col = x.column(a);
3725        order.sort_by(|&i, &j| {
3726            col[i]
3727                .partial_cmp(&col[j])
3728                .unwrap_or(std::cmp::Ordering::Equal)
3729        });
3730        let mut tv = 0.0_f64;
3731        for w in order.windows(2) {
3732            let diff = y[w[1]] - y[w[0]];
3733            tv += diff * diff;
3734        }
3735        // ε guards against ln(0) on a perfectly flat / constant response.
3736        scores.push(-0.5 * (tv + 1e-12).ln());
3737    }
3738    if scores.iter().any(|v| !v.is_finite()) {
3739        return None;
3740    }
3741    let mean = scores.iter().sum::<f64>() / d as f64;
3742    let centered: Vec<f64> = scores.iter().map(|&s| s - mean).collect();
3743    // If every axis is equally structured the centered scores are ~0 and the
3744    // nudge is a no-op — return None so the geometry seed is used unchanged.
3745    if centered.iter().all(|&v| v.abs() < 1e-9) {
3746        return None;
3747    }
3748    Some(centered)
3749}
3750
3751/// Conservative, response-aware anisotropy seed nudge applied before the κ outer
3752/// loop. For each anisotropic spatial term it adds a BOUNDED multiple of the
3753/// per-axis response-structure contrast (`response_aware_axis_contrasts`) on top
3754/// of the existing geometry seed, so the optimizer starts in the correct basin
3755/// instead of at a response-blind near-symmetric point (the #1376 under-recovery
3756/// where a signal axis and a nuisance axis with equal coordinate spread seed to
3757/// ~[0,0]). The nudge is clamped to keep this a perturbation, never a hard
3758/// override, so shared aniso Matérn/Duchon fits cannot be destabilized by it.
3759pub fn apply_response_aware_anisotropy_seed(
3760    data: ArrayView2<'_, f64>,
3761    y: ndarray::ArrayView1<'_, f64>,
3762    spec: &mut TermCollectionSpec,
3763    spatial_terms: &[usize],
3764) {
3765    // Bound on the per-axis contrast nudge (in η units). One LN_2 ≈ 0.69 halves
3766    // the effective per-axis length scale; capping at LN_2 keeps the seed within
3767    // one optimizer log-step of the geometry seed while still breaking the
3768    // symmetric-seed trap.
3769    const MAX_NUDGE: f64 = std::f64::consts::LN_2;
3770    for &term_idx in spatial_terms {
3771        let Some(current_eta) = get_spatial_aniso_log_scales(spec, term_idx) else {
3772            continue;
3773        };
3774        let d = current_eta.len();
3775        if d <= 1 {
3776            continue;
3777        }
3778        let Some(term) = spec.smooth_terms.get(term_idx) else {
3779            continue;
3780        };
3781        let feature_cols = term.basis.structural_feature_cols();
3782        if feature_cols.len() != d {
3783            continue;
3784        }
3785        let Ok(x) = select_columns(data, &feature_cols) else {
3786            continue;
3787        };
3788        let Some(contrast) = response_aware_axis_contrasts(x.view(), y) else {
3789            continue;
3790        };
3791        let nudged: Vec<f64> = current_eta
3792            .iter()
3793            .zip(contrast.iter())
3794            .map(|(&eta_a, &c_a)| eta_a + c_a.clamp(-MAX_NUDGE, MAX_NUDGE))
3795            .collect();
3796        // `set_spatial_aniso_log_scales` re-centers to Σ η = 0. A term that does
3797        // not support aniso scales is silently skipped (the seed is optional).
3798        if let Err(err) = set_spatial_aniso_log_scales(spec, term_idx, nudged) {
3799            log::debug!(
3800                "[spatial-kappa] response-aware anisotropy seed skipped for term {term_idx}: {err}"
3801            );
3802        }
3803    }
3804}
3805
3806/// Get the number of feature columns (spatial dimensionality) for a spatial term.
3807pub fn get_spatial_feature_dim(spec: &TermCollectionSpec, term_idx: usize) -> Option<usize> {
3808    spec.smooth_terms
3809        .get(term_idx)
3810        .and_then(|term| match &term.basis {
3811            SmoothBasisSpec::ThinPlate { feature_cols, .. } => Some(feature_cols.len()),
3812            SmoothBasisSpec::Matern { feature_cols, .. } => Some(feature_cols.len()),
3813            SmoothBasisSpec::Duchon { feature_cols, .. } => Some(feature_cols.len()),
3814            _ => None,
3815        })
3816}
3817
3818/// Log the learned per-axis spatial anisotropy for all spatial terms that
3819/// have `aniso_log_scales` set after optimization.
3820///
3821/// For scalar-scale families this reports eta, effective per-axis length
3822/// scales, and per-axis kappa values. For pure Duchon it reports the centered
3823/// eta contrasts only.
3824pub fn log_spatial_aniso_scales(spec: &TermCollectionSpec) {
3825    for (term_idx, term) in spec.smooth_terms.iter().enumerate() {
3826        let (aniso, length_scale) = match &term.basis {
3827            SmoothBasisSpec::Matern { spec, .. } => {
3828                (spec.aniso_log_scales.as_ref(), Some(spec.length_scale))
3829            }
3830            SmoothBasisSpec::Duchon { spec, .. } => {
3831                (spec.aniso_log_scales.as_ref(), spec.length_scale)
3832            }
3833            _ => (None, None),
3834        };
3835        let Some(eta) = aniso else { continue };
3836        if eta.is_empty() {
3837            continue;
3838        }
3839        let mut lines = match length_scale {
3840            Some(ls) => format!(
3841                "[spatial-kappa] term {} (\"{}\"): anisotropic length scales optimized (global length_scale={:.4})",
3842                term_idx, term.name, ls
3843            ),
3844            None => format!(
3845                "[spatial-kappa] term {} (\"{}\"): pure Duchon shape anisotropy optimized",
3846                term_idx, term.name
3847            ),
3848        };
3849        for (a, &eta_a) in eta.iter().enumerate() {
3850            if let Some(ls) = length_scale {
3851                let length_a = ls * (-eta_a).exp();
3852                let kappa_a = (1.0 / ls) * eta_a.exp();
3853                lines.push_str(&format!(
3854                    "\n  axis {}: eta={:+.4}, length={:.4}, kappa={:.4}",
3855                    a, eta_a, length_a, kappa_a
3856                ));
3857            } else {
3858                lines.push_str(&format!("\n  axis {}: eta={:+.4}", a, eta_a));
3859            }
3860        }
3861        log::info!("{}", lines);
3862    }
3863}
3864
3865/// Set `aniso_log_scales` on a spatial term's basis spec.
3866pub fn set_spatial_aniso_log_scales(
3867    spec: &mut TermCollectionSpec,
3868    term_idx: usize,
3869    eta: Vec<f64>,
3870) -> Result<(), EstimationError> {
3871    let eta = center_aniso_log_scales(&eta);
3872    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3873        crate::bail_invalid_estim!("spatial aniso_log_scales term index {term_idx} out of range");
3874    };
3875    match &mut term.basis {
3876        SmoothBasisSpec::Matern { spec, .. } => {
3877            spec.aniso_log_scales = Some(eta);
3878            Ok(())
3879        }
3880        SmoothBasisSpec::Duchon { spec, .. } => {
3881            spec.aniso_log_scales = Some(eta);
3882            Ok(())
3883        }
3884        _ => Err(EstimationError::InvalidInput(format!(
3885            "term '{}' does not support aniso_log_scales",
3886            term.name
3887        ))),
3888    }
3889}
3890
3891/// Sync knot-cloud-derived anisotropy contrasts from basis metadata back into
3892/// the mutable spec so the optimizer starts from the correct eta values.
3893///
3894/// Call this after building the smooth design but before initializing the
3895/// optimizer's psi coordinates. For each spatial term whose metadata contains
3896/// computed `aniso_log_scales`, this writes them into the spec.
3897pub fn sync_aniso_contrasts_from_metadata(
3898    spec: &mut TermCollectionSpec,
3899    design: &SmoothDesign,
3900) {
3901    for (term_idx, term) in design.terms.iter().enumerate() {
3902        let meta_aniso = match &term.metadata {
3903            BasisMetadata::Matern {
3904                aniso_log_scales, ..
3905            } => aniso_log_scales.clone(),
3906            BasisMetadata::Duchon {
3907                aniso_log_scales, ..
3908            } => aniso_log_scales.clone(),
3909            _ => None,
3910        };
3911        if let Some(eta) = meta_aniso
3912            && eta.len() > 1
3913        {
3914            set_spatial_aniso_log_scales(spec, term_idx, eta).ok();
3915        }
3916    }
3917}
3918
3919#[derive(Debug, Clone)]
3920pub struct SpatialLengthScaleOptimizationOptions {
3921    /// Enable outer-loop optimization over spatial κ (= 1 / length_scale)
3922    /// for supported radial-kernel smooths.
3923    /// This applies to ThinPlate, Matérn, and Duchon terms.
3924    pub enabled: bool,
3925    /// Maximum number of outer iterations in the exact joint [rho, psi] solve.
3926    pub max_outer_iter: usize,
3927    /// Relative improvement threshold for terminating the outer solve.
3928    pub rel_tol: f64,
3929    /// Initial log(length_scale) perturbation used for seed construction.
3930    pub log_step: f64,
3931    /// Minimum allowed length_scale during κ search.
3932    pub min_length_scale: f64,
3933    /// Maximum allowed length_scale during κ search.
3934    pub max_length_scale: f64,
3935    /// Automatic geometry-initializer threshold for large-scale spatial fits.
3936    ///
3937    /// When n exceeds twice this value, the fitter uses a spatially stratified
3938    /// subsample only to seed κ/anisotropy geometry: centers are resolved,
3939    /// axis contrasts are initialized from center/data spread, and one or two
3940    /// cheap ψ reseeding updates are applied. It never runs PIRLS, REML, ARC,
3941    /// BFGS, or any recursive optimizer on the pilot.
3942    ///
3943    /// The final coefficients, smoothing parameters, and spatial geometry are
3944    /// always optimized on the full dataset.
3945    ///
3946    /// Set to 0 to skip the pilot geometry initializer.
3947    pub pilot_subsample_threshold: usize,
3948    /// Optional wall-clock budget (seconds) for the whole outer smoothing search
3949    /// (gam#979). When a family arms the global deadline from this, an outer
3950    /// search that cannot certify convergence (survival marginal-slope's
3951    /// monotonicity-pinned constrained joint-Newton) returns its best-so-far
3952    /// iterate (or a catchable error) within the budget instead of hanging.
3953    /// `None` keeps the legacy unbounded behavior; the survival marginal-slope
3954    /// path applies a generous default when this is `None`.
3955    pub outer_wall_clock_budget_secs: Option<f64>,
3956}
3957
3958impl Default for SpatialLengthScaleOptimizationOptions {
3959    fn default() -> Self {
3960        Self {
3961            enabled: true,
3962            max_outer_iter: 80,
3963            rel_tol: 1e-4,
3964            log_step: std::f64::consts::LN_2,
3965            min_length_scale: 1e-3,
3966            max_length_scale: 1e3,
3967            pilot_subsample_threshold: 10_000,
3968            outer_wall_clock_budget_secs: None,
3969        }
3970    }
3971}
3972
3973impl SpatialLengthScaleOptimizationOptions {
3974    /// Validate the struct's invariants. Callers that construct these options
3975    /// from external input (CLI, config, Python API) should call this before
3976    /// passing the options into the fitter. Returns `Err` with a descriptive
3977    /// message when an invariant is violated; the fitter then panics or
3978    /// returns `EstimationError` at its own boundary.
3979    ///
3980    /// Invariants:
3981    ///   * `min_length_scale > 0`, finite
3982    ///   * `max_length_scale > 0`, finite
3983    ///   * `min_length_scale < max_length_scale`
3984    ///   * `rel_tol > 0`, finite
3985    ///   * `log_step > 0`, finite
3986    ///
3987    /// These invariants are what the downstream κ-bound and ψ-window code
3988    /// assumes (`-log(max_ls)` must be finite, `(min,max)` must not be
3989    /// inverted, etc.). Without validation, invalid options produce silent
3990    /// NaN-propagation inside the outer optimizer.
3991    pub fn validate(&self) -> Result<(), String> {
3992        if !self.min_length_scale.is_finite() || self.min_length_scale <= 0.0 {
3993            return Err(SmoothError::invalid_config(format!(
3994                "SpatialLengthScaleOptimizationOptions::min_length_scale must be > 0 and finite, got {}",
3995                self.min_length_scale
3996            ))
3997            .into());
3998        }
3999        if !self.max_length_scale.is_finite() || self.max_length_scale <= 0.0 {
4000            return Err(SmoothError::invalid_config(format!(
4001                "SpatialLengthScaleOptimizationOptions::max_length_scale must be > 0 and finite, got {}",
4002                self.max_length_scale
4003            ))
4004            .into());
4005        }
4006        if self.min_length_scale >= self.max_length_scale {
4007            return Err(SmoothError::invalid_config(format!(
4008                "SpatialLengthScaleOptimizationOptions requires min_length_scale < max_length_scale, got min={} max={}",
4009                self.min_length_scale, self.max_length_scale
4010            ))
4011            .into());
4012        }
4013        if !self.rel_tol.is_finite() || self.rel_tol <= 0.0 {
4014            return Err(SmoothError::invalid_config(format!(
4015                "SpatialLengthScaleOptimizationOptions::rel_tol must be > 0 and finite, got {}",
4016                self.rel_tol
4017            ))
4018            .into());
4019        }
4020        if !self.log_step.is_finite() || self.log_step <= 0.0 {
4021            return Err(SmoothError::invalid_config(format!(
4022                "SpatialLengthScaleOptimizationOptions::log_step must be > 0 and finite, got {}",
4023                self.log_step
4024            ))
4025            .into());
4026        }
4027        Ok(())
4028    }
4029}
4030
4031#[derive(Debug, Clone)]
4032pub struct RandomEffectBlock {
4033    pub name: String,
4034    /// O(n) group-label vector: group_ids[i] = column index in [0, num_groups).
4035    /// `None` if the observation's level is not in the kept set.
4036    pub group_ids: Vec<Option<usize>>,
4037    pub num_groups: usize,
4038    pub kept_levels: Vec<u64>,
4039}
4040
4041pub const BLOCK_SPARSE_ZERO_EPS: f64 = 1e-12;
4042
4043pub const BLOCK_SPARSE_MAX_DENSITY: f64 = 0.20;
4044
4045pub fn blocks_have_intrinsic_sparse_structure(blocks: &[DesignBlock]) -> bool {
4046    blocks
4047        .iter()
4048        .any(|block| matches!(block, DesignBlock::Sparse(_) | DesignBlock::RandomEffect(_)))
4049}
4050
4051pub fn sparse_compatible_block_nnz(block: &DesignBlock) -> Option<usize> {
4052    match block {
4053        DesignBlock::Intercept(n) => Some(*n),
4054        DesignBlock::RandomEffect(op) => {
4055            Some(op.group_ids.iter().filter(|gid| gid.is_some()).count())
4056        }
4057        DesignBlock::Sparse(sparse) => Some(sparse.val().len()),
4058        DesignBlock::Dense(dense) => dense.as_dense_ref().map(|matrix| {
4059            matrix
4060                .iter()
4061                .filter(|&&value| value.abs() > BLOCK_SPARSE_ZERO_EPS)
4062                .count()
4063        }),
4064    }
4065}
4066
4067pub fn try_build_sparse_design_from_blocks(
4068    blocks: &[DesignBlock],
4069) -> Result<Option<DesignMatrix>, BasisError> {
4070    if blocks.is_empty() {
4071        return Ok(None);
4072    }
4073    let nrows = blocks[0].nrows();
4074    let ncols: usize = blocks.iter().map(DesignBlock::ncols).sum();
4075    if nrows == 0 || ncols == 0 || ncols <= 32 {
4076        return Ok(None);
4077    }
4078
4079    let preserve_sparse_storage = blocks_have_intrinsic_sparse_structure(blocks);
4080    let sparse_nnz_limit = if preserve_sparse_storage {
4081        usize::MAX
4082    } else {
4083        let total_cells = nrows.saturating_mul(ncols);
4084        ((total_cells as f64) * BLOCK_SPARSE_MAX_DENSITY).floor() as usize
4085    };
4086    let mut nnz = 0usize;
4087    for block in blocks {
4088        let block_nnz = if let Some(block_nnz) = sparse_compatible_block_nnz(block) {
4089            block_nnz
4090        } else {
4091            return Ok(None);
4092        };
4093        nnz = nnz.saturating_add(block_nnz);
4094        if nnz > sparse_nnz_limit {
4095            return Ok(None);
4096        }
4097    }
4098
4099    let mut triplets = Vec::<Triplet<usize, usize, f64>>::with_capacity(nnz);
4100    let mut col_offset = 0usize;
4101    for block in blocks {
4102        match block {
4103            DesignBlock::Intercept(n) => {
4104                for row in 0..*n {
4105                    triplets.push(Triplet::new(row, col_offset, 1.0));
4106                }
4107            }
4108            DesignBlock::RandomEffect(op) => {
4109                for (row, group_id) in op.group_ids.iter().enumerate() {
4110                    if let Some(group) = group_id {
4111                        triplets.push(Triplet::new(row, col_offset + group, 1.0));
4112                    }
4113                }
4114            }
4115            DesignBlock::Sparse(sparse) => {
4116                let (symbolic, values) = sparse.parts();
4117                let col_ptr = symbolic.col_ptr();
4118                let row_idx = symbolic.row_idx();
4119                for col in 0..sparse.ncols() {
4120                    for idx in col_ptr[col]..col_ptr[col + 1] {
4121                        let value = values[idx];
4122                        if value.abs() > BLOCK_SPARSE_ZERO_EPS {
4123                            triplets.push(Triplet::new(row_idx[idx], col_offset + col, value));
4124                        }
4125                    }
4126                }
4127            }
4128            DesignBlock::Dense(dense) => {
4129                let matrix = dense.as_dense_ref().ok_or_else(|| {
4130                    BasisError::InvalidInput(
4131                        "sparse-compatible block assembly requires materialized dense blocks"
4132                            .to_string(),
4133                    )
4134                })?;
4135                for row in 0..matrix.nrows() {
4136                    for col in 0..matrix.ncols() {
4137                        let value = matrix[[row, col]];
4138                        if value.abs() > BLOCK_SPARSE_ZERO_EPS {
4139                            triplets.push(Triplet::new(row, col_offset + col, value));
4140                        }
4141                    }
4142                }
4143            }
4144        }
4145        col_offset += block.ncols();
4146    }
4147
4148    let sparse = SparseColMat::try_new_from_triplets(nrows, ncols, &triplets).map_err(|_| {
4149        BasisError::SparseCreation("failed to assemble sparse term-collection design".to_string())
4150    })?;
4151    Ok(Some(DesignMatrix::Sparse(
4152        gam_linalg::matrix::SparseDesignMatrix::new(sparse),
4153    )))
4154}
4155
4156pub fn assemble_term_collection_design_matrix(
4157    blocks: Vec<DesignBlock>,
4158) -> Result<DesignMatrix, BasisError> {
4159    if let Some(sparse) = try_build_sparse_design_from_blocks(&blocks)? {
4160        return Ok(sparse);
4161    }
4162    let block_op = BlockDesignOperator::new(blocks).map_err(|e| {
4163        BasisError::InvalidInput(format!("failed to build block design operator: {e}"))
4164    })?;
4165    Ok(DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
4166        Arc::new(block_op),
4167    )))
4168}
4169
4170pub fn select_columns(data: ArrayView2<'_, f64>, cols: &[usize]) -> Result<Array2<f64>, BasisError> {
4171    let n = data.nrows();
4172    let p = data.ncols();
4173    for &c in cols {
4174        if c >= p {
4175            crate::bail_dim_basis!("feature column {c} is out of bounds for data with {p} columns");
4176        }
4177    }
4178    let mut out = Array2::<f64>::zeros((n, cols.len()));
4179    for (j, &c) in cols.iter().enumerate() {
4180        out.column_mut(j).assign(&data.column(c));
4181    }
4182    Ok(out)
4183}
4184
4185pub fn nonfinite_value_label(value: f64) -> &'static str {
4186    if value.is_nan() {
4187        "NaN"
4188    } else if value.is_sign_positive() {
4189        "+Inf"
4190    } else {
4191        "-Inf"
4192    }
4193}
4194
4195pub fn validate_term_feature_column_finite(
4196    data: ArrayView2<'_, f64>,
4197    term_kind: &str,
4198    term_name: &str,
4199    feature_col: usize,
4200) -> Result<(), BasisError> {
4201    let p = data.ncols();
4202    if feature_col >= p {
4203        crate::bail_dim_basis!(
4204            "{term_kind} term '{term_name}' feature column {feature_col} out of bounds for {p} columns"
4205        );
4206    }
4207    for (row, &value) in data.column(feature_col).iter().enumerate() {
4208        if !value.is_finite() {
4209            crate::bail_invalid_basis!(
4210                "{term_kind} term '{term_name}' feature column {feature_col} row {row} contains non-finite value {}",
4211                nonfinite_value_label(value)
4212            );
4213        }
4214    }
4215    Ok(())
4216}
4217
4218pub fn validate_smooth_terms_finite_inputs(
4219    data: ArrayView2<'_, f64>,
4220    terms: &[SmoothTermSpec],
4221) -> Result<(), BasisError> {
4222    for term in terms {
4223        for feature_col in smooth_term_feature_cols(term) {
4224            validate_term_feature_column_finite(data, "smooth", &term.name, feature_col)?;
4225        }
4226    }
4227    Ok(())
4228}
4229
4230pub fn validate_term_collection_finite_inputs(
4231    data: ArrayView2<'_, f64>,
4232    spec: &TermCollectionSpec,
4233) -> Result<(), BasisError> {
4234    for term in &spec.linear_terms {
4235        validate_term_feature_column_finite(data, "linear", &term.name, term.feature_col)?;
4236    }
4237    for term in &spec.random_effect_terms {
4238        validate_term_feature_column_finite(data, "random-effect", &term.name, term.feature_col)?;
4239    }
4240    validate_smooth_terms_finite_inputs(data, &spec.smooth_terms)
4241}
4242
4243#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
4244pub struct JointSpatialCenterGroupKey {
4245    feature_cols: Vec<usize>,
4246    strategy_kind: CenterStrategyKind,
4247    strategy_aux: usize,
4248    requested_num_centers: usize,
4249    input_scale_bits: Option<Vec<u64>>,
4250}
4251
4252pub fn spatial_term_min_center_count(term: &SmoothTermSpec) -> usize {
4253    match &term.basis {
4254        SmoothBasisSpec::ThinPlate { feature_cols, .. } => feature_cols.len() + 1,
4255        SmoothBasisSpec::Duchon {
4256            feature_cols, spec, ..
4257        } => match spec.nullspace_order {
4258            crate::basis::DuchonNullspaceOrder::Zero => 1,
4259            crate::basis::DuchonNullspaceOrder::Linear => feature_cols.len() + 1,
4260            crate::basis::DuchonNullspaceOrder::Degree(degree) => {
4261                crate::basis::duchon_nullspace_dimension(feature_cols.len(), degree)
4262            }
4263        },
4264        SmoothBasisSpec::Matern { .. } => 1,
4265        _ => 1,
4266    }
4267}
4268
4269pub fn spatial_term_group_key(term: &SmoothTermSpec) -> Option<JointSpatialCenterGroupKey> {
4270    let (feature_cols, strategy, input_scales) = match &term.basis {
4271        SmoothBasisSpec::ThinPlate {
4272            feature_cols,
4273            spec,
4274            input_scales,
4275        } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4276        SmoothBasisSpec::Matern {
4277            feature_cols,
4278            spec,
4279            input_scales,
4280        } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4281        SmoothBasisSpec::Duchon {
4282            feature_cols,
4283            spec,
4284            input_scales,
4285        } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4286        _ => return None,
4287    };
4288    let strategy_kind = center_strategy_kind(strategy);
4289    let strategy_aux = match strategy {
4290        CenterStrategy::Auto(inner) => match inner.as_ref() {
4291            CenterStrategy::KMeans { max_iter, .. } => *max_iter,
4292            CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
4293            _ => 0,
4294        },
4295        CenterStrategy::KMeans { max_iter, .. } => *max_iter,
4296        CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
4297        _ => 0,
4298    };
4299    Some(JointSpatialCenterGroupKey {
4300        feature_cols: feature_cols.clone(),
4301        strategy_kind,
4302        strategy_aux,
4303        requested_num_centers: center_strategy_num_centers(strategy)?,
4304        input_scale_bits: input_scales
4305            .map(|values| values.iter().map(|value| value.to_bits()).collect()),
4306    })
4307}
4308
4309pub fn spatial_term_center_strategy(term: &SmoothTermSpec) -> Option<&CenterStrategy> {
4310    match &term.basis {
4311        SmoothBasisSpec::ThinPlate { spec, .. } => Some(&spec.center_strategy),
4312        SmoothBasisSpec::Matern { spec, .. } => Some(&spec.center_strategy),
4313        SmoothBasisSpec::Duchon { spec, .. } => Some(&spec.center_strategy),
4314        _ => None,
4315    }
4316}
4317
4318pub fn set_spatial_term_centers(
4319    term: &mut SmoothTermSpec,
4320    centers: Array2<f64>,
4321) -> Result<(), BasisError> {
4322    match &mut term.basis {
4323        SmoothBasisSpec::ThinPlate { spec, .. } => {
4324            spec.center_strategy = CenterStrategy::UserProvided(centers);
4325            Ok(())
4326        }
4327        SmoothBasisSpec::Matern { spec, .. } => {
4328            spec.center_strategy = CenterStrategy::UserProvided(centers);
4329            Ok(())
4330        }
4331        SmoothBasisSpec::Duchon { spec, .. } => {
4332            spec.center_strategy = CenterStrategy::UserProvided(centers);
4333            Ok(())
4334        }
4335        _ => Err(BasisError::InvalidInput(format!(
4336            "term '{}' does not support spatial center planning",
4337            term.name
4338        ))),
4339    }
4340}
4341
4342pub fn standardized_spatial_term_data(
4343    data: ArrayView2<'_, f64>,
4344    term: &SmoothTermSpec,
4345) -> Result<Array2<f64>, BasisError> {
4346    let (feature_cols, input_scales) = match &term.basis {
4347        SmoothBasisSpec::ThinPlate {
4348            feature_cols,
4349            input_scales,
4350            ..
4351        }
4352        | SmoothBasisSpec::Matern {
4353            feature_cols,
4354            input_scales,
4355            ..
4356        }
4357        | SmoothBasisSpec::Duchon {
4358            feature_cols,
4359            input_scales,
4360            ..
4361        } => (feature_cols, input_scales.as_ref()),
4362        _ => {
4363            crate::bail_invalid_basis!("term '{}' is not a spatial smooth", term.name);
4364        }
4365    };
4366    let mut x = select_columns(data, feature_cols)?;
4367    if let Some(scales) = input_scales {
4368        apply_input_standardization(&mut x, scales);
4369    } else if let Some(scales) = compute_spatial_input_scales(x.view()) {
4370        apply_input_standardization(&mut x, &scales);
4371    }
4372    Ok(x)
4373}
4374
4375pub fn plan_joint_spatial_centers_for_term_blocks(
4376    data: ArrayView2<'_, f64>,
4377    term_blocks: &[Vec<SmoothTermSpec>],
4378) -> Result<Vec<Vec<SmoothTermSpec>>, BasisError> {
4379    let mut planned_blocks = term_blocks.to_vec();
4380    let n = data.nrows();
4381    let mut groups: BTreeMap<JointSpatialCenterGroupKey, Vec<(usize, usize)>> = BTreeMap::new();
4382
4383    for (block_idx, terms) in planned_blocks.iter().enumerate() {
4384        for (term_idx, term) in terms.iter().enumerate() {
4385            let Some(strategy) = spatial_term_center_strategy(term) else {
4386                continue;
4387            };
4388            if !center_strategy_is_auto(strategy) {
4389                continue;
4390            }
4391            let Some(group_key) = spatial_term_group_key(term) else {
4392                continue;
4393            };
4394            if !matches!(
4395                group_key.strategy_kind,
4396                CenterStrategyKind::EqualMass
4397                    | CenterStrategyKind::EqualMassCovarRepresentative
4398                    | CenterStrategyKind::FarthestPoint
4399                    | CenterStrategyKind::KMeans
4400            ) {
4401                continue;
4402            }
4403            if center_strategy_num_centers(strategy).is_none() {
4404                continue;
4405            }
4406            groups
4407                .entry(group_key)
4408                .or_default()
4409                .push((block_idx, term_idx));
4410        }
4411    }
4412
4413    for (group_key, members) in groups {
4414        if members.len() < 2 {
4415            continue;
4416        }
4417        let min_required = members
4418            .iter()
4419            .map(|&(block_idx, term_idx)| {
4420                spatial_term_min_center_count(&planned_blocks[block_idx][term_idx])
4421            })
4422            .max()
4423            .unwrap_or(1);
4424        let joint_centers = group_key
4425            .requested_num_centers
4426            .max(min_required)
4427            .min(n.max(1));
4428        let (first_block_idx, first_term_idx) = members[0];
4429        let prototype = &planned_blocks[first_block_idx][first_term_idx];
4430        let standardized = standardized_spatial_term_data(data, prototype)?;
4431        let strategy = spatial_term_center_strategy(prototype).ok_or_else(|| {
4432            BasisError::InvalidInput(format!(
4433                "term '{}' lost its spatial center strategy during joint planning",
4434                prototype.name
4435            ))
4436        })?;
4437        let joint_strategy = center_strategy_with_num_centers(strategy, joint_centers)?;
4438        let shared_centers = select_centers_by_strategy(standardized.view(), &joint_strategy)?;
4439        log::info!(
4440            "sharing {} spatial centers across {} smooth terms over columns {:?} (requested {} centers)",
4441            shared_centers.nrows(),
4442            members.len(),
4443            group_key.feature_cols,
4444            group_key.requested_num_centers,
4445        );
4446        for (block_idx, term_idx) in members {
4447            set_spatial_term_centers(
4448                &mut planned_blocks[block_idx][term_idx],
4449                shared_centers.clone(),
4450            )?;
4451        }
4452    }
4453
4454    // Sentinel auto-init: Matern and thin-plate builders write length_scale =
4455    // 0.0 when the user didn't pass `length_scale=...`. Replace those with a
4456    // data-driven initialization here so REML starts in a regime where it can
4457    // escape; the hard-coded 1.0 default was a basin from which ν ≥ 5/2 Matern
4458    // could not recover on high-frequency truths, silently collapsing the fit
4459    // to a near-constant prediction.
4460    for block in planned_blocks.iter_mut() {
4461        for term in block.iter_mut() {
4462            auto_init_length_scale_in_place(data, term);
4463        }
4464    }
4465
4466    Ok(planned_blocks)
4467}
4468
4469/// Tiny positive floor for the auto length scale, guarding against a zero
4470/// kernel range when every feature column is (near-)constant.
4471const AUTO_LENGTH_SCALE_FLOOR: f64 = 1e-6;
4472
4473/// Widest per-axis range of the selected feature columns. Returns `None` when
4474/// every selected column is constant / non-finite (no usable spatial scale).
4475fn feature_columns_max_range(data: ArrayView2<'_, f64>, feature_cols: &[usize]) -> Option<f64> {
4476    let mut max_range = 0.0_f64;
4477    for &c in feature_cols {
4478        if c >= data.ncols() {
4479            continue;
4480        }
4481        let col = data.column(c);
4482        let mut lo = f64::INFINITY;
4483        let mut hi = f64::NEG_INFINITY;
4484        for &v in col.iter() {
4485            if v.is_finite() {
4486                if v < lo {
4487                    lo = v;
4488                }
4489                if v > hi {
4490                    hi = v;
4491                }
4492            }
4493        }
4494        if hi > lo {
4495            let r = hi - lo;
4496            if r > max_range {
4497                max_range = r;
4498            }
4499        }
4500    }
4501    if max_range.is_finite() && max_range > 0.0 {
4502        Some(max_range)
4503    } else {
4504        None
4505    }
4506}
4507
4508/// Compute a data-driven initial length scale from the per-axis range of the
4509/// feature columns. The heuristic `max_range / sqrt(n)` puts the kernel on
4510/// the wiggly side of REML's basin so the optimizer can grow it back if the
4511/// signal is smooth, but is small enough that high-frequency truths remain
4512/// reachable for smoother kernels (ν ≥ 5/2). Clamped to a tiny positive
4513/// floor so degenerate constant-input columns can't produce 0.
4514pub fn auto_initial_length_scale(data: ArrayView2<'_, f64>, feature_cols: &[usize]) -> f64 {
4515    let n = data.nrows();
4516    if n == 0 || feature_cols.is_empty() {
4517        return 1.0;
4518    }
4519    let Some(max_range) = feature_columns_max_range(data, feature_cols) else {
4520        return 1.0;
4521    };
4522    let init = max_range / (n as f64).sqrt();
4523    init.max(AUTO_LENGTH_SCALE_FLOOR).min(max_range)
4524}
4525
4526/// Density-adaptive auto length scale for a kernel basis with `num_centers`
4527/// requested centers (#1731).
4528///
4529/// The plain [`auto_initial_length_scale`] seed `max_range / sqrt(n)` is the
4530/// fill distance of the *n data points*; it is independent of the requested
4531/// center count `k`. For a radial kernel at a FIXED length scale, packing more
4532/// centers into the same cloud makes neighbouring basis functions overlap and
4533/// go numerically collinear, so the realized basis saturates in rank (the
4534/// `matern_rank_reduce_centers` cap) and a richer `k` becomes a no-op — or even
4535/// shrinks the basis. The kernel stays well-conditioned only while the length
4536/// scale tracks the *center* spacing, not the data spacing.
4537///
4538/// We seed the length scale at the fill distance of `max(n, k)` points,
4539/// `max_range / sqrt(max(n, k))`. When `n ≥ k` (the usual case) this is exactly
4540/// the existing `max_range / sqrt(n)` seed, so every current result and small-`k`
4541/// basis size is preserved bit-for-bit (in every covariate dimension). When
4542/// `k > n` (a dense center request on a small cloud, the regime where an
4543/// `n`-sized seed sits above the center spacing and over-smooths the centers
4544/// into collinearity) the seed shrinks with `k` to the center spacing, keeping
4545/// the requested centers numerically independent. This is the Matérn analogue of
4546/// the Duchon-promotion "length_scale from center spacing" rule
4547/// (`hybrid_duchon_promotion_length_scale`).
4548pub fn auto_initial_length_scale_for_centers(
4549    data: ArrayView2<'_, f64>,
4550    feature_cols: &[usize],
4551    num_centers: usize,
4552) -> f64 {
4553    let n = data.nrows();
4554    if n == 0 || feature_cols.is_empty() {
4555        return 1.0;
4556    }
4557    let Some(max_range) = feature_columns_max_range(data, feature_cols) else {
4558        return 1.0;
4559    };
4560    // Resolution density: at least the data points, but no coarser than the
4561    // center spacing once more centers than data are requested. Using the same
4562    // `sqrt` fill-distance law as `auto_initial_length_scale` keeps the seed
4563    // bit-identical whenever `n ≥ num_centers` (every dimension), and only
4564    // shrinks it — never grows it — when `num_centers > n`.
4565    let resolution_points = n.max(num_centers).max(1) as f64;
4566    let spacing = max_range / resolution_points.sqrt();
4567    spacing.max(AUTO_LENGTH_SCALE_FLOOR).min(max_range)
4568}
4569
4570/// Requested center count encoded by a [`CenterStrategy`], if it carries an
4571/// explicit count (used to make the Matérn auto length scale density-adaptive).
4572fn center_strategy_requested_count(strategy: &CenterStrategy) -> Option<usize> {
4573    match strategy {
4574        CenterStrategy::Auto(inner) => center_strategy_requested_count(inner),
4575        CenterStrategy::UserProvided(centers) => Some(centers.nrows()),
4576        CenterStrategy::EqualMass { num_centers }
4577        | CenterStrategy::EqualMassCovarRepresentative { num_centers }
4578        | CenterStrategy::FarthestPoint { num_centers }
4579        | CenterStrategy::KMeans { num_centers, .. } => Some(*num_centers),
4580        CenterStrategy::UniformGrid { .. } => None,
4581    }
4582}
4583
4584/// Walk a term and, if it is a Matern or thin-plate smooth whose length_scale
4585/// was left at the auto sentinel (`0.0`), overwrite it with
4586/// [`auto_initial_length_scale`].
4587pub fn auto_init_length_scale_in_place(data: ArrayView2<'_, f64>, term: &mut SmoothTermSpec) {
4588    auto_init_length_scale_in_basis(data, &mut term.basis);
4589}
4590
4591/// Replace the `0.0` auto-init length-scale sentinel with a data-derived value
4592/// for any Matern / thin-plate kernel reachable from this basis — including the
4593/// inner kernel of a `by=`/factor-smooth wrapper.
4594///
4595/// `by=<factor>` and the sum-to-zero factor smooth wrap a spatial kernel inside
4596/// `SmoothBasisSpec::ByVariable` / `SmoothBasisSpec::FactorSumToZero` /
4597/// `SmoothBasisSpec::BySmooth`, so the wrapper variant is what the planner sees.
4598/// Without recursing into the wrapped basis the inner ThinPlate/Matern keeps the
4599/// `0.0` sentinel (the post-`1605b3a6e` builder default), which makes the kernel
4600/// distance divide by `length_scale² = 0`, producing a non-finite design at both
4601/// fit and predict time. Recurse so the inner kernel is initialized identically
4602/// to a top-level one.
4603pub fn auto_init_length_scale_in_basis(data: ArrayView2<'_, f64>, basis: &mut SmoothBasisSpec) {
4604    match basis {
4605        SmoothBasisSpec::Matern {
4606            feature_cols, spec, ..
4607        } => {
4608            if spec.length_scale == 0.0 {
4609                // Density-adaptive seed (#1731): when the requested center count
4610                // is known, scale the auto length scale with the *center*
4611                // spacing so a richer `k` stays numerically full-rank instead of
4612                // saturating against `matern_rank_reduce_centers`. For `n ≥ k`
4613                // (the usual case) this is identical to the plain `max_range /
4614                // sqrt(n)` seed in 2-D, so small-`k` results are unchanged. The
4615                // unconstrained / non-explicit `UniformGrid` strategy falls back
4616                // to the plain seed.
4617                spec.length_scale = match center_strategy_requested_count(&spec.center_strategy) {
4618                    Some(k) => auto_initial_length_scale_for_centers(data, feature_cols, k),
4619                    None => auto_initial_length_scale(data, feature_cols),
4620                };
4621            }
4622        }
4623        SmoothBasisSpec::ThinPlate {
4624            feature_cols, spec, ..
4625        } => {
4626            if spec.length_scale == 0.0 {
4627                spec.length_scale = auto_initial_length_scale(data, feature_cols);
4628            }
4629        }
4630        SmoothBasisSpec::ByVariable { inner, .. }
4631        | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
4632            auto_init_length_scale_in_basis(data, inner);
4633        }
4634        SmoothBasisSpec::BySmooth { smooth, .. } => {
4635            auto_init_length_scale_in_basis(data, smooth);
4636        }
4637        _ => {}
4638    }
4639}
4640
4641impl LinearFitConditioning {
4642    pub fn from_columns(design: &TermCollectionDesign, selected_cols: &[usize]) -> Self {
4643        const SCALE_EPS: f64 = 1e-12;
4644        let n = design.design.nrows();
4645        let p = design.design.ncols();
4646        let mut columns = Vec::with_capacity(selected_cols.len());
4647        if n == 0 || selected_cols.is_empty() {
4648            return Self {
4649                intercept_idx: design.intercept_range.start,
4650                columns,
4651            };
4652        }
4653        let chunk_rows = gam_linalg::utils::row_chunk_for_byte_budget(n, p);
4654        // Two-pass mean/variance so operator-backed designs don't need to
4655        // materialize the full dense matrix. Pass 1 accumulates per-column
4656        // sums; pass 2 accumulates the sum of squared deviations from the
4657        // pass-1 mean. This matches the original `Σ (x − mean)² / n` formula
4658        // without the catastrophic cancellation of `E[X²] − E[X]²`.
4659        let mut sums = vec![0.0_f64; selected_cols.len()];
4660        for start in (0..n).step_by(chunk_rows) {
4661            let end = (start + chunk_rows).min(n);
4662            let chunk = design
4663                .design
4664                .try_row_chunk(start..end)
4665                .expect("LinearFitConditioning::from_columns row chunk failed");
4666            for (k, &col_idx) in selected_cols.iter().enumerate() {
4667                let column = chunk.column(col_idx);
4668                for &v in column.iter() {
4669                    sums[k] += v;
4670                }
4671            }
4672        }
4673        let inv_n = 1.0_f64 / n as f64;
4674        let means: Vec<f64> = sums.iter().map(|&s| s * inv_n).collect();
4675        let mut sq_devs = vec![0.0_f64; selected_cols.len()];
4676        for start in (0..n).step_by(chunk_rows) {
4677            let end = (start + chunk_rows).min(n);
4678            let chunk = design
4679                .design
4680                .try_row_chunk(start..end)
4681                .expect("LinearFitConditioning::from_columns row chunk failed");
4682            for (k, &col_idx) in selected_cols.iter().enumerate() {
4683                let mean_k = means[k];
4684                let column = chunk.column(col_idx);
4685                for &v in column.iter() {
4686                    let d = v - mean_k;
4687                    sq_devs[k] += d * d;
4688                }
4689            }
4690        }
4691        for (k, &col_idx) in selected_cols.iter().enumerate() {
4692            let mean = means[k];
4693            let var = sq_devs[k] * inv_n;
4694            let (mean, scale) = if var.is_finite() && var > SCALE_EPS * SCALE_EPS {
4695                (mean, var.sqrt())
4696            } else {
4697                // Leave nearly-constant columns untouched; centering them would collapse
4698                // the design column to ~0 and change the model rather than just condition it.
4699                (0.0, 1.0)
4700            };
4701            columns.push(LinearColumnConditioning {
4702                col_idx,
4703                mean,
4704                scale,
4705            });
4706        }
4707        Self {
4708            intercept_idx: design.intercept_range.start,
4709            columns,
4710        }
4711    }
4712
4713    pub fn apply_to_design(&self, design: &Array2<f64>) -> Array2<f64> {
4714        let mut out = design.clone();
4715        for col in &self.columns {
4716            {
4717                let mut dst = out.column_mut(col.col_idx);
4718                dst -= col.mean;
4719            }
4720            if col.scale != 1.0 {
4721                out.column_mut(col.col_idx).mapv_inplace(|v| v / col.scale);
4722            }
4723        }
4724        out
4725    }
4726
4727    fn transform_matrix_columnswith_a(&self, mat: &Array2<f64>) -> Array2<f64> {
4728        let mut out = mat.clone();
4729        let intercept = self.intercept_idx;
4730        for col in &self.columns {
4731            let intercept_col = out.column(intercept).to_owned();
4732            let mut target = out.column_mut(col.col_idx);
4733            target -= &(intercept_col * col.mean);
4734            if col.scale != 1.0 {
4735                target.mapv_inplace(|v| v / col.scale);
4736            }
4737        }
4738        out
4739    }
4740
4741    fn transform_matrixrowswith_a_transpose(&self, mat: &Array2<f64>) -> Array2<f64> {
4742        let mut out = mat.clone();
4743        let intercept = self.intercept_idx;
4744        for col in &self.columns {
4745            let interceptrow = out.row(intercept).to_owned();
4746            let mut target = out.row_mut(col.col_idx);
4747            target -= &(interceptrow * col.mean);
4748            if col.scale != 1.0 {
4749                target.mapv_inplace(|v| v / col.scale);
4750            }
4751        }
4752        out
4753    }
4754
4755    /// Left-multiply `mat_internal` by `M⁻ᵀ` where `M⁻¹[intercept, j] = mean_j`
4756    /// and `M⁻¹[j, j] = scale_j` for each conditioned column. Used together
4757    /// with [`Self::right_multiply_by_m_inv`] to back-transform an internal
4758    /// penalized Hessian to the original coefficient basis.
4759    fn left_multiply_by_m_inv_transpose(&self, mat_internal: &Array2<f64>) -> Array2<f64> {
4760        let mut out = mat_internal.clone();
4761        let intercept = self.intercept_idx;
4762        let interceptrow_snapshot = mat_internal.row(intercept).to_owned();
4763        for col in &self.columns {
4764            if col.scale != 1.0 {
4765                out.row_mut(col.col_idx).mapv_inplace(|v| v * col.scale);
4766            }
4767            if col.mean != 0.0 {
4768                let mut target = out.row_mut(col.col_idx);
4769                target += &(&interceptrow_snapshot * col.mean);
4770            }
4771        }
4772        out
4773    }
4774
4775    /// Right-multiply `mat_internal` by `M⁻¹`. Mirror of
4776    /// [`Self::left_multiply_by_m_inv_transpose`] on columns.
4777    fn right_multiply_by_m_inv(&self, mat_internal: &Array2<f64>) -> Array2<f64> {
4778        let mut out = mat_internal.clone();
4779        let intercept = self.intercept_idx;
4780        let intercept_col_snapshot = mat_internal.column(intercept).to_owned();
4781        for col in &self.columns {
4782            if col.scale != 1.0 {
4783                out.column_mut(col.col_idx).mapv_inplace(|v| v * col.scale);
4784            }
4785            if col.mean != 0.0 {
4786                let mut target = out.column_mut(col.col_idx);
4787                target += &(&intercept_col_snapshot * col.mean);
4788            }
4789        }
4790        out
4791    }
4792
4793    /// Transform blockwise penalties through the conditioning.
4794    ///
4795    /// For block-local penalties whose `col_range` does not overlap with any
4796    /// conditioning column, the transform is identity (the conditioning only
4797    /// affects unpenalized linear columns). In that common case the penalty
4798    /// passes through unchanged, avoiding O(p²) materialization entirely.
4799    pub fn transform_blockwise_penalties_to_internal(
4800        &self,
4801        penalties: &[BlockwisePenalty],
4802        p: usize,
4803    ) -> Vec<crate::penalty_spec::PenaltySpec> {
4804        let conditioning_cols: std::collections::HashSet<usize> =
4805            self.columns.iter().map(|c| c.col_idx).collect();
4806        penalties
4807            .iter()
4808            .map(|bp| {
4809                let overlaps =
4810                    (bp.col_range.start..bp.col_range.end).any(|j| conditioning_cols.contains(&j));
4811                if overlaps {
4812                    // Rare: penalty block overlaps conditioning columns.
4813                    // Fall back to dense transform.
4814                    let global = bp.to_global(p);
4815                    let right = self.transform_matrix_columnswith_a(&global);
4816                    let transformed = self.transform_matrixrowswith_a_transpose(&right);
4817                    crate::penalty_spec::PenaltySpec::Dense(transformed)
4818                } else {
4819                    // Common: smooth penalty block doesn't touch linear columns.
4820                    // The conditioning is identity on this block.
4821                    crate::penalty_spec::PenaltySpec::from_blockwise(bp.clone())
4822                }
4823            })
4824            .collect()
4825    }
4826
4827    pub fn backtransform_beta(&self, beta_internal: &Array1<f64>) -> Array1<f64> {
4828        let mut beta = beta_internal.clone();
4829        let intercept = self.intercept_idx;
4830        for col in &self.columns {
4831            beta[intercept] -= beta_internal[col.col_idx] * col.mean / col.scale;
4832            beta[col.col_idx] = beta_internal[col.col_idx] / col.scale;
4833        }
4834        beta
4835    }
4836
4837    /// `H_orig = M⁻ᵀ · H_int · M⁻¹`, derived from
4838    /// `L_int(β_int) = L_orig(M · β_int)` via the chain rule.
4839    pub fn transform_penalized_hessian_to_original(&self, h_internal: &Array2<f64>) -> Array2<f64> {
4840        let right = self.right_multiply_by_m_inv(h_internal);
4841        self.left_multiply_by_m_inv_transpose(&right)
4842    }
4843
4844    pub fn internal_bounds_for(&self, col_idx: usize, min: f64, max: f64) -> (f64, f64) {
4845        if let Some(col) = self.columns.iter().find(|c| c.col_idx == col_idx) {
4846            (min * col.scale, max * col.scale)
4847        } else {
4848            (min, max)
4849        }
4850    }
4851}
4852
4853pub fn freeze_raw_spatial_metadata(metadata: BasisMetadata, raw_cols: usize) -> BasisMetadata {
4854    match metadata {
4855        BasisMetadata::ThinPlate {
4856            centers,
4857            length_scale,
4858            periodic,
4859            identifiability_transform: None,
4860            input_scales,
4861            radial_reparam,
4862        } => BasisMetadata::ThinPlate {
4863            centers,
4864            length_scale,
4865            periodic,
4866            identifiability_transform: Some(Array2::eye(raw_cols)),
4867            input_scales,
4868            radial_reparam,
4869        },
4870        BasisMetadata::Duchon {
4871            centers,
4872            length_scale,
4873            periodic,
4874            power,
4875            nullspace_order,
4876            identifiability_transform: None,
4877            input_scales,
4878            aniso_log_scales,
4879            operator_collocation_points,
4880            radial_reparam,
4881        } => BasisMetadata::Duchon {
4882            centers,
4883            length_scale,
4884            periodic,
4885            power,
4886            nullspace_order,
4887            identifiability_transform: Some(Array2::eye(raw_cols)),
4888            input_scales,
4889            aniso_log_scales,
4890            operator_collocation_points,
4891            radial_reparam,
4892        },
4893        other => other,
4894    }
4895}
4896
4897pub fn matern_operator_penalty_triplet_from_metadata(
4898    metadata: &BasisMetadata,
4899) -> Result<(Vec<Array2<f64>>, Vec<usize>, Vec<PenaltyInfo>), BasisError> {
4900    let BasisMetadata::Matern {
4901        centers,
4902        length_scale,
4903        periodic,
4904        nu,
4905        include_intercept,
4906        identifiability_transform,
4907        aniso_log_scales,
4908        input_scales,
4909        ..
4910    } = metadata
4911    else {
4912        crate::bail_invalid_basis!("Matérn operator penalties require Matérn metadata");
4913    };
4914    // The metadata records `length_scale` in *original* (un-standardized) data
4915    // coordinates, while `centers` live in the *standardized* coordinate frame
4916    // (per-axis division by `input_scales`). The realized design built the
4917    // kernel against those standardized centers using the σ_geom-compensated
4918    // effective length scale `length_scale / σ_geom`. The collocation operators
4919    // here are evaluated on the same standardized centers, so they must use the
4920    // SAME effective length scale — otherwise the penalty regularizes a
4921    // different RKHS range than the design lives in, leaving rough coefficient
4922    // directions effectively unpenalized. That mismatch is benign in 1-D
4923    // (no standardization) but produces a catastrophic out-of-sample blow-up in
4924    // d ≥ 2 where σ_geom ≠ 1 (#706).
4925    let penalty_length_scale = match input_scales.as_deref() {
4926        Some(scales) => compensate_length_scale_for_standardization(*length_scale, scales),
4927        None => *length_scale,
4928    };
4929    matern_operator_penalty_triplet_at_length_scale(
4930        centers.view(),
4931        periodic.as_deref(),
4932        identifiability_transform.as_ref(),
4933        *nu,
4934        *include_intercept,
4935        aniso_log_scales.as_deref(),
4936        penalty_length_scale,
4937    )
4938}
4939
4940/// Build the canonical Matérn operator-penalty triplet (mass / tension /
4941/// stiffness) at an explicit **effective** length scale — i.e. the
4942/// σ_geom-compensated, standardized-frame scale the design's kernel was built
4943/// against (NOT the original-coordinate `length_scale` stored in metadata).
4944///
4945/// This is the SINGLE source of truth for the Matérn penalty topology. Two
4946/// callers route through it and must therefore stay byte-for-byte consistent:
4947///   * the cold/slow design rebuild (`matern_operator_penalty_triplet_from_metadata`,
4948///     compensating the frozen metadata `length_scale`), and
4949///   * the n-free κ-optimizer re-key (`FrozenTermCollectionIncrementalRealizer::
4950///     canonical_penalties_at_psi`, compensating the trial `ψ → exp(-ψ)` scale).
4951///
4952/// Sharing the body makes the penalty BLOCK COUNT and the per-block numerics
4953/// one deterministic function of `(geometry, ν, η, ℓ_eff)`. The active-operator
4954/// gate is `m = ν + d/2`, which is independent of ℓ, so the block count is
4955/// **ψ-stable by construction**: the re-key can never produce a different number
4956/// of blocks than the frozen design (the desync that #1270 hard-errored on).
4957pub fn matern_operator_penalty_triplet_at_length_scale(
4958    centers: ArrayView2<'_, f64>,
4959    periodic: Option<&[Option<f64>]>,
4960    identifiability_transform: Option<&Array2<f64>>,
4961    nu: crate::basis::MaternNu,
4962    include_intercept: bool,
4963    aniso_log_scales: Option<&[f64]>,
4964    effective_length_scale: f64,
4965) -> Result<(Vec<Array2<f64>>, Vec<usize>, Vec<PenaltyInfo>), BasisError> {
4966    let penalty_centers = crate::basis::expand_periodic_centers(&centers.to_owned(), periodic)?;
4967    let ops = build_matern_collocation_operator_matrices(
4968        penalty_centers.view(),
4969        None,
4970        effective_length_scale,
4971        nu,
4972        include_intercept,
4973        identifiability_transform.map(|z| z.view()),
4974        aniso_log_scales,
4975    )?;
4976    // Gate the operator dials on the Matérn-ν RKHS Sobolev order m = ν + d/2:
4977    // mass (j=0) is always on, tension (j=1) is on for m > 1, stiffness (j=2)
4978    // is on for m > 2. The threshold is strict so the roughest kernel ν=1/2 in
4979    // d=1 (m=1, the exponential/OU H¹ process) sheds both higher operators —
4980    // its kernel already encodes the H¹ control, so adding an extra tension
4981    // dial over-smooths the oscillation it is meant to track (#707). The
4982    // matching gate lives at `DuchonOperatorPenaltySpec::matern_for_smoothness`.
4983    const ORDER_EPS: f64 = 1e-9;
4984    let d = penalty_centers.ncols();
4985    let m = nu.half_integer_value() + 0.5 * d as f64;
4986    let mut candidates = Vec::with_capacity(3);
4987    for (raw, source, min_order) in [
4988        (ops.d0.t().dot(&ops.d0), PenaltySource::OperatorMass, 0.0),
4989        (ops.d1.t().dot(&ops.d1), PenaltySource::OperatorTension, 1.0),
4990        (
4991            ops.d2.t().dot(&ops.d2),
4992            PenaltySource::OperatorStiffness,
4993            2.0,
4994        ),
4995    ] {
4996        if min_order > 0.0 && m <= min_order + ORDER_EPS {
4997            continue;
4998        }
4999        let sym = (&raw + &raw.t()) * 0.5;
5000        let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&sym);
5001        candidates.push(PenaltyCandidate {
5002            matrix,
5003            nullspace_dim_hint: 0,
5004            source,
5005            normalization_scale,
5006            kronecker_factors: None,
5007            op: None,
5008        });
5009    }
5010    filter_active_penalty_candidates(candidates)
5011}
5012
5013pub fn normalize_penalty_in_constrained_space(matrix: &Array2<f64>) -> (Array2<f64>, f64) {
5014    // Constrained-space normalization:
5015    //   c = ||S_con||_F,  S_tilde = S_con / c.
5016    // This is the only normalization coherent with a REML objective that is
5017    // evaluated entirely in constrained coordinates.
5018    let matrix = (matrix + &matrix.t().to_owned()) * 0.5;
5019    // Clamp noise-floor negative eigenvalues so β'Sβ is non-negative as a contract, not just in exact arithmetic.
5020    let matrix = crate::basis::project_penalty_to_psd_cone(&matrix);
5021    let c = matrix.iter().map(|v| v * v).sum::<f64>().sqrt();
5022    if c.is_finite() && c > 0.0 {
5023        (matrix.mapv(|v| v / c), c)
5024    } else {
5025        (matrix, 1.0)
5026    }
5027}
5028
5029pub fn tensor_product_design_from_sparse_marginals(
5030    marginal_sparse: &[&SparseColMat<usize, f64>],
5031) -> Result<SparseColMat<usize, f64>, BasisError> {
5032    if marginal_sparse.is_empty() {
5033        crate::bail_invalid_basis!("TensorBSpline requires at least one marginal basis");
5034    }
5035    let n = marginal_sparse[0].nrows();
5036    for (i, m) in marginal_sparse.iter().enumerate().skip(1) {
5037        if m.nrows() != n {
5038            crate::bail_dim_basis!(
5039                "tensor sparse marginal row mismatch at dim {i}: expected {n}, got {}",
5040                m.nrows()
5041            );
5042        }
5043    }
5044    let dims: Vec<usize> = marginal_sparse.iter().map(|m| m.ncols()).collect();
5045    let total_cols = dims.iter().try_fold(1usize, |acc, &q| {
5046        acc.checked_mul(q)
5047            .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))
5048    })?;
5049    let mut strides = vec![1usize; dims.len()];
5050    for d in (0..dims.len().saturating_sub(1)).rev() {
5051        strides[d] = strides[d + 1]
5052            .checked_mul(dims[d + 1])
5053            .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))?;
5054    }
5055
5056    use faer::sparse::SparseRowMat;
5057    let csrs: Vec<SparseRowMat<usize, f64>> = marginal_sparse
5058        .iter()
5059        .enumerate()
5060        .map(|(d, m)| {
5061            m.as_ref().to_row_major().map_err(|e| {
5062                BasisError::SparseCreation(format!(
5063                    "tensor sparse marginal {d} CSR conversion failed: {e:?}"
5064                ))
5065            })
5066        })
5067        .collect::<Result<Vec<_>, _>>()?;
5068    let row_ptrs: Vec<&[usize]> = csrs.iter().map(|c| c.symbolic().row_ptr()).collect();
5069    let col_idxs: Vec<&[usize]> = csrs.iter().map(|c| c.symbolic().col_idx()).collect();
5070    let vals: Vec<&[f64]> = csrs.iter().map(|c| c.val()).collect();
5071
5072    use rayon::prelude::*;
5073    const CHUNK: usize = 1024;
5074    let num_chunks = n.div_ceil(CHUNK);
5075    let per_chunk: Vec<Vec<Triplet<usize, usize, f64>>> = (0..num_chunks)
5076        .into_par_iter()
5077        .map(|chunk_idx| {
5078            let row_start = chunk_idx * CHUNK;
5079            let row_end = (row_start + CHUNK).min(n);
5080            let mut chunk_triplets = Vec::<Triplet<usize, usize, f64>>::new();
5081            let mut cur_cols = Vec::<usize>::with_capacity(64);
5082            let mut cur_vals = Vec::<f64>::with_capacity(64);
5083            let mut next_cols = Vec::<usize>::with_capacity(64);
5084            let mut next_vals = Vec::<f64>::with_capacity(64);
5085            for i in row_start..row_end {
5086                cur_cols.clear();
5087                cur_vals.clear();
5088                cur_cols.push(0);
5089                cur_vals.push(1.0);
5090                let mut row_is_zero = false;
5091                for d in 0..dims.len() {
5092                    let row_start_d = row_ptrs[d][i];
5093                    let row_end_d = row_ptrs[d][i + 1];
5094                    if row_start_d == row_end_d {
5095                        row_is_zero = true;
5096                        break;
5097                    }
5098                    let stride = strides[d];
5099                    next_cols.clear();
5100                    next_vals.clear();
5101                    next_cols.reserve(cur_cols.len() * (row_end_d - row_start_d));
5102                    next_vals.reserve(cur_vals.len() * (row_end_d - row_start_d));
5103                    for (&prev_col, &prev_val) in cur_cols.iter().zip(cur_vals.iter()) {
5104                        for ptr in row_start_d..row_end_d {
5105                            let cj = col_idxs[d][ptr];
5106                            let vj = vals[d][ptr];
5107                            next_cols.push(prev_col + cj * stride);
5108                            next_vals.push(prev_val * vj);
5109                        }
5110                    }
5111                    std::mem::swap(&mut cur_cols, &mut next_cols);
5112                    std::mem::swap(&mut cur_vals, &mut next_vals);
5113                }
5114                if row_is_zero {
5115                    continue;
5116                }
5117                for (&col, &val) in cur_cols.iter().zip(cur_vals.iter()) {
5118                    chunk_triplets.push(Triplet::new(i, col, val));
5119                }
5120            }
5121            chunk_triplets
5122        })
5123        .collect();
5124    let total_nnz: usize = per_chunk.iter().map(Vec::len).sum();
5125    let mut triplets = Vec::<Triplet<usize, usize, f64>>::with_capacity(total_nnz);
5126    for chunk in per_chunk {
5127        triplets.extend(chunk);
5128    }
5129    SparseColMat::try_new_from_triplets(n, total_cols, &triplets).map_err(|e| {
5130        BasisError::SparseCreation(format!(
5131            "failed to assemble sparse tensor product design: {e:?}"
5132        ))
5133    })
5134}
5135
5136pub fn dense_local_margin_to_sparse(
5137    dense: &Array2<f64>,
5138) -> Result<SparseColMat<usize, f64>, BasisError> {
5139    let expected_row_nnz = dense.ncols().min(4);
5140    let mut triplets =
5141        Vec::<Triplet<usize, usize, f64>>::with_capacity(dense.nrows() * expected_row_nnz);
5142    for ((row, col), &value) in dense.indexed_iter() {
5143        if value != 0.0 {
5144            triplets.push(Triplet::new(row, col, value));
5145        }
5146    }
5147    SparseColMat::try_new_from_triplets(dense.nrows(), dense.ncols(), &triplets).map_err(|e| {
5148        BasisError::SparseCreation(format!(
5149            "failed to convert tensor marginal design to sparse form: {e:?}"
5150        ))
5151    })
5152}
5153
5154pub struct TensorMarginRangeNullProjectors {
5155    range: Array2<f64>,
5156    null: Array2<f64>,
5157}
5158
5159pub fn projector_from_columns(columns: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
5160    if indices.is_empty() {
5161        return Array2::<f64>::zeros((columns.nrows(), columns.nrows()));
5162    }
5163    let basis = columns.select(Axis(1), indices);
5164    basis.dot(&basis.t())
5165}
5166
5167pub fn tensor_margin_range_null_projectors(
5168    normalized_marginal_penalties: &[(Array2<f64>, f64)],
5169) -> Result<Vec<TensorMarginRangeNullProjectors>, BasisError> {
5170    normalized_marginal_penalties
5171        .iter()
5172        .enumerate()
5173        .map(|(dim, (penalty, _))| {
5174            let analysis = crate::basis::analyze_penalty_block(penalty)?;
5175            if analysis.rank == 0 {
5176                crate::bail_invalid_basis!(
5177                    "t2 separable tensor penalty margin {dim} has rank-zero penalty; \
5178                     cannot split penalized and null subspaces"
5179                );
5180            }
5181            let mut range_idx = Vec::<usize>::new();
5182            let mut null_idx = Vec::<usize>::new();
5183            for (idx, &ev) in analysis.eigenvalues.iter().enumerate() {
5184                if ev > analysis.tol {
5185                    range_idx.push(idx);
5186                } else {
5187                    null_idx.push(idx);
5188                }
5189            }
5190            Ok(TensorMarginRangeNullProjectors {
5191                range: projector_from_columns(&analysis.eigenvectors, &range_idx),
5192                null: projector_from_columns(&analysis.eigenvectors, &null_idx),
5193            })
5194        })
5195        .collect()
5196}
5197
5198pub fn build_tensor_bspline_basis(
5199    data: ArrayView2<'_, f64>,
5200    feature_cols: &[usize],
5201    spec: &TensorBSplineSpec,
5202) -> Result<BasisBuildResult, BasisError> {
5203    if feature_cols.is_empty() {
5204        crate::bail_invalid_basis!("TensorBSpline requires at least one feature column");
5205    }
5206    if feature_cols.len() != spec.marginalspecs.len() {
5207        crate::bail_dim_basis!(
5208            "TensorBSpline feature/spec mismatch: feature_cols={}, marginalspecs={}",
5209            feature_cols.len(),
5210            spec.marginalspecs.len()
5211        );
5212    }
5213    if !spec.periods.is_empty() && spec.periods.len() != feature_cols.len() {
5214        crate::bail_dim_basis!(
5215            "TensorBSpline periods length {} does not match feature count {}",
5216            spec.periods.len(),
5217            feature_cols.len()
5218        );
5219    }
5220    let p = data.ncols();
5221    for &c in feature_cols {
5222        if c >= p {
5223            crate::bail_dim_basis!(
5224                "tensor feature column {c} is out of bounds for data with {p} columns"
5225            );
5226        }
5227    }
5228
5229    let mut marginal_knots = Vec::<Array1<f64>>::with_capacity(feature_cols.len());
5230    // Per-margin cr flag (#1074): `true` when the margin is a natural cubic
5231    // regression spline, so the tensor freeze rebuilds the cr knotspec.
5232    let mut marginal_is_cr_flags = Vec::<bool>::with_capacity(feature_cols.len());
5233    let mut marginal_degrees = Vec::<usize>::with_capacity(feature_cols.len());
5234    let mut marginalnum_basis = Vec::<usize>::with_capacity(feature_cols.len());
5235    let mut marginal_penalties = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
5236    let mut marginal_designs = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
5237    // Per-margin effective period: either user-set via `spec.periods` or
5238    // implied by a `PeriodicUniform` marginal knotspec (which the 1D B-spline
5239    // builder realizes as a cyclic B-spline basis).
5240    // Captured here so freeze→reload round-trips both routes back to a
5241    // `PeriodicUniform` marginal knotspec; otherwise a `PeriodicUniform`
5242    // margin specified without `spec.periods` would freeze as a plain
5243    // `Provided(knots)` open spline and lose its wrap-around at predict time.
5244    let mut marginal_effective_periods = Vec::<Option<f64>>::with_capacity(feature_cols.len());
5245    // Per-marginal sparse representation, populated when the 1D builder returned
5246    // a `DesignMatrix::Sparse`. Used to assemble the Khatri-Rao tensor product
5247    // sparsely (only ∏(degree+1) nonzeros per row) instead of densifying to
5248    // shape (n, ∏ q_j) up front. Periodic B-spline margins are local-support
5249    // bases too; when the 1D builder returns them densely, we convert that
5250    // marginal back to sparse form so cylinder/torus tensor products keep the
5251    // same scale behavior as open tensor products.
5252    let mut marginal_sparse =
5253        Vec::<Option<SparseColMat<usize, f64>>>::with_capacity(feature_cols.len());
5254
5255    // Reuse the robust 1D builder to ensure the same knot validation and
5256    // marginal difference-penalty construction as standalone smooth terms.
5257    for (dim, (&col, marginalspec)) in feature_cols
5258        .iter()
5259        .zip(spec.marginalspecs.iter())
5260        .enumerate()
5261    {
5262        // Tensor basis uses raw marginal knot-product columns. Applying 1D
5263        // identifiability constraints here would change marginal penalty sizes
5264        // without changing the tensor design construction, causing dimension
5265        // mismatch. Keep marginal builders unconstrained at this stage.
5266        let mut marginal_unconstrained = marginalspec.clone();
5267        marginal_unconstrained.identifiability = BSplineIdentifiability::None;
5268        let built = build_bspline_basis_1d(data.column(col), &marginal_unconstrained)?;
5269        // A cr (`NaturalCubicRegression`) margin emits `CubicRegression1D`
5270        // metadata whose `knots` are the k value-knots; a B-spline margin emits
5271        // `BSpline1D` with the clamped knot vector. Capture either so the
5272        // tensor freeze can rebuild the exact same marginal knotspec (#1074).
5273        let (knots, marginal_is_cr) = match built.metadata {
5274            BasisMetadata::BSpline1D { knots, .. } => (knots, false),
5275            BasisMetadata::CubicRegression1D { knots, .. } => (knots, true),
5276            _ => {
5277                crate::bail_invalid_basis!(
5278                    "internal TensorBSpline error at dim {dim}: expected BSpline1D or CubicRegression1D metadata"
5279                );
5280            }
5281        };
5282        let metadata_knots = match marginalspec.knotspec {
5283            BSplineKnotSpec::PeriodicUniform {
5284                data_range,
5285                num_basis,
5286            } => Array1::linspace(data_range.0, data_range.1, num_basis),
5287            _ => knots,
5288        };
5289        marginal_knots.push(metadata_knots);
5290        marginal_is_cr_flags.push(marginal_is_cr);
5291        marginal_degrees.push(marginalspec.degree);
5292        marginalnum_basis.push(built.design.ncols());
5293        // Capture the sparse representation of this marginal (when the
5294        // 1D builder produced one) before densifying for the dense
5295        // marginal cache used by `tensor_product_design_from_marginals`
5296        // and `TensorProductDesignOperator`.
5297        let dense_marginal = built.design.to_dense();
5298        let sparse_view: Option<SparseColMat<usize, f64>> = match built.design.as_sparse() {
5299            Some(sd) => {
5300                let inner: &SparseColMat<usize, f64> = sd;
5301                Some(inner.clone())
5302            }
5303            None => match marginalspec.knotspec {
5304                BSplineKnotSpec::PeriodicUniform { .. } => {
5305                    Some(dense_local_margin_to_sparse(&dense_marginal)?)
5306                }
5307                _ => None,
5308            },
5309        };
5310        marginal_sparse.push(sparse_view);
5311        marginal_designs.push(dense_marginal);
5312        marginal_penalties.push(
5313            built
5314                .penalties
5315                .first()
5316                .ok_or_else(|| {
5317                    BasisError::InvalidInput(format!(
5318                        "internal TensorBSpline error at dim {dim}: missing marginal penalty"
5319                    ))
5320                })?
5321                .clone(),
5322        );
5323        built.nullspace_dims.first().ok_or_else(|| {
5324            BasisError::InvalidInput(format!(
5325                "internal TensorBSpline error at dim {dim}: missing marginal nullspace dim"
5326            ))
5327        })?;
5328        // A `PeriodicUniform` marginal knotspec implies the margin is
5329        // wrap-around: the 1D builder already realized it as a periodic
5330        // basis, so the tensor product inherits that periodicity. Record
5331        // the period derived from the knotspec's data range so freeze
5332        // restores `PeriodicUniform` on the marginal — otherwise the
5333        // round-trip downgrades it to `Provided(knots)` (an open spline)
5334        // and predict-time wraps disappear.
5335        let implied_period = match marginalspec.knotspec {
5336            BSplineKnotSpec::PeriodicUniform { data_range, .. } => {
5337                Some(data_range.1 - data_range.0)
5338            }
5339            _ => spec.periods.get(dim).and_then(|p| *p),
5340        };
5341        marginal_effective_periods.push(implied_period);
5342    }
5343
5344    let total_cols: usize = marginalnum_basis.iter().product();
5345    let mut dense_design = (!matches!(spec.identifiability, TensorBSplineIdentifiability::None))
5346        .then(|| tensor_product_design_from_marginals(&marginal_designs))
5347        .transpose()?;
5348    let mut candidates = Vec::<PenaltyCandidate>::with_capacity(
5349        match spec.penalty_decomposition {
5350            TensorBSplinePenaltyDecomposition::MarginalKroneckerSum => marginal_penalties.len(),
5351            TensorBSplinePenaltyDecomposition::Separable => marginal_penalties.len() * 2,
5352        } + if spec.double_penalty { 1 } else { 0 },
5353    );
5354
5355    // Tensor-product smoothing parameters are one-per-margin.  Therefore the
5356    // physical penalty attached to a margin must be normalized in that margin's
5357    // own working coordinates before it is embedded in the full tensor product.
5358    // Normalizing only the already-Kroneckered matrix would fold arbitrary
5359    // dimension-dependent identity factors into the margin's lambda and would
5360    // make anisotropic REML/LAML smoothing depend on the other margins' basis
5361    // sizes rather than on the marginal roughness operator itself.
5362    let normalized_marginal_penalties: Vec<(Array2<f64>, f64)> = marginal_penalties
5363        .iter()
5364        .map(normalize_penalty_in_constrained_space)
5365        .collect();
5366    let mut kronecker_marginal_penalties =
5367        Vec::<Array2<f64>>::with_capacity(normalized_marginal_penalties.len());
5368
5369    match spec.penalty_decomposition {
5370        TensorBSplinePenaltyDecomposition::MarginalKroneckerSum => {
5371            // Accumulate the Kronecker-sum of the per-margin penalties,
5372            // `Σ_dim S_dim`, whose null space is exactly the *joint* null space
5373            // of all marginal penalties — the tensor of marginal polynomial
5374            // null spaces. The tensor double penalty (below) shrinks only this
5375            // joint null, never the already-penalized interaction range.
5376            let mut marginal_kron_sum = Array2::<f64>::zeros((total_cols, total_cols));
5377
5378            for dim in 0..normalized_marginal_penalties.len() {
5379                let mut s_dim = Array2::<f64>::eye(1);
5380                let mut factors = Vec::<Array2<f64>>::with_capacity(marginalnum_basis.len());
5381                for (j, &qj) in marginalnum_basis.iter().enumerate() {
5382                    let factor = if j == dim {
5383                        normalized_marginal_penalties[j].0.clone()
5384                    } else {
5385                        Array2::<f64>::eye(qj)
5386                    };
5387                    factors.push(factor.clone());
5388                    s_dim = kronecker_product(&s_dim, &factor);
5389                }
5390                if dim == kronecker_marginal_penalties.len() {
5391                    kronecker_marginal_penalties.push(normalized_marginal_penalties[dim].0.clone());
5392                }
5393                marginal_kron_sum += &s_dim;
5394
5395                candidates.push(PenaltyCandidate {
5396                    matrix: s_dim,
5397                    nullspace_dim_hint: 0,
5398                    source: PenaltySource::TensorMarginal { dim },
5399                    normalization_scale: normalized_marginal_penalties[dim].1,
5400                    kronecker_factors: Some(factors),
5401                    op: None,
5402                });
5403            }
5404
5405            if spec.double_penalty
5406                && let Some(shrink) =
5407                    crate::basis::build_nullspace_shrinkage_penalty(&marginal_kron_sum)?
5408            {
5409                let (matrix, normalization_scale) =
5410                    normalize_penalty_in_constrained_space(&shrink.sym_penalty);
5411                candidates.push(PenaltyCandidate {
5412                    matrix,
5413                    nullspace_dim_hint: 0,
5414                    source: PenaltySource::TensorGlobalRidge,
5415                    normalization_scale,
5416                    kronecker_factors: None,
5417                    op: None,
5418                });
5419            }
5420        }
5421        TensorBSplinePenaltyDecomposition::Separable => {
5422            let projectors = tensor_margin_range_null_projectors(&normalized_marginal_penalties)?;
5423            let n_masks = 1usize.checked_shl(projectors.len() as u32).ok_or_else(|| {
5424                BasisError::InvalidInput(format!(
5425                    "t2 separable tensor penalty supports at most {} margins, got {}",
5426                    usize::BITS - 1,
5427                    projectors.len()
5428                ))
5429            })?;
5430            for mask in 1..n_masks {
5431                let mut matrix = Array2::<f64>::eye(1);
5432                let mut factors = Vec::<Array2<f64>>::with_capacity(projectors.len());
5433                let mut penalized_margins = Vec::<usize>::new();
5434                for (dim, projector) in projectors.iter().enumerate() {
5435                    let use_range = ((mask >> dim) & 1) == 1;
5436                    let factor = if use_range {
5437                        penalized_margins.push(dim);
5438                        projector.range.clone()
5439                    } else {
5440                        projector.null.clone()
5441                    };
5442                    matrix = kronecker_product(&matrix, &factor);
5443                    factors.push(factor);
5444                }
5445                let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&matrix);
5446                candidates.push(PenaltyCandidate {
5447                    matrix,
5448                    nullspace_dim_hint: 0,
5449                    source: PenaltySource::TensorSeparable { penalized_margins },
5450                    normalization_scale,
5451                    kronecker_factors: Some(factors),
5452                    op: None,
5453                });
5454            }
5455
5456            if spec.double_penalty {
5457                let mut matrix = Array2::<f64>::eye(1);
5458                let mut factors = Vec::<Array2<f64>>::with_capacity(projectors.len());
5459                for projector in &projectors {
5460                    matrix = kronecker_product(&matrix, &projector.null);
5461                    factors.push(projector.null.clone());
5462                }
5463                let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&matrix);
5464                candidates.push(PenaltyCandidate {
5465                    matrix,
5466                    nullspace_dim_hint: 0,
5467                    source: PenaltySource::TensorGlobalRidge,
5468                    normalization_scale,
5469                    kronecker_factors: Some(factors),
5470                    op: None,
5471                });
5472            }
5473        }
5474    }
5475
5476    let z_opt = match &spec.identifiability {
5477        TensorBSplineIdentifiability::None => None,
5478        TensorBSplineIdentifiability::SumToZero => {
5479            if total_cols < 2 {
5480                crate::bail_invalid_basis!(
5481                    "TensorBSpline requires at least 2 basis coefficients to enforce sum-to-zero identifiability"
5482                );
5483            }
5484            let dense_design_ref = dense_design.as_ref().ok_or_else(|| {
5485                BasisError::InvalidInput(
5486                    "tensor sum-to-zero identifiability requires a realized basis".to_string(),
5487                )
5488            })?;
5489            let (_, z) = apply_sum_to_zero_constraint(dense_design_ref.view(), None)?;
5490            let gauge = gam_problem::Gauge::sum_to_zero(z);
5491            Some(gauge.block_transform(0))
5492        }
5493        TensorBSplineIdentifiability::MarginalSumToZero => {
5494            // `ti(...)`: drop the marginal main effects by centering every
5495            // margin independently, then form the tensor product of the
5496            // centered margins. Concretely, each margin `j` is reparameterized
5497            // by its own sum-to-zero null basis `Z_j` (so the constant — i.e.
5498            // the marginal intercept — is removed from that axis), and the
5499            // combined reparameterization is the Kronecker product
5500            // `Z = Z₀ ⊗ Z₁ ⊗ … ⊗ Z_{d-1}`. Applying `Z` to the full-tensor
5501            // design `B = B₀ ⊗ … ⊗ B_{d-1}` yields `B Z = (B₀ Z₀) ⊗ … ⊗
5502            // (B_{d-1} Z_{d-1})`, the tensor product of the centered margins,
5503            // which by construction contains no pure main effect.
5504            if marginal_designs.len() < 2 {
5505                crate::bail_invalid_basis!(
5506                    "tensor interaction (ti) identifiability requires at least 2 margins"
5507                );
5508            }
5509            let mut z = Array2::<f64>::eye(1);
5510            for (dim, marginal) in marginal_designs.iter().enumerate() {
5511                if marginal.ncols() < 2 {
5512                    crate::bail_invalid_basis!(
5513                        "tensor interaction (ti) margin {dim} has fewer than 2 basis functions; \
5514                         cannot remove its marginal main effect"
5515                    );
5516                }
5517                let (_, z_dim) = apply_sum_to_zero_constraint(marginal.view(), None)?;
5518                let gauge_dim = gam_problem::Gauge::sum_to_zero(z_dim);
5519                let z_dim = gauge_dim.block_transform(0);
5520                z = kronecker_product(&z, &z_dim);
5521            }
5522            Some(z)
5523        }
5524        TensorBSplineIdentifiability::FrozenTransform { transform } => {
5525            if transform.nrows() != total_cols {
5526                crate::bail_dim_basis!(
5527                    "frozen tensor identifiability transform mismatch: design has {} columns but transform has {} rows",
5528                    total_cols,
5529                    transform.nrows()
5530                );
5531            }
5532            Some(transform.clone())
5533        }
5534    };
5535
5536    if let Some(z) = z_opt.as_ref() {
5537        let gauge = gam_problem::Gauge::from_block_transforms(&[z.clone()]);
5538        let dense = dense_design.as_mut().ok_or_else(|| {
5539            BasisError::InvalidInput(
5540                "tensor identifiability transform requires a realized basis".to_string(),
5541            )
5542        })?;
5543        let restricted_design = gauge.restrict_design(dense);
5544        *dense = restricted_design;
5545        candidates = candidates
5546            .into_iter()
5547            .map(|candidate| -> Result<PenaltyCandidate, BasisError> {
5548                let matrix = gauge.restrict_penalty(&candidate.matrix);
5549                // Re-normalize in the *actual* coefficient chart used by the
5550                // fit.  The tensor sum-to-zero transform is not norm-preserving
5551                // for each overlapping marginal penalty, so carrying the raw
5552                // marginal Frobenius scale into the restricted space changes the
5553                // relative amount of smoothing seen by the LAML/REML optimizer.
5554                // Keep the physical scale in metadata and give the optimizer
5555                // unit-scale constrained penalties for every tensor margin.
5556                let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
5557                Ok(PenaltyCandidate {
5558                    nullspace_dim_hint: candidate.nullspace_dim_hint,
5559                    matrix,
5560                    source: candidate.source,
5561                    normalization_scale: candidate.normalization_scale * c_new,
5562                    // Z^T S Z is no longer a Kronecker product of the original
5563                    // marginal factors, so the Kronecker fast path in construction.rs
5564                    // must not be taken. Clearing kronecker_factors forces the generic
5565                    // block-local eigendecomposition path, which operates on the
5566                    // transformed matrix and is correct.
5567                    kronecker_factors: None,
5568                    op: candidate.op.clone(),
5569                })
5570            })
5571            .collect::<Result<Vec<_>, _>>()?;
5572    }
5573
5574    let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
5575        filter_active_penalty_candidates_with_ops(candidates)?;
5576    let identifiability_is_none =
5577        matches!(spec.identifiability, TensorBSplineIdentifiability::None);
5578    // All marginals expose a sparse representation iff each `marginal_sparse`
5579    // slot is `Some(...)`. Currently this is true when every marginal is a
5580    // free-boundary, non-periodic 1D B-spline returned as
5581    // `DesignMatrix::Sparse` from `build_bspline_basis_1d`. Periodic B-splines
5582    // and other dense-only marginals leave a `None` and trigger the fall-back
5583    // path. Identifiability transforms (`SumToZero`, `FrozenTransform`) make
5584    // the tensor design dense in general, so we also gate on that.
5585    let all_marginals_sparse = marginal_sparse.iter().all(Option::is_some);
5586    let design = if let Some(dense_design) = dense_design {
5587        DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense_design))
5588    } else if identifiability_is_none && all_marginals_sparse {
5589        // Sparse Khatri-Rao path: assemble the (n, ∏ q_j) tensor product
5590        // directly as a SparseColMat, preserving the ∏(degree_j+1) nonzero
5591        // structure per row instead of densifying to ∏ q_j columns. This is
5592        // mathematically identical to `tensor_product_design_from_marginals`
5593        // applied to the corresponding dense marginals.
5594        let sparse_marginals: Vec<&SparseColMat<usize, f64>> = marginal_sparse
5595            .iter()
5596            .map(|m| m.as_ref().expect("all_marginals_sparse just verified"))
5597            .collect();
5598        let sparse_design = tensor_product_design_from_sparse_marginals(&sparse_marginals)?;
5599        DesignMatrix::Sparse(gam_linalg::matrix::SparseDesignMatrix::new(sparse_design))
5600    } else {
5601        let marginals: Vec<Arc<Array2<f64>>> = marginal_designs
5602            .iter()
5603            .map(|m| Arc::new(m.clone()))
5604            .collect();
5605        let op = TensorProductDesignOperator::new(marginals).map_err(|e| {
5606            BasisError::InvalidInput(format!("TensorProductDesignOperator build failed: {e}"))
5607        })?;
5608        DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op)))
5609    };
5610
5611    Ok(BasisBuildResult {
5612        design,
5613        penalties,
5614        nullspace_dims,
5615        penaltyinfo,
5616        ops,
5617        null_eigenvectors,
5618        joint_null_rotation: None,
5619        metadata: BasisMetadata::TensorBSpline {
5620            feature_cols: feature_cols.to_vec(),
5621            knots: marginal_knots,
5622            degrees: marginal_degrees,
5623            // Prefer the per-margin effective period derived in the loop —
5624            // it captures both the explicit `spec.periods` route and the
5625            // implied period from a `PeriodicUniform` marginal knotspec.
5626            // Falling back to `spec.periods` when populated keeps any
5627            // user-supplied explicit period authoritative even if the
5628            // marginal knotspec carried no periodicity hint.
5629            periods: marginal_effective_periods,
5630            is_cr: marginal_is_cr_flags,
5631            identifiability_transform: z_opt,
5632        },
5633        kronecker_factored: if matches!(spec.identifiability, TensorBSplineIdentifiability::None)
5634            && matches!(
5635                spec.penalty_decomposition,
5636                TensorBSplinePenaltyDecomposition::MarginalKroneckerSum
5637            ) {
5638            Some(KroneckerFactoredBasis::new(
5639                marginal_designs,
5640                kronecker_marginal_penalties,
5641                marginalnum_basis.clone(),
5642                spec.double_penalty,
5643            ))
5644        } else {
5645            None
5646        },
5647    })
5648}
5649
5650pub fn tensor_product_design_from_marginals(
5651    marginal_designs: &[Array2<f64>],
5652) -> Result<Array2<f64>, BasisError> {
5653    if marginal_designs.is_empty() {
5654        crate::bail_invalid_basis!("TensorBSpline requires at least one marginal basis");
5655    }
5656    let n = marginal_designs[0].nrows();
5657    for (i, b) in marginal_designs.iter().enumerate().skip(1) {
5658        if b.nrows() != n {
5659            crate::bail_dim_basis!(
5660                "tensor marginal row mismatch at dim {i}: expected {n}, got {}",
5661                b.nrows()
5662            );
5663        }
5664    }
5665    let total_cols = marginal_designs.iter().try_fold(1usize, |acc, b| {
5666        acc.checked_mul(b.ncols())
5667            .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))
5668    })?;
5669    // Tensor-product Khatri-Rao: design[i, j] = Π_d marginal_d[i, j_d]
5670    // where j is the multi-index (j_1, ..., j_D) flattened. Independent
5671    // across rows; parallelize row chunks and fill the pre-allocated
5672    // contiguous Array2 in place (no Vec-flatten-collect intermediate,
5673    // which doubled the peak memory at large-scale N).
5674    use ndarray::parallel::prelude::*;
5675    use rayon::iter::{IntoParallelIterator, ParallelIterator};
5676    let mut design = Array2::<f64>::zeros((n, total_cols));
5677    design
5678        .axis_chunks_iter_mut(ndarray::Axis(0), 1024)
5679        .into_par_iter()
5680        .enumerate()
5681        .for_each(|(chunk_idx, mut block)| {
5682            let row_offset = chunk_idx * 1024;
5683            // Scratch buffers reused across rows in this chunk.
5684            let mut cur = Vec::<f64>::with_capacity(total_cols);
5685            let mut next = Vec::<f64>::with_capacity(total_cols);
5686            for (local_i, mut out_row) in block.outer_iter_mut().enumerate() {
5687                let i = row_offset + local_i;
5688                cur.clear();
5689                cur.push(1.0);
5690                for b in marginal_designs {
5691                    let q = b.ncols();
5692                    next.clear();
5693                    next.resize(cur.len() * q, 0.0);
5694                    // Hoist the row view out of the inner `col` loop so the
5695                    // q reads per `a_idx` reuse a single contiguous slice
5696                    // instead of recomputing `b[[i, col]]` strides per cell.
5697                    let b_row = b.row(i);
5698                    let b_slice = b_row
5699                        .as_slice()
5700                        .expect("Array2 row from outer_iter is contiguous");
5701                    for (a_idx, &aval) in cur.iter().enumerate() {
5702                        let off = a_idx * q;
5703                        let dst = &mut next[off..off + q];
5704                        for col in 0..q {
5705                            dst[col] = aval * b_slice[col];
5706                        }
5707                    }
5708                    std::mem::swap(&mut cur, &mut next);
5709                }
5710                // `out_row` is a row of the contiguous C-major `design`
5711                // Array2, so it is backed by a contiguous slice. Use a
5712                // bulk slice copy instead of an element-by-element write
5713                // loop.
5714                let out_slice = out_row
5715                    .as_slice_mut()
5716                    .expect("design row is contiguous in C-major Array2");
5717                out_slice.copy_from_slice(&cur);
5718            }
5719        });
5720    Ok(design)
5721}
5722
5723pub fn build_random_effect_block(
5724    data: ArrayView2<'_, f64>,
5725    spec: &RandomEffectTermSpec,
5726) -> Result<RandomEffectBlock, BasisError> {
5727    let n = data.nrows();
5728    let p = data.ncols();
5729    if spec.feature_col >= p {
5730        crate::bail_dim_basis!(
5731            "random-effect term '{}' feature column {} out of bounds for {} columns",
5732            spec.name,
5733            spec.feature_col,
5734            p
5735        );
5736    }
5737
5738    let col = data.column(spec.feature_col);
5739    if col.iter().any(|v| !v.is_finite()) {
5740        crate::bail_invalid_basis!(
5741            "random-effect term '{}' contains non-finite group values",
5742            spec.name
5743        );
5744    }
5745
5746    let kept_levels: Vec<u64> = if let Some(levels) = spec.frozen_levels.as_ref() {
5747        if levels.is_empty() {
5748            crate::bail_invalid_basis!(
5749                "random-effect term '{}' has empty frozen_levels",
5750                spec.name
5751            );
5752        }
5753        levels.clone()
5754    } else {
5755        let mut seen = BTreeSet::<u64>::new();
5756        let mut levels = Vec::<u64>::new();
5757        for &v in col {
5758            let bits = v.to_bits();
5759            if seen.insert(bits) {
5760                levels.push(bits);
5761            }
5762        }
5763        if levels.is_empty() {
5764            crate::bail_invalid_basis!("random-effect term '{}' has no observed levels", spec.name);
5765        }
5766        let start_idx = if spec.drop_first_level && levels.len() > 1 {
5767            1usize
5768        } else {
5769            0usize
5770        };
5771        levels[start_idx..].to_vec()
5772    };
5773
5774    if kept_levels.is_empty() {
5775        crate::bail_invalid_basis!(
5776            "random-effect term '{}' drops all levels; keep at least one level",
5777            spec.name
5778        );
5779    }
5780
5781    let q = kept_levels.len();
5782    let mut level_to_col = BTreeMap::<u64, usize>::new();
5783    for (idx, &bits) in kept_levels.iter().enumerate() {
5784        if level_to_col.insert(bits, idx).is_some() {
5785            crate::bail_invalid_basis!(
5786                "random-effect term '{}' has duplicate frozen level bits {bits}",
5787                spec.name
5788            );
5789        }
5790    }
5791    let mut group_ids = Vec::with_capacity(n);
5792    for &v in col {
5793        let bits = v.to_bits();
5794        group_ids.push(level_to_col.get(&bits).copied());
5795    }
5796
5797    Ok(RandomEffectBlock {
5798        name: spec.name.clone(),
5799        group_ids,
5800        num_groups: q,
5801        kept_levels,
5802    })
5803}
5804
5805impl SmoothDesign {
5806    /// Map an unconstrained term coefficient vector to its constrained shape space.
5807    /// This is useful for nonlinear fits that optimize unconstrained parameters.
5808    pub fn map_term_coefficients(
5809        unconstrained: &Array1<f64>,
5810        shape: ShapeConstraint,
5811    ) -> Result<Array1<f64>, BasisError> {
5812        if unconstrained.is_empty() {
5813            crate::bail_invalid_basis!("unconstrained coefficient vector cannot be empty");
5814        }
5815        let mapped = match shape {
5816            ShapeConstraint::None => unconstrained.clone(),
5817            ShapeConstraint::MonotoneIncreasing => cumulative_exp(unconstrained, 1.0),
5818            ShapeConstraint::MonotoneDecreasing => cumulative_exp(unconstrained, -1.0),
5819            ShapeConstraint::Convex => second_cumulative_exp(unconstrained, 1.0),
5820            ShapeConstraint::Concave => second_cumulative_exp(unconstrained, -1.0),
5821        };
5822        Ok(mapped)
5823    }
5824}
5825
5826pub struct LocalSmoothTermBuild {
5827    pub dim: usize,
5828    pub design: DesignMatrix,
5829    pub penalties: Vec<Array2<f64>>,
5830    pub ops: Vec<Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>>,
5831    pub nullspaces: Vec<usize>,
5832    /// Per-active-penalty null-space eigenvector matrices, parallel to
5833    /// `penalties` / `ops` / `nullspaces`. `Some(U_null)` when
5834    /// `nullspaces[k] > 0`, with `U_null` orthonormal columns spanning
5835    /// `null(penalties[k])` in this smooth's local coordinate system; `None`
5836    /// when the active block is full-rank. Stage 1 plumbing; Stage 2
5837    /// consumes this to absorb the smooth's null space into the parametric
5838    /// block at `TermCollectionDesign` construction.
5839    pub null_eigenvectors: Vec<Option<Array2<f64>>>,
5840    /// Joint-null absorption rotation for this smooth. `Some(rotation)`
5841    /// records `Q = [U_range | U_null]` spanning `null(Σ_k penalties[k])`,
5842    /// the joint null across all active penalty blocks on this smooth.
5843    /// `None` means the joint penalty is full-rank (joint nullity = 0) or
5844    /// there are no penalties. Stage-2 commit A: plumbing only — populated
5845    /// by commit B, applied by commit D.
5846    pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
5847    pub penaltyinfo: Vec<PenaltyInfo>,
5848    pub pre_dropped_penaltyinfo: Vec<PenaltyInfo>,
5849    pub metadata: BasisMetadata,
5850    pub linear_constraints: Option<LinearInequalityConstraints>,
5851    pub box_reparam: bool,
5852    pub kronecker_factored: Option<KroneckerFactoredBasis>,
5853}
5854
5855#[derive(Clone)]
5856pub struct PcaScoresMemmapDesignOperator {
5857    mmap: Arc<memmap2::Mmap>,
5858    data_offset: usize,
5859    nrows: usize,
5860    ncols: usize,
5861    chunk_size: usize,
5862}
5863
5864impl PcaScoresMemmapDesignOperator {
5865    fn open(path: PathBuf, chunk_size: usize) -> Result<Self, BasisError> {
5866        let file = File::open(&path).map_err(|err| {
5867            BasisError::InvalidInput(format!(
5868                "failed to open lazy Pca .npy scores '{}': {err}",
5869                path.display()
5870            ))
5871        })?;
5872        // The .npy scores file is read-only training-cache data; this
5873        // module never mutates it. The error path below converts mmap
5874        // failure to a typed `BasisError::InvalidInput`.
5875        // SAFETY: `memmap2::Mmap::map` requires no concurrent writers; the
5876        // contract is held by this module's read-only access pattern.
5877        let mmap = unsafe {
5878            memmap2::Mmap::map(&file).map_err(|err| {
5879                BasisError::InvalidInput(format!(
5880                    "failed to memmap lazy Pca .npy scores '{}': {err}",
5881                    path.display()
5882                ))
5883            })?
5884        };
5885        let (data_offset, nrows, ncols) = parse_f64_2d_npy_header(&mmap, &path)?;
5886        let expected = data_offset
5887            .checked_add(nrows.saturating_mul(ncols).saturating_mul(8))
5888            .ok_or_else(|| {
5889                BasisError::InvalidInput(format!(
5890                    "lazy Pca .npy scores '{}' shape is too large",
5891                    path.display()
5892                ))
5893            })?;
5894        if mmap.len() < expected {
5895            crate::bail_invalid_basis!(
5896                "lazy Pca .npy scores '{}' is truncated: header expects {} bytes, file has {}",
5897                path.display(),
5898                expected,
5899                mmap.len()
5900            );
5901        }
5902        Ok(Self {
5903            mmap: Arc::new(mmap),
5904            data_offset,
5905            nrows,
5906            ncols,
5907            chunk_size: chunk_size.max(1),
5908        })
5909    }
5910
5911    fn value(&self, row: usize, col: usize) -> f64 {
5912        let offset = self.data_offset + (row * self.ncols + col) * 8;
5913        let mut bytes = [0_u8; 8];
5914        bytes.copy_from_slice(&self.mmap[offset..offset + 8]);
5915        f64::from_le_bytes(bytes)
5916    }
5917
5918    fn chunk_rows(&self) -> usize {
5919        self.chunk_size.min(self.nrows.max(1))
5920    }
5921}
5922
5923impl LinearOperator for PcaScoresMemmapDesignOperator {
5924    fn nrows(&self) -> usize {
5925        self.nrows
5926    }
5927
5928    fn ncols(&self) -> usize {
5929        self.ncols
5930    }
5931
5932    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
5933        assert_eq!(
5934            vector.len(),
5935            self.ncols,
5936            "lazy Pca apply vector length mismatch"
5937        );
5938        let mut out = Array1::<f64>::zeros(self.nrows);
5939        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5940            let end = (start + self.chunk_rows()).min(self.nrows);
5941            for row in start..end {
5942                let mut acc = 0.0;
5943                for col in 0..self.ncols {
5944                    acc += self.value(row, col) * vector[col];
5945                }
5946                out[row] = acc;
5947            }
5948        }
5949        out
5950    }
5951
5952    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
5953        assert_eq!(
5954            vector.len(),
5955            self.nrows,
5956            "lazy Pca apply_transpose vector length mismatch"
5957        );
5958        let mut out = Array1::<f64>::zeros(self.ncols);
5959        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5960            let end = (start + self.chunk_rows()).min(self.nrows);
5961            for row in start..end {
5962                let scale = vector[row];
5963                if scale == 0.0 {
5964                    continue;
5965                }
5966                for col in 0..self.ncols {
5967                    out[col] += scale * self.value(row, col);
5968                }
5969            }
5970        }
5971        out
5972    }
5973
5974    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
5975        if weights.len() != self.nrows {
5976            return Err(format!(
5977                "lazy Pca diag_xtw_x weight length mismatch: weights={}, nrows={}",
5978                weights.len(),
5979                self.nrows
5980            ));
5981        }
5982        let mut gram = Array2::<f64>::zeros((self.ncols, self.ncols));
5983        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5984            let end = (start + self.chunk_rows()).min(self.nrows);
5985            for row in start..end {
5986                let w = weights[row];
5987                if w == 0.0 {
5988                    continue;
5989                }
5990                for a in 0..self.ncols {
5991                    let xa = self.value(row, a);
5992                    if xa == 0.0 {
5993                        continue;
5994                    }
5995                    for b in a..self.ncols {
5996                        gram[[a, b]] += w * xa * self.value(row, b);
5997                    }
5998                }
5999            }
6000        }
6001        for a in 0..self.ncols {
6002            for b in 0..a {
6003                gram[[a, b]] = gram[[b, a]];
6004            }
6005        }
6006        Ok(gram)
6007    }
6008
6009    fn apply_weighted_normal(
6010        &self,
6011        weights: &Array1<f64>,
6012        vector: &Array1<f64>,
6013        penalty: Option<&Array2<f64>>,
6014        ridge: f64,
6015    ) -> Array1<f64> {
6016        assert_eq!(
6017            weights.len(),
6018            self.nrows,
6019            "lazy Pca weighted-normal weight mismatch"
6020        );
6021        assert_eq!(
6022            vector.len(),
6023            self.ncols,
6024            "lazy Pca weighted-normal vector mismatch"
6025        );
6026        let mut out = Array1::<f64>::zeros(self.ncols);
6027        for start in (0..self.nrows).step_by(self.chunk_rows()) {
6028            let end = (start + self.chunk_rows()).min(self.nrows);
6029            for row in start..end {
6030                let w = weights[row].max(0.0);
6031                if w == 0.0 {
6032                    continue;
6033                }
6034                let mut row_dot = 0.0;
6035                for col in 0..self.ncols {
6036                    row_dot += self.value(row, col) * vector[col];
6037                }
6038                if row_dot == 0.0 {
6039                    continue;
6040                }
6041                let scaled = w * row_dot;
6042                for col in 0..self.ncols {
6043                    out[col] += scaled * self.value(row, col);
6044                }
6045            }
6046        }
6047        if let Some(pen) = penalty {
6048            out += &pen.dot(vector);
6049        }
6050        if ridge > 0.0 {
6051            out += &vector.mapv(|x| ridge * x);
6052        }
6053        out
6054    }
6055}
6056
6057impl DenseDesignOperator for PcaScoresMemmapDesignOperator {
6058    fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
6059        if weights.len() != self.nrows || y.len() != self.nrows {
6060            return Err(format!(
6061                "lazy Pca compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
6062                weights.len(),
6063                y.len(),
6064                self.nrows
6065            ));
6066        }
6067        let mut out = Array1::<f64>::zeros(self.ncols);
6068        for start in (0..self.nrows).step_by(self.chunk_rows()) {
6069            let end = (start + self.chunk_rows()).min(self.nrows);
6070            for row in start..end {
6071                let scale = weights[row] * y[row];
6072                if scale == 0.0 {
6073                    continue;
6074                }
6075                for col in 0..self.ncols {
6076                    out[col] += scale * self.value(row, col);
6077                }
6078            }
6079        }
6080        Ok(out)
6081    }
6082
6083    fn row_chunk_into(
6084        &self,
6085        rows: Range<usize>,
6086        mut out: ArrayViewMut2<'_, f64>,
6087    ) -> Result<(), MatrixMaterializationError> {
6088        if rows.end > self.nrows || rows.start > rows.end {
6089            return Err(MatrixMaterializationError::MissingRowChunk {
6090                context: "lazy Pca row range out of bounds",
6091            });
6092        }
6093        if out.nrows() != rows.end - rows.start || out.ncols() != self.ncols {
6094            return Err(MatrixMaterializationError::MissingRowChunk {
6095                context: "lazy Pca row_chunk_into shape mismatch",
6096            });
6097        }
6098        for (local, row) in (rows.start..rows.end).enumerate() {
6099            for col in 0..self.ncols {
6100                out[[local, col]] = self.value(row, col);
6101            }
6102        }
6103        Ok(())
6104    }
6105
6106    fn to_dense(&self) -> Array2<f64> {
6107        let mut out = Array2::<f64>::zeros((self.nrows, self.ncols));
6108        self.row_chunk_into(0..self.nrows, out.view_mut())
6109            .expect("lazy Pca full materialization failed");
6110        out
6111    }
6112}
6113
6114pub fn parse_f64_2d_npy_header(
6115    bytes: &[u8],
6116    path: &PathBuf,
6117) -> Result<(usize, usize, usize), BasisError> {
6118    if bytes.len() < 10 || &bytes[0..6] != b"\x93NUMPY" {
6119        crate::bail_invalid_basis!("lazy Pca scores '{}' is not a .npy file", path.display());
6120    }
6121    let major = bytes[6];
6122    let header_len = match major {
6123        1 => u16::from_le_bytes([bytes[8], bytes[9]]) as usize,
6124        2 | 3 => {
6125            if bytes.len() < 12 {
6126                crate::bail_invalid_basis!(
6127                    "lazy Pca scores '{}' has a truncated .npy header",
6128                    path.display()
6129                );
6130            }
6131            u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize
6132        }
6133        other => {
6134            crate::bail_invalid_basis!(
6135                "lazy Pca scores '{}' uses unsupported .npy version {}",
6136                path.display(),
6137                other
6138            );
6139        }
6140    };
6141    let header_start = if major == 1 { 10 } else { 12 };
6142    let data_offset = header_start + header_len;
6143    if bytes.len() < data_offset {
6144        crate::bail_invalid_basis!(
6145            "lazy Pca scores '{}' has a truncated .npy header",
6146            path.display()
6147        );
6148    }
6149    let header = std::str::from_utf8(&bytes[header_start..data_offset]).map_err(|err| {
6150        BasisError::InvalidInput(format!(
6151            "lazy Pca scores '{}' has a non-UTF8 .npy header: {err}",
6152            path.display()
6153        ))
6154    })?;
6155    if !(header.contains("'descr': '<f8'")
6156        || header.contains("\"descr\": \"<f8\"")
6157        || header.contains("'descr': '|f8'")
6158        || header.contains("\"descr\": \"|f8\""))
6159    {
6160        crate::bail_invalid_basis!(
6161            "lazy Pca scores '{}' must be float64 little-endian .npy",
6162            path.display()
6163        );
6164    }
6165    if header.contains("True") {
6166        crate::bail_invalid_basis!(
6167            "lazy Pca scores '{}' must be C-contiguous, not Fortran-ordered",
6168            path.display()
6169        );
6170    }
6171    let shape_pos = header.find("shape").ok_or_else(|| {
6172        BasisError::InvalidInput(format!(
6173            "lazy Pca scores '{}' .npy header is missing shape",
6174            path.display()
6175        ))
6176    })?;
6177    let open = header[shape_pos..].find('(').ok_or_else(|| {
6178        BasisError::InvalidInput(format!(
6179            "lazy Pca scores '{}' .npy header has malformed shape",
6180            path.display()
6181        ))
6182    })? + shape_pos;
6183    let close = header[open..].find(')').ok_or_else(|| {
6184        BasisError::InvalidInput(format!(
6185            "lazy Pca scores '{}' .npy header has malformed shape",
6186            path.display()
6187        ))
6188    })? + open;
6189    let dims = header[open + 1..close]
6190        .split(',')
6191        .map(str::trim)
6192        .filter(|part| !part.is_empty())
6193        .map(|part| part.parse::<usize>())
6194        .collect::<Result<Vec<_>, _>>()
6195        .map_err(|err| {
6196            BasisError::InvalidInput(format!(
6197                "lazy Pca scores '{}' .npy shape is not integral: {err}",
6198                path.display()
6199            ))
6200        })?;
6201    if dims.len() != 2 {
6202        crate::bail_invalid_basis!(
6203            "lazy Pca scores '{}' must have shape (N, K), got {:?}",
6204            path.display(),
6205            dims
6206        );
6207    }
6208    Ok((data_offset, dims[0], dims[1]))
6209}
6210
6211pub fn pca_center_mean(x: ArrayView2<'_, f64>) -> Result<Array1<f64>, BasisError> {
6212    if x.nrows() == 0 {
6213        crate::bail_invalid_basis!("Pca basis requires at least one row to compute center mean");
6214    }
6215    let mut mean = Array1::<f64>::zeros(x.ncols());
6216    for row in x.rows() {
6217        mean += &row;
6218    }
6219    mean.mapv_inplace(|v| v / x.nrows() as f64);
6220    Ok(mean)
6221}
6222
6223pub fn build_pca_smooth_basis(
6224    data: ArrayView2<'_, f64>,
6225    feature_cols: &[usize],
6226    basis_matrix: &Array2<f64>,
6227    centered: bool,
6228    smooth_penalty: f64,
6229    center_mean: Option<&Array1<f64>>,
6230    pca_basis_path: Option<&PathBuf>,
6231    chunk_size: usize,
6232) -> Result<BasisBuildResult, BasisError> {
6233    if let Some(path) = pca_basis_path {
6234        let op = PcaScoresMemmapDesignOperator::open(path.clone(), chunk_size)?;
6235        if op.nrows != data.nrows() {
6236            crate::bail_dim_basis!(
6237                "lazy Pca scores row mismatch: .npy has {}, data has {}",
6238                op.nrows,
6239                data.nrows()
6240            );
6241        }
6242        let k = op.ncols;
6243        let mut penalty = Array2::<f64>::eye(k);
6244        penalty.mapv_inplace(|v| v * smooth_penalty);
6245        let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
6246            filter_active_penalty_candidates_with_ops(vec![PenaltyCandidate {
6247                matrix: penalty,
6248                nullspace_dim_hint: 0,
6249                source: PenaltySource::Other("PcaRidge".to_string()),
6250                normalization_scale: 1.0,
6251                kronecker_factors: None,
6252                op: None,
6253            }])?;
6254        return Ok(BasisBuildResult {
6255            design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op))),
6256            penalties,
6257            nullspace_dims,
6258            penaltyinfo,
6259            ops,
6260            null_eigenvectors,
6261            joint_null_rotation: None,
6262            metadata: BasisMetadata::Pca {
6263                feature_cols: feature_cols.to_vec(),
6264                basis_matrix: basis_matrix.clone(),
6265                centered,
6266                smooth_penalty,
6267                center_mean: center_mean.cloned(),
6268                pca_basis_path: Some(path.clone()),
6269                chunk_size: chunk_size.max(1),
6270            },
6271            kronecker_factored: None,
6272        });
6273    }
6274    if basis_matrix.nrows() != feature_cols.len() {
6275        crate::bail_dim_basis!(
6276            "Pca basis row mismatch: basis rows={}, feature columns={}",
6277            basis_matrix.nrows(),
6278            feature_cols.len()
6279        );
6280    }
6281    let mut x = select_columns(data, feature_cols)?;
6282    let mean = if centered {
6283        match center_mean {
6284            Some(mean) => mean.clone(),
6285            None => pca_center_mean(x.view())?,
6286        }
6287    } else {
6288        Array1::<f64>::zeros(feature_cols.len())
6289    };
6290    if centered {
6291        for mut row in x.rows_mut() {
6292            row -= &mean;
6293        }
6294    }
6295    let design = fast_ab(&x, basis_matrix);
6296    let k = basis_matrix.ncols();
6297    let mut penalty = Array2::<f64>::eye(k);
6298    penalty.mapv_inplace(|v| v * smooth_penalty);
6299    let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
6300        filter_active_penalty_candidates_with_ops(vec![PenaltyCandidate {
6301            matrix: penalty,
6302            nullspace_dim_hint: 0,
6303            source: PenaltySource::Other("PcaRidge".to_string()),
6304            normalization_scale: 1.0,
6305            kronecker_factors: None,
6306            op: None,
6307        }])?;
6308    Ok(BasisBuildResult {
6309        design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(design)),
6310        penalties,
6311        nullspace_dims,
6312        penaltyinfo,
6313        ops,
6314        null_eigenvectors,
6315        joint_null_rotation: None,
6316        metadata: BasisMetadata::Pca {
6317            feature_cols: feature_cols.to_vec(),
6318            basis_matrix: basis_matrix.clone(),
6319            centered,
6320            smooth_penalty,
6321            center_mean: centered.then_some(mean),
6322            pca_basis_path: None,
6323            chunk_size: chunk_size.max(1),
6324        },
6325        kronecker_factored: None,
6326    })
6327}
6328
6329/// A factor-level `by=` wrapper owns the model-space centering of its inner
6330/// smooth: it gates the raw/structurally-constrained basis to the level rows
6331/// and then centers that gated block exactly once against the level indicator
6332/// (`build_parametric_constraint_block_for_term` in `design_construction`).
6333/// Leaving the inner B-spline's default pooled weighted-sum-to-zero active here
6334/// would impose two generically-independent constraints — the pooled column
6335/// moment `m = Σ_h m_h` and the per-level moment `m_g` — so a raw `k`-column
6336/// basis collapses to `k-2` columns per level instead of `k-1`, deleting one
6337/// genuine nonconstant spline direction *before REML runs* (#1427). The group
6338/// main effect carries only the constant, so it cannot restore that direction.
6339///
6340/// Only the *default model-space* centering is deferred. Explicit structural or
6341/// frozen transforms (`RemoveLinearTrend`, `OrthogonalToDesignColumns`,
6342/// `FrozenTransform`, `None`) are user/structural choices and are preserved
6343/// verbatim.
6344pub fn defer_inner_model_centering_to_factor_level_wrapper(basis: &mut SmoothBasisSpec) {
6345    if let SmoothBasisSpec::BSpline1D { spec, .. } = basis
6346        && matches!(
6347            spec.identifiability,
6348            BSplineIdentifiability::WeightedSumToZero { .. }
6349        )
6350    {
6351        spec.identifiability = BSplineIdentifiability::None;
6352    }
6353}
6354
6355pub fn apply_by_variable_to_local_build(
6356    mut built: LocalSmoothTermBuild,
6357    data: ArrayView2<'_, f64>,
6358    by_col: usize,
6359    by: &ByVariableSpec,
6360    term_name: &str,
6361) -> Result<LocalSmoothTermBuild, BasisError> {
6362    if by_col >= data.ncols() {
6363        crate::bail_dim_basis!(
6364            "by-variable smooth term '{term_name}' references column {by_col}, but data has {} columns",
6365            data.ncols()
6366        );
6367    }
6368    let weights = match by {
6369        ByVariableSpec::Numeric => data.column(by_col).to_owned(),
6370        ByVariableSpec::Level { value_bits, .. } => data.column(by_col).mapv(|value| {
6371            if value.to_bits() == *value_bits {
6372                1.0
6373            } else {
6374                0.0
6375            }
6376        }),
6377    };
6378    if weights.iter().any(|value| !value.is_finite()) {
6379        crate::bail_invalid_basis!(
6380            "by-variable smooth term '{term_name}' has non-finite by-column values"
6381        );
6382    }
6383
6384    let mut dense = built
6385        .design
6386        .try_to_dense_by_chunks("by-variable smooth row gating")
6387        .map_err(BasisError::InvalidInput)?;
6388    for (mut row, &weight) in dense.rows_mut().into_iter().zip(weights.iter()) {
6389        row.mapv_inplace(|value| value * weight);
6390    }
6391    built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense));
6392    built.kronecker_factored = None;
6393    Ok(built)
6394}
6395
6396/// Build the local smooth term for a `BySmooth` spec, which unifies numeric-by
6397/// and factor-by modulation into a single `SmoothTermSpec`.
6398///
6399/// For a **numeric** by-variable the inner smooth is built once and every row
6400/// is multiplied by the by-column value (identical to `ByVariable::Numeric`).
6401///
6402/// For a **factor** by-variable the inner smooth is built once and gated per
6403/// level into side-by-side column blocks, producing a `n × (L * p)` design
6404/// matrix.  The penalties are block-diagonalised (one copy of the inner penalty
6405/// per level) exactly as `build_factor_smooth` does for `bs="fs"/"sz"`.
6406pub fn build_by_smooth_local(
6407    data: ArrayView2<'_, f64>,
6408    term: &SmoothTermSpec,
6409    smooth: &SmoothBasisSpec,
6410    by_kind: &ByVarKind,
6411    workspace: &mut crate::basis::BasisWorkspace,
6412) -> Result<LocalSmoothTermBuild, BasisError> {
6413    let inner_term = SmoothTermSpec {
6414        name: term.name.clone(),
6415        basis: (*smooth).clone(),
6416        shape: term.shape,
6417        joint_null_rotation: None,
6418    };
6419    let inner = build_single_local_smooth_term(data, &inner_term, workspace)?;
6420
6421    match by_kind {
6422        ByVarKind::Numeric { feature_col } => {
6423            let inner_meta = inner.metadata.clone();
6424            let mut built = apply_by_variable_to_local_build(
6425                inner,
6426                data,
6427                *feature_col,
6428                &ByVariableSpec::Numeric,
6429                &term.name,
6430            )?;
6431            built.metadata = BasisMetadata::BySmooth {
6432                inner: Box::new(inner_meta),
6433                by_col: *feature_col,
6434                levels: None,
6435                ordered: false,
6436            };
6437            Ok(built)
6438        }
6439        ByVarKind::Factor {
6440            feature_col,
6441            frozen_levels,
6442            ordered,
6443        } => {
6444            // Collect factor levels: prefer the frozen set (replay path), else
6445            // scan the data column (first-fit path).
6446            let level_bits: Vec<u64> = if let Some(fl) = frozen_levels {
6447                fl.clone()
6448            } else {
6449                let col = data.column(*feature_col);
6450                let mut seen = BTreeSet::<u64>::new();
6451                for &v in col.iter() {
6452                    if v.is_finite() {
6453                        seen.insert(v.to_bits());
6454                    }
6455                }
6456                seen.into_iter().collect()
6457            };
6458            let n_levels = level_bits.len();
6459            if n_levels == 0 {
6460                crate::bail_invalid_basis!(
6461                    "by-factor smooth term '{}': factor column {} has no observed levels",
6462                    term.name,
6463                    feature_col
6464                );
6465            }
6466            let p = inner.dim;
6467            let q = n_levels * p;
6468            let n = data.nrows();
6469
6470            let inner_dense = inner
6471                .design
6472                .try_to_dense_by_chunks("by-factor smooth design gating")
6473                .map_err(BasisError::InvalidInput)?;
6474
6475            // Gate each level into its own p-wide column block.
6476            let mut combined = Array2::<f64>::zeros((n, q));
6477            for (lvl_idx, &bits) in level_bits.iter().enumerate() {
6478                let col_start = lvl_idx * p;
6479                for row in 0..n {
6480                    if data[[row, *feature_col]].to_bits() == bits {
6481                        combined
6482                            .slice_mut(s![row, col_start..col_start + p])
6483                            .assign(&inner_dense.row(row));
6484                    }
6485                }
6486            }
6487
6488            // Build per-level INDEPENDENT penalties (#1427): one copy of each
6489            // inner penalty per level, but each confined to that single level's
6490            // diagonal block, so every (level, inner-penalty) pair is its OWN
6491            // smoothing-parameter coordinate. `s(x, by=g)` selects the per-group
6492            // curve wiggliness independently — the design is block-diagonal and
6493            // block-separable, so a correct REML must reproduce gamfit's own
6494            // independent per-group fits. Tiling a single inner penalty across
6495            // every level (as the `bs="fs"` shared-λ random-effect construction
6496            // does) collapses all groups onto ONE λ, which cannot match uneven
6497            // per-level smoothness and degrades as data grows (under-recovery up
6498            // to ~16× at n=2000). Emit `n_levels * n_penalties` blocks instead.
6499            let inner_meta = inner.metadata.clone();
6500            let n_penalties = inner.penalties.len();
6501            let n_blocks = n_penalties.saturating_mul(n_levels);
6502            let mut penalties = Vec::<Array2<f64>>::with_capacity(n_blocks);
6503            let mut penaltyinfo = Vec::<PenaltyInfo>::with_capacity(n_blocks);
6504            let mut nullspaces = Vec::<usize>::with_capacity(n_blocks);
6505            for (pen_pos, s_inner) in inner.penalties.iter().enumerate() {
6506                for lvl in 0..n_levels {
6507                    let off = lvl * p;
6508                    let mut s_big = Array2::<f64>::zeros((q, q));
6509                    s_big
6510                        .slice_mut(s![off..off + p, off..off + p])
6511                        .assign(s_inner);
6512                    let (s_big, scale) = normalize_penalty_in_constrained_space(&s_big);
6513                    let mut info = inner.penaltyinfo[pen_pos].clone();
6514                    // Distinct original_index per (penalty, level) so each λ is a
6515                    // separate identifiable coordinate downstream.
6516                    info.original_index = pen_pos * n_levels + lvl;
6517                    info.normalization_scale *= scale;
6518                    // Each block now spans exactly ONE level → per-level nullity,
6519                    // not the tiled (× n_levels) hint of the shared construction.
6520                    info.kronecker_factors = None;
6521                    penalties.push(s_big);
6522                    penaltyinfo.push(info);
6523                    nullspaces.push(inner.nullspaces[pen_pos]);
6524                }
6525            }
6526
6527            let null_eigenvectors = vec![None; penalties.len()];
6528            let ops = vec![None; penalties.len()];
6529
6530            Ok(LocalSmoothTermBuild {
6531                dim: q,
6532                design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(combined)),
6533                penalties,
6534                ops,
6535                nullspaces,
6536                null_eigenvectors,
6537                joint_null_rotation: None,
6538                penaltyinfo,
6539                pre_dropped_penaltyinfo: inner.pre_dropped_penaltyinfo,
6540                metadata: BasisMetadata::BySmooth {
6541                    inner: Box::new(inner_meta),
6542                    by_col: *feature_col,
6543                    levels: Some(level_bits),
6544                    ordered: *ordered,
6545                },
6546                linear_constraints: None,
6547                box_reparam: false,
6548                kronecker_factored: None,
6549            })
6550        }
6551    }
6552}
6553
6554pub fn ensure_by_variable_specs_match(
6555    kind: &BySmoothKind,
6556    by: &ByVariableSpec,
6557    term_name: &str,
6558) -> Result<(), BasisError> {
6559    match (kind, by) {
6560        (BySmoothKind::Numeric, ByVariableSpec::Numeric) => Ok(()),
6561        (BySmoothKind::Level { level_bits }, ByVariableSpec::Level { value_bits, .. })
6562            if level_bits == value_bits =>
6563        {
6564            Ok(())
6565        }
6566        _ => Err(BasisError::InvalidInput(format!(
6567            "by-variable smooth term '{term_name}' has inconsistent by-variable specifications"
6568        ))),
6569    }
6570}
6571
6572/// Build a factor-smooth interaction basis (`bs="fs"`/`"sz"`/`"re"`).
6573///
6574/// A factor smooth replicates a shared marginal smooth in the continuous
6575/// covariate(s) once per level of a grouping factor, coupling all level blocks
6576/// through a *single* set of smoothing parameters (one per marginal penalty).
6577/// This is mgcv's `smooth.construct.fs.smooth.spec` realization and the
6578/// random-effect interpretation of a smooth: the per-level deviations are an
6579/// exchangeable family whose joint wiggliness/shrinkage is governed by the
6580/// shared λ, so the construction scales to many levels with a fixed parameter
6581/// count.
6582///
6583/// Flavours:
6584/// * `Fs` — full random factor-smooth. The marginal carries its wiggliness
6585///   penalty *and* a null-space ridge (double penalty), so the replicated
6586///   design is a proper full-rank random effect: each level's curve is shrunk
6587///   toward zero (intercept + linear trend included), recovering the mgcv
6588///   `bs="fs"` penalty structure `I_L ⊗ S_j` for every marginal penalty `S_j`.
6589/// * `Sz` — sum-to-zero factor smooth. Delegates to the existing
6590///   [`SmoothBasisSpec::FactorSumToZero`] construction (`L-1` deviation blocks,
6591///   coefficient-wise zero sum across levels).
6592/// * `Re` — pure random effect / random slope (`bs="re"`). A degree-1 marginal
6593///   gives the per-level `[1, x]` span; the penalty is the identity over each
6594///   level block (iid Gaussian coefficients), matching mgcv's `bs="re"` ridge.
6595///
6596/// The grouping levels are resolved once at fit time (sorted unique bit
6597/// patterns of the factor column) and frozen into the returned metadata so the
6598/// predict-time rebuild evaluates every row against its own level's block.
6599pub fn build_factor_smooth(
6600    data: ArrayView2<'_, f64>,
6601    spec: &FactorSmoothSpec,
6602    term_name: &str,
6603    workspace: &mut crate::basis::BasisWorkspace,
6604) -> Result<LocalSmoothTermBuild, BasisError> {
6605    if spec.continuous_cols.len() != 1 {
6606        crate::bail_invalid_basis!(
6607            "factor smooth term '{}' currently supports exactly one continuous covariate; found {}",
6608            term_name,
6609            spec.continuous_cols.len()
6610        );
6611    }
6612    let feature_col = spec.continuous_cols[0];
6613    let group_col = spec.group_col;
6614    if feature_col >= data.ncols() || group_col >= data.ncols() {
6615        crate::bail_dim_basis!(
6616            "factor smooth term '{}' references columns ({}, {}) out of bounds for {} columns",
6617            term_name,
6618            feature_col,
6619            group_col,
6620            data.ncols()
6621        );
6622    }
6623
6624    // `Sz` is exactly the existing sum-to-zero factor smooth: reuse it verbatim
6625    // so there is a single source of truth for the zero-sum construction.
6626    if matches!(spec.flavour, FactorSmoothFlavour::Sz) {
6627        let levels = resolve_factor_smooth_levels(data, group_col, spec, term_name)?;
6628        let inner = SmoothBasisSpec::BSpline1D {
6629            feature_col,
6630            spec: factor_smooth_marginal_for_replay(&spec.marginal),
6631        };
6632        let sz_term = SmoothTermSpec {
6633            name: term_name.to_string(),
6634            basis: SmoothBasisSpec::FactorSumToZero {
6635                inner: Box::new(inner),
6636                by_col: group_col,
6637                levels: levels.clone(),
6638                frozen_global_orthogonality: None,
6639            },
6640            shape: ShapeConstraint::None,
6641            joint_null_rotation: None,
6642        };
6643        let mut built = build_single_local_smooth_term(data, &sz_term, workspace)?;
6644        // The delegated `FactorSumToZero` build returns the BARE inner B-spline
6645        // metadata (`BasisMetadata::BSpline1D`), but the term that owns this
6646        // build carries a `SmoothBasisSpec::FactorSmooth { Sz }` spec. Two
6647        // things break if we hand that mismatched pair downstream:
6648        //   1. `freeze_smooth_basis_from_metadata` matches on (spec, metadata)
6649        //      and has no `(FactorSmooth, BSpline1D)` arm, so any refit / spatial
6650        //      re-optimization that freezes the basis aborts with a "smooth
6651        //      metadata/spec type mismatch" error.
6652        //   2. The bare B-spline metadata carries no grouping levels, so a
6653        //      predict-time rebuild cannot replay the SAME replicated design.
6654        // Re-wrap the marginal geometry as `FactorSmooth` metadata exactly as
6655        // the Fs/Re path below does, giving all three factor-smooth flavours a
6656        // single, freeze-consistent metadata shape that also pins the levels.
6657        // Since #1605 the sz marginal is ALWAYS the penalized B-spline the `fs`
6658        // sibling uses (a natural cubic regression marginal hard-enforces f''=0
6659        // at the boundary and cannot represent curved deviations — a consistency
6660        // failure). The `CubicRegression1D` arm below is therefore unreachable on
6661        // a freshly-built sz spec; it is retained only as defense / backward
6662        // compatibility for a frozen spec that still carries a cr marginal, so
6663        // the predict-time freeze restores whatever marginal class it finds.
6664        let (knots, degree, periodic, marginal_is_cr) = match &built.metadata {
6665            BasisMetadata::BSpline1D {
6666                knots,
6667                periodic,
6668                degree,
6669                ..
6670            } => (
6671                knots.clone(),
6672                degree.unwrap_or(spec.marginal.degree),
6673                *periodic,
6674                false,
6675            ),
6676            BasisMetadata::CubicRegression1D { knots, .. } => {
6677                (knots.clone(), spec.marginal.degree, None, true)
6678            }
6679            other => {
6680                crate::bail_invalid_basis!(
6681                    "sz factor smooth term '{}' produced an unexpected marginal metadata variant {:?}",
6682                    term_name,
6683                    other
6684                );
6685            }
6686        };
6687        built.metadata = BasisMetadata::FactorSmooth {
6688            continuous_cols: spec.continuous_cols.clone(),
6689            group_col,
6690            knots,
6691            degree,
6692            periodic,
6693            group_levels: levels,
6694            flavour: "sz".to_string(),
6695            marginal_is_cr,
6696        };
6697        return Ok(built);
6698    }
6699
6700    let levels = resolve_factor_smooth_levels(data, group_col, spec, term_name)?;
6701    let n_levels = levels.len();
6702    if n_levels < 2 {
6703        crate::bail_invalid_basis!(
6704            "factor smooth term '{}' requires at least two grouping levels; found {}",
6705            term_name,
6706            n_levels
6707        );
6708    }
6709
6710    // `Fs` (order ≥ 1, the default) is the random-effect flavour: it penalizes
6711    // each null-space dimension of the marginal wiggliness penalty separately
6712    // below (mgcv's `bs="fs"` construction). That replaces the marginal's single
6713    // *combined* double penalty, so disable the latter here to avoid penalizing
6714    // the null space twice (once combined, once per dimension). The explicit
6715    // `m=0` opt-out keeps the legacy combined double penalty and adds no
6716    // per-dimension penalties.
6717    let use_per_dim_null = matches!(
6718        &spec.flavour,
6719        FactorSmoothFlavour::Fs { m_null_penalty_orders }
6720            if m_null_penalty_orders.iter().copied().max().unwrap_or(0) >= 1
6721    );
6722
6723    // Build the shared marginal design + penalties from the 1-D B-spline.
6724    // `Re` forces a degree-1 marginal (linear span) and replaces the marginal
6725    // wiggliness with an identity ridge below; `Fs` keeps the user's marginal
6726    // (cubic by default) and, under the per-dimension null path, gets its null
6727    // space penalized one dimension at a time after replication.
6728    let mut marginal_spec = factor_smooth_marginal_for_replay(&spec.marginal);
6729    if use_per_dim_null {
6730        marginal_spec.double_penalty = false;
6731    }
6732    let inner_term = SmoothTermSpec {
6733        name: format!("{term_name}::marginal"),
6734        basis: SmoothBasisSpec::BSpline1D {
6735            feature_col,
6736            spec: marginal_spec,
6737        },
6738        shape: ShapeConstraint::None,
6739        joint_null_rotation: None,
6740    };
6741    let inner = build_single_local_smooth_term(data, &inner_term, workspace)?;
6742    let mut base = inner
6743        .design
6744        .try_to_dense_by_chunks("factor smooth marginal")
6745        .map_err(BasisError::InvalidInput)?;
6746    if matches!(spec.flavour, FactorSmoothFlavour::Re) {
6747        // `bs="re"` is a parametric random intercept+slope, not a B-spline
6748        // smooth evaluated through clamped knot support.  A degree-1 B-spline
6749        // with no internal knots spans the training rows, but outside the
6750        // boundary knots its basis is not the model matrix for `(1 + x | g)`;
6751        // held-out extrapolation then loses the random slope contribution.
6752        // Build the random-effect marginal directly as `[1, x - c]`, centered
6753        // at the frozen marginal domain, so fit-time and replay-time rows use
6754        // the same well-conditioned parametric columns on and off the training
6755        // interval.
6756        let center = match &inner.metadata {
6757            BasisMetadata::BSpline1D { knots, .. } if !knots.is_empty() => {
6758                0.5 * (knots[0] + knots[knots.len() - 1])
6759            }
6760            _ => 0.0,
6761        };
6762        let mut linear = Array2::<f64>::ones((data.nrows(), 2));
6763        linear
6764            .column_mut(1)
6765            .assign(&data.column(feature_col).mapv(|x| x - center));
6766        base = linear;
6767    }
6768    let n = base.nrows();
6769    let p = base.ncols();
6770    let q = p * n_levels;
6771
6772    // Block-diagonal replicated design: row i contributes its marginal row to
6773    // the column block owned by its grouping level, zeros elsewhere.
6774    let mut dense = Array2::<f64>::zeros((n, q));
6775    for i in 0..n {
6776        let bits = data[[i, group_col]].to_bits();
6777        let level_idx = levels.iter().position(|b| *b == bits).ok_or_else(|| {
6778            BasisError::InvalidInput(format!(
6779                "factor smooth term '{term_name}' saw an unseen grouping level at row {}",
6780                i + 1
6781            ))
6782        })?;
6783        let start = level_idx * p;
6784        dense
6785            .slice_mut(s![i, start..start + p])
6786            .assign(&base.row(i));
6787    }
6788
6789    // Penalties: replicate each marginal penalty into a block-diagonal
6790    // `I_L ⊗ S_j` so every level shares the same smoothing parameter λ_j (one
6791    // λ per marginal penalty), the defining feature of a factor smooth. For
6792    // `Re` the marginal penalty is replaced by one ridge per parametric
6793    // coordinate so intercept and slope variances can be learned separately.
6794    let marginal_penalties: Vec<Array2<f64>> = if matches!(spec.flavour, FactorSmoothFlavour::Re) {
6795        (0..p)
6796            .map(|j| {
6797                let mut s = Array2::<f64>::zeros((p, p));
6798                s[[j, j]] = 1.0;
6799                s
6800            })
6801            .collect()
6802    } else {
6803        inner.penalties.clone()
6804    };
6805    let marginal_penaltyinfo: Vec<PenaltyInfo> = if matches!(spec.flavour, FactorSmoothFlavour::Re)
6806    {
6807        (0..p)
6808            .map(|j| PenaltyInfo {
6809                source: PenaltySource::Primary,
6810                original_index: j,
6811                active: true,
6812                effective_rank: 1,
6813                dropped_reason: None,
6814                nullspace_dim_hint: p.saturating_sub(1),
6815                normalization_scale: 1.0,
6816                kronecker_factors: None,
6817            })
6818            .collect()
6819    } else {
6820        inner.penaltyinfo.clone()
6821    };
6822    if marginal_penalties.len() != marginal_penaltyinfo.len() {
6823        crate::bail_invalid_basis!(
6824            "internal factor-smooth penalty metadata mismatch for term '{}': penalties={}, infos={}",
6825            term_name,
6826            marginal_penalties.len(),
6827            marginal_penaltyinfo.len()
6828        );
6829    }
6830
6831    let mut penalties = Vec::<Array2<f64>>::with_capacity(marginal_penalties.len());
6832    let mut penaltyinfo = Vec::<PenaltyInfo>::with_capacity(marginal_penalties.len());
6833    for (penalty_pos, s_inner) in marginal_penalties.iter().enumerate() {
6834        let mut s_big = Array2::<f64>::zeros((q, q));
6835        for level in 0..n_levels {
6836            let start = level * p;
6837            s_big
6838                .slice_mut(s![start..start + p, start..start + p])
6839                .assign(s_inner);
6840        }
6841        let (s_big, factor_smooth_scale) = normalize_penalty_in_constrained_space(&s_big);
6842        let mut info = marginal_penaltyinfo[penalty_pos].clone();
6843        info.original_index = penalty_pos;
6844        info.normalization_scale *= factor_smooth_scale;
6845        info.nullspace_dim_hint = info.nullspace_dim_hint.saturating_mul(n_levels);
6846        info.kronecker_factors = None;
6847        penalties.push(s_big);
6848        penaltyinfo.push(info);
6849    }
6850
6851    let mut nullspaces: Vec<usize> = if matches!(spec.flavour, FactorSmoothFlavour::Re) {
6852        vec![q.saturating_sub(n_levels); p]
6853    } else {
6854        inner
6855            .nullspaces
6856            .iter()
6857            .map(|ns| ns.saturating_mul(n_levels))
6858            .collect()
6859    };
6860
6861    // `Fs` is the random-effect flavour of a smooth: the per-group curve is an
6862    // exchangeable Gaussian *function*, so EVERY coefficient — including the
6863    // {const, linear} null space of the marginal wiggliness penalty — must be
6864    // shrinkable toward zero under its own shared variance. The wiggliness
6865    // penalty `S_wiggle` shapes curvature but leaves the per-group intercept and
6866    // slope (its null space) completely UNPENALIZED. With the null space free,
6867    // each group fits its own intercept and slope with NO partial pooling, so
6868    // the held-out per-subject forecast inherits the full no-pooling variance
6869    // and curves away from the true per-group line (gam#712 real arm, gam#713;
6870    // gam#903 sleepstudy forecast ran ~74% over the lme4 BLUP bar).
6871    //
6872    // mgcv's `bs="fs"` fixes this by penalizing each null-space dimension
6873    // SEPARATELY (`smooth.construct.fs.smooth.spec` adds one rank-1 penalty per
6874    // null coordinate), each replicated block-diagonally across levels under a
6875    // single shared smoothing parameter — so REML fits a distinct
6876    // random-intercept variance and random-slope variance, the partial pooling
6877    // that makes the forecast track lme4's correlated random-effect BLUP. A
6878    // single *combined* null penalty (one λ for intercept+slope together) cannot
6879    // express the typically very different intercept and slope variances, which
6880    // is the residual forecast gap. We mirror mgcv exactly: for each orthonormal
6881    // null direction `z_k` of the marginal wiggliness penalty add
6882    // `I_L ⊗ (z_k z_kᵀ)` as its own penalty. The marginal's combined double
6883    // penalty was disabled above, so the null space is penalized once, per
6884    // dimension. With linear data REML drives the curvature λ up and degrades
6885    // `fs` to a linear random slope (edf → ≈2/group); with genuine curvature the
6886    // wiggliness λ stays small and the wiggle survives (data-adaptive, not a
6887    // cap). Gated by `m_null_penalty_orders`: order ≥ 1 (default) enables the
6888    // per-dimension null penalties; `m=0` keeps the legacy combined double
6889    // penalty and adds nothing here.
6890    if use_per_dim_null
6891        && let Some(Some(z)) = inner.null_eigenvectors.first()
6892        && z.nrows() == p
6893    {
6894        for k in 0..z.ncols() {
6895            // Rank-1 marginal penalty `z_k z_kᵀ`, replicated block-diagonally
6896            // across levels into `I_L ⊗ (z_k z_kᵀ)`. Its own λ is one shared
6897            // variance for this null component (intercept or slope) across all
6898            // groups — the random-effect structure of mgcv `fs`.
6899            let zk = z.column(k);
6900            let mut p_k = Array2::<f64>::zeros((p, p));
6901            for a in 0..p {
6902                for b in 0..p {
6903                    p_k[[a, b]] = zk[a] * zk[b];
6904                }
6905            }
6906            let mut s_null = Array2::<f64>::zeros((q, q));
6907            for level in 0..n_levels {
6908                let start = level * p;
6909                s_null
6910                    .slice_mut(s![start..start + p, start..start + p])
6911                    .assign(&p_k);
6912            }
6913            let (s_null, null_scale) = normalize_penalty_in_constrained_space(&s_null);
6914            let null_block = crate::basis::analyze_penalty_block_with_op(&s_null, None)?;
6915            if null_block.rank > 0 {
6916                let original_index = penalties.len();
6917                penalties.push(null_block.sym_penalty);
6918                nullspaces.push(null_block.nullity);
6919                penaltyinfo.push(PenaltyInfo {
6920                    source: PenaltySource::Primary,
6921                    original_index,
6922                    active: true,
6923                    effective_rank: null_block.rank,
6924                    dropped_reason: None,
6925                    nullspace_dim_hint: null_block.nullity,
6926                    normalization_scale: null_scale,
6927                    kronecker_factors: None,
6928                });
6929            }
6930        }
6931    }
6932    let null_eigenvectors = crate::basis::recompute_null_eigenvectors(&penalties)?;
6933    let joint_null_rotation = crate::basis::compute_joint_null_rotation(&penalties)?;
6934
6935    // Metadata: carry the marginal knot geometry + frozen levels so prediction
6936    // reconstructs an identical replicated design.
6937    let (knots, degree, periodic) = match &inner.metadata {
6938        BasisMetadata::BSpline1D {
6939            knots,
6940            periodic,
6941            degree,
6942            ..
6943        } => (
6944            knots.clone(),
6945            degree.unwrap_or(spec.marginal.degree),
6946            *periodic,
6947        ),
6948        other => {
6949            crate::bail_invalid_basis!(
6950                "factor smooth term '{}' produced an unexpected marginal metadata variant {:?}",
6951                term_name,
6952                other
6953            );
6954        }
6955    };
6956    let flavour_tag = match &spec.flavour {
6957        FactorSmoothFlavour::Fs { .. } => "fs",
6958        FactorSmoothFlavour::Sz => "sz",
6959        FactorSmoothFlavour::Re => "re",
6960    }
6961    .to_string();
6962    let metadata = BasisMetadata::FactorSmooth {
6963        continuous_cols: spec.continuous_cols.clone(),
6964        group_col,
6965        knots,
6966        degree,
6967        periodic,
6968        group_levels: levels,
6969        flavour: flavour_tag,
6970        // fs/re marginals are always B-spline; the cr marginal is sz-only and
6971        // handled on the dedicated Sz path above.
6972        marginal_is_cr: false,
6973    };
6974
6975    let ops = vec![None; penalties.len()];
6976    Ok(LocalSmoothTermBuild {
6977        dim: q,
6978        design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense)),
6979        penalties,
6980        ops,
6981        nullspaces,
6982        null_eigenvectors,
6983        joint_null_rotation,
6984        penaltyinfo,
6985        pre_dropped_penaltyinfo: Vec::new(),
6986        metadata,
6987        linear_constraints: None,
6988        box_reparam: false,
6989        kronecker_factored: None,
6990    })
6991}
6992
6993/// Resolve the grouping levels for a factor smooth: replay the frozen level
6994/// list when present (predict path), otherwise discover the sorted unique bit
6995/// patterns of the factor column (fit path).
6996pub fn resolve_factor_smooth_levels(
6997    data: ArrayView2<'_, f64>,
6998    group_col: usize,
6999    spec: &FactorSmoothSpec,
7000    term_name: &str,
7001) -> Result<Vec<u64>, BasisError> {
7002    if let Some(frozen) = &spec.group_frozen_levels {
7003        if frozen.is_empty() {
7004            crate::bail_invalid_basis!(
7005                "factor smooth term '{}' has an empty frozen level list",
7006                term_name
7007            );
7008        }
7009        return Ok(frozen.clone());
7010    }
7011    let mut bits: Vec<u64> = data.column(group_col).iter().map(|v| v.to_bits()).collect();
7012    bits.sort_by(|a, b| {
7013        f64::from_bits(*a)
7014            .partial_cmp(&f64::from_bits(*b))
7015            .unwrap_or(std::cmp::Ordering::Equal)
7016    });
7017    bits.dedup();
7018    Ok(bits)
7019}
7020
7021/// Marginal B-spline spec for a factor-smooth block. The marginal always builds
7022/// without an identifiability constraint (the per-level replication, not a
7023/// sum-to-zero side constraint, provides identifiability against the parametric
7024/// block). At predict time the marginal's knot geometry has already been pinned
7025/// into `marginal.knotspec` by the metadata replay, so the spec is used
7026/// verbatim aside from clearing the identifiability transform.
7027pub fn factor_smooth_marginal_for_replay(marginal: &BSplineBasisSpec) -> BSplineBasisSpec {
7028    let mut m = marginal.clone();
7029    m.identifiability = BSplineIdentifiability::None;
7030    m
7031}
7032
7033pub fn build_single_local_smooth_term(
7034    data: ArrayView2<'_, f64>,
7035    term: &SmoothTermSpec,
7036    workspace: &mut crate::basis::BasisWorkspace,
7037) -> Result<LocalSmoothTermBuild, BasisError> {
7038    if term.shape != ShapeConstraint::None && !shape_supports_basis(term) {
7039        crate::bail_invalid_basis!(
7040            "ShapeConstraint::{:?} is unsupported for term '{}'",
7041            term.shape,
7042            term.name
7043        );
7044    }
7045    if let SmoothBasisSpec::ByVariable {
7046        inner,
7047        by_col,
7048        kind,
7049        by,
7050    } = &term.basis
7051    {
7052        ensure_by_variable_specs_match(kind, by, &term.name)?;
7053        let mut inner_basis = (**inner).clone();
7054        // Factor-level `by=` owns model-space centering (it centers the gated
7055        // block against the level indicator downstream). Defer the inner
7056        // basis's default pooled centering so the level block is not
7057        // double-centered down to `k-2` columns (#1427). Numeric-by smooths are
7058        // untouched: they are not row-gated to a level and keep ordinary
7059        // intercept centering.
7060        if matches!(by, ByVariableSpec::Level { .. }) {
7061            defer_inner_model_centering_to_factor_level_wrapper(&mut inner_basis);
7062        }
7063        let inner_term = SmoothTermSpec {
7064            name: term.name.clone(),
7065            basis: inner_basis,
7066            shape: term.shape,
7067            joint_null_rotation: None,
7068        };
7069        let built = build_single_local_smooth_term(data, &inner_term, workspace)?;
7070        return apply_by_variable_to_local_build(built, data, *by_col, by, &term.name);
7071    }
7072
7073    // BySmooth: a `by=` smooth that unifies numeric or factor modulation into a
7074    // single term.  Lower it here so the downstream match does not need an arm.
7075    if let SmoothBasisSpec::BySmooth { smooth, by_kind } = &term.basis {
7076        return build_by_smooth_local(data, term, smooth, by_kind, workspace);
7077    }
7078
7079    let mut shape_axis_col: Option<usize> = None;
7080    let mut built: BasisBuildResult = match &term.basis {
7081        SmoothBasisSpec::FactorSumToZero {
7082            inner,
7083            by_col,
7084            levels,
7085            ..
7086        } => {
7087            if *by_col >= data.ncols() {
7088                crate::bail_dim_basis!(
7089                    "term '{}' by column {} out of bounds for {} columns",
7090                    term.name,
7091                    by_col,
7092                    data.ncols()
7093                );
7094            }
7095            if levels.len() < 2 {
7096                crate::bail_invalid_basis!(
7097                    "sum-to-zero factor smooth term '{}' requires at least two levels",
7098                    term.name
7099                );
7100            }
7101            if term.shape != ShapeConstraint::None {
7102                crate::bail_invalid_basis!(
7103                    "ShapeConstraint::{:?} is unsupported for sum-to-zero factor smooth term '{}'",
7104                    term.shape,
7105                    term.name
7106                );
7107            }
7108            let inner_term = SmoothTermSpec {
7109                name: format!("{}::inner", term.name),
7110                basis: (**inner).clone(),
7111                shape: ShapeConstraint::None,
7112                joint_null_rotation: None,
7113            };
7114            let mut inner_built = build_single_local_smooth_term(data, &inner_term, workspace)?;
7115            // Capture the marginal penalty's null directions BEFORE the penalty
7116            // vector is rebuilt below; the sum-to-zero null-space ridge replicates
7117            // these `z_k` into the contrast space (mgcv `bs="fs"` double-penalty).
7118            let inner_null_eigenvectors = inner_built.null_eigenvectors.clone();
7119            let base = inner_built
7120                .design
7121                .try_to_dense_by_chunks("sum-to-zero factor smooth")
7122                .map_err(BasisError::InvalidInput)?;
7123            let n = base.nrows();
7124            let p = base.ncols();
7125            let l_minus_one = levels.len() - 1;
7126            let mut dense = Array2::<f64>::zeros((n, p * l_minus_one));
7127            for i in 0..n {
7128                let bits = data[[i, *by_col]].to_bits();
7129                let level_idx = levels.iter().position(|b| *b == bits).ok_or_else(|| {
7130                    BasisError::InvalidInput(format!(
7131                        "sum-to-zero factor smooth term '{}' saw an unseen level at row {}",
7132                        term.name,
7133                        i + 1
7134                    ))
7135                })?;
7136                if level_idx < l_minus_one {
7137                    let start = level_idx * p;
7138                    dense
7139                        .slice_mut(s![i, start..start + p])
7140                        .assign(&base.row(i));
7141                } else {
7142                    for level in 0..l_minus_one {
7143                        let start = level * p;
7144                        dense
7145                            .slice_mut(s![i, start..start + p])
7146                            .assign(&base.row(i).mapv(|v| -v));
7147                    }
7148                }
7149            }
7150            let mut penalties = Vec::<Array2<f64>>::with_capacity(inner_built.penalties.len());
7151            let active_penalty_indices = inner_built
7152                .penaltyinfo
7153                .iter()
7154                .enumerate()
7155                .filter_map(|(idx, info)| info.active.then_some(idx))
7156                .collect::<Vec<_>>();
7157            if active_penalty_indices.len() != inner_built.penalties.len() {
7158                crate::bail_invalid_basis!(
7159                    "internal sz penalty metadata mismatch: activeinfos={}, penalties={}",
7160                    active_penalty_indices.len(),
7161                    inner_built.penalties.len()
7162                );
7163            }
7164            // Replicate each marginal penalty into the sum-to-zero contrast
7165            // space. With `L-1` free deviation blocks and the reference level
7166            // `d_L = -Σ_{k<L} d_k`, the marginal penalty summed over ALL `L`
7167            // levels, `Σ_{k=1}^{L} d_kᵀ S d_k`, expands to the `(I + 11ᵀ) ⊗ S`
7168            // contrast form (factor 2 on the diagonal blocks, 1 off-diagonal).
7169            //
7170            // PER-GROUP SMOOTHING PARAMETERS (#1074). mgcv's `bs="sz"` does NOT
7171            // pool that sum under one λ: `smooth.construct.sz` emits ONE penalty
7172            // matrix per factor level (here 6 separate `S`s, each with its own
7173            // smoothing parameter), so REML can shrink a low-amplitude group's
7174            // deviation curve hard while leaving a high-amplitude group nearly
7175            // unpenalized. A single shared wiggliness λ (the old construction)
7176            // forces every group to the SAME curvature budget, so a group whose
7177            // true curve is flat drags curvature into the noise of the busy
7178            // groups and vice-versa — systematic truth-recovery loss even when
7179            // the pooled total edf matches mgcv's (the observed `sz` 1.23× gap).
7180            //
7181            // We mirror mgcv exactly by splitting the per-marginal penalty
7182            // `Σ_{k=1}^{L} d_kᵀ S d_k` back into its `L` independent
7183            // rank-controlled summands BEFORE mapping to the contrast space, each
7184            // carrying its own λ:
7185            //   * level k < L (free block):  `d_kᵀ S d_k` → block-diagonal
7186            //     `(e_k e_kᵀ) ⊗ S`  (only the (k,k) block is `S`).
7187            //   * level L (reference):       `d_Lᵀ S d_L = (Σ_{j<L} d_j)ᵀ S (·)`
7188            //     → the fully-coupled `(11ᵀ) ⊗ S` block.
7189            // Summed at equal λ these `L` blocks recover the old `(I + 11ᵀ) ⊗ S`
7190            // exactly (`Σ_k e_k e_kᵀ = I`), so this is a strict generalization:
7191            // the pooled fit is still reachable, REML only GAINS the freedom to
7192            // spend curvature per group. The zero-sum reparameterization (hence
7193            // the `sz` vs `fs` identifiability) is untouched.
7194            //
7195            // `which_level ∈ 0..=l_minus_one`: `< l_minus_one` selects the single
7196            // free deviation block; `== l_minus_one` selects the reference-level
7197            // coupling block.
7198            let stz_per_group_penalty = |s_inner: &Array2<f64>, which_level: usize| -> Array2<f64> {
7199                let mut s_big = Array2::<f64>::zeros((p * l_minus_one, p * l_minus_one));
7200                if which_level < l_minus_one {
7201                    // (e_k e_kᵀ) ⊗ S: a single diagonal block.
7202                    let k = which_level;
7203                    let mut block = s_big.slice_mut(s![k * p..(k + 1) * p, k * p..(k + 1) * p]);
7204                    block.assign(s_inner);
7205                } else {
7206                    // (11ᵀ) ⊗ S: every block (diagonal and off-diagonal) is S.
7207                    for a in 0..l_minus_one {
7208                        for b in 0..l_minus_one {
7209                            let mut block =
7210                                s_big.slice_mut(s![a * p..(a + 1) * p, b * p..(b + 1) * p]);
7211                            block.assign(s_inner);
7212                        }
7213                    }
7214                }
7215                s_big
7216            };
7217            // One nullspace-dim entry per emitted penalty (must stay parallel to
7218            // `penalties`). Each per-group wiggliness block carries the marginal's
7219            // OWN nullity (a rank-`p` block touching a single level for the free
7220            // blocks; the coupling block is rank-`p` over the diagonal sum), and
7221            // the null ridges below record their own nullity.
7222            let mut nullspaces = Vec::<usize>::with_capacity(penalties.capacity());
7223            for (penalty_pos, s_inner) in inner_built.penalties.iter().enumerate() {
7224                let info_idx = active_penalty_indices[penalty_pos];
7225                let base_info = inner_built.penaltyinfo[info_idx].clone();
7226                let marginal_nullity = inner_built.nullspaces.get(penalty_pos).copied().unwrap_or(0);
7227                // Emit `L` independent per-level blocks for this marginal penalty.
7228                for which_level in 0..=l_minus_one {
7229                    let raw = stz_per_group_penalty(s_inner, which_level);
7230                    let (s_big, group_scale) = normalize_penalty_in_constrained_space(&raw);
7231                    let block = crate::basis::analyze_penalty_block_with_op(&s_big, None)?;
7232                    if block.rank == 0 {
7233                        continue;
7234                    }
7235                    if which_level == 0 {
7236                        // Reuse the marginal's own info slot for the first block so
7237                        // the existing normalization bookkeeping stays attached.
7238                        inner_built.penaltyinfo[info_idx].normalization_scale *= group_scale;
7239                        inner_built.penaltyinfo[info_idx].original_index = penalties.len();
7240                        inner_built.penaltyinfo[info_idx].effective_rank = block.rank;
7241                        inner_built.penaltyinfo[info_idx].nullspace_dim_hint = block.nullity;
7242                    } else {
7243                        let mut info = base_info.clone();
7244                        info.original_index = penalties.len();
7245                        info.normalization_scale = base_info.normalization_scale * group_scale;
7246                        info.effective_rank = block.rank;
7247                        info.nullspace_dim_hint = block.nullity;
7248                        info.kronecker_factors = None;
7249                        inner_built.penaltyinfo.push(info);
7250                    }
7251                    penalties.push(block.sym_penalty);
7252                    // The coupling block (which_level == l_minus_one) spans the
7253                    // marginal range on the diagonal-sum direction; the free
7254                    // blocks touch one level. Both leave the marginal null space
7255                    // unpenalized, recorded here so the null ridges below complete
7256                    // the double penalty.
7257                    nullspaces.push(marginal_nullity);
7258                }
7259            }
7260
7261            // Null-space ridge, mirroring the `bs="fs"` double-penalty
7262            // construction (#1605, same defect class as #700/#712/#713). The
7263            // marginal wiggliness penalty `S` shapes curvature but leaves the
7264            // {const, linear} null space of each deviation curve COMPLETELY
7265            // unpenalized. With that null space free, the single combined
7266            // wiggliness smoothing parameter cannot separate the per-group
7267            // intercept/slope variance from the curvature variance, so REML
7268            // parks the wiggliness `λ` high — over-smoothing (under-fitting) the
7269            // deviation blocks even when the truth lives in their span (the `sz`
7270            // recovery gap vs the `fs` superset). mgcv's `bs="fs"` fixes the
7271            // analogous gap by penalizing each null-space dimension SEPARATELY
7272            // under its own shared variance; we mirror that here while keeping
7273            // the zero-sum reparameterization, so the constraint (and the
7274            // identifiability of `sz` vs `fs`) is preserved. For each orthonormal
7275            // null direction `z_k` of the marginal penalty, add the rank-1
7276            // marginal penalty `z_k z_kᵀ` mapped into the SAME `(I + 11ᵀ)`
7277            // sum-to-zero contrast space, each carrying its own `λ`.
7278            if let Some(Some(z)) = inner_null_eigenvectors.first()
7279                && z.nrows() == p
7280            {
7281                for k in 0..z.ncols() {
7282                    let zk = z.column(k);
7283                    let mut p_k = Array2::<f64>::zeros((p, p));
7284                    for a in 0..p {
7285                        for b in 0..p {
7286                            p_k[[a, b]] = zk[a] * zk[b];
7287                        }
7288                    }
7289                    // Null ridges stay POOLED (the `(I + 11ᵀ) ⊗ z_k z_kᵀ` form):
7290                    // they govern the per-group intercept/slope shrinkage, which
7291                    // mgcv pools under one variance even for `sz`; only the
7292                    // curvature (wiggliness) penalty is split per group above.
7293                    let stz_pooled_null = {
7294                        let mut s_big = Array2::<f64>::zeros((p * l_minus_one, p * l_minus_one));
7295                        for a in 0..l_minus_one {
7296                            for b in 0..l_minus_one {
7297                                let factor = if a == b { 2.0 } else { 1.0 };
7298                                let mut block =
7299                                    s_big.slice_mut(s![a * p..(a + 1) * p, b * p..(b + 1) * p]);
7300                                block.assign(&p_k.mapv(|v| v * factor));
7301                            }
7302                        }
7303                        s_big
7304                    };
7305                    let (s_null, null_scale) =
7306                        normalize_penalty_in_constrained_space(&stz_pooled_null);
7307                    let null_block = crate::basis::analyze_penalty_block_with_op(&s_null, None)?;
7308                    if null_block.rank > 0 {
7309                        let original_index = penalties.len();
7310                        penalties.push(null_block.sym_penalty);
7311                        nullspaces.push(null_block.nullity);
7312                        inner_built.penaltyinfo.push(PenaltyInfo {
7313                            source: PenaltySource::Primary,
7314                            original_index,
7315                            active: true,
7316                            effective_rank: null_block.rank,
7317                            dropped_reason: None,
7318                            nullspace_dim_hint: null_block.nullity,
7319                            normalization_scale: null_scale,
7320                            kronecker_factors: None,
7321                        });
7322                    }
7323                }
7324            }
7325            inner_built.dim = p * l_minus_one;
7326            inner_built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense));
7327            inner_built.penalties = penalties;
7328            inner_built.ops = vec![None; inner_built.penalties.len()];
7329            inner_built.nullspaces = nullspaces;
7330            // Invariant: `null_eigenvectors[k]` must mirror `penalties[k]`'s
7331            // spectral null space. We just rebuilt `inner_built.penalties` from
7332            // Kronecker-like `S_big` blocks, so the previously-plumbed
7333            // `null_eigenvectors` (still parallel to the OLD per-level penalty)
7334            // is stale. Recompute from the rebuilt penalties to restore the
7335            // invariant; ditto for the joint-null absorption rotation.
7336            inner_built.null_eigenvectors =
7337                crate::basis::recompute_null_eigenvectors(&inner_built.penalties)?;
7338            inner_built.joint_null_rotation =
7339                crate::basis::compute_joint_null_rotation(&inner_built.penalties)?;
7340            inner_built.kronecker_factored = None;
7341            return Ok(inner_built);
7342        }
7343        SmoothBasisSpec::BSpline1D { feature_col, spec } => {
7344            if *feature_col >= data.ncols() {
7345                crate::bail_dim_basis!(
7346                    "term '{}' feature column {} out of bounds for {} columns",
7347                    term.name,
7348                    feature_col,
7349                    data.ncols()
7350                );
7351            }
7352            let mut spec_local = spec.clone();
7353            if term.shape != ShapeConstraint::None {
7354                // Shape-constrained B-splines are anchored by construction.
7355                // Sum-to-zero side constraints conflict with monotonic/convex cones.
7356                spec_local.identifiability = BSplineIdentifiability::None;
7357            }
7358            // Endpoint boundary conditions are structural for B-splines: the
7359            // basis builder bakes their homogeneous nullspace transform into
7360            // the design, penalties, and stored raw-basis transform.
7361            build_bspline_basis_1d(data.column(*feature_col), &spec_local)?
7362        }
7363        SmoothBasisSpec::ThinPlate {
7364            feature_cols,
7365            spec,
7366            input_scales,
7367        } => {
7368            if term.shape != ShapeConstraint::None {
7369                if feature_cols.len() != 1 {
7370                    crate::bail_invalid_basis!(
7371                        "ShapeConstraint::{:?} for term '{}' on ThinPlate basis requires exactly 1 feature axis; found {}",
7372                        term.shape,
7373                        term.name,
7374                        feature_cols.len()
7375                    );
7376                }
7377                shape_axis_col = Some(feature_cols[0]);
7378            }
7379            let mut x = select_columns(data, feature_cols)?;
7380            // Auto-standardize multivariate inputs: use stored scales (prediction)
7381            // or compute fresh ones (training). Same standardization-vs-
7382            // length-scale compensation as Matérn / hybrid Duchon: divide
7383            // the user's L by σ_geom so kernel(‖x_std − c_std‖/L_eff)
7384            // matches the original-coord kernel for uniform σ.
7385            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7386                apply_input_standardization(&mut x, s);
7387                (
7388                    Some(s.clone()),
7389                    compensate_length_scale_for_standardization(spec.length_scale, s),
7390                )
7391            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7392                apply_input_standardization(&mut x, &s);
7393                let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
7394                (Some(s), l_eff)
7395            } else {
7396                (None, spec.length_scale)
7397            };
7398            let mut spec_local = spec.clone();
7399            spec_local.length_scale = length_scale_eff;
7400            if matches!(
7401                spec_local.identifiability,
7402                SpatialIdentifiability::OrthogonalToParametric
7403            ) {
7404                spec_local.identifiability = SpatialIdentifiability::None;
7405            }
7406            let mut result = build_thin_plate_basis(x.view(), &spec_local).map_err(|err| {
7407                rewrite_thin_plate_knots_error(err, &term.name, feature_cols.len(), spec)
7408            })?;
7409            // Inject input scales into metadata; also restore the user's
7410            // original length_scale (not the σ_geom-compensated one) so a
7411            // metadata-driven rebuild that re-applies compensation does not
7412            // double-divide. The build may auto-promote to Duchon when
7413            // canonical TPS is infeasible (k < polynomial-nullspace size);
7414            // in that case patch the Duchon metadata variant so predict-time
7415            // round-trips through the same standardized data path.
7416            match &mut result.metadata {
7417                BasisMetadata::ThinPlate {
7418                    input_scales: ms,
7419                    length_scale,
7420                    ..
7421                } => {
7422                    *ms = scales;
7423                    *length_scale = spec.length_scale;
7424                }
7425                BasisMetadata::Duchon {
7426                    input_scales: ms,
7427                    length_scale,
7428                    ..
7429                } => {
7430                    // Auto-promotion (canonical TPS infeasible at this (d, k)).
7431                    // Since #1091 the promotion does NOT forward the incoming
7432                    // σ_geom-compensated `spec_local.length_scale` to the Duchon
7433                    // builder — it DISCARDS it and substitutes the geometric mean
7434                    // of the center pairwise distances (`promotion_length_scale`,
7435                    // the natural radial-kernel scale where κ·r ≈ O(1)). So the
7436                    // realized kernel bandwidth recorded in this metadata bears no
7437                    // fixed relation to the user's `spec.length_scale`; clobbering
7438                    // it to the user-facing value (the pre-#1091 behavior) makes
7439                    // freeze→replay re-derive `compensate(spec.length_scale, σ) ≠
7440                    // promotion_length_scale`, evaluating the kernel at the wrong
7441                    // bandwidth and corrupting the replayed design (#1091 broke the
7442                    // e7ff5ed83 freeze contract for the auto-promoted path).
7443                    //
7444                    // The freeze→replay round trip rebuilds through the Duchon arm,
7445                    // which re-applies σ_geom compensation:
7446                    //   replay_eff = compensate(metadata.length_scale, σ)
7447                    //              = metadata.length_scale / σ_geom.
7448                    // For replay_eff to reproduce the realized `promotion_length_scale`
7449                    // we must store the UN-compensated value `promotion_length_scale
7450                    // · σ_geom`. `compensate(1.0, σ) = 1/σ_geom`, so divide the
7451                    // realized scale by it to multiply back through σ_geom. With no
7452                    // standardization (`scales == None`) replay does not compensate,
7453                    // so the realized value is kept verbatim.
7454                    if let (Some(s), Some(realized)) = (scales.as_ref(), *length_scale) {
7455                        let inv_sigma_geom =
7456                            compensate_length_scale_for_standardization(1.0, s);
7457                        if inv_sigma_geom.is_finite() && inv_sigma_geom > 0.0 {
7458                            *length_scale = Some(realized / inv_sigma_geom);
7459                        }
7460                    }
7461                    *ms = scales;
7462                }
7463                _ => {}
7464            }
7465            result
7466        }
7467        SmoothBasisSpec::Sphere { feature_cols, spec } => {
7468            if term.shape != ShapeConstraint::None {
7469                crate::bail_invalid_basis!(
7470                    "ShapeConstraint::{:?} for term '{}' is not supported on spherical splines",
7471                    term.shape,
7472                    term.name
7473                );
7474            }
7475            let x = select_columns(data, feature_cols)?;
7476            build_spherical_spline_basis(x.view(), spec)?
7477        }
7478        SmoothBasisSpec::ConstantCurvature { feature_cols, spec } => {
7479            if term.shape != ShapeConstraint::None {
7480                crate::bail_invalid_basis!(
7481                    "ShapeConstraint::{:?} for term '{}' is not supported on constant-curvature smooths",
7482                    term.shape,
7483                    term.name
7484                );
7485            }
7486            // Chart coordinates are consumed verbatim: NO auto-standardization.
7487            // Rescaling axes would change the chart gauge `1 + κ‖x‖²` and
7488            // silently redefine which curvature κ refers to (the same point
7489            // cloud at a different chart scale has a different κ̂); the user's
7490            // coordinates ARE the geometry here, exactly as for the sphere
7491            // smooth's (lat, lon).
7492            let x = select_columns(data, feature_cols)?;
7493            build_constant_curvature_basis(x.view(), spec)?
7494        }
7495        SmoothBasisSpec::MeasureJet {
7496            feature_cols,
7497            spec,
7498            input_scales,
7499        } => {
7500            if term.shape != ShapeConstraint::None {
7501                crate::bail_invalid_basis!(
7502                    "ShapeConstraint::{:?} for term '{}' is not supported on measure-jet smooths",
7503                    term.shape,
7504                    term.name
7505                );
7506            }
7507            let mut x = select_columns(data, feature_cols)?;
7508            // Matern-style per-axis standardization; the realized σ vector is
7509            // persisted into the metadata for predict-time replay.
7510            //
7511            // Length-scale round-trip contract (owning statement; the freeze
7512            // and frozen-validation arms reference it): `input_scales: Some`
7513            // marks the REPLAY path — the frozen length_scale is already the
7514            // realized post-standardization value and passes through
7515            // verbatim. Fresh path: an explicit user length_scale is in
7516            // ORIGINAL coordinates and gets the σ_geom compensation; the 0.0
7517            // auto sentinel passes through (auto-derivation runs inside the
7518            // builder, post-standardization).
7519            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7520                apply_input_standardization(&mut x, s);
7521                (Some(s.clone()), spec.length_scale)
7522            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7523                apply_input_standardization(&mut x, &s);
7524                let l_eff = if spec.length_scale > 0.0 {
7525                    compensate_length_scale_for_standardization(spec.length_scale, &s)
7526                } else {
7527                    spec.length_scale
7528                };
7529                (Some(s), l_eff)
7530            } else {
7531                (None, spec.length_scale)
7532            };
7533            let mut spec_local = spec.clone();
7534            spec_local.length_scale = length_scale_eff;
7535            let mut result = build_measure_jet_basis(x.view(), &spec_local)?;
7536            if let BasisMetadata::MeasureJet {
7537                input_scales: ms, ..
7538            } = &mut result.metadata
7539            {
7540                *ms = scales;
7541            }
7542            result
7543        }
7544        SmoothBasisSpec::Matern {
7545            feature_cols,
7546            spec,
7547            input_scales,
7548        } => {
7549            if term.shape != ShapeConstraint::None {
7550                if feature_cols.len() != 1 {
7551                    crate::bail_invalid_basis!(
7552                        "ShapeConstraint::{:?} for term '{}' on Matern basis requires exactly 1 feature axis; found {}",
7553                        term.shape,
7554                        term.name,
7555                        feature_cols.len()
7556                    );
7557                }
7558                shape_axis_col = Some(feature_cols[0]);
7559            }
7560            let mut x = select_columns(data, feature_cols)?;
7561            // Auto-standardization (per-axis division by σ_a) reinterprets
7562            // the user's `length_scale` from original data coordinates
7563            // into post-standardization coordinates: for uniform σ_a = σ,
7564            // `kernel(‖x_std − c_std‖/L)` equals `kernel(‖x − c‖/(σ·L))`,
7565            // so the effective kernel range shrinks by σ. To keep
7566            // `length_scale` consistently expressed in *original* data
7567            // coordinates regardless of axis variances, we standardize
7568            // and divide L by σ_geom = (∏σ_a)^(1/d). For uniform σ this
7569            // recovers the user's kernel exactly; for anisotropic data
7570            // the resulting per-axis effective scales σ_a / σ_geom are
7571            // the standard Mahalanobis preconditioning and preserve the
7572            // geometric-mean kernel range. Storing the σ vector in
7573            // metadata.input_scales makes the same transformation
7574            // replayable at predict time.
7575            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7576                apply_input_standardization(&mut x, s);
7577                (
7578                    Some(s.clone()),
7579                    compensate_length_scale_for_standardization(spec.length_scale, s),
7580                )
7581            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7582                apply_input_standardization(&mut x, &s);
7583                let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
7584                (Some(s), l_eff)
7585            } else {
7586                (None, spec.length_scale)
7587            };
7588            let mut spec_local = spec.clone();
7589            spec_local.length_scale = length_scale_eff;
7590            let mut result = build_matern_basiswithworkspace(x.view(), &spec_local, workspace)?;
7591            if let BasisMetadata::Matern {
7592                input_scales,
7593                length_scale,
7594                ..
7595            } = &mut result.metadata
7596            {
7597                *input_scales = scales;
7598                *length_scale = spec.length_scale;
7599            }
7600            result
7601        }
7602        SmoothBasisSpec::Duchon {
7603            feature_cols,
7604            spec,
7605            input_scales,
7606        } => {
7607            if term.shape != ShapeConstraint::None {
7608                if feature_cols.len() != 1 {
7609                    crate::bail_invalid_basis!(
7610                        "ShapeConstraint::{:?} for term '{}' on Duchon basis requires exactly 1 feature axis; found {}",
7611                        term.shape,
7612                        term.name,
7613                        feature_cols.len()
7614                    );
7615                }
7616                shape_axis_col = Some(feature_cols[0]);
7617            }
7618            let mut x = select_columns(data, feature_cols)?;
7619            // Hybrid Duchon (length_scale=Some) is governed by the same
7620            // standardization-vs-length-scale equivalence as Matérn: the
7621            // user's `length_scale` is interpreted in original data
7622            // coordinates, but auto-standardization (per-axis division by
7623            // σ_a) reinterprets it as σ_geom · L. Pre-multiply by 1/σ_geom
7624            // so kernel(‖x_std − c_std‖/L_eff) reproduces the user's
7625            // original-coord kernel exactly for uniform σ_a, and reduces
7626            // to standard Mahalanobis preconditioning for anisotropic σ.
7627            // Pure Duchon (length_scale=None) is scale-free and needs no
7628            // compensation.
7629            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7630                apply_input_standardization(&mut x, s);
7631                (
7632                    Some(s.clone()),
7633                    compensate_optional_length_scale_for_standardization(spec.length_scale, s),
7634                )
7635            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7636                apply_input_standardization(&mut x, &s);
7637                let l_eff =
7638                    compensate_optional_length_scale_for_standardization(spec.length_scale, &s);
7639                (Some(s), l_eff)
7640            } else {
7641                (None, spec.length_scale)
7642            };
7643            let mut spec_local = spec.clone();
7644            spec_local.length_scale = length_scale_eff;
7645            // The Duchon input axis is standardized in place above (`x → x/σ`,
7646            // scale-only, no centering). A 1-D cyclic boundary `[start, end)`
7647            // declared in ORIGINAL covariate units must move into that same
7648            // standardized frame, or the modular wrap in
7649            // `build_cyclic_duchon_basis_1dwithworkspace` folds the standardized
7650            // coordinate against an original-unit period: the seam never closes
7651            // and the basis silently degrades to non-periodic (#1074:
7652            // `duchon(x, periodic=true)` predictions diverged across the wrap,
7653            // f(0) ≠ f(2π)). Rescale by the same 1/σ applied to the data so
7654            // training and predict share one periodic geometry.
7655            if let (Some(s), crate::basis::OneDimensionalBoundary::Cyclic { start, end }) =
7656                (scales.as_ref(), spec_local.boundary.clone())
7657                && s.len() == 1
7658                && s[0] > 0.0
7659            {
7660                spec_local.boundary = crate::basis::OneDimensionalBoundary::Cyclic {
7661                    start: start / s[0],
7662                    end: end / s[0],
7663                };
7664            }
7665            if matches!(
7666                spec_local.identifiability,
7667                SpatialIdentifiability::OrthogonalToParametric
7668            ) {
7669                spec_local.identifiability = SpatialIdentifiability::None;
7670            }
7671            let mut result = build_duchon_basiswithworkspace(x.view(), &spec_local, workspace)?;
7672            if let BasisMetadata::Duchon {
7673                input_scales,
7674                length_scale,
7675                ..
7676            } = &mut result.metadata
7677            {
7678                *input_scales = scales;
7679                *length_scale = spec.length_scale;
7680            }
7681            result
7682        }
7683        SmoothBasisSpec::Pca {
7684            feature_cols,
7685            basis_matrix,
7686            centered,
7687            smooth_penalty,
7688            center_mean,
7689            pca_basis_path,
7690            chunk_size,
7691        } => {
7692            if term.shape != ShapeConstraint::None {
7693                crate::bail_invalid_basis!(
7694                    "ShapeConstraint::{:?} for term '{}' is not supported on Pca basis",
7695                    term.shape,
7696                    term.name
7697                );
7698            }
7699            build_pca_smooth_basis(
7700                data,
7701                feature_cols,
7702                basis_matrix,
7703                *centered,
7704                *smooth_penalty,
7705                center_mean.as_ref(),
7706                pca_basis_path.as_ref(),
7707                *chunk_size,
7708            )?
7709        }
7710        SmoothBasisSpec::TensorBSpline { feature_cols, spec } => {
7711            build_tensor_bspline_basis(data, feature_cols, spec)?
7712        }
7713        SmoothBasisSpec::ByVariable { .. } => {
7714            crate::bail_invalid_basis!(
7715                "internal: ByVariable smooths must return before inner basis dispatch"
7716            );
7717        }
7718        SmoothBasisSpec::BySmooth { .. } => {
7719            crate::bail_invalid_basis!("internal: BySmooth smooths must be lowered to ByVariable before inner basis dispatch"
7720                    .to_string(),);
7721        }
7722        SmoothBasisSpec::FactorSmooth { spec } => {
7723            if term.shape != ShapeConstraint::None {
7724                crate::bail_invalid_basis!(
7725                    "ShapeConstraint::{:?} is unsupported for factor smooth term '{}'",
7726                    term.shape,
7727                    term.name
7728                );
7729            }
7730            return build_factor_smooth(data, spec, &term.name, workspace);
7731        }
7732    };
7733
7734    // The Matérn design ALWAYS uses the operator-collocation {mass, tension,
7735    // stiffness} penalty triplet, overriding whatever penalty
7736    // `build_matern_basis_seeded` produced for the `double_penalty` flag.
7737    //
7738    // #1074 investigated swapping this for the genuine RKHS kernel penalty
7739    // `β' K_CC β` (mgcv `bs="gp"` / fields kriging) on the theory that the
7740    // operator triplet under-smooths the rougher half-integer kernels. MSI
7741    // truth-recovery measurement REFUTED that: the kernel penalty did NOT
7742    // improve ν=3/2 recovery (`matern(x,nu=1.5)` RMSE-vs-truth stayed 0.0554)
7743    // and it REGRESSED the high-frequency-init guard — `matern(x,nu≥5/2)` on
7744    // sin(2π·8·x) collapsed (span 0.53, RMSE 0.70) because the single RKHS
7745    // norm over-smooths a high-frequency truth where the Sobolev-order operator
7746    // dials do not. The operator triplet is therefore retained as the Matérn
7747    // penalty, and the κ-optimizer re-key / ψ-derivative paths route through the
7748    // same triplet builder so the block count stays ψ-stable (#1270).
7749    if let SmoothBasisSpec::Matern { .. } = &term.basis {
7750        let (penalties, nullspace_dims, penaltyinfo) =
7751            matern_operator_penalty_triplet_from_metadata(&built.metadata)?;
7752        built.penalties = penalties;
7753        built.nullspace_dims = nullspace_dims;
7754        built.penaltyinfo = penaltyinfo;
7755    }
7756
7757    let p_local = built.design.ncols();
7758    let mut metadata = built.metadata.clone();
7759    // Extract factored Kronecker representation before consuming fields.
7760    // Invalidate it if shape transforms will be applied (they break structure).
7761    let kron_factored = if term.shape == ShapeConstraint::None {
7762        built.kronecker_factored
7763    } else {
7764        None
7765    };
7766    let mut design_t = built.design;
7767    let mut penalties_t: Vec<Array2<f64>> = built.penalties;
7768    // Ops vector parallel to `penalties_t`. Survives unchanged through the
7769    // identity path; nulled element-wise when `T^T S T` reparametrization
7770    // is applied (operator no longer bit-equivalent to the transformed
7771    // matrix); wrapped in `ScaledPenaltyOp` after Frobenius normalization.
7772    let mut ops_t: Vec<Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>> =
7773        built.ops;
7774    if matches!(
7775        spatial_identifiability_policy(term),
7776        Some(SpatialIdentifiability::OrthogonalToParametric)
7777    ) {
7778        metadata = freeze_raw_spatial_metadata(metadata, design_t.ncols());
7779    }
7780
7781    let active_penaltyinfo_t = built
7782        .penaltyinfo
7783        .iter()
7784        .filter(|info| info.active)
7785        .cloned()
7786        .collect::<Vec<_>>();
7787    let pre_dropped_penaltyinfo_t = built
7788        .penaltyinfo
7789        .iter()
7790        .filter(|info| !info.active)
7791        .cloned()
7792        .collect::<Vec<_>>();
7793    let use_box_reparam =
7794        term.shape != ShapeConstraint::None && shape_uses_box_reparameterization(&term.basis);
7795    if let Some((order, sign)) = shape_order_and_sign(term.shape)
7796        && use_box_reparam
7797    {
7798        // Order 1 (monotone): the plain first-difference cone θ_{i+1}−θ_i ≥ 0 is
7799        // the control-polygon monotonicity criterion, which is independent of
7800        // Greville-abscissa spacing (it only fixes the *sign* of consecutive
7801        // control-point gaps), so the integer-difference transform is exact.
7802        //
7803        // Order 2 (convex/concave): the plain second-difference cone is only
7804        // correct for evenly spaced Greville abscissae. gam's B-splines are
7805        // clamped (and may use quantile knots), so the abscissae are not
7806        // uniform and the geometrically-correct cone is the second *divided*
7807        // difference. Build the Greville-scaled transform so γ_{≥2} ≥ 0
7808        // certifies convexity of the function, not of the raw coefficient
7809        // index. Periodic B-splines use uniform interior knots (uniform
7810        // abscissae), where the divided differences coincide with the integer
7811        // differences up to scale, so the plain path stays exact there.
7812        let t = if order == 2 {
7813            let bspline_meta = match &metadata {
7814                BasisMetadata::BSpline1D {
7815                    knots,
7816                    degree,
7817                    periodic,
7818                    ..
7819                } if periodic.is_none() => Some((knots.clone(), degree.unwrap_or(0))),
7820                _ => None,
7821            };
7822            match bspline_meta {
7823                Some((knots, degree)) if degree >= 1 => {
7824                    let greville = crate::basis::compute_greville_abscissae(&knots, degree)?;
7825                    if greville.len() != p_local {
7826                        crate::bail_invalid_basis!(
7827                            "shape-constraint Greville abscissae count {} does not match basis dim {} for term '{}'",
7828                            greville.len(),
7829                            p_local,
7830                            term.name
7831                        );
7832                    }
7833                    convex_divided_difference_transform_matrix(&greville, sign)?
7834                }
7835                _ => cumulative_sum_transform_matrix(p_local, order, sign),
7836            }
7837        } else {
7838            cumulative_sum_transform_matrix(p_local, order, sign)
7839        };
7840        // Coefficient-side transform: wrap the design in an operator that
7841        // applies T on the coefficient side, preserving sparsity/operator
7842        // structure of the inner design.
7843        let inner_dense = match design_t {
7844            DesignMatrix::Dense(d) => d,
7845            DesignMatrix::Sparse(sp) => gam_linalg::matrix::DenseDesignMatrix::from(
7846                sp.try_to_dense_arc("shape-constrained coefficient transform")
7847                    .map_err(BasisError::InvalidInput)?,
7848            ),
7849        };
7850        let coeff_op = gam_linalg::matrix::CoefficientTransformOperator::new(inner_dense, t.clone())
7851            .map_err(|e| BasisError::InvalidInput(format!("CoefficientTransformOperator: {e}")))?;
7852        design_t = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(coeff_op)));
7853        if penalties_t.len() != active_penaltyinfo_t.len() {
7854            crate::bail_invalid_basis!(
7855                "internal box-reparam penalty/info mismatch for term '{}': penalties={}, infos={}",
7856                term.name,
7857                penalties_t.len(),
7858                active_penaltyinfo_t.len()
7859            );
7860        }
7861        // Wiggliness penalties undergo the exact congruence `S → TᵀST` (PSD
7862        // preserving). The double-penalty *nullspace shrinkage* ridge must NOT:
7863        // it is a unit-eigenvalue projector `ZZᵀ` onto null(S_wiggle) in the
7864        // β (B-spline coefficient) coordinates, and the congruence
7865        // `Tᵀ(ZZᵀ)T = (TᵀZ)(TᵀZ)ᵀ` is no longer a projector — its eigenvalues
7866        // blow up by the conditioning of the cumulative-sum `T` (cond(T) grows
7867        // with the basis dim), concentrating an enormous penalty on the leading
7868        // γ₀ "level" coordinate. REML then drives the shared λ to its ceiling
7869        // and the smooth collapses to a flat constant (#509, the over-smoothing
7870        // face). The principled fix keeps mgcv's double-penalty semantics in the
7871        // *reparametrized* space: rebuild the ridge as the unit-eigenvalue
7872        // nullspace projector of the transformed wiggliness penalty `TᵀST`, so
7873        // the double penalty shrinks exactly the unpenalized polynomial
7874        // directions of the γ-space smooth with eigenvalue 1, identical in
7875        // conditioning to the unconstrained fit.
7876        let transformed_wiggliness = penalties_t
7877            .iter()
7878            .zip(active_penaltyinfo_t.iter())
7879            .find(|(_, info)| !matches!(info.source, PenaltySource::DoublePenaltyNullspace))
7880            .map(|(s_local, _)| {
7881                let tt_s = fast_atb(&t, s_local);
7882                fast_ab(&tt_s, &t)
7883            });
7884        let mut rebuilt = Vec::with_capacity(penalties_t.len());
7885        for (s_local, info) in penalties_t.iter().zip(active_penaltyinfo_t.iter()) {
7886            if matches!(info.source, PenaltySource::DoublePenaltyNullspace) {
7887                // #1654: the double-penalty nullspace ridge under the box
7888                // reparameterization.
7889                //
7890                // For the CURVATURE constraints (order == 2, convex/concave) the
7891                // box transform `T` is the Greville-scaled second *divided*
7892                // difference map (`convex_divided_difference_transform_matrix`).
7893                // Rebuilding the ridge from scratch as the orthonormal null-space
7894                // projector of `TᵀST` (the #509-era monotone fix below) yields a
7895                // γ-space ridge `Z_γ Z_γᵀ` whose null subspace is the affine
7896                // (level γ₀ + slope γ₁) face — the SAME subspace targeted in
7897                // β-space, but measured in the γ inner product rather than the
7898                // β one. A reparameterization `β = Tγ` must leave the penalized
7899                // REML fit invariant, which requires every penalty block to
7900                // transform by the SAME congruence `S ↦ TᵀST`; the from-scratch
7901                // projector rebuild is NOT that congruence, so it silently
7902                // re-weights the level/slope shrinkage relative to the wiggliness
7903                // penalty (each block is independently Frobenius-normalized just
7904                // below, decoupling their scales). The distorted REML λ landscape
7905                // then drives the convex/concave smooth into the flat linear
7906                // corner (curvature pinned ≈ 0, EDF ≈ 1.5) for a
7907                // seed/basis-dimension–specific subset of fits, even though an
7908                // unconstrained `s(x)` on the same data recovers the convex truth
7909                // at EDF ≈ 4. Restoring the exact congruence `Tᵀ R_β T` for the
7910                // ridge keeps the box reparameterization a true invertible
7911                // change of coordinates, so the curvature-constrained fit tracks
7912                // the unconstrained smoothing instead of over-smoothing
7913                // (verified: seed-7/k-20 truth-RMSE 0.31 → 0.045).
7914                //
7915                // For MONOTONE (order == 1) `T` is the cumulative-sum transform
7916                // whose conditioning grows fast with the basis dimension; there
7917                // the congruence concentrates an enormous penalty on the leading
7918                // γ₀ level coordinate and over-smooths to a flat constant (#509),
7919                // which the from-scratch unit-eigenvalue projector rebuild was
7920                // introduced to cure. Keep that path for monotone.
7921                if order == 2 {
7922                    let tt_s = fast_atb(&t, s_local);
7923                    rebuilt.push(fast_ab(&tt_s, &t));
7924                } else {
7925                    let s_wiggle_t = transformed_wiggliness.as_ref().ok_or_else(|| {
7926                        BasisError::InvalidInput(format!(
7927                            "box-reparam term '{}' has a double-penalty ridge but no primary wiggliness penalty to derive its nullspace from",
7928                            term.name
7929                        ))
7930                    })?;
7931                    let ridge = crate::basis::build_nullspace_shrinkage_penalty(s_wiggle_t)?
7932                        .map(|shrink| shrink.sym_penalty)
7933                        .unwrap_or_else(|| Array2::<f64>::zeros((p_local, p_local)));
7934                    rebuilt.push(ridge);
7935                }
7936            } else {
7937                let tt_s = fast_atb(&t, s_local);
7938                rebuilt.push(fast_ab(&tt_s, &t));
7939            }
7940        }
7941        penalties_t = rebuilt;
7942        // T^T S T (and the rebuilt γ-space ridge) invalidate op-form
7943        // bit-equivalence; drop ops here.
7944        ops_t = vec![None; penalties_t.len()];
7945    }
7946    if penalties_t.len() != active_penaltyinfo_t.len() {
7947        crate::bail_invalid_basis!(
7948            "internal penalty metadata mismatch for term '{}': active penalties={}, active infos={}",
7949            term.name,
7950            penalties_t.len(),
7951            active_penaltyinfo_t.len()
7952        );
7953    }
7954    if ops_t.len() != penalties_t.len() {
7955        ops_t = vec![None; penalties_t.len()];
7956    }
7957    let penalty_candidates = penalties_t
7958        .into_iter()
7959        .zip(active_penaltyinfo_t.into_iter())
7960        .zip(ops_t.into_iter())
7961        .map(
7962            |((matrix, info), op_in)| -> Result<PenaltyCandidate, BasisError> {
7963                let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
7964                let normalization_scale = info.normalization_scale * c_new;
7965                let op_scale = 1.0 / c_new;
7966                let kronecker_scale = 1.0 / c_new;
7967                // Frobenius rescale: wrap inner op in `ScaledPenaltyOp(1/c_new)`
7968                // so `op.as_dense() == matrix` post-normalization.
7969                let scaled_op = if op_scale > 0.0 && op_scale.is_finite() {
7970                    op_in.map(|op| {
7971                        std::sync::Arc::new(crate::analytic_penalties::ScaledPenaltyOp::new(
7972                            op, op_scale,
7973                        ))
7974                            as std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>
7975                    })
7976                } else {
7977                    None
7978                };
7979                let kronecker_factors = info.kronecker_factors.map(|mut factors| {
7980                    if let Some(first) = factors.first_mut() {
7981                        first.mapv_inplace(|v| v * kronecker_scale);
7982                    }
7983                    factors
7984                });
7985                Ok(PenaltyCandidate {
7986                    nullspace_dim_hint: info.nullspace_dim_hint,
7987                    matrix,
7988                    source: info.source,
7989                    normalization_scale,
7990                    kronecker_factors,
7991                    op: scaled_op,
7992                })
7993            },
7994        )
7995        .collect::<Result<Vec<_>, _>>()?;
7996    let (penalties_t, nullspaces_t, penaltyinfo_t, null_eigenvectors_t, ops_t) =
7997        crate::basis::filter_active_penalty_candidates_with_ops(penalty_candidates)?;
7998    let shape_linear_constraints = if term.shape != ShapeConstraint::None && !use_box_reparam {
7999        let axis = shape_axis_col.ok_or_else(|| {
8000            BasisError::InvalidInput(format!(
8001                "internal shape-constraint axis missing for term '{}'",
8002                term.name
8003            ))
8004        })?;
8005        let (x_shape_eval, design_shape_eval) =
8006            build_shape_constraint_design_1d(data, term, &metadata, axis)?;
8007        build_shape_linear_constraints_1d(
8008            x_shape_eval.view(),
8009            design_shape_eval.view(),
8010            term.shape,
8011        )?
8012    } else {
8013        None
8014    };
8015    let linear_constraints_local = merge_linear_constraints_global(shape_linear_constraints, None);
8016
8017    // Joint-null absorption rotation. Fresh fit specs compute Q from the final
8018    // per-smooth penalty set (after all in-smooth reparameterizations have
8019    // already been applied). Frozen specs already carry the complete realized
8020    // coefficient chart in their `FrozenTransform`; recomputing Q there would
8021    // rotate an already-frozen chart a second time and desynchronize value
8022    // rebuilds from derivative operators.
8023    //
8024    // Kronecker-factored smooths (tensor B-splines under `TensorBSplineIdentifiability::None`)
8025    // carry their joint penalty as `Σ_d S_d` with `S_d = I ⊗ … ⊗ S_d^{1D} ⊗ … ⊗ I`.
8026    // The joint null space is the tensor of marginal nulls and is handled directly
8027    // by the REML runtime's `kronecker_penalty_system` path (see
8028    // `runtime.rs:8334-8344`). Applying a dense (p × p) Q here would densify
8029    // `X_raw = mx ⊗ my` into `X_raw · Q`, destroying the Kronecker product
8030    // structure that the runtime relies on for fast log-det/derivative
8031    // assembly — and the rotation block at the wrapper site also unconditionally
8032    // wipes `kronecker_factored`, leaving the runtime to fall back to the
8033    // dense per-block log-det. Skip the rotation for Kronecker-factored terms
8034    // so the factored representation survives end-to-end.
8035    let joint_null_rotation = match term.joint_null_rotation.clone() {
8036        Some(persisted) => Some(persisted),
8037        None if smooth_has_frozen_identifiability(term) => None,
8038        None if kron_factored.is_some() => None,
8039        None => crate::basis::compute_joint_null_rotation(&penalties_t)?,
8040    };
8041
8042    Ok(LocalSmoothTermBuild {
8043        dim: p_local,
8044        design: design_t,
8045        penalties: penalties_t,
8046        ops: ops_t,
8047        nullspaces: nullspaces_t,
8048        null_eigenvectors: null_eigenvectors_t,
8049        joint_null_rotation,
8050        penaltyinfo: penaltyinfo_t,
8051        pre_dropped_penaltyinfo: pre_dropped_penaltyinfo_t,
8052        metadata,
8053        linear_constraints: linear_constraints_local,
8054        box_reparam: use_box_reparam,
8055        kronecker_factored: kron_factored,
8056    })
8057}
8058
8059pub fn build_smooth_design(
8060    data: ArrayView2<'_, f64>,
8061    terms: &[SmoothTermSpec],
8062) -> Result<RawSmoothDesign, BasisError> {
8063    let mut ws = crate::basis::BasisWorkspace::new();
8064    build_smooth_design_withworkspace(data, terms, &mut ws)
8065}
8066
8067/// Like `build_smooth_design`, but honors the caller workspace policy while
8068/// building each planned smooth term with an independent per-term workspace.
8069///
8070/// Independent workspaces avoid shared mutable distance-cache state during the
8071/// parallel term build; the final design, penalties, and metadata are assembled
8072/// in the original smooth-term order.
8073pub fn build_smooth_design_withworkspace(
8074    data: ArrayView2<'_, f64>,
8075    terms: &[SmoothTermSpec],
8076    workspace: &mut crate::basis::BasisWorkspace,
8077) -> Result<RawSmoothDesign, BasisError> {
8078    validate_smooth_terms_finite_inputs(data, terms)?;
8079    build_smooth_design_withworkspace_unvalidated(data, terms, workspace)
8080}
8081
8082pub fn build_smooth_design_withworkspace_unvalidated(
8083    data: ArrayView2<'_, f64>,
8084    terms: &[SmoothTermSpec],
8085    workspace: &mut crate::basis::BasisWorkspace,
8086) -> Result<RawSmoothDesign, BasisError> {
8087    let mut planned_blocks = plan_joint_spatial_centers_for_term_blocks(data, &[terms.to_vec()])?;
8088    let planned_terms = planned_blocks.pop().ok_or_else(|| {
8089        BasisError::InvalidInput(
8090            "joint spatial center planner returned no smooth blocks".to_string(),
8091        )
8092    })?;
8093    let policy = workspace.policy().clone();
8094    let local_builds: Vec<LocalSmoothTermBuild> = {
8095        use rayon::iter::{IntoParallelIterator, ParallelIterator};
8096        planned_terms
8097            .into_par_iter()
8098            .map(|term| {
8099                let mut term_workspace = crate::basis::BasisWorkspace::with_policy(policy.clone());
8100                build_single_local_smooth_term(data, &term, &mut term_workspace)
8101            })
8102            .collect::<Result<Vec<_>, _>>()?
8103    };
8104
8105    let total_p: usize = local_builds.iter().map(|built| built.dim).sum();
8106
8107    let mut local_designs: Vec<DesignMatrix> = Vec::with_capacity(local_builds.len());
8108    let mut terms_out = Vec::<SmoothTerm>::with_capacity(terms.len());
8109    let mut penalties_global = Vec::<BlockwisePenalty>::new();
8110    let mut nullspace_dims_global = Vec::<usize>::new();
8111    let mut penaltyinfo_global = Vec::<PenaltyBlockInfo>::new();
8112    let mut dropped_penaltyinfo_global = Vec::<DroppedPenaltyBlockInfo>::new();
8113    let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
8114    let mut any_bounds = false;
8115    // Each linear-constraint row only touches the current term's column slice.
8116    // Track `(col_start, col_end, local_row_values)` and assemble the final
8117    // dense `Array2` in one pass, avoiding per-row `Array1::zeros(total_p)`
8118    // allocation plus a row-by-row copy at the end.
8119    let mut linear_constraintsrows: Vec<(usize, usize, Array1<f64>)> = Vec::new();
8120    let mut linear_constraints_b: Vec<f64> = Vec::new();
8121
8122    let mut col_start = 0usize;
8123    for (term, mut built) in terms.iter().zip(local_builds.into_iter()) {
8124        let p_local = built.dim;
8125        let col_end = col_start + p_local;
8126        let lb_local = if built.box_reparam {
8127            shape_lower_bounds_local(term.shape, p_local)
8128        } else {
8129            None
8130        };
8131
8132        // Stage-2 joint-null absorption rotation. Fired *before* the
8133        // penalty / design / global aggregation loops below so that every
8134        // subsequent reference to `built.penalties`, `built.design`, and
8135        // `built.ops` sees the post-rotation values.
8136        //
8137        // The math: when the smooth's joint penalty `Σ_k S_k` has a
8138        // non-trivial null space, eigh selects `Q = [U_range | U_null]`
8139        // with null columns at the tail. Setting `β_raw = Q · γ` and
8140        // applying:
8141        //     design        ← X · Q
8142        //     penalties[k]  ← Qᵀ · S_k · Q   (block-diag, zero null tail)
8143        // yields a model whose fitted γ is invariant to the rotation
8144        // (since likelihood depends only on `X · β_raw = X · Q · γ`), but
8145        // whose penalty is full-rank on the range columns. The large-scale
8146        // failing case (cert refusal in the joint-Newton inner solve)
8147        // resolves because `H_pen = H_loglik + S` becomes full rank on
8148        // the smooth's range columns.
8149        //
8150        // Rotation is suppressed when the smooth carries coordinate-wise
8151        // shape constraints (`lb_local` or `built.linear_constraints`):
8152        // those encode a cone in the original coordinate system and a
8153        // general orthogonal rotation breaks the cone geometry. Smooths
8154        // with shape constraints typically have full-rank joint penalty
8155        // (their structural shape comes from the cone, not from null
8156        // directions in the penalty), so suppression is rarely a loss.
8157        //
8158        // `applied_rotation` carries the Q that was applied (or `None`
8159        // if no rotation fired). It is persisted onto `SmoothTerm` below
8160        // so prediction-side `X_new_raw · Q` replay can reproduce the
8161        // exact rotation. Persistence through the saved-model artifact
8162        // is a follow-up — see the doc on `SmoothTerm.joint_null_rotation`.
8163        let applied_rotation: Option<crate::basis::JointNullRotation> = match (
8164            built.joint_null_rotation.take(),
8165            lb_local.is_some(),
8166            built.linear_constraints.is_some(),
8167        ) {
8168            (Some(rot), false, false) => {
8169                let q = &rot.rotation;
8170                let dense = built
8171                    .design
8172                    .try_to_dense_by_chunks("joint-null absorption rotation")
8173                    .map_err(BasisError::InvalidInput)?;
8174                let rotated = gam_linalg::faer_ndarray::fast_ab(&dense, q);
8175                built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(rotated));
8176                built.penalties = built
8177                    .penalties
8178                    .into_iter()
8179                    .map(|s_local| {
8180                        let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
8181                        gam_linalg::faer_ndarray::fast_ab(&qt_s, q)
8182                    })
8183                    .collect();
8184                built.ops = vec![None; built.penalties.len()];
8185                built.kronecker_factored = None;
8186                Some(rot)
8187            }
8188            (Some(_), _, _) => None,
8189            (None, _, _) => None,
8190        };
8191
8192        let activeinfos = built
8193            .penaltyinfo
8194            .iter()
8195            .filter(|info| info.active)
8196            .collect::<Vec<_>>();
8197        if activeinfos.len() != built.penalties.len() {
8198            crate::bail_invalid_basis!(
8199                "internal penalty info mismatch for term '{}': activeinfos={}, penalties={}",
8200                term.name,
8201                activeinfos.len(),
8202                built.penalties.len()
8203            );
8204        }
8205        for (((s_local, &ns), info), op_local) in built
8206            .penalties
8207            .iter()
8208            .zip(built.nullspaces.iter())
8209            .zip(activeinfos.into_iter())
8210            .zip(built.ops.iter())
8211        {
8212            let global_index = penalties_global.len();
8213            penalties_global.push(
8214                BlockwisePenalty::new(col_start..col_end, s_local.clone())
8215                    .with_op(op_local.clone()),
8216            );
8217            nullspace_dims_global.push(ns);
8218            let mut penalty = info.clone();
8219            penalty.nullspace_dim_hint = ns;
8220            penaltyinfo_global.push(PenaltyBlockInfo {
8221                global_index,
8222                termname: Some(term.name.clone()),
8223                penalty,
8224            });
8225        }
8226        for info in built.penaltyinfo.iter().filter(|info| !info.active) {
8227            dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
8228                termname: Some(term.name.clone()),
8229                penalty: info.clone(),
8230            });
8231        }
8232        for info in &built.pre_dropped_penaltyinfo {
8233            dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
8234                termname: Some(term.name.clone()),
8235                penalty: info.clone(),
8236            });
8237        }
8238
8239        if let Some(lin_local) = &built.linear_constraints {
8240            for r in 0..lin_local.a.nrows() {
8241                linear_constraintsrows.push((col_start, col_end, lin_local.a.row(r).to_owned()));
8242                linear_constraints_b.push(lin_local.b[r]);
8243            }
8244        }
8245        if let Some(lb_local) = &lb_local {
8246            coefficient_lower_bounds
8247                .slice_mut(s![col_start..col_end])
8248                .assign(lb_local);
8249            any_bounds = true;
8250        }
8251
8252        // Move the per-term design out of `built` rather than cloning it.
8253        local_designs.push(built.design);
8254
8255        terms_out.push(SmoothTerm {
8256            name: term.name.clone(),
8257            coeff_range: col_start..col_end,
8258            shape: term.shape,
8259            penalties_local: built.penalties,
8260            nullspace_dims: built.nullspaces,
8261            penaltyinfo_local: built.penaltyinfo,
8262            metadata: built.metadata,
8263            lower_bounds_local: lb_local,
8264            linear_constraints_local: built.linear_constraints,
8265            kronecker_factored: built.kronecker_factored.take(),
8266            joint_null_rotation: applied_rotation,
8267            unabsorbed_global_orthogonality: None,
8268        });
8269
8270        col_start = col_end;
8271    }
8272
8273    assert_eq!(
8274        penalties_global.len(),
8275        nullspace_dims_global.len(),
8276        "global smooth penalty/nullspace bookkeeping diverged"
8277    );
8278    assert_eq!(
8279        penalties_global.len(),
8280        penaltyinfo_global.len(),
8281        "global smooth penalty metadata bookkeeping diverged"
8282    );
8283
8284    Ok(RawSmoothDesign {
8285        term_designs: local_designs,
8286        penalties: penalties_global,
8287        nullspace_dims: nullspace_dims_global,
8288        penaltyinfo: penaltyinfo_global,
8289        dropped_penaltyinfo: dropped_penaltyinfo_global,
8290        terms: terms_out,
8291        coefficient_lower_bounds: if any_bounds {
8292            Some(coefficient_lower_bounds)
8293        } else {
8294            None
8295        },
8296        linear_constraints: if linear_constraintsrows.is_empty() {
8297            None
8298        } else {
8299            let mut a = Array2::<f64>::zeros((linear_constraintsrows.len(), total_p));
8300            for (i, (cs, ce, values)) in linear_constraintsrows.iter().enumerate() {
8301                a.row_mut(i).slice_mut(s![*cs..*ce]).assign(values);
8302            }
8303            Some(LinearInequalityConstraints {
8304                a,
8305                b: Array1::from_vec(linear_constraints_b),
8306            })
8307        },
8308    })
8309}