Skip to main content

gam_solve/
residual_cascade.rs

1//! Multiresolution residual cascade for scattered 2-3D smooths at huge n
2//! (compute-first primitive #3, #1032; siblings: the 1-D scan in
3//! [`crate::spline_scan`], the 2-D grid in
4//! [`gam_terms::grid_spline_2d`]).
5//!
6//! Model. In metric-scaled coordinates `z = diag(metric)·x` the smooth is
7//!   `f(z) = P(z)'γ + Σ_l Σ_j c_{l,j} · φ((z − ξ_{l,j})/δ_l)`,
8//! an unpenalized linear polynomial layer `P = {1, z_1, …, z_d}` at the root
9//! plus, per level `l = 0..L`, compactly supported Wendland bumps
10//! `φ(r) = (1−r)₊⁴(4r+1)` (positive definite and C² on ℝ³) of support radius
11//! `δ_l = OVERLAP·h_l` planted on the NEW centers of a nested net with
12//! covering radius `h_l = h₀·2^{−l}`. Coefficients are a-priori independent,
13//! `c_{l,j} ~ N(0, τ²·4^{−l(s−d/2)})` — the standard multilevel frame whose
14//! diagonal prior norm is equivalent to the Sobolev-`s` (semi)norm on
15//! quasi-uniform nested nets (Narcowich–Ward inverse estimates + Le Gia–
16//! Wendland multilevel stability; `d/2 < s ≤ (d+3)/2`, the native smoothness
17//! of the Wendland-(3,1) bump). The assembled claim is certified in-test
18//! against a dense kernel solve on small n (#904 style), not assumed.
19//!
20//! Nets. Each level's center set is a greedy hash-grid ε-net scanned in data
21//! order, seeded with the previous level's net: covering radius ≤ h_l over
22//! the data AND separation ≥ h_l — the same quasi-uniformity guarantees
23//! farthest-point sampling gives, at O(n) per level (each point checks the
24//! 3^d neighboring cells of one hash grid of cell size h_l). Nets are nested
25//! (`Ξ_0 ⊂ Ξ_1 ⊂ …`); a center carries a bump only at its birth level.
26//!
27//! Fit. With `W = diag(w)`, `D = diag(0 on the polynomial layer, d_l =
28//! 4^{l(s−d/2)} on level-l bumps)` and `λ = σ²/τ²`, the posterior mode solves
29//! `(X'WX + λD)c = X'Wy`. `X` is sparse — a row touches the O(1) bumps per
30//! level whose supports cover it, O(qL) nonzeros — and is held in CSR. For
31//! moderate column counts (`m ≤ DENSE_GRAM_MAX`) the normal equations are
32//! solved by dense Cholesky with the EXACT log-determinant (same route as the
33//! grid sibling); beyond that the solve is preconditioned CG with the two-level
34//! additive-Schwarz coarse-space preconditioner `P = blockdiag(A_CC,
35//! diag(A_FF))`. The multilevel Wendland frame is redundant across scales — a
36//! coarse bump and the fine bumps in its support are strongly correlated — so
37//! the data-fit Gram `X'WX` couples levels and a pure-diagonal preconditioner
38//! leaves a conditioning that GROWS with the number of data-identified levels
39//! (hence with n). The coarse space `C` (polynomial layer + the data-dominated
40//! coarsest levels, see `coarse_space_cols`) is solved EXACTLY by a small dense
41//! Cholesky and the penalty-dominated fine tail `F` — where `A_ll ≈ λ d_l I` is
42//! already uniformly conditioned — by its Jacobi diagonal. That deflation is
43//! what makes `P^{−1/2}(X'WX+λD)P^{−1/2}` uniformly conditioned, so the CG
44//! iteration count is genuinely n-independent (the in-test gate asserts an
45//! ADDITIVE bound across a 4× n jump, not a multiplicative one). Every CG solve
46//! reports its relative residual `‖b − Ac‖/‖b‖`: a computable backward-error
47//! certificate (`c` solves a system perturbed by no more than that fraction)
48//! inherited by every linear functional of the solution.
49//!
50//! REML. λ maximizes the profiled-σ² restricted criterion
51//!   `ℓ_R(λ) = −½[ log|X'WX+λD| − log|λD|₊ + (n−d−1)·log σ̂²(λ) ] + const`,
52//! `log|λD|₊ = r·logλ + Σ_j log d_j` over the `r` penalized columns and
53//! `σ̂² = (y'Wy − c'X'Wy)/(n−d−1)` — the same shape as the siblings, with the
54//! penalty-logdet constant kept so criteria are comparable across cascade
55//! depths. On the dense route `log|X'WX+λD|` is exact; on the iterative
56//! route it is `Σ_j log P_jj + tr log(P^{−1/2}AP^{−1/2})` with the trace
57//! estimated by stochastic Lanczos quadrature on FIXED Rademacher probes
58//! (deterministic seed, shared across every λ trial — common random numbers
59//! make the criterion a smooth deterministic function of λ, so the coarse-
60//! grid + golden-section search is exactly as deterministic as the
61//! siblings'). The diagonal split is the level-block control variate: the
62//! dominant λ-dependence rides the exactly-computed `Σ log P_jj` term and
63//! SLQ only sees the well-conditioned remainder.
64//!
65//! Refinement certificate. After fitting L levels, the candidate level L+1
66//! is constructed (O(n)) and the EXACT objective decrease available from
67//! adding it is bounded: for the penalized objective `F(c) = ‖√W(y−Xc)‖² +
68//! λc'Dc`, appending columns `X₂` with penalty `λd_{L+1}I` decreases the
69//! minimum by `g'S⁻¹g`, `g = X₂'W r̂`, `S` the Schur complement; since
70//! `A₁₁ ⪰ X₁'WX₁` and `X₂'W^{1/2}·proj·W^{1/2}X₂ ⪯ X₂'WX₂`, `S ⪰ λd_{L+1}I`,
71//! so the decrease is at most `‖X₂'W r̂‖²/(λ·d_{L+1})` — a computable
72//! discretization certificate. The cascade refines (adds the level, refits,
73//! re-selects λ) until that bound drops below `REFINE_TOL` of the penalized
74//! residual, the net stops producing new centers (every point is a center),
75//! or the level/center caps are reached: certified-or-fallback, the same
76//! discipline as the radial-profile GL ladder.
77//!
78//! Posterior. Coefficient covariance is `σ²(X'WX+λD)^{−1}`; pointwise
79//! prediction variance routes the basis row through one (certified) solve.
80//! Exact posterior samples come from perturb-and-solve: `c_s = A^{−1}(X'Wy +
81//! σ(X'W^{1/2}z₁ + √λ D^{1/2}z₂))` with iid standard-normal `z₁, z₂` has
82//! mean `ĉ` and covariance exactly `σ²A^{−1}` (deterministically seeded; one
83//! certified solve per sample).
84//!
85//! Payoff. Build O(n·(L + 3^d)), fit O(nnz · iters) per λ trial with
86//! n-independent iters — O(n log n) end to end, against the dense n×k kernel
87//! Gram + O(k³) per trial that duchon/matern pay today. Gap behavior is
88//! mechanical: levels wider than a gap keep support across it (polynomial +
89//! coarse bumps bridge), finer levels have no data and revert to their prior
90//! variance, so the posterior mean bridges instead of sagging while the
91//! variance grows into the gap.
92
93use std::collections::HashMap;
94use std::sync::Arc;
95
96use gam_terms::grid_spline_2d::{chol_solve, cholesky_logdet};
97
98/// Bump support radius as a multiple of the level's covering radius:
99/// `δ_l = OVERLAP·h_l`. Separation ≥ h_l caps the bumps covering a point at
100/// a packing constant per level (O(q) row nonzeros per level).
101const OVERLAP: f64 = 2.0;
102/// Root covering radius as a fraction of the largest scaled axis range.
103const H0_FRACTION: f64 = 0.5;
104/// Levels in the initial cascade before refinement certificates run.
105const INITIAL_LEVELS: usize = 3;
106/// Hard cap on cascade depth (h shrinks 2^16-fold below the root).
107const MAX_LEVELS: usize = 16;
108/// Hard cap on total centers across all levels.
109const MAX_CENTERS: usize = 200_000;
110/// Refinement stops when the exact next-level gain bound falls below this
111/// fraction of the penalized residual.
112const REFINE_TOL: f64 = 1e-3;
113
114/// Column count up to which the normal equations go through dense Cholesky
115/// (exact logdet, no iteration); above it, PCG + SLQ. 1536² doubles ≈ 18 MB.
116const DENSE_GRAM_MAX: usize = 1536;
117
118/// Deterministic coarse-grid width and bounds for the log-λ search (same
119/// scheme as the siblings), then golden-section refinement.
120const LOG_LAMBDA_GRID: usize = 25;
121const LOG_LAMBDA_LO: f64 = -18.0;
122const LOG_LAMBDA_HI: f64 = 18.0;
123const LOG_LAMBDA_TOL: f64 = 1e-6;
124
125/// PCG convergence: relative residual ‖b − Ac‖/‖b‖ (the backward-error
126/// certificate) demanded of every solve, and the iteration cap past which
127/// the solve is an error rather than a silent approximation. The certification
128/// suite gates the iterative route at 1e-9; asking for more burns matvecs
129/// without strengthening any downstream certificate.
130const CG_RTOL: f64 = 1e-9;
131const CG_MAX_ITERS: usize = 4000;
132
133/// Coarse-space additive-Schwarz preconditioner controls (issue #1032: the
134/// "BPX/level-diagonal preconditioned CG, n-independent iters" spec).
135///
136/// The multilevel Wendland frame is redundant across scales — a coarse bump and
137/// the fine bumps inside its support are strongly correlated — so the data-fit
138/// Gram `X'WX` couples levels and a pure-diagonal (Jacobi) preconditioner leaves
139/// a conditioning that grows with the number of *data-identified* levels, hence
140/// with `n` (more rows ⇒ finer levels carry data ⇒ another collinear coarse
141/// scale the diagonal can't decouple). The cure is the textbook two-level
142/// additive Schwarz coarse space: solve the coarse block — the polynomial layer
143/// plus every level the penalty has NOT yet made diagonally dominant — EXACTLY,
144/// and precondition the remaining penalty-dominated fine levels (where
145/// `A_ll ≈ λ d_l I` is already uniformly conditioned) by their Jacobi diagonal.
146///
147/// A level is "data-dominated" while `λ d_l < COARSE_DOMINANCE · median diag
148/// (X'WX) over the level`. Because columns are laid out poly, level-0, level-1,
149/// … and `d_l` increases while the per-level data weight decreases, the
150/// data-dominated levels are exactly the coarsest prefix `[0, ncoarse)`, so the
151/// coarse space is a contiguous column prefix and the cut is a single scan. The
152/// crossover level grows only as `½ log₄(n/λ)` — `ncoarse = O(√(n/λ))` columns —
153/// so the exact coarse factorization stays small against the sparse matvecs at
154/// every n the primitive serves. [`COARSE_SPACE_MAX`] caps it as a safety valve
155/// (past the cap the finer data-dominated levels fall back to Jacobi and the
156/// iteration count rises, but the CG residual certificate still guarantees the
157/// solve); [`MIN_COARSE_LEVELS`] always deflates the two coarsest scales, which
158/// are near-collinear with the polynomial layer at every λ.
159const COARSE_DOMINANCE: f64 = 4.0;
160/// Safety ceiling on the exact-coarse column count. It must NOT bind at the n
161/// the primitive serves: the n-independent iteration count rests on the coarse
162/// block containing the WHOLE data-dominated prefix (`O(√(n/λ))` columns), so a
163/// cap that truncates that prefix is exactly what makes the iteration count
164/// climb with n (a finer data-dominated level demoted to Jacobi cannot be
165/// decoupled from the coarse scales it is collinear with). At the n-scales the
166/// iterative route engages (tens of thousands of rows → a ≈1.4k-column
167/// prefix) this is non-binding headroom; it only triggers in the genuinely
168/// degenerate case the quasi-uniformity guard is meant to catch first. The
169/// realized coarse factorization runs at the actual prefix length, not the cap,
170/// so the ceiling costs nothing until it fires.
171const COARSE_SPACE_MAX: usize = 4096;
172const MIN_COARSE_LEVELS: usize = 2;
173
174/// Quasi-uniformity guard (issue #1032, caveat 2). The BPX n-independent CG
175/// iteration bound rests on the nested ε-nets being quasi-uniform *in the
176/// metric-scaled coordinates `z = diag(metric)·x` the bumps live in*. The
177/// greedy net guarantees covering ≤ h and separation ≥ h in `z` by
178/// construction, so the only way the BPX norm-equivalence constant blows up is
179/// when the metric is so anisotropic that the metric-scaled point cloud is
180/// effectively degenerate along a direction — the data collapses onto a lower
181/// dimension in `z`, the root covering radius `h₀ = ½·max_a range_a` swamps the
182/// collapsed axis, the level-`l` bumps overlap pathologically, and the
183/// preconditioner constant (hence the iteration count) grows without an
184/// n-independent bound. The realized symptom is `solve_iters` climbing toward
185/// [`CG_MAX_ITERS`]; this guard detects the *cause* up front from the
186/// metric-scaled per-axis spread so the auto-route can fall back to the dense
187/// kernel BEFORE paying an unbounded iterative solve, rather than discovering
188/// the blow-up only after `CG_MAX_ITERS` work.
189///
190/// Condition measure: the ratio of the largest to smallest metric-scaled
191/// per-axis standard deviation (a scale-free aspect ratio of the scaled
192/// cloud). Past this threshold the net is no longer quasi-uniform in every
193/// direction and the BPX bound is not trustworthy. Derived, not a knob: a
194/// `10³` aspect ratio means the collapsed axis carries <0.1% of the dominant
195/// axis's variation, at which point its bumps span the whole cloud and the
196/// multilevel hierarchy degenerates to a single ill-conditioned level.
197const QUASI_UNIFORMITY_MAX_ASPECT: f64 = 1.0e3;
198
199/// SLQ controls: fixed Rademacher probes (shared across λ trials) and the
200/// Lanczos depth per probe (full reorthogonalization; early exit on
201/// breakdown).
202const SLQ_PROBES: usize = 24;
203const SLQ_LANCZOS_STEPS: usize = 48;
204
205/// Deterministic seed for the SLQ probes and posterior samples.
206const RNG_SEED: u64 = 0x1032_CA5C_ADE0_5EED;
207
208/// Floor for eigenvalues/pivots before the system is declared singular.
209const EIG_FLOOR: f64 = 1e-300;
210
211// ───────────────────────────── deterministic RNG ────────────────────────────
212
213/// SplitMix64: tiny, deterministic, full-period stream generator.
214struct SplitMix64(u64);
215
216impl SplitMix64 {
217    fn new(seed: u64) -> Self {
218        SplitMix64(seed)
219    }
220
221    fn next_u64(&mut self) -> u64 {
222        gam_linalg::utils::splitmix64(&mut self.0)
223    }
224
225    /// Uniform in (0, 1): 53-bit mantissa, shifted off zero.
226    fn next_unit(&mut self) -> f64 {
227        ((self.next_u64() >> 11) as f64 + 0.5) / 9_007_199_254_740_992.0
228    }
229
230    /// Standard normal via Box–Muller.
231    fn next_normal(&mut self) -> f64 {
232        let u1 = self.next_unit();
233        let u2 = self.next_unit();
234        (-2.0 * u1.ln()).sqrt() * (std::f64::consts::TAU * u2).cos()
235    }
236
237    /// Rademacher ±1.
238    fn next_sign(&mut self) -> f64 {
239        if self.next_u64() & 1 == 0 { 1.0 } else { -1.0 }
240    }
241}
242
243// ─────────────────────────────── hash grids ─────────────────────────────────
244
245/// Integer cell of a point at a given cell width (coordinates are already
246/// metric-scaled and shifted to be ≥ 0, so indices are small and exact).
247#[inline]
248fn cell_of(z: &[f64; 3], dim: usize, width: f64) -> (i32, i32, i32) {
249    let mut c = [0_i32; 3];
250    for a in 0..dim {
251        c[a] = (z[a] / width).floor() as i32;
252    }
253    (c[0], c[1], c[2])
254}
255
256/// Hash grid over a point set: cell → indices. Lookup scans the 3^d
257/// neighborhood, which covers every point within one cell width.
258struct HashGrid {
259    width: f64,
260    dim: usize,
261    cells: HashMap<(i32, i32, i32), Vec<u32>>,
262}
263
264impl HashGrid {
265    fn new(width: f64, dim: usize) -> Self {
266        HashGrid {
267            width,
268            dim,
269            cells: HashMap::new(),
270        }
271    }
272
273    fn insert(&mut self, idx: u32, z: &[f64; 3]) {
274        let key = cell_of(z, self.dim, self.width);
275        self.cells.entry(key).or_default().push(idx);
276    }
277
278    /// Visit every stored index in the 3^d cells around `z` (deterministic
279    /// order: lexicographic cells, insertion order within a cell).
280    fn for_neighbors(&self, z: &[f64; 3], mut visit: impl FnMut(u32)) {
281        let (c0, c1, c2) = cell_of(z, self.dim, self.width);
282        let d2 = if self.dim > 2 { 1 } else { 0 };
283        let d1 = if self.dim > 1 { 1 } else { 0 };
284        for i0 in -1..=1_i32 {
285            for i1 in -d1..=d1 {
286                for i2 in -d2..=d2 {
287                    if let Some(bucket) = self.cells.get(&(c0 + i0, c1 + i1, c2 + i2)) {
288                        for &idx in bucket {
289                            visit(idx);
290                        }
291                    }
292                }
293            }
294        }
295    }
296}
297
298#[inline]
299fn dist2(a: &[f64; 3], b: &[f64; 3], dim: usize) -> f64 {
300    let mut s = 0.0;
301    for k in 0..dim {
302        let d = a[k] - b[k];
303        s += d * d;
304    }
305    s
306}
307
308/// Wendland-(3,1) bump `(1−r)₊⁴(4r+1)`: positive definite on ℝ^d, d ≤ 3,
309/// C², native space H^{(d+3)/2}.
310#[inline]
311fn wendland(r: f64) -> f64 {
312    if r >= 1.0 {
313        return 0.0;
314    }
315    let v = 1.0 - r;
316    let v2 = v * v;
317    v2 * v2 * (4.0 * r + 1.0)
318}
319
320// ───────────────────────────── design assembly ──────────────────────────────
321
322/// One resolution level: its NEW centers (scaled coordinates), covering
323/// radius, support radius, prior precision weight, and a lookup grid of cell
324/// width δ_l over those centers.
325struct Level {
326    h: f64,
327    delta: f64,
328    /// Prior precision weight `d_l = 4^{l(s−d/2)}` (prior variance τ²/d_l).
329    weight: f64,
330    centers: Vec<[f64; 3]>,
331    /// First flat column index of this level's coefficients.
332    col_offset: usize,
333    grid: HashGrid,
334}
335
336/// Immutable fitted-design core shared between the design handle and fits.
337struct Core {
338    dim: usize,
339    metric: [f64; 3],
340    /// Lower corner / range of the scaled bounding box (polynomial layer
341    /// coordinates are `2(z − lo)/range − 1` for conditioning).
342    z_lo: [f64; 3],
343    z_range: [f64; 3],
344    sobolev_s: f64,
345    levels: Vec<Level>,
346    /// Full nested net Ξ_L (scaled coords), retained so the candidate level
347    /// L+1 can extend it without re-deriving coarser levels.
348    net: Vec<[f64; 3]>,
349    /// Total columns: `dim + 1` polynomial + all level centers.
350    m: usize,
351    /// CSR design rows (column-sorted within a row).
352    row_ptr: Vec<usize>,
353    col_idx: Vec<u32>,
354    vals: Vec<f64>,
355    /// Inputs retained for matvecs, residuals, and refinement.
356    w: Vec<f64>,
357    y: Vec<f64>,
358    /// Scaled data coordinates (shifted to the box corner).
359    z: Vec<[f64; 3]>,
360    /// `X'Wy`, `y'Wy`, `diag(X'WX)`.
361    rhs: Vec<f64>,
362    ytwy: f64,
363    gram_diag: Vec<f64>,
364    /// Per-column prior precision weight (0 on the polynomial layer).
365    pen_diag: Vec<f64>,
366    /// `Σ_j log d_j` over penalized columns (the λ-free part of log|λD|₊,
367    /// kept so REML criteria compare across cascade depths).
368    pen_logdet_const: f64,
369    /// Dense upper-triangular `X'WX` when `m ≤ DENSE_GRAM_MAX` (row-major
370    /// m×m, lower mirror filled at solve time); None on the iterative route.
371    dense_gram: Option<Vec<f64>>,
372    /// Predict-only factored precision: the lower Cholesky factor `L` of
373    /// `A = X'WX + λD` at the FIT's λ, populated only on a core rebuilt from a
374    /// persisted [`ResidualCascadeState`] (where the training CSR is dropped).
375    /// When present, `solve_coeff` replays the posterior-variance solve through
376    /// this factor instead of the absent training design; `None` on a
377    /// training-built core, which solves through `dense_gram`/PCG as usual.
378    predict_chol: Option<Vec<f64>>,
379}
380
381/// Solver route a fit took for its log-determinant.
382#[derive(Clone, Copy, Debug, PartialEq, Eq)]
383pub enum LogdetMethod {
384    /// Dense Cholesky: exact.
385    DenseExact,
386    /// Diagonal control variate + stochastic Lanczos quadrature on fixed
387    /// deterministic probes.
388    Slq,
389}
390
391/// Computable certificates attached to a fit.
392#[derive(Clone, Copy, Debug)]
393pub struct CascadeCertificate {
394    /// Backward error of the coefficient solve: ‖b − Aĉ‖/‖b‖ (0 on the dense
395    /// route).
396    pub solve_rel_residual: f64,
397    /// CG iterations of the coefficient solve (0 on the dense route); the
398    /// n-independence gate watches this.
399    pub solve_iters: usize,
400    /// Route the log-determinant took.
401    pub logdet_method: LogdetMethod,
402}
403
404/// Discretization certificate of the refinement loop: the exact upper bound
405/// on the penalized-objective decrease available from one more level.
406#[derive(Clone, Copy, Debug)]
407pub struct RefinementCertificate {
408    /// `‖X_{L+1}'W r̂‖² / (λ·d_{L+1})` at the accepted fit.
409    pub next_level_gain_bound: f64,
410    /// The absolute tolerance it was compared against (`REFINE_TOL·rss_pen`).
411    pub tolerance: f64,
412    /// True when refinement stopped because the net produced no new centers
413    /// or a cap was reached rather than because the bound passed.
414    pub exhausted: bool,
415}
416
417/// Multiresolution residual-cascade design: nested nets, sparse design,
418/// diagonal multilevel prior — everything needed to evaluate the REML
419/// criterion and solve at any λ.
420pub struct ResidualCascadeDesign {
421    core: Arc<Core>,
422}
423
424/// Fitted cascade with factored-by-solve posterior access.
425pub struct ResidualCascadeFit {
426    core: Arc<Core>,
427    /// Dense-route prediction factor at the fit's λ. When present, pointwise
428    /// variance uses this one Cholesky factor instead of refactoring the same
429    /// precision matrix for every prediction point.
430    predict_chol: Option<Vec<f64>>,
431    /// Coefficients: `dim+1` polynomial entries, then level blocks.
432    pub coeff: Vec<f64>,
433    /// Selected (or supplied) log smoothing parameter `log λ = log σ²/τ²`.
434    pub log_lambda: f64,
435    /// Profiled (or supplied) observation variance σ².
436    pub sigma2: f64,
437    /// Restricted log-likelihood at the fit, up to λ- and data-independent
438    /// additive constants (exact REML differences across λ on the dense
439    /// route; SLQ-estimated on the iterative route).
440    pub restricted_loglik: f64,
441    /// Penalized residual quadratic `y'Wy − c'X'Wy`.
442    pub rss_pen: f64,
443    /// Solve/logdet certificates.
444    pub certificate: CascadeCertificate,
445    /// Present when the fit came from the refinement loop.
446    pub refinement: Option<RefinementCertificate>,
447}
448
449/// One resolution level's geometry in a persisted snapshot: the data needed to
450/// rebuild a [`Level`] (its lookup grid, bumps, and column block) without the
451/// training rows. Centers are flattened `dim`-major (`dim` floats per center).
452#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
453pub struct LevelState {
454    pub h: f64,
455    pub delta: f64,
456    pub weight: f64,
457    pub col_offset: u64,
458    /// `dim·n_centers` scaled-coordinate floats, center-major.
459    pub centers: Vec<f64>,
460}
461
462/// Serializable snapshot of a [`ResidualCascadeFit`] (#1032 persistence
463/// prerequisite). Holds everything `predict` needs and NOTHING about the
464/// training rows:
465/// - MEAN: the nested geometry (`dim`/`metric`/box/`sobolev_s` + per-level
466///   centers/δ/weights/col-offsets) and the root polynomial layer are all that
467///   `basis_row_scaled`·`coeff` reads;
468/// - VARIANCE: the factored precision `predict_chol` — the lower Cholesky factor
469///   `L` of `A = X'WX + λD` at the fit's λ — which the posterior-variance solve
470///   `x'A⁻¹x` replays against (the training design that originally assembled `A`
471///   is dropped).
472///
473/// `from_state` rebuilds a predict-capable fit whose `Core` carries empty
474/// training CSR and `predict_chol = Some(L)`; `solve_coeff` then routes the
475/// variance solve through `L`. The reconstructed fit cannot be re-fit or
476/// resampled (it has no rows), only predicted from — exactly the persistence
477/// contract.
478#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
479pub struct ResidualCascadeState {
480    pub dim: u64,
481    /// Per-axis metric scaling (length 3; trailing entries are 1 for `dim < 3`).
482    pub metric: [f64; 3],
483    pub z_lo: [f64; 3],
484    pub z_range: [f64; 3],
485    pub sobolev_s: f64,
486    pub levels: Vec<LevelState>,
487    /// Total column count `dim + 1 + Σ centers`.
488    pub m: u64,
489    /// `Σ_j log d_j` over penalized columns (kept so restored REML scalars stay
490    /// comparable across cascade depths).
491    pub pen_logdet_const: f64,
492    /// Posterior-mode coefficients (length `m`).
493    pub coeff: Vec<f64>,
494    pub log_lambda: f64,
495    pub sigma2: f64,
496    pub restricted_loglik: f64,
497    pub rss_pen: f64,
498    /// Lower Cholesky factor `L` of `A = X'WX + λD` at the fit's λ, `m × m`
499    /// row-major — the factored precision the variance solve replays through.
500    pub predict_chol: Vec<f64>,
501}
502
503/// Forward substitution `L y = b` (lower factor, row-major) into `out`.
504fn forward_sub_into(l: &[f64], p: usize, b: &[f64], out: &mut [f64]) {
505    for i in 0..p {
506        let mut s = b[i];
507        for t in 0..i {
508            s -= l[i * p + t] * out[t];
509        }
510        out[i] = s / l[i * p + i];
511    }
512}
513
514/// Back substitution `Lᵀ z = y` (lower factor, row-major) into `out`.
515fn back_sub_into(l: &[f64], p: usize, y: &[f64], out: &mut [f64]) {
516    for i in (0..p).rev() {
517        let mut s = y[i];
518        for t in i + 1..p {
519            s -= l[t * p + i] * out[t];
520        }
521        out[i] = s / l[i * p + i];
522    }
523}
524
525/// Coarse-space additive-Schwarz preconditioner for the iterative route
526/// (issue #1032). `A = X'WX + λD` is preconditioned by the symmetric positive
527/// definite block-diagonal `P = blockdiag(A_CC, diag(A_FF))`, where the coarse
528/// index set `C = [0, ncoarse)` is the polynomial layer plus the data-dominated
529/// (coarsest) levels and `F` the penalty-dominated fine tail — see the
530/// [`COARSE_DOMINANCE`]/[`COARSE_SPACE_MAX`] docs for why this delivers
531/// n-independent CG iteration counts where the pure-Jacobi diagonal does not.
532///
533/// `solve` applies `P⁻¹` (exact coarse Cholesky solve ⊕ fine Jacobi). For the
534/// SLQ log-determinant the symmetric factor `R = blockdiag(L_CC, diag√A_FF)`
535/// with `P = R Rᵀ` is exposed through `apply_r_inv`/`apply_r_inv_t`, and
536/// `log|P| = log|A_CC| + Σ_F log A_jj`.
537struct Preconditioner {
538    /// First fine column; coarse block is the principal `[0, ncoarse)` submatrix.
539    ncoarse: usize,
540    /// Lower Cholesky factor of the coarse block `A_CC` (`ncoarse × ncoarse`).
541    coarse_chol: Vec<f64>,
542    /// `log|A_CC|` (exact).
543    coarse_logdet: f64,
544    /// `1/A_jj` on the fine columns `[ncoarse, m)`.
545    inv_fine: Vec<f64>,
546    /// `1/√A_jj` on the fine columns (the `R⁻¹`/`R⁻ᵀ` fine scaling).
547    inv_sqrt_fine: Vec<f64>,
548    /// `Σ_F log A_jj` (the fine part of `log|P|`).
549    fine_logdet: f64,
550}
551
552impl Preconditioner {
553    /// `out = P⁻¹ r`: exact coarse solve on `[0, ncoarse)`, Jacobi on the tail.
554    fn solve(&self, r: &[f64], out: &mut [f64]) {
555        let nc = self.ncoarse;
556        let zc = chol_solve(&self.coarse_chol, nc, &r[..nc]);
557        out[..nc].copy_from_slice(&zc);
558        for (k, o) in out[nc..].iter_mut().enumerate() {
559            *o = r[nc + k] * self.inv_fine[k];
560        }
561    }
562
563    /// `out = R⁻ᵀ v` (coarse: `L_CCᵀ` back-solve; fine: `/√A_jj`).
564    fn apply_r_inv_t(&self, v: &[f64], out: &mut [f64]) {
565        let nc = self.ncoarse;
566        back_sub_into(&self.coarse_chol, nc, &v[..nc], &mut out[..nc]);
567        for (k, o) in out[nc..].iter_mut().enumerate() {
568            *o = v[nc + k] * self.inv_sqrt_fine[k];
569        }
570    }
571
572    /// `out = R⁻¹ v` (coarse: `L_CC` forward-solve; fine: `/√A_jj`).
573    fn apply_r_inv(&self, v: &[f64], out: &mut [f64]) {
574        let nc = self.ncoarse;
575        forward_sub_into(&self.coarse_chol, nc, &v[..nc], &mut out[..nc]);
576        for (k, o) in out[nc..].iter_mut().enumerate() {
577            *o = v[nc + k] * self.inv_sqrt_fine[k];
578        }
579    }
580
581    /// `log|P| = log|A_CC| + Σ_F log A_jj`.
582    fn logdet(&self) -> f64 {
583        self.coarse_logdet + self.fine_logdet
584    }
585}
586
587impl Core {
588    /// Scale a raw point into shifted metric coordinates.
589    fn scale_point(&self, x: &[f64]) -> [f64; 3] {
590        let mut z = [0.0_f64; 3];
591        for a in 0..self.dim {
592            z[a] = self.metric[a] * x[a] - self.z_lo[a];
593        }
594        z
595    }
596
597    /// Sparse basis row at a scaled point: polynomial layer then every bump
598    /// whose support covers it, as (column, value) pairs sorted by column.
599    fn basis_row_scaled(&self, z: &[f64; 3]) -> Vec<(usize, f64)> {
600        let mut row = Vec::with_capacity(self.dim + 1 + self.levels.len() * 8);
601        row.push((0, 1.0));
602        for a in 0..self.dim {
603            row.push((a + 1, 2.0 * z[a] / self.z_range[a] - 1.0));
604        }
605        for level in &self.levels {
606            let start = row.len();
607            level.grid.for_neighbors(z, |j| {
608                let c = &level.centers[j as usize];
609                let r = dist2(z, c, self.dim).sqrt() / level.delta;
610                let v = wendland(r);
611                if v > 0.0 {
612                    row.push((level.col_offset + j as usize, v));
613                }
614            });
615            row[start..].sort_unstable_by_key(|&(col, _)| col);
616        }
617        row
618    }
619
620    /// `out = (X'WX + λD)·v` through the CSR rows: O(nnz).
621    fn matvec(&self, lambda: f64, v: &[f64], out: &mut [f64]) {
622        for (o, (&d, &x)) in out.iter_mut().zip(self.pen_diag.iter().zip(v.iter())) {
623            *o = lambda * d * x;
624        }
625        for i in 0..self.w.len() {
626            let lo = self.row_ptr[i];
627            let hi = self.row_ptr[i + 1];
628            let mut t = 0.0;
629            for e in lo..hi {
630                t += self.vals[e] * v[self.col_idx[e] as usize];
631            }
632            t *= self.w[i];
633            for e in lo..hi {
634                out[self.col_idx[e] as usize] += self.vals[e] * t;
635            }
636        }
637    }
638
639    /// Jacobi / level-diagonal preconditioner: `diag(X'WX) + λ·diag(λD)`.
640    /// Levels share a constant prior weight, so this IS the level-block
641    /// (BPX-flavored) diagonal in the multilevel frame.
642    /// Coarse column count of the additive-Schwarz coarse space at `λ`: the
643    /// polynomial layer plus the longest prefix of data-dominated levels
644    /// (`λ d_l < COARSE_DOMINANCE · median diag(X'WX) over the level`), with the
645    /// two coarsest levels always deflated and the total capped at
646    /// [`COARSE_SPACE_MAX`]. Because `d_l` rises while the per-level data weight
647    /// falls, the data-dominated set is a contiguous prefix, so one scan from the
648    /// coarsest level finds the cut. (See [`COARSE_DOMINANCE`].)
649    fn coarse_space_cols(&self, lambda: f64) -> usize {
650        let mut ncoarse = self.nullity();
651        let mut buf: Vec<f64> = Vec::new();
652        for (li, level) in self.levels.iter().enumerate() {
653            let a = level.col_offset;
654            let b = a + level.centers.len();
655            if b <= a {
656                continue;
657            }
658            if b > COARSE_SPACE_MAX {
659                break;
660            }
661            let dominated = if li < MIN_COARSE_LEVELS {
662                true
663            } else {
664                buf.clear();
665                buf.extend_from_slice(&self.gram_diag[a..b]);
666                buf.sort_unstable_by(|x, y| x.partial_cmp(y).unwrap());
667                let gram_median = buf[buf.len() / 2];
668                lambda * level.weight < COARSE_DOMINANCE * gram_median
669            };
670            if dominated {
671                ncoarse = b;
672            } else {
673                break;
674            }
675        }
676        // Keep at least one fine column so the split is well-defined; if every
677        // level is coarse the iterative route is degenerate anyway and the dense
678        // route would have been taken, but guard regardless.
679        let ncoarse = ncoarse.min(self.m);
680        // Debug-only coarse-space layout trace (#1032). Gated on the log level so
681        // the per-call string build stays out of this preconditioner hot path,
682        // and routed through `log` (an `eprintln!` here trips the src banned-macro
683        // gate and broke the build).
684        if log::log_enabled!(log::Level::Debug) {
685            let mut s = String::new();
686            for (li, level) in self.levels.iter().enumerate() {
687                let a = level.col_offset;
688                let b = a + level.centers.len();
689                let mut buf: Vec<f64> = self.gram_diag[a..b].to_vec();
690                buf.sort_unstable_by(|x, y| x.partial_cmp(y).unwrap());
691                let med = if buf.is_empty() {
692                    0.0
693                } else {
694                    buf[buf.len() / 2]
695                };
696                let coarse = b <= ncoarse;
697                s.push_str(&format!(
698                    " L{li}[{}c off{a} w={:.2e} λw={:.2e} med={:.2e} {}]",
699                    level.centers.len(),
700                    level.weight,
701                    lambda * level.weight,
702                    med,
703                    if coarse { "C" } else { "F" }
704                ));
705            }
706            log::debug!(
707                "[1032-COARSE] λ={lambda:.3e} m={} ncoarse={ncoarse} cap={COARSE_SPACE_MAX}{s}",
708                self.m
709            );
710        }
711        ncoarse
712    }
713
714    /// Build the coarse-space additive-Schwarz preconditioner at `λ`: assemble
715    /// and factor the coarse block `A_CC` from the CSR (coarse columns are the
716    /// prefix `[0, ncoarse)`, and each CSR row is column-sorted, so a row's
717    /// coarse entries are its leading run), then the Jacobi diagonal on the fine
718    /// tail. `O(n · q_C²) + O(ncoarse³)` — paid once per `λ`, not per CG step.
719    fn build_preconditioner(&self, lambda: f64) -> Result<Preconditioner, String> {
720        let m = self.m;
721        let nc = self.coarse_space_cols(lambda);
722        let mut acc = vec![0.0_f64; nc * nc];
723        for i in 0..self.w.len() {
724            let lo = self.row_ptr[i];
725            let hi = self.row_ptr[i + 1];
726            // Leading run of coarse columns (CSR rows are column-sorted).
727            let mut end = lo;
728            while end < hi && (self.col_idx[end] as usize) < nc {
729                end += 1;
730            }
731            for ea in lo..end {
732                let ca = self.col_idx[ea] as usize;
733                let va = self.w[i] * self.vals[ea];
734                for eb in ea..end {
735                    let cb = self.col_idx[eb] as usize;
736                    acc[ca * nc + cb] += va * self.vals[eb];
737                }
738            }
739        }
740        for i in 0..nc {
741            for j in i + 1..nc {
742                acc[j * nc + i] = acc[i * nc + j];
743            }
744        }
745        for i in 0..nc {
746            acc[i * nc + i] += lambda * self.pen_diag[i];
747        }
748        let coarse_logdet = cholesky_logdet(&mut acc, nc)?;
749        let mut inv_fine = Vec::with_capacity(m - nc);
750        let mut inv_sqrt_fine = Vec::with_capacity(m - nc);
751        let mut fine_logdet = 0.0;
752        for j in nc..m {
753            let p = self.gram_diag[j] + lambda * self.pen_diag[j];
754            if !(p.is_finite() && p > EIG_FLOOR) {
755                return Err(format!(
756                    "residual cascade: non-positive preconditioner diagonal {p} at column {j}"
757                ));
758            }
759            inv_fine.push(1.0 / p);
760            inv_sqrt_fine.push(1.0 / p.sqrt());
761            fine_logdet += p.ln();
762        }
763        Ok(Preconditioner {
764            ncoarse: nc,
765            coarse_chol: acc,
766            coarse_logdet,
767            inv_fine,
768            inv_sqrt_fine,
769            fine_logdet,
770        })
771    }
772
773    /// Preconditioned CG on `(X'WX + λD)c = b` to relative residual CG_RTOL.
774    /// Returns the solution with its backward-error certificate.
775    fn pcg(
776        &self,
777        lambda: f64,
778        b: &[f64],
779        warm: Option<&[f64]>,
780    ) -> Result<(Vec<f64>, f64, usize), String> {
781        let m = self.m;
782        let prec = self.build_preconditioner(lambda)?;
783        let b_norm = b.iter().map(|v| v * v).sum::<f64>().sqrt();
784        if b_norm == 0.0 {
785            return Ok((vec![0.0; m], 0.0, 0));
786        }
787        let mut zv = vec![0.0; m];
788        let mut x = match warm {
789            Some(x0) => {
790                if x0.len() != m {
791                    return Err(format!(
792                        "residual cascade: warm-start length {} != system size {m}",
793                        x0.len()
794                    ));
795                }
796                x0.to_vec()
797            }
798            None => {
799                prec.solve(b, &mut zv);
800                zv.clone()
801            }
802        };
803        let mut r = vec![0.0; m];
804        self.matvec(lambda, &x, &mut r);
805        for (ri, &bi) in r.iter_mut().zip(b.iter()) {
806            *ri = bi - *ri;
807        }
808        prec.solve(&r, &mut zv);
809        let mut p_dir = zv.clone();
810        let mut rz: f64 = r.iter().zip(zv.iter()).map(|(&a, &c)| a * c).sum();
811        let mut ap = vec![0.0; m];
812        let max_iters = CG_MAX_ITERS;
813        for iter in 0..max_iters {
814            let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
815            if r_norm <= CG_RTOL * b_norm {
816                return Ok((x, r_norm / b_norm, iter));
817            }
818            self.matvec(lambda, &p_dir, &mut ap);
819            let pap: f64 = p_dir.iter().zip(ap.iter()).map(|(&a, &c)| a * c).sum();
820            if !(pap.is_finite() && pap > 0.0) {
821                return Err(format!(
822                    "residual cascade: CG curvature breakdown (p'Ap = {pap}) at iteration {iter}"
823                ));
824            }
825            let alpha = rz / pap;
826            for j in 0..m {
827                x[j] += alpha * p_dir[j];
828                r[j] -= alpha * ap[j];
829            }
830            prec.solve(&r, &mut zv);
831            let rz_new: f64 = r.iter().zip(zv.iter()).map(|(&a, &c)| a * c).sum();
832            let beta = rz_new / rz;
833            rz = rz_new;
834            for j in 0..m {
835                p_dir[j] = zv[j] + beta * p_dir[j];
836            }
837        }
838        Err(format!(
839            "residual cascade: CG failed to reach relative residual {CG_RTOL} within \
840             {CG_MAX_ITERS} iterations (the coarse-space additive-Schwarz preconditioner should \
841             make this n-independent; this indicates a degenerate design)"
842        ))
843    }
844
845    /// Expand the cached dense upper Gram + λD into a full symmetric matrix.
846    fn dense_system(&self, lambda: f64) -> Option<Vec<f64>> {
847        let gram = self.dense_gram.as_ref()?;
848        let m = self.m;
849        let mut a = vec![0.0; m * m];
850        for i in 0..m {
851            for j in i..m {
852                let mut v = gram[i * m + j];
853                if i == j {
854                    v += lambda * self.pen_diag[i];
855                }
856                a[i * m + j] = v;
857                a[j * m + i] = v;
858            }
859        }
860        Some(a)
861    }
862
863    /// Exact log-determinant of `X'WX + λD` by dense Cholesky. Errors when
864    /// the design is past the dense sizing cap.
865    fn logdet_dense(&self, lambda: f64) -> Result<f64, String> {
866        let mut a = self.dense_system(lambda).ok_or_else(|| {
867            format!(
868                "residual cascade: dense logdet requested past the sizing cap \
869                 (m = {} > {DENSE_GRAM_MAX})",
870                self.m
871            )
872        })?;
873        cholesky_logdet(&mut a, self.m)
874    }
875
876    /// SLQ log-determinant: exact control variate `log|P|` (the coarse-space
877    /// additive-Schwarz preconditioner's own log-determinant — `log|A_CC|` plus
878    /// the fine Jacobi `Σ_F log A_jj`) plus stochastic Lanczos quadrature for
879    /// `tr log(R⁻¹ A R⁻ᵀ)`, `P = R Rᵀ`, on fixed deterministic Rademacher probes
880    /// shared across every λ (common random numbers ⇒ the REML criterion is a
881    /// smooth deterministic function of λ). The same coarse deflation that makes
882    /// the PCG iteration count n-independent makes `R⁻¹ A R⁻ᵀ` uniformly
883    /// conditioned, so the Lanczos quadrature converges in a depth-independent
884    /// number of steps too.
885    fn logdet_slq(&self, lambda: f64) -> Result<f64, String> {
886        let m = self.m;
887        let prec = self.build_preconditioner(lambda)?;
888        let logdet = prec.logdet();
889        // M·v = R⁻¹ A R⁻ᵀ v (eigenvalues of P^{−1/2} A P^{−1/2}) without forming M.
890        let mut scratch_in = vec![0.0; m];
891        let mut scratch_out = vec![0.0; m];
892        let mut vbuf = vec![0.0; m];
893        let mut trace_est = 0.0;
894        let steps = SLQ_LANCZOS_STEPS.min(m);
895        let mut basis: Vec<Vec<f64>> = Vec::with_capacity(steps);
896        for probe in 0..SLQ_PROBES {
897            let mut rng =
898                SplitMix64::new(RNG_SEED ^ (probe as u64).wrapping_mul(0xD134_2543_DE82_EF95));
899            let mut q = vec![0.0; m];
900            for qj in q.iter_mut() {
901                *qj = rng.next_sign();
902            }
903            let z_norm2 = m as f64;
904            let inv_norm = 1.0 / (m as f64).sqrt();
905            for qj in q.iter_mut() {
906                *qj *= inv_norm;
907            }
908            // Lanczos with full reorthogonalization.
909            basis.clear();
910            let mut alpha = Vec::with_capacity(steps);
911            let mut beta: Vec<f64> = Vec::with_capacity(steps);
912            let mut q_prev: Option<Vec<f64>> = None;
913            for _step in 0..steps {
914                // v = R⁻¹ A R⁻ᵀ q.
915                prec.apply_r_inv_t(&q, &mut scratch_in);
916                self.matvec(lambda, &scratch_in, &mut scratch_out);
917                prec.apply_r_inv(&scratch_out, &mut vbuf);
918                let mut v: Vec<f64> = vbuf.clone();
919                let a: f64 = v.iter().zip(q.iter()).map(|(&x, &y)| x * y).sum();
920                alpha.push(a);
921                for j in 0..m {
922                    v[j] -= a * q[j];
923                }
924                if let Some(prev) = &q_prev {
925                    let b_prev = beta.last().copied().unwrap_or(0.0);
926                    for j in 0..m {
927                        v[j] -= b_prev * prev[j];
928                    }
929                }
930                // Full reorthogonalization against the stored basis.
931                basis.push(q.clone());
932                for qb in &basis {
933                    let proj: f64 = v.iter().zip(qb.iter()).map(|(&x, &y)| x * y).sum();
934                    for j in 0..m {
935                        v[j] -= proj * qb[j];
936                    }
937                }
938                let b: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
939                if !(b.is_finite()) {
940                    return Err("residual cascade: Lanczos breakdown (non-finite norm)".into());
941                }
942                if b < 1e-13 {
943                    break;
944                }
945                beta.push(b);
946                q_prev = Some(std::mem::replace(&mut q, v));
947                for qj in q.iter_mut() {
948                    *qj /= b;
949                }
950            }
951            beta.truncate(alpha.len().saturating_sub(1));
952            let (theta, tau) = symmetric_tridiagonal_eigen(&alpha, &beta)?;
953            let mut quad = 0.0;
954            for (&t, &w0) in theta.iter().zip(tau.iter()) {
955                if !(t.is_finite() && t > EIG_FLOOR) {
956                    return Err(format!(
957                        "residual cascade: non-positive Ritz value {t} in SLQ (system not PD)"
958                    ));
959                }
960                quad += w0 * w0 * t.ln();
961            }
962            trace_est += z_norm2 * quad;
963        }
964        Ok(logdet + trace_est / SLQ_PROBES as f64)
965    }
966
967    /// Log-determinant through the route the sizing contract picks.
968    fn logdet(&self, lambda: f64) -> Result<(f64, LogdetMethod), String> {
969        if self.dense_gram.is_some() {
970            Ok((self.logdet_dense(lambda)?, LogdetMethod::DenseExact))
971        } else {
972            Ok((self.logdet_slq(lambda)?, LogdetMethod::Slq))
973        }
974    }
975
976    /// Coefficient solve at λ: dense Cholesky when cached, else certified PCG.
977    fn solve_coeff(
978        &self,
979        lambda: f64,
980        b: &[f64],
981        warm: Option<&[f64]>,
982    ) -> Result<(Vec<f64>, f64, usize), String> {
983        // A core rebuilt from a persisted state carries no training design, only
984        // the factored precision `L` of `A = X'WX + λD` at the fit's λ. Replay
985        // the solve through it (exact — predict always solves at that same λ).
986        if let Some(l) = &self.predict_chol {
987            return Ok((chol_solve(l, self.m, b), 0.0, 0));
988        }
989        if let Some(mut a) = self.dense_system(lambda) {
990            cholesky_logdet(&mut a, self.m)?;
991            return Ok((chol_solve(&a, self.m, b), 0.0, 0));
992        }
993        self.pcg(lambda, b, warm)
994    }
995
996    /// Assemble the lower Cholesky factor `L` of `A = X'WX + λD` as a dense
997    /// `m × m` row-major matrix — the factored precision a persisted predict
998    /// replays its posterior-variance solve through. Uses the cached dense Gram
999    /// when present; otherwise scatters the CSR row outer products into the
1000    /// upper triangle (one O(nnz·q) pass), the same assembly `build` uses under
1001    /// the sizing cap, just without the cap. Factoring is O(m³) — paid once at
1002    /// snapshot time, not per predict.
1003    fn assemble_predict_factor(&self, lambda: f64) -> Result<Vec<f64>, String> {
1004        let m = self.m;
1005        let mut a = vec![0.0_f64; m * m];
1006        if let Some(gram) = &self.dense_gram {
1007            for i in 0..m {
1008                for j in i..m {
1009                    let v = gram[i * m + j];
1010                    a[i * m + j] = v;
1011                    a[j * m + i] = v;
1012                }
1013            }
1014        } else {
1015            for i in 0..self.w.len() {
1016                let lo = self.row_ptr[i];
1017                let hi = self.row_ptr[i + 1];
1018                for ea in lo..hi {
1019                    let ca = self.col_idx[ea] as usize;
1020                    let va = self.w[i] * self.vals[ea];
1021                    for eb in ea..hi {
1022                        let cb = self.col_idx[eb] as usize;
1023                        a[ca * m + cb] += va * self.vals[eb];
1024                    }
1025                }
1026            }
1027            // Mirror the upper triangle into the lower.
1028            for i in 0..m {
1029                for j in i + 1..m {
1030                    a[j * m + i] = a[i * m + j];
1031                }
1032            }
1033        }
1034        for (i, d) in self.pen_diag.iter().enumerate() {
1035            a[i * m + i] += lambda * d;
1036        }
1037        cholesky_logdet(&mut a, m)?;
1038        Ok(a)
1039    }
1040
1041    /// Penalized residual quadratic at a solution: `y'Wy − c'X'Wy`.
1042    fn rss_pen(&self, coeff: &[f64]) -> f64 {
1043        let mut quad = 0.0;
1044        for (c, r) in coeff.iter().zip(self.rhs.iter()) {
1045            quad += c * r;
1046        }
1047        self.ytwy - quad
1048    }
1049
1050    /// Number of unpenalized (polynomial) columns.
1051    fn nullity(&self) -> usize {
1052        self.dim + 1
1053    }
1054
1055    /// Working residual `r_i = y_i − (Xc)_i`.
1056    fn residuals(&self, coeff: &[f64]) -> Vec<f64> {
1057        let n = self.y.len();
1058        let mut r = Vec::with_capacity(n);
1059        for i in 0..n {
1060            let mut fit = 0.0;
1061            for e in self.row_ptr[i]..self.row_ptr[i + 1] {
1062                fit += self.vals[e] * coeff[self.col_idx[e] as usize];
1063            }
1064            r.push(self.y[i] - fit);
1065        }
1066        r
1067    }
1068}
1069
1070// ──────────────────── symmetric tridiagonal eigensolver ─────────────────────
1071
1072/// Eigenvalues and FIRST eigenvector components of a symmetric tridiagonal
1073/// matrix (diag `d`, off-diagonal `e`), by implicit-shift QL with the
1074/// first-row vector carried through the rotations — exactly what Lanczos
1075/// quadrature needs.
1076fn symmetric_tridiagonal_eigen(d: &[f64], e: &[f64]) -> Result<(Vec<f64>, Vec<f64>), String> {
1077    let n = d.len();
1078    if n == 0 {
1079        return Ok((Vec::new(), Vec::new()));
1080    }
1081    let mut diag = d.to_vec();
1082    let mut off = vec![0.0; n];
1083    off[..n - 1].copy_from_slice(&e[..n - 1]);
1084    let mut first = vec![0.0; n];
1085    first[0] = 1.0;
1086    for l in 0..n {
1087        let mut iter = 0;
1088        loop {
1089            // Find a negligible off-diagonal to split at.
1090            let mut msplit = n - 1;
1091            for mm in l..n - 1 {
1092                let dd = diag[mm].abs() + diag[mm + 1].abs();
1093                if off[mm].abs() <= f64::EPSILON * dd {
1094                    msplit = mm;
1095                    break;
1096                }
1097            }
1098            if msplit == l {
1099                break;
1100            }
1101            iter += 1;
1102            if iter > 60 {
1103                return Err("residual cascade: tridiagonal QL failed to converge".into());
1104            }
1105            let mut g = (diag[l + 1] - diag[l]) / (2.0 * off[l]);
1106            let mut r = g.hypot(1.0);
1107            g = diag[msplit] - diag[l] + off[l] / (g + r.copysign(g));
1108            let (mut s, mut c) = (1.0, 1.0);
1109            let mut p = 0.0;
1110            let mut broke_early = false;
1111            for i in (l..msplit).rev() {
1112                let mut f = s * off[i];
1113                let b = c * off[i];
1114                r = f.hypot(g);
1115                off[i + 1] = r;
1116                if r == 0.0 {
1117                    diag[i + 1] -= p;
1118                    off[msplit] = 0.0;
1119                    broke_early = true;
1120                    break;
1121                }
1122                s = f / r;
1123                c = g / r;
1124                g = diag[i + 1] - p;
1125                r = (diag[i] - g) * s + 2.0 * c * b;
1126                p = s * r;
1127                diag[i + 1] = g + p;
1128                g = c * r - b;
1129                // Carry the first-row eigenvector components.
1130                f = first[i + 1];
1131                first[i + 1] = s * first[i] + c * f;
1132                first[i] = c * first[i] - s * f;
1133            }
1134            if broke_early {
1135                continue;
1136            }
1137            diag[l] -= p;
1138            off[l] = g;
1139            off[msplit] = 0.0;
1140        }
1141    }
1142    Ok((diag, first))
1143}
1144
1145// ───────────────────────────── net construction ─────────────────────────────
1146
1147/// Extend a nested net to covering radius `h` over the DOMAIN: first every data
1148/// point further than `h` from the (seeded) net becomes a new center, then every
1149/// cell of the `h`-grid over the bounding box `[0, box_hi]` whose centre is not
1150/// yet within `h` of the net is filled with a synthetic center. O((n + box
1151/// cells)·3^d). Returns the new centers.
1152///
1153/// Covering the box, not merely the data cloud, is what the multilevel Wendland
1154/// norm-equivalence (Narcowich–Ward inverse estimates + Le Gia–Wendland
1155/// multilevel stability) actually requires: the nested centres must be
1156/// quasi-uniform over the domain Ω. In data-dense regions every cell is already
1157/// covered by a data center, so the fill is a no-op there; in a data void it
1158/// plants the fine centres whose coefficients carry no data and revert to the
1159/// prior — the mechanism by which the posterior mean bridges a gap (coarse
1160/// data-pinned bumps) while the posterior variance GROWS into it (fine void
1161/// bumps the data cannot pin). The synthetic centres carry (almost) no data
1162/// rows, so their Gram diagonal is ~0 and they land in the penalty-dominated
1163/// fine block where the Jacobi preconditioner is exact — they neither perturb
1164/// the coarse factorization nor the n-independent iteration count.
1165fn extend_net(
1166    net: &mut Vec<[f64; 3]>,
1167    points: &[[f64; 3]],
1168    dim: usize,
1169    h: f64,
1170    box_hi: &[f64; 3],
1171) -> Vec<[f64; 3]> {
1172    let mut grid = HashGrid::new(h, dim);
1173    for (idx, c) in net.iter().enumerate() {
1174        grid.insert(idx as u32, c);
1175    }
1176    let h2 = h * h;
1177    let mut new_centers = Vec::new();
1178    let try_add = |net: &mut Vec<[f64; 3]>,
1179                   grid: &mut HashGrid,
1180                   new_centers: &mut Vec<[f64; 3]>,
1181                   p: &[f64; 3]| {
1182        let mut covered = false;
1183        grid.for_neighbors(p, |j| {
1184            if !covered && dist2(p, &net[j as usize], dim) <= h2 {
1185                covered = true;
1186            }
1187        });
1188        if !covered {
1189            let idx = net.len() as u32;
1190            net.push(*p);
1191            grid.insert(idx, p);
1192            new_centers.push(*p);
1193        }
1194    };
1195    for p in points {
1196        try_add(net, &mut grid, &mut new_centers, p);
1197        if net.len() > MAX_CENTERS {
1198            return new_centers;
1199        }
1200    }
1201    // Fill the bounding box so the net covers the domain, not just the data.
1202    //
1203    // The box has ~`(box_hi/h)^dim` cells, so the fill cost grows like
1204    // `(2^l)^dim` as the covering radius `h = h₀·2^{-l}` shrinks with the
1205    // level `l`. At fine levels below the data spacing that is an explosion
1206    // (every sub-data-spacing cell of the whole domain becomes a synthetic
1207    // center), which is unbounded work the caller never needs: once the net
1208    // crosses `MAX_CENTERS` the build path errors and the auto-route
1209    // refinement (`next_level_gain_bound`) treats it as "stop refining". So
1210    // cap the fill IN the loop — stop planting synthetic centers the moment
1211    // the net exceeds the cap rather than materializing the entire fine-level
1212    // box first. Coarse levels (few cells, never near the cap) keep the full
1213    // quasi-uniform domain fill and the polynomial-bridge gap behavior intact.
1214    let mut cells = [1_i64; 3];
1215    for a in 0..dim {
1216        cells[a] = (box_hi[a] / h).ceil() as i64 + 1;
1217    }
1218    let mut c = [0.0_f64; 3];
1219    'fill: for i0 in 0..cells[0] {
1220        c[0] = (i0 as f64 + 0.5) * h;
1221        for i1 in 0..cells[1] {
1222            if dim > 1 {
1223                c[1] = (i1 as f64 + 0.5) * h;
1224            }
1225            for i2 in 0..cells[2] {
1226                if dim > 2 {
1227                    c[2] = (i2 as f64 + 0.5) * h;
1228                }
1229                try_add(net, &mut grid, &mut new_centers, &c);
1230                if net.len() > MAX_CENTERS {
1231                    break 'fill;
1232                }
1233            }
1234        }
1235    }
1236    new_centers
1237}
1238
1239impl ResidualCascadeDesign {
1240    /// Build the cascade design: validate, scale by the metric, grow `levels`
1241    /// nested nets, and assemble the sparse design plus its sufficient
1242    /// statistics in O(n·(levels + 3^d)).
1243    ///
1244    /// `xs` holds one slice per axis (2 or 3 of them), `metric` the positive
1245    /// per-axis scaling of the learned metric, `sobolev_s` the Sobolev order
1246    /// of the equivalent (semi)norm — must satisfy `d/2 < s ≤ (d+3)/2` (the
1247    /// Wendland-(3,1) native smoothness).
1248    pub fn build(
1249        xs: &[&[f64]],
1250        y: &[f64],
1251        w: &[f64],
1252        metric: &[f64],
1253        sobolev_s: f64,
1254        levels: usize,
1255    ) -> Result<Self, String> {
1256        let dim = xs.len();
1257        if !(dim == 2 || dim == 3) {
1258            return Err(format!(
1259                "residual cascade: built for scattered 2-3D smooths, got {dim} axes"
1260            ));
1261        }
1262        let n = y.len();
1263        if w.len() != n || xs.iter().any(|x| x.len() != n) {
1264            return Err(format!(
1265                "residual cascade: length mismatch (y={n}, w={}, axes={:?})",
1266                w.len(),
1267                xs.iter().map(|x| x.len()).collect::<Vec<_>>()
1268            ));
1269        }
1270        if n <= dim + 1 {
1271            return Err(format!(
1272                "residual cascade: needs more than {} rows for the profiled REML degrees of \
1273                 freedom, got {n}",
1274                dim + 1
1275            ));
1276        }
1277        if metric.len() != dim || metric.iter().any(|&s| !(s.is_finite() && s > 0.0)) {
1278            return Err(format!(
1279                "residual cascade: metric must be {dim} finite positive scales, got {metric:?}"
1280            ));
1281        }
1282        if !(sobolev_s > dim as f64 / 2.0 && sobolev_s <= (dim as f64 + 3.0) / 2.0) {
1283            return Err(format!(
1284                "residual cascade: sobolev_s must lie in (d/2, (d+3)/2] = ({}, {}] for the \
1285                 Wendland-(3,1) bump, got {sobolev_s}",
1286                dim as f64 / 2.0,
1287                (dim as f64 + 3.0) / 2.0
1288            ));
1289        }
1290        if levels == 0 || levels > MAX_LEVELS {
1291            return Err(format!(
1292                "residual cascade: levels must be in 1..={MAX_LEVELS}, got {levels}"
1293            ));
1294        }
1295        for i in 0..n {
1296            if !(y[i].is_finite() && w[i].is_finite() && w[i] > 0.0)
1297                || xs.iter().any(|x| !x[i].is_finite())
1298            {
1299                return Err(format!(
1300                    "residual cascade: non-finite or non-positive input at row {i}"
1301                ));
1302            }
1303        }
1304        // Scaled, corner-shifted coordinates.
1305        let mut z_lo = [f64::INFINITY; 3];
1306        let mut z_hi = [f64::NEG_INFINITY; 3];
1307        for a in 0..dim {
1308            for &v in xs[a] {
1309                let s = metric[a] * v;
1310                z_lo[a] = z_lo[a].min(s);
1311                z_hi[a] = z_hi[a].max(s);
1312            }
1313        }
1314        let mut z_range = [1.0_f64; 3];
1315        let mut max_range = 0.0_f64;
1316        for a in 0..dim {
1317            if !(z_hi[a] > z_lo[a]) {
1318                return Err(format!(
1319                    "residual cascade: degenerate axis {a} bounding box [{}, {}]",
1320                    z_lo[a], z_hi[a]
1321                ));
1322            }
1323            z_range[a] = z_hi[a] - z_lo[a];
1324            max_range = max_range.max(z_range[a]);
1325        }
1326        for a in dim..3 {
1327            z_lo[a] = 0.0;
1328        }
1329        let z: Vec<[f64; 3]> = (0..n)
1330            .map(|i| {
1331                let mut p = [0.0_f64; 3];
1332                for a in 0..dim {
1333                    p[a] = metric[a] * xs[a][i] - z_lo[a];
1334                }
1335                p
1336            })
1337            .collect();
1338        let mut metric3 = [1.0_f64; 3];
1339        metric3[..dim].copy_from_slice(metric);
1340
1341        let h0 = H0_FRACTION * max_range;
1342        let mut net: Vec<[f64; 3]> = Vec::new();
1343        let mut level_specs = Vec::with_capacity(levels);
1344        let mut col = dim + 1;
1345        let mut pen_logdet_const = 0.0;
1346        for l in 0..levels {
1347            let h = h0 * 0.5_f64.powi(l as i32);
1348            let new_centers = extend_net(&mut net, &z, dim, h, &z_range);
1349            if net.len() > MAX_CENTERS {
1350                return Err(format!(
1351                    "residual cascade: center cap {MAX_CENTERS} exceeded at level {l}"
1352                ));
1353            }
1354            let weight = level_weight(l, sobolev_s, dim);
1355            pen_logdet_const += new_centers.len() as f64 * weight.ln();
1356            let delta = OVERLAP * h;
1357            let mut grid = HashGrid::new(delta, dim);
1358            for (j, c) in new_centers.iter().enumerate() {
1359                grid.insert(j as u32, c);
1360            }
1361            let col_offset = col;
1362            col += new_centers.len();
1363            level_specs.push(Level {
1364                h,
1365                delta,
1366                weight,
1367                centers: new_centers,
1368                col_offset,
1369                grid,
1370            });
1371        }
1372        let m = col;
1373
1374        // CSR assembly + sufficient statistics in one pass.
1375        let mut row_ptr = Vec::with_capacity(n + 1);
1376        row_ptr.push(0_usize);
1377        let mut col_idx: Vec<u32> = Vec::new();
1378        let mut vals: Vec<f64> = Vec::new();
1379        let mut rhs = vec![0.0_f64; m];
1380        let mut gram_diag = vec![0.0_f64; m];
1381        let mut ytwy = 0.0_f64;
1382        let probe_core = CoreScaffold {
1383            dim,
1384            z_range,
1385            levels: &level_specs,
1386        };
1387        for i in 0..n {
1388            let row = probe_core.basis_row(&z[i]);
1389            for &(c, v) in &row {
1390                col_idx.push(c as u32);
1391                vals.push(v);
1392                rhs[c] += w[i] * y[i] * v;
1393                gram_diag[c] += w[i] * v * v;
1394            }
1395            ytwy += w[i] * y[i] * y[i];
1396            row_ptr.push(col_idx.len());
1397        }
1398        let mut pen_diag = vec![0.0_f64; m];
1399        for level in &level_specs {
1400            for j in 0..level.centers.len() {
1401                pen_diag[level.col_offset + j] = level.weight;
1402            }
1403        }
1404
1405        // Dense Gram cache under the sizing cap: O(n·q²) scatter of row outer
1406        // products into the upper triangle.
1407        let dense_gram = if m <= DENSE_GRAM_MAX {
1408            let mut gram = vec![0.0_f64; m * m];
1409            for i in 0..n {
1410                let lo = row_ptr[i];
1411                let hi = row_ptr[i + 1];
1412                for ea in lo..hi {
1413                    let ca = col_idx[ea] as usize;
1414                    let va = w[i] * vals[ea];
1415                    for eb in ea..hi {
1416                        gram[ca * m + col_idx[eb] as usize] += va * vals[eb];
1417                    }
1418                }
1419            }
1420            Some(gram)
1421        } else {
1422            None
1423        };
1424
1425        Ok(ResidualCascadeDesign {
1426            core: Arc::new(Core {
1427                dim,
1428                metric: metric3,
1429                z_lo,
1430                z_range,
1431                sobolev_s,
1432                levels: level_specs,
1433                net,
1434                m,
1435                row_ptr,
1436                col_idx,
1437                vals,
1438                w: w.to_vec(),
1439                y: y.to_vec(),
1440                z,
1441                rhs,
1442                ytwy,
1443                gram_diag,
1444                pen_diag,
1445                pen_logdet_const,
1446                dense_gram,
1447                predict_chol: None,
1448            }),
1449        })
1450    }
1451
1452    /// Number of resolution levels.
1453    pub fn num_levels(&self) -> usize {
1454        self.core.levels.len()
1455    }
1456
1457    /// Aspect ratio of the metric-scaled point cloud: the ratio of the largest
1458    /// to smallest per-axis standard deviation of the scaled coordinates `z`.
1459    /// This is the metric-condition measure the quasi-uniformity guard (issue
1460    /// #1032, caveat 2) keys on — see [`QUASI_UNIFORMITY_MAX_ASPECT`]. A value
1461    /// near 1 is an isotropic (benign) cloud; a large value means the metric
1462    /// has collapsed the data onto a lower-dimensional sheet in `z`, breaking
1463    /// the BPX n-independent iteration bound.
1464    pub fn metric_scaled_aspect_ratio(&self) -> f64 {
1465        let dim = self.core.dim;
1466        let n = self.core.z.len();
1467        if dim == 0 || n == 0 {
1468            return 1.0;
1469        }
1470        let mut mean = [0.0_f64; 3];
1471        for p in &self.core.z {
1472            for a in 0..dim {
1473                mean[a] += p[a];
1474            }
1475        }
1476        for m in mean.iter_mut().take(dim) {
1477            *m /= n as f64;
1478        }
1479        let mut var = [0.0_f64; 3];
1480        for p in &self.core.z {
1481            for a in 0..dim {
1482                let d = p[a] - mean[a];
1483                var[a] += d * d;
1484            }
1485        }
1486        let mut sd_lo = f64::INFINITY;
1487        let mut sd_hi = 0.0_f64;
1488        for v in var.iter().take(dim) {
1489            let sd = (v / n as f64).sqrt();
1490            sd_lo = sd_lo.min(sd);
1491            sd_hi = sd_hi.max(sd);
1492        }
1493        if !(sd_lo > 0.0 && sd_lo.is_finite()) {
1494            // A collapsed axis (zero scaled spread) is maximally degenerate.
1495            return f64::INFINITY;
1496        }
1497        sd_hi / sd_lo
1498    }
1499
1500    /// Quasi-uniformity certificate (issue #1032, caveat 2): `true` iff the
1501    /// metric-scaled cloud is isotropic enough that the BPX n-independent CG
1502    /// iteration bound is trustworthy. When this returns `false` the auto-route
1503    /// MUST fall back to the dense kernel path rather than pay an iterative
1504    /// solve whose iteration count is no longer n-independent — the CG residual
1505    /// certificate would still *catch* a mis-solve at [`CG_MAX_ITERS`], but the
1506    /// guard prevents the silent O(n·iters) blow-up up front.
1507    pub fn quasi_uniformity_certified(&self) -> bool {
1508        self.metric_scaled_aspect_ratio() <= QUASI_UNIFORMITY_MAX_ASPECT
1509    }
1510
1511    /// Number of columns `ncoarse` in the additive-Schwarz coarse space at `log
1512    /// λ` (the polynomial layer plus the data-dominated coarsest levels). The
1513    /// iterative-route preconditioner solves the principal `[0, ncoarse)` block
1514    /// of `A = X'WX + λD` exactly and Jacobi-preconditions the fine tail; exposed
1515    /// so the conditioning oracle can reconstruct that block-arrow preconditioner
1516    /// from the public dense system and certify it is uniformly conditioned in
1517    /// depth. See [`COARSE_DOMINANCE`].
1518    pub fn coarse_space_cols(&self, log_lambda: f64) -> usize {
1519        self.core.coarse_space_cols(log_lambda.exp())
1520    }
1521
1522    /// Total coefficient count (`dim + 1` polynomial + all centers).
1523    pub fn num_coeffs(&self) -> usize {
1524        self.core.m
1525    }
1526
1527    /// Structural nonzero count of the sparse design `X` (its CSR size). Each
1528    /// iterative-route PCG iteration applies the operator `A = XᵀWX + λD` as two
1529    /// CSR products against `X`, so its per-iteration cost is `Θ(nnz(X))`; the
1530    /// certified sparse-solve work is therefore `solve_iters · num_nonzeros()`,
1531    /// the figure the residual-cascade complexity certificate compares against
1532    /// the dense `m³/3` factorization cost. Zero on a predict-only core rebuilt
1533    /// from a persisted snapshot (the training CSR is intentionally dropped).
1534    pub fn num_nonzeros(&self) -> usize {
1535        self.core.col_idx.len()
1536    }
1537
1538    /// Total centers across all levels.
1539    pub fn num_centers(&self) -> usize {
1540        self.core.m - self.core.nullity()
1541    }
1542
1543    /// NEW centers of one level in ORIGINAL (unscaled) coordinates.
1544    pub fn centers(&self, level: usize) -> Vec<Vec<f64>> {
1545        let lv = &self.core.levels[level];
1546        lv.centers
1547            .iter()
1548            .map(|c| {
1549                (0..self.core.dim)
1550                    .map(|a| (c[a] + self.core.z_lo[a]) / self.core.metric[a])
1551                    .collect()
1552            })
1553            .collect()
1554    }
1555
1556    /// Sparse basis row at a raw point, as (column, value) pairs sorted by
1557    /// column within each block — the exact row the fit used for training
1558    /// rows, exposed so oracles can assemble the dense system independently.
1559    pub fn basis_row(&self, x: &[f64]) -> Result<Vec<(usize, f64)>, String> {
1560        self.check_point(x)?;
1561        Ok(self.core.basis_row_scaled(&self.core.scale_point(x)))
1562    }
1563
1564    fn check_point(&self, x: &[f64]) -> Result<(), String> {
1565        if x.len() != self.core.dim || x.iter().any(|v| !v.is_finite()) {
1566            return Err(format!(
1567                "residual cascade: point must be {} finite coordinates, got {x:?}",
1568                self.core.dim
1569            ));
1570        }
1571        Ok(())
1572    }
1573
1574    /// Exact penalty quadratic `c'Dc` (unit-λ multilevel prior energy).
1575    pub fn penalty_value(&self, coeff: &[f64]) -> Result<f64, String> {
1576        if coeff.len() != self.core.m {
1577            return Err(format!(
1578                "residual cascade: coefficient length {} != {}",
1579                coeff.len(),
1580                self.core.m
1581            ));
1582        }
1583        Ok(coeff
1584            .iter()
1585            .zip(self.core.pen_diag.iter())
1586            .map(|(&c, &d)| d * c * c)
1587            .sum())
1588    }
1589
1590    /// Exact dense log-determinant of `X'WX + λD` (errors past the sizing
1591    /// cap) — exposed for the in-test SLQ-vs-exact oracle.
1592    pub fn logdet_exact(&self, log_lambda: f64) -> Result<f64, String> {
1593        self.core.logdet_dense(log_lambda.exp())
1594    }
1595
1596    /// SLQ log-determinant estimate on the fixed deterministic probes —
1597    /// exposed for the in-test SLQ-vs-exact oracle.
1598    pub fn logdet_slq(&self, log_lambda: f64) -> Result<f64, String> {
1599        self.core.logdet_slq(log_lambda.exp())
1600    }
1601
1602    /// Profiled-σ² REML criterion at `log λ` (differences across λ are exact
1603    /// REML differences on the dense route; SLQ-estimated past the cap).
1604    pub fn criterion(&self, log_lambda: f64) -> Result<f64, String> {
1605        Ok(self.criterion_with_warm(log_lambda, None)?.0)
1606    }
1607
1608    fn criterion_with_warm(
1609        &self,
1610        log_lambda: f64,
1611        warm: Option<&[f64]>,
1612    ) -> Result<(f64, Vec<f64>), String> {
1613        if !log_lambda.is_finite() {
1614            return Err(format!(
1615                "residual cascade: non-finite log lambda {log_lambda}"
1616            ));
1617        }
1618        let core = &self.core;
1619        let lambda = log_lambda.exp();
1620        let (coeff, _, _) = core.solve_coeff(lambda, &core.rhs, warm)?;
1621        let rss_pen = core.rss_pen(&coeff);
1622        if !(rss_pen > 0.0) {
1623            return Err(format!(
1624                "residual cascade: degenerate penalized residual {rss_pen}"
1625            ));
1626        }
1627        let (logdet, _) = core.logdet(lambda)?;
1628        let dof = (core.y.len() - core.nullity()) as f64;
1629        let r = (core.m - core.nullity()) as f64;
1630        let sigma2 = rss_pen / dof;
1631        Ok((
1632            -0.5 * (logdet - r * log_lambda - core.pen_logdet_const + dof * sigma2.ln()),
1633            coeff,
1634        ))
1635    }
1636
1637    /// Fit at a FIXED `log λ`, with σ² either supplied or profiled.
1638    pub fn fit_at(
1639        &self,
1640        log_lambda: f64,
1641        sigma2: Option<f64>,
1642    ) -> Result<ResidualCascadeFit, String> {
1643        self.fit_at_with_warm(log_lambda, sigma2, None)
1644    }
1645
1646    fn fit_at_with_warm(
1647        &self,
1648        log_lambda: f64,
1649        sigma2: Option<f64>,
1650        warm: Option<&[f64]>,
1651    ) -> Result<ResidualCascadeFit, String> {
1652        if !log_lambda.is_finite() {
1653            return Err(format!(
1654                "residual cascade: non-finite log lambda {log_lambda}"
1655            ));
1656        }
1657        let core = &self.core;
1658        let lambda = log_lambda.exp();
1659        let (coeff, rel_res, iters) = core.solve_coeff(lambda, &core.rhs, warm)?;
1660        let rss_pen = core.rss_pen(&coeff);
1661        let dof = (core.y.len() - core.nullity()) as f64;
1662        let sigma2 = match sigma2 {
1663            Some(s) => {
1664                if !(s.is_finite() && s > 0.0) {
1665                    return Err(format!("residual cascade: invalid sigma2 {s}"));
1666                }
1667                s
1668            }
1669            None => {
1670                if !(rss_pen > 0.0) {
1671                    return Err(format!(
1672                        "residual cascade: degenerate penalized residual {rss_pen}"
1673                    ));
1674                }
1675                rss_pen / dof
1676            }
1677        };
1678        let (logdet, logdet_method) = core.logdet(lambda)?;
1679        let r = (core.m - core.nullity()) as f64;
1680        // Full restricted log-likelihood at this (λ, σ²) up to λ- and σ-free
1681        // constants; at the profiled σ̂² the quadratic collapses to `dof`.
1682        let restricted_loglik = -0.5
1683            * (logdet - r * log_lambda - core.pen_logdet_const
1684                + dof * sigma2.ln()
1685                + rss_pen / sigma2);
1686        let predict_chol = if core.dense_gram.is_some() {
1687            Some(core.assemble_predict_factor(lambda)?)
1688        } else {
1689            None
1690        };
1691        Ok(ResidualCascadeFit {
1692            core: Arc::clone(&self.core),
1693            predict_chol,
1694            coeff,
1695            log_lambda,
1696            sigma2,
1697            restricted_loglik,
1698            rss_pen,
1699            certificate: CascadeCertificate {
1700                solve_rel_residual: rel_res,
1701                solve_iters: iters,
1702                logdet_method,
1703            },
1704            refinement: None,
1705        })
1706    }
1707
1708    /// Fit with `log λ` selected by the profiled REML criterion:
1709    /// deterministic coarse grid then golden-section refinement (the SLQ
1710    /// probes are fixed, so the iterative-route criterion is just as
1711    /// deterministic — same data, same fit).
1712    pub fn fit_reml(&self) -> Result<ResidualCascadeFit, String> {
1713        let mut best_i = 0usize;
1714        let mut best_v = f64::NEG_INFINITY;
1715        let mut best_coeff = Vec::new();
1716        let mut warm: Option<Vec<f64>> = None;
1717        let step = (LOG_LAMBDA_HI - LOG_LAMBDA_LO) / (LOG_LAMBDA_GRID - 1) as f64;
1718        for i in 0..LOG_LAMBDA_GRID {
1719            let ll = LOG_LAMBDA_LO + step * i as f64;
1720            let (v, coeff) = self.criterion_with_warm(ll, warm.as_deref())?;
1721            if v > best_v {
1722                best_v = v;
1723                best_i = i;
1724                best_coeff = coeff.clone();
1725            }
1726            warm = Some(coeff);
1727        }
1728        let mut lo = LOG_LAMBDA_LO + step * best_i.saturating_sub(1) as f64;
1729        let mut hi = (LOG_LAMBDA_LO + step * (best_i + 1) as f64).min(LOG_LAMBDA_HI);
1730        let inv_phi = 0.618_033_988_749_894_9_f64;
1731        let mut x1 = hi - inv_phi * (hi - lo);
1732        let mut x2 = lo + inv_phi * (hi - lo);
1733        let (mut f1, mut c1) = self.criterion_with_warm(x1, Some(&best_coeff))?;
1734        let (mut f2, mut c2) = self.criterion_with_warm(x2, Some(&c1))?;
1735        while hi - lo > LOG_LAMBDA_TOL {
1736            if f1 < f2 {
1737                lo = x1;
1738                x1 = x2;
1739                f1 = f2;
1740                c1 = c2;
1741                x2 = lo + inv_phi * (hi - lo);
1742                (f2, c2) = self.criterion_with_warm(x2, Some(&c1))?;
1743            } else {
1744                hi = x2;
1745                x2 = x1;
1746                f2 = f1;
1747                c2 = c1;
1748                x1 = hi - inv_phi * (hi - lo);
1749                (f1, c1) = self.criterion_with_warm(x1, Some(&c2))?;
1750            }
1751        }
1752        let warm = if f1 >= f2 { &c1 } else { &c2 };
1753        self.fit_at_with_warm(0.5 * (lo + hi), None, Some(warm))
1754    }
1755
1756    /// Exact upper bound on the penalized-objective decrease available from
1757    /// appending the candidate level L+1 at this fit's λ:
1758    /// `‖X₂'W r̂‖² / (λ·d_{L+1})` (see the module header for the Schur-
1759    /// complement argument). Returns `None` when the net is exhausted (no new
1760    /// centers — every data point is already a center).
1761    pub fn next_level_gain_bound(&self, fit: &ResidualCascadeFit) -> Result<Option<f64>, String> {
1762        let core = &self.core;
1763        if !Arc::ptr_eq(core, &fit.core) {
1764            return Err("residual cascade: fit does not belong to this design".into());
1765        }
1766        let next_l = core.levels.len();
1767        if next_l >= MAX_LEVELS {
1768            return Ok(None);
1769        }
1770        let h = core.levels[next_l - 1].h * 0.5;
1771        let mut net = core.net.clone();
1772        let candidates = extend_net(&mut net, &core.z, core.dim, h, &core.z_range);
1773        if candidates.is_empty() || net.len() > MAX_CENTERS {
1774            return Ok(None);
1775        }
1776        let delta = OVERLAP * h;
1777        let mut grid = HashGrid::new(delta, core.dim);
1778        for (j, c) in candidates.iter().enumerate() {
1779            grid.insert(j as u32, c);
1780        }
1781        let r = core.residuals(&fit.coeff);
1782        let mut g = vec![0.0_f64; candidates.len()];
1783        for (i, zi) in core.z.iter().enumerate() {
1784            let wr = core.w[i] * r[i];
1785            grid.for_neighbors(zi, |j| {
1786                let rad = dist2(zi, &candidates[j as usize], core.dim).sqrt() / delta;
1787                g[j as usize] += wr * wendland(rad);
1788            });
1789        }
1790        let g2: f64 = g.iter().map(|v| v * v).sum();
1791        let d_next = level_weight(next_l, core.sobolev_s, core.dim);
1792        Ok(Some(g2 / (fit.log_lambda.exp() * d_next)))
1793    }
1794}
1795
1796/// Prior precision weight of level `l`: `4^{l(s−d/2)}`.
1797fn level_weight(l: usize, sobolev_s: f64, dim: usize) -> f64 {
1798    (4.0_f64).powf(l as f64 * (sobolev_s - dim as f64 / 2.0))
1799}
1800
1801/// Lightweight view used during assembly, before the Core exists: shares the
1802/// exact basis-row logic with [`Core::basis_row_scaled`] so the assembled CSR
1803/// and later prediction rows cannot drift apart.
1804struct CoreScaffold<'a> {
1805    dim: usize,
1806    z_range: [f64; 3],
1807    levels: &'a [Level],
1808}
1809
1810impl CoreScaffold<'_> {
1811    fn basis_row(&self, z: &[f64; 3]) -> Vec<(usize, f64)> {
1812        let mut row = Vec::with_capacity(self.dim + 1 + self.levels.len() * 8);
1813        row.push((0, 1.0));
1814        for a in 0..self.dim {
1815            row.push((a + 1, 2.0 * z[a] / self.z_range[a] - 1.0));
1816        }
1817        for level in self.levels {
1818            let start = row.len();
1819            level.grid.for_neighbors(z, |j| {
1820                let c = &level.centers[j as usize];
1821                let r = dist2(z, c, self.dim).sqrt() / level.delta;
1822                let v = wendland(r);
1823                if v > 0.0 {
1824                    row.push((level.col_offset + j as usize, v));
1825                }
1826            });
1827            row[start..].sort_unstable_by_key(|&(col, _)| col);
1828        }
1829        row
1830    }
1831}
1832
1833impl ResidualCascadeFit {
1834    /// Posterior `(mean, variance)` at a raw point: the sparse basis row
1835    /// dotted with the coefficients, and `σ̂²·x'(X'WX+λD)^{−1}x` through one
1836    /// certified solve.
1837    pub fn predict(&self, x: &[f64]) -> Result<(f64, f64), String> {
1838        let core = &self.core;
1839        if x.len() != core.dim || x.iter().any(|v| !v.is_finite()) {
1840            return Err(format!(
1841                "residual cascade: prediction point must be {} finite coordinates, got {x:?}",
1842                core.dim
1843            ));
1844        }
1845        let row = core.basis_row_scaled(&core.scale_point(x));
1846        let mut mean = 0.0;
1847        let mut dense_row = vec![0.0_f64; core.m];
1848        for &(c, v) in &row {
1849            mean += v * self.coeff[c];
1850            dense_row[c] += v;
1851        }
1852        let lambda = self.log_lambda.exp();
1853        let zsol = if let Some(l) = &self.predict_chol {
1854            chol_solve(l, core.m, &dense_row)
1855        } else {
1856            core.solve_coeff(lambda, &dense_row, None)?.0
1857        };
1858        let mut quad = 0.0;
1859        for (a, b) in dense_row.iter().zip(zsol.iter()) {
1860            quad += a * b;
1861        }
1862        Ok((mean, self.sigma2 * quad))
1863    }
1864
1865    /// EXACT posterior coefficient samples by perturb-and-solve:
1866    /// `c_s = A^{−1}(X'Wy + σ(X'W^{1/2}z₁ + √λ D^{1/2}z₂))` has mean ĉ and
1867    /// covariance exactly `σ̂²A^{−1}`. Deterministically seeded; one certified
1868    /// solve per sample (warm-started at the mode).
1869    pub fn sample_coefficients(&self, n_samples: usize) -> Result<Vec<Vec<f64>>, String> {
1870        let core = &self.core;
1871        let lambda = self.log_lambda.exp();
1872        let sigma = self.sigma2.sqrt();
1873        let sqrt_lambda = lambda.sqrt();
1874        let n = core.y.len();
1875        let mut rng = SplitMix64::new(RNG_SEED ^ 0xA11C_E5A_u64);
1876        let mut samples = Vec::with_capacity(n_samples);
1877        for _ in 0..n_samples {
1878            let mut b = core.rhs.clone();
1879            // X'W^{1/2} z₁: one CSR pass with per-row factor √w_i·z₁_i.
1880            for i in 0..n {
1881                let f = sigma * core.w[i].sqrt() * rng.next_normal();
1882                for e in core.row_ptr[i]..core.row_ptr[i + 1] {
1883                    b[core.col_idx[e] as usize] += f * core.vals[e];
1884                }
1885            }
1886            // √λ D^{1/2} z₂ on the penalized columns.
1887            for (bj, &dj) in b.iter_mut().zip(core.pen_diag.iter()) {
1888                if dj > 0.0 {
1889                    *bj += sigma * sqrt_lambda * dj.sqrt() * rng.next_normal();
1890                }
1891            }
1892            let (c, _, _) = core.solve_coeff(lambda, &b, Some(&self.coeff))?;
1893            samples.push(c);
1894        }
1895        Ok(samples)
1896    }
1897
1898    /// Number of resolution levels in the fitted cascade.
1899    pub fn num_levels(&self) -> usize {
1900        self.core.levels.len()
1901    }
1902
1903    /// Total coefficient count.
1904    pub fn num_coeffs(&self) -> usize {
1905        self.core.m
1906    }
1907
1908    /// Total centers across all fitted resolution levels.
1909    pub fn num_centers(&self) -> usize {
1910        self.core.m - self.core.nullity()
1911    }
1912
1913    /// Snapshot the fit for persistence (#1032). Assembles the factored
1914    /// precision `L` of `A = X'WX + λD` at the fit's λ (O(m³) once) and copies
1915    /// the nested geometry + coefficients, dropping all training rows. The
1916    /// resulting [`ResidualCascadeState`] is predict-complete: `from_state`
1917    /// replays the posterior mean+variance bit-for-bit.
1918    pub fn to_state(&self) -> Result<ResidualCascadeState, String> {
1919        let core = &self.core;
1920        let lambda = self.log_lambda.exp();
1921        let predict_chol = if let Some(l) = &self.predict_chol {
1922            l.clone()
1923        } else if let Some(l) = &core.predict_chol {
1924            l.clone()
1925        } else {
1926            core.assemble_predict_factor(lambda)?
1927        };
1928        let dim = core.dim;
1929        let levels = core
1930            .levels
1931            .iter()
1932            .map(|level| {
1933                let mut centers = Vec::with_capacity(level.centers.len() * dim);
1934                for c in &level.centers {
1935                    centers.extend_from_slice(&c[..dim]);
1936                }
1937                LevelState {
1938                    h: level.h,
1939                    delta: level.delta,
1940                    weight: level.weight,
1941                    col_offset: level.col_offset as u64,
1942                    centers,
1943                }
1944            })
1945            .collect();
1946        Ok(ResidualCascadeState {
1947            dim: dim as u64,
1948            metric: core.metric,
1949            z_lo: core.z_lo,
1950            z_range: core.z_range,
1951            sobolev_s: core.sobolev_s,
1952            levels,
1953            m: core.m as u64,
1954            pen_logdet_const: core.pen_logdet_const,
1955            coeff: self.coeff.clone(),
1956            log_lambda: self.log_lambda,
1957            sigma2: self.sigma2,
1958            restricted_loglik: self.restricted_loglik,
1959            rss_pen: self.rss_pen,
1960            predict_chol,
1961        })
1962    }
1963
1964    /// Rebuild a predict-capable fit from a snapshot (#1032). Validates shape,
1965    /// finiteness, the Sobolev/Wendland window, strictly-positive level weights
1966    /// and box ranges, the column accounting (`m = dim+1 + Σ centers`, matching
1967    /// `col_offset`s), positive σ², and that `predict_chol` is a valid `m × m`
1968    /// lower factor (positive pivots) — so a corrupt payload fails here, not in
1969    /// a later `predict`. The restored `Core` has empty training CSR and
1970    /// `predict_chol = Some(L)`; its `predict` reads only geometry (mean) and
1971    /// the factor (variance), replaying both exactly.
1972    pub fn from_state(state: &ResidualCascadeState) -> Result<Self, String> {
1973        let dim = state.dim as usize;
1974        if !(dim == 2 || dim == 3) {
1975            return Err(format!(
1976                "residual cascade state: dim must be 2 or 3, got {dim}"
1977            ));
1978        }
1979        if !(state.sobolev_s > dim as f64 / 2.0 && state.sobolev_s <= (dim as f64 + 3.0) / 2.0) {
1980            return Err(format!(
1981                "residual cascade state: sobolev_s {} outside the Wendland window ({}, {}]",
1982                state.sobolev_s,
1983                dim as f64 / 2.0,
1984                (dim as f64 + 3.0) / 2.0
1985            ));
1986        }
1987        for a in 0..dim {
1988            if !(state.metric[a].is_finite() && state.metric[a] > 0.0) {
1989                return Err(format!(
1990                    "residual cascade state: metric axis {a} must be finite positive, got {}",
1991                    state.metric[a]
1992                ));
1993            }
1994            if !(state.z_range[a].is_finite()
1995                && state.z_range[a] > 0.0
1996                && state.z_lo[a].is_finite())
1997            {
1998                return Err(format!(
1999                    "residual cascade state: degenerate box on axis {a} (lo={}, range={})",
2000                    state.z_lo[a], state.z_range[a]
2001                ));
2002            }
2003        }
2004        let m = state.m as usize;
2005        let mut metric3 = [1.0_f64; 3];
2006        metric3[..dim].copy_from_slice(&state.metric[..dim]);
2007        let mut z_lo = [0.0_f64; 3];
2008        let mut z_range = [1.0_f64; 3];
2009        z_lo[..dim].copy_from_slice(&state.z_lo[..dim]);
2010        z_range[..dim].copy_from_slice(&state.z_range[..dim]);
2011
2012        // Rebuild the levels and their lookup grids from the flattened centers,
2013        // checking the column accounting matches the polynomial layer + blocks.
2014        let mut levels = Vec::with_capacity(state.levels.len());
2015        let mut net: Vec<[f64; 3]> = Vec::new();
2016        let mut pen_diag = vec![0.0_f64; m];
2017        let mut expected_offset = dim + 1;
2018        for (li, ls) in state.levels.iter().enumerate() {
2019            if !(ls.h.is_finite() && ls.h > 0.0 && ls.delta.is_finite() && ls.delta > 0.0) {
2020                return Err(format!(
2021                    "residual cascade state: level {li} has non-positive h/delta ({}, {})",
2022                    ls.h, ls.delta
2023                ));
2024            }
2025            if !(ls.weight.is_finite() && ls.weight > 0.0) {
2026                return Err(format!(
2027                    "residual cascade state: level {li} has non-positive prior weight {}",
2028                    ls.weight
2029                ));
2030            }
2031            if ls.centers.len() % dim != 0 {
2032                return Err(format!(
2033                    "residual cascade state: level {li} centers length {} not a multiple of dim {dim}",
2034                    ls.centers.len()
2035                ));
2036            }
2037            let n_centers = ls.centers.len() / dim;
2038            let col_offset = ls.col_offset as usize;
2039            if col_offset != expected_offset {
2040                return Err(format!(
2041                    "residual cascade state: level {li} col_offset {col_offset} ≠ expected {expected_offset}"
2042                ));
2043            }
2044            let mut grid = HashGrid::new(ls.delta, dim);
2045            let mut centers = Vec::with_capacity(n_centers);
2046            for j in 0..n_centers {
2047                let mut c = [0.0_f64; 3];
2048                for a in 0..dim {
2049                    let v = ls.centers[j * dim + a];
2050                    if !v.is_finite() {
2051                        return Err(format!(
2052                            "residual cascade state: non-finite center coordinate at level {li}, center {j}"
2053                        ));
2054                    }
2055                    c[a] = v;
2056                }
2057                grid.insert(j as u32, &c);
2058                centers.push(c);
2059                net.push(c);
2060                let col = col_offset + j;
2061                if col >= m {
2062                    return Err(format!(
2063                        "residual cascade state: level {li} column {col} exceeds m {m}"
2064                    ));
2065                }
2066                pen_diag[col] = ls.weight;
2067            }
2068            expected_offset = col_offset + n_centers;
2069            levels.push(Level {
2070                h: ls.h,
2071                delta: ls.delta,
2072                weight: ls.weight,
2073                centers,
2074                col_offset,
2075                grid,
2076            });
2077        }
2078        if expected_offset != m {
2079            return Err(format!(
2080                "residual cascade state: column accounting mismatch (dim+1+Σcenters = {expected_offset} ≠ m {m})"
2081            ));
2082        }
2083        if state.coeff.len() != m {
2084            return Err(format!(
2085                "residual cascade state: coeff length {} ≠ m {m}",
2086                state.coeff.len()
2087            ));
2088        }
2089        if state.predict_chol.len() != m * m {
2090            return Err(format!(
2091                "residual cascade state: predict_chol must be m×m = {m}² = {}, got {}",
2092                m * m,
2093                state.predict_chol.len()
2094            ));
2095        }
2096        for (i, v) in state
2097            .coeff
2098            .iter()
2099            .chain(state.predict_chol.iter())
2100            .enumerate()
2101        {
2102            if !v.is_finite() {
2103                return Err(format!("residual cascade state: non-finite entry at {i}"));
2104            }
2105        }
2106        for g in 0..m {
2107            let piv = state.predict_chol[g * m + g];
2108            if !(piv.is_finite() && piv > 0.0) {
2109                return Err(format!(
2110                    "residual cascade state: non-positive Cholesky pivot {piv} at index {g}"
2111                ));
2112            }
2113        }
2114        if !(state.log_lambda.is_finite()
2115            && state.sigma2.is_finite()
2116            && state.sigma2 > 0.0
2117            && state.restricted_loglik.is_finite()
2118            && state.rss_pen.is_finite())
2119        {
2120            return Err(format!(
2121                "residual cascade state: invalid scalars (log_lambda={}, sigma2={}, restricted_loglik={}, rss_pen={})",
2122                state.log_lambda, state.sigma2, state.restricted_loglik, state.rss_pen
2123            ));
2124        }
2125        let core = Core {
2126            dim,
2127            metric: metric3,
2128            z_lo,
2129            z_range,
2130            sobolev_s: state.sobolev_s,
2131            levels,
2132            net,
2133            m,
2134            row_ptr: Vec::new(),
2135            col_idx: Vec::new(),
2136            vals: Vec::new(),
2137            w: Vec::new(),
2138            y: Vec::new(),
2139            z: Vec::new(),
2140            rhs: Vec::new(),
2141            ytwy: 0.0,
2142            gram_diag: Vec::new(),
2143            pen_diag,
2144            pen_logdet_const: state.pen_logdet_const,
2145            dense_gram: None,
2146            predict_chol: Some(state.predict_chol.clone()),
2147        };
2148        Ok(ResidualCascadeFit {
2149            core: Arc::new(core),
2150            predict_chol: None,
2151            coeff: state.coeff.clone(),
2152            log_lambda: state.log_lambda,
2153            sigma2: state.sigma2,
2154            restricted_loglik: state.restricted_loglik,
2155            rss_pen: state.rss_pen,
2156            certificate: CascadeCertificate {
2157                solve_rel_residual: 0.0,
2158                solve_iters: 0,
2159                logdet_method: LogdetMethod::DenseExact,
2160            },
2161            refinement: None,
2162        })
2163    }
2164}
2165
2166/// Fit the full magic-default cascade: start at [`INITIAL_LEVELS`], REML-fit,
2167/// and refine (add a level, refit, re-select λ) until the exact next-level
2168/// gain bound certifies that one more level cannot move the penalized
2169/// objective by more than [`REFINE_TOL`] of the penalized residual — or the
2170/// net/cap is exhausted. Returns the certified fit.
2171pub fn fit_residual_cascade(
2172    xs: &[&[f64]],
2173    y: &[f64],
2174    w: &[f64],
2175    metric: &[f64],
2176    sobolev_s: f64,
2177) -> Result<ResidualCascadeFit, String> {
2178    let mut levels = INITIAL_LEVELS;
2179    loop {
2180        let design = ResidualCascadeDesign::build(xs, y, w, metric, sobolev_s, levels)?;
2181        // Quasi-uniformity guard (issue #1032, caveat 2): if the metric has
2182        // collapsed the cloud onto a near-degenerate sheet in scaled
2183        // coordinates, the BPX iteration bound no longer holds. Refuse the
2184        // iterative solve up front with a typed signal so the auto-route falls
2185        // back to the dense kernel BEFORE paying an unbounded CG, rather than
2186        // grinding to CG_MAX_ITERS. (The guard is checked at the root level
2187        // only — refinement adds finer nets to the SAME scaled cloud, so the
2188        // aspect ratio is invariant under added levels.)
2189        if levels == INITIAL_LEVELS && !design.quasi_uniformity_certified() {
2190            return Err(format!(
2191                "residual cascade: metric-scaled aspect ratio {:.3e} exceeds the \
2192                 quasi-uniformity ceiling {QUASI_UNIFORMITY_MAX_ASPECT:.0e}; the BPX \
2193                 iteration bound is not trustworthy on this (near-degenerate) metric — \
2194                 fall back to the dense kernel path",
2195                design.metric_scaled_aspect_ratio()
2196            ));
2197        }
2198        let mut fit = design.fit_reml()?;
2199        // The realized CG iteration count at this cascade depth is the runtime
2200        // tell of the BPX n-independence bound (issue #1032 caveat: a count
2201        // creeping toward CG_MAX_ITERS means the quasi-uniformity guard's static
2202        // aspect-ratio check was too lenient for this cloud). It is exposed
2203        // STRUCTURALLY rather than over stderr: the per-depth count and backward
2204        // error ride on `fit.certificate` (`solve_iters` — 0 on the dense route,
2205        // the PCG count on the iterative route — and `solve_rel_residual`), so a
2206        // caller that wants to watch the bound reads them off the returned fit
2207        // instead of scraping log lines. (A library solve never writes to
2208        // stderr.)
2209        let gain = design.next_level_gain_bound(&fit)?;
2210        let tolerance = REFINE_TOL * fit.rss_pen;
2211        match gain {
2212            None => {
2213                fit.refinement = Some(RefinementCertificate {
2214                    next_level_gain_bound: 0.0,
2215                    tolerance,
2216                    exhausted: true,
2217                });
2218                return Ok(fit);
2219            }
2220            Some(bound) if bound <= tolerance || levels >= MAX_LEVELS => {
2221                fit.refinement = Some(RefinementCertificate {
2222                    next_level_gain_bound: bound,
2223                    tolerance,
2224                    exhausted: bound > tolerance,
2225                });
2226                return Ok(fit);
2227            }
2228            Some(_) => {
2229                levels += 1;
2230            }
2231        }
2232    }
2233}