Skip to main content

gam_terms/
grid_spline_2d.rs

1//! Streaming scatter-add 2-D smoother: K×K tensor-product cubic B-splines
2//! with the EXACT anisotropic biharmonic penalty and REML-selected λ.
3//!
4//! Basis. Each axis carries K equal-width cells over the data's bounding box
5//! `[lo, hi]` with uniform extended knots `t_j = lo + (j−3)·h`, `h = (hi−lo)/K`,
6//! giving `m = K+3` cubic B-splines per axis; the tensor product has
7//! `p = (K+3)²` coefficients. A point in cell `i` activates exactly the four
8//! splines `i..i+3` per axis, hence exactly 4×4 = 16 tensor basis entries per
9//! data row.
10//!
11//! Streaming normal equations. ONE pass over the rows `(x1, x2, y_·, w)`
12//! scatter-adds `X'WX` and `X'Wy_d` (any number of response dimensions share
13//! the design, the penalty, and one REML λ — the multi-output "one surface
14//! smoothness" contract of the ANOVA pair component): O(n·(16² + 16·D)) work,
15//! no n×p design is ever materialized. Two tensor bases overlap only when both per-axis indices
16//! differ by ≤ 3, so under the row-major coefficient index
17//! `g = j1·(K+3) + j2` both `X'WX` and the penalty `S` are banded with
18//! half-bandwidth `3(K+3)+3`; they are stored as upper bands — O(K³) numbers.
19//!
20//! Penalty. The FULL anisotropic biharmonic form for the diagonal metric
21//! `A = diag(a1, a2)`,
22//!   `J(f) = ∫∫ a1²·f_{x1x1}² + 2·a1·a2·f_{x1x2}² + a2²·f_{x2x2}²  dx1 dx2`,
23//! INCLUDING the mixed `f_{x1x2}` term (the axis-wise P-spline difference
24//! shortcut drops it), assembled per knot cell by 4-point Gauss–Legendre per
25//! axis. Exactness degree arithmetic: on a knot cell every basis function is
26//! a single cubic polynomial per axis, so each entry of `S` is a sum over
27//! cells of integrands that factorize per axis as one of value·value
28//! (degree 3+3 = 6), deriv·deriv (2+2 = 4) or 2nd-deriv·2nd-deriv (1+1 = 2);
29//! every channel pairs a low-degree factor on one axis with at worst the
30//! degree-6 value·value factor on the other. 4-point Gauss–Legendre is exact
31//! through degree 2·4−1 = 7 ≥ 6, so the assembled `S` is the EXACT integral,
32//! not a quadrature approximation.
33//!
34//! Solve and selection. For a trial λ the bands are expanded and
35//! `(X'WX + λS)c = X'Wy` is solved by dense Cholesky with the exact
36//! log-determinant read off the factor. `p ≤ (32+3)² = 1225`, so the O(p³)
37//! factorization costs ≲ 1 Gflop per trial and the retained factor is ≤ 12 MB;
38//! K is capped at 32 to keep that sizing contract honest. λ maximizes the
39//! profiled-σ² restricted (REML) criterion
40//!   `ℓ_R(λ) = −½[ log|X'WX+λS| − r·log λ + (n−3)·log σ̂²(λ) ] + const`,
41//! where `r = p−3` is the penalty rank — the null space of `J` is
42//! span{1, x1, x2} (the mixed term penalizes `x1·x2`, whose cross derivative
43//! is 1 ≠ 0, so it is NOT in the null space), `σ̂²(λ) = (y'Wy − c'X'Wy)/(n−3)`
44//! is the profiled scale, and the λ-free additive constants (`log|S|₊` on the
45//! row space of S, `Σ log w`, 2π factors) are dropped: differences across λ
46//! are exact REML criterion differences. Selection is the same deterministic
47//! coarse-grid + golden-section scheme as `spline_scan` — no RNG, same data ⇒
48//! same fit.
49//!
50//! Prediction. `predict(x1, x2)` builds the 16-entry basis row; the mean is
51//! its dot with `c` and the variance is the Bayesian posterior
52//! `σ̂²·x'(X'WX+λS)⁻¹x` through the retained Cholesky factor. Outside the
53//! bounding box the boundary cell's cubic polynomial extends naturally (the
54//! cell index clamps, the local coordinate does not).
55
56/// Dimension of the penalty null space: span{1, x1, x2}. The mixed
57/// `2·a1·a2·f_{x1x2}²` term excludes `x1·x2` (its cross derivative is 1).
58const PENALTY_NULLITY: usize = 3;
59
60/// Deterministic coarse-grid width for the log-λ search.
61const LOG_LAMBDA_GRID: usize = 25;
62/// Search interval for log λ (natural log), generous on both sides.
63const LOG_LAMBDA_LO: f64 = -18.0;
64const LOG_LAMBDA_HI: f64 = 18.0;
65/// Golden-section refinement tolerance on log λ.
66const LOG_LAMBDA_TOL: f64 = 1e-7;
67/// Cholesky pivot floor below which the penalized system is declared singular.
68const PIVOT_FLOOR: f64 = 1e-300;
69/// Dense-Cholesky sizing contract documented in the module header.
70const MAX_CELLS_PER_AXIS: usize = 32;
71
72/// 4-point Gauss–Legendre nodes and weights on [−1, 1]. Exact through degree
73/// 2·4−1 = 7, which dominates the degree-6 worst per-axis factor of the
74/// penalty integrands (see the module header for the degree arithmetic).
75const GL4_NODES: [f64; 4] = [
76    -0.861_136_311_594_052_6,
77    -0.339_981_043_584_856_26,
78    0.339_981_043_584_856_26,
79    0.861_136_311_594_052_6,
80];
81const GL4_WEIGHTS: [f64; 4] = [
82    0.347_854_845_137_453_85,
83    0.652_145_154_862_546_2,
84    0.652_145_154_862_546_2,
85    0.347_854_845_137_453_85,
86];
87
88/// Cubic B-spline segment values at local coordinate `u` within a cell.
89/// Entry `m` weights basis `cell + m`: m = 0 is the spline ENDING in this
90/// cell (`(1−u)³/6`), m = 3 the one STARTING (`u³/6`). The four entries sum
91/// to 1 (partition of unity) for u ∈ [0, 1].
92#[inline]
93fn bspline_value(u: f64) -> [f64; 4] {
94    let v = 1.0 - u;
95    [
96        v * v * v / 6.0,
97        (3.0 * u * u * u - 6.0 * u * u + 4.0) / 6.0,
98        (-3.0 * u * u * u + 3.0 * u * u + 3.0 * u + 1.0) / 6.0,
99        u * u * u / 6.0,
100    ]
101}
102
103/// d/du of `bspline_value` (caller scales by 1/h for d/dx). Entries sum to 0.
104#[inline]
105fn bspline_d1(u: f64) -> [f64; 4] {
106    let v = 1.0 - u;
107    [
108        -0.5 * v * v,
109        0.5 * (3.0 * u * u - 4.0 * u),
110        0.5 * (-3.0 * u * u + 2.0 * u + 1.0),
111        0.5 * u * u,
112    ]
113}
114
115/// d²/du² of `bspline_value` (caller scales by 1/h²). Piecewise LINEAR in u —
116/// the degree-1 factor in the quadrature-exactness argument. Entries sum to 0.
117#[inline]
118fn bspline_d2(u: f64) -> [f64; 4] {
119    [1.0 - u, 3.0 * u - 2.0, 1.0 - 3.0 * u, u]
120}
121
122/// One uniform B-spline axis over `[lo, lo + cells·h]`.
123#[derive(Clone, Copy, Debug)]
124struct Axis {
125    lo: f64,
126    h: f64,
127    cells: usize,
128}
129
130impl Axis {
131    /// Cell index and local coordinate. Inside the box `u ∈ [0, 1]`; outside,
132    /// the cell clamps and `u` leaves [0, 1], extending the boundary cell's
133    /// cubic polynomial (deterministic extrapolation, no special casing).
134    #[inline]
135    fn locate(&self, x: f64) -> (usize, f64) {
136        let t = (x - self.lo) / self.h;
137        let cell = (t.floor().max(0.0) as usize).min(self.cells - 1);
138        (cell, t - cell as f64)
139    }
140}
141
142/// The four active cubic B-spline values of one uniform axis `(lo, h, cells)`
143/// at `x`: `(first basis index, values)`, where `values[i]` weights basis
144/// `first + i` of the `cells + 3` axis splines. Outside `[lo, lo + cells·h]`
145/// the boundary cell's cubic polynomial extends — the single convention
146/// shared by fitting, prediction, and every consumer-rebuilt basis row.
147pub fn axis_basis_at(lo: f64, h: f64, cells: usize, x: f64) -> (usize, [f64; 4]) {
148    let (cell, u) = Axis { lo, h, cells }.locate(x);
149    (cell, bspline_value(u))
150}
151
152/// The 16 active tensor-basis entries `(flat index, value)` at `(x1, x2)`.
153/// Flat indices are strictly increasing across the returned arrays.
154#[inline]
155fn basis_row(axes: &[Axis; 2], m_axis: usize, x1: f64, x2: f64) -> ([usize; 16], [f64; 16]) {
156    let (c1, u1) = axes[0].locate(x1);
157    let (c2, u2) = axes[1].locate(x2);
158    let b1 = bspline_value(u1);
159    let b2 = bspline_value(u2);
160    let mut idx = [0usize; 16];
161    let mut val = [0f64; 16];
162    for i in 0..4 {
163        for j in 0..4 {
164            idx[4 * i + j] = (c1 + i) * m_axis + (c2 + j);
165            val[4 * i + j] = b1[i] * b2[j];
166        }
167    }
168    (idx, val)
169}
170
171/// Dense lower-Cholesky in place (row-major `p×p`); returns the exact
172/// `log det` (twice the log of the pivot products). The strict upper triangle
173/// is zeroed so the buffer is exactly `L` afterwards.
174pub fn cholesky_logdet(a: &mut [f64], p: usize) -> Result<f64, String> {
175    let mut logdet = 0.0;
176    for j in 0..p {
177        let mut s = a[j * p + j];
178        for t in 0..j {
179            s -= a[j * p + t] * a[j * p + t];
180        }
181        if !(s.is_finite() && s > PIVOT_FLOOR) {
182            return Err(format!(
183                "grid spline 2d: penalized system not positive definite at pivot {j} (value {s})"
184            ));
185        }
186        let l = s.sqrt();
187        a[j * p + j] = l;
188        logdet += 2.0 * l.ln();
189        for i in j + 1..p {
190            let mut s2 = a[i * p + j];
191            for t in 0..j {
192                s2 -= a[i * p + t] * a[j * p + t];
193            }
194            a[i * p + j] = s2 / l;
195        }
196    }
197    for i in 0..p {
198        for j in i + 1..p {
199            a[i * p + j] = 0.0;
200        }
201    }
202    Ok(logdet)
203}
204
205/// Solve `L Lᵀ x = b` from the stored lower factor.
206pub fn chol_solve(l: &[f64], p: usize, b: &[f64]) -> Vec<f64> {
207    let mut z = b.to_vec();
208    for i in 0..p {
209        let mut s = z[i];
210        for t in 0..i {
211            s -= l[i * p + t] * z[t];
212        }
213        z[i] = s / l[i * p + i];
214    }
215    for i in (0..p).rev() {
216        let mut s = z[i];
217        for t in i + 1..p {
218            s -= l[t * p + i] * z[t];
219        }
220        z[i] = s / l[i * p + i];
221    }
222    z
223}
224
225/// Banded sufficient statistics of one streaming pass plus the exact penalty:
226/// everything needed to evaluate the REML criterion and solve at any λ.
227pub struct GridSpline2dDesign {
228    axes: [Axis; 2],
229    /// Basis count per axis, `K + 3`.
230    m_axis: usize,
231    /// Total coefficients, `(K + 3)²`.
232    p: usize,
233    /// Upper half-bandwidth `3·(K+3) + 3` of both banded matrices.
234    band_half: usize,
235    /// Upper band of `X'WX`: entry `(g, g+d)` at `g·(band_half+1) + d`.
236    gram_band: Vec<f64>,
237    /// Upper band of the exact anisotropic biharmonic penalty `S`.
238    pen_band: Vec<f64>,
239    /// `X'Wy_d`, one length-`p` vector per response dimension. The design
240    /// (gram and penalty bands) is shared across dimensions; only these
241    /// right-hand sides and the response cross-moments are per-dimension.
242    rhs: Vec<Vec<f64>>,
243    /// Response cross-moments `y_d'W y_e` (`D × D` row-major), for the
244    /// profiled-σ² residual quadratics and the residual cross-covariance.
245    cross_moments: Vec<f64>,
246    n_obs: usize,
247}
248
249/// Internal solve product at one λ (all response dimensions share the factor).
250struct Solved {
251    chol: Vec<f64>,
252    logdet: f64,
253    coeffs: Vec<Vec<f64>>,
254    /// Per dimension: penalized residual quadratic `y'Wy − c'X'Wy` =
255    /// `‖√W(y − Xc)‖² + λ c'Sc` at the minimizer.
256    rss_pen: Vec<f64>,
257}
258
259impl GridSpline2dDesign {
260    /// Single-response entry: see [`Self::build_multi`].
261    pub fn build(
262        x1: &[f64],
263        x2: &[f64],
264        y: &[f64],
265        w: &[f64],
266        k: usize,
267        metric: [f64; 2],
268    ) -> Result<Self, String> {
269        Self::build_multi(x1, x2, &[y], w, k, metric)
270    }
271
272    /// One streaming pass over the rows plus the exact per-cell quadrature
273    /// assembly of the penalty. `k` is the number of cells per axis;
274    /// `metric = [a1, a2]` is the diagonal anisotropy of the biharmonic form.
275    /// `responses` holds one length-`n` response per dimension; the design,
276    /// penalty, and the REML-shared λ are common to all dimensions (one
277    /// surface smoothness), only the right-hand sides differ.
278    pub fn build_multi(
279        x1: &[f64],
280        x2: &[f64],
281        responses: &[&[f64]],
282        w: &[f64],
283        k: usize,
284        metric: [f64; 2],
285    ) -> Result<Self, String> {
286        let n = x1.len();
287        if responses.is_empty() {
288            return Err("grid spline 2d: no response dimensions supplied".to_string());
289        }
290        if x2.len() != n || w.len() != n {
291            return Err(format!(
292                "grid spline 2d: length mismatch x1={n}, x2={}, w={}",
293                x2.len(),
294                w.len()
295            ));
296        }
297        for (d, y) in responses.iter().enumerate() {
298            if y.len() != n {
299                return Err(format!(
300                    "grid spline 2d: response dimension {d} has length {} != {n}",
301                    y.len()
302                ));
303            }
304        }
305        if n <= PENALTY_NULLITY {
306            return Err(format!(
307                "grid spline 2d: needs more than {PENALTY_NULLITY} rows for the profiled REML \
308                 degrees of freedom, got {n}"
309            ));
310        }
311        if k == 0 || k > MAX_CELLS_PER_AXIS {
312            return Err(format!(
313                "grid spline 2d: k must be in 1..={MAX_CELLS_PER_AXIS} (dense Cholesky on \
314                 (k+3)² coefficients — see module sizing contract), got {k}"
315            ));
316        }
317        if !(metric[0].is_finite() && metric[0] > 0.0 && metric[1].is_finite() && metric[1] > 0.0) {
318            return Err(format!(
319                "grid spline 2d: metric diagonal must be finite and positive, got [{}, {}]",
320                metric[0], metric[1]
321            ));
322        }
323        for i in 0..n {
324            if !(x1[i].is_finite() && x2[i].is_finite()) || !(w[i] > 0.0) || !w[i].is_finite() {
325                return Err(format!(
326                    "grid spline 2d: non-finite or non-positive input at row {i} \
327                     (x1={}, x2={}, w={})",
328                    x1[i], x2[i], w[i]
329                ));
330            }
331            for (d, y) in responses.iter().enumerate() {
332                if !y[i].is_finite() {
333                    return Err(format!(
334                        "grid spline 2d: non-finite response at row {i}, dimension {d} ({})",
335                        y[i]
336                    ));
337                }
338            }
339        }
340        let mut axes = [Axis {
341            lo: 0.0,
342            h: 1.0,
343            cells: k,
344        }; 2];
345        for (axis, xs) in axes.iter_mut().zip([x1, x2]) {
346            let mut lo = f64::INFINITY;
347            let mut hi = f64::NEG_INFINITY;
348            for &v in xs {
349                lo = lo.min(v);
350                hi = hi.max(v);
351            }
352            if !(hi > lo) {
353                return Err(format!(
354                    "grid spline 2d: degenerate axis bounding box [{lo}, {hi}]"
355                ));
356            }
357            axis.lo = lo;
358            axis.h = (hi - lo) / k as f64;
359        }
360        let m_axis = k + 3;
361        let p = m_axis * m_axis;
362        let band_half = 3 * m_axis + 3;
363        let stride = band_half + 1;
364        let n_dims = responses.len();
365        let mut gram_band = vec![0.0_f64; p * stride];
366        let mut rhs = vec![vec![0.0_f64; p]; n_dims];
367        let mut cross_moments = vec![0.0_f64; n_dims * n_dims];
368
369        // ── ONE streaming pass: scatter-add X'WX (upper band) and X'Wy_d ──
370        // Each row touches exactly 16 basis entries with strictly increasing
371        // flat indices, so the in-row pair loop (a ≤ b) lands directly in the
372        // upper band: O(n·(16² + 16·D)) total work.
373        for i in 0..n {
374            let (idx, val) = basis_row(&axes, m_axis, x1[i], x2[i]);
375            let wi = w[i];
376            for (d, y) in responses.iter().enumerate() {
377                let wy = wi * y[i];
378                for e in 0..16 {
379                    rhs[d][idx[e]] += wy * val[e];
380                }
381                for (e, ye) in responses.iter().enumerate().skip(d) {
382                    cross_moments[d * n_dims + e] += wy * ye[i];
383                }
384            }
385            for a in 0..16 {
386                let base = idx[a] * stride - idx[a];
387                let wa = wi * val[a];
388                for b in a..16 {
389                    gram_band[base + idx[b]] += wa * val[b];
390                }
391            }
392        }
393        for d in 0..n_dims {
394            for e in 0..d {
395                cross_moments[d * n_dims + e] = cross_moments[e * n_dims + d];
396            }
397        }
398
399        // ── Exact penalty assembly: 4-pt Gauss–Legendre per axis per cell ──
400        // Per-axis quadrature tables (cell-independent on a uniform grid):
401        // values, d/dx (scaled 1/h), d²/dx² (scaled 1/h²) at each GL node.
402        let mut tab = [[[[0.0_f64; 4]; 4]; 3]; 2]; // [axis][channel][node][basis offset]
403        for ax in 0..2 {
404            let h = axes[ax].h;
405            for q in 0..4 {
406                let u = 0.5 * (1.0 + GL4_NODES[q]);
407                let v0 = bspline_value(u);
408                let v1 = bspline_d1(u);
409                let v2 = bspline_d2(u);
410                for e in 0..4 {
411                    tab[ax][0][q][e] = v0[e];
412                    tab[ax][1][q][e] = v1[e] / h;
413                    tab[ax][2][q][e] = v2[e] / (h * h);
414                }
415            }
416        }
417        // Channel scales: J = ∫ a1²·f11² + 2·a1·a2·f12² + a2²·f22².
418        let s11 = metric[0] * metric[0];
419        let s12 = 2.0 * metric[0] * metric[1];
420        let s22 = metric[1] * metric[1];
421        let cell_area_jac = 0.25 * axes[0].h * axes[1].h; // d(x1,x2)/d(ξ1,ξ2) on [−1,1]²
422        let mut pen_band = vec![0.0_f64; p * stride];
423        let mut r11 = [0.0_f64; 16];
424        let mut r12 = [0.0_f64; 16];
425        let mut r22 = [0.0_f64; 16];
426        let mut idx = [0usize; 16];
427        for c1 in 0..k {
428            for c2 in 0..k {
429                for i in 0..4 {
430                    for j in 0..4 {
431                        idx[4 * i + j] = (c1 + i) * m_axis + (c2 + j);
432                    }
433                }
434                for q1 in 0..4 {
435                    for q2 in 0..4 {
436                        let wq = cell_area_jac * GL4_WEIGHTS[q1] * GL4_WEIGHTS[q2];
437                        for i in 0..4 {
438                            for j in 0..4 {
439                                let e = 4 * i + j;
440                                r11[e] = tab[0][2][q1][i] * tab[1][0][q2][j];
441                                r12[e] = tab[0][1][q1][i] * tab[1][1][q2][j];
442                                r22[e] = tab[0][0][q1][i] * tab[1][2][q2][j];
443                            }
444                        }
445                        for a in 0..16 {
446                            let base = idx[a] * stride - idx[a];
447                            let (pa11, pa12, pa22) =
448                                (wq * s11 * r11[a], wq * s12 * r12[a], wq * s22 * r22[a]);
449                            for b in a..16 {
450                                pen_band[base + idx[b]] +=
451                                    pa11 * r11[b] + pa12 * r12[b] + pa22 * r22[b];
452                            }
453                        }
454                    }
455                }
456            }
457        }
458
459        Ok(GridSpline2dDesign {
460            axes,
461            m_axis,
462            p,
463            band_half,
464            gram_band,
465            pen_band,
466            rhs,
467            cross_moments,
468            n_obs: n,
469        })
470    }
471
472    /// Number of cells per axis (the caller-supplied K).
473    pub fn num_cells(&self) -> usize {
474        self.axes[0].cells
475    }
476
477    /// Basis functions per axis, `K + 3`.
478    pub fn basis_per_axis(&self) -> usize {
479        self.m_axis
480    }
481
482    /// Total coefficient count `(K + 3)²`.
483    pub fn num_coeffs(&self) -> usize {
484        self.p
485    }
486
487    /// Lower corner of the data bounding box per axis.
488    pub fn lower_corner(&self) -> [f64; 2] {
489        [self.axes[0].lo, self.axes[1].lo]
490    }
491
492    /// Knot-cell width per axis.
493    pub fn cell_widths(&self) -> [f64; 2] {
494        [self.axes[0].h, self.axes[1].h]
495    }
496
497    /// Number of data rows the design was streamed from.
498    pub fn num_rows(&self) -> usize {
499        self.n_obs
500    }
501
502    /// Number of response dimensions sharing the design.
503    pub fn num_responses(&self) -> usize {
504        self.rhs.len()
505    }
506
507    /// The four active cubic B-spline values of one AXIS at `x`: returns
508    /// `(j0, values)` where `values[i]` weights basis `j0 + i` of that axis
509    /// (`0..K+3`). The tensor flat index of `(j1, j2)` is `j1·(K+3) + j2` —
510    /// row-major, axis 0 major. Outside the bounding box the boundary cell's
511    /// cubic polynomial extends (same convention as fitting and prediction).
512    pub fn axis_basis(&self, axis: usize, x: f64) -> Result<(usize, [f64; 4]), String> {
513        if axis > 1 {
514            return Err(format!("grid spline 2d: axis {axis} out of range"));
515        }
516        if !x.is_finite() {
517            return Err(format!("grid spline 2d: non-finite axis-{axis} point {x}"));
518        }
519        let ax = self.axes[axis];
520        Ok(axis_basis_at(ax.lo, ax.h, ax.cells, x))
521    }
522
523    /// Exact penalty quadratic form `J(f) = c'Sc` of a coefficient vector —
524    /// the assembled anisotropic biharmonic energy of the spline it encodes.
525    pub fn penalty_value(&self, coeff: &[f64]) -> Result<f64, String> {
526        if coeff.len() != self.p {
527            return Err(format!(
528                "grid spline 2d: coefficient length {} != {}",
529                coeff.len(),
530                self.p
531            ));
532        }
533        let stride = self.band_half + 1;
534        let mut j = 0.0;
535        for g in 0..self.p {
536            let dmax = self.band_half.min(self.p - 1 - g);
537            j += self.pen_band[g * stride] * coeff[g] * coeff[g];
538            for d in 1..=dmax {
539                j += 2.0 * self.pen_band[g * stride + d] * coeff[g] * coeff[g + d];
540            }
541        }
542        Ok(j)
543    }
544
545    /// Expand `X'WX + λS` from the bands to a dense symmetric matrix.
546    fn dense_system(&self, lambda: f64) -> Vec<f64> {
547        let p = self.p;
548        let stride = self.band_half + 1;
549        let mut a = vec![0.0_f64; p * p];
550        for g in 0..p {
551            let dmax = self.band_half.min(p - 1 - g);
552            for d in 0..=dmax {
553                let v = self.gram_band[g * stride + d] + lambda * self.pen_band[g * stride + d];
554                a[g * p + g + d] = v;
555                a[(g + d) * p + g] = v;
556            }
557        }
558        a
559    }
560
561    fn solve_at(&self, log_lambda: f64) -> Result<Solved, String> {
562        if !log_lambda.is_finite() {
563            return Err(format!(
564                "grid spline 2d: non-finite log lambda {log_lambda}"
565            ));
566        }
567        let mut a = self.dense_system(log_lambda.exp());
568        let logdet = cholesky_logdet(&mut a, self.p)?;
569        let n_dims = self.rhs.len();
570        let mut coeffs = Vec::with_capacity(n_dims);
571        let mut rss_pen = Vec::with_capacity(n_dims);
572        for (d, rhs) in self.rhs.iter().enumerate() {
573            let coeff = chol_solve(&a, self.p, rhs);
574            let mut quad = 0.0;
575            for g in 0..self.p {
576                quad += rhs[g] * coeff[g];
577            }
578            rss_pen.push(self.cross_moments[d * n_dims + d] - quad);
579            coeffs.push(coeff);
580        }
581        Ok(Solved {
582            chol: a,
583            logdet,
584            coeffs,
585            rss_pen,
586        })
587    }
588
589    /// Profiled-σ² REML criterion at `log λ`, pooled across the response
590    /// dimensions sharing the design and λ, up to λ- and data-independent
591    /// additive constants (differences across λ are exact REML differences):
592    ///   `−½ Σ_d [ log|X'WX+λS| − r·log λ + (n−3)·log σ̂²_d(λ) ]`.
593    fn criterion(&self, log_lambda: f64) -> Result<f64, String> {
594        let solved = self.solve_at(log_lambda)?;
595        let dof = (self.n_obs - PENALTY_NULLITY) as f64;
596        let r = (self.p - PENALTY_NULLITY) as f64;
597        let shared = solved.logdet - r * log_lambda;
598        let mut v = 0.0;
599        for &rss in &solved.rss_pen {
600            if !(rss > 0.0) {
601                return Err(format!(
602                    "grid spline 2d: degenerate penalized residual {rss}"
603                ));
604            }
605            v += shared + dof * (rss / dof).ln();
606        }
607        Ok(-0.5 * v)
608    }
609
610    /// Fit at a FIXED `log λ`, with σ² either supplied (applied to every
611    /// response dimension) or profiled per dimension.
612    pub fn fit_at(&self, log_lambda: f64, sigma2: Option<f64>) -> Result<GridSpline2dFit, String> {
613        let solved = self.solve_at(log_lambda)?;
614        let dof = (self.n_obs - PENALTY_NULLITY) as f64;
615        let mut sigma2_dims = Vec::with_capacity(solved.rss_pen.len());
616        for &rss in &solved.rss_pen {
617            match sigma2 {
618                Some(s) => {
619                    if !(s.is_finite() && s > 0.0) {
620                        return Err(format!("grid spline 2d: invalid sigma2 {s}"));
621                    }
622                    sigma2_dims.push(s);
623                }
624                None => {
625                    if !(rss > 0.0) {
626                        return Err(format!(
627                            "grid spline 2d: degenerate penalized residual {rss}"
628                        ));
629                    }
630                    sigma2_dims.push(rss / dof);
631                }
632            }
633        }
634        // Full restricted log-likelihood at this (λ, σ²) up to λ- and σ-free
635        // constants, pooled across dimensions: at the profiled σ̂²_d the
636        // quadratic collapses to the λ-free constant `dof` per dimension,
637        // matching `criterion` up to that constant.
638        let r = (self.p - PENALTY_NULLITY) as f64;
639        let mut restricted_loglik = 0.0;
640        for (d, &rss) in solved.rss_pen.iter().enumerate() {
641            restricted_loglik -= 0.5
642                * (solved.logdet - r * log_lambda
643                    + dof * sigma2_dims[d].ln()
644                    + rss / sigma2_dims[d]);
645        }
646        Ok(GridSpline2dFit {
647            coeffs: solved.coeffs,
648            log_lambda,
649            sigma2: sigma2_dims,
650            restricted_loglik,
651            chol: solved.chol,
652            axes: self.axes,
653            m_axis: self.m_axis,
654        })
655    }
656
657    /// Fit with `log λ` selected by the profiled REML criterion: deterministic
658    /// coarse grid then golden-section refinement (no RNG — same data, same fit).
659    pub fn fit_reml(&self) -> Result<GridSpline2dFit, String> {
660        let mut best_i = 0usize;
661        let mut best_v = f64::NEG_INFINITY;
662        let step = (LOG_LAMBDA_HI - LOG_LAMBDA_LO) / (LOG_LAMBDA_GRID - 1) as f64;
663        for i in 0..LOG_LAMBDA_GRID {
664            let ll = LOG_LAMBDA_LO + step * i as f64;
665            let v = self.criterion(ll)?;
666            if v > best_v {
667                best_v = v;
668                best_i = i;
669            }
670        }
671        let mut lo = LOG_LAMBDA_LO + step * best_i.saturating_sub(1) as f64;
672        let mut hi = (LOG_LAMBDA_LO + step * (best_i + 1) as f64).min(LOG_LAMBDA_HI);
673        // Golden-section maximization on [lo, hi].
674        let inv_phi = 0.618_033_988_749_894_9_f64;
675        let mut x1 = hi - inv_phi * (hi - lo);
676        let mut x2 = lo + inv_phi * (hi - lo);
677        let mut f1 = self.criterion(x1)?;
678        let mut f2 = self.criterion(x2)?;
679        while hi - lo > LOG_LAMBDA_TOL {
680            if f1 < f2 {
681                lo = x1;
682                x1 = x2;
683                f1 = f2;
684                x2 = lo + inv_phi * (hi - lo);
685                f2 = self.criterion(x2)?;
686            } else {
687                hi = x2;
688                x2 = x1;
689                f2 = f1;
690                x1 = hi - inv_phi * (hi - lo);
691                f1 = self.criterion(x1)?;
692            }
693        }
694        self.fit_at(0.5 * (lo + hi), None)
695    }
696
697    /// `a'(X'WX)b` through the retained upper band (exact, O(p·bandwidth)).
698    fn gram_quadratic(&self, a: &[f64], b: &[f64]) -> f64 {
699        let stride = self.band_half + 1;
700        let mut q = 0.0;
701        for g in 0..self.p {
702            let dmax = self.band_half.min(self.p - 1 - g);
703            q += self.gram_band[g * stride] * a[g] * b[g];
704            for d in 1..=dmax {
705                q += self.gram_band[g * stride + d] * (a[g] * b[g + d] + a[g + d] * b[g]);
706            }
707        }
708        q
709    }
710
711    /// Posterior summary of a fit FROM THIS DESIGN, in the exact algebra of
712    /// the solved system (no approximation):
713    /// - `unit_covariance = (X'WX + λS)⁻¹` (scale-free Bayesian posterior
714    ///   covariance of the row-major coefficient vec, shared by dimensions);
715    /// - `edf = tr[(X'WX + λS)⁻¹ X'WX]` (the smoother's effective degrees of
716    ///   freedom at the fitted λ);
717    /// - `residual_cross_cov[d,e] = r_d'W r_e / (n − edf)` assembled from the
718    ///   streamed sufficient statistics
719    ///   (`y_d'Wy_e − c_d'X'Wy_e − c_e'X'Wy_d + c_d'X'WX c_e`).
720    pub fn posterior(&self, fit: &GridSpline2dFit) -> Result<GridSpline2dPosterior, String> {
721        let p = self.p;
722        let n_dims = self.rhs.len();
723        if fit.coeffs.len() != n_dims || fit.coeffs.iter().any(|c| c.len() != p) {
724            return Err(format!(
725                "grid spline 2d: posterior asked for a fit with {} dimensions of length {}, \
726                 design has {n_dims} of {p}",
727                fit.coeffs.len(),
728                fit.coeffs.first().map_or(0, Vec::len)
729            ));
730        }
731        // H⁻¹ column by column through the retained factor (symmetric, O(p³)).
732        let mut unit_covariance = vec![0.0_f64; p * p];
733        let mut e_g = vec![0.0_f64; p];
734        for g in 0..p {
735            e_g[g] = 1.0;
736            let col = chol_solve(&fit.chol, p, &e_g);
737            e_g[g] = 0.0;
738            for (r, &v) in col.iter().enumerate() {
739                unit_covariance[r * p + g] = v;
740            }
741        }
742        // edf = tr(H⁻¹ X'WX) via the gram band (diagonal once, off-band twice).
743        let stride = self.band_half + 1;
744        let mut edf = 0.0;
745        for g in 0..p {
746            let dmax = self.band_half.min(p - 1 - g);
747            edf += self.gram_band[g * stride] * unit_covariance[g * p + g];
748            for d in 1..=dmax {
749                edf += 2.0 * self.gram_band[g * stride + d] * unit_covariance[g * p + g + d];
750            }
751        }
752        let residual_df = self.n_obs as f64 - edf;
753        if !(residual_df >= 1.0) {
754            return Err(format!(
755                "grid spline 2d: too few rows for a scale estimate \
756                 (n = {}, edf = {edf:.2}; need n − edf ≥ 1)",
757                self.n_obs
758            ));
759        }
760        let mut residual_cross_cov = vec![0.0_f64; n_dims * n_dims];
761        for d in 0..n_dims {
762            for e in d..n_dims {
763                let mut cd_rhse = 0.0;
764                let mut ce_rhsd = 0.0;
765                for g in 0..p {
766                    cd_rhse += fit.coeffs[d][g] * self.rhs[e][g];
767                    ce_rhsd += fit.coeffs[e][g] * self.rhs[d][g];
768                }
769                let quad = self.gram_quadratic(&fit.coeffs[d], &fit.coeffs[e]);
770                let v =
771                    (self.cross_moments[d * n_dims + e] - cd_rhse - ce_rhsd + quad) / residual_df;
772                residual_cross_cov[d * n_dims + e] = v;
773                residual_cross_cov[e * n_dims + d] = v;
774            }
775        }
776        Ok(GridSpline2dPosterior {
777            unit_covariance,
778            edf,
779            residual_df,
780            residual_cross_cov,
781        })
782    }
783}
784
785/// Exact posterior summary of a [`GridSpline2dFit`] (see
786/// [`GridSpline2dDesign::posterior`]): the bridge from the streaming engine
787/// to covariance-consuming clients (the ANOVA pair-component carve).
788pub struct GridSpline2dPosterior {
789    /// `(X'WX + λS)⁻¹`, `p × p` row-major — scale-free posterior covariance
790    /// of the row-major coefficient vec, shared by all response dimensions.
791    pub unit_covariance: Vec<f64>,
792    /// `tr[(X'WX + λS)⁻¹ X'WX]`.
793    pub edf: f64,
794    /// `n − edf`.
795    pub residual_df: f64,
796    /// `D × D` row-major residual cross-covariance at `n − edf`.
797    pub residual_cross_cov: Vec<f64>,
798}
799
800/// Fitted penalized tensor-product smoother with its factored covariance.
801pub struct GridSpline2dFit {
802    /// Per response dimension: coefficients in row-major flat order
803    /// `g = j1·(K+3) + j2`.
804    pub coeffs: Vec<Vec<f64>>,
805    /// Selected (or supplied) log smoothing parameter, shared by all
806    /// response dimensions.
807    pub log_lambda: f64,
808    /// Per response dimension: profiled (or supplied) observation variance σ².
809    pub sigma2: Vec<f64>,
810    /// Pooled restricted log-likelihood at the optimum, up to λ- and
811    /// data-independent additive constants (exact REML differences across λ).
812    pub restricted_loglik: f64,
813    /// Lower Cholesky factor of `X'WX + λS` — the factored posterior precision
814    /// (unit-σ² scale) used for prediction variances, shared by all dimensions.
815    chol: Vec<f64>,
816    axes: [Axis; 2],
817    m_axis: usize,
818}
819
820/// Serializable snapshot of a [`GridSpline2dFit`] (#1031 persistence
821/// prerequisite). The grid is deliberately NOT a formula fast path — it is an
822/// ANOVA pair component (#975 carve) — so there is no `FitResult` variant; this
823/// state is what the carve's persistence payload serializes and what
824/// `from_state` replays for an exact predict.
825///
826/// Predict needs the MEAN (`coeffs` + the 16-entry tensor basis row, which is a
827/// pure function of `axes`/`m_axis`) and the VARIANCE
828/// (`σ²·x'(X'WX+λS)⁻¹x` through the retained Cholesky factor `chol`). All of
829/// that — and nothing about the training rows — lives on the fit already, so the
830/// state is a verbatim snapshot: no design CSR, no re-factor on load.
831#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
832pub struct GridSpline2dState {
833    /// Per response dimension: row-major coefficients `g = j1·(K+3) + j2`.
834    pub coeffs: Vec<Vec<f64>>,
835    pub log_lambda: f64,
836    /// Per response dimension: profiled (or supplied) observation variance σ².
837    pub sigma2: Vec<f64>,
838    pub restricted_loglik: f64,
839    /// Lower Cholesky factor of `X'WX + λS` (unit-σ² scale), `p × p` row-major —
840    /// the factored posterior precision the variance term solves against.
841    pub chol: Vec<f64>,
842    /// Per axis lower corner of the basis bounding box.
843    pub axis_lo: [f64; 2],
844    /// Per axis cell width `h = (hi − lo)/K`.
845    pub axis_h: [f64; 2],
846    /// Per axis cell count `K`.
847    pub axis_cells: [u64; 2],
848    /// Basis count per axis, `K + 3` (so `p = m_axis²`).
849    pub m_axis: u64,
850}
851
852impl GridSpline2dFit {
853    /// Snapshot the fit for persistence (#1031). Verbatim — every field
854    /// `predict` reads is copied; the training design is not retained on the fit
855    /// and is not needed for replay.
856    pub fn to_state(&self) -> GridSpline2dState {
857        GridSpline2dState {
858            coeffs: self.coeffs.clone(),
859            log_lambda: self.log_lambda,
860            sigma2: self.sigma2.clone(),
861            restricted_loglik: self.restricted_loglik,
862            chol: self.chol.clone(),
863            axis_lo: [self.axes[0].lo, self.axes[1].lo],
864            axis_h: [self.axes[0].h, self.axes[1].h],
865            axis_cells: [self.axes[0].cells as u64, self.axes[1].cells as u64],
866            m_axis: self.m_axis as u64,
867        }
868    }
869
870    /// Rebuild a predict-capable fit from a snapshot (#1031). Validates shape,
871    /// finiteness, positive cell widths/counts, positive σ², and that the basis
872    /// arithmetic is self-consistent (`m_axis = K + 3`, `chol` is `p × p`,
873    /// `coeffs`/`sigma2` agree on `D`), so a corrupt payload fails here rather
874    /// than inside a later `predict`. The restored fit replays the posterior
875    /// mean+variance bit-for-bit: `predict` reads only the snapshotted fields.
876    pub fn from_state(state: &GridSpline2dState) -> Result<Self, String> {
877        let m_axis = state.m_axis as usize;
878        let p = m_axis * m_axis;
879        for a in 0..2 {
880            let cells = state.axis_cells[a] as usize;
881            if cells == 0 {
882                return Err(format!(
883                    "grid spline 2d state: axis {a} must have at least one cell"
884                ));
885            }
886            if m_axis != cells + 3 {
887                return Err(format!(
888                    "grid spline 2d state: m_axis {m_axis} must equal K+3 = {} for axis {a}",
889                    cells + 3
890                ));
891            }
892            if !(state.axis_lo[a].is_finite()
893                && state.axis_h[a].is_finite()
894                && state.axis_h[a] > 0.0)
895            {
896                return Err(format!(
897                    "grid spline 2d state: axis {a} must have finite lo and positive h, got lo={}, h={}",
898                    state.axis_lo[a], state.axis_h[a]
899                ));
900            }
901        }
902        if state.chol.len() != p * p {
903            return Err(format!(
904                "grid spline 2d state: chol must be p×p = {p}² = {}, got {}",
905                p * p,
906                state.chol.len()
907            ));
908        }
909        let d = state.coeffs.len();
910        if d == 0 || state.sigma2.len() != d {
911            return Err(format!(
912                "grid spline 2d state: need ≥1 response dimension with matching σ² (coeffs D={d}, sigma2 D={})",
913                state.sigma2.len()
914            ));
915        }
916        for (dim, c) in state.coeffs.iter().enumerate() {
917            if c.len() != p {
918                return Err(format!(
919                    "grid spline 2d state: response dimension {dim} has {} coeffs, expected p = {p}",
920                    c.len()
921                ));
922            }
923        }
924        for (dim, &s2) in state.sigma2.iter().enumerate() {
925            if !(s2.is_finite() && s2 > 0.0) {
926                return Err(format!(
927                    "grid spline 2d state: response dimension {dim} has non-positive σ² = {s2}"
928                ));
929            }
930        }
931        for (i, v) in state
932            .chol
933            .iter()
934            .chain(state.coeffs.iter().flatten())
935            .enumerate()
936        {
937            if !v.is_finite() {
938                return Err(format!("grid spline 2d state: non-finite entry at {i}"));
939            }
940        }
941        // The diagonal of a lower Cholesky factor is strictly positive; a
942        // zero/negative pivot means the persisted factor is not a valid
943        // precision factor and `chol_solve` would divide by it.
944        for g in 0..p {
945            let piv = state.chol[g * p + g];
946            if !(piv.is_finite() && piv > 0.0) {
947                return Err(format!(
948                    "grid spline 2d state: non-positive Cholesky pivot {piv} at index {g}"
949                ));
950            }
951        }
952        if !(state.log_lambda.is_finite() && state.restricted_loglik.is_finite()) {
953            return Err(format!(
954                "grid spline 2d state: invalid scalars (log_lambda={}, restricted_loglik={})",
955                state.log_lambda, state.restricted_loglik
956            ));
957        }
958        let axes = [
959            Axis {
960                lo: state.axis_lo[0],
961                h: state.axis_h[0],
962                cells: state.axis_cells[0] as usize,
963            },
964            Axis {
965                lo: state.axis_lo[1],
966                h: state.axis_h[1],
967                cells: state.axis_cells[1] as usize,
968            },
969        ];
970        Ok(GridSpline2dFit {
971            coeffs: state.coeffs.clone(),
972            log_lambda: state.log_lambda,
973            sigma2: state.sigma2.clone(),
974            restricted_loglik: state.restricted_loglik,
975            chol: state.chol.clone(),
976            axes,
977            m_axis,
978        })
979    }
980
981    /// Posterior `(mean, variance)` of response dimension `dim` at an
982    /// arbitrary point: the 16-entry basis row dotted with the coefficients,
983    /// and `σ̂²_dim·x'(X'WX+λS)⁻¹x` through the retained Cholesky factor.
984    /// Outside the bounding box the boundary cell's cubic polynomial extends.
985    pub fn predict(&self, dim: usize, x1: f64, x2: f64) -> Result<(f64, f64), String> {
986        if dim >= self.coeffs.len() {
987            return Err(format!(
988                "grid spline 2d: response dimension {dim} out of range (D = {})",
989                self.coeffs.len()
990            ));
991        }
992        if !(x1.is_finite() && x2.is_finite()) {
993            return Err(format!(
994                "grid spline 2d: non-finite prediction point ({x1}, {x2})"
995            ));
996        }
997        let (idx, val) = basis_row(&self.axes, self.m_axis, x1, x2);
998        let p = self.coeffs[dim].len();
999        let mut mean = 0.0;
1000        let mut row = vec![0.0_f64; p];
1001        for e in 0..16 {
1002            mean += val[e] * self.coeffs[dim][idx[e]];
1003            row[idx[e]] += val[e];
1004        }
1005        let z = chol_solve(&self.chol, p, &row);
1006        let mut quad = 0.0;
1007        for g in 0..p {
1008            quad += row[g] * z[g];
1009        }
1010        Ok((mean, self.sigma2[dim] * quad))
1011    }
1012}
1013
1014/// Build the streaming design and fit with REML-selected λ.
1015pub fn fit_grid_spline_2d(
1016    x1: &[f64],
1017    x2: &[f64],
1018    y: &[f64],
1019    w: &[f64],
1020    k: usize,
1021    metric: [f64; 2],
1022) -> Result<GridSpline2dFit, String> {
1023    GridSpline2dDesign::build(x1, x2, y, w, k, metric)?.fit_reml()
1024}
1025
1026/// Build the streaming design and fit at a FIXED `log λ` (σ² supplied or profiled).
1027pub fn fit_grid_spline_2d_at(
1028    x1: &[f64],
1029    x2: &[f64],
1030    y: &[f64],
1031    w: &[f64],
1032    k: usize,
1033    metric: [f64; 2],
1034    log_lambda: f64,
1035    sigma2: Option<f64>,
1036) -> Result<GridSpline2dFit, String> {
1037    GridSpline2dDesign::build(x1, x2, y, w, k, metric)?.fit_at(log_lambda, sigma2)
1038}
1039
1040#[cfg(test)]
1041mod tests {
1042    use super::*;
1043
1044    /// State → JSON → from_state replays the posterior mean+variance bit-for-bit
1045    /// at held-out points (the grid carries no training CSR, so the snapshot is
1046    /// the whole predict-capable object). This is the #1031 persistence
1047    /// prerequisite the ANOVA carve consumes.
1048    #[test]
1049    fn grid_spline_2d_state_roundtrip_reproduces_predict() {
1050        let k = 8usize;
1051        // A smooth multi-output surface on a scattered grid of points.
1052        let mut x1 = Vec::new();
1053        let mut x2 = Vec::new();
1054        let mut y0 = Vec::new();
1055        let mut y1 = Vec::new();
1056        for i in 0..24 {
1057            for j in 0..24 {
1058                let a = i as f64 / 23.0;
1059                let b = j as f64 / 23.0;
1060                x1.push(a);
1061                x2.push(b);
1062                y0.push((2.5 * a).sin() * (1.7 * b).cos() + 0.3 * a * b);
1063                y1.push(a * a - 0.5 * b + 0.2 * (3.0 * a * b).cos());
1064            }
1065        }
1066        let n = x1.len();
1067        let w = vec![1.0_f64; n];
1068        let ys: Vec<&[f64]> = vec![&y0, &y1];
1069        let fit = GridSpline2dDesign::build_multi(&x1, &x2, &ys, &w, k, [1.0, 1.0])
1070            .expect("design")
1071            .fit_reml()
1072            .expect("fit");
1073
1074        let json = serde_json::to_string(&fit.to_state()).expect("serialize");
1075        let state: GridSpline2dState = serde_json::from_str(&json).expect("deserialize");
1076        let restored = GridSpline2dFit::from_state(&state).expect("restore");
1077
1078        // Held-out points, including one outside the box to exercise the
1079        // boundary-cell polynomial extension.
1080        let probes = [
1081            (0.13, 0.77),
1082            (0.41, 0.05),
1083            (0.66, 0.92),
1084            (0.99, 0.31),
1085            (1.20, -0.10),
1086        ];
1087        for dim in 0..2 {
1088            for &(p1, p2) in &probes {
1089                let (m0, v0) = fit.predict(dim, p1, p2).expect("orig predict");
1090                let (m1, v1) = restored.predict(dim, p1, p2).expect("restored predict");
1091                assert!(
1092                    (m0 - m1).abs() <= 1e-12 * (1.0 + m0.abs()),
1093                    "mean drift dim={dim} at ({p1},{p2}): {m0} vs {m1}"
1094                );
1095                assert!(
1096                    (v0 - v1).abs() <= 1e-12 * (1.0 + v0.abs()),
1097                    "variance drift dim={dim} at ({p1},{p2}): {v0} vs {v1}"
1098                );
1099            }
1100        }
1101        assert!((fit.log_lambda - restored.log_lambda).abs() <= 0.0);
1102        assert!((fit.restricted_loglik - restored.restricted_loglik).abs() <= 0.0);
1103    }
1104
1105    /// Corrupt snapshots fail loudly in `from_state`, not inside a later predict.
1106    #[test]
1107    fn grid_spline_2d_state_rejects_corruption() {
1108        let k = 6usize;
1109        // A dense grid with n > p = (k+3)² so the fit is well-posed: this test
1110        // exercises `from_state` corruption rejection, not the small-n regime,
1111        // so the fit must succeed first (n=18 ≪ p=81 left the penalized design
1112        // rank-deficient and `fit_grid_spline_2d` refused before any assertion).
1113        let side = 12usize;
1114        let mut x1 = Vec::new();
1115        let mut x2 = Vec::new();
1116        for i in 0..side {
1117            for j in 0..side {
1118                x1.push(i as f64 / (side - 1) as f64);
1119                x2.push(j as f64 / (side - 1) as f64);
1120            }
1121        }
1122        let n = x1.len();
1123        // The response must carry genuine curvature: a purely affine `a + b`
1124        // lies entirely in the penalty NULL SPACE (the spline reproduces it
1125        // exactly at any λ), so the penalized residual is identically zero and
1126        // `fit_grid_spline_2d` correctly refuses with "degenerate penalized
1127        // residual 0" — there is no variance to estimate. Add a smooth
1128        // non-null-space (curved) component so the penalized fit leaves a
1129        // positive residual and the REML criterion is well-posed; this test is
1130        // about `from_state` corruption rejection, which needs a successful fit
1131        // first.
1132        let y: Vec<f64> = x1
1133            .iter()
1134            .zip(&x2)
1135            .map(|(&a, &b)| a + b + (3.0 * a).sin() * (2.5 * b).cos())
1136            .collect();
1137        let w = vec![1.0_f64; n];
1138        let fit = fit_grid_spline_2d(&x1, &x2, &y, &w, k, [1.0, 1.0]).expect("fit");
1139
1140        let good = fit.to_state();
1141        let mut bad = good.clone();
1142        bad.chol.pop();
1143        assert!(
1144            GridSpline2dFit::from_state(&bad).is_err(),
1145            "chol length mismatch must error"
1146        );
1147
1148        let mut bad = good.clone();
1149        bad.sigma2[0] = -1.0;
1150        assert!(
1151            GridSpline2dFit::from_state(&bad).is_err(),
1152            "non-positive σ² must error"
1153        );
1154
1155        let mut bad = good.clone();
1156        bad.m_axis += 1;
1157        assert!(
1158            GridSpline2dFit::from_state(&bad).is_err(),
1159            "m_axis ≠ K+3 must error"
1160        );
1161
1162        let mut bad = good.clone();
1163        bad.axis_h[0] = 0.0;
1164        assert!(
1165            GridSpline2dFit::from_state(&bad).is_err(),
1166            "non-positive cell width must error"
1167        );
1168
1169        let mut bad = good;
1170        bad.chol[0] = 0.0;
1171        assert!(
1172            GridSpline2dFit::from_state(&bad).is_err(),
1173            "zero Cholesky pivot must error"
1174        );
1175    }
1176}