Skip to main content

gam_terms/smooth/
term_specs.rs

1use coefficient_transforms::{
2    convex_divided_difference_transform_matrix, cumulative_exp, cumulative_sum_transform_matrix,
3    second_cumulative_exp,
4};
5
6pub use error::SmoothError;
7
8use input_standardization::{
9    apply_input_standardization, compensate_length_scale_for_standardization,
10    compensate_optional_length_scale_for_standardization, compute_spatial_input_scales,
11};
12
13use shape_constraints::{
14    build_shape_constraint_design_1d, build_shape_linear_constraints_1d,
15    merge_linear_constraints_global, shape_lower_bounds_local, shape_order_and_sign,
16    shape_supports_basis, shape_uses_box_reparameterization,
17};
18
19pub fn describe_thin_plate_center_request(strategy: &CenterStrategy) -> String {
20    match strategy {
21        CenterStrategy::Auto(inner) => describe_thin_plate_center_request(inner),
22        CenterStrategy::UserProvided(centers) => format!("{} centers", centers.nrows()),
23        CenterStrategy::EqualMass { num_centers }
24        | CenterStrategy::EqualMassCovarRepresentative { num_centers }
25        | CenterStrategy::FarthestPoint { num_centers }
26        | CenterStrategy::KMeans { num_centers, .. } => format!("{num_centers} centers"),
27        CenterStrategy::UniformGrid { points_per_dim } => {
28            format!("uniform grid with {points_per_dim} points per dimension")
29        }
30    }
31}
32
33pub fn rewrite_thin_plate_knots_error(
34    err: BasisError,
35    termname: &str,
36    feature_count: usize,
37    spec: &ThinPlateBasisSpec,
38) -> BasisError {
39    match err {
40        // Polynomial-nullspace shortfall reported directly by the kernel
41        // builder ("thin-plate spline requires at least N centers to span ...").
42        BasisError::InvalidInput(msg)
43            if msg.contains("thin-plate spline requires at least")
44                && (msg.contains("centers to span") || msg.contains("knots to span")) =>
45        {
46            let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
47            let requested = describe_thin_plate_center_request(&spec.center_strategy);
48            BasisError::InvalidInput(format!(
49                "joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
50            ))
51        }
52        // Insufficient-rows shortfall raised by `select_thin_plate_knots` when
53        // the requested center count exceeds the available row count. Rewrite
54        // it in term language so the diagnostic points at the smooth term and
55        // the polynomial-nullspace minimum the user needs to satisfy.
56        BasisError::InvalidInput(msg)
57            if msg.starts_with("requested ") && msg.contains(" knots but only ") =>
58        {
59            let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
60            let requested = describe_thin_plate_center_request(&spec.center_strategy);
61            BasisError::InvalidInput(format!(
62                "joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
63            ))
64        }
65        other => other,
66    }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
70pub enum ShapeConstraint {
71    None,
72    MonotoneIncreasing,
73    MonotoneDecreasing,
74    Convex,
75    Concave,
76}
77
78/// Parse a shape-constraint string into a [`ShapeConstraint`].
79///
80/// This is the single source of truth shared by the formula DSL
81/// (`s(x, shape=...)`) and the `smooths={...}` override path
82/// (`Smooth.shape_constraint`). The accepted spellings cover the canonical
83/// Python `ShapeConstraintLiteral` strings exactly
84/// (`"none"` / `"monotone_increasing"` / `"monotone_decreasing"` /
85/// `"convex"` / `"concave"`) plus a few common aliases. Hyphens and case are
86/// normalized, so `"Monotone-Increasing"` and `"mono_inc"` both resolve to
87/// [`ShapeConstraint::MonotoneIncreasing`].
88pub fn parse_shape_constraint(raw: &str) -> Result<ShapeConstraint, String> {
89    let normalized = raw.trim().to_ascii_lowercase().replace('-', "_");
90    match normalized.as_str() {
91        "" | "none" => Ok(ShapeConstraint::None),
92        "monotone_increasing" | "monotonic_increasing" | "increasing" | "mono_inc" | "mpi" => {
93            Ok(ShapeConstraint::MonotoneIncreasing)
94        }
95        "monotone_decreasing" | "monotonic_decreasing" | "decreasing" | "mono_dec" | "mpd" => {
96            Ok(ShapeConstraint::MonotoneDecreasing)
97        }
98        "convex" | "cvx" => Ok(ShapeConstraint::Convex),
99        "concave" | "ccv" => Ok(ShapeConstraint::Concave),
100        other => Err(format!(
101            "unknown shape constraint {other:?}; expected one of \
102             \"none\", \"monotone_increasing\", \"monotone_decreasing\", \
103             \"convex\", \"concave\""
104        )),
105    }
106}
107
108impl ShapeConstraint {
109    /// Canonical formula-DSL spelling, i.e. the text emitted into
110    /// `s(x, shape=...)`. Round-trips through [`parse_shape_constraint`].
111    pub fn dsl_str(&self) -> &'static str {
112        match self {
113            ShapeConstraint::None => "none",
114            ShapeConstraint::MonotoneIncreasing => "monotone_increasing",
115            ShapeConstraint::MonotoneDecreasing => "monotone_decreasing",
116            ShapeConstraint::Convex => "convex",
117            ShapeConstraint::Concave => "concave",
118        }
119    }
120}
121
122/// Smooth-term head keywords recognised by the formula DSL. A `shape=` option
123/// may be attached to any term whose head is one of these.
124pub const SMOOTH_HEAD_KEYWORDS: [&str; 11] = [
125    "s",
126    "smooth",
127    "te",
128    "tensor",
129    "thinplate",
130    "tps",
131    "duchon",
132    "matern",
133    "sphere",
134    "bs",
135    "bspline",
136];
137
138/// Rewrite smooth-term calls in `formula` so each named smooth carries a
139/// `shape=<kind>` option understood by the formula DSL.
140///
141/// `constraints` pairs the smooth-term text as it appears in the formula
142/// (e.g. `"s(x)"` or `"s(x, type=duchon, centers=8)"`) with a shape-constraint
143/// spelling accepted by [`parse_shape_constraint`]; comparison is exact after
144/// whitespace removal. A `"none"` constraint is a no-op. Referencing a term not
145/// present in the formula is an error.
146///
147/// This is the single source of truth for the `gamfit.fit(..., constraints=…)`
148/// rewrite — the Python wrapper only marshals the mapping across the FFI and
149/// holds no formula-parsing or alias-normalization logic of its own.
150pub fn apply_shape_constraints_to_formula(
151    formula: &str,
152    constraints: &[(String, String)],
153) -> Result<String, String> {
154    use std::collections::{BTreeMap, BTreeSet};
155
156    if constraints.is_empty() {
157        return Ok(formula.to_string());
158    }
159    let strip_ws = |s: &str| -> String { s.chars().filter(|c| !c.is_whitespace()).collect() };
160
161    // Whitespace-stripped term text -> canonical shape spelling.
162    let mut wanted: BTreeMap<String, &'static str> = BTreeMap::new();
163    // Whitespace-stripped term text -> original key (for error labels).
164    let mut originals: BTreeMap<String, String> = BTreeMap::new();
165    for (key, kind_raw) in constraints {
166        let kind = parse_shape_constraint(kind_raw)?;
167        let nk = strip_ws(key);
168        originals.entry(nk.clone()).or_insert_with(|| key.clone());
169        if kind != ShapeConstraint::None {
170            wanted.insert(nk, kind.dsl_str());
171        }
172    }
173    if wanted.is_empty() {
174        return Ok(formula.to_string());
175    }
176
177    let chars: Vec<char> = formula.chars().collect();
178    let n = chars.len();
179    let is_ident = |c: char| c.is_ascii_alphanumeric() || c == '_';
180
181    let mut out = String::with_capacity(formula.len() + 32);
182    let mut matched: BTreeSet<String> = BTreeSet::new();
183    let mut i = 0usize;
184    while i < n {
185        // Locate the next smooth-term head (`<keyword> \s* (`) at or after `i`,
186        // respecting word boundaries so `abs(` never matches the `s(` head.
187        let mut head: Option<(usize, usize)> = None; // (head_start, paren_index)
188        let mut p = i;
189        while p < n {
190            let boundary = p == 0 || !is_ident(chars[p - 1]);
191            if boundary {
192                for kw in SMOOTH_HEAD_KEYWORDS.iter() {
193                    let klen = kw.chars().count();
194                    if p + klen > n || chars[p..p + klen].iter().collect::<String>() != **kw {
195                        continue;
196                    }
197                    let mut q = p + klen;
198                    while q < n && chars[q].is_whitespace() {
199                        q += 1;
200                    }
201                    if q < n && chars[q] == '(' {
202                        head = Some((p, q));
203                        break;
204                    }
205                }
206            }
207            if head.is_some() {
208                break;
209            }
210            p += 1;
211        }
212        let (head_start, paren_open) = match head {
213            Some(h) => h,
214            None => {
215                out.extend(chars[i..].iter());
216                break;
217            }
218        };
219        out.extend(chars[i..head_start].iter());
220
221        // Find the matching close paren, honoring nesting and string literals.
222        let body_start = paren_open + 1;
223        let mut depth = 1i32;
224        let mut j = body_start;
225        let mut in_str: Option<char> = None;
226        let mut closed = false;
227        while j < n {
228            let ch = chars[j];
229            if let Some(quote) = in_str {
230                if ch == quote {
231                    in_str = None;
232                }
233            } else if ch == '\'' || ch == '"' {
234                in_str = Some(ch);
235            } else if ch == '(' {
236                depth += 1;
237            } else if ch == ')' {
238                depth -= 1;
239                if depth == 0 {
240                    closed = true;
241                    break;
242                }
243            }
244            j += 1;
245        }
246
247        if !closed {
248            // Unbalanced — emit the remainder verbatim; the DSL parser will
249            // produce the canonical error.
250            out.extend(chars[head_start..].iter());
251            break;
252        }
253
254        let term_text: String = chars[head_start..=j].iter().collect();
255
256        let key_norm = strip_ws(&term_text);
257
258        match wanted.get(&key_norm) {
259            None => out.extend(chars[head_start..=j].iter()),
260            Some(kind) => {
261                let head_paren: String = chars[head_start..body_start].iter().collect();
262                let inside: String = chars[body_start..j].iter().collect();
263                let inside = inside.trim();
264                if inside.is_empty() {
265                    out.push_str(&format!("{head_paren}shape={kind})"));
266                } else {
267                    out.push_str(&format!("{head_paren}{inside}, shape={kind})"));
268                }
269                matched.insert(key_norm);
270            }
271        }
272
273        i = j + 1;
274    }
275
276    let mut missing: Vec<String> = wanted
277        .keys()
278        .filter(|k| !matched.contains(*k))
279        .map(|k| originals.get(k).cloned().unwrap_or_else(|| k.clone()))
280        .collect();
281
282    if !missing.is_empty() {
283        missing.sort();
284        return Err(format!(
285            "shape constraints referenced smooth term(s) not found in formula: {}",
286            missing.join(", ")
287        ));
288    }
289
290    Ok(out)
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub enum BySmoothKind {
295    Numeric,
296    Level { level_bits: u64 },
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub enum SmoothBasisSpec {
301    /// Row-gated wrapper used for mgcv-style ``by=`` smooths.
302    ///
303    /// ``ByNumeric`` multiplies the inner smooth by a numeric column.
304    /// ``ByLevel`` keeps the inner smooth active only for rows whose encoded
305    /// categorical value has the stored bit pattern.  Unordered factor-by
306    /// smooths are represented as one independent ``ByLevel`` term per level.
307    ///
308    /// `kind` preserves the compact structural discriminator, while `by`
309    /// carries the full row-gating spec used to build the local design.
310    ByVariable {
311        inner: Box<SmoothBasisSpec>,
312        by_col: usize,
313        kind: BySmoothKind,
314        by: ByVariableSpec,
315    },
316    /// Sum-to-zero factor smooth (`bs="sz"`): with L levels, estimate L-1
317    /// deviation coefficient blocks and use the final level as the negative
318    /// sum of the others, enforcing coefficient-wise zero sums across levels.
319    FactorSumToZero {
320        inner: Box<SmoothBasisSpec>,
321        by_col: usize,
322        levels: Vec<u64>,
323        /// Global-orthogonality column map `Z` captured at fit time when this
324        /// term overlapped an owner smooth (`s(x) + s(g, x, bs=sz)`, #978):
325        /// the hierarchical-ownership pass residualized this term's realized
326        /// design as `X ← X·Z`, shrinking its coefficient block. `Z` depends
327        /// on the *training-row* owner designs, so prediction cannot rederive
328        /// it — it must be persisted and replayed
329        /// (`apply_global_smooth_identifiability` consumes it verbatim).
330        /// Chart convention: `Z` lives in the post-restack, post-joint-null-Q
331        /// coordinates — the raw `sz` rebuild reapplies `Q` deterministically
332        /// (#700), then `Z` applies on top. `None` for non-overlapping terms.
333        #[serde(default)]
334        frozen_global_orthogonality: Option<Array2<f64>>,
335    },
336    BSpline1D {
337        feature_col: usize,
338        spec: BSplineBasisSpec,
339    },
340    /// A smooth modulated by a `by=` variable. Numeric `by` scales one inner
341    /// smooth; factor `by` replicates the inner smooth by level.
342    BySmooth {
343        smooth: Box<SmoothBasisSpec>,
344        by_kind: ByVarKind,
345    },
346    /// Factor-smooth interaction families (`bs="fs"`, `bs="sz"`) and
347    /// random slopes (`bs="re"`).
348    FactorSmooth { spec: FactorSmoothSpec },
349    ThinPlate {
350        feature_cols: Vec<usize>,
351        spec: ThinPlateBasisSpec,
352        /// Per-column standard deviations used to standardize input dimensions
353        /// before kernel evaluation when d > 1. `None` means no standardization
354        /// (either d == 1 or explicitly disabled).
355        #[serde(default)]
356        input_scales: Option<Vec<f64>>,
357    },
358    Sphere {
359        feature_cols: Vec<usize>,
360        spec: SphericalSplineBasisSpec,
361    },
362    /// Constant-curvature (`M_κ`) geodesic-kernel smooth over κ-stereographic
363    /// chart coordinates (#944): one construction interpolating
364    /// S^d → ℝ^d → H^d through the spec's fixed κ. The Wahba S² smooth is the
365    /// structural template; the geometry comes from
366    /// `geometry::constant_curvature::ConstantCurvature`.
367    ConstantCurvature {
368        feature_cols: Vec<usize>,
369        spec: ConstantCurvatureBasisSpec,
370    },
371    Matern {
372        feature_cols: Vec<usize>,
373        spec: MaternBasisSpec,
374        #[serde(default)]
375        input_scales: Option<Vec<f64>>,
376    },
377    /// Measure-jet spline smooth: multiscale local-jet-residual energy of the
378    /// empirical measure (centers as μ-quadrature, masses as μ-weights — no
379    /// graph, mesh, or neighbor set inside the statistical object). The
380    /// feature columns are ambient coordinates of data concentrated near an
381    /// unknown low-dimensional, possibly stratified set.
382    MeasureJet {
383        feature_cols: Vec<usize>,
384        spec: MeasureJetBasisSpec,
385        #[serde(default)]
386        input_scales: Option<Vec<f64>>,
387    },
388    Duchon {
389        feature_cols: Vec<usize>,
390        spec: DuchonBasisSpec,
391        #[serde(default)]
392        input_scales: Option<Vec<f64>>,
393    },
394    Pca {
395        feature_cols: Vec<usize>,
396        basis_matrix: Array2<f64>,
397        centered: bool,
398        #[serde(default = "default_pca_smooth_penalty")]
399        smooth_penalty: f64,
400        #[serde(default)]
401        center_mean: Option<Array1<f64>>,
402        #[serde(default)]
403        pca_basis_path: Option<PathBuf>,
404        #[serde(default = "default_pca_chunk_size")]
405        chunk_size: usize,
406    },
407    /// Tensor-product smooth built from 1D B-spline marginals.
408    ///
409    /// This is the `te()`-style construction used when axes have different units/scales
410    /// (for example, space x time) and isotropic radial kernels are not appropriate.
411    TensorBSpline {
412        feature_cols: Vec<usize>,
413        spec: TensorBSplineSpec,
414    },
415}
416
417impl SmoothBasisSpec {
418    /// Conservative lower bound on the number of sample rows needed for this
419    /// smooth basis to have a well-posed REML fit.
420    ///
421    /// Each basis kind answers the question for itself, so the workflow does
422    /// not have to know how many columns a B-spline, tensor product, PCA
423    /// projection, or spatial kernel emits. The contract is a *lower bound*:
424    /// returning too small a number is permitted (the inner solver will catch
425    /// any genuine n-vs-rank failure that slips past); returning too large a
426    /// number is a regression because it rejects legitimate fits.
427    ///
428    /// Rationale: B-spline / tensor / PCA bases have a closed-form column
429    /// count, so we use the exact dimension. Radial bases (TPS, Matern,
430    /// Duchon, Sphere) and factor smooths choose their column count from the
431    /// data (`heuristic_centers`, `unique_count`); we fall back to a small
432    /// constant floor because a fit on fewer than five rows cannot stabilise
433    /// any radial smooth regardless of the configured kernel scale.
434    pub fn min_sample_rows(&self) -> usize {
435        // Floor used for data-driven bases whose column count is not known
436        // from the spec alone. Five rows is the minimum at which the inner
437        // pivot/QR + REML smoothing-parameter search has any chance of being
438        // well-posed for a non-parametric smooth.
439        const RADIAL_FLOOR: usize = 5;
440
441        match self {
442            Self::ByVariable { inner, .. } => inner.min_sample_rows(),
443            Self::FactorSumToZero { inner, levels, .. } => {
444                // L-1 independent deviation blocks each carrying the inner
445                // basis dimension. Skip the levels-multiplier if it doesn't
446                // bring more rows; we want the *lower bound* not the rank.
447                let inner_min = inner.min_sample_rows();
448                let lvls = levels.len().saturating_sub(1).max(1);
449                inner_min.saturating_mul(lvls)
450            }
451            Self::BSpline1D { spec, .. } => bspline_basis_min_rows(spec),
452            Self::BySmooth { smooth, .. } => smooth.min_sample_rows(),
453            Self::FactorSmooth { spec } => {
454                // Replicates the marginal once per level; without a known
455                // level count we conservatively require at least the marginal
456                // basis dimension.
457                bspline_basis_min_rows(&spec.marginal)
458            }
459            Self::ThinPlate { .. }
460            | Self::Sphere { .. }
461            | Self::ConstantCurvature { .. }
462            | Self::Matern { .. }
463            | Self::MeasureJet { .. }
464            | Self::Duchon { .. } => RADIAL_FLOOR,
465            Self::Pca { basis_matrix, .. } => basis_matrix.ncols().max(1),
466            Self::TensorBSpline { spec, .. } => {
467                // A `te(...)` smooth is *penalized*: each margin carries a
468                // difference (wiggliness) penalty and the tensor inherits a
469                // Kronecker-sum penalty `S = Σ_i I ⊗ … ⊗ S_i ⊗ … ⊗ I`. The raw
470                // column count is the *product* of the per-marginal column
471                // counts, but that product is the lower bound for an
472                // *unpenalized* tensor regression — it is the number of rows you
473                // would need to identify every interaction column with no
474                // regularization. The penalty regularizes all of those
475                // interaction directions; only the combined penalty *null space*
476                // (the tensor product of the per-margin polynomial trends, a
477                // handful of columns) must be identified by the data, and the
478                // smoothing-parameter search shrinks the rest. The effective
479                // degrees of freedom of the fitted `te()` are therefore a small
480                // fraction of the column product, which is exactly why mgcv
481                // fits a default `te(x, y)` on a couple hundred rows.
482                //
483                // The honest *penalized* lower bound is the **sum** of the
484                // per-marginal column counts, not their product: a row floor of
485                // `Σ_i k_i` still guarantees enough data to identify each
486                // margin's additive main-effect (the largest sub-block the
487                // penalty cannot shrink to zero), while no longer conflating
488                // unpenalized column-count identifiability with penalized
489                // well-posedness. This accepts moderate-`n` penalized tensors
490                // (e.g. a 20×20 default basis on n=200) yet still rejects a
491                // genuinely undersized fit where `n < Σ_i k_i` and even the
492                // additive part is rank-deficient.
493                //
494                // Binary / low-cardinality margins (#724): gam will accept a
495                // `te(x, badh)` whose `badh ∈ {0, 1}` margin nominally requests
496                // more basis columns than `badh` has unique values, where mgcv
497                // refuses the unpenalized term as ill-posed ("badh has
498                // insufficient unique values to support k knots"). This is
499                // correct-by-design, *not* a degenerate fit: the marginal
500                // wiggliness penalty on the `badh` axis has a null space that is
501                // exactly its identifiable trend (the two cell means of a binary
502                // covariate), and the Kronecker-sum penalty shrinks every tensor
503                // column outside that null space toward zero. The resulting fit
504                // is the well-posed "per-level `x` smooth + binary main effect"
505                // that mgcv reaches only after manually collapsing the basis —
506                // gam reaches it automatically because the penalty, not the raw
507                // column count, sets the effective rank. A genuinely
508                // rank-deficient design (penalty null space wider than the data
509                // can support) is still caught downstream by the inner pivoted
510                // factorization, which owns the exact n-vs-rank decision; this
511                // pre-fit gate only refuses the grossly-undersized formula.
512                let mut total: usize = 0;
513                for marginal in &spec.marginalspecs {
514                    let m = bspline_basis_min_rows(marginal);
515                    total = total.saturating_add(m.max(1));
516                }
517                total.max(RADIAL_FLOOR)
518            }
519        }
520    }
521
522    /// Stable structural discriminant for warm-start cache keying (#869).
523    ///
524    /// Two smooths that produce different bases / penalty structures must map
525    /// to different strings here so they cannot collide on the persistent
526    /// warm-start `cache_key` (which is otherwise blind to topology: it hashes
527    /// only the raw input column count, so e.g. `sphere` vs `torus` vs
528    /// `euclidean` candidates fit on the *same* data would otherwise share one
529    /// key and cross-contaminate each other's β/ρ seed). The string is the
530    /// topology identity, not the fitted coefficients, so same-topology refits
531    /// (the screen→full-refit cascade) still hit the same key and reuse work.
532    pub fn structural_kind(&self) -> &'static str {
533        match self {
534            Self::ByVariable { .. } => "by_variable",
535            Self::FactorSumToZero { .. } => "factor_sum_to_zero",
536            Self::BSpline1D { .. } => "bspline_1d",
537            Self::BySmooth { .. } => "by_smooth",
538            Self::FactorSmooth { .. } => "factor_smooth",
539            Self::ThinPlate { .. } => "thin_plate",
540            Self::Sphere { .. } => "sphere",
541            Self::ConstantCurvature { .. } => "constant_curvature",
542            Self::Matern { .. } => "matern",
543            Self::MeasureJet { .. } => "measurejet",
544            Self::Duchon { .. } => "duchon",
545            Self::Pca { .. } => "pca",
546            Self::TensorBSpline { .. } => "tensor_bspline",
547        }
548    }
549
550    /// True for a tensor-product smooth that is only *marginally* centered
551    /// (`ti(...)`, [`TensorBSplineIdentifiability::MarginalSumToZero`]): its
552    /// per-margin sum-to-zero reparameterization `(B_xZ_x)⊗(B_zZ_z)` has ALREADY
553    /// removed each axis's main effect analytically (mgcv-identical), so its
554    /// main-effect removal is complete and it must take NO additional
555    /// owner-residualization block. Residualizing it a second time against the
556    /// realized main-effect designs is a grid-fragile no-op on an exact tensor
557    /// grid but eats genuine pure-interaction curvature off-grid (#1470).
558    pub fn is_marginally_centered_tensor(&self) -> bool {
559        matches!(
560            self,
561            Self::TensorBSpline { spec, .. }
562                if matches!(spec.identifiability, TensorBSplineIdentifiability::MarginalSumToZero)
563        )
564    }
565
566    /// A sum-to-zero factor smooth (`bs="sz"`) has ALREADY removed the
567    /// cross-group main effect analytically, in coefficient space, via its
568    /// `Σ_g d_g(x) ≡ 0` reparameterization (`L-1` deviation blocks with the
569    /// reference level the negative sum of the others) — exactly mgcv's `sz`
570    /// construction, which is self-identifiable against an overlapping `s(x)`
571    /// with no further constraint. Residualizing it a SECOND time against the
572    /// realized B-spline span of the explicit `s(x)` smooth is redundant in
573    /// exact arithmetic (the common-to-all-groups component is zero by
574    /// construction) and actively HARMFUL on finite data: each deviation block
575    /// is a B-spline in `x` whose realized columns share `s(x)`'s span, so the
576    /// joint residualization collapses the full `L·k`-column deviation design to
577    /// `L·k − rank(s(x))` columns and eats the within-group curvature `s(x)`
578    /// cannot represent. REML then rails the deviation smoothing parameter and
579    /// the factor smooth under-recovers (#1605). This is the exact analogue of
580    /// the marginally-centered tensor (`ti`) exemption (#1470), so such a term
581    /// takes NO owner-residualization block.
582    pub fn is_sum_to_zero_factor_smooth(&self) -> bool {
583        matches!(
584            self,
585            Self::FactorSumToZero { .. }
586                | Self::FactorSmooth {
587                    spec: FactorSmoothSpec {
588                        flavour: FactorSmoothFlavour::Sz,
589                        ..
590                    }
591                }
592        )
593    }
594
595    /// Feature columns this basis consumes, used alongside [`structural_kind`]
596    /// to disambiguate two same-kind smooths on different axes. Wrapper
597    /// variants delegate to their inner basis.
598    pub fn structural_feature_cols(&self) -> Vec<usize> {
599        match self {
600            Self::ByVariable { inner, .. } | Self::FactorSumToZero { inner, .. } => {
601                inner.structural_feature_cols()
602            }
603            Self::BySmooth { smooth, .. } => smooth.structural_feature_cols(),
604            Self::FactorSmooth { .. } => Vec::new(),
605            Self::BSpline1D { feature_col, .. } => vec![*feature_col],
606            Self::ThinPlate { feature_cols, .. }
607            | Self::Sphere { feature_cols, .. }
608            | Self::ConstantCurvature { feature_cols, .. }
609            | Self::Matern { feature_cols, .. }
610            | Self::MeasureJet { feature_cols, .. }
611            | Self::Duchon { feature_cols, .. }
612            | Self::Pca { feature_cols, .. }
613            | Self::TensorBSpline { feature_cols, .. } => feature_cols.clone(),
614        }
615    }
616}
617
618/// Lower bound on the number of sample rows a 1D B-spline smooth needs for a
619/// well-posed *penalized* REML fit. Used as the per-smooth row floor in
620/// [`SmoothBasisSpec::min_sample_rows`].
621///
622/// For a *singly*-penalized smooth the floor is the full column count: the
623/// wiggliness penalty leaves the order-`m` polynomial trend unpenalized, and
624/// gam's original gate conservatively required enough rows for the whole basis.
625/// That conservative floor is kept here unchanged.
626///
627/// A *double*-penalized smooth (mgcv `select=TRUE`) is different: it adds a
628/// second penalty on the wiggliness penalty's null space, so even the
629/// polynomial trend is shrinkable toward zero and *nothing* in the basis
630/// requires unpenalized identification by the data — exactly the reasoning the
631/// `TensorBSpline` arm of [`SmoothBasisSpec::min_sample_rows`] already applies
632/// to a penalized tensor. Its honest floor is therefore a small stabilization
633/// constant, not the column count. This is what lets mgcv (and now gam) fit
634/// several `select=TRUE` smooths on a dataset whose row count is below the
635/// summed basis width (e.g. the n≈30 `wine_gamair` fold, 5 `ps` smooths,
636/// p≈51): the penalties, not the data, set the effective rank. The bounded
637/// outer REML loop still terminates, and the genuine n-vs-rank decision is
638/// owned downstream by the inner pivoted factorization. Without this, gam
639/// rejected the fit outright (or, before the gate existed, the outer REML loop
640/// wandered the flat overparameterized surface until the benchmark wall budget
641/// killed it — #1089).
642pub fn bspline_basis_min_rows(spec: &crate::basis::BSplineBasisSpec) -> usize {
643    use crate::basis::BSplineKnotSpec;
644    let columns = match &spec.knotspec {
645        BSplineKnotSpec::Generate {
646            num_internal_knots, ..
647        } => *num_internal_knots + spec.degree + 1,
648        BSplineKnotSpec::Automatic {
649            num_internal_knots: Some(k),
650            ..
651        } => *k + spec.degree + 1,
652        BSplineKnotSpec::Automatic {
653            num_internal_knots: None,
654            ..
655        } => {
656            // Knot count is data-derived (`default_internal_knot_count_for_data`).
657            // A minimal cubic basis is `degree + 2` columns; below that the
658            // basis cannot represent a non-parametric smooth.
659            spec.degree + 2
660        }
661        BSplineKnotSpec::Provided(knots) => knots.len().saturating_sub(spec.degree + 1).max(1),
662        // cr basis dimension equals the knot count (no degree offset).
663        BSplineKnotSpec::NaturalCubicRegression { knots } => knots.len(),
664        BSplineKnotSpec::PeriodicUniform { num_basis, .. } => *num_basis,
665    };
666    let columns = columns.max(spec.degree + 2);
667
668    if spec.double_penalty {
669        // Fully shrinkable basis: only a small stabilization floor must be
670        // identified by the data, capped by the actual column count.
671        const DOUBLE_PENALTY_FLOOR: usize = 2;
672        DOUBLE_PENALTY_FLOOR.min(columns).max(1)
673    } else {
674        columns
675    }
676}
677
678#[derive(Debug, Clone, Serialize, Deserialize)]
679pub enum ByVariableSpec {
680    Numeric,
681    Level { value_bits: u64, label: String },
682}
683
684
685#[derive(Debug, Clone, Serialize, Deserialize)]
686pub enum ByVarKind {
687    Numeric {
688        feature_col: usize,
689    },
690    Factor {
691        feature_col: usize,
692        ordered: bool,
693        frozen_levels: Option<Vec<u64>>,
694    },
695}
696
697#[derive(Debug, Clone, Serialize, Deserialize)]
698pub struct FactorSmoothSpec {
699    pub continuous_cols: Vec<usize>,
700    pub group_col: usize,
701    pub marginal: BSplineBasisSpec,
702    pub flavour: FactorSmoothFlavour,
703    pub group_frozen_levels: Option<Vec<u64>>,
704    /// Fit-time global-orthogonality chart `Z` for this term (`s(x) + fs(x, g)`
705    /// overlap residualization, #978), in the post-joint-null-`Q` coordinates
706    /// (the raw rebuild recomputes any `Q` itself; `fs` penalties are
707    /// typically full-rank so `Q` is absent). Training-row dependent, hence
708    /// persisted; replayed verbatim by `apply_global_smooth_identifiability`.
709    #[serde(default)]
710    pub frozen_global_orthogonality: Option<Array2<f64>>,
711}
712
713#[derive(Debug, Clone, Serialize, Deserialize)]
714pub enum FactorSmoothFlavour {
715    Fs { m_null_penalty_orders: Vec<usize> },
716    Sz,
717    Re,
718}
719
720#[derive(Debug, Default, Clone, Serialize, Deserialize)]
721pub struct TensorBSplineSpec {
722    pub marginalspecs: Vec<BSplineBasisSpec>,
723    #[serde(default)]
724    pub periods: Vec<Option<f64>>,
725    pub double_penalty: bool,
726    #[serde(default)]
727    pub identifiability: TensorBSplineIdentifiability,
728    #[serde(default)]
729    pub penalty_decomposition: TensorBSplinePenaltyDecomposition,
730}
731
732#[derive(Debug, Default, Clone, Serialize, Deserialize)]
733pub enum TensorBSplineIdentifiability {
734    None,
735    #[default]
736    SumToZero,
737    /// mgcv `ti(...)` semantics: a *tensor interaction* smooth that excludes the
738    /// marginal main effects. A sum-to-zero constraint is applied to **each
739    /// marginal basis independently** before forming the tensor product, so the
740    /// resulting column space contains no function of a single variable alone —
741    /// only the pure interaction survives. The realized identifiability
742    /// transform is the Kronecker product `Z = Z₀ ⊗ Z₁ ⊗ … ⊗ Z_{d-1}` of the
743    /// per-margin sum-to-zero null-space bases, which is exactly the
744    /// reparameterization that turns the full-tensor design into the tensor
745    /// product of the centered margins.
746    MarginalSumToZero,
747    FrozenTransform {
748        transform: Array2<f64>,
749    },
750}
751
752#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
753pub enum TensorBSplinePenaltyDecomposition {
754    /// mgcv `te(...)`: one overlapping Kronecker-product penalty per margin,
755    /// `S_j` embedded against identities in the other tensor factors.
756    #[default]
757    MarginalKroneckerSum,
758    /// mgcv `t2(...)`: split every marginal coefficient space into penalized
759    /// range and penalty-null subspaces, then emit one disjoint tensor-subspace
760    /// penalty for every non-empty penalized/null combination.
761    Separable,
762}
763
764#[derive(Debug, Clone, Serialize, Deserialize)]
765pub struct SmoothTermSpec {
766    pub name: String,
767    pub basis: SmoothBasisSpec,
768    pub shape: ShapeConstraint,
769    /// Joint-null absorption rotation captured at fit time. `Some(Q)` means
770    /// the fitted coefficient vector lives in `γ`-coordinates with
771    /// `β_raw = Q · γ`; prediction must rotate the raw-basis design via
772    /// `X_new = X_new_raw · Q` to match. `None` means either the smooth had
773    /// no joint null space (penalty already full-rank) or rotation was
774    /// suppressed (smooth carries shape constraints whose cone geometry
775    /// would not survive an arbitrary orthogonal rotation). Persisted so
776    /// `save → load → predict` is bit-equivalent to in-memory prediction.
777    #[serde(default)]
778    pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
779}
780
781#[derive(Debug, Clone)]
782pub struct SmoothTerm {
783    pub name: String,
784    pub coeff_range: Range<usize>,
785    pub shape: ShapeConstraint,
786    pub penalties_local: Vec<Array2<f64>>,
787    pub nullspace_dims: Vec<usize>,
788    pub penaltyinfo_local: Vec<PenaltyInfo>,
789    pub metadata: BasisMetadata,
790    /// Optional term-local lower bounds for constrained coefficients.
791    /// `-inf` means unconstrained.
792    pub lower_bounds_local: Option<Array1<f64>>,
793    /// Optional term-local inequality constraints in local coefficient coordinates.
794    /// `A_local * beta_local >= b_local`.
795    pub linear_constraints_local: Option<LinearInequalityConstraints>,
796    /// Optional factored tensor-product representation preserved for operator-backed
797    /// assembly in the main design builder.
798    pub kronecker_factored: Option<KroneckerFactoredBasis>,
799    /// Joint-null absorption rotation. `Some(Q)` records the orthonormal
800    /// `(p_local × p_local)` matrix that was applied to this term's design
801    /// and per-block penalties at construction time:
802    /// `term_design ← X_raw · Q`, `penalties_local[k] ← Qᵀ · S_raw · Q`.
803    /// The smooth's coefficient vector therefore lives in the rotated
804    /// (`γ`) coordinate system, with `β_raw = Q · γ` recovering the raw
805    /// pre-rotation parameterization. `None` means either no joint null
806    /// space (penalty already full-rank) or rotation was suppressed —
807    /// suppression fires when the smooth carries shape constraints
808    /// (lower bounds or local linear inequalities) that would lose their
809    /// cone geometry under a general orthogonal rotation.
810    ///
811    /// Prediction-side replay: callers building a new-data design `X_new_raw`
812    /// from the *raw* basis must call [`SmoothTerm::apply_rotation_to_predict`]
813    /// (or equivalent) to obtain `X_new = X_new_raw · Q` matching this
814    /// term's coefficient system.
815    ///
816    /// Persistence replay: `freeze_term_collection_from_design` copies this
817    /// rotation into `SmoothTermSpec`, which is serialized with fitted-model
818    /// payloads and reused by the predict-time basis builder. Saved models
819    /// therefore replay the same `X_new_raw · Q` transform as in-memory
820    /// prediction.
821    pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
822    /// Global-orthogonality transform that `apply_global_smooth_identifiability`
823    /// applied to this term's design but could NOT embed into `metadata`
824    /// (factor-smooth kinds: `sz` metadata is per-marginal, `fs` metadata has
825    /// no transform slot — #978). `freeze_term_collection_from_design` copies
826    /// it onto the term's basis spec (`frozen_global_orthogonality`) so the
827    /// predict-side rebuild replays it instead of emitting the unresidualized
828    /// (wider) design that the fitted coefficients no longer match.
829    /// Chart convention is per kind: post-`Q` `Z` for `sz` (the raw rebuild
830    /// reapplies `Q` itself, #700), full `Q·Z` chart for `fs`.
831    pub unabsorbed_global_orthogonality: Option<Array2<f64>>,
832}
833
834impl SmoothTerm {
835    /// Apply the joint-null absorption rotation to a raw new-data design
836    /// matrix, returning `X_new_raw · Q` when this term was rotated at
837    /// fit time, or `X_new_raw` unchanged when no rotation was applied.
838    ///
839    /// Callers in the prediction path: after building the smooth's basis
840    /// at new data via the *raw* basis builder (the same builder used at
841    /// fit time, applied to `x_new` instead of the training rows), call
842    /// this method on the resulting matrix before forming `X · β`. The
843    /// fitted `β` lives in `γ`-coordinates if Q was applied; multiplying
844    /// the un-rotated `X_new_raw` by `β` would give a wrong η.
845    ///
846    /// Returns an error if the raw design's column count does not match
847    /// the rotation's `p_local`. The width invariant must hold: the raw
848    /// basis builder MUST emit the same `p_local` columns that the
849    /// fit-time builder did, and the rotation is `(p_local × p_local)`.
850    pub fn apply_rotation_to_predict(
851        &self,
852        x_new_raw: Array2<f64>,
853    ) -> Result<Array2<f64>, BasisError> {
854        let Some(rot) = self.joint_null_rotation.as_ref() else {
855            return Ok(x_new_raw);
856        };
857        let p_local = rot.rotation.nrows();
858        if x_new_raw.ncols() != p_local {
859            crate::bail_dim_basis!(
860                "joint-null rotation replay for term '{}': raw design has {} columns, \
861                 rotation expects {} (the raw basis builder must emit the same column \
862                 count as at fit time)",
863                self.name,
864                x_new_raw.ncols(),
865                p_local,
866            );
867        }
868        Ok(gam_linalg::faer_ndarray::fast_ab(
869            &x_new_raw,
870            &rot.rotation,
871        ))
872    }
873
874    /// Dimension of the **joint** null space of this term's active penalties:
875    /// the coefficient directions penalized by *no* penalty. The smooth-component
876    /// Wald test ([`crate::inference::smooth_test::wood_smooth_test`]) treats this
877    /// many leading coefficients as genuine unpenalized fixed effects and tests
878    /// them at full rank; the remainder is the penalized sub-block tested with a
879    /// rank-`≈EDF` truncated pseudo-inverse.
880    ///
881    /// Because every penalty block `S_k` is positive semi-definite,
882    /// `vᵀ(Σ_k S_k)v = Σ_k vᵀ S_k v = 0` iff `S_k v = 0` for *every* `k`; the
883    /// joint null space is therefore exactly `null(Σ_k S_k)`, of dimension
884    /// `p_local − rank(Σ_k S_k)`. This is the **intersection** of the per-penalty
885    /// null spaces, not their sum.
886    ///
887    /// Summing the per-penalty `nullspace_dims` instead (the historical defect
888    /// behind #1360) *unions* the null spaces and badly over-counts: a
889    /// double-penalty smooth carries a bending penalty (null space = its
890    /// polynomial part) plus a complementary null-space ridge (which penalizes
891    /// exactly that polynomial part), so the two null spaces are disjoint and the
892    /// joint null space is empty — yet the per-penalty dims sum to nearly
893    /// `p_local`. Feeding that inflated count to the Wald test makes it test
894    /// almost the whole shrunk block at full rank, manufacturing overwhelming
895    /// "significance" for a term the fit drove to ~0 EDF.
896    pub fn wald_unpenalized_dim(&self) -> usize {
897        joint_unpenalized_dim(
898            self.coeff_range.len(),
899            &self.penalties_local,
900            &self.nullspace_dims,
901        )
902    }
903}
904
905/// Numeric core of [`SmoothTerm::wald_unpenalized_dim`]: the dimension of the
906/// joint null space `∩_k null(S_k) = null(Σ_k S_k)` of a term's local penalty
907/// blocks, with a conservative fallback when a penalty is not materialized as a
908/// full `p_local × p_local` matrix (e.g. a Kronecker tensor factor).
909pub fn joint_unpenalized_dim(
910    p_local: usize,
911    penalties_local: &[Array2<f64>],
912    nullspace_dims: &[usize],
913) -> usize {
914    use gam_linalg::faer_ndarray::FaerEigh;
915    if p_local == 0 {
916        return 0;
917    }
918    if penalties_local.is_empty() {
919        // No penalty ⇒ a wholly unpenalized (fixed-effect) block.
920        return p_local;
921    }
922    // Sum the penalties that are materialized as full `p_local × p_local`
923    // blocks (the common smooth case). The covariance block the Wald test
924    // slices lives in this same coefficient basis (post joint-null rotation),
925    // so the rank is computed in the right metric.
926    let mut s_total = Array2::<f64>::zeros((p_local, p_local));
927    let mut materialized = 0usize;
928    for s in penalties_local {
929        if s.nrows() == p_local && s.ncols() == p_local {
930            s_total += s;
931            materialized += 1;
932        }
933    }
934    if materialized == penalties_local.len() {
935        let symmetric = {
936            let transpose = s_total.t().to_owned();
937            (&s_total + &transpose) * 0.5
938        };
939        if let Ok((evals, _)) = symmetric.eigh(faer::Side::Lower) {
940            let max_abs = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
941            if max_abs == 0.0 {
942                // All penalties identically zero ⇒ unpenalized block.
943                return p_local;
944            }
945            let tol = max_abs * (p_local as f64) * 1e-12;
946            let rank = evals.iter().filter(|&&v| v > tol).count();
947            return p_local.saturating_sub(rank);
948        }
949    }
950    // Conservative fallback when a penalty is not a materialized full block
951    // (e.g. a Kronecker tensor factor): with ≥2 active penalties the joint
952    // null space is almost always empty (the only over-rejecting direction);
953    // with a single penalty it is exactly that penalty's own null space.
954    if penalties_local.len() >= 2 {
955        0
956    } else {
957        nullspace_dims
958            .iter()
959            .copied()
960            .min()
961            .unwrap_or(0)
962            .min(p_local)
963    }
964}
965
966#[derive(Debug, Clone, Serialize, Deserialize)]
967pub struct PenaltyBlockInfo {
968    pub global_index: usize,
969    pub termname: Option<String>,
970    pub penalty: PenaltyInfo,
971}
972
973#[derive(Debug, Clone, Serialize, Deserialize)]
974pub struct DroppedPenaltyBlockInfo {
975    pub termname: Option<String>,
976    pub penalty: PenaltyInfo,
977}
978
979#[derive(Debug, Clone)]
980pub struct SmoothDesign {
981    pub term_designs: Vec<DesignMatrix>,
982    /// Per-term block-local penalties.  Each `col_range` is relative to the
983    /// smooth block (i.e. indexing into the concatenation of `term_designs`).
984    pub penalties: Vec<BlockwisePenalty>,
985    pub nullspace_dims: Vec<usize>,
986    pub penaltyinfo: Vec<PenaltyBlockInfo>,
987    pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
988    pub terms: Vec<SmoothTerm>,
989    /// Optional smooth-block lower bounds in smooth coefficient coordinates.
990    /// Length equals `total_smooth_cols()` when present.
991    pub coefficient_lower_bounds: Option<Array1<f64>>,
992    /// Optional smooth-block inequality constraints:
993    /// `A_smooth * beta_smooth >= b`.
994    pub linear_constraints: Option<LinearInequalityConstraints>,
995}
996
997impl SmoothDesign {
998    pub fn total_smooth_cols(&self) -> usize {
999        self.term_designs.iter().map(DesignMatrix::ncols).sum()
1000    }
1001    pub fn nrows(&self) -> usize {
1002        self.term_designs.first().map_or(0, DesignMatrix::nrows)
1003    }
1004}
1005
1006#[derive(Debug, Clone)]
1007pub struct RawSmoothDesign {
1008    pub term_designs: Vec<DesignMatrix>,
1009    /// Per-term block-local penalties.  Each `col_range` is relative to the
1010    /// smooth block (i.e. indexing into the concatenation of `term_designs`).
1011    pub penalties: Vec<BlockwisePenalty>,
1012    pub nullspace_dims: Vec<usize>,
1013    pub penaltyinfo: Vec<PenaltyBlockInfo>,
1014    pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
1015    pub terms: Vec<SmoothTerm>,
1016    pub coefficient_lower_bounds: Option<Array1<f64>>,
1017    pub linear_constraints: Option<LinearInequalityConstraints>,
1018}
1019
1020impl RawSmoothDesign {
1021    pub fn total_smooth_cols(&self) -> usize {
1022        self.term_designs.iter().map(DesignMatrix::ncols).sum()
1023    }
1024    pub fn nrows(&self) -> usize {
1025        self.term_designs.first().map_or(0, DesignMatrix::nrows)
1026    }
1027}
1028
1029impl From<RawSmoothDesign> for SmoothDesign {
1030    fn from(value: RawSmoothDesign) -> Self {
1031        Self {
1032            term_designs: value.term_designs,
1033            penalties: value.penalties,
1034            nullspace_dims: value.nullspace_dims,
1035            penaltyinfo: value.penaltyinfo,
1036            dropped_penaltyinfo: value.dropped_penaltyinfo,
1037            terms: value.terms,
1038            coefficient_lower_bounds: value.coefficient_lower_bounds,
1039            linear_constraints: value.linear_constraints,
1040        }
1041    }
1042}
1043
1044#[derive(Debug, Default, Clone, Serialize, Deserialize)]
1045pub enum BoundedCoefficientPriorSpec {
1046    #[default]
1047    None,
1048    Uniform,
1049    Beta {
1050        a: f64,
1051        b: f64,
1052    },
1053}
1054
1055#[derive(Debug, Clone, Serialize, Deserialize, Default)]
1056pub enum LinearCoefficientGeometry {
1057    #[default]
1058    Unconstrained,
1059    Bounded {
1060        min: f64,
1061        max: f64,
1062        #[serde(default)]
1063        prior: BoundedCoefficientPriorSpec,
1064    },
1065}
1066
1067#[derive(Debug, Clone, Serialize, Deserialize)]
1068pub struct LinearTermSpec {
1069    pub name: String,
1070    /// Primary feature column index. For Wilkinson-Rogers `:` interaction
1071    /// terms (`a:b[:c...]`) this is the first column in `feature_cols`; the
1072    /// realized design column is the elementwise product across every entry
1073    /// of `feature_cols`. Plain (non-interaction) linear terms set
1074    /// `feature_cols == vec![feature_col]`.
1075    pub feature_col: usize,
1076    /// Full list of columns whose elementwise product yields this term's
1077    /// design column. `len() >= 1`; `len() == 1` is a plain linear effect.
1078    #[serde(default)]
1079    pub feature_cols: Vec<usize>,
1080    /// Categorical-level gates for a factor-aware `:` interaction.
1081    ///
1082    /// Each `(col, level_bits)` multiplies the realized design column by the
1083    /// indicator `1[data[row, col].to_bits() == level_bits]`. This is how a
1084    /// `factor:x` (or `factor:factor`) interaction is expanded: `build_termspec`
1085    /// emits one `LinearTermSpec` per surviving cell of the categorical
1086    /// operand(s) (treatment-coded, first level dropped per factor), each
1087    /// carrying the numeric operands in `feature_cols` and the cell's level
1088    /// gate(s) here. Empty for a plain numeric `:` interaction or main effect,
1089    /// in which case the realized column is exactly the numeric product.
1090    #[serde(default)]
1091    pub categorical_levels: Vec<(usize, u64)>,
1092    /// Optional ridge (`S = I`, REML-selected `λ`) on this linear coefficient.
1093    /// A parametric linear term carries no wiggliness, so it is **unpenalized by
1094    /// default** — gam reports the MLE, matching mgcv/glm/survreg/VGAM (which
1095    /// penalize parametric terms only under an explicit `paraPen`). Set `true`
1096    /// to opt into an explicit shrinkage ridge (a zero-mean Gaussian prior
1097    /// `β ~ N(0, λ⁻¹)`); doing so adds one outer REML smoothing coordinate.
1098    #[serde(default = "default_linear_term_double_penalty")]
1099    pub double_penalty: bool,
1100    #[serde(default)]
1101    pub coefficient_geometry: LinearCoefficientGeometry,
1102    #[serde(default)]
1103    pub coefficient_min: Option<f64>,
1104    #[serde(default)]
1105    pub coefficient_max: Option<f64>,
1106}
1107
1108impl LinearTermSpec {
1109    /// Return the effective list of feature columns. Backfills from
1110    /// `feature_col` for legacy specs that predate the multi-column field.
1111    pub fn effective_feature_cols(&self) -> Vec<usize> {
1112        if self.feature_cols.is_empty() {
1113            vec![self.feature_col]
1114        } else {
1115            self.feature_cols.clone()
1116        }
1117    }
1118
1119    /// True when this term is a Wilkinson-Rogers `:` interaction (multi-col).
1120    pub fn is_interaction(&self) -> bool {
1121        self.feature_cols.len() > 1 || !self.categorical_levels.is_empty()
1122    }
1123
1124    /// Realize this linear term's `(n,)` design column from `data`.
1125    ///
1126    /// The column is the elementwise product of every numeric feature column
1127    /// (`effective_feature_cols`) gated by the categorical-level indicators in
1128    /// `categorical_levels`: each `(col, level_bits)` multiplies the running
1129    /// column by `1[data[row, col].to_bits() == level_bits]`. A plain numeric
1130    /// term (no `categorical_levels`) reduces to the bare product, matching the
1131    /// historical behaviour. A pure categorical interaction (empty
1132    /// `feature_cols`, non-empty `categorical_levels`) reduces to the cell
1133    /// indicator. Bounds are validated here; the returned column has length
1134    /// `data.nrows()`.
1135    pub fn realized_design_column(&self, data: ArrayView2<'_, f64>) -> Result<Array1<f64>, String> {
1136        let n = data.nrows();
1137        let p = data.ncols();
1138        let bounds = |col: usize| -> Result<(), String> {
1139            if col >= p {
1140                Err(format!(
1141                    "linear term '{}' feature column {} out of bounds for {} columns",
1142                    self.name, col, p
1143                ))
1144            } else {
1145                Ok(())
1146            }
1147        };
1148
1149        // Numeric operands. When `categorical_levels` is set we treat
1150        // `feature_cols` as the (possibly empty) numeric operand list and start
1151        // from a column of ones; otherwise we preserve the legacy backfill from
1152        // `feature_col` so a plain term with no `feature_cols` still resolves.
1153        let mut column = if self.categorical_levels.is_empty() {
1154            let cols = self.effective_feature_cols();
1155            for &c in &cols {
1156                bounds(c)?;
1157            }
1158            let mut acc = data.column(cols[0]).to_owned();
1159            for &c in cols.iter().skip(1) {
1160                acc *= &data.column(c);
1161            }
1162            acc
1163        } else {
1164            let mut acc = Array1::<f64>::ones(n);
1165            for &c in &self.feature_cols {
1166                bounds(c)?;
1167                acc *= &data.column(c);
1168            }
1169            acc
1170        };
1171
1172        for &(col, level_bits) in &self.categorical_levels {
1173            bounds(col)?;
1174            let gate = data.column(col);
1175            for (out, &v) in column.iter_mut().zip(gate.iter()) {
1176                if v.to_bits() != level_bits {
1177                    *out = 0.0;
1178                }
1179            }
1180        }
1181
1182        Ok(column)
1183    }
1184}
1185
1186pub const fn default_linear_term_double_penalty() -> bool {
1187    // Parametric/linear terms are unpenalized by default — a single linear
1188    // coefficient has no roughness for a smoothing penalty to control, so the
1189    // historical `S = I`, REML-selected `λ` shrank every linear coefficient off
1190    // the MLE and injected a spurious outer smoothing coordinate (#749). Mature
1191    // tools (mgcv/glm/survreg/VGAM) leave parametric terms unpenalized; gam now
1192    // matches that and reports the MLE. An explicit `double_penalty = true`
1193    // still opts a term into a ridge.
1194    false
1195}
1196
1197pub const fn default_pca_smooth_penalty() -> f64 {
1198    1.0
1199}
1200
1201pub const fn default_pca_chunk_size() -> usize {
1202    4096
1203}
1204
1205/// Random-effects term specification.
1206///
1207/// The selected feature column is interpreted as a categorical grouping variable.
1208/// The term contributes a one-hot dummy block with an identity penalty on group
1209/// coefficients, equivalent to i.i.d. Gaussian random effects.
1210#[derive(Debug, Clone, Serialize, Deserialize)]
1211pub struct RandomEffectTermSpec {
1212    pub name: String,
1213    pub feature_col: usize,
1214    /// If true, drop the lexicographically first group level to use treatment coding.
1215    /// If false, keep all levels (full one-hot block, still identifiable under ridge).
1216    pub drop_first_level: bool,
1217    /// If true, add a ridge penalty and estimate this block as a random effect.
1218    /// If false, leave the one-hot/treatment-coded block unpenalized so it is a
1219    /// fixed categorical main effect.  The default preserves older saved models.
1220    #[serde(default = "default_random_effect_penalized")]
1221    pub penalized: bool,
1222    /// Optional fixed kept-level set (sorted by f64 bit pattern) captured at fit time.
1223    /// When present, prediction uses exactly these columns to avoid design drift.
1224    #[serde(default)]
1225    pub frozen_levels: Option<Vec<u64>>,
1226}
1227
1228pub fn default_random_effect_penalized() -> bool {
1229    true
1230}
1231
1232pub fn validate_measure_jet_positive_vec_len(
1233    label: &str,
1234    term_name: &str,
1235    field: &str,
1236    values: &[f64],
1237    expected: usize,
1238) -> Result<(), String> {
1239    if values.len() != expected {
1240        return Err(SmoothError::invalid_config(format!(
1241            "{label} term '{term_name}' frozen MeasureJet {field} has length {}, expected {expected}",
1242            values.len()
1243        ))
1244        .into());
1245    }
1246    if values
1247        .iter()
1248        .any(|value| !(value.is_finite() && *value > 0.0))
1249    {
1250        return Err(SmoothError::invalid_config(format!(
1251            "{label} term '{term_name}' frozen MeasureJet {field} values must be positive and finite"
1252        ))
1253        .into());
1254    }
1255    Ok(())
1256}
1257
1258#[derive(Debug, Clone, Serialize, Deserialize)]
1259pub struct TermCollectionSpec {
1260    pub linear_terms: Vec<LinearTermSpec>,
1261    pub random_effect_terms: Vec<RandomEffectTermSpec>,
1262    pub smooth_terms: Vec<SmoothTermSpec>,
1263}
1264
1265pub fn validate_smooth_basis_frozen(
1266    basis: &SmoothBasisSpec,
1267    label: &str,
1268    term_name: &str,
1269) -> Result<(), String> {
1270    match basis {
1271        SmoothBasisSpec::ByVariable { inner, .. }
1272        | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
1273            validate_smooth_basis_frozen(inner, label, term_name)
1274        }
1275        SmoothBasisSpec::BSpline1D { spec, .. } => {
1276            if !matches!(
1277                spec.knotspec,
1278                BSplineKnotSpec::Provided(_)
1279                    | BSplineKnotSpec::PeriodicUniform { .. }
1280                    | BSplineKnotSpec::NaturalCubicRegression { .. }
1281            ) {
1282                return Err(format!(
1283                    "{label} term '{term_name}' is not frozen: BSpline knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression"
1284                ));
1285            }
1286            Ok(())
1287        }
1288        SmoothBasisSpec::ThinPlate { spec, .. } => {
1289            if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1290                return Err(format!(
1291                    "{label} term '{term_name}' is not frozen: ThinPlate centers must be UserProvided"
1292                ));
1293            }
1294            if matches!(
1295                spec.identifiability,
1296                SpatialIdentifiability::OrthogonalToParametric
1297            ) {
1298                return Err(format!(
1299                    "{label} term '{term_name}' is not frozen: ThinPlate identifiability must be FrozenTransform or None"
1300                ));
1301            }
1302            Ok(())
1303        }
1304        _ => Ok(()),
1305    }
1306}
1307
1308impl TermCollectionSpec {
1309    /// Write this collection's topology identity into a warm-start cache
1310    /// fingerprint (#869).
1311    ///
1312    /// The persistent warm-start `cache_key` hashes only family + raw input
1313    /// dimensions, so two fits on the same data that differ *only* in their
1314    /// smooth topology (the `s(..., type=AUTO)` candidate enumeration: sphere
1315    /// vs torus vs euclidean vs duchon) collide on one key and seed each other
1316    /// with geometrically incompatible β/ρ. Folding the per-term structural
1317    /// kind + feature columns + linear/random-effect counts into the shape hash
1318    /// gives each candidate its own key, so the screen→full-refit reuse of one
1319    /// candidate is preserved while cross-candidate contamination is removed.
1320    /// Only the structural identity is hashed (not fitted coefficients or
1321    /// frozen knot values), so a refit of the *same* topology still hits.
1322    pub fn write_structural_shape_hash(&self, h: &mut gam_runtime::warm_start::Fingerprinter) {
1323        h.write_str("term-collection");
1324        h.write_usize(self.linear_terms.len());
1325        for linear in &self.linear_terms {
1326            h.write_str(&linear.name);
1327        }
1328        h.write_usize(self.random_effect_terms.len());
1329        h.write_usize(self.smooth_terms.len());
1330        for smooth in &self.smooth_terms {
1331            h.write_str(&smooth.name);
1332            h.write_str(smooth.basis.structural_kind());
1333            for col in smooth.basis.structural_feature_cols() {
1334                h.write_usize(col);
1335            }
1336        }
1337    }
1338
1339    /// Validate that a term collection spec represents a fully frozen model
1340    /// (i.e. all knots/centers are pre-computed, identifiability transforms are
1341    /// baked in, and random-effect levels are fixed).
1342    pub fn validate_frozen(&self, label: &str) -> Result<(), String> {
1343        for linear in &self.linear_terms {
1344            if let (Some(min), Some(max)) = (linear.coefficient_min, linear.coefficient_max)
1345                && (!min.is_finite() || !max.is_finite() || min > max)
1346            {
1347                return Err(SmoothError::invalid_config(format!(
1348                    "{label} linear term '{}' has invalid coefficient constraint [{min}, {max}]",
1349                    linear.name
1350                ))
1351                .into());
1352            }
1353            if let Some(min) = linear.coefficient_min
1354                && !min.is_finite()
1355            {
1356                return Err(SmoothError::invalid_config(format!(
1357                    "{label} linear term '{}' has non-finite coefficient minimum {min}",
1358                    linear.name
1359                ))
1360                .into());
1361            }
1362            if let Some(max) = linear.coefficient_max
1363                && !max.is_finite()
1364            {
1365                return Err(SmoothError::invalid_config(format!(
1366                    "{label} linear term '{}' has non-finite coefficient maximum {max}",
1367                    linear.name
1368                ))
1369                .into());
1370            }
1371            if let LinearCoefficientGeometry::Bounded { min, max, prior } =
1372                &linear.coefficient_geometry
1373            {
1374                if !min.is_finite() || !max.is_finite() || min >= max {
1375                    return Err(SmoothError::invalid_config(format!(
1376                        "{label} bounded term '{}' has invalid bounds [{min}, {max}]",
1377                        linear.name
1378                    ))
1379                    .into());
1380                }
1381                match prior {
1382                    BoundedCoefficientPriorSpec::None | BoundedCoefficientPriorSpec::Uniform => {}
1383                    BoundedCoefficientPriorSpec::Beta { a, b } => {
1384                        if !a.is_finite() || !b.is_finite() || *a < 1.0 || *b < 1.0 {
1385                            return Err(SmoothError::invalid_config(format!(
1386                                "{label} bounded term '{}' has invalid Beta prior ({a}, {b})",
1387                                linear.name
1388                            ))
1389                            .into());
1390                        }
1391                    }
1392                }
1393            }
1394        }
1395        for st in &self.smooth_terms {
1396            match &st.basis {
1397                SmoothBasisSpec::ByVariable { inner, .. } => {
1398                    validate_smooth_basis_frozen(inner, label, &st.name)?;
1399                    let nested = SmoothTermSpec {
1400                        name: st.name.clone(),
1401                        basis: (**inner).clone(),
1402                        shape: st.shape,
1403                        joint_null_rotation: None,
1404                    };
1405                    TermCollectionSpec {
1406                        linear_terms: Vec::new(),
1407                        random_effect_terms: Vec::new(),
1408                        smooth_terms: vec![nested],
1409                    }
1410                    .validate_frozen(label)?;
1411                }
1412                SmoothBasisSpec::FactorSumToZero { inner, levels, .. } => {
1413                    if levels.len() < 2 {
1414                        return Err(format!(
1415                            "{label} term '{}' has invalid frozen sz levels",
1416                            st.name
1417                        ));
1418                    }
1419                    validate_smooth_basis_frozen(inner, label, &st.name)?;
1420                }
1421                SmoothBasisSpec::BSpline1D { spec, .. } => {
1422                    if !matches!(
1423                        spec.knotspec,
1424                        BSplineKnotSpec::Provided(_)
1425                            | BSplineKnotSpec::PeriodicUniform { .. }
1426                            | BSplineKnotSpec::NaturalCubicRegression { .. }
1427                    ) {
1428                        return Err(SmoothError::invalid_config(format!(
1429                            "{label} term '{}' is not frozen: BSpline knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression",
1430                            st.name
1431                        ))
1432                        .into());
1433                    }
1434                }
1435                SmoothBasisSpec::ThinPlate { spec, .. } => {
1436                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1437                        return Err(SmoothError::invalid_config(format!(
1438                            "{label} term '{}' is not frozen: ThinPlate centers must be UserProvided",
1439                            st.name
1440                        ))
1441                        .into());
1442                    }
1443                    if matches!(
1444                        spec.identifiability,
1445                        SpatialIdentifiability::OrthogonalToParametric
1446                    ) {
1447                        return Err(SmoothError::invalid_config(format!(
1448                            "{label} term '{}' is not frozen: ThinPlate identifiability must be FrozenTransform or None",
1449                            st.name
1450                        ))
1451                        .into());
1452                    }
1453                }
1454                SmoothBasisSpec::Sphere { spec, .. } => {
1455                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1456                        return Err(SmoothError::invalid_config(format!(
1457                            "{label} term '{}' is not frozen: Sphere centers must be UserProvided",
1458                            st.name
1459                        ))
1460                        .into());
1461                    }
1462                    if matches!(spec.method, crate::basis::SphereMethod::Harmonic)
1463                        && spec.max_degree.is_none_or(|d| d == 0)
1464                    {
1465                        return Err(format!(
1466                            "{label} term '{}' is not frozen: sphere max_degree must be positive",
1467                            st.name
1468                        ));
1469                    }
1470                }
1471                SmoothBasisSpec::ConstantCurvature { spec, .. } => {
1472                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1473                        return Err(SmoothError::invalid_config(format!(
1474                            "{label} term '{}' is not frozen: ConstantCurvature centers must be UserProvided",
1475                            st.name
1476                        ))
1477                        .into());
1478                    }
1479                    if !(spec.length_scale.is_finite() && spec.length_scale > 0.0) {
1480                        return Err(SmoothError::invalid_config(format!(
1481                            "{label} term '{}' is not frozen: ConstantCurvature length_scale must be the realized positive value",
1482                            st.name
1483                        ))
1484                        .into());
1485                    }
1486                }
1487                SmoothBasisSpec::MeasureJet { spec, .. } => {
1488                    let centers = match &spec.center_strategy {
1489                        CenterStrategy::UserProvided(centers) => centers,
1490                        _ => {
1491                            return Err(SmoothError::invalid_config(format!(
1492                                "{label} term '{}' is not frozen: MeasureJet centers must be UserProvided",
1493                                st.name
1494                            ))
1495                            .into());
1496                        }
1497                    };
1498                    if centers.nrows() == 0 {
1499                        return Err(SmoothError::invalid_config(format!(
1500                            "{label} term '{}' is not frozen: MeasureJet centers are empty",
1501                            st.name
1502                        ))
1503                        .into());
1504                    }
1505                    if !(spec.length_scale.is_finite() && spec.length_scale > 0.0) {
1506                        return Err(SmoothError::invalid_config(format!(
1507                            "{label} term '{}' is not frozen: MeasureJet length_scale must be the realized positive value",
1508                            st.name
1509                        ))
1510                        .into());
1511                    }
1512                    // Exact replay needs the fit-data penalty quadrature and
1513                    // normalization payload (`BasisMetadata::MeasureJet`).
1514                    let frozen = spec.frozen_quadrature.as_ref().ok_or_else(|| {
1515                        SmoothError::invalid_config(format!(
1516                            "{label} term '{}' is not frozen: MeasureJet frozen_quadrature payload is missing",
1517                            st.name
1518                        ))
1519                    })?;
1520                    if frozen.masses.len() != centers.nrows() {
1521                        return Err(SmoothError::invalid_config(format!(
1522                            "{label} term '{}' frozen MeasureJet has {} masses for {} centers",
1523                            st.name,
1524                            frozen.masses.len(),
1525                            centers.nrows()
1526                        ))
1527                        .into());
1528                    }
1529                    let total_mass = frozen.masses.sum();
1530                    if frozen
1531                        .masses
1532                        .iter()
1533                        .any(|mass| !(mass.is_finite() && *mass >= 0.0))
1534                        || !(total_mass.is_finite() && total_mass > 0.0)
1535                    {
1536                        return Err(SmoothError::invalid_config(format!(
1537                            "{label} term '{}' frozen MeasureJet masses must be finite, nonnegative, and have positive total mass",
1538                            st.name
1539                        ))
1540                        .into());
1541                    }
1542                    let n_levels = frozen.eps_band.len();
1543                    if n_levels == 0
1544                        || frozen
1545                            .eps_band
1546                            .iter()
1547                            .any(|eps| !(eps.is_finite() && *eps > 0.0))
1548                    {
1549                        return Err(SmoothError::invalid_config(format!(
1550                            "{label} term '{}' frozen MeasureJet eps_band must be nonempty, finite, and positive",
1551                            st.name
1552                        ))
1553                        .into());
1554                    }
1555                    for (idx, pair) in frozen.eps_band.windows(2).enumerate() {
1556                        if pair[1] <= pair[0] {
1557                            return Err(SmoothError::invalid_config(format!(
1558                                "{label} term '{}' frozen MeasureJet eps_band is not strictly ascending at {idx}: {} then {}",
1559                                st.name,
1560                                pair[0],
1561                                pair[1]
1562                            ))
1563                            .into());
1564                        }
1565                    }
1566                    validate_measure_jet_positive_vec_len(
1567                        label,
1568                        &st.name,
1569                        "support_means",
1570                        &frozen.support_means,
1571                        n_levels,
1572                    )?;
1573                    // Mode predicate MUST match the builder's
1574                    // (`measure_jet_multiscale_mode`): per-level/multiscale is the
1575                    // explicit `spec.multiscale` opt-in (#1116). In single-scale
1576                    // mode the builder emits a single FUSED penalty (empty
1577                    // per-level scales + `fused_penalty_normalization_scale:
1578                    // Some`); only the multiscale opt-in carries `n_levels`
1579                    // per-level scales.
1580                    let per_level = crate::basis::measure_jet_multiscale_mode(spec);
1581                    if per_level {
1582                        validate_measure_jet_positive_vec_len(
1583                            label,
1584                            &st.name,
1585                            "penalty_normalization_scales",
1586                            &frozen.penalty_normalization_scales,
1587                            n_levels,
1588                        )?;
1589                        validate_measure_jet_positive_vec_len(
1590                            label,
1591                            &st.name,
1592                            "raw_penalty_normalization_scales",
1593                            &frozen.raw_penalty_normalization_scales,
1594                            n_levels,
1595                        )?;
1596                        if frozen.fused_penalty_normalization_scale.is_some() {
1597                            return Err(SmoothError::invalid_config(format!(
1598                                "{label} term '{}' per-level MeasureJet must not carry a fused penalty normalization scale",
1599                                st.name
1600                            ))
1601                            .into());
1602                        }
1603                    } else {
1604                        if !frozen.penalty_normalization_scales.is_empty()
1605                            || !frozen.raw_penalty_normalization_scales.is_empty()
1606                        {
1607                            return Err(SmoothError::invalid_config(format!(
1608                                "{label} term '{}' fused MeasureJet must not carry per-level penalty normalization scales",
1609                                st.name
1610                            ))
1611                            .into());
1612                        }
1613                        match frozen.fused_penalty_normalization_scale {
1614                            Some(scale) if scale.is_finite() && scale > 0.0 => {}
1615                            Some(scale) => {
1616                                return Err(SmoothError::invalid_config(format!(
1617                                    "{label} term '{}' fused MeasureJet penalty normalization scale must be positive and finite, got {scale}",
1618                                    st.name
1619                                ))
1620                                .into());
1621                            }
1622                            None => {
1623                                return Err(SmoothError::invalid_config(format!(
1624                                    "{label} term '{}' fused MeasureJet is missing its penalty normalization scale",
1625                                    st.name
1626                                ))
1627                                .into());
1628                            }
1629                        }
1630                    }
1631                }
1632                SmoothBasisSpec::Matern { spec, .. } => {
1633                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1634                        return Err(SmoothError::invalid_config(format!(
1635                            "{label} term '{}' is not frozen: Matern centers must be UserProvided",
1636                            st.name
1637                        ))
1638                        .into());
1639                    }
1640                }
1641                SmoothBasisSpec::Duchon { spec, .. } => {
1642                    if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1643                        return Err(SmoothError::invalid_config(format!(
1644                            "{label} term '{}' is not frozen: Duchon centers must be UserProvided",
1645                            st.name
1646                        ))
1647                        .into());
1648                    }
1649                    if matches!(
1650                        spec.identifiability,
1651                        SpatialIdentifiability::OrthogonalToParametric
1652                    ) {
1653                        return Err(SmoothError::invalid_config(format!(
1654                            "{label} term '{}' is not frozen: Duchon identifiability must be FrozenTransform or None",
1655                            st.name
1656                        ))
1657                        .into());
1658                    }
1659                }
1660                SmoothBasisSpec::Pca {
1661                    centered,
1662                    center_mean,
1663                    pca_basis_path,
1664                    ..
1665                } => {
1666                    if *centered && center_mean.is_none() && pca_basis_path.is_none() {
1667                        return Err(SmoothError::invalid_config(format!(
1668                            "{label} term '{}' is not frozen: centered Pca missing center_mean",
1669                            st.name
1670                        ))
1671                        .into());
1672                    }
1673                }
1674                SmoothBasisSpec::BySmooth { smooth, by_kind } => {
1675                    if let SmoothBasisSpec::BySmooth { .. } = smooth.as_ref() {
1676                        return Err(format!("{label} term '{}' has nested by-smooths", st.name));
1677                    }
1678                    match by_kind {
1679                        ByVarKind::Numeric { .. } => {}
1680                        ByVarKind::Factor { frozen_levels, .. } if frozen_levels.is_none() => {
1681                            return Err(format!(
1682                                "{label} term '{}' is not frozen: by-factor levels missing",
1683                                st.name
1684                            ));
1685                        }
1686                        ByVarKind::Factor { .. } => {}
1687                    }
1688                    let nested = TermCollectionSpec {
1689                        linear_terms: vec![],
1690                        random_effect_terms: vec![],
1691                        smooth_terms: vec![SmoothTermSpec {
1692                            name: st.name.clone(),
1693                            basis: (**smooth).clone(),
1694                            shape: st.shape,
1695                            joint_null_rotation: None,
1696                        }],
1697                    };
1698                    nested.validate_frozen(label)?;
1699                }
1700                SmoothBasisSpec::FactorSmooth { spec } => {
1701                    if spec.group_frozen_levels.is_none() {
1702                        return Err(format!(
1703                            "{label} term '{}' is not frozen: factor-smooth levels missing",
1704                            st.name
1705                        ));
1706                    }
1707                    if !matches!(
1708                        spec.marginal.knotspec,
1709                        BSplineKnotSpec::Provided(_)
1710                            | BSplineKnotSpec::PeriodicUniform { .. }
1711                            // mgcv's `bs="sz"` default marginal is a cubic
1712                            // regression spline (#1074), and the freeze step
1713                            // restores it as a `NaturalCubicRegression` knotspec
1714                            // carrying its `k` value-knots (spatial_optimization.rs
1715                            // `marginal_is_cr` branch) — the SAME treatment the
1716                            // tensor margin already gets in the arm below. Without
1717                            // this variant a frozen `sz` factor smooth fails its own
1718                            // predict-time freeze check ("factor-smooth marginal
1719                            // knots missing") even though its knots are fully
1720                            // materialized; the validation simply was not updated
1721                            // when the cr marginal landed.
1722                            | BSplineKnotSpec::NaturalCubicRegression { .. }
1723                    ) {
1724                        return Err(format!(
1725                            "{label} term '{}' is not frozen: factor-smooth marginal knots missing",
1726                            st.name
1727                        ));
1728                    }
1729                }
1730                SmoothBasisSpec::TensorBSpline { spec, .. } => {
1731                    for (dim, marginal) in spec.marginalspecs.iter().enumerate() {
1732                        if !matches!(
1733                            marginal.knotspec,
1734                            BSplineKnotSpec::Provided(_)
1735                                | BSplineKnotSpec::PeriodicUniform { .. }
1736                                | BSplineKnotSpec::NaturalCubicRegression { .. }
1737                        ) {
1738                            return Err(SmoothError::invalid_config(format!(
1739                                "{label} term '{}' dim {} is not frozen: tensor marginal knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression",
1740                                st.name, dim
1741                            ))
1742                            .into());
1743                        }
1744                    }
1745                    if matches!(
1746                        spec.identifiability,
1747                        TensorBSplineIdentifiability::SumToZero
1748                            | TensorBSplineIdentifiability::MarginalSumToZero
1749                    ) {
1750                        return Err(SmoothError::invalid_config(format!(
1751                            "{label} term '{}' is not frozen: tensor identifiability must be FrozenTransform or None",
1752                            st.name
1753                        ))
1754                        .into());
1755                    }
1756                }
1757            }
1758        }
1759
1760        for rt in &self.random_effect_terms {
1761            if rt.frozen_levels.is_none() {
1762                return Err(SmoothError::invalid_config(format!(
1763                    "{label} random-effect term '{}' is not frozen: missing frozen_levels",
1764                    rt.name
1765                ))
1766                .into());
1767            }
1768        }
1769
1770        Ok(())
1771    }
1772
1773    /// Re-resolve every stored feature-column index through `remap`, returning a
1774    /// spec that addresses a different column layout.
1775    ///
1776    /// A frozen `TermCollectionSpec` stores feature columns as *absolute indices
1777    /// into the training table*. To replay it on a fresh dataset whose columns
1778    /// sit at different positions — the common case at prediction time, where
1779    /// the response column is unknown and may be absent entirely — every index
1780    /// must be re-resolved against the new layout. `remap` receives each
1781    /// training-table index and returns its position in the runtime table;
1782    /// callers typically implement it as "look the name up in the training
1783    /// headers, then resolve that name against the prediction dataset".
1784    ///
1785    /// This is the single authority on *which* fields carry a column index
1786    /// across every basis variant (linear, random-effect, the `by=` column of
1787    /// `ByVariable`/`FactorSumToZero`/`BySmooth`, the continuous and group
1788    /// columns of a `FactorSmooth`, and the multi-axis `feature_cols` of every
1789    /// spatial/tensor basis), so a predict-time realignment cannot silently miss
1790    /// one and dereference a stale training index.
1791    pub fn remap_feature_columns<E, F>(&self, mut remap: F) -> Result<TermCollectionSpec, E>
1792    where
1793        F: FnMut(usize) -> Result<usize, E>,
1794    {
1795        let mut out = self.clone();
1796        for lt in &mut out.linear_terms {
1797            lt.feature_col = remap(lt.feature_col)?;
1798            // Also remap the full interaction-factor list. The design builder
1799            // (`build_term_collection_design_inner`) materializes the column from
1800            // `effective_feature_cols()` — which returns `feature_cols` whenever
1801            // it is non-empty (i.e. essentially always, including a plain linear
1802            // term where `feature_cols == [feature_col]`). Remapping only the
1803            // singular `feature_col` left these at their saved *training* indices
1804            // at predict time, so a parametric `Surv(...) ~ x` (and any `:`
1805            // interaction) bailed with "feature column N out of bounds" once the
1806            // response/time columns shift the runtime layout (issue #898).
1807            for fc in lt.feature_cols.iter_mut() {
1808                *fc = remap(*fc)?;
1809            }
1810            // A factor-aware `:` interaction also gates on categorical columns;
1811            // those indices live in the same training-time layout and must be
1812            // realigned to the runtime table alongside the numeric operands, or
1813            // the predict-time level indicator would dereference a stale column.
1814            for (col, _bits) in lt.categorical_levels.iter_mut() {
1815                *col = remap(*col)?;
1816            }
1817        }
1818        for rt in &mut out.random_effect_terms {
1819            rt.feature_col = remap(rt.feature_col)?;
1820        }
1821        for st in &mut out.smooth_terms {
1822            remap_smooth_basis_feature_columns(&mut st.basis, &mut remap)?;
1823        }
1824        Ok(out)
1825    }
1826}
1827
1828/// Walk a `SmoothBasisSpec` tree, re-resolving every column index through
1829/// `remap`. Shared by all predict-time column realignment (see
1830/// [`TermCollectionSpec::remap_feature_columns`]); kept exhaustive so a newly
1831/// added index-bearing variant fails to compile until it is handled here.
1832pub fn remap_smooth_basis_feature_columns<E, F>(
1833    basis: &mut SmoothBasisSpec,
1834    remap: &mut F,
1835) -> Result<(), E>
1836where
1837    F: FnMut(usize) -> Result<usize, E>,
1838{
1839    match basis {
1840        SmoothBasisSpec::ByVariable { inner, by_col, .. }
1841        | SmoothBasisSpec::FactorSumToZero { inner, by_col, .. } => {
1842            *by_col = remap(*by_col)?;
1843            remap_smooth_basis_feature_columns(inner, remap)?;
1844        }
1845        SmoothBasisSpec::BSpline1D { feature_col, .. } => {
1846            *feature_col = remap(*feature_col)?;
1847        }
1848        SmoothBasisSpec::BySmooth { smooth, by_kind } => {
1849            let by_feature_col = match by_kind {
1850                ByVarKind::Numeric { feature_col } | ByVarKind::Factor { feature_col, .. } => {
1851                    feature_col
1852                }
1853            };
1854            *by_feature_col = remap(*by_feature_col)?;
1855            remap_smooth_basis_feature_columns(smooth, remap)?;
1856        }
1857        SmoothBasisSpec::FactorSmooth { spec } => {
1858            for fc in spec.continuous_cols.iter_mut() {
1859                *fc = remap(*fc)?;
1860            }
1861            spec.group_col = remap(spec.group_col)?;
1862        }
1863        SmoothBasisSpec::ThinPlate { feature_cols, .. }
1864        | SmoothBasisSpec::Sphere { feature_cols, .. }
1865        | SmoothBasisSpec::ConstantCurvature { feature_cols, .. }
1866        | SmoothBasisSpec::Matern { feature_cols, .. }
1867        | SmoothBasisSpec::MeasureJet { feature_cols, .. }
1868        | SmoothBasisSpec::Duchon { feature_cols, .. }
1869        | SmoothBasisSpec::Pca { feature_cols, .. }
1870        | SmoothBasisSpec::TensorBSpline { feature_cols, .. } => {
1871            for fc in feature_cols.iter_mut() {
1872                *fc = remap(*fc)?;
1873            }
1874        }
1875    }
1876    Ok(())
1877}
1878
1879#[derive(Debug, Clone)]
1880pub enum PenaltyStructureHint {
1881    Ridge(f64),
1882    Kronecker(Vec<Array2<f64>>),
1883}
1884
1885/// A penalty matrix stored at its natural block size together with the
1886/// column range it occupies in the global coefficient vector.
1887///
1888/// Instead of embedding every penalty into a full `p_total × p_total` dense
1889/// matrix filled with zeros, we keep the compact local matrix and reconstruct
1890/// the global view only when a downstream consumer explicitly requires it.
1891#[derive(Clone)]
1892pub struct BlockwisePenalty {
1893    /// Column range in the global coefficient vector that this penalty covers.
1894    pub col_range: Range<usize>,
1895    /// The local penalty matrix — dimensions `block_p × block_p` where
1896    /// `block_p = col_range.len()`.
1897    pub local: Array2<f64>,
1898    /// Optional nonzero centering vector for this coefficient block.
1899    pub prior_mean: gam_problem::CoefficientPriorMean,
1900    /// Optional structural hint so downstream spectral/logdet code can stay
1901    /// block-local or factorized without reverse-engineering the matrix.
1902    pub structure_hint: Option<PenaltyStructureHint>,
1903    /// Optional operator-form handle bit-equivalent to `local`. Populated when
1904    /// the originating closed-form factory emitted an op-form penalty so exact
1905    /// operator algebra can use matvec instead of materializing the dense
1906    /// `block_p × block_p` Gram. `None` for ordinary dense penalties.
1907    pub op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1908}
1909
1910impl std::fmt::Debug for BlockwisePenalty {
1911    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1912        f.debug_struct("BlockwisePenalty")
1913            .field("col_range", &self.col_range)
1914            .field(
1915                "local",
1916                &format_args!("{}×{}", self.local.nrows(), self.local.ncols()),
1917            )
1918            .field("prior_mean", &self.prior_mean)
1919            .field("structure_hint", &self.structure_hint)
1920            .field("op", &self.op.as_ref().map(|o| o.dim()))
1921            .finish()
1922    }
1923}
1924
1925impl BlockwisePenalty {
1926    /// Create a new blockwise penalty.
1927    pub fn new(col_range: Range<usize>, local: Array2<f64>) -> Self {
1928        assert_eq!(col_range.len(), local.nrows());
1929        assert_eq!(col_range.len(), local.ncols());
1930        Self {
1931            col_range,
1932            local,
1933            prior_mean: gam_problem::CoefficientPriorMean::Zero,
1934            structure_hint: None,
1935            op: None,
1936        }
1937    }
1938
1939    pub fn with_prior_mean(
1940        mut self,
1941        prior_mean: gam_problem::CoefficientPriorMean,
1942    ) -> Self {
1943        self.prior_mean = prior_mean;
1944        self
1945    }
1946
1947    /// Attach an op-form penalty handle bit-equivalent to `local`.
1948    pub fn with_op(
1949        mut self,
1950        op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1951    ) -> Self {
1952        self.op = op;
1953        self
1954    }
1955
1956    pub fn ridge(col_range: Range<usize>, scale: f64) -> Self {
1957        let block_size = col_range.len();
1958        let mut local = Array2::<f64>::zeros((block_size, block_size));
1959        for i in 0..block_size {
1960            local[[i, i]] = scale;
1961        }
1962        Self {
1963            col_range,
1964            local,
1965            prior_mean: gam_problem::CoefficientPriorMean::Zero,
1966            structure_hint: Some(PenaltyStructureHint::Ridge(scale)),
1967            op: None,
1968        }
1969    }
1970
1971    pub fn kronecker(
1972        col_range: Range<usize>,
1973        local: Array2<f64>,
1974        factors: Vec<Array2<f64>>,
1975    ) -> Self {
1976        assert_eq!(col_range.len(), local.nrows());
1977        assert_eq!(col_range.len(), local.ncols());
1978        Self {
1979            col_range,
1980            local,
1981            prior_mean: gam_problem::CoefficientPriorMean::Zero,
1982            structure_hint: Some(PenaltyStructureHint::Kronecker(factors)),
1983            op: None,
1984        }
1985    }
1986
1987    /// Expand this blockwise penalty into a full `p_total × p_total` dense
1988    /// matrix (mostly zeros). Use sparingly — the whole point of blockwise
1989    /// storage is to avoid this allocation.
1990    pub fn to_global(&self, p_total: usize) -> Array2<f64> {
1991        let mut g = Array2::<f64>::zeros((p_total, p_total));
1992        let r = &self.col_range;
1993        assert!(
1994            r.end <= p_total && self.local.nrows() == r.len() && self.local.ncols() == r.len(),
1995            "BlockwisePenalty::to_global shape invariant violated: \
1996             col_range={}..{}, local={}x{}, p_total={}",
1997            r.start,
1998            r.end,
1999            self.local.nrows(),
2000            self.local.ncols(),
2001            p_total,
2002        );
2003        g.slice_mut(s![r.start..r.end, r.start..r.end])
2004            .assign(&self.local);
2005        g
2006    }
2007
2008    /// Convert into a blockwise [`gam_problem::PenaltyMatrix`] without
2009    /// expanding to full dimensions.
2010    pub fn to_penalty_matrix(
2011        &self,
2012        total_dim: usize,
2013    ) -> gam_problem::PenaltyMatrix {
2014        gam_problem::PenaltyMatrix::Blockwise {
2015            local: self.local.clone(),
2016            col_range: self.col_range.clone(),
2017            total_dim,
2018        }
2019    }
2020
2021    /// The block size of this penalty.
2022    #[inline]
2023    pub fn block_size(&self) -> usize {
2024        self.col_range.len()
2025    }
2026}
2027
2028/// Compute `Σ_k λ_k S_k` directly from blockwise penalties, accumulating
2029/// into a pre-allocated `p_total × p_total` output without ever materializing
2030/// individual global matrices.
2031pub fn weighted_blockwise_penalty_sum(
2032    penalties: &[BlockwisePenalty],
2033    lambdas: &[f64],
2034    p_total: usize,
2035) -> Array2<f64> {
2036    assert_eq!(penalties.len(), lambdas.len());
2037    // Smoothing parameters λ_k must be non-negative and finite. A negative
2038    // λ would flip the sign of the corresponding block S_k, turning the
2039    // total penalty matrix indefinite and silently corrupting every
2040    // downstream Cholesky / PIRLS / REML / pseudo-logdet computation that
2041    // assumes S_λ ⪰ 0. Catch this at the boundary rather than after it
2042    // has propagated.
2043    for (idx, &lam) in lambdas.iter().enumerate() {
2044        assert!(
2045            lam.is_finite() && lam >= 0.0,
2046            "weighted_blockwise_penalty_sum: lambdas[{idx}] = {lam} is invalid (must be finite and non-negative; negative smoothing parameters violate S_λ ⪰ 0)",
2047        );
2048    }
2049    // Block column ranges must also fit inside the declared total parameter
2050    // dimension; an out-of-bounds slice would otherwise panic from ndarray
2051    // with a far less informative message.
2052    for (idx, bp) in penalties.iter().enumerate() {
2053        let r = &bp.col_range;
2054        assert!(
2055            r.end <= p_total,
2056            "weighted_blockwise_penalty_sum: penalties[{idx}] col_range {:?} exceeds p_total = {p_total}",
2057            r,
2058        );
2059    }
2060    let mut out = Array2::<f64>::zeros((p_total, p_total));
2061    for (bp, &lam) in penalties.iter().zip(lambdas.iter()) {
2062        let r = &bp.col_range;
2063        let mut slice = out.slice_mut(s![r.start..r.end, r.start..r.end]);
2064        slice.scaled_add(lam, &bp.local);
2065    }
2066    out
2067}
2068
2069// ---------------------------------------------------------------------------
2070// KroneckerPenaltySystem — factored tensor-product penalty representation
2071// ---------------------------------------------------------------------------
2072
2073/// Factored representation of tensor-product penalties with precomputed
2074/// marginal eigensystems for O(∏q_j) logdet and penalty operations.
2075#[derive(Debug, Clone)]
2076pub struct KroneckerPenaltySystem {
2077    /// Marginal penalty matrices: `marginal_penalties[k]` is `(q_k, q_k)`.
2078    pub marginal_penalties: Vec<Array2<f64>>,
2079    /// Precomputed eigensystems: `(eigenvalues, eigenvectors)` per marginal.
2080    pub marginal_eigensystems: Vec<(Array1<f64>, Array2<f64>)>,
2081    /// Marginal basis dimensions.
2082    pub marginal_dims: Vec<usize>,
2083    /// Whether a global ridge (double) penalty is present.
2084    pub has_double_penalty: bool,
2085}
2086
2087impl KroneckerPenaltySystem {
2088    pub fn new(
2089        marginal_penalties: Vec<Array2<f64>>,
2090        marginal_dims: Vec<usize>,
2091        has_double_penalty: bool,
2092    ) -> Result<Self, BasisError> {
2093        if marginal_penalties.len() != marginal_dims.len() {
2094            crate::bail_dim_basis!(
2095                "KroneckerPenaltySystem: {} penalties vs {} dims",
2096                marginal_penalties.len(),
2097                marginal_dims.len()
2098            );
2099        }
2100        let eigensystems =
2101            kronecker_marginal_eigensystems(&marginal_penalties, "KroneckerPenaltySystem")
2102                .map_err(|e| BasisError::InvalidInput(e.to_string()))?;
2103        Ok(Self {
2104            marginal_penalties,
2105            marginal_eigensystems: eigensystems,
2106            marginal_dims,
2107            has_double_penalty,
2108        })
2109    }
2110
2111    pub fn p_total(&self) -> usize {
2112        self.marginal_dims.iter().copied().product()
2113    }
2114
2115    pub fn ndim(&self) -> usize {
2116        self.marginal_dims.len()
2117    }
2118
2119    pub fn num_penalties(&self) -> usize {
2120        self.marginal_dims.len() + if self.has_double_penalty { 1 } else { 0 }
2121    }
2122
2123    /// Compute `log|S|₊` and its first/second derivatives w.r.t. `ρ_k = log(λ_k)`.
2124    ///
2125    /// Iterates over the ∏q_j multi-index grid. Cost: O(d · ∏q_j), no O(p²) storage.
2126    pub fn logdet_and_derivatives(
2127        &self,
2128        lambdas: &[f64],
2129        ridge: f64,
2130    ) -> (f64, Array1<f64>, Array2<f64>) {
2131        let n_pen = self.num_penalties();
2132        assert_eq!(lambdas.len(), n_pen, "lambda count mismatch");
2133        let marginal_evals: Vec<_> = self
2134            .marginal_eigensystems
2135            .iter()
2136            .map(|(evals, _)| evals.view())
2137            .collect();
2138        kronecker_logdet_and_derivatives(
2139            &marginal_evals,
2140            &self.marginal_dims,
2141            lambdas,
2142            self.has_double_penalty,
2143            ridge,
2144        )
2145    }
2146
2147    pub fn logdet_rank_and_derivatives(
2148        &self,
2149        lambdas: &[f64],
2150        ridge: f64,
2151    ) -> (f64, usize, Array1<f64>, Array2<f64>) {
2152        let n_pen = self.num_penalties();
2153        assert_eq!(lambdas.len(), n_pen, "lambda count mismatch");
2154        let d = self.marginal_dims.len();
2155        let mut logdet = 0.0;
2156        let mut rank = 0usize;
2157        let mut grad = Array1::<f64>::zeros(n_pen);
2158        let mut hess = Array2::<f64>::zeros((n_pen, n_pen));
2159        // Positivity floor for a penalized eigenvalue `σ`: below this the mode
2160        // is treated as an unpenalized (null-space) direction and excluded from
2161        // both the rank count and the pseudo-log-determinant.
2162        const EIGENVALUE_POSITIVITY_FLOOR: f64 = 1e-12;
2163        // Floor on the *structural* eigenvalue sum (λ-independent) used to decide
2164        // whether a mode lives in the penalty range space and so should receive
2165        // the stabilizing ridge; a structurally-null mode gets no ridge.
2166        const STRUCTURAL_ZERO_FLOOR: f64 = 1e-12;
2167        let mut multi_idx = vec![0usize; d];
2168        loop {
2169            let mut sigma = 0.0;
2170            let mut structural_sigma = 0.0;
2171            for k in 0..d {
2172                let marginal_eigenvalue = self.marginal_eigensystems[k].0[multi_idx[k]];
2173                structural_sigma += marginal_eigenvalue;
2174                sigma += lambdas[k] * marginal_eigenvalue;
2175            }
2176            let joint_null = structural_sigma <= STRUCTURAL_ZERO_FLOOR;
2177            if self.has_double_penalty && joint_null {
2178                sigma += lambdas[d];
2179            }
2180            if structural_sigma > STRUCTURAL_ZERO_FLOOR {
2181                sigma += ridge;
2182            }
2183
2184            if sigma > EIGENVALUE_POSITIVITY_FLOOR {
2185                rank += 1;
2186                logdet += sigma.ln();
2187                let inv_sigma = 1.0 / sigma;
2188                let inv_sigma2 = inv_sigma * inv_sigma;
2189                for k in 0..n_pen {
2190                    let ck = if k < d {
2191                        lambdas[k] * self.marginal_eigensystems[k].0[multi_idx[k]]
2192                    } else if joint_null {
2193                        lambdas[d]
2194                    } else {
2195                        0.0
2196                    };
2197                    grad[k] += ck * inv_sigma;
2198                    hess[[k, k]] += ck * inv_sigma - ck * ck * inv_sigma2;
2199                    for l in (k + 1)..n_pen {
2200                        let cl = if l < d {
2201                            lambdas[l] * self.marginal_eigensystems[l].0[multi_idx[l]]
2202                        } else if joint_null {
2203                            lambdas[d]
2204                        } else {
2205                            0.0
2206                        };
2207                        let off = -ck * cl * inv_sigma2;
2208                        hess[[k, l]] += off;
2209                        hess[[l, k]] += off;
2210                    }
2211                }
2212            }
2213
2214            let mut carry = true;
2215            for dim in (0..d).rev() {
2216                if carry {
2217                    multi_idx[dim] += 1;
2218                    if multi_idx[dim] < self.marginal_dims[dim] {
2219                        carry = false;
2220                    } else {
2221                        multi_idx[dim] = 0;
2222                    }
2223                }
2224            }
2225            if carry {
2226                break;
2227            }
2228        }
2229        (logdet, rank, grad, hess)
2230    }
2231}
2232
2233#[cfg(test)]
2234mod joint_unpenalized_dim_tests {
2235    use super::joint_unpenalized_dim;
2236    use ndarray::{Array2, array};
2237
2238    #[test]
2239    fn no_penalty_is_fully_unpenalized() {
2240        assert_eq!(joint_unpenalized_dim(4, &[], &[]), 4);
2241    }
2242
2243    #[test]
2244    fn single_penalty_returns_its_own_null_space() {
2245        // A 3×3 penalty that penalizes only the last coordinate ⇒ 2-dim null
2246        // space (the first two coordinates are unpenalized).
2247        let s = array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 5.0]];
2248        assert_eq!(joint_unpenalized_dim(3, std::slice::from_ref(&s), &[2]), 2);
2249    }
2250
2251    #[test]
2252    fn complementary_double_penalty_has_empty_joint_null_space() {
2253        // The #1360 case in miniature: a "bending" penalty that leaves the
2254        // first coordinate (its 2-dim... here 1-dim) null, plus a
2255        // complementary "null-space ridge" that penalizes exactly that
2256        // coordinate. Per-penalty null dims are {1, 2} and sum to 3 (≈ p),
2257        // but the INTERSECTION is empty: every coordinate is penalized by
2258        // someone, so the joint unpenalized dim is 0.
2259        let bending = array![[0.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]];
2260        let ridge = array![[2.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]];
2261        assert_eq!(joint_unpenalized_dim(3, &[bending, ridge], &[1, 2]), 0);
2262    }
2263
2264    #[test]
2265    fn partial_overlap_keeps_shared_null_direction() {
2266        // Two penalties that BOTH leave coordinate 0 unpenalized ⇒ the shared
2267        // null direction survives the intersection (joint unpenalized dim 1),
2268        // even though naively summing the per-penalty dims would give 4.
2269        let a = array![[0.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 0.0]];
2270        let b = array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]];
2271        assert_eq!(joint_unpenalized_dim(3, &[a, b], &[2, 2]), 1);
2272    }
2273
2274    #[test]
2275    fn non_materialized_penalty_falls_back_conservatively() {
2276        // A penalty whose stored block is not p_local × p_local (e.g. a
2277        // Kronecker tensor factor). With ≥2 penalties the conservative joint
2278        // dim is 0 (never over-rejecting).
2279        let full: Array2<f64> = array![[0.0, 0.0], [0.0, 1.0]];
2280        let factor: Array2<f64> = array![[1.0]]; // wrong shape for p_local=2
2281        assert_eq!(
2282            joint_unpenalized_dim(2, &[full, factor.clone()], &[1, 0]),
2283            0
2284        );
2285        // With a single non-materialized penalty, fall back to its own null dim.
2286        assert_eq!(joint_unpenalized_dim(4, std::slice::from_ref(&factor), &[2]), 2);
2287    }
2288}
2289
2290#[cfg(test)]
2291mod kronecker_penalty_system_tests {
2292    use super::KroneckerPenaltySystem;
2293    use ndarray::array;
2294
2295    #[test]
2296    fn double_penalty_rank_derivatives_use_only_joint_null_space() {
2297        let penalties = vec![
2298            array![[0.0, 0.0], [0.0, 2.0]],
2299            array![[0.0, 0.0], [0.0, 3.0]],
2300        ];
2301        let system = KroneckerPenaltySystem::new(penalties, vec![2usize, 2usize], true).unwrap();
2302        let lambdas = vec![5.0, 7.0, 11.0];
2303
2304        let (logdet, rank, grad, hess) = system.logdet_rank_and_derivatives(&lambdas, 0.0);
2305
2306        let expected_diag = [11.0_f64, 21.0, 10.0, 31.0];
2307        let expected_logdet: f64 = expected_diag.iter().map(|v| v.ln()).sum();
2308        assert_eq!(rank, 4);
2309        assert!((logdet - expected_logdet).abs() <= 1e-12);
2310        assert!(
2311            (grad[2] - 1.0).abs() <= 1e-12,
2312            "double-penalty rank derivative must count only the joint null mode, got {}",
2313            grad[2]
2314        );
2315        assert!(hess[[2, 2]].abs() <= 1e-12);
2316    }
2317}
2318
2319#[derive(Clone, Debug)]
2320pub struct TermCollectionDesign {
2321    /// The full design matrix.
2322    ///
2323    /// Prefer a true sparse matrix when every block is sparse-compatible.
2324    /// If the collection already contains intrinsically sparse blocks, preserve
2325    /// that storage and let PIRLS decide later whether the penalized system is
2326    /// sparse-native eligible. Purely dense materialized blocks still fall back
2327    /// to the lazy block operator when sparse storage would just re-encode a
2328    /// dense matrix.
2329    pub design: DesignMatrix,
2330    pub penalties: Vec<BlockwisePenalty>,
2331    pub nullspace_dims: Vec<usize>,
2332    pub penaltyinfo: Vec<PenaltyBlockInfo>,
2333    pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
2334    /// Optional global coefficient lower bounds for constrained fitting.
2335    /// Length equals `design.ncols()` when present. Unconstrained entries are `-inf`.
2336    pub coefficient_lower_bounds: Option<Array1<f64>>,
2337    /// Optional global inequality constraints:
2338    /// `A * beta >= b`.
2339    pub linear_constraints: Option<LinearInequalityConstraints>,
2340    pub intercept_range: Range<usize>,
2341    pub linear_ranges: Vec<(String, Range<usize>)>,
2342    pub random_effect_ranges: Vec<(String, Range<usize>)>,
2343    pub random_effect_levels: Vec<(String, Vec<u64>)>,
2344    pub smooth: SmoothDesign,
2345}
2346
2347impl TermCollectionDesign {
2348    /// Convert blockwise penalties to `PenaltyMatrix::Blockwise` without
2349    /// expanding to `p_total × p_total`. This is the preferred path for
2350    /// family modules that accept `Vec<PenaltyMatrix>`.
2351    pub fn penalties_as_penalty_matrix(&self) -> Vec<gam_problem::PenaltyMatrix> {
2352        let p = self.design.ncols();
2353        self.penalties
2354            .iter()
2355            .map(|bp| bp.to_penalty_matrix(p))
2356            .collect()
2357    }
2358
2359    /// Number of penalty blocks.
2360    #[inline]
2361    pub fn num_penalties(&self) -> usize {
2362        self.penalties.len()
2363    }
2364
2365    /// Resolve coefficient groups against this design's global coefficient
2366    /// layout and append their penalties after the existing term penalties.
2367    pub fn realize_coefficient_groups(
2368        &self,
2369        groups: &[CoefficientGroupSpec],
2370        base_prior: &gam_spec::RhoPrior,
2371    ) -> Result<RealizedCoefficientGroups, BasisError> {
2372        realize_coefficient_groups(self, groups, base_prior)
2373    }
2374
2375    /// Extract a `KroneckerPenaltySystem` when the model's *only* smooth term is
2376    /// a single Kronecker-factored tensor.
2377    ///
2378    /// This is a deliberate single-tensor fast path, not a partial feature: any
2379    /// other shape — zero Kronecker terms, several of them, or a tensor mixed
2380    /// with non-tensor smooth terms — is served correctly by the standard
2381    /// block-separable assembly, so this returns `None` and the caller falls
2382    /// back to it. The two former conditions (`len != 1` and "a non-Kronecker
2383    /// smooth term exists") are jointly equivalent to "the sole smooth term is
2384    /// Kronecker", which the slice pattern below expresses directly in one pass.
2385    pub fn kronecker_penalty_system(&self) -> Option<KroneckerPenaltySystem> {
2386        let [only_term] = self.smooth.terms.as_slice() else {
2387            return None;
2388        };
2389        let kron = only_term.kronecker_factored.as_ref()?;
2390        // A genuine tensor product needs at least two margins, and the marginal
2391        // design / penalty / dim collections must agree in length. A degenerate
2392        // (single-margin) or internally inconsistent factored basis cannot feed
2393        // the Kronecker fast path, so fall back to the standard assembly rather
2394        // than construct a malformed `KroneckerPenaltySystem` from it.
2395        if kron.marginal_dims.len() < 2
2396            || kron.marginal_penalties.len() != kron.marginal_dims.len()
2397            || kron.marginal_designs.len() != kron.marginal_dims.len()
2398        {
2399            return None;
2400        }
2401        KroneckerPenaltySystem::new(
2402            kron.marginal_penalties.clone(),
2403            kron.marginal_dims.clone(),
2404            kron.has_double_penalty,
2405        )
2406        .ok()
2407    }
2408}
2409
2410// `FittedTermCollection`, `SpatialLengthScaleOptimizationTiming`, and
2411// `FittedTermCollectionWithSpec` were relocated with the GAM fit-orchestration
2412// drivers to `gam-models` (`crate::fit_orchestration::drivers`) — they hold a
2413// `gam_solve::UnifiedFitResult` and are consumed only by those drivers (#1521).
2414
2415#[derive(Clone)]
2416pub struct StandardLatentCoordConfig {
2417    pub values: std::sync::Arc<crate::latent::LatentCoordValues>,
2418    pub term_index: gam_problem::types::SmoothTermIdx,
2419    pub feature_cols: Vec<usize>,
2420    pub manifold: crate::latent::LatentManifold,
2421    pub manifold_auto: bool,
2422    pub retraction_registry: gam_problem::LatentRetractionRegistry,
2423    pub analytic_penalties: Option<std::sync::Arc<crate::AnalyticPenaltyRegistry>>,
2424}
2425
2426#[derive(Clone, Debug, Serialize, Deserialize)]
2427pub struct AdaptiveSpatialMap {
2428    pub termname: String,
2429    pub feature_cols: Vec<usize>,
2430    pub collocation_points: Array2<f64>,
2431    pub inv_magweight: Array1<f64>,
2432    pub invgradweight: Array1<f64>,
2433    pub inv_lapweight: Array1<f64>,
2434}
2435
2436#[derive(Clone, Debug, Serialize, Deserialize)]
2437pub struct AdaptiveRegularizationDiagnostics {
2438    pub epsilon_0: f64,
2439    pub epsilon_g: f64,
2440    pub epsilon_c: f64,
2441    pub epsilon_outer_iterations: usize,
2442    pub mm_iterations: usize,
2443    pub converged: bool,
2444    pub maps: Vec<AdaptiveSpatialMap>,
2445}
2446
2447#[derive(Debug, Clone)]
2448pub struct LinearColumnConditioning {
2449    col_idx: usize,
2450    mean: f64,
2451    scale: f64,
2452}
2453
2454#[derive(Debug, Clone, Default)]
2455pub struct LinearFitConditioning {
2456    pub intercept_idx: usize,
2457    pub columns: Vec<LinearColumnConditioning>,
2458}
2459
2460#[derive(Clone)]
2461pub struct SpatialPsiDerivative {
2462    // These are derivatives with respect to psi = log(kappa), not log(length_scale).
2463    pub penalty_index: usize,
2464    pub penalty_indices: Vec<usize>,
2465    pub global_range: Range<usize>,
2466    pub total_p: usize,
2467    pub x_psi_local: Array2<f64>,
2468    pub s_psi_components_local: Vec<Array2<f64>>,
2469    pub x_psi_psi_local: Array2<f64>,
2470    pub s_psi_psi_components_local: Vec<Array2<f64>>,
2471    pub aniso_group_id: Option<usize>,
2472    /// Pre-computed cross-derivative design matrices for other axes
2473    /// in the same aniso group: Vec of (axis_offset_in_group, matrix).
2474    pub aniso_cross_designs: Option<Vec<(usize, Array2<f64>)>>,
2475    /// On-demand cross-penalty second derivatives ∂²S_m/∂ψ_a∂ψ_b for axes in
2476    /// the same anisotropy group. The input is the other axis offset in the
2477    /// group, and the output is one local penalty matrix per active penalty.
2478    pub aniso_cross_penalty_provider: Option<
2479        std::sync::Arc<
2480            dyn Fn(usize) -> Result<Vec<Array2<f64>>, EstimationError> + Send + Sync + 'static,
2481        >,
2482    >,
2483    /// Optional implicit design-derivative operator (shared across all axes
2484    /// in the same aniso group). When present, `x_psi_local` and
2485    /// `x_psi_psi_local` may be zero-sized, and design-derivative matvecs
2486    /// should go through this operator using `implicit_axis` as the axis index.
2487    pub implicit_operator: Option<std::sync::Arc<crate::basis::ImplicitDesignPsiDerivative>>,
2488    /// Which axis in the implicit operator this entry corresponds to.
2489    pub implicit_axis: usize,
2490}
2491
2492#[derive(Debug, Clone)]
2493pub struct SpatialLogKappaCoords {
2494    /// Flattened ψ values. For isotropic terms, one entry per term.
2495    /// For anisotropic terms, d entries per term (one ψ_a per axis).
2496    pub values: Array1<f64>,
2497    /// Dimensionality of each term: 1 for isotropic, d for anisotropic.
2498    pub dims_per_term: Vec<usize>,
2499}
2500
2501/// Which end of the ψ bound the shared `aniso_bounds_from_data` helper is
2502/// computing. The lower end uses `-max_length_scale.ln()` as the pure-Duchon
2503/// fallback and the `.0` element of `spatial_term_psi_bounds`; the upper end
2504/// uses `-min_length_scale.ln()` and `.1`. Everything else is identical.
2505#[derive(Clone, Copy)]
2506pub enum AnisoBoundEnd {
2507    Lower,
2508    Upper,
2509}
2510
2511impl SpatialLogKappaCoords {
2512    /// Construct from an explicit dims layout plus values.
2513    pub fn new_with_dims(values: Array1<f64>, dims_per_term: Vec<usize>) -> Self {
2514        assert_eq!(
2515            values.len(),
2516            dims_per_term.iter().sum::<usize>(),
2517            "SpatialLogKappaCoords: values length {} != sum of dims_per_term {}",
2518            values.len(),
2519            dims_per_term.iter().sum::<usize>(),
2520        );
2521        Self {
2522            values,
2523            dims_per_term,
2524        }
2525    }
2526
2527    /// Isotropic initialization (backward-compatible path).
2528    pub fn from_length_scales(
2529        spec: &TermCollectionSpec,
2530        term_indices: &[usize],
2531        options: &SpatialLengthScaleOptimizationOptions,
2532    ) -> Self {
2533        let mut out = Array1::<f64>::zeros(term_indices.len());
2534        for (slot, &term_idx) in term_indices.iter().enumerate() {
2535            // Constant-curvature: the single ψ slot is the raw signed κ, seeded
2536            // from the spec (default κ = 0). The −ln(length_scale) convention is
2537            // log-κ semantics and must not touch the raw-κ coordinate; the κ
2538            // window projection happens later via `clamp_to_bounds`. Mirrors the
2539            // aniso constructor's κ branch.
2540            if let Some(cc) = constant_curvature_term_spec(spec, term_idx) {
2541                out[slot] = cc.kappa;
2542                continue;
2543            }
2544            let length_scale = get_spatial_length_scale(spec, term_idx)
2545                .unwrap_or(options.min_length_scale)
2546                .clamp(options.min_length_scale, options.max_length_scale);
2547            out[slot] = -length_scale.ln();
2548        }
2549        Self {
2550            values: out,
2551            dims_per_term: vec![1; term_indices.len()],
2552        }
2553    }
2554
2555    /// Anisotropic-aware initialization.
2556    ///
2557    /// Initialization strategy (per math team recommendation): standardize the
2558    /// knot cloud axiswise, then run the existing isotropic κ initializer in
2559    /// the standardized space. This reuses the trusted isotropic initializer
2560    /// and gives initial η_a = −ln(σ_a) + mean(ln(σ_a)), which satisfies
2561    /// Ση_a = 0 by construction.
2562    ///
2563    /// For each term, checks whether it has `aniso_log_scales` set on its basis spec.
2564    /// - If isotropic (no aniso_log_scales, or 1-D): 1 entry = −ln(length_scale).
2565    /// - If anisotropic with a scalar length scale: d entries, one ψ_a per axis.
2566    ///   Initialized as ψ_a = −ln(length_scale) + η_a  where η_a are the existing
2567    ///   aniso_log_scales (which sum to zero). Multi-dimensional terms without
2568    ///   explicit anisotropy stay scalar here so the seed dimensionality matches
2569    ///   `spatial_dims_per_term`.
2570    /// - If pure Duchon anisotropic: d - 1 free entries store the leading η_a
2571    ///   values directly; the final axis is reconstructed to keep Ση_a = 0.
2572    pub fn from_length_scales_aniso(
2573        spec: &TermCollectionSpec,
2574        term_indices: &[usize],
2575        options: &SpatialLengthScaleOptimizationOptions,
2576    ) -> Self {
2577        let mut vals = Vec::new();
2578        let mut dims = Vec::new();
2579        for &term_idx in term_indices {
2580            // Measure-jet: dial coordinates seeded directly from the term's
2581            // realized (α, τ[, s]); the −ln(length_scale) convention below is
2582            // κ-semantics and never applies to dials.
2583            if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
2584                let seed = measure_jet_psi_seed(mj);
2585                dims.push(seed.len());
2586                vals.extend(seed);
2587                continue;
2588            }
2589            // Constant-curvature: one signed κ slot seeded from the spec's κ
2590            // (clamped feasible). The −ln(length_scale) convention below is
2591            // log-κ semantics and must not touch the raw-κ coordinate. Bounds
2592            // are unavailable here (no data view), so this is the raw spec κ;
2593            // `reseed_from_data` / `clamp_to_bounds` later project it feasible.
2594            if let Some(cc) = constant_curvature_term_spec(spec, term_idx) {
2595                vals.push(cc.kappa);
2596                dims.push(1);
2597                continue;
2598            }
2599            let length_scale = get_spatial_length_scale(spec, term_idx)
2600                .unwrap_or(options.min_length_scale)
2601                .clamp(options.min_length_scale, options.max_length_scale);
2602            let psi_bar = -length_scale.ln(); // global scale = −ln(length_scale)
2603
2604            if spatial_term_uses_per_axis_psi(spec, term_idx) {
2605                // Per-axis anisotropy is enrolled in the joint outer vector:
2606                // ψ_a = ψ̄ + η_a, one slot per axis. The hyper_dirs builder
2607                // produces matching per-axis derivatives in
2608                // `try_build_spatial_term_log_kappa_aniso_derivativeinfos`.
2609                let d = get_spatial_feature_dim(spec, term_idx).unwrap_or(1);
2610                let eta_raw = get_spatial_aniso_log_scales(spec, term_idx)
2611                    .expect("predicate guarantees aniso_log_scales is Some");
2612                let eta = center_aniso_log_scales(&eta_raw);
2613                for &eta_a in &eta {
2614                    vals.push(psi_bar + eta_a);
2615                }
2616                dims.push(d);
2617            } else {
2618                // Isotropic enrollment — either a 1-D term, a multi-D term
2619                // without explicit anisotropy, or a basis (e.g. Duchon) whose
2620                // η is a fixed geometry parameter rather than a REML hyper
2621                // axis. Exactly one ψ̄ slot, matching the single
2622                // `SpatialPsiDerivative` produced by
2623                // `try_build_spatial_term_log_kappa_derivativeinfo`.
2624                vals.push(psi_bar);
2625                dims.push(1);
2626            }
2627        }
2628        Self {
2629            values: Array1::from_vec(vals),
2630            dims_per_term: dims,
2631        }
2632    }
2633
2634    /// Isotropic lower bounds derived from per-term data geometry.
2635    /// Each entry gets the ψ_lo bound returned by `spatial_term_psi_bounds`
2636    /// for the corresponding term, intersected with the options window.
2637    pub fn lower_bounds_from_data(
2638        data: ArrayView2<'_, f64>,
2639        spec: &TermCollectionSpec,
2640        term_indices: &[usize],
2641        options: &SpatialLengthScaleOptimizationOptions,
2642    ) -> Self {
2643        let mut values = Array1::<f64>::zeros(term_indices.len());
2644        for (slot, &term_idx) in term_indices.iter().enumerate() {
2645            values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).0;
2646        }
2647        Self {
2648            values,
2649            dims_per_term: vec![1; term_indices.len()],
2650        }
2651    }
2652
2653    /// Isotropic upper bounds derived from per-term data geometry.
2654    pub fn upper_bounds_from_data(
2655        data: ArrayView2<'_, f64>,
2656        spec: &TermCollectionSpec,
2657        term_indices: &[usize],
2658        options: &SpatialLengthScaleOptimizationOptions,
2659    ) -> Self {
2660        let mut values = Array1::<f64>::zeros(term_indices.len());
2661        for (slot, &term_idx) in term_indices.iter().enumerate() {
2662            values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).1;
2663        }
2664        Self {
2665            values,
2666            dims_per_term: vec![1; term_indices.len()],
2667        }
2668    }
2669
2670    /// Anisotropic-aware lower bounds derived from per-term data geometry.
2671    /// For hybrid anisotropic terms the scalar ψ_lo bound applies to the
2672    /// mean `ψ̄`, not directly to every raw axis coordinate `ψ_a = ψ̄ + η_a`.
2673    /// Shift each axis by the current centered `η_a` so projecting/clamping
2674    /// the seed moves only the global scale direction and does not silently
2675    /// shrink anisotropy that is already consistent with the current
2676    /// `length_scale`.
2677    ///
2678    /// Pure Duchon anisotropy is structurally different: its stored
2679    /// coordinates are (d-1) free η_a values representing log axis-scale
2680    /// ratios, NOT log-κ. For those terms the κ-range geometry bound is
2681    /// over-restrictive (η_a = ±5 is normal, but that corresponds to 7+
2682    /// orders of magnitude in κ-space and would be rejected by the data
2683    /// window). Fall back to the options window `[-ln(max_ls), -ln(min_ls)]`
2684    /// for those coordinates — that's the same bound the pre-data-geometry
2685    /// code used, which is calibrated to allow legitimate anisotropy.
2686    pub fn lower_bounds_aniso_from_data(
2687        data: ArrayView2<'_, f64>,
2688        spec: &TermCollectionSpec,
2689        term_indices: &[usize],
2690        dims_per_term: &[usize],
2691        options: &SpatialLengthScaleOptimizationOptions,
2692    ) -> Self {
2693        Self::aniso_bounds_from_data(
2694            data,
2695            spec,
2696            term_indices,
2697            dims_per_term,
2698            options,
2699            AnisoBoundEnd::Lower,
2700        )
2701    }
2702
2703    /// Anisotropic-aware upper bounds derived from per-term data geometry.
2704    /// See `lower_bounds_aniso_from_data` for the hybrid-aniso offsetting and
2705    /// pure-Duchon dispatch rationale.
2706    pub fn upper_bounds_aniso_from_data(
2707        data: ArrayView2<'_, f64>,
2708        spec: &TermCollectionSpec,
2709        term_indices: &[usize],
2710        dims_per_term: &[usize],
2711        options: &SpatialLengthScaleOptimizationOptions,
2712    ) -> Self {
2713        Self::aniso_bounds_from_data(
2714            data,
2715            spec,
2716            term_indices,
2717            dims_per_term,
2718            options,
2719            AnisoBoundEnd::Upper,
2720        )
2721    }
2722
2723    /// Shared implementation for the lower/upper aniso bounds. The bound end
2724    /// only changes which options scale (`max_length_scale` vs
2725    /// `min_length_scale`) becomes the pure-Duchon fallback bound and which
2726    /// element of the `(lo, hi)` data-geometry tuple is consumed; the
2727    /// per-term cursor walk and aniso-offset handling are identical.
2728    fn aniso_bounds_from_data(
2729        data: ArrayView2<'_, f64>,
2730        spec: &TermCollectionSpec,
2731        term_indices: &[usize],
2732        dims_per_term: &[usize],
2733        options: &SpatialLengthScaleOptimizationOptions,
2734        end: AnisoBoundEnd,
2735    ) -> Self {
2736        assert_eq!(term_indices.len(), dims_per_term.len());
2737        let total: usize = dims_per_term.iter().sum();
2738        let mut values = Array1::<f64>::zeros(total);
2739        let mut cursor = 0;
2740        for (slot, &term_idx) in term_indices.iter().enumerate() {
2741            let d = dims_per_term[slot];
2742            // Measure-jet: per-coordinate dial boxes, never κ-window geometry
2743            // (which would reject legitimate dial values outright).
2744            if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
2745                let bounds = measure_jet_psi_bound_values(mj, matches!(end, AnisoBoundEnd::Upper));
2746                for (offset, bound) in bounds.into_iter().enumerate() {
2747                    if offset < d {
2748                        values[cursor + offset] = bound;
2749                    }
2750                }
2751                cursor += d;
2752                continue;
2753            }
2754            // Constant-curvature: the single signed-κ box from the data chart
2755            // window (symmetric about κ = 0), never a κ = log-scale window.
2756            if constant_curvature_term_spec(spec, term_idx).is_some() {
2757                let (lo, hi) = constant_curvature_kappa_bounds(data, spec, term_idx);
2758                if d >= 1 {
2759                    values[cursor] = match end {
2760                        AnisoBoundEnd::Lower => lo,
2761                        AnisoBoundEnd::Upper => hi,
2762                    };
2763                }
2764                cursor += d;
2765                continue;
2766            }
2767            let psi_bound = {
2768                let (lo, hi) = spatial_term_psi_bounds(data, spec, term_idx, options);
2769                match end {
2770                    AnisoBoundEnd::Lower => lo,
2771                    AnisoBoundEnd::Upper => hi,
2772                }
2773            };
2774            let axis_offsets = if d <= 1 {
2775                vec![0.0; d]
2776            } else {
2777                get_spatial_aniso_log_scales(spec, term_idx)
2778                    .filter(|eta| eta.len() == d)
2779                    .map(|eta| center_aniso_log_scales(&eta))
2780                    .unwrap_or_else(|| vec![0.0; d])
2781            };
2782            for offset in 0..d {
2783                values[cursor + offset] = psi_bound + axis_offsets[offset];
2784            }
2785            cursor += d;
2786        }
2787        Self {
2788            values,
2789            dims_per_term: dims_per_term.to_vec(),
2790        }
2791    }
2792
2793    /// Rewrite any ψ entries whose originating term lacks an explicit
2794    /// `length_scale` so they sit at the midpoint of the per-term data-derived
2795    /// ψ window. Used so the outer optimizer starts inside the physically
2796    /// meaningful region instead of at an arbitrary `options.max_length_scale`
2797    /// derived seed. For terms with an explicit length_scale, the user's
2798    /// choice is respected. Anisotropy offsets η_a (those stored by
2799    /// `from_length_scales_aniso`) are preserved: we re-center around the new
2800    /// ψ̄, keeping Ση_a = 0.
2801    pub fn reseed_from_data(
2802        mut self,
2803        data: ArrayView2<'_, f64>,
2804        spec: &TermCollectionSpec,
2805        term_indices: &[usize],
2806        options: &SpatialLengthScaleOptimizationOptions,
2807    ) -> Self {
2808        assert_eq!(term_indices.len(), self.dims_per_term.len());
2809        let mut cursor = 0;
2810        for (slot, &term_idx) in term_indices.iter().enumerate() {
2811            let d = self.dims_per_term[slot];
2812            // Measure-jet dials are seeded from the realized spec and must
2813            // not be recentered into a κ data window.
2814            if measure_jet_term_spec(spec, term_idx).is_some() {
2815                cursor += d;
2816                continue;
2817            }
2818            // Constant-curvature κ is seeded from the spec (the user's curvature
2819            // hint, default κ = 0); `clamp_to_bounds` projects it feasible. It
2820            // is not a log-scale, so the log-κ recenter below never applies.
2821            if constant_curvature_term_spec(spec, term_idx).is_some() {
2822                cursor += d;
2823                continue;
2824            }
2825            let Some(psi_bar_new) = spatial_term_psi_seed(data, spec, term_idx, options) else {
2826                cursor += d;
2827                continue;
2828            };
2829            if d == 0 {
2830                continue;
2831            }
2832            let current: Vec<f64> = self.values.slice(s![cursor..cursor + d]).to_vec();
2833            let psi_bar_old = current.iter().sum::<f64>() / d as f64;
2834            for (offset, &old_value) in current.iter().enumerate() {
2835                self.values[cursor + offset] = psi_bar_new + (old_value - psi_bar_old);
2836            }
2837            cursor += d;
2838        }
2839        self
2840    }
2841
2842    /// Project ψ values into `[lower, upper]` element-wise. Used after
2843    /// `from_length_scales*` + `reseed_from_data` when a user-supplied
2844    /// `spec.length_scale` falls outside the data-derived ψ window set by
2845    /// `{lower,upper}_bounds*_from_data`. BFGS requires theta0 ∈ [lower,
2846    /// upper]; projecting is the unique closest feasible seed. The user's
2847    /// length_scale was always a hint for the outer optimizer (the optimizer
2848    /// is authoritative for κ), not a hard constraint — so clipping preserves
2849    /// their intent as far as the geometry allows. Emits `log::info!` when
2850    /// any coordinate moves, so the outside-window case is diagnostically
2851    /// visible (not silent).
2852    pub fn clamp_to_bounds(
2853        mut self,
2854        lower: &SpatialLogKappaCoords,
2855        upper: &SpatialLogKappaCoords,
2856    ) -> Self {
2857        assert_eq!(self.values.len(), lower.values.len());
2858        assert_eq!(self.values.len(), upper.values.len());
2859        let mut n_projected = 0usize;
2860        let mut worst_delta = 0.0_f64;
2861        for idx in 0..self.values.len() {
2862            let lo = lower.values[idx];
2863            let hi = upper.values[idx];
2864            if !(lo.is_finite() && hi.is_finite()) {
2865                continue;
2866            }
2867            let v = self.values[idx];
2868            if v < lo {
2869                worst_delta = worst_delta.max(lo - v);
2870                self.values[idx] = lo;
2871                n_projected += 1;
2872            } else if v > hi {
2873                worst_delta = worst_delta.max(v - hi);
2874                self.values[idx] = hi;
2875                n_projected += 1;
2876            }
2877        }
2878        if n_projected > 0 {
2879            log::info!(
2880                "[spatial-kappa] projected {n_projected}/{} ψ seed coords into data-derived bounds \
2881                 (worst excess={worst_delta:.3} log units); user length_scale falls outside \
2882                 [{KERNEL_RANGE_MIN_DIAMETER_FRACTION}/r_max, {KERNEL_RANGE_MAX_SPACING_MULTIPLE}/r_min] geometry window",
2883                self.values.len()
2884            );
2885        }
2886        self
2887    }
2888
2889    /// Reconstruct from theta tail with known dimensionality layout.
2890    pub fn from_theta_tail_with_dims(
2891        theta: &Array1<f64>,
2892        start: usize,
2893        dims_per_term: Vec<usize>,
2894    ) -> Self {
2895        let total: usize = dims_per_term.iter().sum();
2896        Self {
2897            values: theta.slice(s![start..start + total]).to_owned(),
2898            dims_per_term,
2899        }
2900    }
2901
2902    /// Total number of ψ values in the flat array (= sum of dims_per_term).
2903    pub fn len(&self) -> usize {
2904        self.values.len()
2905    }
2906
2907    /// Dimensionality layout: how many ψ values each term contributes.
2908    pub fn dims_per_term(&self) -> &[usize] {
2909        &self.dims_per_term
2910    }
2911
2912    /// Get the offset into the flat array for logical term i.
2913    fn term_offset(&self, term_idx: usize) -> usize {
2914        self.dims_per_term[..term_idx].iter().sum()
2915    }
2916
2917    /// Get the slice of ψ values for logical term i.
2918    pub fn term_slice(&self, term_idx: usize) -> &[f64] {
2919        let offset = self.term_offset(term_idx);
2920        let d = self.dims_per_term[term_idx];
2921        &self.values.as_slice().unwrap()[offset..offset + d]
2922    }
2923
2924    pub fn as_array(&self) -> &Array1<f64> {
2925        &self.values
2926    }
2927
2928    /// #1464: overwrite the single ψ value of a scalar (1-D) logical term by its
2929    /// position `slot` in this coords vector (the same ordering as the
2930    /// `term_indices` slice the constructors were built from). Used to inject the
2931    /// fixed-κ sign-basin seed into a constant-curvature term's raw-κ slot before
2932    /// the joint solve. No-op (returns `false`) when the slot is not scalar.
2933    pub fn set_scalar_slot(&mut self, slot: usize, value: f64) -> bool {
2934        if slot >= self.dims_per_term.len() || self.dims_per_term[slot] != 1 {
2935            return false;
2936        }
2937        let offset = self.term_offset(slot);
2938        self.values[offset] = value;
2939        true
2940    }
2941
2942    /// Split at a logical-term boundary. `mid` is the number of terms in the
2943    /// first half (not a flat-array index).
2944    pub fn split_at(&self, mid: usize) -> (Self, Self) {
2945        let flat_mid: usize = self.dims_per_term[..mid].iter().sum();
2946        (
2947            Self {
2948                values: self.values.slice(s![0..flat_mid]).to_owned(),
2949                dims_per_term: self.dims_per_term[..mid].to_vec(),
2950            },
2951            Self {
2952                values: self.values.slice(s![flat_mid..]).to_owned(),
2953                dims_per_term: self.dims_per_term[mid..].to_vec(),
2954            },
2955        )
2956    }
2957
2958    /// Apply optimized ψ values back to the spec.
2959    ///
2960    /// For isotropic terms (dims=1): sets scalar length_scale = exp(−ψ).
2961    /// For anisotropic terms (dims=d): hybrid/isotropic families set
2962    /// length_scale = exp(−ψ̄) with centered η_a = ψ_a − ψ̄, while pure Duchon
2963    /// writes only centered η_a and leaves length_scale = None.
2964    pub fn apply_tospec(
2965        &self,
2966        spec: &TermCollectionSpec,
2967        term_indices: &[usize],
2968    ) -> Result<TermCollectionSpec, EstimationError> {
2969        if term_indices.len() != self.dims_per_term.len() {
2970            crate::bail_invalid_estim!(
2971                "SpatialLogKappaCoords::apply_tospec: term count mismatch: \
2972                 term_indices={} dims_per_term={}",
2973                term_indices.len(),
2974                self.dims_per_term.len()
2975            );
2976        }
2977        let mut updated = spec.clone();
2978        for (slot, &term_idx) in term_indices.iter().enumerate() {
2979            let psi = self.term_slice(slot);
2980            let d = self.dims_per_term[slot];
2981            // Measure-jet: write the dial coordinates straight back; the
2982            // κ-translation below would misread them as log-scales.
2983            if measure_jet_term_spec(&updated, term_idx).is_some() {
2984                set_measure_jet_psi_dials(&mut updated, term_idx, psi)?;
2985                continue;
2986            }
2987            // Constant-curvature: write the optimized signed κ straight back;
2988            // the −exp(ψ) length-scale translation below is log-κ semantics and
2989            // would misread the raw curvature.
2990            if constant_curvature_term_spec(&updated, term_idx).is_some() {
2991                set_constant_curvature_kappa(&mut updated, term_idx, psi)?;
2992                continue;
2993            }
2994            let (next_length_scale, next_aniso) = spatial_term_psi_to_length_scale_and_aniso(psi);
2995            if (d == 1 || next_length_scale.is_some())
2996                && let Some(length_scale) = next_length_scale
2997            {
2998                set_spatial_length_scale(&mut updated, term_idx, length_scale)?;
2999            }
3000            if let Some(eta) = next_aniso {
3001                set_spatial_aniso_log_scales(&mut updated, term_idx, eta)?;
3002            }
3003        }
3004        Ok(updated)
3005    }
3006}
3007
3008pub fn center_aniso_log_scales(eta: &[f64]) -> Vec<f64> {
3009    if eta.len() <= 1 {
3010        return eta.to_vec();
3011    }
3012    let mean = eta.iter().sum::<f64>() / eta.len() as f64;
3013    eta.iter()
3014        .map(|&v| {
3015            let centered = v - mean;
3016            if centered.abs() <= 1e-15 {
3017                0.0
3018            } else {
3019                centered
3020            }
3021        })
3022        .collect()
3023}
3024
3025/// Whether a spatial term contributes per-axis ψ entries to the outer joint
3026/// hyperparameter vector.
3027pub fn spatial_term_uses_per_axis_psi(resolvedspec: &TermCollectionSpec, term_idx: usize) -> bool {
3028    if let Some(mj) = measure_jet_term_spec(resolvedspec, term_idx) {
3029        return measure_jet_enrolls_psi(mj);
3030    }
3031    let Some(d) = get_spatial_feature_dim(resolvedspec, term_idx) else {
3032        return false;
3033    };
3034    if d <= 1 {
3035        return false;
3036    }
3037    let Some(eta) = get_spatial_aniso_log_scales(resolvedspec, term_idx) else {
3038        return false;
3039    };
3040    if eta.len() != d {
3041        return false;
3042    }
3043    !matches!(
3044        resolvedspec.smooth_terms.get(term_idx).map(|term| &term.basis),
3045        Some(SmoothBasisSpec::Duchon { .. })
3046    )
3047}
3048
3049pub fn set_spatial_length_scale(
3050    spec: &mut TermCollectionSpec,
3051    term_idx: usize,
3052    length_scale: f64,
3053) -> Result<(), EstimationError> {
3054    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3055        crate::bail_invalid_estim!("spatial length-scale term index {term_idx} out of range");
3056    };
3057    match &mut term.basis {
3058        SmoothBasisSpec::ThinPlate { spec, .. } => {
3059            spec.length_scale = length_scale;
3060            Ok(())
3061        }
3062        SmoothBasisSpec::Matern { spec, .. } => {
3063            spec.length_scale = length_scale;
3064            Ok(())
3065        }
3066        SmoothBasisSpec::Duchon { spec, .. } => {
3067            spec.length_scale = Some(length_scale);
3068            Ok(())
3069        }
3070        _ => Err(EstimationError::InvalidInput(format!(
3071            "term '{}' does not expose a spatial length scale",
3072            term.name
3073        ))),
3074    }
3075}
3076
3077pub fn get_spatial_length_scale(spec: &TermCollectionSpec, term_idx: usize) -> Option<f64> {
3078    spec.smooth_terms
3079        .get(term_idx)
3080        .and_then(|term| match &term.basis {
3081            SmoothBasisSpec::ThinPlate { spec, .. } => Some(spec.length_scale),
3082            SmoothBasisSpec::Matern { spec, .. } => Some(spec.length_scale),
3083            SmoothBasisSpec::Duchon { spec, .. } => spec.length_scale,
3084            _ => None,
3085        })
3086}
3087
3088pub fn spatial_term_supports_hyper_optimization(spec: &TermCollectionSpec, term_idx: usize) -> bool {
3089    // Ordinary penalized thin-plate regression splines do not have an
3090    // identifiable kernel scale once REML is already learning the smoothing
3091    // penalty. Treat the resolved length scale as fixed geometry; enrolling a
3092    // scalar TPS kappa axis creates the flat ρ/κ valleys reported in #718,
3093    // #721, #731, and #732.
3094    if let Some(term) = spec.smooth_terms.get(term_idx)
3095        && let SmoothBasisSpec::ThinPlate { .. } = &term.basis
3096    {
3097        return false;
3098    }
3099
3100    // Duchon anisotropy η is a FIXED, geometry-derived basis parameter, NOT a
3101    // REML hyper axis: the metric is estimated once from the knot-cloud spread
3102    // (`auto_seed_aniso_contrasts`, applied on every Duchon basis build) and the
3103    // Hilbert-scale λ's carry all learned smoothness. So a pure Duchon (no κ)
3104    // contributes no outer optimization axis even when `scale_dims` is on —
3105    // "standardize the geometry, then learn the smoothness." Only an explicit
3106    // kernel length scale κ (the Matérn / hybrid path) is optimized here.
3107    //
3108    // ISOTROPIC Matérn: the *default* `matern(x1, x2)` is isotropic
3109    // (`scale_dims=false` → `aniso_log_scales = None`). It contributes exactly
3110    // ONE κ optimization axis — its scalar log-κ. The shared GAMLSS /
3111    // location-scale exact-joint ψ engine and the spatial-κ joint outer solver
3112    // both require an isotropic Matérn block to expose this single isotropic κ
3113    // axis (#822/#851); without it the per-block ψ-derivative lists are empty
3114    // and the joint-ψ hooks degenerate to `None`. The isotropic κ is the lone
3115    // kernel hyper axis here, mirroring the per-axis ψ ARD that the anisotropic
3116    // path exposes (just collapsed to one dimension).
3117    //
3118    // ANISOTROPIC Matérn (`scale_dims=true` → `aniso_log_scales = Some`) keeps
3119    // its per-axis kernel-η ARD: the d-dimensional ψ search is the *point* of
3120    // the anisotropic request ("Matérn keeps its kernel-η ARD").
3121    //
3122    // Either way a Matérn term always enrolls a κ/ψ axis (1 isotropic, or d
3123    // anisotropic), so `spatial_dims_per_term` reports the correct count.
3124    if let Some(term) = spec.smooth_terms.get(term_idx)
3125        && let SmoothBasisSpec::Matern { .. } = &term.basis
3126    {
3127        return true;
3128    }
3129
3130    // Measure-jet geometry dials are outer ψ coordinates; enrollment is
3131    // owned by `measure_jet_enrolls_psi`.
3132    if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
3133        return measure_jet_enrolls_psi(mj);
3134    }
3135
3136    // Constant-curvature smooths always enroll their single signed curvature κ
3137    // as an outer ψ-coordinate (#944 stage 3): κ̂ is the headline estimand, so
3138    // unlike a fixed-ℓ kernel it is fitted by default, not gated on a
3139    // user-supplied scale. The coordinate is raw κ (interior κ = 0), and its
3140    // exact design/penalty κ-derivatives come from
3141    // `build_constant_curvature_basis_kappa_derivatives`.
3142    if constant_curvature_term_spec(spec, term_idx).is_some() {
3143        return true;
3144    }
3145
3146    get_spatial_length_scale(spec, term_idx).is_some()
3147}
3148
3149/// The measure-jet term's spec, when `term_idx` is a measure-jet smooth.
3150/// Single accessor for every dial-plumbing dispatch below.
3151pub fn measure_jet_term_spec(
3152    spec: &TermCollectionSpec,
3153    term_idx: usize,
3154) -> Option<&crate::basis::MeasureJetBasisSpec> {
3155    spec.smooth_terms
3156        .get(term_idx)
3157        .and_then(|term| match &term.basis {
3158            SmoothBasisSpec::MeasureJet { spec, .. } => Some(spec),
3159            _ => None,
3160        })
3161}
3162
3163/// Single source for measure-jet outer-ψ enrollment: the lnτ dial is
3164/// undefined in the τ = 0 pseudo-inverse oracle mode (see
3165/// `build_measure_jet_basis_psi_derivatives`), so only a positive ridge
3166/// enrolls the dial group. `spatial_term_supports_hyper_optimization` and
3167/// `spatial_term_uses_per_axis_psi` both defer here so the θ-layout
3168/// sources cannot disagree.
3169pub fn measure_jet_enrolls_psi(mj: &crate::basis::MeasureJetBasisSpec) -> bool {
3170    // Two independent enrollment sources (#1116), both explicit:
3171    //   * the design-moving representer length-scale ℓ (`learn_length_scale`),
3172    //     available in every mode when the spec opts in;
3173    //   * the multiscale penalty dials (s, α, lnτ): the per-scale spectral
3174    //     split's (α, lnτ) ride the explicit `multiscale` opt-in, and the lnτ
3175    //     channel additionally needs a positive ridge (τ = 0 is the
3176    //     pseudo-inverse oracle mode where lnτ is undefined).
3177    // A term enrolls if EITHER source is active.
3178    measure_jet_learns_length_scale(mj)
3179        || (mj.tau0 > 0.0 && crate::basis::measure_jet_multiscale_mode(mj))
3180}
3181
3182/// Whether the design-moving ℓ dial is enrolled for this term. ℓ is fixed by
3183/// default and learnable in every mode only when `learn_length_scale = true`.
3184pub fn measure_jet_learns_length_scale(mj: &crate::basis::MeasureJetBasisSpec) -> bool {
3185    mj.learn_length_scale
3186}
3187
3188pub fn freeze_measure_jet_length_scale_learning(spec: &mut TermCollectionSpec) -> usize {
3189    let mut frozen = 0;
3190    for term in spec.smooth_terms.iter_mut() {
3191        if let SmoothBasisSpec::MeasureJet { spec: mj, .. } = &mut term.basis
3192            && mj.learn_length_scale
3193        {
3194            mj.learn_length_scale = false;
3195            frozen += 1;
3196        }
3197    }
3198    frozen
3199}
3200
3201/// Measure-jet ψ dial boxes. The dials are NOT log-kernel-scales, so the
3202/// κ-window machinery never applies: `α` spans density-weighted (0) through
3203/// past-Coifman–Lafon (>1) normalization, and `lnτ` covers the ridge from
3204/// numerically-exact-projection to heavy noise-floor damping. (The energy
3205/// order `s` is the pinned explicit value or absorbed by the REML-learned
3206/// per-scale amplitudes — see `measure_jet_penalty_psi_dim` — so it carries no
3207/// dial box.)
3208pub const MEASURE_JET_PSI_ALPHA_BOUNDS: (f64, f64) = (-1.0, 3.0);
3209
3210pub const MEASURE_JET_PSI_LN_TAU_BOUNDS: (f64, f64) = (-18.420680743952367, 4.605170185988092);
3211
3212/// Log-ℓ box for the design-moving representer length-scale dial (#1116). An
3213/// ABSOLUTE window in the data coordinate scale (ln of ℓ ∈ [1e-3, 1e2]) used
3214/// only when the spec explicitly enrolls the learned representer range. Absolute
3215/// (not seed-relative) so the bound producer needs no data view, matching the
3216/// other dial boxes. `ln(1e-3) = -6.9077…`, `ln(1e2) = 4.6051…`.
3217pub const MEASURE_JET_PSI_LN_LENGTH_SCALE_BOUNDS: (f64, f64) = (-6.907755278982137, 4.605170185988092);
3218
3219/// Number of multiscale PENALTY dials (excluding the design-moving ℓ):
3220/// multiscale (per-scale spectral) mode carries (α, lnτ) = 2 — the order is
3221/// either the pinned explicit `s` or absorbed by the REML-learned per-scale
3222/// amplitudes, so it is NOT a dial; single-scale (the default) carries none.
3223/// MUST agree with the penalty-coordinate layout of
3224/// `build_measure_jet_basis_psi_derivatives` (its `per_level` branch always
3225/// emits exactly the (α, lnτ) coordinate pair).
3226pub fn measure_jet_penalty_psi_dim(mj: &crate::basis::MeasureJetBasisSpec) -> usize {
3227    if crate::basis::measure_jet_multiscale_mode(mj) {
3228        2
3229    } else {
3230        0
3231    }
3232}
3233
3234/// ψ dimension of a measure-jet term. The design-moving ℓ dial (when enrolled)
3235/// is coordinate 0; the multiscale penalty dials follow. MUST agree with the
3236/// coordinate layout of `build_measure_jet_basis_psi_derivatives` (ℓ first).
3237pub fn measure_jet_psi_dim(mj: &crate::basis::MeasureJetBasisSpec) -> usize {
3238    usize::from(measure_jet_learns_length_scale(mj)) + measure_jet_penalty_psi_dim(mj)
3239}
3240
3241/// Seed ψ from the term's realized dials, in producer coordinate order: ℓ first
3242/// (when enrolled), then the multiscale penalty dials. The ℓ seed is the
3243/// realized representer range `ln(length_scale)` (the resolved spec carries the
3244/// concrete auto value after the design build/freeze).
3245pub fn measure_jet_psi_seed(mj: &crate::basis::MeasureJetBasisSpec) -> Vec<f64> {
3246    let mut seed = Vec::with_capacity(measure_jet_psi_dim(mj));
3247    if measure_jet_learns_length_scale(mj) {
3248        // length_scale > 0 after resolution; the 0.0 sentinel (pre-resolution)
3249        // falls back to the centre of the log-ℓ box so the optimizer still
3250        // starts feasible and the first data-aware reseed corrects it.
3251        let ell = if mj.length_scale > 0.0 {
3252            mj.length_scale
3253        } else {
3254            1.0
3255        };
3256        seed.push(ell.ln());
3257    }
3258    if measure_jet_penalty_psi_dim(mj) > 0 {
3259        // Multiscale penalty dials, producer order: (α, lnτ).
3260        let ln_tau = mj.tau0.max(f64::MIN_POSITIVE).ln();
3261        seed.extend_from_slice(&[mj.alpha, ln_tau]);
3262    }
3263    seed
3264}
3265
3266/// One end of the per-coordinate dial boxes, in producer coordinate order
3267/// (ℓ first when enrolled, then the multiscale penalty dials).
3268pub fn measure_jet_psi_bound_values(mj: &crate::basis::MeasureJetBasisSpec, upper: bool) -> Vec<f64> {
3269    let pick = |b: (f64, f64)| if upper { b.1 } else { b.0 };
3270    let mut bounds = Vec::with_capacity(measure_jet_psi_dim(mj));
3271    if measure_jet_learns_length_scale(mj) {
3272        bounds.push(pick(MEASURE_JET_PSI_LN_LENGTH_SCALE_BOUNDS));
3273    }
3274    if measure_jet_penalty_psi_dim(mj) > 0 {
3275        // Multiscale penalty dials, producer order: (α, lnτ).
3276        bounds.push(pick(MEASURE_JET_PSI_ALPHA_BOUNDS));
3277        bounds.push(pick(MEASURE_JET_PSI_LN_TAU_BOUNDS));
3278    }
3279    bounds
3280}
3281
3282/// Write optimized ψ dials back into a measure-jet spec. Returns `true` when
3283/// any dial actually moved. The geometry (centers, masses, band, ℓ, z) is
3284/// ψ-FIXED by contract — only the dials change, so frozen-quadrature
3285/// rebuilds reproduce the identical penalty layout at the new dials.
3286pub fn apply_measure_jet_psi(
3287    mj: &mut crate::basis::MeasureJetBasisSpec,
3288    psi: &[f64],
3289) -> Result<bool, EstimationError> {
3290    if psi.len() != measure_jet_psi_dim(mj) {
3291        crate::bail_invalid_estim!(
3292            "measure-jet ψ write-back dimension mismatch: got {} values for a {}-dial term",
3293            psi.len(),
3294            measure_jet_psi_dim(mj)
3295        );
3296    }
3297    let mut changed = false;
3298    // Coordinate 0 (when enrolled) is the design-moving ln(ℓ); the multiscale
3299    // penalty dials follow. Same order as `measure_jet_psi_seed` and the
3300    // producer (`build_measure_jet_basis_psi_derivatives`).
3301    let mut cursor = 0usize;
3302    if measure_jet_learns_length_scale(mj) {
3303        let next_ell = psi[cursor].exp();
3304        cursor += 1;
3305        if !(next_ell.is_finite() && next_ell > 0.0) {
3306            crate::bail_invalid_estim!(
3307                "measure-jet ψ write-back produced a non-finite/non-positive length_scale (ℓ={next_ell})"
3308            );
3309        }
3310        if next_ell != mj.length_scale {
3311            mj.length_scale = next_ell;
3312            changed = true;
3313        }
3314    }
3315    if measure_jet_penalty_psi_dim(mj) > 0 {
3316        // Multiscale penalty dials, producer order: (α, lnτ). The order `s` is
3317        // not a dial (pinned explicit or absorbed by the per-scale amplitudes).
3318        let next_alpha = psi[cursor];
3319        let next_tau = psi[cursor + 1].exp();
3320        if !(next_alpha.is_finite() && next_tau.is_finite() && next_tau > 0.0) {
3321            crate::bail_invalid_estim!(
3322                "measure-jet ψ write-back produced non-finite dials (alpha={next_alpha}, tau={next_tau})"
3323            );
3324        }
3325        if next_alpha != mj.alpha {
3326            mj.alpha = next_alpha;
3327            changed = true;
3328        }
3329        if next_tau != mj.tau0 {
3330            mj.tau0 = next_tau;
3331            changed = true;
3332        }
3333    }
3334    Ok(changed)
3335}
3336
3337/// Collection-level measure-jet dial write-back (the `apply_tospec` /
3338/// realizer-side entry). Returns whether anything moved.
3339pub fn set_measure_jet_psi_dials(
3340    spec: &mut TermCollectionSpec,
3341    term_idx: usize,
3342    psi: &[f64],
3343) -> Result<bool, EstimationError> {
3344    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3345        crate::bail_invalid_estim!("measure-jet ψ write-back: term index {term_idx} out of range");
3346    };
3347    set_single_term_measure_jet_psi_dials(term, psi)
3348}
3349
3350/// Single-term dial write-back: the shared match+apply core, also used
3351/// directly on the cached per-trial build spec (whose caller has already
3352/// change-checked at the collection level and rebuilds regardless of the
3353/// moved flag).
3354pub fn set_single_term_measure_jet_psi_dials(
3355    term: &mut SmoothTermSpec,
3356    psi: &[f64],
3357) -> Result<bool, EstimationError> {
3358    let SmoothBasisSpec::MeasureJet { spec: mj, .. } = &mut term.basis else {
3359        crate::bail_invalid_estim!("measure-jet ψ write-back targeted a non-measure-jet term");
3360    };
3361    apply_measure_jet_psi(mj, psi)
3362}
3363
3364/// The constant-curvature smooth's spec, when `term_idx` is one. Single
3365/// accessor for every κ-ψ dispatch below, mirroring `measure_jet_term_spec`.
3366pub fn constant_curvature_term_spec(
3367    spec: &TermCollectionSpec,
3368    term_idx: usize,
3369) -> Option<&crate::basis::ConstantCurvatureBasisSpec> {
3370    spec.smooth_terms
3371        .get(term_idx)
3372        .and_then(|term| match &term.basis {
3373            SmoothBasisSpec::ConstantCurvature { spec, .. } => Some(spec),
3374            _ => None,
3375        })
3376}
3377
3378/// Hard positive cap on |κ| relative to the data's inverse squared chart
3379/// radius. The κ-stereographic chart is valid for `1 + κ‖x‖² > 0`; at
3380/// `|κ| = 1/R²` (R² = max squared chart radius) the gauge `1 + κ‖x‖²` reaches
3381/// the chart edge for the farthest data point, so the optimizer is boxed to a
3382/// safe fraction of that scale on both sides. κ = 0 (flat) is the centre of
3383/// the window, an interior point of the `S^d ← ℝ^d → H^d` family — exactly the
3384/// reachability the raw-κ (not log-κ) coordinate exists to preserve.
3385pub const CONSTANT_CURVATURE_KAPPA_CHART_FRACTION: f64 = 0.5;
3386
3387/// Floor on the data's squared chart radius used to scale the κ window, so a
3388/// degenerate (near-origin) point cloud still yields a finite, usable bracket
3389/// rather than an unbounded one.
3390pub const CONSTANT_CURVATURE_MIN_CHART_RADIUS2: f64 = 1e-8;
3391
3392/// `(κ_min, κ_max)` outer-optimization window for a constant-curvature term,
3393/// derived from the data's maximum squared chart radius `R²` so the κ-jets
3394/// never leave the κ-stereographic chart. Symmetric about κ = 0:
3395/// `±CONSTANT_CURVATURE_KAPPA_CHART_FRACTION / R²`.
3396pub fn constant_curvature_kappa_bounds(
3397    data: ArrayView2<'_, f64>,
3398    spec: &TermCollectionSpec,
3399    term_idx: usize,
3400) -> (f64, f64) {
3401    let feature_cols = match spec.smooth_terms.get(term_idx).map(|t| &t.basis) {
3402        Some(SmoothBasisSpec::ConstantCurvature { feature_cols, .. }) => feature_cols,
3403        _ => return (-1.0, 1.0),
3404    };
3405    let mut max_r2 = CONSTANT_CURVATURE_MIN_CHART_RADIUS2;
3406    for row in data.outer_iter() {
3407        let mut r2 = 0.0_f64;
3408        for &c in feature_cols.iter() {
3409            if let Some(&v) = row.get(c)
3410                && v.is_finite()
3411            {
3412                r2 += v * v;
3413            }
3414        }
3415        if r2 > max_r2 {
3416            max_r2 = r2;
3417        }
3418    }
3419    let half = CONSTANT_CURVATURE_KAPPA_CHART_FRACTION / max_r2;
3420    (-half, half)
3421}
3422
3423/// Write the optimized κ back into a constant-curvature term spec. Returns
3424/// `true` when κ moved. Centers, ℓ, and the constraint transform `z` are
3425/// κ-FIXED by the basis κ-contract, so only `kappa` changes.
3426pub fn set_constant_curvature_kappa(
3427    spec: &mut TermCollectionSpec,
3428    term_idx: usize,
3429    psi: &[f64],
3430) -> Result<bool, EstimationError> {
3431    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3432        crate::bail_invalid_estim!(
3433            "constant-curvature κ write-back: term index {term_idx} out of range"
3434        );
3435    };
3436    set_single_term_constant_curvature_kappa(term, psi)
3437}
3438
3439/// Single-term κ write-back: the shared validate+apply core, also used directly
3440/// on the cached per-trial build spec in the incremental realizer (whose caller
3441/// has already change-checked at the collection level and rebuilds regardless
3442/// of the moved flag). Mirrors [`set_single_term_measure_jet_psi_dials`].
3443pub fn set_single_term_constant_curvature_kappa(
3444    term: &mut SmoothTermSpec,
3445    psi: &[f64],
3446) -> Result<bool, EstimationError> {
3447    if psi.len() != 1 {
3448        crate::bail_invalid_estim!(
3449            "constant-curvature κ write-back expects exactly one value, got {}",
3450            psi.len()
3451        );
3452    }
3453    let next_kappa = psi[0];
3454    if !next_kappa.is_finite() {
3455        crate::bail_invalid_estim!(
3456            "constant-curvature κ write-back produced a non-finite κ = {next_kappa}"
3457        );
3458    }
3459    let SmoothBasisSpec::ConstantCurvature { spec: cc, .. } = &mut term.basis else {
3460        crate::bail_invalid_estim!(
3461            "constant-curvature κ write-back targeted a non-constant-curvature term"
3462        );
3463    };
3464    if cc.kappa != next_kappa {
3465        cc.kappa = next_kappa;
3466        Ok(true)
3467    } else {
3468        Ok(false)
3469    }
3470}
3471
3472/// Returns `true` when a spatial term has NO outer optimization axes — i.e.
3473/// the user provided an explicit `length_scale` and the term does not enroll
3474/// REML-side per-axis ψ contrasts, so both the scalar κ and any fixed geometry
3475/// anisotropy are anchored.
3476///
3477/// This is the per-term predicate that distinguishes "fixed kernel scale"
3478/// from "optimize the kernel scale" within the family entry points that
3479/// want to honor an explicit user-supplied scale (e.g. Bernoulli
3480/// marginal-slope, where the joint-spatial outer solver otherwise spends
3481/// ~80 iters stalled on the user's chosen ρ at high gradient).
3482pub fn spatial_term_has_locked_kappa(spec: &TermCollectionSpec, term_idx: usize) -> bool {
3483    get_spatial_length_scale(spec, term_idx).is_some()
3484        && !spatial_term_uses_per_axis_psi(spec, term_idx)
3485}
3486
3487pub fn all_spatial_terms_kappa_fixed(spec: &TermCollectionSpec) -> bool {
3488    spec.smooth_terms.iter().enumerate().all(|(idx, _)| {
3489        !spatial_term_supports_hyper_optimization(spec, idx)
3490            || spatial_term_has_locked_kappa(spec, idx)
3491    })
3492}
3493
3494pub fn spatial_identifiability_policy(termspec: &SmoothTermSpec) -> Option<&SpatialIdentifiability> {
3495    match &termspec.basis {
3496        SmoothBasisSpec::ThinPlate { spec, .. } => Some(&spec.identifiability),
3497        SmoothBasisSpec::Duchon { spec, .. } => Some(&spec.identifiability),
3498        _ => None,
3499    }
3500}
3501
3502/// Standard deviation of the wide, weakly-informative symmetric `Normal` prior
3503/// placed on a relaxable double-penalty smooth's `DoublePenaltyNullspace`
3504/// selection coordinate when the fit is well-determined.
3505pub const NULLSPACE_WELLDET_DEGENERACY_RHO_SD: f64 = 15.0;
3506
3507/// True iff `prior` is the well-determined double-penalty null-space
3508/// degeneracy prior placed on a `DoublePenaltyNullspace` selection coordinate.
3509pub fn is_nullspace_degeneracy_prior(prior: &gam_spec::RhoPrior) -> bool {
3510    matches!(
3511        prior,
3512        gam_spec::RhoPrior::Normal { mean, sd }
3513            if *mean == 0.0 && *sd == NULLSPACE_WELLDET_DEGENERACY_RHO_SD
3514    )
3515}
3516
3517/// Per-term data-derived ψ = log κ bounds.
3518///
3519/// Uses the same safe operating range documented in
3520/// [`crate::basis::build_matern_basis`] / [`crate::basis::build_duchon_basis`]:
3521///   κ ∈ [2 / r_max, 1e2 / r_min]
3522/// where (r_min, r_max) are pairwise-distance extrema of the term's resolved
3523/// centers (post-fit) or the standardized feature data columns (pre-fit).
3524/// Lower edge of the data-derived kernel-range window, as a fraction of the
3525/// maximum pairwise distance `r_max`: length scales below `2/r_max` resolve
3526/// structure finer than the closest center pair, so the kernel range floor is
3527/// set at twice the maximum spacing.
3528pub const KERNEL_RANGE_MIN_DIAMETER_FRACTION: f64 = 2.0;
3529
3530/// Upper edge of the data-derived kernel-range window, as a multiple of the
3531/// minimum pairwise distance `r_min`: beyond `100/r_min` the radial columns go
3532/// nearly collinear with the polynomial nullspace, so the kernel range is
3533/// capped here to keep the basis geometry well-conditioned.
3534pub const KERNEL_RANGE_MAX_SPACING_MULTIPLE: f64 = 1e2;
3535
3536
3537/// Returns ψ-space bounds (ψ_lo = ln(κ_lo), ψ_hi = ln(κ_hi)).
3538///
3539/// When geometry is unavailable (e.g., fewer than 2 distinct points), falls
3540/// back to the scalar `options.min_length_scale` / `options.max_length_scale`
3541/// window so the outer optimizer never sees NaN bounds.
3542///
3543/// The returned window is intersected with the options window so user-set
3544/// `min_length_scale` / `max_length_scale` remain hard limits.
3545pub fn spatial_term_psi_bounds(
3546    data: ArrayView2<'_, f64>,
3547    spec: &TermCollectionSpec,
3548    term_idx: usize,
3549    options: &SpatialLengthScaleOptimizationOptions,
3550) -> (f64, f64) {
3551    let fallback = (
3552        -options.max_length_scale.ln(),
3553        -options.min_length_scale.ln(),
3554    );
3555    // Constant-curvature: the ψ coordinate is the raw signed κ, so its window is
3556    // the chart-feasible κ bracket, NOT a log-ℓ window. Mirrors the aniso bounds
3557    // path's `constant_curvature_kappa_bounds` branch so the isotropic
3558    // (non-aniso) seed clamp projects κ into the right interval.
3559    if constant_curvature_term_spec(spec, term_idx).is_some() {
3560        return constant_curvature_kappa_bounds(data, spec, term_idx);
3561    }
3562    let Some(term) = spec.smooth_terms.get(term_idx) else {
3563        return fallback;
3564    };
3565    // Prefer resolved centers (post-fit) since they live in the same standardized
3566    // space the kernel actually sees. Centers are capped at `default_num_centers`
3567    // (<=2000), so exact pairwise bounds are cheap (<4M ops). If centers are
3568    // not yet UserProvided, fall back to the standardized feature data columns
3569    // with the capped-sample path (O(K²·d), K=1024) — the sample is
3570    // conservative for κ bounds (see `pairwise_distance_bounds_sampled`
3571    // docs): it never excludes a feasible κ the exact method would include.
3572    //
3573    // Under anisotropy the kernel metric is y-space (y_a = exp(η_a) x_a),
3574    // so r_min/r_max must be y-space distances. This matters only when the
3575    // spec already carries calibrated η_a at setup time (e.g., warm-start
3576    // or refit paths); for fresh optimization η_a starts at 0 and y = x.
3577    let aniso = get_spatial_aniso_log_scales(spec, term_idx);
3578    let r_bounds = match spatial_term_center_strategy(term) {
3579        Some(CenterStrategy::UserProvided(centers)) if centers.nrows() >= 2 => {
3580            match aniso.as_deref() {
3581                Some(eta) if eta.len() == centers.ncols() => {
3582                    let y = points_in_aniso_y_space(centers.view(), eta);
3583                    pairwise_distance_bounds(y.view())
3584                }
3585                _ => pairwise_distance_bounds(centers.view()),
3586            }
3587        }
3588        _ => standardized_spatial_term_data(data, term)
3589            .ok()
3590            .and_then(|x| match aniso.as_deref() {
3591                Some(eta) if eta.len() == x.ncols() => {
3592                    let y = points_in_aniso_y_space(x.view(), eta);
3593                    pairwise_distance_bounds_sampled(y.view())
3594                }
3595                _ => pairwise_distance_bounds_sampled(x.view()),
3596            }),
3597    };
3598    let Some((r_min, r_max)) = r_bounds else {
3599        return fallback;
3600    };
3601    // Length scales substantially larger than the data diameter make radial
3602    // TPS/Matern columns nearly collinear with their polynomial nullspace.
3603    // The nullspace already carries constant/linear low-frequency structure,
3604    // so cap the kernel range at the diameter scale instead of letting the
3605    // optimizer enter a numerically degenerate basis geometry.
3606    let psi_lo_data = (KERNEL_RANGE_MIN_DIAMETER_FRACTION / r_max).ln();
3607    let psi_hi_data = (KERNEL_RANGE_MAX_SPACING_MULTIPLE / r_min).ln();
3608    // #1074: the Matérn-specific length-scale ceiling that used to live here was
3609    // deleted. It was masking, not fixing, the real defect: a hard upper bound on
3610    // the kernel range that pinned the κ-optimizer short rather than letting the
3611    // optimizer find the REML optimum. Matérn now shares the same generic geometry
3612    // window as Duchon / TPS (`KERNEL_RANGE_MIN_DIAMETER_FRACTION / r_max` floor,
3613    // `KERNEL_RANGE_MAX_SPACING_MULTIPLE / r_min` ceiling); the #1357 fully-flat
3614    // collapse corner is guarded by the EDF-collapse guard in
3615    // `spatial_optimization.rs`, which acts on the realized fit, not on a clamp.
3616    // Intersect with the options window so min/max_length_scale remain hard caps.
3617    let psi_lo = psi_lo_data.max(fallback.0);
3618    let psi_hi = psi_hi_data.min(fallback.1);
3619    if psi_lo >= psi_hi {
3620        // Degenerate intersection — fall back to the options window to keep the
3621        // outer optimizer from collapsing to a point.
3622        return fallback;
3623    }
3624    (psi_lo, psi_hi)
3625}
3626
3627/// Data-derived ψ seed for a spatial term when the user has not set an
3628/// explicit length_scale on its basis spec. Uses the geometric mean of the
3629/// data-informed kappa range (i.e., the midpoint of the ψ window).
3630pub fn spatial_term_psi_seed(
3631    data: ArrayView2<'_, f64>,
3632    spec: &TermCollectionSpec,
3633    term_idx: usize,
3634    options: &SpatialLengthScaleOptimizationOptions,
3635) -> Option<f64> {
3636    if get_spatial_length_scale(spec, term_idx).is_some() {
3637        return None; // user/spec-provided length_scale wins
3638    }
3639    let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, spec, term_idx, options);
3640    Some(0.5 * (psi_lo + psi_hi))
3641}
3642
3643pub fn spatial_term_psi_to_length_scale_and_aniso(psi: &[f64]) -> (Option<f64>, Option<Vec<f64>>) {
3644    if psi.len() <= 1 {
3645        (Some((-psi.first().copied().unwrap_or(0.0)).exp()), None)
3646    } else {
3647        let psi_bar = psi.iter().sum::<f64>() / psi.len() as f64;
3648        (
3649            Some((-psi_bar).exp()),
3650            Some(psi.iter().map(|&value| value - psi_bar).collect()),
3651        )
3652    }
3653}
3654
3655/// Get the `aniso_log_scales` from a spatial term, if present.
3656pub fn get_spatial_aniso_log_scales(
3657    spec: &TermCollectionSpec,
3658    term_idx: usize,
3659) -> Option<Vec<f64>> {
3660    spec.smooth_terms
3661        .get(term_idx)
3662        .and_then(|term| match &term.basis {
3663            SmoothBasisSpec::Matern { spec, .. } => spec.aniso_log_scales.clone(),
3664            SmoothBasisSpec::Duchon { spec, .. } => spec.aniso_log_scales.clone(),
3665            _ => None,
3666        })
3667}
3668
3669/// Per-axis response-structure score for anisotropy seeding.
3670///
3671/// For each spatial axis `a`, sort the response `y` by the axis coordinate
3672/// `x_a` and measure the total squared successive variation of the sorted
3673/// response, `tv_a = Σ_i (y_{σ(i+1)} − y_{σ(i)})²` where `σ` orders rows by
3674/// `x_a`. An axis that carries real (possibly nonlinear) signal makes `y` vary
3675/// SMOOTHLY when the rows are walked in that axis's order, so `tv_a` is SMALL;
3676/// a pure-nuisance axis leaves `y` looking unordered, so `tv_a` is LARGE.
3677///
3678/// This deliberately does NOT use a linear correlation `corr(x_a, y)`: for an
3679/// odd, symmetric signal such as `sin(2·x1)` over a symmetric domain the linear
3680/// correlation is ~0 on the *signal* axis, which would misdirect the seed. The
3681/// total-variation-of-sorted-response score captures nonlinear association.
3682///
3683/// Returns `score_a = −½·ln(tv_a + ε)` (larger ⇒ more signal on axis `a`),
3684/// centered to sum to zero, or `None` when the data is degenerate (too few
3685/// rows, non-finite, or all axes equally (un)structured). The caller adds a
3686/// BOUNDED multiple of this to the geometry seed — it is a conservative nudge,
3687/// never a hard override.
3688pub fn response_aware_axis_contrasts(
3689    x: ndarray::ArrayView2<'_, f64>,
3690    y: ndarray::ArrayView1<'_, f64>,
3691) -> Option<Vec<f64>> {
3692    let n = x.nrows();
3693    let d = x.ncols();
3694    if d <= 1 || n < 4 || y.len() != n {
3695        return None;
3696    }
3697    if x.iter().any(|v| !v.is_finite()) || y.iter().any(|v| !v.is_finite()) {
3698        return None;
3699    }
3700    let mut scores = Vec::with_capacity(d);
3701    for a in 0..d {
3702        let mut order: Vec<usize> = (0..n).collect();
3703        let col = x.column(a);
3704        order.sort_by(|&i, &j| {
3705            col[i]
3706                .partial_cmp(&col[j])
3707                .unwrap_or(std::cmp::Ordering::Equal)
3708        });
3709        let mut tv = 0.0_f64;
3710        for w in order.windows(2) {
3711            let diff = y[w[1]] - y[w[0]];
3712            tv += diff * diff;
3713        }
3714        // ε guards against ln(0) on a perfectly flat / constant response.
3715        scores.push(-0.5 * (tv + 1e-12).ln());
3716    }
3717    if scores.iter().any(|v| !v.is_finite()) {
3718        return None;
3719    }
3720    let mean = scores.iter().sum::<f64>() / d as f64;
3721    let centered: Vec<f64> = scores.iter().map(|&s| s - mean).collect();
3722    // If every axis is equally structured the centered scores are ~0 and the
3723    // nudge is a no-op — return None so the geometry seed is used unchanged.
3724    if centered.iter().all(|&v| v.abs() < 1e-9) {
3725        return None;
3726    }
3727    Some(centered)
3728}
3729
3730/// Conservative, response-aware anisotropy seed nudge applied before the κ outer
3731/// loop. For each anisotropic spatial term it adds a BOUNDED multiple of the
3732/// per-axis response-structure contrast (`response_aware_axis_contrasts`) on top
3733/// of the existing geometry seed, so the optimizer starts in the correct basin
3734/// instead of at a response-blind near-symmetric point (the #1376 under-recovery
3735/// where a signal axis and a nuisance axis with equal coordinate spread seed to
3736/// ~[0,0]). The nudge is clamped to keep this a perturbation, never a hard
3737/// override, so shared aniso Matérn/Duchon fits cannot be destabilized by it.
3738pub fn apply_response_aware_anisotropy_seed(
3739    data: ArrayView2<'_, f64>,
3740    y: ndarray::ArrayView1<'_, f64>,
3741    spec: &mut TermCollectionSpec,
3742    spatial_terms: &[usize],
3743) {
3744    // Bound on the per-axis contrast nudge (in η units). One LN_2 ≈ 0.69 halves
3745    // the effective per-axis length scale; capping at LN_2 keeps the seed within
3746    // one optimizer log-step of the geometry seed while still breaking the
3747    // symmetric-seed trap.
3748    const MAX_NUDGE: f64 = std::f64::consts::LN_2;
3749    for &term_idx in spatial_terms {
3750        let Some(current_eta) = get_spatial_aniso_log_scales(spec, term_idx) else {
3751            continue;
3752        };
3753        let d = current_eta.len();
3754        if d <= 1 {
3755            continue;
3756        }
3757        let Some(term) = spec.smooth_terms.get(term_idx) else {
3758            continue;
3759        };
3760        let feature_cols = term.basis.structural_feature_cols();
3761        if feature_cols.len() != d {
3762            continue;
3763        }
3764        let Ok(x) = select_columns(data, &feature_cols) else {
3765            continue;
3766        };
3767        let Some(contrast) = response_aware_axis_contrasts(x.view(), y) else {
3768            continue;
3769        };
3770        let nudged: Vec<f64> = current_eta
3771            .iter()
3772            .zip(contrast.iter())
3773            .map(|(&eta_a, &c_a)| eta_a + c_a.clamp(-MAX_NUDGE, MAX_NUDGE))
3774            .collect();
3775        // `set_spatial_aniso_log_scales` re-centers to Σ η = 0. A term that does
3776        // not support aniso scales is silently skipped (the seed is optional).
3777        if let Err(err) = set_spatial_aniso_log_scales(spec, term_idx, nudged) {
3778            log::debug!(
3779                "[spatial-kappa] response-aware anisotropy seed skipped for term {term_idx}: {err}"
3780            );
3781        }
3782    }
3783}
3784
3785/// Get the number of feature columns (spatial dimensionality) for a spatial term.
3786pub fn get_spatial_feature_dim(spec: &TermCollectionSpec, term_idx: usize) -> Option<usize> {
3787    spec.smooth_terms
3788        .get(term_idx)
3789        .and_then(|term| match &term.basis {
3790            SmoothBasisSpec::ThinPlate { feature_cols, .. } => Some(feature_cols.len()),
3791            SmoothBasisSpec::Matern { feature_cols, .. } => Some(feature_cols.len()),
3792            SmoothBasisSpec::Duchon { feature_cols, .. } => Some(feature_cols.len()),
3793            _ => None,
3794        })
3795}
3796
3797/// Log the learned per-axis spatial anisotropy for all spatial terms that
3798/// have `aniso_log_scales` set after optimization.
3799///
3800/// For scalar-scale families this reports eta, effective per-axis length
3801/// scales, and per-axis kappa values. For pure Duchon it reports the centered
3802/// eta contrasts only.
3803pub fn log_spatial_aniso_scales(spec: &TermCollectionSpec) {
3804    for (term_idx, term) in spec.smooth_terms.iter().enumerate() {
3805        let (aniso, length_scale) = match &term.basis {
3806            SmoothBasisSpec::Matern { spec, .. } => {
3807                (spec.aniso_log_scales.as_ref(), Some(spec.length_scale))
3808            }
3809            SmoothBasisSpec::Duchon { spec, .. } => {
3810                (spec.aniso_log_scales.as_ref(), spec.length_scale)
3811            }
3812            _ => (None, None),
3813        };
3814        let Some(eta) = aniso else { continue };
3815        if eta.is_empty() {
3816            continue;
3817        }
3818        let mut lines = match length_scale {
3819            Some(ls) => format!(
3820                "[spatial-kappa] term {} (\"{}\"): anisotropic length scales optimized (global length_scale={:.4})",
3821                term_idx, term.name, ls
3822            ),
3823            None => format!(
3824                "[spatial-kappa] term {} (\"{}\"): pure Duchon shape anisotropy optimized",
3825                term_idx, term.name
3826            ),
3827        };
3828        for (a, &eta_a) in eta.iter().enumerate() {
3829            if let Some(ls) = length_scale {
3830                let length_a = ls * (-eta_a).exp();
3831                let kappa_a = (1.0 / ls) * eta_a.exp();
3832                lines.push_str(&format!(
3833                    "\n  axis {}: eta={:+.4}, length={:.4}, kappa={:.4}",
3834                    a, eta_a, length_a, kappa_a
3835                ));
3836            } else {
3837                lines.push_str(&format!("\n  axis {}: eta={:+.4}", a, eta_a));
3838            }
3839        }
3840        log::info!("{}", lines);
3841    }
3842}
3843
3844/// Set `aniso_log_scales` on a spatial term's basis spec.
3845pub fn set_spatial_aniso_log_scales(
3846    spec: &mut TermCollectionSpec,
3847    term_idx: usize,
3848    eta: Vec<f64>,
3849) -> Result<(), EstimationError> {
3850    let eta = center_aniso_log_scales(&eta);
3851    let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3852        crate::bail_invalid_estim!("spatial aniso_log_scales term index {term_idx} out of range");
3853    };
3854    match &mut term.basis {
3855        SmoothBasisSpec::Matern { spec, .. } => {
3856            spec.aniso_log_scales = Some(eta);
3857            Ok(())
3858        }
3859        SmoothBasisSpec::Duchon { spec, .. } => {
3860            spec.aniso_log_scales = Some(eta);
3861            Ok(())
3862        }
3863        _ => Err(EstimationError::InvalidInput(format!(
3864            "term '{}' does not support aniso_log_scales",
3865            term.name
3866        ))),
3867    }
3868}
3869
3870/// Sync knot-cloud-derived anisotropy contrasts from basis metadata back into
3871/// the mutable spec so the optimizer starts from the correct eta values.
3872///
3873/// Call this after building the smooth design but before initializing the
3874/// optimizer's psi coordinates. For each spatial term whose metadata contains
3875/// computed `aniso_log_scales`, this writes them into the spec.
3876pub fn sync_aniso_contrasts_from_metadata(
3877    spec: &mut TermCollectionSpec,
3878    design: &SmoothDesign,
3879) {
3880    for (term_idx, term) in design.terms.iter().enumerate() {
3881        let meta_aniso = match &term.metadata {
3882            BasisMetadata::Matern {
3883                aniso_log_scales, ..
3884            } => aniso_log_scales.clone(),
3885            BasisMetadata::Duchon {
3886                aniso_log_scales, ..
3887            } => aniso_log_scales.clone(),
3888            _ => None,
3889        };
3890        if let Some(eta) = meta_aniso
3891            && eta.len() > 1
3892        {
3893            set_spatial_aniso_log_scales(spec, term_idx, eta).ok();
3894        }
3895    }
3896}
3897
3898#[derive(Debug, Clone)]
3899pub struct SpatialLengthScaleOptimizationOptions {
3900    /// Enable outer-loop optimization over spatial κ (= 1 / length_scale)
3901    /// for supported radial-kernel smooths.
3902    /// This applies to ThinPlate, Matérn, and Duchon terms.
3903    pub enabled: bool,
3904    /// Maximum number of outer iterations in the exact joint [rho, psi] solve.
3905    pub max_outer_iter: usize,
3906    /// Relative improvement threshold for terminating the outer solve.
3907    pub rel_tol: f64,
3908    /// Initial log(length_scale) perturbation used for seed construction.
3909    pub log_step: f64,
3910    /// Minimum allowed length_scale during κ search.
3911    pub min_length_scale: f64,
3912    /// Maximum allowed length_scale during κ search.
3913    pub max_length_scale: f64,
3914    /// Automatic geometry-initializer threshold for large-scale spatial fits.
3915    ///
3916    /// When n exceeds twice this value, the fitter uses a spatially stratified
3917    /// subsample only to seed κ/anisotropy geometry: centers are resolved,
3918    /// axis contrasts are initialized from center/data spread, and one or two
3919    /// cheap ψ reseeding updates are applied. It never runs PIRLS, REML, ARC,
3920    /// BFGS, or any recursive optimizer on the pilot.
3921    ///
3922    /// The final coefficients, smoothing parameters, and spatial geometry are
3923    /// always optimized on the full dataset.
3924    ///
3925    /// Set to 0 to skip the pilot geometry initializer.
3926    pub pilot_subsample_threshold: usize,
3927    /// Optional wall-clock budget (seconds) for the whole outer smoothing search
3928    /// (gam#979). When a family arms the global deadline from this, an outer
3929    /// search that cannot certify convergence (survival marginal-slope's
3930    /// monotonicity-pinned constrained joint-Newton) returns its best-so-far
3931    /// iterate (or a catchable error) within the budget instead of hanging.
3932    /// `None` keeps the legacy unbounded behavior; the survival marginal-slope
3933    /// path applies a generous default when this is `None`.
3934    pub outer_wall_clock_budget_secs: Option<f64>,
3935}
3936
3937impl Default for SpatialLengthScaleOptimizationOptions {
3938    fn default() -> Self {
3939        Self {
3940            enabled: true,
3941            max_outer_iter: 80,
3942            rel_tol: 1e-4,
3943            log_step: std::f64::consts::LN_2,
3944            min_length_scale: 1e-3,
3945            max_length_scale: 1e3,
3946            pilot_subsample_threshold: 10_000,
3947            outer_wall_clock_budget_secs: None,
3948        }
3949    }
3950}
3951
3952impl SpatialLengthScaleOptimizationOptions {
3953    /// Validate the struct's invariants. Callers that construct these options
3954    /// from external input (CLI, config, Python API) should call this before
3955    /// passing the options into the fitter. Returns `Err` with a descriptive
3956    /// message when an invariant is violated; the fitter then panics or
3957    /// returns `EstimationError` at its own boundary.
3958    ///
3959    /// Invariants:
3960    ///   * `min_length_scale > 0`, finite
3961    ///   * `max_length_scale > 0`, finite
3962    ///   * `min_length_scale < max_length_scale`
3963    ///   * `rel_tol > 0`, finite
3964    ///   * `log_step > 0`, finite
3965    ///
3966    /// These invariants are what the downstream κ-bound and ψ-window code
3967    /// assumes (`-log(max_ls)` must be finite, `(min,max)` must not be
3968    /// inverted, etc.). Without validation, invalid options produce silent
3969    /// NaN-propagation inside the outer optimizer.
3970    pub fn validate(&self) -> Result<(), String> {
3971        if !self.min_length_scale.is_finite() || self.min_length_scale <= 0.0 {
3972            return Err(SmoothError::invalid_config(format!(
3973                "SpatialLengthScaleOptimizationOptions::min_length_scale must be > 0 and finite, got {}",
3974                self.min_length_scale
3975            ))
3976            .into());
3977        }
3978        if !self.max_length_scale.is_finite() || self.max_length_scale <= 0.0 {
3979            return Err(SmoothError::invalid_config(format!(
3980                "SpatialLengthScaleOptimizationOptions::max_length_scale must be > 0 and finite, got {}",
3981                self.max_length_scale
3982            ))
3983            .into());
3984        }
3985        if self.min_length_scale >= self.max_length_scale {
3986            return Err(SmoothError::invalid_config(format!(
3987                "SpatialLengthScaleOptimizationOptions requires min_length_scale < max_length_scale, got min={} max={}",
3988                self.min_length_scale, self.max_length_scale
3989            ))
3990            .into());
3991        }
3992        if !self.rel_tol.is_finite() || self.rel_tol <= 0.0 {
3993            return Err(SmoothError::invalid_config(format!(
3994                "SpatialLengthScaleOptimizationOptions::rel_tol must be > 0 and finite, got {}",
3995                self.rel_tol
3996            ))
3997            .into());
3998        }
3999        if !self.log_step.is_finite() || self.log_step <= 0.0 {
4000            return Err(SmoothError::invalid_config(format!(
4001                "SpatialLengthScaleOptimizationOptions::log_step must be > 0 and finite, got {}",
4002                self.log_step
4003            ))
4004            .into());
4005        }
4006        Ok(())
4007    }
4008}
4009
4010#[derive(Debug, Clone)]
4011pub struct RandomEffectBlock {
4012    pub name: String,
4013    /// O(n) group-label vector: group_ids[i] = column index in [0, num_groups).
4014    /// `None` if the observation's level is not in the kept set.
4015    pub group_ids: Vec<Option<usize>>,
4016    pub num_groups: usize,
4017    pub kept_levels: Vec<u64>,
4018}
4019
4020pub const BLOCK_SPARSE_ZERO_EPS: f64 = 1e-12;
4021
4022pub const BLOCK_SPARSE_MAX_DENSITY: f64 = 0.20;
4023
4024pub fn blocks_have_intrinsic_sparse_structure(blocks: &[DesignBlock]) -> bool {
4025    blocks
4026        .iter()
4027        .any(|block| matches!(block, DesignBlock::Sparse(_) | DesignBlock::RandomEffect(_)))
4028}
4029
4030pub fn sparse_compatible_block_nnz(block: &DesignBlock) -> Option<usize> {
4031    match block {
4032        DesignBlock::Intercept(n) => Some(*n),
4033        DesignBlock::RandomEffect(op) => {
4034            Some(op.group_ids.iter().filter(|gid| gid.is_some()).count())
4035        }
4036        DesignBlock::Sparse(sparse) => Some(sparse.val().len()),
4037        DesignBlock::Dense(dense) => dense.as_dense_ref().map(|matrix| {
4038            matrix
4039                .iter()
4040                .filter(|&&value| value.abs() > BLOCK_SPARSE_ZERO_EPS)
4041                .count()
4042        }),
4043    }
4044}
4045
4046pub fn try_build_sparse_design_from_blocks(
4047    blocks: &[DesignBlock],
4048) -> Result<Option<DesignMatrix>, BasisError> {
4049    if blocks.is_empty() {
4050        return Ok(None);
4051    }
4052    let nrows = blocks[0].nrows();
4053    let ncols: usize = blocks.iter().map(DesignBlock::ncols).sum();
4054    if nrows == 0 || ncols == 0 || ncols <= 32 {
4055        return Ok(None);
4056    }
4057
4058    let preserve_sparse_storage = blocks_have_intrinsic_sparse_structure(blocks);
4059    let sparse_nnz_limit = if preserve_sparse_storage {
4060        usize::MAX
4061    } else {
4062        let total_cells = nrows.saturating_mul(ncols);
4063        ((total_cells as f64) * BLOCK_SPARSE_MAX_DENSITY).floor() as usize
4064    };
4065    let mut nnz = 0usize;
4066    for block in blocks {
4067        let block_nnz = if let Some(block_nnz) = sparse_compatible_block_nnz(block) {
4068            block_nnz
4069        } else {
4070            return Ok(None);
4071        };
4072        nnz = nnz.saturating_add(block_nnz);
4073        if nnz > sparse_nnz_limit {
4074            return Ok(None);
4075        }
4076    }
4077
4078    let mut triplets = Vec::<Triplet<usize, usize, f64>>::with_capacity(nnz);
4079    let mut col_offset = 0usize;
4080    for block in blocks {
4081        match block {
4082            DesignBlock::Intercept(n) => {
4083                for row in 0..*n {
4084                    triplets.push(Triplet::new(row, col_offset, 1.0));
4085                }
4086            }
4087            DesignBlock::RandomEffect(op) => {
4088                for (row, group_id) in op.group_ids.iter().enumerate() {
4089                    if let Some(group) = group_id {
4090                        triplets.push(Triplet::new(row, col_offset + group, 1.0));
4091                    }
4092                }
4093            }
4094            DesignBlock::Sparse(sparse) => {
4095                let (symbolic, values) = sparse.parts();
4096                let col_ptr = symbolic.col_ptr();
4097                let row_idx = symbolic.row_idx();
4098                for col in 0..sparse.ncols() {
4099                    for idx in col_ptr[col]..col_ptr[col + 1] {
4100                        let value = values[idx];
4101                        if value.abs() > BLOCK_SPARSE_ZERO_EPS {
4102                            triplets.push(Triplet::new(row_idx[idx], col_offset + col, value));
4103                        }
4104                    }
4105                }
4106            }
4107            DesignBlock::Dense(dense) => {
4108                let matrix = dense.as_dense_ref().ok_or_else(|| {
4109                    BasisError::InvalidInput(
4110                        "sparse-compatible block assembly requires materialized dense blocks"
4111                            .to_string(),
4112                    )
4113                })?;
4114                for row in 0..matrix.nrows() {
4115                    for col in 0..matrix.ncols() {
4116                        let value = matrix[[row, col]];
4117                        if value.abs() > BLOCK_SPARSE_ZERO_EPS {
4118                            triplets.push(Triplet::new(row, col_offset + col, value));
4119                        }
4120                    }
4121                }
4122            }
4123        }
4124        col_offset += block.ncols();
4125    }
4126
4127    let sparse = SparseColMat::try_new_from_triplets(nrows, ncols, &triplets).map_err(|_| {
4128        BasisError::SparseCreation("failed to assemble sparse term-collection design".to_string())
4129    })?;
4130    Ok(Some(DesignMatrix::Sparse(
4131        gam_linalg::matrix::SparseDesignMatrix::new(sparse),
4132    )))
4133}
4134
4135pub fn assemble_term_collection_design_matrix(
4136    blocks: Vec<DesignBlock>,
4137) -> Result<DesignMatrix, BasisError> {
4138    if let Some(sparse) = try_build_sparse_design_from_blocks(&blocks)? {
4139        return Ok(sparse);
4140    }
4141    let block_op = BlockDesignOperator::new(blocks).map_err(|e| {
4142        BasisError::InvalidInput(format!("failed to build block design operator: {e}"))
4143    })?;
4144    Ok(DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
4145        Arc::new(block_op),
4146    )))
4147}
4148
4149pub fn select_columns(data: ArrayView2<'_, f64>, cols: &[usize]) -> Result<Array2<f64>, BasisError> {
4150    let n = data.nrows();
4151    let p = data.ncols();
4152    for &c in cols {
4153        if c >= p {
4154            crate::bail_dim_basis!("feature column {c} is out of bounds for data with {p} columns");
4155        }
4156    }
4157    let mut out = Array2::<f64>::zeros((n, cols.len()));
4158    for (j, &c) in cols.iter().enumerate() {
4159        out.column_mut(j).assign(&data.column(c));
4160    }
4161    Ok(out)
4162}
4163
4164pub fn nonfinite_value_label(value: f64) -> &'static str {
4165    if value.is_nan() {
4166        "NaN"
4167    } else if value.is_sign_positive() {
4168        "+Inf"
4169    } else {
4170        "-Inf"
4171    }
4172}
4173
4174pub fn validate_term_feature_column_finite(
4175    data: ArrayView2<'_, f64>,
4176    term_kind: &str,
4177    term_name: &str,
4178    feature_col: usize,
4179) -> Result<(), BasisError> {
4180    let p = data.ncols();
4181    if feature_col >= p {
4182        crate::bail_dim_basis!(
4183            "{term_kind} term '{term_name}' feature column {feature_col} out of bounds for {p} columns"
4184        );
4185    }
4186    for (row, &value) in data.column(feature_col).iter().enumerate() {
4187        if !value.is_finite() {
4188            crate::bail_invalid_basis!(
4189                "{term_kind} term '{term_name}' feature column {feature_col} row {row} contains non-finite value {}",
4190                nonfinite_value_label(value)
4191            );
4192        }
4193    }
4194    Ok(())
4195}
4196
4197pub fn validate_smooth_terms_finite_inputs(
4198    data: ArrayView2<'_, f64>,
4199    terms: &[SmoothTermSpec],
4200) -> Result<(), BasisError> {
4201    for term in terms {
4202        for feature_col in smooth_term_feature_cols(term) {
4203            validate_term_feature_column_finite(data, "smooth", &term.name, feature_col)?;
4204        }
4205    }
4206    Ok(())
4207}
4208
4209pub fn validate_term_collection_finite_inputs(
4210    data: ArrayView2<'_, f64>,
4211    spec: &TermCollectionSpec,
4212) -> Result<(), BasisError> {
4213    for term in &spec.linear_terms {
4214        validate_term_feature_column_finite(data, "linear", &term.name, term.feature_col)?;
4215    }
4216    for term in &spec.random_effect_terms {
4217        validate_term_feature_column_finite(data, "random-effect", &term.name, term.feature_col)?;
4218    }
4219    validate_smooth_terms_finite_inputs(data, &spec.smooth_terms)
4220}
4221
4222#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
4223pub struct JointSpatialCenterGroupKey {
4224    feature_cols: Vec<usize>,
4225    strategy_kind: CenterStrategyKind,
4226    strategy_aux: usize,
4227    requested_num_centers: usize,
4228    input_scale_bits: Option<Vec<u64>>,
4229}
4230
4231pub fn spatial_term_min_center_count(term: &SmoothTermSpec) -> usize {
4232    match &term.basis {
4233        SmoothBasisSpec::ThinPlate { feature_cols, .. } => feature_cols.len() + 1,
4234        SmoothBasisSpec::Duchon {
4235            feature_cols, spec, ..
4236        } => match spec.nullspace_order {
4237            crate::basis::DuchonNullspaceOrder::Zero => 1,
4238            crate::basis::DuchonNullspaceOrder::Linear => feature_cols.len() + 1,
4239            crate::basis::DuchonNullspaceOrder::Degree(degree) => {
4240                crate::basis::duchon_nullspace_dimension(feature_cols.len(), degree)
4241            }
4242        },
4243        SmoothBasisSpec::Matern { .. } => 1,
4244        _ => 1,
4245    }
4246}
4247
4248pub fn spatial_term_group_key(term: &SmoothTermSpec) -> Option<JointSpatialCenterGroupKey> {
4249    let (feature_cols, strategy, input_scales) = match &term.basis {
4250        SmoothBasisSpec::ThinPlate {
4251            feature_cols,
4252            spec,
4253            input_scales,
4254        } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4255        SmoothBasisSpec::Matern {
4256            feature_cols,
4257            spec,
4258            input_scales,
4259        } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4260        SmoothBasisSpec::Duchon {
4261            feature_cols,
4262            spec,
4263            input_scales,
4264        } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4265        _ => return None,
4266    };
4267    let strategy_kind = center_strategy_kind(strategy);
4268    let strategy_aux = match strategy {
4269        CenterStrategy::Auto(inner) => match inner.as_ref() {
4270            CenterStrategy::KMeans { max_iter, .. } => *max_iter,
4271            CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
4272            _ => 0,
4273        },
4274        CenterStrategy::KMeans { max_iter, .. } => *max_iter,
4275        CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
4276        _ => 0,
4277    };
4278    Some(JointSpatialCenterGroupKey {
4279        feature_cols: feature_cols.clone(),
4280        strategy_kind,
4281        strategy_aux,
4282        requested_num_centers: center_strategy_num_centers(strategy)?,
4283        input_scale_bits: input_scales
4284            .map(|values| values.iter().map(|value| value.to_bits()).collect()),
4285    })
4286}
4287
4288pub fn spatial_term_center_strategy(term: &SmoothTermSpec) -> Option<&CenterStrategy> {
4289    match &term.basis {
4290        SmoothBasisSpec::ThinPlate { spec, .. } => Some(&spec.center_strategy),
4291        SmoothBasisSpec::Matern { spec, .. } => Some(&spec.center_strategy),
4292        SmoothBasisSpec::Duchon { spec, .. } => Some(&spec.center_strategy),
4293        _ => None,
4294    }
4295}
4296
4297pub fn set_spatial_term_centers(
4298    term: &mut SmoothTermSpec,
4299    centers: Array2<f64>,
4300) -> Result<(), BasisError> {
4301    match &mut term.basis {
4302        SmoothBasisSpec::ThinPlate { spec, .. } => {
4303            spec.center_strategy = CenterStrategy::UserProvided(centers);
4304            Ok(())
4305        }
4306        SmoothBasisSpec::Matern { spec, .. } => {
4307            spec.center_strategy = CenterStrategy::UserProvided(centers);
4308            Ok(())
4309        }
4310        SmoothBasisSpec::Duchon { spec, .. } => {
4311            spec.center_strategy = CenterStrategy::UserProvided(centers);
4312            Ok(())
4313        }
4314        _ => Err(BasisError::InvalidInput(format!(
4315            "term '{}' does not support spatial center planning",
4316            term.name
4317        ))),
4318    }
4319}
4320
4321pub fn standardized_spatial_term_data(
4322    data: ArrayView2<'_, f64>,
4323    term: &SmoothTermSpec,
4324) -> Result<Array2<f64>, BasisError> {
4325    let (feature_cols, input_scales) = match &term.basis {
4326        SmoothBasisSpec::ThinPlate {
4327            feature_cols,
4328            input_scales,
4329            ..
4330        }
4331        | SmoothBasisSpec::Matern {
4332            feature_cols,
4333            input_scales,
4334            ..
4335        }
4336        | SmoothBasisSpec::Duchon {
4337            feature_cols,
4338            input_scales,
4339            ..
4340        } => (feature_cols, input_scales.as_ref()),
4341        _ => {
4342            crate::bail_invalid_basis!("term '{}' is not a spatial smooth", term.name);
4343        }
4344    };
4345    let mut x = select_columns(data, feature_cols)?;
4346    if let Some(scales) = input_scales {
4347        apply_input_standardization(&mut x, scales);
4348    } else if let Some(scales) = compute_spatial_input_scales(x.view()) {
4349        apply_input_standardization(&mut x, &scales);
4350    }
4351    Ok(x)
4352}
4353
4354pub fn plan_joint_spatial_centers_for_term_blocks(
4355    data: ArrayView2<'_, f64>,
4356    term_blocks: &[Vec<SmoothTermSpec>],
4357) -> Result<Vec<Vec<SmoothTermSpec>>, BasisError> {
4358    let mut planned_blocks = term_blocks.to_vec();
4359    let n = data.nrows();
4360    let mut groups: BTreeMap<JointSpatialCenterGroupKey, Vec<(usize, usize)>> = BTreeMap::new();
4361
4362    for (block_idx, terms) in planned_blocks.iter().enumerate() {
4363        for (term_idx, term) in terms.iter().enumerate() {
4364            let Some(strategy) = spatial_term_center_strategy(term) else {
4365                continue;
4366            };
4367            if !center_strategy_is_auto(strategy) {
4368                continue;
4369            }
4370            let Some(group_key) = spatial_term_group_key(term) else {
4371                continue;
4372            };
4373            if !matches!(
4374                group_key.strategy_kind,
4375                CenterStrategyKind::EqualMass
4376                    | CenterStrategyKind::EqualMassCovarRepresentative
4377                    | CenterStrategyKind::FarthestPoint
4378                    | CenterStrategyKind::KMeans
4379            ) {
4380                continue;
4381            }
4382            if center_strategy_num_centers(strategy).is_none() {
4383                continue;
4384            }
4385            groups
4386                .entry(group_key)
4387                .or_default()
4388                .push((block_idx, term_idx));
4389        }
4390    }
4391
4392    for (group_key, members) in groups {
4393        if members.len() < 2 {
4394            continue;
4395        }
4396        let min_required = members
4397            .iter()
4398            .map(|&(block_idx, term_idx)| {
4399                spatial_term_min_center_count(&planned_blocks[block_idx][term_idx])
4400            })
4401            .max()
4402            .unwrap_or(1);
4403        let joint_centers = group_key
4404            .requested_num_centers
4405            .max(min_required)
4406            .min(n.max(1));
4407        let (first_block_idx, first_term_idx) = members[0];
4408        let prototype = &planned_blocks[first_block_idx][first_term_idx];
4409        let standardized = standardized_spatial_term_data(data, prototype)?;
4410        let strategy = spatial_term_center_strategy(prototype).ok_or_else(|| {
4411            BasisError::InvalidInput(format!(
4412                "term '{}' lost its spatial center strategy during joint planning",
4413                prototype.name
4414            ))
4415        })?;
4416        let joint_strategy = center_strategy_with_num_centers(strategy, joint_centers)?;
4417        let shared_centers = select_centers_by_strategy(standardized.view(), &joint_strategy)?;
4418        log::info!(
4419            "sharing {} spatial centers across {} smooth terms over columns {:?} (requested {} centers)",
4420            shared_centers.nrows(),
4421            members.len(),
4422            group_key.feature_cols,
4423            group_key.requested_num_centers,
4424        );
4425        for (block_idx, term_idx) in members {
4426            set_spatial_term_centers(
4427                &mut planned_blocks[block_idx][term_idx],
4428                shared_centers.clone(),
4429            )?;
4430        }
4431    }
4432
4433    // Sentinel auto-init: Matern and thin-plate builders write length_scale =
4434    // 0.0 when the user didn't pass `length_scale=...`. Replace those with a
4435    // data-driven initialization here so REML starts in a regime where it can
4436    // escape; the hard-coded 1.0 default was a basin from which ν ≥ 5/2 Matern
4437    // could not recover on high-frequency truths, silently collapsing the fit
4438    // to a near-constant prediction.
4439    for block in planned_blocks.iter_mut() {
4440        for term in block.iter_mut() {
4441            auto_init_length_scale_in_place(data, term);
4442        }
4443    }
4444
4445    Ok(planned_blocks)
4446}
4447
4448/// Compute a data-driven initial length scale from the per-axis range of the
4449/// feature columns. The heuristic `max_range / sqrt(n)` puts the kernel on
4450/// the wiggly side of REML's basin so the optimizer can grow it back if the
4451/// signal is smooth, but is small enough that high-frequency truths remain
4452/// reachable for smoother kernels (ν ≥ 5/2). Clamped to a tiny positive
4453/// floor so degenerate constant-input columns can't produce 0.
4454pub fn auto_initial_length_scale(data: ArrayView2<'_, f64>, feature_cols: &[usize]) -> f64 {
4455    /// Tiny positive floor for the auto length scale, guarding against a zero
4456    /// kernel range when every feature column is (near-)constant.
4457    const LENGTH_SCALE_FLOOR: f64 = 1e-6;
4458    let n = data.nrows();
4459    if n == 0 || feature_cols.is_empty() {
4460        return 1.0;
4461    }
4462    let mut max_range = 0.0_f64;
4463    for &c in feature_cols {
4464        if c >= data.ncols() {
4465            continue;
4466        }
4467        let col = data.column(c);
4468        let mut lo = f64::INFINITY;
4469        let mut hi = f64::NEG_INFINITY;
4470        for &v in col.iter() {
4471            if v.is_finite() {
4472                if v < lo {
4473                    lo = v;
4474                }
4475                if v > hi {
4476                    hi = v;
4477                }
4478            }
4479        }
4480        if hi > lo {
4481            let r = hi - lo;
4482            if r > max_range {
4483                max_range = r;
4484            }
4485        }
4486    }
4487    if !max_range.is_finite() || max_range <= 0.0 {
4488        return 1.0;
4489    }
4490    let init = max_range / (n as f64).sqrt();
4491    init.max(LENGTH_SCALE_FLOOR).min(max_range)
4492}
4493
4494/// Walk a term and, if it is a Matern or thin-plate smooth whose length_scale
4495/// was left at the auto sentinel (`0.0`), overwrite it with
4496/// [`auto_initial_length_scale`].
4497pub fn auto_init_length_scale_in_place(data: ArrayView2<'_, f64>, term: &mut SmoothTermSpec) {
4498    auto_init_length_scale_in_basis(data, &mut term.basis);
4499}
4500
4501/// Replace the `0.0` auto-init length-scale sentinel with a data-derived value
4502/// for any Matern / thin-plate kernel reachable from this basis — including the
4503/// inner kernel of a `by=`/factor-smooth wrapper.
4504///
4505/// `by=<factor>` and the sum-to-zero factor smooth wrap a spatial kernel inside
4506/// `SmoothBasisSpec::ByVariable` / `SmoothBasisSpec::FactorSumToZero` /
4507/// `SmoothBasisSpec::BySmooth`, so the wrapper variant is what the planner sees.
4508/// Without recursing into the wrapped basis the inner ThinPlate/Matern keeps the
4509/// `0.0` sentinel (the post-`1605b3a6e` builder default), which makes the kernel
4510/// distance divide by `length_scale² = 0`, producing a non-finite design at both
4511/// fit and predict time. Recurse so the inner kernel is initialized identically
4512/// to a top-level one.
4513pub fn auto_init_length_scale_in_basis(data: ArrayView2<'_, f64>, basis: &mut SmoothBasisSpec) {
4514    match basis {
4515        SmoothBasisSpec::Matern {
4516            feature_cols, spec, ..
4517        } => {
4518            if spec.length_scale == 0.0 {
4519                spec.length_scale = auto_initial_length_scale(data, feature_cols);
4520            }
4521        }
4522        SmoothBasisSpec::ThinPlate {
4523            feature_cols, spec, ..
4524        } => {
4525            if spec.length_scale == 0.0 {
4526                spec.length_scale = auto_initial_length_scale(data, feature_cols);
4527            }
4528        }
4529        SmoothBasisSpec::ByVariable { inner, .. }
4530        | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
4531            auto_init_length_scale_in_basis(data, inner);
4532        }
4533        SmoothBasisSpec::BySmooth { smooth, .. } => {
4534            auto_init_length_scale_in_basis(data, smooth);
4535        }
4536        _ => {}
4537    }
4538}
4539
4540impl LinearFitConditioning {
4541    pub fn from_columns(design: &TermCollectionDesign, selected_cols: &[usize]) -> Self {
4542        const SCALE_EPS: f64 = 1e-12;
4543        let n = design.design.nrows();
4544        let p = design.design.ncols();
4545        let mut columns = Vec::with_capacity(selected_cols.len());
4546        if n == 0 || selected_cols.is_empty() {
4547            return Self {
4548                intercept_idx: design.intercept_range.start,
4549                columns,
4550            };
4551        }
4552        let chunk_rows = gam_linalg::utils::row_chunk_for_byte_budget(n, p);
4553        // Two-pass mean/variance so operator-backed designs don't need to
4554        // materialize the full dense matrix. Pass 1 accumulates per-column
4555        // sums; pass 2 accumulates the sum of squared deviations from the
4556        // pass-1 mean. This matches the original `Σ (x − mean)² / n` formula
4557        // without the catastrophic cancellation of `E[X²] − E[X]²`.
4558        let mut sums = vec![0.0_f64; selected_cols.len()];
4559        for start in (0..n).step_by(chunk_rows) {
4560            let end = (start + chunk_rows).min(n);
4561            let chunk = design
4562                .design
4563                .try_row_chunk(start..end)
4564                .expect("LinearFitConditioning::from_columns row chunk failed");
4565            for (k, &col_idx) in selected_cols.iter().enumerate() {
4566                let column = chunk.column(col_idx);
4567                for &v in column.iter() {
4568                    sums[k] += v;
4569                }
4570            }
4571        }
4572        let inv_n = 1.0_f64 / n as f64;
4573        let means: Vec<f64> = sums.iter().map(|&s| s * inv_n).collect();
4574        let mut sq_devs = vec![0.0_f64; selected_cols.len()];
4575        for start in (0..n).step_by(chunk_rows) {
4576            let end = (start + chunk_rows).min(n);
4577            let chunk = design
4578                .design
4579                .try_row_chunk(start..end)
4580                .expect("LinearFitConditioning::from_columns row chunk failed");
4581            for (k, &col_idx) in selected_cols.iter().enumerate() {
4582                let mean_k = means[k];
4583                let column = chunk.column(col_idx);
4584                for &v in column.iter() {
4585                    let d = v - mean_k;
4586                    sq_devs[k] += d * d;
4587                }
4588            }
4589        }
4590        for (k, &col_idx) in selected_cols.iter().enumerate() {
4591            let mean = means[k];
4592            let var = sq_devs[k] * inv_n;
4593            let (mean, scale) = if var.is_finite() && var > SCALE_EPS * SCALE_EPS {
4594                (mean, var.sqrt())
4595            } else {
4596                // Leave nearly-constant columns untouched; centering them would collapse
4597                // the design column to ~0 and change the model rather than just condition it.
4598                (0.0, 1.0)
4599            };
4600            columns.push(LinearColumnConditioning {
4601                col_idx,
4602                mean,
4603                scale,
4604            });
4605        }
4606        Self {
4607            intercept_idx: design.intercept_range.start,
4608            columns,
4609        }
4610    }
4611
4612    pub fn apply_to_design(&self, design: &Array2<f64>) -> Array2<f64> {
4613        let mut out = design.clone();
4614        for col in &self.columns {
4615            {
4616                let mut dst = out.column_mut(col.col_idx);
4617                dst -= col.mean;
4618            }
4619            if col.scale != 1.0 {
4620                out.column_mut(col.col_idx).mapv_inplace(|v| v / col.scale);
4621            }
4622        }
4623        out
4624    }
4625
4626    fn transform_matrix_columnswith_a(&self, mat: &Array2<f64>) -> Array2<f64> {
4627        let mut out = mat.clone();
4628        let intercept = self.intercept_idx;
4629        for col in &self.columns {
4630            let intercept_col = out.column(intercept).to_owned();
4631            let mut target = out.column_mut(col.col_idx);
4632            target -= &(intercept_col * col.mean);
4633            if col.scale != 1.0 {
4634                target.mapv_inplace(|v| v / col.scale);
4635            }
4636        }
4637        out
4638    }
4639
4640    fn transform_matrixrowswith_a_transpose(&self, mat: &Array2<f64>) -> Array2<f64> {
4641        let mut out = mat.clone();
4642        let intercept = self.intercept_idx;
4643        for col in &self.columns {
4644            let interceptrow = out.row(intercept).to_owned();
4645            let mut target = out.row_mut(col.col_idx);
4646            target -= &(interceptrow * col.mean);
4647            if col.scale != 1.0 {
4648                target.mapv_inplace(|v| v / col.scale);
4649            }
4650        }
4651        out
4652    }
4653
4654    /// Left-multiply `mat_internal` by `M⁻ᵀ` where `M⁻¹[intercept, j] = mean_j`
4655    /// and `M⁻¹[j, j] = scale_j` for each conditioned column. Used together
4656    /// with [`Self::right_multiply_by_m_inv`] to back-transform an internal
4657    /// penalized Hessian to the original coefficient basis.
4658    fn left_multiply_by_m_inv_transpose(&self, mat_internal: &Array2<f64>) -> Array2<f64> {
4659        let mut out = mat_internal.clone();
4660        let intercept = self.intercept_idx;
4661        let interceptrow_snapshot = mat_internal.row(intercept).to_owned();
4662        for col in &self.columns {
4663            if col.scale != 1.0 {
4664                out.row_mut(col.col_idx).mapv_inplace(|v| v * col.scale);
4665            }
4666            if col.mean != 0.0 {
4667                let mut target = out.row_mut(col.col_idx);
4668                target += &(&interceptrow_snapshot * col.mean);
4669            }
4670        }
4671        out
4672    }
4673
4674    /// Right-multiply `mat_internal` by `M⁻¹`. Mirror of
4675    /// [`Self::left_multiply_by_m_inv_transpose`] on columns.
4676    fn right_multiply_by_m_inv(&self, mat_internal: &Array2<f64>) -> Array2<f64> {
4677        let mut out = mat_internal.clone();
4678        let intercept = self.intercept_idx;
4679        let intercept_col_snapshot = mat_internal.column(intercept).to_owned();
4680        for col in &self.columns {
4681            if col.scale != 1.0 {
4682                out.column_mut(col.col_idx).mapv_inplace(|v| v * col.scale);
4683            }
4684            if col.mean != 0.0 {
4685                let mut target = out.column_mut(col.col_idx);
4686                target += &(&intercept_col_snapshot * col.mean);
4687            }
4688        }
4689        out
4690    }
4691
4692    /// Transform blockwise penalties through the conditioning.
4693    ///
4694    /// For block-local penalties whose `col_range` does not overlap with any
4695    /// conditioning column, the transform is identity (the conditioning only
4696    /// affects unpenalized linear columns). In that common case the penalty
4697    /// passes through unchanged, avoiding O(p²) materialization entirely.
4698    pub fn transform_blockwise_penalties_to_internal(
4699        &self,
4700        penalties: &[BlockwisePenalty],
4701        p: usize,
4702    ) -> Vec<crate::penalty_spec::PenaltySpec> {
4703        let conditioning_cols: std::collections::HashSet<usize> =
4704            self.columns.iter().map(|c| c.col_idx).collect();
4705        penalties
4706            .iter()
4707            .map(|bp| {
4708                let overlaps =
4709                    (bp.col_range.start..bp.col_range.end).any(|j| conditioning_cols.contains(&j));
4710                if overlaps {
4711                    // Rare: penalty block overlaps conditioning columns.
4712                    // Fall back to dense transform.
4713                    let global = bp.to_global(p);
4714                    let right = self.transform_matrix_columnswith_a(&global);
4715                    let transformed = self.transform_matrixrowswith_a_transpose(&right);
4716                    crate::penalty_spec::PenaltySpec::Dense(transformed)
4717                } else {
4718                    // Common: smooth penalty block doesn't touch linear columns.
4719                    // The conditioning is identity on this block.
4720                    crate::penalty_spec::PenaltySpec::from_blockwise(bp.clone())
4721                }
4722            })
4723            .collect()
4724    }
4725
4726    pub fn backtransform_beta(&self, beta_internal: &Array1<f64>) -> Array1<f64> {
4727        let mut beta = beta_internal.clone();
4728        let intercept = self.intercept_idx;
4729        for col in &self.columns {
4730            beta[intercept] -= beta_internal[col.col_idx] * col.mean / col.scale;
4731            beta[col.col_idx] = beta_internal[col.col_idx] / col.scale;
4732        }
4733        beta
4734    }
4735
4736    /// `H_orig = M⁻ᵀ · H_int · M⁻¹`, derived from
4737    /// `L_int(β_int) = L_orig(M · β_int)` via the chain rule.
4738    pub fn transform_penalized_hessian_to_original(&self, h_internal: &Array2<f64>) -> Array2<f64> {
4739        let right = self.right_multiply_by_m_inv(h_internal);
4740        self.left_multiply_by_m_inv_transpose(&right)
4741    }
4742
4743    pub fn internal_bounds_for(&self, col_idx: usize, min: f64, max: f64) -> (f64, f64) {
4744        if let Some(col) = self.columns.iter().find(|c| c.col_idx == col_idx) {
4745            (min * col.scale, max * col.scale)
4746        } else {
4747            (min, max)
4748        }
4749    }
4750}
4751
4752pub fn freeze_raw_spatial_metadata(metadata: BasisMetadata, raw_cols: usize) -> BasisMetadata {
4753    match metadata {
4754        BasisMetadata::ThinPlate {
4755            centers,
4756            length_scale,
4757            periodic,
4758            identifiability_transform: None,
4759            input_scales,
4760            radial_reparam,
4761        } => BasisMetadata::ThinPlate {
4762            centers,
4763            length_scale,
4764            periodic,
4765            identifiability_transform: Some(Array2::eye(raw_cols)),
4766            input_scales,
4767            radial_reparam,
4768        },
4769        BasisMetadata::Duchon {
4770            centers,
4771            length_scale,
4772            periodic,
4773            power,
4774            nullspace_order,
4775            identifiability_transform: None,
4776            input_scales,
4777            aniso_log_scales,
4778            operator_collocation_points,
4779            radial_reparam,
4780        } => BasisMetadata::Duchon {
4781            centers,
4782            length_scale,
4783            periodic,
4784            power,
4785            nullspace_order,
4786            identifiability_transform: Some(Array2::eye(raw_cols)),
4787            input_scales,
4788            aniso_log_scales,
4789            operator_collocation_points,
4790            radial_reparam,
4791        },
4792        other => other,
4793    }
4794}
4795
4796pub fn matern_operator_penalty_triplet_from_metadata(
4797    metadata: &BasisMetadata,
4798) -> Result<(Vec<Array2<f64>>, Vec<usize>, Vec<PenaltyInfo>), BasisError> {
4799    let BasisMetadata::Matern {
4800        centers,
4801        length_scale,
4802        periodic,
4803        nu,
4804        include_intercept,
4805        identifiability_transform,
4806        aniso_log_scales,
4807        input_scales,
4808        ..
4809    } = metadata
4810    else {
4811        crate::bail_invalid_basis!("Matérn operator penalties require Matérn metadata");
4812    };
4813    // The metadata records `length_scale` in *original* (un-standardized) data
4814    // coordinates, while `centers` live in the *standardized* coordinate frame
4815    // (per-axis division by `input_scales`). The realized design built the
4816    // kernel against those standardized centers using the σ_geom-compensated
4817    // effective length scale `length_scale / σ_geom`. The collocation operators
4818    // here are evaluated on the same standardized centers, so they must use the
4819    // SAME effective length scale — otherwise the penalty regularizes a
4820    // different RKHS range than the design lives in, leaving rough coefficient
4821    // directions effectively unpenalized. That mismatch is benign in 1-D
4822    // (no standardization) but produces a catastrophic out-of-sample blow-up in
4823    // d ≥ 2 where σ_geom ≠ 1 (#706).
4824    let penalty_length_scale = match input_scales.as_deref() {
4825        Some(scales) => compensate_length_scale_for_standardization(*length_scale, scales),
4826        None => *length_scale,
4827    };
4828    matern_operator_penalty_triplet_at_length_scale(
4829        centers.view(),
4830        periodic.as_deref(),
4831        identifiability_transform.as_ref(),
4832        *nu,
4833        *include_intercept,
4834        aniso_log_scales.as_deref(),
4835        penalty_length_scale,
4836    )
4837}
4838
4839/// Build the canonical Matérn operator-penalty triplet (mass / tension /
4840/// stiffness) at an explicit **effective** length scale — i.e. the
4841/// σ_geom-compensated, standardized-frame scale the design's kernel was built
4842/// against (NOT the original-coordinate `length_scale` stored in metadata).
4843///
4844/// This is the SINGLE source of truth for the Matérn penalty topology. Two
4845/// callers route through it and must therefore stay byte-for-byte consistent:
4846///   * the cold/slow design rebuild (`matern_operator_penalty_triplet_from_metadata`,
4847///     compensating the frozen metadata `length_scale`), and
4848///   * the n-free κ-optimizer re-key (`FrozenTermCollectionIncrementalRealizer::
4849///     canonical_penalties_at_psi`, compensating the trial `ψ → exp(-ψ)` scale).
4850///
4851/// Sharing the body makes the penalty BLOCK COUNT and the per-block numerics
4852/// one deterministic function of `(geometry, ν, η, ℓ_eff)`. The active-operator
4853/// gate is `m = ν + d/2`, which is independent of ℓ, so the block count is
4854/// **ψ-stable by construction**: the re-key can never produce a different number
4855/// of blocks than the frozen design (the desync that #1270 hard-errored on).
4856pub fn matern_operator_penalty_triplet_at_length_scale(
4857    centers: ArrayView2<'_, f64>,
4858    periodic: Option<&[Option<f64>]>,
4859    identifiability_transform: Option<&Array2<f64>>,
4860    nu: crate::basis::MaternNu,
4861    include_intercept: bool,
4862    aniso_log_scales: Option<&[f64]>,
4863    effective_length_scale: f64,
4864) -> Result<(Vec<Array2<f64>>, Vec<usize>, Vec<PenaltyInfo>), BasisError> {
4865    let penalty_centers = crate::basis::expand_periodic_centers(&centers.to_owned(), periodic)?;
4866    let ops = build_matern_collocation_operator_matrices(
4867        penalty_centers.view(),
4868        None,
4869        effective_length_scale,
4870        nu,
4871        include_intercept,
4872        identifiability_transform.map(|z| z.view()),
4873        aniso_log_scales,
4874    )?;
4875    // Gate the operator dials on the Matérn-ν RKHS Sobolev order m = ν + d/2:
4876    // mass (j=0) is always on, tension (j=1) is on for m > 1, stiffness (j=2)
4877    // is on for m > 2. The threshold is strict so the roughest kernel ν=1/2 in
4878    // d=1 (m=1, the exponential/OU H¹ process) sheds both higher operators —
4879    // its kernel already encodes the H¹ control, so adding an extra tension
4880    // dial over-smooths the oscillation it is meant to track (#707). The
4881    // matching gate lives at `DuchonOperatorPenaltySpec::matern_for_smoothness`.
4882    const ORDER_EPS: f64 = 1e-9;
4883    let d = penalty_centers.ncols();
4884    let m = nu.half_integer_value() + 0.5 * d as f64;
4885    let mut candidates = Vec::with_capacity(3);
4886    for (raw, source, min_order) in [
4887        (ops.d0.t().dot(&ops.d0), PenaltySource::OperatorMass, 0.0),
4888        (ops.d1.t().dot(&ops.d1), PenaltySource::OperatorTension, 1.0),
4889        (
4890            ops.d2.t().dot(&ops.d2),
4891            PenaltySource::OperatorStiffness,
4892            2.0,
4893        ),
4894    ] {
4895        if min_order > 0.0 && m <= min_order + ORDER_EPS {
4896            continue;
4897        }
4898        let sym = (&raw + &raw.t()) * 0.5;
4899        let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&sym);
4900        candidates.push(PenaltyCandidate {
4901            matrix,
4902            nullspace_dim_hint: 0,
4903            source,
4904            normalization_scale,
4905            kronecker_factors: None,
4906            op: None,
4907        });
4908    }
4909    filter_active_penalty_candidates(candidates)
4910}
4911
4912pub fn normalize_penalty_in_constrained_space(matrix: &Array2<f64>) -> (Array2<f64>, f64) {
4913    // Constrained-space normalization:
4914    //   c = ||S_con||_F,  S_tilde = S_con / c.
4915    // This is the only normalization coherent with a REML objective that is
4916    // evaluated entirely in constrained coordinates.
4917    let matrix = (matrix + &matrix.t().to_owned()) * 0.5;
4918    // Clamp noise-floor negative eigenvalues so β'Sβ is non-negative as a contract, not just in exact arithmetic.
4919    let matrix = crate::basis::project_penalty_to_psd_cone(&matrix);
4920    let c = matrix.iter().map(|v| v * v).sum::<f64>().sqrt();
4921    if c.is_finite() && c > 0.0 {
4922        (matrix.mapv(|v| v / c), c)
4923    } else {
4924        (matrix, 1.0)
4925    }
4926}
4927
4928pub fn tensor_product_design_from_sparse_marginals(
4929    marginal_sparse: &[&SparseColMat<usize, f64>],
4930) -> Result<SparseColMat<usize, f64>, BasisError> {
4931    if marginal_sparse.is_empty() {
4932        crate::bail_invalid_basis!("TensorBSpline requires at least one marginal basis");
4933    }
4934    let n = marginal_sparse[0].nrows();
4935    for (i, m) in marginal_sparse.iter().enumerate().skip(1) {
4936        if m.nrows() != n {
4937            crate::bail_dim_basis!(
4938                "tensor sparse marginal row mismatch at dim {i}: expected {n}, got {}",
4939                m.nrows()
4940            );
4941        }
4942    }
4943    let dims: Vec<usize> = marginal_sparse.iter().map(|m| m.ncols()).collect();
4944    let total_cols = dims.iter().try_fold(1usize, |acc, &q| {
4945        acc.checked_mul(q)
4946            .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))
4947    })?;
4948    let mut strides = vec![1usize; dims.len()];
4949    for d in (0..dims.len().saturating_sub(1)).rev() {
4950        strides[d] = strides[d + 1]
4951            .checked_mul(dims[d + 1])
4952            .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))?;
4953    }
4954
4955    use faer::sparse::SparseRowMat;
4956    let csrs: Vec<SparseRowMat<usize, f64>> = marginal_sparse
4957        .iter()
4958        .enumerate()
4959        .map(|(d, m)| {
4960            m.as_ref().to_row_major().map_err(|e| {
4961                BasisError::SparseCreation(format!(
4962                    "tensor sparse marginal {d} CSR conversion failed: {e:?}"
4963                ))
4964            })
4965        })
4966        .collect::<Result<Vec<_>, _>>()?;
4967    let row_ptrs: Vec<&[usize]> = csrs.iter().map(|c| c.symbolic().row_ptr()).collect();
4968    let col_idxs: Vec<&[usize]> = csrs.iter().map(|c| c.symbolic().col_idx()).collect();
4969    let vals: Vec<&[f64]> = csrs.iter().map(|c| c.val()).collect();
4970
4971    use rayon::prelude::*;
4972    const CHUNK: usize = 1024;
4973    let num_chunks = n.div_ceil(CHUNK);
4974    let per_chunk: Vec<Vec<Triplet<usize, usize, f64>>> = (0..num_chunks)
4975        .into_par_iter()
4976        .map(|chunk_idx| {
4977            let row_start = chunk_idx * CHUNK;
4978            let row_end = (row_start + CHUNK).min(n);
4979            let mut chunk_triplets = Vec::<Triplet<usize, usize, f64>>::new();
4980            let mut cur_cols = Vec::<usize>::with_capacity(64);
4981            let mut cur_vals = Vec::<f64>::with_capacity(64);
4982            let mut next_cols = Vec::<usize>::with_capacity(64);
4983            let mut next_vals = Vec::<f64>::with_capacity(64);
4984            for i in row_start..row_end {
4985                cur_cols.clear();
4986                cur_vals.clear();
4987                cur_cols.push(0);
4988                cur_vals.push(1.0);
4989                let mut row_is_zero = false;
4990                for d in 0..dims.len() {
4991                    let row_start_d = row_ptrs[d][i];
4992                    let row_end_d = row_ptrs[d][i + 1];
4993                    if row_start_d == row_end_d {
4994                        row_is_zero = true;
4995                        break;
4996                    }
4997                    let stride = strides[d];
4998                    next_cols.clear();
4999                    next_vals.clear();
5000                    next_cols.reserve(cur_cols.len() * (row_end_d - row_start_d));
5001                    next_vals.reserve(cur_vals.len() * (row_end_d - row_start_d));
5002                    for (&prev_col, &prev_val) in cur_cols.iter().zip(cur_vals.iter()) {
5003                        for ptr in row_start_d..row_end_d {
5004                            let cj = col_idxs[d][ptr];
5005                            let vj = vals[d][ptr];
5006                            next_cols.push(prev_col + cj * stride);
5007                            next_vals.push(prev_val * vj);
5008                        }
5009                    }
5010                    std::mem::swap(&mut cur_cols, &mut next_cols);
5011                    std::mem::swap(&mut cur_vals, &mut next_vals);
5012                }
5013                if row_is_zero {
5014                    continue;
5015                }
5016                for (&col, &val) in cur_cols.iter().zip(cur_vals.iter()) {
5017                    chunk_triplets.push(Triplet::new(i, col, val));
5018                }
5019            }
5020            chunk_triplets
5021        })
5022        .collect();
5023    let total_nnz: usize = per_chunk.iter().map(Vec::len).sum();
5024    let mut triplets = Vec::<Triplet<usize, usize, f64>>::with_capacity(total_nnz);
5025    for chunk in per_chunk {
5026        triplets.extend(chunk);
5027    }
5028    SparseColMat::try_new_from_triplets(n, total_cols, &triplets).map_err(|e| {
5029        BasisError::SparseCreation(format!(
5030            "failed to assemble sparse tensor product design: {e:?}"
5031        ))
5032    })
5033}
5034
5035pub fn dense_local_margin_to_sparse(
5036    dense: &Array2<f64>,
5037) -> Result<SparseColMat<usize, f64>, BasisError> {
5038    let expected_row_nnz = dense.ncols().min(4);
5039    let mut triplets =
5040        Vec::<Triplet<usize, usize, f64>>::with_capacity(dense.nrows() * expected_row_nnz);
5041    for ((row, col), &value) in dense.indexed_iter() {
5042        if value != 0.0 {
5043            triplets.push(Triplet::new(row, col, value));
5044        }
5045    }
5046    SparseColMat::try_new_from_triplets(dense.nrows(), dense.ncols(), &triplets).map_err(|e| {
5047        BasisError::SparseCreation(format!(
5048            "failed to convert tensor marginal design to sparse form: {e:?}"
5049        ))
5050    })
5051}
5052
5053pub struct TensorMarginRangeNullProjectors {
5054    range: Array2<f64>,
5055    null: Array2<f64>,
5056}
5057
5058pub fn projector_from_columns(columns: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
5059    if indices.is_empty() {
5060        return Array2::<f64>::zeros((columns.nrows(), columns.nrows()));
5061    }
5062    let basis = columns.select(Axis(1), indices);
5063    basis.dot(&basis.t())
5064}
5065
5066pub fn tensor_margin_range_null_projectors(
5067    normalized_marginal_penalties: &[(Array2<f64>, f64)],
5068) -> Result<Vec<TensorMarginRangeNullProjectors>, BasisError> {
5069    normalized_marginal_penalties
5070        .iter()
5071        .enumerate()
5072        .map(|(dim, (penalty, _))| {
5073            let analysis = crate::basis::analyze_penalty_block(penalty)?;
5074            if analysis.rank == 0 {
5075                crate::bail_invalid_basis!(
5076                    "t2 separable tensor penalty margin {dim} has rank-zero penalty; \
5077                     cannot split penalized and null subspaces"
5078                );
5079            }
5080            let mut range_idx = Vec::<usize>::new();
5081            let mut null_idx = Vec::<usize>::new();
5082            for (idx, &ev) in analysis.eigenvalues.iter().enumerate() {
5083                if ev > analysis.tol {
5084                    range_idx.push(idx);
5085                } else {
5086                    null_idx.push(idx);
5087                }
5088            }
5089            Ok(TensorMarginRangeNullProjectors {
5090                range: projector_from_columns(&analysis.eigenvectors, &range_idx),
5091                null: projector_from_columns(&analysis.eigenvectors, &null_idx),
5092            })
5093        })
5094        .collect()
5095}
5096
5097pub fn build_tensor_bspline_basis(
5098    data: ArrayView2<'_, f64>,
5099    feature_cols: &[usize],
5100    spec: &TensorBSplineSpec,
5101) -> Result<BasisBuildResult, BasisError> {
5102    if feature_cols.is_empty() {
5103        crate::bail_invalid_basis!("TensorBSpline requires at least one feature column");
5104    }
5105    if feature_cols.len() != spec.marginalspecs.len() {
5106        crate::bail_dim_basis!(
5107            "TensorBSpline feature/spec mismatch: feature_cols={}, marginalspecs={}",
5108            feature_cols.len(),
5109            spec.marginalspecs.len()
5110        );
5111    }
5112    if !spec.periods.is_empty() && spec.periods.len() != feature_cols.len() {
5113        crate::bail_dim_basis!(
5114            "TensorBSpline periods length {} does not match feature count {}",
5115            spec.periods.len(),
5116            feature_cols.len()
5117        );
5118    }
5119    let p = data.ncols();
5120    for &c in feature_cols {
5121        if c >= p {
5122            crate::bail_dim_basis!(
5123                "tensor feature column {c} is out of bounds for data with {p} columns"
5124            );
5125        }
5126    }
5127
5128    let mut marginal_knots = Vec::<Array1<f64>>::with_capacity(feature_cols.len());
5129    // Per-margin cr flag (#1074): `true` when the margin is a natural cubic
5130    // regression spline, so the tensor freeze rebuilds the cr knotspec.
5131    let mut marginal_is_cr_flags = Vec::<bool>::with_capacity(feature_cols.len());
5132    let mut marginal_degrees = Vec::<usize>::with_capacity(feature_cols.len());
5133    let mut marginalnum_basis = Vec::<usize>::with_capacity(feature_cols.len());
5134    let mut marginal_penalties = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
5135    let mut marginal_designs = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
5136    // Per-margin effective period: either user-set via `spec.periods` or
5137    // implied by a `PeriodicUniform` marginal knotspec (which the 1D B-spline
5138    // builder realizes as a cyclic B-spline basis).
5139    // Captured here so freeze→reload round-trips both routes back to a
5140    // `PeriodicUniform` marginal knotspec; otherwise a `PeriodicUniform`
5141    // margin specified without `spec.periods` would freeze as a plain
5142    // `Provided(knots)` open spline and lose its wrap-around at predict time.
5143    let mut marginal_effective_periods = Vec::<Option<f64>>::with_capacity(feature_cols.len());
5144    // Per-marginal sparse representation, populated when the 1D builder returned
5145    // a `DesignMatrix::Sparse`. Used to assemble the Khatri-Rao tensor product
5146    // sparsely (only ∏(degree+1) nonzeros per row) instead of densifying to
5147    // shape (n, ∏ q_j) up front. Periodic B-spline margins are local-support
5148    // bases too; when the 1D builder returns them densely, we convert that
5149    // marginal back to sparse form so cylinder/torus tensor products keep the
5150    // same scale behavior as open tensor products.
5151    let mut marginal_sparse =
5152        Vec::<Option<SparseColMat<usize, f64>>>::with_capacity(feature_cols.len());
5153
5154    // Reuse the robust 1D builder to ensure the same knot validation and
5155    // marginal difference-penalty construction as standalone smooth terms.
5156    for (dim, (&col, marginalspec)) in feature_cols
5157        .iter()
5158        .zip(spec.marginalspecs.iter())
5159        .enumerate()
5160    {
5161        // Tensor basis uses raw marginal knot-product columns. Applying 1D
5162        // identifiability constraints here would change marginal penalty sizes
5163        // without changing the tensor design construction, causing dimension
5164        // mismatch. Keep marginal builders unconstrained at this stage.
5165        let mut marginal_unconstrained = marginalspec.clone();
5166        marginal_unconstrained.identifiability = BSplineIdentifiability::None;
5167        let built = build_bspline_basis_1d(data.column(col), &marginal_unconstrained)?;
5168        // A cr (`NaturalCubicRegression`) margin emits `CubicRegression1D`
5169        // metadata whose `knots` are the k value-knots; a B-spline margin emits
5170        // `BSpline1D` with the clamped knot vector. Capture either so the
5171        // tensor freeze can rebuild the exact same marginal knotspec (#1074).
5172        let (knots, marginal_is_cr) = match built.metadata {
5173            BasisMetadata::BSpline1D { knots, .. } => (knots, false),
5174            BasisMetadata::CubicRegression1D { knots, .. } => (knots, true),
5175            _ => {
5176                crate::bail_invalid_basis!(
5177                    "internal TensorBSpline error at dim {dim}: expected BSpline1D or CubicRegression1D metadata"
5178                );
5179            }
5180        };
5181        let metadata_knots = match marginalspec.knotspec {
5182            BSplineKnotSpec::PeriodicUniform {
5183                data_range,
5184                num_basis,
5185            } => Array1::linspace(data_range.0, data_range.1, num_basis),
5186            _ => knots,
5187        };
5188        marginal_knots.push(metadata_knots);
5189        marginal_is_cr_flags.push(marginal_is_cr);
5190        marginal_degrees.push(marginalspec.degree);
5191        marginalnum_basis.push(built.design.ncols());
5192        // Capture the sparse representation of this marginal (when the
5193        // 1D builder produced one) before densifying for the dense
5194        // marginal cache used by `tensor_product_design_from_marginals`
5195        // and `TensorProductDesignOperator`.
5196        let dense_marginal = built.design.to_dense();
5197        let sparse_view: Option<SparseColMat<usize, f64>> = match built.design.as_sparse() {
5198            Some(sd) => {
5199                let inner: &SparseColMat<usize, f64> = sd;
5200                Some(inner.clone())
5201            }
5202            None => match marginalspec.knotspec {
5203                BSplineKnotSpec::PeriodicUniform { .. } => {
5204                    Some(dense_local_margin_to_sparse(&dense_marginal)?)
5205                }
5206                _ => None,
5207            },
5208        };
5209        marginal_sparse.push(sparse_view);
5210        marginal_designs.push(dense_marginal);
5211        marginal_penalties.push(
5212            built
5213                .penalties
5214                .first()
5215                .ok_or_else(|| {
5216                    BasisError::InvalidInput(format!(
5217                        "internal TensorBSpline error at dim {dim}: missing marginal penalty"
5218                    ))
5219                })?
5220                .clone(),
5221        );
5222        built.nullspace_dims.first().ok_or_else(|| {
5223            BasisError::InvalidInput(format!(
5224                "internal TensorBSpline error at dim {dim}: missing marginal nullspace dim"
5225            ))
5226        })?;
5227        // A `PeriodicUniform` marginal knotspec implies the margin is
5228        // wrap-around: the 1D builder already realized it as a periodic
5229        // basis, so the tensor product inherits that periodicity. Record
5230        // the period derived from the knotspec's data range so freeze
5231        // restores `PeriodicUniform` on the marginal — otherwise the
5232        // round-trip downgrades it to `Provided(knots)` (an open spline)
5233        // and predict-time wraps disappear.
5234        let implied_period = match marginalspec.knotspec {
5235            BSplineKnotSpec::PeriodicUniform { data_range, .. } => {
5236                Some(data_range.1 - data_range.0)
5237            }
5238            _ => spec.periods.get(dim).and_then(|p| *p),
5239        };
5240        marginal_effective_periods.push(implied_period);
5241    }
5242
5243    let total_cols: usize = marginalnum_basis.iter().product();
5244    let mut dense_design = (!matches!(spec.identifiability, TensorBSplineIdentifiability::None))
5245        .then(|| tensor_product_design_from_marginals(&marginal_designs))
5246        .transpose()?;
5247    let mut candidates = Vec::<PenaltyCandidate>::with_capacity(
5248        match spec.penalty_decomposition {
5249            TensorBSplinePenaltyDecomposition::MarginalKroneckerSum => marginal_penalties.len(),
5250            TensorBSplinePenaltyDecomposition::Separable => marginal_penalties.len() * 2,
5251        } + if spec.double_penalty { 1 } else { 0 },
5252    );
5253
5254    // Tensor-product smoothing parameters are one-per-margin.  Therefore the
5255    // physical penalty attached to a margin must be normalized in that margin's
5256    // own working coordinates before it is embedded in the full tensor product.
5257    // Normalizing only the already-Kroneckered matrix would fold arbitrary
5258    // dimension-dependent identity factors into the margin's lambda and would
5259    // make anisotropic REML/LAML smoothing depend on the other margins' basis
5260    // sizes rather than on the marginal roughness operator itself.
5261    let normalized_marginal_penalties: Vec<(Array2<f64>, f64)> = marginal_penalties
5262        .iter()
5263        .map(normalize_penalty_in_constrained_space)
5264        .collect();
5265    let mut kronecker_marginal_penalties =
5266        Vec::<Array2<f64>>::with_capacity(normalized_marginal_penalties.len());
5267
5268    match spec.penalty_decomposition {
5269        TensorBSplinePenaltyDecomposition::MarginalKroneckerSum => {
5270            // Accumulate the Kronecker-sum of the per-margin penalties,
5271            // `Σ_dim S_dim`, whose null space is exactly the *joint* null space
5272            // of all marginal penalties — the tensor of marginal polynomial
5273            // null spaces. The tensor double penalty (below) shrinks only this
5274            // joint null, never the already-penalized interaction range.
5275            let mut marginal_kron_sum = Array2::<f64>::zeros((total_cols, total_cols));
5276
5277            for dim in 0..normalized_marginal_penalties.len() {
5278                let mut s_dim = Array2::<f64>::eye(1);
5279                let mut factors = Vec::<Array2<f64>>::with_capacity(marginalnum_basis.len());
5280                for (j, &qj) in marginalnum_basis.iter().enumerate() {
5281                    let factor = if j == dim {
5282                        normalized_marginal_penalties[j].0.clone()
5283                    } else {
5284                        Array2::<f64>::eye(qj)
5285                    };
5286                    factors.push(factor.clone());
5287                    s_dim = kronecker_product(&s_dim, &factor);
5288                }
5289                if dim == kronecker_marginal_penalties.len() {
5290                    kronecker_marginal_penalties.push(normalized_marginal_penalties[dim].0.clone());
5291                }
5292                marginal_kron_sum += &s_dim;
5293
5294                candidates.push(PenaltyCandidate {
5295                    matrix: s_dim,
5296                    nullspace_dim_hint: 0,
5297                    source: PenaltySource::TensorMarginal { dim },
5298                    normalization_scale: normalized_marginal_penalties[dim].1,
5299                    kronecker_factors: Some(factors),
5300                    op: None,
5301                });
5302            }
5303
5304            if spec.double_penalty
5305                && let Some(shrink) =
5306                    crate::basis::build_nullspace_shrinkage_penalty(&marginal_kron_sum)?
5307            {
5308                let (matrix, normalization_scale) =
5309                    normalize_penalty_in_constrained_space(&shrink.sym_penalty);
5310                candidates.push(PenaltyCandidate {
5311                    matrix,
5312                    nullspace_dim_hint: 0,
5313                    source: PenaltySource::TensorGlobalRidge,
5314                    normalization_scale,
5315                    kronecker_factors: None,
5316                    op: None,
5317                });
5318            }
5319        }
5320        TensorBSplinePenaltyDecomposition::Separable => {
5321            let projectors = tensor_margin_range_null_projectors(&normalized_marginal_penalties)?;
5322            let n_masks = 1usize.checked_shl(projectors.len() as u32).ok_or_else(|| {
5323                BasisError::InvalidInput(format!(
5324                    "t2 separable tensor penalty supports at most {} margins, got {}",
5325                    usize::BITS - 1,
5326                    projectors.len()
5327                ))
5328            })?;
5329            for mask in 1..n_masks {
5330                let mut matrix = Array2::<f64>::eye(1);
5331                let mut factors = Vec::<Array2<f64>>::with_capacity(projectors.len());
5332                let mut penalized_margins = Vec::<usize>::new();
5333                for (dim, projector) in projectors.iter().enumerate() {
5334                    let use_range = ((mask >> dim) & 1) == 1;
5335                    let factor = if use_range {
5336                        penalized_margins.push(dim);
5337                        projector.range.clone()
5338                    } else {
5339                        projector.null.clone()
5340                    };
5341                    matrix = kronecker_product(&matrix, &factor);
5342                    factors.push(factor);
5343                }
5344                let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&matrix);
5345                candidates.push(PenaltyCandidate {
5346                    matrix,
5347                    nullspace_dim_hint: 0,
5348                    source: PenaltySource::TensorSeparable { penalized_margins },
5349                    normalization_scale,
5350                    kronecker_factors: Some(factors),
5351                    op: None,
5352                });
5353            }
5354
5355            if spec.double_penalty {
5356                let mut matrix = Array2::<f64>::eye(1);
5357                let mut factors = Vec::<Array2<f64>>::with_capacity(projectors.len());
5358                for projector in &projectors {
5359                    matrix = kronecker_product(&matrix, &projector.null);
5360                    factors.push(projector.null.clone());
5361                }
5362                let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&matrix);
5363                candidates.push(PenaltyCandidate {
5364                    matrix,
5365                    nullspace_dim_hint: 0,
5366                    source: PenaltySource::TensorGlobalRidge,
5367                    normalization_scale,
5368                    kronecker_factors: Some(factors),
5369                    op: None,
5370                });
5371            }
5372        }
5373    }
5374
5375    let z_opt = match &spec.identifiability {
5376        TensorBSplineIdentifiability::None => None,
5377        TensorBSplineIdentifiability::SumToZero => {
5378            if total_cols < 2 {
5379                crate::bail_invalid_basis!(
5380                    "TensorBSpline requires at least 2 basis coefficients to enforce sum-to-zero identifiability"
5381                );
5382            }
5383            let dense_design_ref = dense_design.as_ref().ok_or_else(|| {
5384                BasisError::InvalidInput(
5385                    "tensor sum-to-zero identifiability requires a realized basis".to_string(),
5386                )
5387            })?;
5388            let (_, z) = apply_sum_to_zero_constraint(dense_design_ref.view(), None)?;
5389            let gauge = gam_problem::Gauge::sum_to_zero(z);
5390            Some(gauge.block_transform(0))
5391        }
5392        TensorBSplineIdentifiability::MarginalSumToZero => {
5393            // `ti(...)`: drop the marginal main effects by centering every
5394            // margin independently, then form the tensor product of the
5395            // centered margins. Concretely, each margin `j` is reparameterized
5396            // by its own sum-to-zero null basis `Z_j` (so the constant — i.e.
5397            // the marginal intercept — is removed from that axis), and the
5398            // combined reparameterization is the Kronecker product
5399            // `Z = Z₀ ⊗ Z₁ ⊗ … ⊗ Z_{d-1}`. Applying `Z` to the full-tensor
5400            // design `B = B₀ ⊗ … ⊗ B_{d-1}` yields `B Z = (B₀ Z₀) ⊗ … ⊗
5401            // (B_{d-1} Z_{d-1})`, the tensor product of the centered margins,
5402            // which by construction contains no pure main effect.
5403            if marginal_designs.len() < 2 {
5404                crate::bail_invalid_basis!(
5405                    "tensor interaction (ti) identifiability requires at least 2 margins"
5406                );
5407            }
5408            let mut z = Array2::<f64>::eye(1);
5409            for (dim, marginal) in marginal_designs.iter().enumerate() {
5410                if marginal.ncols() < 2 {
5411                    crate::bail_invalid_basis!(
5412                        "tensor interaction (ti) margin {dim} has fewer than 2 basis functions; \
5413                         cannot remove its marginal main effect"
5414                    );
5415                }
5416                let (_, z_dim) = apply_sum_to_zero_constraint(marginal.view(), None)?;
5417                let gauge_dim = gam_problem::Gauge::sum_to_zero(z_dim);
5418                let z_dim = gauge_dim.block_transform(0);
5419                z = kronecker_product(&z, &z_dim);
5420            }
5421            Some(z)
5422        }
5423        TensorBSplineIdentifiability::FrozenTransform { transform } => {
5424            if transform.nrows() != total_cols {
5425                crate::bail_dim_basis!(
5426                    "frozen tensor identifiability transform mismatch: design has {} columns but transform has {} rows",
5427                    total_cols,
5428                    transform.nrows()
5429                );
5430            }
5431            Some(transform.clone())
5432        }
5433    };
5434
5435    if let Some(z) = z_opt.as_ref() {
5436        let gauge = gam_problem::Gauge::from_block_transforms(&[z.clone()]);
5437        let dense = dense_design.as_mut().ok_or_else(|| {
5438            BasisError::InvalidInput(
5439                "tensor identifiability transform requires a realized basis".to_string(),
5440            )
5441        })?;
5442        let restricted_design = gauge.restrict_design(dense);
5443        *dense = restricted_design;
5444        candidates = candidates
5445            .into_iter()
5446            .map(|candidate| -> Result<PenaltyCandidate, BasisError> {
5447                let matrix = gauge.restrict_penalty(&candidate.matrix);
5448                // Re-normalize in the *actual* coefficient chart used by the
5449                // fit.  The tensor sum-to-zero transform is not norm-preserving
5450                // for each overlapping marginal penalty, so carrying the raw
5451                // marginal Frobenius scale into the restricted space changes the
5452                // relative amount of smoothing seen by the LAML/REML optimizer.
5453                // Keep the physical scale in metadata and give the optimizer
5454                // unit-scale constrained penalties for every tensor margin.
5455                let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
5456                Ok(PenaltyCandidate {
5457                    nullspace_dim_hint: candidate.nullspace_dim_hint,
5458                    matrix,
5459                    source: candidate.source,
5460                    normalization_scale: candidate.normalization_scale * c_new,
5461                    // Z^T S Z is no longer a Kronecker product of the original
5462                    // marginal factors, so the Kronecker fast path in construction.rs
5463                    // must not be taken. Clearing kronecker_factors forces the generic
5464                    // block-local eigendecomposition path, which operates on the
5465                    // transformed matrix and is correct.
5466                    kronecker_factors: None,
5467                    op: candidate.op.clone(),
5468                })
5469            })
5470            .collect::<Result<Vec<_>, _>>()?;
5471    }
5472
5473    let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
5474        filter_active_penalty_candidates_with_ops(candidates)?;
5475    let identifiability_is_none =
5476        matches!(spec.identifiability, TensorBSplineIdentifiability::None);
5477    // All marginals expose a sparse representation iff each `marginal_sparse`
5478    // slot is `Some(...)`. Currently this is true when every marginal is a
5479    // free-boundary, non-periodic 1D B-spline returned as
5480    // `DesignMatrix::Sparse` from `build_bspline_basis_1d`. Periodic B-splines
5481    // and other dense-only marginals leave a `None` and trigger the fall-back
5482    // path. Identifiability transforms (`SumToZero`, `FrozenTransform`) make
5483    // the tensor design dense in general, so we also gate on that.
5484    let all_marginals_sparse = marginal_sparse.iter().all(Option::is_some);
5485    let design = if let Some(dense_design) = dense_design {
5486        DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense_design))
5487    } else if identifiability_is_none && all_marginals_sparse {
5488        // Sparse Khatri-Rao path: assemble the (n, ∏ q_j) tensor product
5489        // directly as a SparseColMat, preserving the ∏(degree_j+1) nonzero
5490        // structure per row instead of densifying to ∏ q_j columns. This is
5491        // mathematically identical to `tensor_product_design_from_marginals`
5492        // applied to the corresponding dense marginals.
5493        let sparse_marginals: Vec<&SparseColMat<usize, f64>> = marginal_sparse
5494            .iter()
5495            .map(|m| m.as_ref().expect("all_marginals_sparse just verified"))
5496            .collect();
5497        let sparse_design = tensor_product_design_from_sparse_marginals(&sparse_marginals)?;
5498        DesignMatrix::Sparse(gam_linalg::matrix::SparseDesignMatrix::new(sparse_design))
5499    } else {
5500        let marginals: Vec<Arc<Array2<f64>>> = marginal_designs
5501            .iter()
5502            .map(|m| Arc::new(m.clone()))
5503            .collect();
5504        let op = TensorProductDesignOperator::new(marginals).map_err(|e| {
5505            BasisError::InvalidInput(format!("TensorProductDesignOperator build failed: {e}"))
5506        })?;
5507        DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op)))
5508    };
5509
5510    Ok(BasisBuildResult {
5511        design,
5512        penalties,
5513        nullspace_dims,
5514        penaltyinfo,
5515        ops,
5516        null_eigenvectors,
5517        joint_null_rotation: None,
5518        metadata: BasisMetadata::TensorBSpline {
5519            feature_cols: feature_cols.to_vec(),
5520            knots: marginal_knots,
5521            degrees: marginal_degrees,
5522            // Prefer the per-margin effective period derived in the loop —
5523            // it captures both the explicit `spec.periods` route and the
5524            // implied period from a `PeriodicUniform` marginal knotspec.
5525            // Falling back to `spec.periods` when populated keeps any
5526            // user-supplied explicit period authoritative even if the
5527            // marginal knotspec carried no periodicity hint.
5528            periods: marginal_effective_periods,
5529            is_cr: marginal_is_cr_flags,
5530            identifiability_transform: z_opt,
5531        },
5532        kronecker_factored: if matches!(spec.identifiability, TensorBSplineIdentifiability::None)
5533            && matches!(
5534                spec.penalty_decomposition,
5535                TensorBSplinePenaltyDecomposition::MarginalKroneckerSum
5536            ) {
5537            Some(KroneckerFactoredBasis::new(
5538                marginal_designs,
5539                kronecker_marginal_penalties,
5540                marginalnum_basis.clone(),
5541                spec.double_penalty,
5542            ))
5543        } else {
5544            None
5545        },
5546    })
5547}
5548
5549pub fn tensor_product_design_from_marginals(
5550    marginal_designs: &[Array2<f64>],
5551) -> Result<Array2<f64>, BasisError> {
5552    if marginal_designs.is_empty() {
5553        crate::bail_invalid_basis!("TensorBSpline requires at least one marginal basis");
5554    }
5555    let n = marginal_designs[0].nrows();
5556    for (i, b) in marginal_designs.iter().enumerate().skip(1) {
5557        if b.nrows() != n {
5558            crate::bail_dim_basis!(
5559                "tensor marginal row mismatch at dim {i}: expected {n}, got {}",
5560                b.nrows()
5561            );
5562        }
5563    }
5564    let total_cols = marginal_designs.iter().try_fold(1usize, |acc, b| {
5565        acc.checked_mul(b.ncols())
5566            .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))
5567    })?;
5568    // Tensor-product Khatri-Rao: design[i, j] = Π_d marginal_d[i, j_d]
5569    // where j is the multi-index (j_1, ..., j_D) flattened. Independent
5570    // across rows; parallelize row chunks and fill the pre-allocated
5571    // contiguous Array2 in place (no Vec-flatten-collect intermediate,
5572    // which doubled the peak memory at large-scale N).
5573    use ndarray::parallel::prelude::*;
5574    use rayon::iter::{IntoParallelIterator, ParallelIterator};
5575    let mut design = Array2::<f64>::zeros((n, total_cols));
5576    design
5577        .axis_chunks_iter_mut(ndarray::Axis(0), 1024)
5578        .into_par_iter()
5579        .enumerate()
5580        .for_each(|(chunk_idx, mut block)| {
5581            let row_offset = chunk_idx * 1024;
5582            // Scratch buffers reused across rows in this chunk.
5583            let mut cur = Vec::<f64>::with_capacity(total_cols);
5584            let mut next = Vec::<f64>::with_capacity(total_cols);
5585            for (local_i, mut out_row) in block.outer_iter_mut().enumerate() {
5586                let i = row_offset + local_i;
5587                cur.clear();
5588                cur.push(1.0);
5589                for b in marginal_designs {
5590                    let q = b.ncols();
5591                    next.clear();
5592                    next.resize(cur.len() * q, 0.0);
5593                    // Hoist the row view out of the inner `col` loop so the
5594                    // q reads per `a_idx` reuse a single contiguous slice
5595                    // instead of recomputing `b[[i, col]]` strides per cell.
5596                    let b_row = b.row(i);
5597                    let b_slice = b_row
5598                        .as_slice()
5599                        .expect("Array2 row from outer_iter is contiguous");
5600                    for (a_idx, &aval) in cur.iter().enumerate() {
5601                        let off = a_idx * q;
5602                        let dst = &mut next[off..off + q];
5603                        for col in 0..q {
5604                            dst[col] = aval * b_slice[col];
5605                        }
5606                    }
5607                    std::mem::swap(&mut cur, &mut next);
5608                }
5609                // `out_row` is a row of the contiguous C-major `design`
5610                // Array2, so it is backed by a contiguous slice. Use a
5611                // bulk slice copy instead of an element-by-element write
5612                // loop.
5613                let out_slice = out_row
5614                    .as_slice_mut()
5615                    .expect("design row is contiguous in C-major Array2");
5616                out_slice.copy_from_slice(&cur);
5617            }
5618        });
5619    Ok(design)
5620}
5621
5622pub fn build_random_effect_block(
5623    data: ArrayView2<'_, f64>,
5624    spec: &RandomEffectTermSpec,
5625) -> Result<RandomEffectBlock, BasisError> {
5626    let n = data.nrows();
5627    let p = data.ncols();
5628    if spec.feature_col >= p {
5629        crate::bail_dim_basis!(
5630            "random-effect term '{}' feature column {} out of bounds for {} columns",
5631            spec.name,
5632            spec.feature_col,
5633            p
5634        );
5635    }
5636
5637    let col = data.column(spec.feature_col);
5638    if col.iter().any(|v| !v.is_finite()) {
5639        crate::bail_invalid_basis!(
5640            "random-effect term '{}' contains non-finite group values",
5641            spec.name
5642        );
5643    }
5644
5645    let kept_levels: Vec<u64> = if let Some(levels) = spec.frozen_levels.as_ref() {
5646        if levels.is_empty() {
5647            crate::bail_invalid_basis!(
5648                "random-effect term '{}' has empty frozen_levels",
5649                spec.name
5650            );
5651        }
5652        levels.clone()
5653    } else {
5654        let mut levels_set = BTreeSet::<u64>::new();
5655        for &v in col {
5656            levels_set.insert(v.to_bits());
5657        }
5658        if levels_set.is_empty() {
5659            crate::bail_invalid_basis!("random-effect term '{}' has no observed levels", spec.name);
5660        }
5661        let levels: Vec<u64> = levels_set.into_iter().collect();
5662        let start_idx = if spec.drop_first_level && levels.len() > 1 {
5663            1usize
5664        } else {
5665            0usize
5666        };
5667        levels[start_idx..].to_vec()
5668    };
5669
5670    if kept_levels.is_empty() {
5671        crate::bail_invalid_basis!(
5672            "random-effect term '{}' drops all levels; keep at least one level",
5673            spec.name
5674        );
5675    }
5676
5677    let q = kept_levels.len();
5678    let mut level_to_col = BTreeMap::<u64, usize>::new();
5679    for (idx, &bits) in kept_levels.iter().enumerate() {
5680        if level_to_col.insert(bits, idx).is_some() {
5681            crate::bail_invalid_basis!(
5682                "random-effect term '{}' has duplicate frozen level bits {bits}",
5683                spec.name
5684            );
5685        }
5686    }
5687    let mut group_ids = Vec::with_capacity(n);
5688    for &v in col {
5689        let bits = v.to_bits();
5690        group_ids.push(level_to_col.get(&bits).copied());
5691    }
5692
5693    Ok(RandomEffectBlock {
5694        name: spec.name.clone(),
5695        group_ids,
5696        num_groups: q,
5697        kept_levels,
5698    })
5699}
5700
5701impl SmoothDesign {
5702    /// Map an unconstrained term coefficient vector to its constrained shape space.
5703    /// This is useful for nonlinear fits that optimize unconstrained parameters.
5704    pub fn map_term_coefficients(
5705        unconstrained: &Array1<f64>,
5706        shape: ShapeConstraint,
5707    ) -> Result<Array1<f64>, BasisError> {
5708        if unconstrained.is_empty() {
5709            crate::bail_invalid_basis!("unconstrained coefficient vector cannot be empty");
5710        }
5711        let mapped = match shape {
5712            ShapeConstraint::None => unconstrained.clone(),
5713            ShapeConstraint::MonotoneIncreasing => cumulative_exp(unconstrained, 1.0),
5714            ShapeConstraint::MonotoneDecreasing => cumulative_exp(unconstrained, -1.0),
5715            ShapeConstraint::Convex => second_cumulative_exp(unconstrained, 1.0),
5716            ShapeConstraint::Concave => second_cumulative_exp(unconstrained, -1.0),
5717        };
5718        Ok(mapped)
5719    }
5720}
5721
5722pub struct LocalSmoothTermBuild {
5723    pub dim: usize,
5724    pub design: DesignMatrix,
5725    pub penalties: Vec<Array2<f64>>,
5726    pub ops: Vec<Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>>,
5727    pub nullspaces: Vec<usize>,
5728    /// Per-active-penalty null-space eigenvector matrices, parallel to
5729    /// `penalties` / `ops` / `nullspaces`. `Some(U_null)` when
5730    /// `nullspaces[k] > 0`, with `U_null` orthonormal columns spanning
5731    /// `null(penalties[k])` in this smooth's local coordinate system; `None`
5732    /// when the active block is full-rank. Stage 1 plumbing; Stage 2
5733    /// consumes this to absorb the smooth's null space into the parametric
5734    /// block at `TermCollectionDesign` construction.
5735    pub null_eigenvectors: Vec<Option<Array2<f64>>>,
5736    /// Joint-null absorption rotation for this smooth. `Some(rotation)`
5737    /// records `Q = [U_range | U_null]` spanning `null(Σ_k penalties[k])`,
5738    /// the joint null across all active penalty blocks on this smooth.
5739    /// `None` means the joint penalty is full-rank (joint nullity = 0) or
5740    /// there are no penalties. Stage-2 commit A: plumbing only — populated
5741    /// by commit B, applied by commit D.
5742    pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
5743    pub penaltyinfo: Vec<PenaltyInfo>,
5744    pub pre_dropped_penaltyinfo: Vec<PenaltyInfo>,
5745    pub metadata: BasisMetadata,
5746    pub linear_constraints: Option<LinearInequalityConstraints>,
5747    pub box_reparam: bool,
5748    pub kronecker_factored: Option<KroneckerFactoredBasis>,
5749}
5750
5751#[derive(Clone)]
5752pub struct PcaScoresMemmapDesignOperator {
5753    mmap: Arc<memmap2::Mmap>,
5754    data_offset: usize,
5755    nrows: usize,
5756    ncols: usize,
5757    chunk_size: usize,
5758}
5759
5760impl PcaScoresMemmapDesignOperator {
5761    fn open(path: PathBuf, chunk_size: usize) -> Result<Self, BasisError> {
5762        let file = File::open(&path).map_err(|err| {
5763            BasisError::InvalidInput(format!(
5764                "failed to open lazy Pca .npy scores '{}': {err}",
5765                path.display()
5766            ))
5767        })?;
5768        // The .npy scores file is read-only training-cache data; this
5769        // module never mutates it. The error path below converts mmap
5770        // failure to a typed `BasisError::InvalidInput`.
5771        // SAFETY: `memmap2::Mmap::map` requires no concurrent writers; the
5772        // contract is held by this module's read-only access pattern.
5773        let mmap = unsafe {
5774            memmap2::Mmap::map(&file).map_err(|err| {
5775                BasisError::InvalidInput(format!(
5776                    "failed to memmap lazy Pca .npy scores '{}': {err}",
5777                    path.display()
5778                ))
5779            })?
5780        };
5781        let (data_offset, nrows, ncols) = parse_f64_2d_npy_header(&mmap, &path)?;
5782        let expected = data_offset
5783            .checked_add(nrows.saturating_mul(ncols).saturating_mul(8))
5784            .ok_or_else(|| {
5785                BasisError::InvalidInput(format!(
5786                    "lazy Pca .npy scores '{}' shape is too large",
5787                    path.display()
5788                ))
5789            })?;
5790        if mmap.len() < expected {
5791            crate::bail_invalid_basis!(
5792                "lazy Pca .npy scores '{}' is truncated: header expects {} bytes, file has {}",
5793                path.display(),
5794                expected,
5795                mmap.len()
5796            );
5797        }
5798        Ok(Self {
5799            mmap: Arc::new(mmap),
5800            data_offset,
5801            nrows,
5802            ncols,
5803            chunk_size: chunk_size.max(1),
5804        })
5805    }
5806
5807    fn value(&self, row: usize, col: usize) -> f64 {
5808        let offset = self.data_offset + (row * self.ncols + col) * 8;
5809        let mut bytes = [0_u8; 8];
5810        bytes.copy_from_slice(&self.mmap[offset..offset + 8]);
5811        f64::from_le_bytes(bytes)
5812    }
5813
5814    fn chunk_rows(&self) -> usize {
5815        self.chunk_size.min(self.nrows.max(1))
5816    }
5817}
5818
5819impl LinearOperator for PcaScoresMemmapDesignOperator {
5820    fn nrows(&self) -> usize {
5821        self.nrows
5822    }
5823
5824    fn ncols(&self) -> usize {
5825        self.ncols
5826    }
5827
5828    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
5829        assert_eq!(
5830            vector.len(),
5831            self.ncols,
5832            "lazy Pca apply vector length mismatch"
5833        );
5834        let mut out = Array1::<f64>::zeros(self.nrows);
5835        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5836            let end = (start + self.chunk_rows()).min(self.nrows);
5837            for row in start..end {
5838                let mut acc = 0.0;
5839                for col in 0..self.ncols {
5840                    acc += self.value(row, col) * vector[col];
5841                }
5842                out[row] = acc;
5843            }
5844        }
5845        out
5846    }
5847
5848    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
5849        assert_eq!(
5850            vector.len(),
5851            self.nrows,
5852            "lazy Pca apply_transpose vector length mismatch"
5853        );
5854        let mut out = Array1::<f64>::zeros(self.ncols);
5855        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5856            let end = (start + self.chunk_rows()).min(self.nrows);
5857            for row in start..end {
5858                let scale = vector[row];
5859                if scale == 0.0 {
5860                    continue;
5861                }
5862                for col in 0..self.ncols {
5863                    out[col] += scale * self.value(row, col);
5864                }
5865            }
5866        }
5867        out
5868    }
5869
5870    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
5871        if weights.len() != self.nrows {
5872            return Err(format!(
5873                "lazy Pca diag_xtw_x weight length mismatch: weights={}, nrows={}",
5874                weights.len(),
5875                self.nrows
5876            ));
5877        }
5878        let mut gram = Array2::<f64>::zeros((self.ncols, self.ncols));
5879        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5880            let end = (start + self.chunk_rows()).min(self.nrows);
5881            for row in start..end {
5882                let w = weights[row];
5883                if w == 0.0 {
5884                    continue;
5885                }
5886                for a in 0..self.ncols {
5887                    let xa = self.value(row, a);
5888                    if xa == 0.0 {
5889                        continue;
5890                    }
5891                    for b in a..self.ncols {
5892                        gram[[a, b]] += w * xa * self.value(row, b);
5893                    }
5894                }
5895            }
5896        }
5897        for a in 0..self.ncols {
5898            for b in 0..a {
5899                gram[[a, b]] = gram[[b, a]];
5900            }
5901        }
5902        Ok(gram)
5903    }
5904
5905    fn apply_weighted_normal(
5906        &self,
5907        weights: &Array1<f64>,
5908        vector: &Array1<f64>,
5909        penalty: Option<&Array2<f64>>,
5910        ridge: f64,
5911    ) -> Array1<f64> {
5912        assert_eq!(
5913            weights.len(),
5914            self.nrows,
5915            "lazy Pca weighted-normal weight mismatch"
5916        );
5917        assert_eq!(
5918            vector.len(),
5919            self.ncols,
5920            "lazy Pca weighted-normal vector mismatch"
5921        );
5922        let mut out = Array1::<f64>::zeros(self.ncols);
5923        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5924            let end = (start + self.chunk_rows()).min(self.nrows);
5925            for row in start..end {
5926                let w = weights[row].max(0.0);
5927                if w == 0.0 {
5928                    continue;
5929                }
5930                let mut row_dot = 0.0;
5931                for col in 0..self.ncols {
5932                    row_dot += self.value(row, col) * vector[col];
5933                }
5934                if row_dot == 0.0 {
5935                    continue;
5936                }
5937                let scaled = w * row_dot;
5938                for col in 0..self.ncols {
5939                    out[col] += scaled * self.value(row, col);
5940                }
5941            }
5942        }
5943        if let Some(pen) = penalty {
5944            out += &pen.dot(vector);
5945        }
5946        if ridge > 0.0 {
5947            out += &vector.mapv(|x| ridge * x);
5948        }
5949        out
5950    }
5951}
5952
5953impl DenseDesignOperator for PcaScoresMemmapDesignOperator {
5954    fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
5955        if weights.len() != self.nrows || y.len() != self.nrows {
5956            return Err(format!(
5957                "lazy Pca compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
5958                weights.len(),
5959                y.len(),
5960                self.nrows
5961            ));
5962        }
5963        let mut out = Array1::<f64>::zeros(self.ncols);
5964        for start in (0..self.nrows).step_by(self.chunk_rows()) {
5965            let end = (start + self.chunk_rows()).min(self.nrows);
5966            for row in start..end {
5967                let scale = weights[row] * y[row];
5968                if scale == 0.0 {
5969                    continue;
5970                }
5971                for col in 0..self.ncols {
5972                    out[col] += scale * self.value(row, col);
5973                }
5974            }
5975        }
5976        Ok(out)
5977    }
5978
5979    fn row_chunk_into(
5980        &self,
5981        rows: Range<usize>,
5982        mut out: ArrayViewMut2<'_, f64>,
5983    ) -> Result<(), MatrixMaterializationError> {
5984        if rows.end > self.nrows || rows.start > rows.end {
5985            return Err(MatrixMaterializationError::MissingRowChunk {
5986                context: "lazy Pca row range out of bounds",
5987            });
5988        }
5989        if out.nrows() != rows.end - rows.start || out.ncols() != self.ncols {
5990            return Err(MatrixMaterializationError::MissingRowChunk {
5991                context: "lazy Pca row_chunk_into shape mismatch",
5992            });
5993        }
5994        for (local, row) in (rows.start..rows.end).enumerate() {
5995            for col in 0..self.ncols {
5996                out[[local, col]] = self.value(row, col);
5997            }
5998        }
5999        Ok(())
6000    }
6001
6002    fn to_dense(&self) -> Array2<f64> {
6003        let mut out = Array2::<f64>::zeros((self.nrows, self.ncols));
6004        self.row_chunk_into(0..self.nrows, out.view_mut())
6005            .expect("lazy Pca full materialization failed");
6006        out
6007    }
6008}
6009
6010pub fn parse_f64_2d_npy_header(
6011    bytes: &[u8],
6012    path: &PathBuf,
6013) -> Result<(usize, usize, usize), BasisError> {
6014    if bytes.len() < 10 || &bytes[0..6] != b"\x93NUMPY" {
6015        crate::bail_invalid_basis!("lazy Pca scores '{}' is not a .npy file", path.display());
6016    }
6017    let major = bytes[6];
6018    let header_len = match major {
6019        1 => u16::from_le_bytes([bytes[8], bytes[9]]) as usize,
6020        2 | 3 => {
6021            if bytes.len() < 12 {
6022                crate::bail_invalid_basis!(
6023                    "lazy Pca scores '{}' has a truncated .npy header",
6024                    path.display()
6025                );
6026            }
6027            u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize
6028        }
6029        other => {
6030            crate::bail_invalid_basis!(
6031                "lazy Pca scores '{}' uses unsupported .npy version {}",
6032                path.display(),
6033                other
6034            );
6035        }
6036    };
6037    let header_start = if major == 1 { 10 } else { 12 };
6038    let data_offset = header_start + header_len;
6039    if bytes.len() < data_offset {
6040        crate::bail_invalid_basis!(
6041            "lazy Pca scores '{}' has a truncated .npy header",
6042            path.display()
6043        );
6044    }
6045    let header = std::str::from_utf8(&bytes[header_start..data_offset]).map_err(|err| {
6046        BasisError::InvalidInput(format!(
6047            "lazy Pca scores '{}' has a non-UTF8 .npy header: {err}",
6048            path.display()
6049        ))
6050    })?;
6051    if !(header.contains("'descr': '<f8'")
6052        || header.contains("\"descr\": \"<f8\"")
6053        || header.contains("'descr': '|f8'")
6054        || header.contains("\"descr\": \"|f8\""))
6055    {
6056        crate::bail_invalid_basis!(
6057            "lazy Pca scores '{}' must be float64 little-endian .npy",
6058            path.display()
6059        );
6060    }
6061    if header.contains("True") {
6062        crate::bail_invalid_basis!(
6063            "lazy Pca scores '{}' must be C-contiguous, not Fortran-ordered",
6064            path.display()
6065        );
6066    }
6067    let shape_pos = header.find("shape").ok_or_else(|| {
6068        BasisError::InvalidInput(format!(
6069            "lazy Pca scores '{}' .npy header is missing shape",
6070            path.display()
6071        ))
6072    })?;
6073    let open = header[shape_pos..].find('(').ok_or_else(|| {
6074        BasisError::InvalidInput(format!(
6075            "lazy Pca scores '{}' .npy header has malformed shape",
6076            path.display()
6077        ))
6078    })? + shape_pos;
6079    let close = header[open..].find(')').ok_or_else(|| {
6080        BasisError::InvalidInput(format!(
6081            "lazy Pca scores '{}' .npy header has malformed shape",
6082            path.display()
6083        ))
6084    })? + open;
6085    let dims = header[open + 1..close]
6086        .split(',')
6087        .map(str::trim)
6088        .filter(|part| !part.is_empty())
6089        .map(|part| part.parse::<usize>())
6090        .collect::<Result<Vec<_>, _>>()
6091        .map_err(|err| {
6092            BasisError::InvalidInput(format!(
6093                "lazy Pca scores '{}' .npy shape is not integral: {err}",
6094                path.display()
6095            ))
6096        })?;
6097    if dims.len() != 2 {
6098        crate::bail_invalid_basis!(
6099            "lazy Pca scores '{}' must have shape (N, K), got {:?}",
6100            path.display(),
6101            dims
6102        );
6103    }
6104    Ok((data_offset, dims[0], dims[1]))
6105}
6106
6107pub fn pca_center_mean(x: ArrayView2<'_, f64>) -> Result<Array1<f64>, BasisError> {
6108    if x.nrows() == 0 {
6109        crate::bail_invalid_basis!("Pca basis requires at least one row to compute center mean");
6110    }
6111    let mut mean = Array1::<f64>::zeros(x.ncols());
6112    for row in x.rows() {
6113        mean += &row;
6114    }
6115    mean.mapv_inplace(|v| v / x.nrows() as f64);
6116    Ok(mean)
6117}
6118
6119pub fn build_pca_smooth_basis(
6120    data: ArrayView2<'_, f64>,
6121    feature_cols: &[usize],
6122    basis_matrix: &Array2<f64>,
6123    centered: bool,
6124    smooth_penalty: f64,
6125    center_mean: Option<&Array1<f64>>,
6126    pca_basis_path: Option<&PathBuf>,
6127    chunk_size: usize,
6128) -> Result<BasisBuildResult, BasisError> {
6129    if let Some(path) = pca_basis_path {
6130        let op = PcaScoresMemmapDesignOperator::open(path.clone(), chunk_size)?;
6131        if op.nrows != data.nrows() {
6132            crate::bail_dim_basis!(
6133                "lazy Pca scores row mismatch: .npy has {}, data has {}",
6134                op.nrows,
6135                data.nrows()
6136            );
6137        }
6138        let k = op.ncols;
6139        let mut penalty = Array2::<f64>::eye(k);
6140        penalty.mapv_inplace(|v| v * smooth_penalty);
6141        let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
6142            filter_active_penalty_candidates_with_ops(vec![PenaltyCandidate {
6143                matrix: penalty,
6144                nullspace_dim_hint: 0,
6145                source: PenaltySource::Other("PcaRidge".to_string()),
6146                normalization_scale: 1.0,
6147                kronecker_factors: None,
6148                op: None,
6149            }])?;
6150        return Ok(BasisBuildResult {
6151            design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op))),
6152            penalties,
6153            nullspace_dims,
6154            penaltyinfo,
6155            ops,
6156            null_eigenvectors,
6157            joint_null_rotation: None,
6158            metadata: BasisMetadata::Pca {
6159                feature_cols: feature_cols.to_vec(),
6160                basis_matrix: basis_matrix.clone(),
6161                centered,
6162                smooth_penalty,
6163                center_mean: center_mean.cloned(),
6164                pca_basis_path: Some(path.clone()),
6165                chunk_size: chunk_size.max(1),
6166            },
6167            kronecker_factored: None,
6168        });
6169    }
6170    if basis_matrix.nrows() != feature_cols.len() {
6171        crate::bail_dim_basis!(
6172            "Pca basis row mismatch: basis rows={}, feature columns={}",
6173            basis_matrix.nrows(),
6174            feature_cols.len()
6175        );
6176    }
6177    let mut x = select_columns(data, feature_cols)?;
6178    let mean = if centered {
6179        match center_mean {
6180            Some(mean) => mean.clone(),
6181            None => pca_center_mean(x.view())?,
6182        }
6183    } else {
6184        Array1::<f64>::zeros(feature_cols.len())
6185    };
6186    if centered {
6187        for mut row in x.rows_mut() {
6188            row -= &mean;
6189        }
6190    }
6191    let design = fast_ab(&x, basis_matrix);
6192    let k = basis_matrix.ncols();
6193    let mut penalty = Array2::<f64>::eye(k);
6194    penalty.mapv_inplace(|v| v * smooth_penalty);
6195    let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
6196        filter_active_penalty_candidates_with_ops(vec![PenaltyCandidate {
6197            matrix: penalty,
6198            nullspace_dim_hint: 0,
6199            source: PenaltySource::Other("PcaRidge".to_string()),
6200            normalization_scale: 1.0,
6201            kronecker_factors: None,
6202            op: None,
6203        }])?;
6204    Ok(BasisBuildResult {
6205        design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(design)),
6206        penalties,
6207        nullspace_dims,
6208        penaltyinfo,
6209        ops,
6210        null_eigenvectors,
6211        joint_null_rotation: None,
6212        metadata: BasisMetadata::Pca {
6213            feature_cols: feature_cols.to_vec(),
6214            basis_matrix: basis_matrix.clone(),
6215            centered,
6216            smooth_penalty,
6217            center_mean: centered.then_some(mean),
6218            pca_basis_path: None,
6219            chunk_size: chunk_size.max(1),
6220        },
6221        kronecker_factored: None,
6222    })
6223}
6224
6225/// A factor-level `by=` wrapper owns the model-space centering of its inner
6226/// smooth: it gates the raw/structurally-constrained basis to the level rows
6227/// and then centers that gated block exactly once against the level indicator
6228/// (`build_parametric_constraint_block_for_term` in `design_construction`).
6229/// Leaving the inner B-spline's default pooled weighted-sum-to-zero active here
6230/// would impose two generically-independent constraints — the pooled column
6231/// moment `m = Σ_h m_h` and the per-level moment `m_g` — so a raw `k`-column
6232/// basis collapses to `k-2` columns per level instead of `k-1`, deleting one
6233/// genuine nonconstant spline direction *before REML runs* (#1427). The group
6234/// main effect carries only the constant, so it cannot restore that direction.
6235///
6236/// Only the *default model-space* centering is deferred. Explicit structural or
6237/// frozen transforms (`RemoveLinearTrend`, `OrthogonalToDesignColumns`,
6238/// `FrozenTransform`, `None`) are user/structural choices and are preserved
6239/// verbatim.
6240pub fn defer_inner_model_centering_to_factor_level_wrapper(basis: &mut SmoothBasisSpec) {
6241    if let SmoothBasisSpec::BSpline1D { spec, .. } = basis
6242        && matches!(
6243            spec.identifiability,
6244            BSplineIdentifiability::WeightedSumToZero { .. }
6245        )
6246    {
6247        spec.identifiability = BSplineIdentifiability::None;
6248    }
6249}
6250
6251pub fn apply_by_variable_to_local_build(
6252    mut built: LocalSmoothTermBuild,
6253    data: ArrayView2<'_, f64>,
6254    by_col: usize,
6255    by: &ByVariableSpec,
6256    term_name: &str,
6257) -> Result<LocalSmoothTermBuild, BasisError> {
6258    if by_col >= data.ncols() {
6259        crate::bail_dim_basis!(
6260            "by-variable smooth term '{term_name}' references column {by_col}, but data has {} columns",
6261            data.ncols()
6262        );
6263    }
6264    let weights = match by {
6265        ByVariableSpec::Numeric => data.column(by_col).to_owned(),
6266        ByVariableSpec::Level { value_bits, .. } => data.column(by_col).mapv(|value| {
6267            if value.to_bits() == *value_bits {
6268                1.0
6269            } else {
6270                0.0
6271            }
6272        }),
6273    };
6274    if weights.iter().any(|value| !value.is_finite()) {
6275        crate::bail_invalid_basis!(
6276            "by-variable smooth term '{term_name}' has non-finite by-column values"
6277        );
6278    }
6279
6280    let mut dense = built
6281        .design
6282        .try_to_dense_by_chunks("by-variable smooth row gating")
6283        .map_err(BasisError::InvalidInput)?;
6284    for (mut row, &weight) in dense.rows_mut().into_iter().zip(weights.iter()) {
6285        row.mapv_inplace(|value| value * weight);
6286    }
6287    built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense));
6288    built.kronecker_factored = None;
6289    Ok(built)
6290}
6291
6292/// Build the local smooth term for a `BySmooth` spec, which unifies numeric-by
6293/// and factor-by modulation into a single `SmoothTermSpec`.
6294///
6295/// For a **numeric** by-variable the inner smooth is built once and every row
6296/// is multiplied by the by-column value (identical to `ByVariable::Numeric`).
6297///
6298/// For a **factor** by-variable the inner smooth is built once and gated per
6299/// level into side-by-side column blocks, producing a `n × (L * p)` design
6300/// matrix.  The penalties are block-diagonalised (one copy of the inner penalty
6301/// per level) exactly as `build_factor_smooth` does for `bs="fs"/"sz"`.
6302pub fn build_by_smooth_local(
6303    data: ArrayView2<'_, f64>,
6304    term: &SmoothTermSpec,
6305    smooth: &SmoothBasisSpec,
6306    by_kind: &ByVarKind,
6307    workspace: &mut crate::basis::BasisWorkspace,
6308) -> Result<LocalSmoothTermBuild, BasisError> {
6309    let inner_term = SmoothTermSpec {
6310        name: term.name.clone(),
6311        basis: (*smooth).clone(),
6312        shape: term.shape,
6313        joint_null_rotation: None,
6314    };
6315    let inner = build_single_local_smooth_term(data, &inner_term, workspace)?;
6316
6317    match by_kind {
6318        ByVarKind::Numeric { feature_col } => {
6319            let inner_meta = inner.metadata.clone();
6320            let mut built = apply_by_variable_to_local_build(
6321                inner,
6322                data,
6323                *feature_col,
6324                &ByVariableSpec::Numeric,
6325                &term.name,
6326            )?;
6327            built.metadata = BasisMetadata::BySmooth {
6328                inner: Box::new(inner_meta),
6329                by_col: *feature_col,
6330                levels: None,
6331                ordered: false,
6332            };
6333            Ok(built)
6334        }
6335        ByVarKind::Factor {
6336            feature_col,
6337            frozen_levels,
6338            ordered,
6339        } => {
6340            // Collect factor levels: prefer the frozen set (replay path), else
6341            // scan the data column (first-fit path).
6342            let level_bits: Vec<u64> = if let Some(fl) = frozen_levels {
6343                fl.clone()
6344            } else {
6345                let col = data.column(*feature_col);
6346                let mut seen = BTreeSet::<u64>::new();
6347                for &v in col.iter() {
6348                    if v.is_finite() {
6349                        seen.insert(v.to_bits());
6350                    }
6351                }
6352                seen.into_iter().collect()
6353            };
6354            let n_levels = level_bits.len();
6355            if n_levels == 0 {
6356                crate::bail_invalid_basis!(
6357                    "by-factor smooth term '{}': factor column {} has no observed levels",
6358                    term.name,
6359                    feature_col
6360                );
6361            }
6362            let p = inner.dim;
6363            let q = n_levels * p;
6364            let n = data.nrows();
6365
6366            let inner_dense = inner
6367                .design
6368                .try_to_dense_by_chunks("by-factor smooth design gating")
6369                .map_err(BasisError::InvalidInput)?;
6370
6371            // Gate each level into its own p-wide column block.
6372            let mut combined = Array2::<f64>::zeros((n, q));
6373            for (lvl_idx, &bits) in level_bits.iter().enumerate() {
6374                let col_start = lvl_idx * p;
6375                for row in 0..n {
6376                    if data[[row, *feature_col]].to_bits() == bits {
6377                        combined
6378                            .slice_mut(s![row, col_start..col_start + p])
6379                            .assign(&inner_dense.row(row));
6380                    }
6381                }
6382            }
6383
6384            // Build per-level INDEPENDENT penalties (#1427): one copy of each
6385            // inner penalty per level, but each confined to that single level's
6386            // diagonal block, so every (level, inner-penalty) pair is its OWN
6387            // smoothing-parameter coordinate. `s(x, by=g)` selects the per-group
6388            // curve wiggliness independently — the design is block-diagonal and
6389            // block-separable, so a correct REML must reproduce gamfit's own
6390            // independent per-group fits. Tiling a single inner penalty across
6391            // every level (as the `bs="fs"` shared-λ random-effect construction
6392            // does) collapses all groups onto ONE λ, which cannot match uneven
6393            // per-level smoothness and degrades as data grows (under-recovery up
6394            // to ~16× at n=2000). Emit `n_levels * n_penalties` blocks instead.
6395            let inner_meta = inner.metadata.clone();
6396            let n_penalties = inner.penalties.len();
6397            let n_blocks = n_penalties.saturating_mul(n_levels);
6398            let mut penalties = Vec::<Array2<f64>>::with_capacity(n_blocks);
6399            let mut penaltyinfo = Vec::<PenaltyInfo>::with_capacity(n_blocks);
6400            let mut nullspaces = Vec::<usize>::with_capacity(n_blocks);
6401            for (pen_pos, s_inner) in inner.penalties.iter().enumerate() {
6402                for lvl in 0..n_levels {
6403                    let off = lvl * p;
6404                    let mut s_big = Array2::<f64>::zeros((q, q));
6405                    s_big
6406                        .slice_mut(s![off..off + p, off..off + p])
6407                        .assign(s_inner);
6408                    let (s_big, scale) = normalize_penalty_in_constrained_space(&s_big);
6409                    let mut info = inner.penaltyinfo[pen_pos].clone();
6410                    // Distinct original_index per (penalty, level) so each λ is a
6411                    // separate identifiable coordinate downstream.
6412                    info.original_index = pen_pos * n_levels + lvl;
6413                    info.normalization_scale *= scale;
6414                    // Each block now spans exactly ONE level → per-level nullity,
6415                    // not the tiled (× n_levels) hint of the shared construction.
6416                    info.kronecker_factors = None;
6417                    penalties.push(s_big);
6418                    penaltyinfo.push(info);
6419                    nullspaces.push(inner.nullspaces[pen_pos]);
6420                }
6421            }
6422
6423            let null_eigenvectors = vec![None; penalties.len()];
6424            let ops = vec![None; penalties.len()];
6425
6426            Ok(LocalSmoothTermBuild {
6427                dim: q,
6428                design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(combined)),
6429                penalties,
6430                ops,
6431                nullspaces,
6432                null_eigenvectors,
6433                joint_null_rotation: None,
6434                penaltyinfo,
6435                pre_dropped_penaltyinfo: inner.pre_dropped_penaltyinfo,
6436                metadata: BasisMetadata::BySmooth {
6437                    inner: Box::new(inner_meta),
6438                    by_col: *feature_col,
6439                    levels: Some(level_bits),
6440                    ordered: *ordered,
6441                },
6442                linear_constraints: None,
6443                box_reparam: false,
6444                kronecker_factored: None,
6445            })
6446        }
6447    }
6448}
6449
6450pub fn ensure_by_variable_specs_match(
6451    kind: &BySmoothKind,
6452    by: &ByVariableSpec,
6453    term_name: &str,
6454) -> Result<(), BasisError> {
6455    match (kind, by) {
6456        (BySmoothKind::Numeric, ByVariableSpec::Numeric) => Ok(()),
6457        (BySmoothKind::Level { level_bits }, ByVariableSpec::Level { value_bits, .. })
6458            if level_bits == value_bits =>
6459        {
6460            Ok(())
6461        }
6462        _ => Err(BasisError::InvalidInput(format!(
6463            "by-variable smooth term '{term_name}' has inconsistent by-variable specifications"
6464        ))),
6465    }
6466}
6467
6468/// Build a factor-smooth interaction basis (`bs="fs"`/`"sz"`/`"re"`).
6469///
6470/// A factor smooth replicates a shared marginal smooth in the continuous
6471/// covariate(s) once per level of a grouping factor, coupling all level blocks
6472/// through a *single* set of smoothing parameters (one per marginal penalty).
6473/// This is mgcv's `smooth.construct.fs.smooth.spec` realization and the
6474/// random-effect interpretation of a smooth: the per-level deviations are an
6475/// exchangeable family whose joint wiggliness/shrinkage is governed by the
6476/// shared λ, so the construction scales to many levels with a fixed parameter
6477/// count.
6478///
6479/// Flavours:
6480/// * `Fs` — full random factor-smooth. The marginal carries its wiggliness
6481///   penalty *and* a null-space ridge (double penalty), so the replicated
6482///   design is a proper full-rank random effect: each level's curve is shrunk
6483///   toward zero (intercept + linear trend included), recovering the mgcv
6484///   `bs="fs"` penalty structure `I_L ⊗ S_j` for every marginal penalty `S_j`.
6485/// * `Sz` — sum-to-zero factor smooth. Delegates to the existing
6486///   [`SmoothBasisSpec::FactorSumToZero`] construction (`L-1` deviation blocks,
6487///   coefficient-wise zero sum across levels).
6488/// * `Re` — pure random effect / random slope (`bs="re"`). A degree-1 marginal
6489///   gives the per-level `[1, x]` span; the penalty is the identity over each
6490///   level block (iid Gaussian coefficients), matching mgcv's `bs="re"` ridge.
6491///
6492/// The grouping levels are resolved once at fit time (sorted unique bit
6493/// patterns of the factor column) and frozen into the returned metadata so the
6494/// predict-time rebuild evaluates every row against its own level's block.
6495pub fn build_factor_smooth(
6496    data: ArrayView2<'_, f64>,
6497    spec: &FactorSmoothSpec,
6498    term_name: &str,
6499    workspace: &mut crate::basis::BasisWorkspace,
6500) -> Result<LocalSmoothTermBuild, BasisError> {
6501    if spec.continuous_cols.len() != 1 {
6502        crate::bail_invalid_basis!(
6503            "factor smooth term '{}' currently supports exactly one continuous covariate; found {}",
6504            term_name,
6505            spec.continuous_cols.len()
6506        );
6507    }
6508    let feature_col = spec.continuous_cols[0];
6509    let group_col = spec.group_col;
6510    if feature_col >= data.ncols() || group_col >= data.ncols() {
6511        crate::bail_dim_basis!(
6512            "factor smooth term '{}' references columns ({}, {}) out of bounds for {} columns",
6513            term_name,
6514            feature_col,
6515            group_col,
6516            data.ncols()
6517        );
6518    }
6519
6520    // `Sz` is exactly the existing sum-to-zero factor smooth: reuse it verbatim
6521    // so there is a single source of truth for the zero-sum construction.
6522    if matches!(spec.flavour, FactorSmoothFlavour::Sz) {
6523        let levels = resolve_factor_smooth_levels(data, group_col, spec, term_name)?;
6524        let inner = SmoothBasisSpec::BSpline1D {
6525            feature_col,
6526            spec: factor_smooth_marginal_for_replay(&spec.marginal),
6527        };
6528        let sz_term = SmoothTermSpec {
6529            name: term_name.to_string(),
6530            basis: SmoothBasisSpec::FactorSumToZero {
6531                inner: Box::new(inner),
6532                by_col: group_col,
6533                levels: levels.clone(),
6534                frozen_global_orthogonality: None,
6535            },
6536            shape: ShapeConstraint::None,
6537            joint_null_rotation: None,
6538        };
6539        let mut built = build_single_local_smooth_term(data, &sz_term, workspace)?;
6540        // The delegated `FactorSumToZero` build returns the BARE inner B-spline
6541        // metadata (`BasisMetadata::BSpline1D`), but the term that owns this
6542        // build carries a `SmoothBasisSpec::FactorSmooth { Sz }` spec. Two
6543        // things break if we hand that mismatched pair downstream:
6544        //   1. `freeze_smooth_basis_from_metadata` matches on (spec, metadata)
6545        //      and has no `(FactorSmooth, BSpline1D)` arm, so any refit / spatial
6546        //      re-optimization that freezes the basis aborts with a "smooth
6547        //      metadata/spec type mismatch" error.
6548        //   2. The bare B-spline metadata carries no grouping levels, so a
6549        //      predict-time rebuild cannot replay the SAME replicated design.
6550        // Re-wrap the marginal geometry as `FactorSmooth` metadata exactly as
6551        // the Fs/Re path below does, giving all three factor-smooth flavours a
6552        // single, freeze-consistent metadata shape that also pins the levels.
6553        // Since #1605 the sz marginal is ALWAYS the penalized B-spline the `fs`
6554        // sibling uses (a natural cubic regression marginal hard-enforces f''=0
6555        // at the boundary and cannot represent curved deviations — a consistency
6556        // failure). The `CubicRegression1D` arm below is therefore unreachable on
6557        // a freshly-built sz spec; it is retained only as defense / backward
6558        // compatibility for a frozen spec that still carries a cr marginal, so
6559        // the predict-time freeze restores whatever marginal class it finds.
6560        let (knots, degree, periodic, marginal_is_cr) = match &built.metadata {
6561            BasisMetadata::BSpline1D {
6562                knots,
6563                periodic,
6564                degree,
6565                ..
6566            } => (
6567                knots.clone(),
6568                degree.unwrap_or(spec.marginal.degree),
6569                *periodic,
6570                false,
6571            ),
6572            BasisMetadata::CubicRegression1D { knots, .. } => {
6573                (knots.clone(), spec.marginal.degree, None, true)
6574            }
6575            other => {
6576                crate::bail_invalid_basis!(
6577                    "sz factor smooth term '{}' produced an unexpected marginal metadata variant {:?}",
6578                    term_name,
6579                    other
6580                );
6581            }
6582        };
6583        built.metadata = BasisMetadata::FactorSmooth {
6584            continuous_cols: spec.continuous_cols.clone(),
6585            group_col,
6586            knots,
6587            degree,
6588            periodic,
6589            group_levels: levels,
6590            flavour: "sz".to_string(),
6591            marginal_is_cr,
6592        };
6593        return Ok(built);
6594    }
6595
6596    let levels = resolve_factor_smooth_levels(data, group_col, spec, term_name)?;
6597    let n_levels = levels.len();
6598    if n_levels < 2 {
6599        crate::bail_invalid_basis!(
6600            "factor smooth term '{}' requires at least two grouping levels; found {}",
6601            term_name,
6602            n_levels
6603        );
6604    }
6605
6606    // `Fs` (order ≥ 1, the default) is the random-effect flavour: it penalizes
6607    // each null-space dimension of the marginal wiggliness penalty separately
6608    // below (mgcv's `bs="fs"` construction). That replaces the marginal's single
6609    // *combined* double penalty, so disable the latter here to avoid penalizing
6610    // the null space twice (once combined, once per dimension). The explicit
6611    // `m=0` opt-out keeps the legacy combined double penalty and adds no
6612    // per-dimension penalties.
6613    let use_per_dim_null = matches!(
6614        &spec.flavour,
6615        FactorSmoothFlavour::Fs { m_null_penalty_orders }
6616            if m_null_penalty_orders.iter().copied().max().unwrap_or(0) >= 1
6617    );
6618
6619    // Build the shared marginal design + penalties from the 1-D B-spline.
6620    // `Re` forces a degree-1 marginal (linear span) and replaces the marginal
6621    // wiggliness with an identity ridge below; `Fs` keeps the user's marginal
6622    // (cubic by default) and, under the per-dimension null path, gets its null
6623    // space penalized one dimension at a time after replication.
6624    let mut marginal_spec = factor_smooth_marginal_for_replay(&spec.marginal);
6625    if use_per_dim_null {
6626        marginal_spec.double_penalty = false;
6627    }
6628    let inner_term = SmoothTermSpec {
6629        name: format!("{term_name}::marginal"),
6630        basis: SmoothBasisSpec::BSpline1D {
6631            feature_col,
6632            spec: marginal_spec,
6633        },
6634        shape: ShapeConstraint::None,
6635        joint_null_rotation: None,
6636    };
6637    let inner = build_single_local_smooth_term(data, &inner_term, workspace)?;
6638    let base = inner
6639        .design
6640        .try_to_dense_by_chunks("factor smooth marginal")
6641        .map_err(BasisError::InvalidInput)?;
6642    let n = base.nrows();
6643    let p = base.ncols();
6644    let q = p * n_levels;
6645
6646    // Block-diagonal replicated design: row i contributes its marginal row to
6647    // the column block owned by its grouping level, zeros elsewhere.
6648    let mut dense = Array2::<f64>::zeros((n, q));
6649    for i in 0..n {
6650        let bits = data[[i, group_col]].to_bits();
6651        let level_idx = levels.iter().position(|b| *b == bits).ok_or_else(|| {
6652            BasisError::InvalidInput(format!(
6653                "factor smooth term '{term_name}' saw an unseen grouping level at row {}",
6654                i + 1
6655            ))
6656        })?;
6657        let start = level_idx * p;
6658        dense
6659            .slice_mut(s![i, start..start + p])
6660            .assign(&base.row(i));
6661    }
6662
6663    // Penalties: replicate each marginal penalty into a block-diagonal
6664    // `I_L ⊗ S_j` so every level shares the same smoothing parameter λ_j (one
6665    // λ per marginal penalty), the defining feature of a factor smooth. For
6666    // `Re` the marginal penalty is replaced by an identity ridge so each
6667    // per-level coefficient is an iid Gaussian random effect.
6668    let marginal_penalties: Vec<Array2<f64>> = if matches!(spec.flavour, FactorSmoothFlavour::Re) {
6669        vec![Array2::<f64>::eye(p)]
6670    } else {
6671        inner.penalties.clone()
6672    };
6673    let marginal_penaltyinfo: Vec<PenaltyInfo> = if matches!(spec.flavour, FactorSmoothFlavour::Re)
6674    {
6675        vec![PenaltyInfo {
6676            source: PenaltySource::Primary,
6677            original_index: 0,
6678            active: true,
6679            effective_rank: p,
6680            dropped_reason: None,
6681            nullspace_dim_hint: 0,
6682            normalization_scale: 1.0,
6683            kronecker_factors: None,
6684        }]
6685    } else {
6686        inner.penaltyinfo.clone()
6687    };
6688    if marginal_penalties.len() != marginal_penaltyinfo.len() {
6689        crate::bail_invalid_basis!(
6690            "internal factor-smooth penalty metadata mismatch for term '{}': penalties={}, infos={}",
6691            term_name,
6692            marginal_penalties.len(),
6693            marginal_penaltyinfo.len()
6694        );
6695    }
6696
6697    let mut penalties = Vec::<Array2<f64>>::with_capacity(marginal_penalties.len());
6698    let mut penaltyinfo = Vec::<PenaltyInfo>::with_capacity(marginal_penalties.len());
6699    for (penalty_pos, s_inner) in marginal_penalties.iter().enumerate() {
6700        let mut s_big = Array2::<f64>::zeros((q, q));
6701        for level in 0..n_levels {
6702            let start = level * p;
6703            s_big
6704                .slice_mut(s![start..start + p, start..start + p])
6705                .assign(s_inner);
6706        }
6707        let (s_big, factor_smooth_scale) = normalize_penalty_in_constrained_space(&s_big);
6708        let mut info = marginal_penaltyinfo[penalty_pos].clone();
6709        info.original_index = penalty_pos;
6710        info.normalization_scale *= factor_smooth_scale;
6711        info.nullspace_dim_hint = info.nullspace_dim_hint.saturating_mul(n_levels);
6712        info.kronecker_factors = None;
6713        penalties.push(s_big);
6714        penaltyinfo.push(info);
6715    }
6716
6717    let mut nullspaces: Vec<usize> = if matches!(spec.flavour, FactorSmoothFlavour::Re) {
6718        vec![0]
6719    } else {
6720        inner
6721            .nullspaces
6722            .iter()
6723            .map(|ns| ns.saturating_mul(n_levels))
6724            .collect()
6725    };
6726
6727    // `Fs` is the random-effect flavour of a smooth: the per-group curve is an
6728    // exchangeable Gaussian *function*, so EVERY coefficient — including the
6729    // {const, linear} null space of the marginal wiggliness penalty — must be
6730    // shrinkable toward zero under its own shared variance. The wiggliness
6731    // penalty `S_wiggle` shapes curvature but leaves the per-group intercept and
6732    // slope (its null space) completely UNPENALIZED. With the null space free,
6733    // each group fits its own intercept and slope with NO partial pooling, so
6734    // the held-out per-subject forecast inherits the full no-pooling variance
6735    // and curves away from the true per-group line (gam#712 real arm, gam#713;
6736    // gam#903 sleepstudy forecast ran ~74% over the lme4 BLUP bar).
6737    //
6738    // mgcv's `bs="fs"` fixes this by penalizing each null-space dimension
6739    // SEPARATELY (`smooth.construct.fs.smooth.spec` adds one rank-1 penalty per
6740    // null coordinate), each replicated block-diagonally across levels under a
6741    // single shared smoothing parameter — so REML fits a distinct
6742    // random-intercept variance and random-slope variance, the partial pooling
6743    // that makes the forecast track lme4's correlated random-effect BLUP. A
6744    // single *combined* null penalty (one λ for intercept+slope together) cannot
6745    // express the typically very different intercept and slope variances, which
6746    // is the residual forecast gap. We mirror mgcv exactly: for each orthonormal
6747    // null direction `z_k` of the marginal wiggliness penalty add
6748    // `I_L ⊗ (z_k z_kᵀ)` as its own penalty. The marginal's combined double
6749    // penalty was disabled above, so the null space is penalized once, per
6750    // dimension. With linear data REML drives the curvature λ up and degrades
6751    // `fs` to a linear random slope (edf → ≈2/group); with genuine curvature the
6752    // wiggliness λ stays small and the wiggle survives (data-adaptive, not a
6753    // cap). Gated by `m_null_penalty_orders`: order ≥ 1 (default) enables the
6754    // per-dimension null penalties; `m=0` keeps the legacy combined double
6755    // penalty and adds nothing here.
6756    if use_per_dim_null
6757        && let Some(Some(z)) = inner.null_eigenvectors.first()
6758        && z.nrows() == p
6759    {
6760        for k in 0..z.ncols() {
6761            // Rank-1 marginal penalty `z_k z_kᵀ`, replicated block-diagonally
6762            // across levels into `I_L ⊗ (z_k z_kᵀ)`. Its own λ is one shared
6763            // variance for this null component (intercept or slope) across all
6764            // groups — the random-effect structure of mgcv `fs`.
6765            let zk = z.column(k);
6766            let mut p_k = Array2::<f64>::zeros((p, p));
6767            for a in 0..p {
6768                for b in 0..p {
6769                    p_k[[a, b]] = zk[a] * zk[b];
6770                }
6771            }
6772            let mut s_null = Array2::<f64>::zeros((q, q));
6773            for level in 0..n_levels {
6774                let start = level * p;
6775                s_null
6776                    .slice_mut(s![start..start + p, start..start + p])
6777                    .assign(&p_k);
6778            }
6779            let (s_null, null_scale) = normalize_penalty_in_constrained_space(&s_null);
6780            let null_block = crate::basis::analyze_penalty_block_with_op(&s_null, None)?;
6781            if null_block.rank > 0 {
6782                let original_index = penalties.len();
6783                penalties.push(null_block.sym_penalty);
6784                nullspaces.push(null_block.nullity);
6785                penaltyinfo.push(PenaltyInfo {
6786                    source: PenaltySource::Primary,
6787                    original_index,
6788                    active: true,
6789                    effective_rank: null_block.rank,
6790                    dropped_reason: None,
6791                    nullspace_dim_hint: null_block.nullity,
6792                    normalization_scale: null_scale,
6793                    kronecker_factors: None,
6794                });
6795            }
6796        }
6797    }
6798    let null_eigenvectors = crate::basis::recompute_null_eigenvectors(&penalties)?;
6799    let joint_null_rotation = crate::basis::compute_joint_null_rotation(&penalties)?;
6800
6801    // Metadata: carry the marginal knot geometry + frozen levels so prediction
6802    // reconstructs an identical replicated design.
6803    let (knots, degree, periodic) = match &inner.metadata {
6804        BasisMetadata::BSpline1D {
6805            knots,
6806            periodic,
6807            degree,
6808            ..
6809        } => (
6810            knots.clone(),
6811            degree.unwrap_or(spec.marginal.degree),
6812            *periodic,
6813        ),
6814        other => {
6815            crate::bail_invalid_basis!(
6816                "factor smooth term '{}' produced an unexpected marginal metadata variant {:?}",
6817                term_name,
6818                other
6819            );
6820        }
6821    };
6822    let flavour_tag = match &spec.flavour {
6823        FactorSmoothFlavour::Fs { .. } => "fs",
6824        FactorSmoothFlavour::Sz => "sz",
6825        FactorSmoothFlavour::Re => "re",
6826    }
6827    .to_string();
6828    let metadata = BasisMetadata::FactorSmooth {
6829        continuous_cols: spec.continuous_cols.clone(),
6830        group_col,
6831        knots,
6832        degree,
6833        periodic,
6834        group_levels: levels,
6835        flavour: flavour_tag,
6836        // fs/re marginals are always B-spline; the cr marginal is sz-only and
6837        // handled on the dedicated Sz path above.
6838        marginal_is_cr: false,
6839    };
6840
6841    let ops = vec![None; penalties.len()];
6842    Ok(LocalSmoothTermBuild {
6843        dim: q,
6844        design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense)),
6845        penalties,
6846        ops,
6847        nullspaces,
6848        null_eigenvectors,
6849        joint_null_rotation,
6850        penaltyinfo,
6851        pre_dropped_penaltyinfo: Vec::new(),
6852        metadata,
6853        linear_constraints: None,
6854        box_reparam: false,
6855        kronecker_factored: None,
6856    })
6857}
6858
6859/// Resolve the grouping levels for a factor smooth: replay the frozen level
6860/// list when present (predict path), otherwise discover the sorted unique bit
6861/// patterns of the factor column (fit path).
6862pub fn resolve_factor_smooth_levels(
6863    data: ArrayView2<'_, f64>,
6864    group_col: usize,
6865    spec: &FactorSmoothSpec,
6866    term_name: &str,
6867) -> Result<Vec<u64>, BasisError> {
6868    if let Some(frozen) = &spec.group_frozen_levels {
6869        if frozen.is_empty() {
6870            crate::bail_invalid_basis!(
6871                "factor smooth term '{}' has an empty frozen level list",
6872                term_name
6873            );
6874        }
6875        return Ok(frozen.clone());
6876    }
6877    let mut bits: Vec<u64> = data.column(group_col).iter().map(|v| v.to_bits()).collect();
6878    bits.sort_by(|a, b| {
6879        f64::from_bits(*a)
6880            .partial_cmp(&f64::from_bits(*b))
6881            .unwrap_or(std::cmp::Ordering::Equal)
6882    });
6883    bits.dedup();
6884    Ok(bits)
6885}
6886
6887/// Marginal B-spline spec for a factor-smooth block. The marginal always builds
6888/// without an identifiability constraint (the per-level replication, not a
6889/// sum-to-zero side constraint, provides identifiability against the parametric
6890/// block). At predict time the marginal's knot geometry has already been pinned
6891/// into `marginal.knotspec` by the metadata replay, so the spec is used
6892/// verbatim aside from clearing the identifiability transform.
6893pub fn factor_smooth_marginal_for_replay(marginal: &BSplineBasisSpec) -> BSplineBasisSpec {
6894    let mut m = marginal.clone();
6895    m.identifiability = BSplineIdentifiability::None;
6896    m
6897}
6898
6899pub fn build_single_local_smooth_term(
6900    data: ArrayView2<'_, f64>,
6901    term: &SmoothTermSpec,
6902    workspace: &mut crate::basis::BasisWorkspace,
6903) -> Result<LocalSmoothTermBuild, BasisError> {
6904    if term.shape != ShapeConstraint::None && !shape_supports_basis(term) {
6905        crate::bail_invalid_basis!(
6906            "ShapeConstraint::{:?} is unsupported for term '{}'",
6907            term.shape,
6908            term.name
6909        );
6910    }
6911    if let SmoothBasisSpec::ByVariable {
6912        inner,
6913        by_col,
6914        kind,
6915        by,
6916    } = &term.basis
6917    {
6918        ensure_by_variable_specs_match(kind, by, &term.name)?;
6919        let mut inner_basis = (**inner).clone();
6920        // Factor-level `by=` owns model-space centering (it centers the gated
6921        // block against the level indicator downstream). Defer the inner
6922        // basis's default pooled centering so the level block is not
6923        // double-centered down to `k-2` columns (#1427). Numeric-by smooths are
6924        // untouched: they are not row-gated to a level and keep ordinary
6925        // intercept centering.
6926        if matches!(by, ByVariableSpec::Level { .. }) {
6927            defer_inner_model_centering_to_factor_level_wrapper(&mut inner_basis);
6928        }
6929        let inner_term = SmoothTermSpec {
6930            name: term.name.clone(),
6931            basis: inner_basis,
6932            shape: term.shape,
6933            joint_null_rotation: None,
6934        };
6935        let built = build_single_local_smooth_term(data, &inner_term, workspace)?;
6936        return apply_by_variable_to_local_build(built, data, *by_col, by, &term.name);
6937    }
6938
6939    // BySmooth: a `by=` smooth that unifies numeric or factor modulation into a
6940    // single term.  Lower it here so the downstream match does not need an arm.
6941    if let SmoothBasisSpec::BySmooth { smooth, by_kind } = &term.basis {
6942        return build_by_smooth_local(data, term, smooth, by_kind, workspace);
6943    }
6944
6945    let mut shape_axis_col: Option<usize> = None;
6946    let mut built: BasisBuildResult = match &term.basis {
6947        SmoothBasisSpec::FactorSumToZero {
6948            inner,
6949            by_col,
6950            levels,
6951            ..
6952        } => {
6953            if *by_col >= data.ncols() {
6954                crate::bail_dim_basis!(
6955                    "term '{}' by column {} out of bounds for {} columns",
6956                    term.name,
6957                    by_col,
6958                    data.ncols()
6959                );
6960            }
6961            if levels.len() < 2 {
6962                crate::bail_invalid_basis!(
6963                    "sum-to-zero factor smooth term '{}' requires at least two levels",
6964                    term.name
6965                );
6966            }
6967            if term.shape != ShapeConstraint::None {
6968                crate::bail_invalid_basis!(
6969                    "ShapeConstraint::{:?} is unsupported for sum-to-zero factor smooth term '{}'",
6970                    term.shape,
6971                    term.name
6972                );
6973            }
6974            let inner_term = SmoothTermSpec {
6975                name: format!("{}::inner", term.name),
6976                basis: (**inner).clone(),
6977                shape: ShapeConstraint::None,
6978                joint_null_rotation: None,
6979            };
6980            let mut inner_built = build_single_local_smooth_term(data, &inner_term, workspace)?;
6981            // Capture the marginal penalty's null directions BEFORE the penalty
6982            // vector is rebuilt below; the sum-to-zero null-space ridge replicates
6983            // these `z_k` into the contrast space (mgcv `bs="fs"` double-penalty).
6984            let inner_null_eigenvectors = inner_built.null_eigenvectors.clone();
6985            let base = inner_built
6986                .design
6987                .try_to_dense_by_chunks("sum-to-zero factor smooth")
6988                .map_err(BasisError::InvalidInput)?;
6989            let n = base.nrows();
6990            let p = base.ncols();
6991            let l_minus_one = levels.len() - 1;
6992            let mut dense = Array2::<f64>::zeros((n, p * l_minus_one));
6993            for i in 0..n {
6994                let bits = data[[i, *by_col]].to_bits();
6995                let level_idx = levels.iter().position(|b| *b == bits).ok_or_else(|| {
6996                    BasisError::InvalidInput(format!(
6997                        "sum-to-zero factor smooth term '{}' saw an unseen level at row {}",
6998                        term.name,
6999                        i + 1
7000                    ))
7001                })?;
7002                if level_idx < l_minus_one {
7003                    let start = level_idx * p;
7004                    dense
7005                        .slice_mut(s![i, start..start + p])
7006                        .assign(&base.row(i));
7007                } else {
7008                    for level in 0..l_minus_one {
7009                        let start = level * p;
7010                        dense
7011                            .slice_mut(s![i, start..start + p])
7012                            .assign(&base.row(i).mapv(|v| -v));
7013                    }
7014                }
7015            }
7016            let mut penalties = Vec::<Array2<f64>>::with_capacity(inner_built.penalties.len());
7017            let active_penalty_indices = inner_built
7018                .penaltyinfo
7019                .iter()
7020                .enumerate()
7021                .filter_map(|(idx, info)| info.active.then_some(idx))
7022                .collect::<Vec<_>>();
7023            if active_penalty_indices.len() != inner_built.penalties.len() {
7024                crate::bail_invalid_basis!(
7025                    "internal sz penalty metadata mismatch: activeinfos={}, penalties={}",
7026                    active_penalty_indices.len(),
7027                    inner_built.penalties.len()
7028                );
7029            }
7030            // Replicate each marginal penalty into the sum-to-zero contrast
7031            // space. With `L-1` free deviation blocks and the reference level
7032            // `d_L = -Σ_{k<L} d_k`, the marginal penalty summed over ALL `L`
7033            // levels, `Σ_{k=1}^{L} d_kᵀ S d_k`, expands to the `(I + 11ᵀ) ⊗ S`
7034            // contrast form (factor 2 on the diagonal blocks, 1 off-diagonal) —
7035            // the exact penalty the zero-sum reparameterization induces.
7036            let stz_contrast_penalty = |s_inner: &Array2<f64>| -> Array2<f64> {
7037                let mut s_big = Array2::<f64>::zeros((p * l_minus_one, p * l_minus_one));
7038                for a in 0..l_minus_one {
7039                    for b in 0..l_minus_one {
7040                        let factor = if a == b { 2.0 } else { 1.0 };
7041                        let mut block = s_big.slice_mut(s![a * p..(a + 1) * p, b * p..(b + 1) * p]);
7042                        block.assign(&s_inner.mapv(|v| v * factor));
7043                    }
7044                }
7045                s_big
7046            };
7047            // One nullspace-dim entry per emitted penalty (must stay parallel to
7048            // `penalties`); the wiggliness blocks inherit the marginal nullity
7049            // scaled by `l_minus_one`, the null ridges record their own nullity.
7050            let mut nullspaces = Vec::<usize>::with_capacity(penalties.capacity());
7051            for (penalty_pos, s_inner) in inner_built.penalties.iter().enumerate() {
7052                let (s_big, factor_smooth_scale) =
7053                    normalize_penalty_in_constrained_space(&stz_contrast_penalty(s_inner));
7054                let info_idx = active_penalty_indices[penalty_pos];
7055                inner_built.penaltyinfo[info_idx].normalization_scale *= factor_smooth_scale;
7056                penalties.push(s_big);
7057                nullspaces.push(
7058                    inner_built
7059                        .nullspaces
7060                        .get(penalty_pos)
7061                        .copied()
7062                        .unwrap_or(0)
7063                        .saturating_mul(l_minus_one),
7064                );
7065            }
7066
7067            // Null-space ridge, mirroring the `bs="fs"` double-penalty
7068            // construction (#1605, same defect class as #700/#712/#713). The
7069            // marginal wiggliness penalty `S` shapes curvature but leaves the
7070            // {const, linear} null space of each deviation curve COMPLETELY
7071            // unpenalized. With that null space free, the single combined
7072            // wiggliness smoothing parameter cannot separate the per-group
7073            // intercept/slope variance from the curvature variance, so REML
7074            // parks the wiggliness `λ` high — over-smoothing (under-fitting) the
7075            // deviation blocks even when the truth lives in their span (the `sz`
7076            // recovery gap vs the `fs` superset). mgcv's `bs="fs"` fixes the
7077            // analogous gap by penalizing each null-space dimension SEPARATELY
7078            // under its own shared variance; we mirror that here while keeping
7079            // the zero-sum reparameterization, so the constraint (and the
7080            // identifiability of `sz` vs `fs`) is preserved. For each orthonormal
7081            // null direction `z_k` of the marginal penalty, add the rank-1
7082            // marginal penalty `z_k z_kᵀ` mapped into the SAME `(I + 11ᵀ)`
7083            // sum-to-zero contrast space, each carrying its own `λ`.
7084            if let Some(Some(z)) = inner_null_eigenvectors.first()
7085                && z.nrows() == p
7086            {
7087                for k in 0..z.ncols() {
7088                    let zk = z.column(k);
7089                    let mut p_k = Array2::<f64>::zeros((p, p));
7090                    for a in 0..p {
7091                        for b in 0..p {
7092                            p_k[[a, b]] = zk[a] * zk[b];
7093                        }
7094                    }
7095                    let (s_null, null_scale) =
7096                        normalize_penalty_in_constrained_space(&stz_contrast_penalty(&p_k));
7097                    let null_block = crate::basis::analyze_penalty_block_with_op(&s_null, None)?;
7098                    if null_block.rank > 0 {
7099                        let original_index = penalties.len();
7100                        penalties.push(null_block.sym_penalty);
7101                        nullspaces.push(null_block.nullity);
7102                        inner_built.penaltyinfo.push(PenaltyInfo {
7103                            source: PenaltySource::Primary,
7104                            original_index,
7105                            active: true,
7106                            effective_rank: null_block.rank,
7107                            dropped_reason: None,
7108                            nullspace_dim_hint: null_block.nullity,
7109                            normalization_scale: null_scale,
7110                            kronecker_factors: None,
7111                        });
7112                    }
7113                }
7114            }
7115            inner_built.dim = p * l_minus_one;
7116            inner_built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense));
7117            inner_built.penalties = penalties;
7118            inner_built.ops = vec![None; inner_built.penalties.len()];
7119            inner_built.nullspaces = nullspaces;
7120            // Invariant: `null_eigenvectors[k]` must mirror `penalties[k]`'s
7121            // spectral null space. We just rebuilt `inner_built.penalties` from
7122            // Kronecker-like `S_big` blocks, so the previously-plumbed
7123            // `null_eigenvectors` (still parallel to the OLD per-level penalty)
7124            // is stale. Recompute from the rebuilt penalties to restore the
7125            // invariant; ditto for the joint-null absorption rotation.
7126            inner_built.null_eigenvectors =
7127                crate::basis::recompute_null_eigenvectors(&inner_built.penalties)?;
7128            inner_built.joint_null_rotation =
7129                crate::basis::compute_joint_null_rotation(&inner_built.penalties)?;
7130            inner_built.kronecker_factored = None;
7131            return Ok(inner_built);
7132        }
7133        SmoothBasisSpec::BSpline1D { feature_col, spec } => {
7134            if *feature_col >= data.ncols() {
7135                crate::bail_dim_basis!(
7136                    "term '{}' feature column {} out of bounds for {} columns",
7137                    term.name,
7138                    feature_col,
7139                    data.ncols()
7140                );
7141            }
7142            let mut spec_local = spec.clone();
7143            if term.shape != ShapeConstraint::None {
7144                // Shape-constrained B-splines are anchored by construction.
7145                // Sum-to-zero side constraints conflict with monotonic/convex cones.
7146                spec_local.identifiability = BSplineIdentifiability::None;
7147            }
7148            // Endpoint boundary conditions are structural for B-splines: the
7149            // basis builder bakes their homogeneous nullspace transform into
7150            // the design, penalties, and stored raw-basis transform.
7151            build_bspline_basis_1d(data.column(*feature_col), &spec_local)?
7152        }
7153        SmoothBasisSpec::ThinPlate {
7154            feature_cols,
7155            spec,
7156            input_scales,
7157        } => {
7158            if term.shape != ShapeConstraint::None {
7159                if feature_cols.len() != 1 {
7160                    crate::bail_invalid_basis!(
7161                        "ShapeConstraint::{:?} for term '{}' on ThinPlate basis requires exactly 1 feature axis; found {}",
7162                        term.shape,
7163                        term.name,
7164                        feature_cols.len()
7165                    );
7166                }
7167                shape_axis_col = Some(feature_cols[0]);
7168            }
7169            let mut x = select_columns(data, feature_cols)?;
7170            // Auto-standardize multivariate inputs: use stored scales (prediction)
7171            // or compute fresh ones (training). Same standardization-vs-
7172            // length-scale compensation as Matérn / hybrid Duchon: divide
7173            // the user's L by σ_geom so kernel(‖x_std − c_std‖/L_eff)
7174            // matches the original-coord kernel for uniform σ.
7175            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7176                apply_input_standardization(&mut x, s);
7177                (
7178                    Some(s.clone()),
7179                    compensate_length_scale_for_standardization(spec.length_scale, s),
7180                )
7181            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7182                apply_input_standardization(&mut x, &s);
7183                let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
7184                (Some(s), l_eff)
7185            } else {
7186                (None, spec.length_scale)
7187            };
7188            let mut spec_local = spec.clone();
7189            spec_local.length_scale = length_scale_eff;
7190            if matches!(
7191                spec_local.identifiability,
7192                SpatialIdentifiability::OrthogonalToParametric
7193            ) {
7194                spec_local.identifiability = SpatialIdentifiability::None;
7195            }
7196            let mut result = build_thin_plate_basis(x.view(), &spec_local).map_err(|err| {
7197                rewrite_thin_plate_knots_error(err, &term.name, feature_cols.len(), spec)
7198            })?;
7199            // Inject input scales into metadata; also restore the user's
7200            // original length_scale (not the σ_geom-compensated one) so a
7201            // metadata-driven rebuild that re-applies compensation does not
7202            // double-divide. The build may auto-promote to Duchon when
7203            // canonical TPS is infeasible (k < polynomial-nullspace size);
7204            // in that case patch the Duchon metadata variant so predict-time
7205            // round-trips through the same standardized data path.
7206            match &mut result.metadata {
7207                BasisMetadata::ThinPlate {
7208                    input_scales: ms,
7209                    length_scale,
7210                    ..
7211                } => {
7212                    *ms = scales;
7213                    *length_scale = spec.length_scale;
7214                }
7215                BasisMetadata::Duchon {
7216                    input_scales: ms,
7217                    length_scale,
7218                    ..
7219                } => {
7220                    *ms = scales;
7221                    // The ThinPlate auto-promotion path delegates to
7222                    // `build_duchon_basis` with `Some(spec_local.length_scale)`,
7223                    // which is the σ_geom-compensated value. The metadata
7224                    // therefore records the compensated kernel range, but the
7225                    // freeze→replay round trip plugs that value back into a
7226                    // user-facing `DuchonBasisSpec.length_scale` whose builder
7227                    // applies the σ_geom compensation a second time. Restore
7228                    // the user-facing scale here so replay re-compensates
7229                    // exactly once and reproduces the realized fit-time basis.
7230                    *length_scale = Some(spec.length_scale);
7231                }
7232                _ => {}
7233            }
7234            result
7235        }
7236        SmoothBasisSpec::Sphere { feature_cols, spec } => {
7237            if term.shape != ShapeConstraint::None {
7238                crate::bail_invalid_basis!(
7239                    "ShapeConstraint::{:?} for term '{}' is not supported on spherical splines",
7240                    term.shape,
7241                    term.name
7242                );
7243            }
7244            let x = select_columns(data, feature_cols)?;
7245            build_spherical_spline_basis(x.view(), spec)?
7246        }
7247        SmoothBasisSpec::ConstantCurvature { feature_cols, spec } => {
7248            if term.shape != ShapeConstraint::None {
7249                crate::bail_invalid_basis!(
7250                    "ShapeConstraint::{:?} for term '{}' is not supported on constant-curvature smooths",
7251                    term.shape,
7252                    term.name
7253                );
7254            }
7255            // Chart coordinates are consumed verbatim: NO auto-standardization.
7256            // Rescaling axes would change the chart gauge `1 + κ‖x‖²` and
7257            // silently redefine which curvature κ refers to (the same point
7258            // cloud at a different chart scale has a different κ̂); the user's
7259            // coordinates ARE the geometry here, exactly as for the sphere
7260            // smooth's (lat, lon).
7261            let x = select_columns(data, feature_cols)?;
7262            build_constant_curvature_basis(x.view(), spec)?
7263        }
7264        SmoothBasisSpec::MeasureJet {
7265            feature_cols,
7266            spec,
7267            input_scales,
7268        } => {
7269            if term.shape != ShapeConstraint::None {
7270                crate::bail_invalid_basis!(
7271                    "ShapeConstraint::{:?} for term '{}' is not supported on measure-jet smooths",
7272                    term.shape,
7273                    term.name
7274                );
7275            }
7276            let mut x = select_columns(data, feature_cols)?;
7277            // Matern-style per-axis standardization; the realized σ vector is
7278            // persisted into the metadata for predict-time replay.
7279            //
7280            // Length-scale round-trip contract (owning statement; the freeze
7281            // and frozen-validation arms reference it): `input_scales: Some`
7282            // marks the REPLAY path — the frozen length_scale is already the
7283            // realized post-standardization value and passes through
7284            // verbatim. Fresh path: an explicit user length_scale is in
7285            // ORIGINAL coordinates and gets the σ_geom compensation; the 0.0
7286            // auto sentinel passes through (auto-derivation runs inside the
7287            // builder, post-standardization).
7288            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7289                apply_input_standardization(&mut x, s);
7290                (Some(s.clone()), spec.length_scale)
7291            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7292                apply_input_standardization(&mut x, &s);
7293                let l_eff = if spec.length_scale > 0.0 {
7294                    compensate_length_scale_for_standardization(spec.length_scale, &s)
7295                } else {
7296                    spec.length_scale
7297                };
7298                (Some(s), l_eff)
7299            } else {
7300                (None, spec.length_scale)
7301            };
7302            let mut spec_local = spec.clone();
7303            spec_local.length_scale = length_scale_eff;
7304            let mut result = build_measure_jet_basis(x.view(), &spec_local)?;
7305            if let BasisMetadata::MeasureJet {
7306                input_scales: ms, ..
7307            } = &mut result.metadata
7308            {
7309                *ms = scales;
7310            }
7311            result
7312        }
7313        SmoothBasisSpec::Matern {
7314            feature_cols,
7315            spec,
7316            input_scales,
7317        } => {
7318            if term.shape != ShapeConstraint::None {
7319                if feature_cols.len() != 1 {
7320                    crate::bail_invalid_basis!(
7321                        "ShapeConstraint::{:?} for term '{}' on Matern basis requires exactly 1 feature axis; found {}",
7322                        term.shape,
7323                        term.name,
7324                        feature_cols.len()
7325                    );
7326                }
7327                shape_axis_col = Some(feature_cols[0]);
7328            }
7329            let mut x = select_columns(data, feature_cols)?;
7330            // Auto-standardization (per-axis division by σ_a) reinterprets
7331            // the user's `length_scale` from original data coordinates
7332            // into post-standardization coordinates: for uniform σ_a = σ,
7333            // `kernel(‖x_std − c_std‖/L)` equals `kernel(‖x − c‖/(σ·L))`,
7334            // so the effective kernel range shrinks by σ. To keep
7335            // `length_scale` consistently expressed in *original* data
7336            // coordinates regardless of axis variances, we standardize
7337            // and divide L by σ_geom = (∏σ_a)^(1/d). For uniform σ this
7338            // recovers the user's kernel exactly; for anisotropic data
7339            // the resulting per-axis effective scales σ_a / σ_geom are
7340            // the standard Mahalanobis preconditioning and preserve the
7341            // geometric-mean kernel range. Storing the σ vector in
7342            // metadata.input_scales makes the same transformation
7343            // replayable at predict time.
7344            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7345                apply_input_standardization(&mut x, s);
7346                (
7347                    Some(s.clone()),
7348                    compensate_length_scale_for_standardization(spec.length_scale, s),
7349                )
7350            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7351                apply_input_standardization(&mut x, &s);
7352                let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
7353                (Some(s), l_eff)
7354            } else {
7355                (None, spec.length_scale)
7356            };
7357            let mut spec_local = spec.clone();
7358            spec_local.length_scale = length_scale_eff;
7359            let mut result = build_matern_basiswithworkspace(x.view(), &spec_local, workspace)?;
7360            if let BasisMetadata::Matern {
7361                input_scales,
7362                length_scale,
7363                ..
7364            } = &mut result.metadata
7365            {
7366                *input_scales = scales;
7367                *length_scale = spec.length_scale;
7368            }
7369            result
7370        }
7371        SmoothBasisSpec::Duchon {
7372            feature_cols,
7373            spec,
7374            input_scales,
7375        } => {
7376            if term.shape != ShapeConstraint::None {
7377                if feature_cols.len() != 1 {
7378                    crate::bail_invalid_basis!(
7379                        "ShapeConstraint::{:?} for term '{}' on Duchon basis requires exactly 1 feature axis; found {}",
7380                        term.shape,
7381                        term.name,
7382                        feature_cols.len()
7383                    );
7384                }
7385                shape_axis_col = Some(feature_cols[0]);
7386            }
7387            let mut x = select_columns(data, feature_cols)?;
7388            // Hybrid Duchon (length_scale=Some) is governed by the same
7389            // standardization-vs-length-scale equivalence as Matérn: the
7390            // user's `length_scale` is interpreted in original data
7391            // coordinates, but auto-standardization (per-axis division by
7392            // σ_a) reinterprets it as σ_geom · L. Pre-multiply by 1/σ_geom
7393            // so kernel(‖x_std − c_std‖/L_eff) reproduces the user's
7394            // original-coord kernel exactly for uniform σ_a, and reduces
7395            // to standard Mahalanobis preconditioning for anisotropic σ.
7396            // Pure Duchon (length_scale=None) is scale-free and needs no
7397            // compensation.
7398            let (scales, length_scale_eff) = if let Some(s) = input_scales {
7399                apply_input_standardization(&mut x, s);
7400                (
7401                    Some(s.clone()),
7402                    compensate_optional_length_scale_for_standardization(spec.length_scale, s),
7403                )
7404            } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7405                apply_input_standardization(&mut x, &s);
7406                let l_eff =
7407                    compensate_optional_length_scale_for_standardization(spec.length_scale, &s);
7408                (Some(s), l_eff)
7409            } else {
7410                (None, spec.length_scale)
7411            };
7412            let mut spec_local = spec.clone();
7413            spec_local.length_scale = length_scale_eff;
7414            if matches!(
7415                spec_local.identifiability,
7416                SpatialIdentifiability::OrthogonalToParametric
7417            ) {
7418                spec_local.identifiability = SpatialIdentifiability::None;
7419            }
7420            let mut result = build_duchon_basiswithworkspace(x.view(), &spec_local, workspace)?;
7421            if let BasisMetadata::Duchon {
7422                input_scales,
7423                length_scale,
7424                ..
7425            } = &mut result.metadata
7426            {
7427                *input_scales = scales;
7428                *length_scale = spec.length_scale;
7429            }
7430            result
7431        }
7432        SmoothBasisSpec::Pca {
7433            feature_cols,
7434            basis_matrix,
7435            centered,
7436            smooth_penalty,
7437            center_mean,
7438            pca_basis_path,
7439            chunk_size,
7440        } => {
7441            if term.shape != ShapeConstraint::None {
7442                crate::bail_invalid_basis!(
7443                    "ShapeConstraint::{:?} for term '{}' is not supported on Pca basis",
7444                    term.shape,
7445                    term.name
7446                );
7447            }
7448            build_pca_smooth_basis(
7449                data,
7450                feature_cols,
7451                basis_matrix,
7452                *centered,
7453                *smooth_penalty,
7454                center_mean.as_ref(),
7455                pca_basis_path.as_ref(),
7456                *chunk_size,
7457            )?
7458        }
7459        SmoothBasisSpec::TensorBSpline { feature_cols, spec } => {
7460            build_tensor_bspline_basis(data, feature_cols, spec)?
7461        }
7462        SmoothBasisSpec::ByVariable { .. } => {
7463            crate::bail_invalid_basis!(
7464                "internal: ByVariable smooths must return before inner basis dispatch"
7465            );
7466        }
7467        SmoothBasisSpec::BySmooth { .. } => {
7468            crate::bail_invalid_basis!("internal: BySmooth smooths must be lowered to ByVariable before inner basis dispatch"
7469                    .to_string(),);
7470        }
7471        SmoothBasisSpec::FactorSmooth { spec } => {
7472            if term.shape != ShapeConstraint::None {
7473                crate::bail_invalid_basis!(
7474                    "ShapeConstraint::{:?} is unsupported for factor smooth term '{}'",
7475                    term.shape,
7476                    term.name
7477                );
7478            }
7479            return build_factor_smooth(data, spec, &term.name, workspace);
7480        }
7481    };
7482
7483    // The Matérn design ALWAYS uses the operator-collocation {mass, tension,
7484    // stiffness} penalty triplet, overriding whatever penalty
7485    // `build_matern_basis_seeded` produced for the `double_penalty` flag.
7486    //
7487    // #1074 investigated swapping this for the genuine RKHS kernel penalty
7488    // `β' K_CC β` (mgcv `bs="gp"` / fields kriging) on the theory that the
7489    // operator triplet under-smooths the rougher half-integer kernels. MSI
7490    // truth-recovery measurement REFUTED that: the kernel penalty did NOT
7491    // improve ν=3/2 recovery (`matern(x,nu=1.5)` RMSE-vs-truth stayed 0.0554)
7492    // and it REGRESSED the high-frequency-init guard — `matern(x,nu≥5/2)` on
7493    // sin(2π·8·x) collapsed (span 0.53, RMSE 0.70) because the single RKHS
7494    // norm over-smooths a high-frequency truth where the Sobolev-order operator
7495    // dials do not. The operator triplet is therefore retained as the Matérn
7496    // penalty, and the κ-optimizer re-key / ψ-derivative paths route through the
7497    // same triplet builder so the block count stays ψ-stable (#1270).
7498    if let SmoothBasisSpec::Matern { .. } = &term.basis {
7499        let (penalties, nullspace_dims, penaltyinfo) =
7500            matern_operator_penalty_triplet_from_metadata(&built.metadata)?;
7501        built.penalties = penalties;
7502        built.nullspace_dims = nullspace_dims;
7503        built.penaltyinfo = penaltyinfo;
7504    }
7505
7506    let p_local = built.design.ncols();
7507    let mut metadata = built.metadata.clone();
7508    // Extract factored Kronecker representation before consuming fields.
7509    // Invalidate it if shape transforms will be applied (they break structure).
7510    let kron_factored = if term.shape == ShapeConstraint::None {
7511        built.kronecker_factored
7512    } else {
7513        None
7514    };
7515    let mut design_t = built.design;
7516    let mut penalties_t: Vec<Array2<f64>> = built.penalties;
7517    // Ops vector parallel to `penalties_t`. Survives unchanged through the
7518    // identity path; nulled element-wise when `T^T S T` reparametrization
7519    // is applied (operator no longer bit-equivalent to the transformed
7520    // matrix); wrapped in `ScaledPenaltyOp` after Frobenius normalization.
7521    let mut ops_t: Vec<Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>> =
7522        built.ops;
7523    if matches!(
7524        spatial_identifiability_policy(term),
7525        Some(SpatialIdentifiability::OrthogonalToParametric)
7526    ) {
7527        metadata = freeze_raw_spatial_metadata(metadata, design_t.ncols());
7528    }
7529
7530    let active_penaltyinfo_t = built
7531        .penaltyinfo
7532        .iter()
7533        .filter(|info| info.active)
7534        .cloned()
7535        .collect::<Vec<_>>();
7536    let pre_dropped_penaltyinfo_t = built
7537        .penaltyinfo
7538        .iter()
7539        .filter(|info| !info.active)
7540        .cloned()
7541        .collect::<Vec<_>>();
7542    let use_box_reparam =
7543        term.shape != ShapeConstraint::None && shape_uses_box_reparameterization(&term.basis);
7544    if let Some((order, sign)) = shape_order_and_sign(term.shape)
7545        && use_box_reparam
7546    {
7547        // Order 1 (monotone): the plain first-difference cone θ_{i+1}−θ_i ≥ 0 is
7548        // the control-polygon monotonicity criterion, which is independent of
7549        // Greville-abscissa spacing (it only fixes the *sign* of consecutive
7550        // control-point gaps), so the integer-difference transform is exact.
7551        //
7552        // Order 2 (convex/concave): the plain second-difference cone is only
7553        // correct for evenly spaced Greville abscissae. gam's B-splines are
7554        // clamped (and may use quantile knots), so the abscissae are not
7555        // uniform and the geometrically-correct cone is the second *divided*
7556        // difference. Build the Greville-scaled transform so γ_{≥2} ≥ 0
7557        // certifies convexity of the function, not of the raw coefficient
7558        // index. Periodic B-splines use uniform interior knots (uniform
7559        // abscissae), where the divided differences coincide with the integer
7560        // differences up to scale, so the plain path stays exact there.
7561        let t = if order == 2 {
7562            let bspline_meta = match &metadata {
7563                BasisMetadata::BSpline1D {
7564                    knots,
7565                    degree,
7566                    periodic,
7567                    ..
7568                } if periodic.is_none() => Some((knots.clone(), degree.unwrap_or(0))),
7569                _ => None,
7570            };
7571            match bspline_meta {
7572                Some((knots, degree)) if degree >= 1 => {
7573                    let greville = crate::basis::compute_greville_abscissae(&knots, degree)?;
7574                    if greville.len() != p_local {
7575                        crate::bail_invalid_basis!(
7576                            "shape-constraint Greville abscissae count {} does not match basis dim {} for term '{}'",
7577                            greville.len(),
7578                            p_local,
7579                            term.name
7580                        );
7581                    }
7582                    convex_divided_difference_transform_matrix(&greville, sign)?
7583                }
7584                _ => cumulative_sum_transform_matrix(p_local, order, sign),
7585            }
7586        } else {
7587            cumulative_sum_transform_matrix(p_local, order, sign)
7588        };
7589        // Coefficient-side transform: wrap the design in an operator that
7590        // applies T on the coefficient side, preserving sparsity/operator
7591        // structure of the inner design.
7592        let inner_dense = match design_t {
7593            DesignMatrix::Dense(d) => d,
7594            DesignMatrix::Sparse(sp) => gam_linalg::matrix::DenseDesignMatrix::from(
7595                sp.try_to_dense_arc("shape-constrained coefficient transform")
7596                    .map_err(BasisError::InvalidInput)?,
7597            ),
7598        };
7599        let coeff_op = gam_linalg::matrix::CoefficientTransformOperator::new(inner_dense, t.clone())
7600            .map_err(|e| BasisError::InvalidInput(format!("CoefficientTransformOperator: {e}")))?;
7601        design_t = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(coeff_op)));
7602        if penalties_t.len() != active_penaltyinfo_t.len() {
7603            crate::bail_invalid_basis!(
7604                "internal box-reparam penalty/info mismatch for term '{}': penalties={}, infos={}",
7605                term.name,
7606                penalties_t.len(),
7607                active_penaltyinfo_t.len()
7608            );
7609        }
7610        // Wiggliness penalties undergo the exact congruence `S → TᵀST` (PSD
7611        // preserving). The double-penalty *nullspace shrinkage* ridge must NOT:
7612        // it is a unit-eigenvalue projector `ZZᵀ` onto null(S_wiggle) in the
7613        // β (B-spline coefficient) coordinates, and the congruence
7614        // `Tᵀ(ZZᵀ)T = (TᵀZ)(TᵀZ)ᵀ` is no longer a projector — its eigenvalues
7615        // blow up by the conditioning of the cumulative-sum `T` (cond(T) grows
7616        // with the basis dim), concentrating an enormous penalty on the leading
7617        // γ₀ "level" coordinate. REML then drives the shared λ to its ceiling
7618        // and the smooth collapses to a flat constant (#509, the over-smoothing
7619        // face). The principled fix keeps mgcv's double-penalty semantics in the
7620        // *reparametrized* space: rebuild the ridge as the unit-eigenvalue
7621        // nullspace projector of the transformed wiggliness penalty `TᵀST`, so
7622        // the double penalty shrinks exactly the unpenalized polynomial
7623        // directions of the γ-space smooth with eigenvalue 1, identical in
7624        // conditioning to the unconstrained fit.
7625        let transformed_wiggliness = penalties_t
7626            .iter()
7627            .zip(active_penaltyinfo_t.iter())
7628            .find(|(_, info)| !matches!(info.source, PenaltySource::DoublePenaltyNullspace))
7629            .map(|(s_local, _)| {
7630                let tt_s = fast_atb(&t, s_local);
7631                fast_ab(&tt_s, &t)
7632            });
7633        let mut rebuilt = Vec::with_capacity(penalties_t.len());
7634        for (s_local, info) in penalties_t.iter().zip(active_penaltyinfo_t.iter()) {
7635            if matches!(info.source, PenaltySource::DoublePenaltyNullspace) {
7636                let s_wiggle_t = transformed_wiggliness.as_ref().ok_or_else(|| {
7637                    BasisError::InvalidInput(format!(
7638                        "box-reparam term '{}' has a double-penalty ridge but no primary wiggliness penalty to derive its nullspace from",
7639                        term.name
7640                    ))
7641                })?;
7642                let ridge = crate::basis::build_nullspace_shrinkage_penalty(s_wiggle_t)?
7643                    .map(|shrink| shrink.sym_penalty)
7644                    .unwrap_or_else(|| Array2::<f64>::zeros((p_local, p_local)));
7645                rebuilt.push(ridge);
7646            } else {
7647                let tt_s = fast_atb(&t, s_local);
7648                rebuilt.push(fast_ab(&tt_s, &t));
7649            }
7650        }
7651        penalties_t = rebuilt;
7652        // T^T S T (and the rebuilt γ-space ridge) invalidate op-form
7653        // bit-equivalence; drop ops here.
7654        ops_t = vec![None; penalties_t.len()];
7655    }
7656    if penalties_t.len() != active_penaltyinfo_t.len() {
7657        crate::bail_invalid_basis!(
7658            "internal penalty metadata mismatch for term '{}': active penalties={}, active infos={}",
7659            term.name,
7660            penalties_t.len(),
7661            active_penaltyinfo_t.len()
7662        );
7663    }
7664    if ops_t.len() != penalties_t.len() {
7665        ops_t = vec![None; penalties_t.len()];
7666    }
7667    let penalty_candidates = penalties_t
7668        .into_iter()
7669        .zip(active_penaltyinfo_t.into_iter())
7670        .zip(ops_t.into_iter())
7671        .map(
7672            |((matrix, info), op_in)| -> Result<PenaltyCandidate, BasisError> {
7673                let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
7674                let normalization_scale = info.normalization_scale * c_new;
7675                let op_scale = 1.0 / c_new;
7676                let kronecker_scale = 1.0 / c_new;
7677                // Frobenius rescale: wrap inner op in `ScaledPenaltyOp(1/c_new)`
7678                // so `op.as_dense() == matrix` post-normalization.
7679                let scaled_op = if op_scale > 0.0 && op_scale.is_finite() {
7680                    op_in.map(|op| {
7681                        std::sync::Arc::new(crate::analytic_penalties::ScaledPenaltyOp::new(
7682                            op, op_scale,
7683                        ))
7684                            as std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>
7685                    })
7686                } else {
7687                    None
7688                };
7689                let kronecker_factors = info.kronecker_factors.map(|mut factors| {
7690                    if let Some(first) = factors.first_mut() {
7691                        first.mapv_inplace(|v| v * kronecker_scale);
7692                    }
7693                    factors
7694                });
7695                Ok(PenaltyCandidate {
7696                    nullspace_dim_hint: info.nullspace_dim_hint,
7697                    matrix,
7698                    source: info.source,
7699                    normalization_scale,
7700                    kronecker_factors,
7701                    op: scaled_op,
7702                })
7703            },
7704        )
7705        .collect::<Result<Vec<_>, _>>()?;
7706    let (penalties_t, nullspaces_t, penaltyinfo_t, null_eigenvectors_t, ops_t) =
7707        crate::basis::filter_active_penalty_candidates_with_ops(penalty_candidates)?;
7708    let shape_linear_constraints = if term.shape != ShapeConstraint::None && !use_box_reparam {
7709        let axis = shape_axis_col.ok_or_else(|| {
7710            BasisError::InvalidInput(format!(
7711                "internal shape-constraint axis missing for term '{}'",
7712                term.name
7713            ))
7714        })?;
7715        let (x_shape_eval, design_shape_eval) =
7716            build_shape_constraint_design_1d(data, term, &metadata, axis)?;
7717        build_shape_linear_constraints_1d(
7718            x_shape_eval.view(),
7719            design_shape_eval.view(),
7720            term.shape,
7721        )?
7722    } else {
7723        None
7724    };
7725    let linear_constraints_local = merge_linear_constraints_global(shape_linear_constraints, None);
7726
7727    // Joint-null absorption rotation. Fresh fit specs compute Q from the final
7728    // per-smooth penalty set (after all in-smooth reparameterizations have
7729    // already been applied). Frozen specs already carry the complete realized
7730    // coefficient chart in their `FrozenTransform`; recomputing Q there would
7731    // rotate an already-frozen chart a second time and desynchronize value
7732    // rebuilds from derivative operators.
7733    //
7734    // Kronecker-factored smooths (tensor B-splines under `TensorBSplineIdentifiability::None`)
7735    // carry their joint penalty as `Σ_d S_d` with `S_d = I ⊗ … ⊗ S_d^{1D} ⊗ … ⊗ I`.
7736    // The joint null space is the tensor of marginal nulls and is handled directly
7737    // by the REML runtime's `kronecker_penalty_system` path (see
7738    // `runtime.rs:8334-8344`). Applying a dense (p × p) Q here would densify
7739    // `X_raw = mx ⊗ my` into `X_raw · Q`, destroying the Kronecker product
7740    // structure that the runtime relies on for fast log-det/derivative
7741    // assembly — and the rotation block at the wrapper site also unconditionally
7742    // wipes `kronecker_factored`, leaving the runtime to fall back to the
7743    // dense per-block log-det. Skip the rotation for Kronecker-factored terms
7744    // so the factored representation survives end-to-end.
7745    let joint_null_rotation = match term.joint_null_rotation.clone() {
7746        Some(persisted) => Some(persisted),
7747        None if smooth_has_frozen_identifiability(term) => None,
7748        None if kron_factored.is_some() => None,
7749        None => crate::basis::compute_joint_null_rotation(&penalties_t)?,
7750    };
7751
7752    Ok(LocalSmoothTermBuild {
7753        dim: p_local,
7754        design: design_t,
7755        penalties: penalties_t,
7756        ops: ops_t,
7757        nullspaces: nullspaces_t,
7758        null_eigenvectors: null_eigenvectors_t,
7759        joint_null_rotation,
7760        penaltyinfo: penaltyinfo_t,
7761        pre_dropped_penaltyinfo: pre_dropped_penaltyinfo_t,
7762        metadata,
7763        linear_constraints: linear_constraints_local,
7764        box_reparam: use_box_reparam,
7765        kronecker_factored: kron_factored,
7766    })
7767}
7768
7769pub fn build_smooth_design(
7770    data: ArrayView2<'_, f64>,
7771    terms: &[SmoothTermSpec],
7772) -> Result<RawSmoothDesign, BasisError> {
7773    let mut ws = crate::basis::BasisWorkspace::new();
7774    build_smooth_design_withworkspace(data, terms, &mut ws)
7775}
7776
7777/// Like `build_smooth_design`, but honors the caller workspace policy while
7778/// building each planned smooth term with an independent per-term workspace.
7779///
7780/// Independent workspaces avoid shared mutable distance-cache state during the
7781/// parallel term build; the final design, penalties, and metadata are assembled
7782/// in the original smooth-term order.
7783pub fn build_smooth_design_withworkspace(
7784    data: ArrayView2<'_, f64>,
7785    terms: &[SmoothTermSpec],
7786    workspace: &mut crate::basis::BasisWorkspace,
7787) -> Result<RawSmoothDesign, BasisError> {
7788    validate_smooth_terms_finite_inputs(data, terms)?;
7789    build_smooth_design_withworkspace_unvalidated(data, terms, workspace)
7790}
7791
7792pub fn build_smooth_design_withworkspace_unvalidated(
7793    data: ArrayView2<'_, f64>,
7794    terms: &[SmoothTermSpec],
7795    workspace: &mut crate::basis::BasisWorkspace,
7796) -> Result<RawSmoothDesign, BasisError> {
7797    let mut planned_blocks = plan_joint_spatial_centers_for_term_blocks(data, &[terms.to_vec()])?;
7798    let planned_terms = planned_blocks.pop().ok_or_else(|| {
7799        BasisError::InvalidInput(
7800            "joint spatial center planner returned no smooth blocks".to_string(),
7801        )
7802    })?;
7803    let policy = workspace.policy().clone();
7804    let local_builds: Vec<LocalSmoothTermBuild> = {
7805        use rayon::iter::{IntoParallelIterator, ParallelIterator};
7806        planned_terms
7807            .into_par_iter()
7808            .map(|term| {
7809                let mut term_workspace = crate::basis::BasisWorkspace::with_policy(policy.clone());
7810                build_single_local_smooth_term(data, &term, &mut term_workspace)
7811            })
7812            .collect::<Result<Vec<_>, _>>()?
7813    };
7814
7815    let total_p: usize = local_builds.iter().map(|built| built.dim).sum();
7816
7817    let mut local_designs: Vec<DesignMatrix> = Vec::with_capacity(local_builds.len());
7818    let mut terms_out = Vec::<SmoothTerm>::with_capacity(terms.len());
7819    let mut penalties_global = Vec::<BlockwisePenalty>::new();
7820    let mut nullspace_dims_global = Vec::<usize>::new();
7821    let mut penaltyinfo_global = Vec::<PenaltyBlockInfo>::new();
7822    let mut dropped_penaltyinfo_global = Vec::<DroppedPenaltyBlockInfo>::new();
7823    let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
7824    let mut any_bounds = false;
7825    // Each linear-constraint row only touches the current term's column slice.
7826    // Track `(col_start, col_end, local_row_values)` and assemble the final
7827    // dense `Array2` in one pass, avoiding per-row `Array1::zeros(total_p)`
7828    // allocation plus a row-by-row copy at the end.
7829    let mut linear_constraintsrows: Vec<(usize, usize, Array1<f64>)> = Vec::new();
7830    let mut linear_constraints_b: Vec<f64> = Vec::new();
7831
7832    let mut col_start = 0usize;
7833    for (term, mut built) in terms.iter().zip(local_builds.into_iter()) {
7834        let p_local = built.dim;
7835        let col_end = col_start + p_local;
7836        let lb_local = if built.box_reparam {
7837            shape_lower_bounds_local(term.shape, p_local)
7838        } else {
7839            None
7840        };
7841
7842        // Stage-2 joint-null absorption rotation. Fired *before* the
7843        // penalty / design / global aggregation loops below so that every
7844        // subsequent reference to `built.penalties`, `built.design`, and
7845        // `built.ops` sees the post-rotation values.
7846        //
7847        // The math: when the smooth's joint penalty `Σ_k S_k` has a
7848        // non-trivial null space, eigh selects `Q = [U_range | U_null]`
7849        // with null columns at the tail. Setting `β_raw = Q · γ` and
7850        // applying:
7851        //     design        ← X · Q
7852        //     penalties[k]  ← Qᵀ · S_k · Q   (block-diag, zero null tail)
7853        // yields a model whose fitted γ is invariant to the rotation
7854        // (since likelihood depends only on `X · β_raw = X · Q · γ`), but
7855        // whose penalty is full-rank on the range columns. The large-scale
7856        // failing case (cert refusal in the joint-Newton inner solve)
7857        // resolves because `H_pen = H_loglik + S` becomes full rank on
7858        // the smooth's range columns.
7859        //
7860        // Rotation is suppressed when the smooth carries coordinate-wise
7861        // shape constraints (`lb_local` or `built.linear_constraints`):
7862        // those encode a cone in the original coordinate system and a
7863        // general orthogonal rotation breaks the cone geometry. Smooths
7864        // with shape constraints typically have full-rank joint penalty
7865        // (their structural shape comes from the cone, not from null
7866        // directions in the penalty), so suppression is rarely a loss.
7867        //
7868        // `applied_rotation` carries the Q that was applied (or `None`
7869        // if no rotation fired). It is persisted onto `SmoothTerm` below
7870        // so prediction-side `X_new_raw · Q` replay can reproduce the
7871        // exact rotation. Persistence through the saved-model artifact
7872        // is a follow-up — see the doc on `SmoothTerm.joint_null_rotation`.
7873        let applied_rotation: Option<crate::basis::JointNullRotation> = match (
7874            built.joint_null_rotation.take(),
7875            lb_local.is_some(),
7876            built.linear_constraints.is_some(),
7877        ) {
7878            (Some(rot), false, false) => {
7879                let q = &rot.rotation;
7880                let dense = built
7881                    .design
7882                    .try_to_dense_by_chunks("joint-null absorption rotation")
7883                    .map_err(BasisError::InvalidInput)?;
7884                let rotated = gam_linalg::faer_ndarray::fast_ab(&dense, q);
7885                built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(rotated));
7886                built.penalties = built
7887                    .penalties
7888                    .into_iter()
7889                    .map(|s_local| {
7890                        let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
7891                        gam_linalg::faer_ndarray::fast_ab(&qt_s, q)
7892                    })
7893                    .collect();
7894                built.ops = vec![None; built.penalties.len()];
7895                built.kronecker_factored = None;
7896                Some(rot)
7897            }
7898            (Some(_), _, _) => None,
7899            (None, _, _) => None,
7900        };
7901
7902        let activeinfos = built
7903            .penaltyinfo
7904            .iter()
7905            .filter(|info| info.active)
7906            .collect::<Vec<_>>();
7907        if activeinfos.len() != built.penalties.len() {
7908            crate::bail_invalid_basis!(
7909                "internal penalty info mismatch for term '{}': activeinfos={}, penalties={}",
7910                term.name,
7911                activeinfos.len(),
7912                built.penalties.len()
7913            );
7914        }
7915        for (((s_local, &ns), info), op_local) in built
7916            .penalties
7917            .iter()
7918            .zip(built.nullspaces.iter())
7919            .zip(activeinfos.into_iter())
7920            .zip(built.ops.iter())
7921        {
7922            let global_index = penalties_global.len();
7923            penalties_global.push(
7924                BlockwisePenalty::new(col_start..col_end, s_local.clone())
7925                    .with_op(op_local.clone()),
7926            );
7927            nullspace_dims_global.push(ns);
7928            let mut penalty = info.clone();
7929            penalty.nullspace_dim_hint = ns;
7930            penaltyinfo_global.push(PenaltyBlockInfo {
7931                global_index,
7932                termname: Some(term.name.clone()),
7933                penalty,
7934            });
7935        }
7936        for info in built.penaltyinfo.iter().filter(|info| !info.active) {
7937            dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
7938                termname: Some(term.name.clone()),
7939                penalty: info.clone(),
7940            });
7941        }
7942        for info in &built.pre_dropped_penaltyinfo {
7943            dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
7944                termname: Some(term.name.clone()),
7945                penalty: info.clone(),
7946            });
7947        }
7948
7949        if let Some(lin_local) = &built.linear_constraints {
7950            for r in 0..lin_local.a.nrows() {
7951                linear_constraintsrows.push((col_start, col_end, lin_local.a.row(r).to_owned()));
7952                linear_constraints_b.push(lin_local.b[r]);
7953            }
7954        }
7955        if let Some(lb_local) = &lb_local {
7956            coefficient_lower_bounds
7957                .slice_mut(s![col_start..col_end])
7958                .assign(lb_local);
7959            any_bounds = true;
7960        }
7961
7962        // Move the per-term design out of `built` rather than cloning it.
7963        local_designs.push(built.design);
7964
7965        terms_out.push(SmoothTerm {
7966            name: term.name.clone(),
7967            coeff_range: col_start..col_end,
7968            shape: term.shape,
7969            penalties_local: built.penalties,
7970            nullspace_dims: built.nullspaces,
7971            penaltyinfo_local: built.penaltyinfo,
7972            metadata: built.metadata,
7973            lower_bounds_local: lb_local,
7974            linear_constraints_local: built.linear_constraints,
7975            kronecker_factored: built.kronecker_factored.take(),
7976            joint_null_rotation: applied_rotation,
7977            unabsorbed_global_orthogonality: None,
7978        });
7979
7980        col_start = col_end;
7981    }
7982
7983    assert_eq!(
7984        penalties_global.len(),
7985        nullspace_dims_global.len(),
7986        "global smooth penalty/nullspace bookkeeping diverged"
7987    );
7988    assert_eq!(
7989        penalties_global.len(),
7990        penaltyinfo_global.len(),
7991        "global smooth penalty metadata bookkeeping diverged"
7992    );
7993
7994    Ok(RawSmoothDesign {
7995        term_designs: local_designs,
7996        penalties: penalties_global,
7997        nullspace_dims: nullspace_dims_global,
7998        penaltyinfo: penaltyinfo_global,
7999        dropped_penaltyinfo: dropped_penaltyinfo_global,
8000        terms: terms_out,
8001        coefficient_lower_bounds: if any_bounds {
8002            Some(coefficient_lower_bounds)
8003        } else {
8004            None
8005        },
8006        linear_constraints: if linear_constraintsrows.is_empty() {
8007            None
8008        } else {
8009            let mut a = Array2::<f64>::zeros((linear_constraintsrows.len(), total_p));
8010            for (i, (cs, ce, values)) in linear_constraintsrows.iter().enumerate() {
8011                a.row_mut(i).slice_mut(s![*cs..*ce]).assign(values);
8012            }
8013            Some(LinearInequalityConstraints {
8014                a,
8015                b: Array1::from_vec(linear_constraints_b),
8016            })
8017        },
8018    })
8019}