Skip to main content

gam_sae/
encode.rs

1//! Kantorovich-certified encode atlas (issue #1010).
2//!
3//! Encoding a row `x ∈ ℝᵖ` against a FROZEN SAE dictionary is, per atom `k`,
4//! the coordinate-only Newton problem
5//!
6//! ```text
7//! min_t  f_k(t) = ½‖x − z_k · B_kᵀ Φ_k(t)‖² + prior_k(t),
8//! ```
9//!
10//! with the amplitude `z_k` and decoder block `B_k` held fixed (the encode
11//! freezes the dictionary; only the latent coordinate `t` moves). Newton on
12//! `F(t) = ∇f_k(t)` converges quadratically from a start `t₀` into the unique
13//! root in a certified ball whenever the **Newton–Kantorovich** quantity
14//!
15//! ```text
16//! h = β · η · L ≤ ½,    β = ‖F'(t₀)⁻¹‖,   η = ‖F'(t₀)⁻¹ F(t₀)‖,
17//! ```
18//!
19//! where `L` is a Lipschitz constant of `F'` (the Hessian of `f_k`) on a region
20//! containing the Newton iterates. `h` is CHECKABLE per row in `O(q³)`
21//! (`q = latent_dim`, tiny), so each fast-path encode carries its own
22//! exactness certificate.
23//!
24//! ## The closed-form Hessian-Lipschitz constant `L`
25//!
26//! Write `m(t) = z·BᵀΦ(t) ∈ ℝᵖ` (the reconstruction) and `r(t) = m(t) − x`.
27//! Then `f = ½‖r‖² + prior` and, differentiating three times,
28//!
29//! ```text
30//! ∇³f = 3·sym(J_mᵀ : ∇²m) + ⟨r, ∇³m⟩ + ∇³prior,
31//! ```
32//!
33//! so an operator-norm bound on the chart is
34//!
35//! ```text
36//! L ≤ 3·‖J_m‖·‖∇²m‖ + ‖r‖·‖∇³m‖ + L_prior,
37//! ```
38//!
39//! with `‖∂^g m‖ ≤ |z|·(Σ_m ‖B_{m,:}‖)·B_g`, where `B_g = sup_chart max_m
40//! ‖∂^g Φ_m‖` is the per-column jet sup of the basis family — closed form per
41//! family ([`BasisHessianLipschitz`]). `‖r‖` is bounded by `‖x‖ +
42//! |z|·(Σ_m‖B_{m,:}‖)·B_0`. The ARD/von-Mises prior `L_prior` is a closed-form
43//! constant from the prior strength. Every bound is conservative (an
44//! over-estimate of `L` only SHRINKS the certified radius — it can never
45//! certify a row that does not converge).
46//!
47//! ## Pipeline
48//!
49//! 1. **Offline, per atom** ([`EncodeAtlas::build`]): chart centers `t_c` on the
50//!    atom's coordinate grid (the SHAPE_BAND grid idiom), each with a certified
51//!    Newton radius `R_c` solved from the Kantorovich inequality at the
52//!    worst-case in-chart start.
53//! 2. **Online, per row** ([`EncodeAtlas::certified_encode_row`]): route to the
54//!    nearest chart, start from its distilled IFT predictor, take one or two
55//!    Newton steps, then the `h ≤ ½` check AT the start point is the per-row
56//!    certificate.
57//! 3. **Uncertified tail**: rows whose start fails `h ≤ ½` are FLAGGED (counted
58//!    in [`EncodeResult::encode_uncertified_count`]) and must be routed by the
59//!    caller to the existing exact multi-start solve. No approximation enters
60//!    silently.
61
62use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
63
64use gam_linalg::faer_ndarray::FaerEigh;
65use crate::candidate_index::{
66    AtomFrameSketch, SaeCandidateIndex, auto_candidate_budget,
67};
68use crate::manifold::{
69    AffineCoordinateEvaluator, CylinderHarmonicEvaluator, DuchonCoordinateEvaluator,
70    EuclideanPatchEvaluator, PeriodicHarmonicEvaluator, SaeBasisEvaluator, SaeManifoldAtom,
71    SphereChartEvaluator, TorusHarmonicEvaluator,
72};
73
74use faer::Side;
75
76/// The Kantorovich convergence threshold `h ≤ ½`. Below this the Newton
77/// iteration is guaranteed to converge quadratically into the unique root in
78/// the certified ball; at or above it the start is uncertified.
79pub const KANTOROVICH_THRESHOLD: f64 = 0.5;
80
81/// Row count at or above which the corpus-rate certified-encode batch
82/// (`certified_encode_batch` / `certified_encode_with_index`) fans its
83/// per-row encodes out over rayon. Below this the per-row Newton + chart
84/// routing is cheap enough that the fan-out overhead does not pay; matched to
85/// the same order as the arrow-Schur `SCHUR_MATVEC_PARALLEL_ROW_MIN` gate so
86/// short batches inside an outer atom-level fan-out stay sequential.
87pub(crate) const ENCODE_BATCH_PARALLEL_ROW_MIN: usize = 256;
88
89/// Minimum frame alignment `‖Uₖᵀd‖/‖d‖ ∈ [0,1]` the routed atom must have for an
90/// index-routed encode to be attempted at all — a FIT-QUALITY floor, NOT a
91/// routing-correctness gate (#1026, corrected by the #1777 exact-routing path).
92///
93/// Routing itself is now EXACT: the index-routed encode picks the atom via
94/// [`SaeCandidateIndex::route_exact`], which returns the GLOBAL argmax of the
95/// routing score (the universal-bound LSH fast path, else a full-scan fallback) —
96/// so there is no "missed-better-ungathered-atom" hole left for this constant to
97/// patch. What remains is a different, honest question: even the globally-best
98/// atom may align only weakly with a row (no atom in the dictionary fits it). A
99/// finite alignment below this floor means the best available atom is a poor fit,
100/// so the row is flagged and routed to the exact multi-start fallback rather than
101/// encoded against an atom it barely belongs to. (Previously this same comparison
102/// double-served as a recall proxy; that role is gone — `route_exact` guarantees
103/// recall — leaving only the fit-quality role described here.)
104pub(crate) const CANDIDATE_ROUTING_MIN_ALIGNMENT: f64 = 0.5;
105
106/// Number of nearest charts the CERTIFIED encode refines in before returning the
107/// lowest-reconstruction-error certified result. A single nearest chart is not
108/// globally sound where the decoded manifold folds near itself (both competing
109/// basins' charts reconstruct near the fold, so both rank among the nearest by
110/// ambient distance); refining the top few captures the global basin. For a
111/// unimodal atom all candidates converge to the same root, so K>1 is a no-op.
112pub(crate) const CERTIFIED_ROUTING_TOPK: usize = 4;
113
114/// Newton refinement convergence floor. Once a refinement step's length `‖δ‖`
115/// falls below this (relative to the coordinate scale `1 + ‖t‖`), the iterate has
116/// reached the certified root to f64 resolution: applying the step cannot move `t`
117/// meaningfully, and the remaining fixed-budget steps only re-accumulate round-off.
118/// Stopping there is STRICTLY more accurate than draining a fixed step budget on a
119/// well-conditioned quadratic Newton tail, and it removes that tail's per-step
120/// `evaluate` + `second_jet` cost (the dominant per-row encode work). The batched
121/// and per-row encodes share this rule, so they stay bit-identical.
122pub(crate) const NEWTON_REFINE_CONVERGED_EPS: f64 = 1.0e-12;
123
124/// Global-minimum short-circuit floor for top-K certified routing. The
125/// reconstruction error `‖x − z·m(t)‖` is bounded below by 0, so a certified
126/// candidate whose residual already sits at the ambient noise floor
127/// (`≤ this · (1 + ‖x‖)`) is provably the global optimum over the charts — no
128/// competing chart can reach a strictly lower residual. The remaining candidates'
129/// refinement is then skipped. Conservative (a genuine second basin of the same
130/// target reconstructs the SAME point, so returning the first is a valid encode).
131pub(crate) const CERTIFIED_GLOBAL_MIN_RECON_FLOOR: f64 = 1.0e-11;
132
133/// A chart region on an atom's latent coordinate: a center `t_c` plus a
134/// certified in-chart radius. Over the ball `‖t − t_c‖ ≤ radius` the jet sup
135/// bounds returned by [`BasisHessianLipschitz`] hold, so the Kantorovich
136/// constant `L` computed from them is valid for any start in the ball.
137///
138/// For radial (Duchon) families the chart also carries the minimum kernel-center
139/// distance `exclusion_r_min` (a lower bound on `‖t − c_k‖` over the chart) that
140/// bounds the otherwise-singular `1/r` radial tails (issue #1010).
141#[derive(Debug, Clone)]
142pub struct ChartRegion {
143    /// Chart center coordinate `t_c` (length = latent_dim).
144    pub center: Array1<f64>,
145    /// In-chart radius in the coordinate metric.
146    pub radius: f64,
147    /// For radial (Duchon) families: a lower bound on `‖t − c_k‖` over the
148    /// chart, across every kernel center `c_k`. `None` for non-radial families.
149    pub exclusion_r_min: Option<f64>,
150    /// For radial (Duchon) families: an upper bound on `‖t − c_k‖` over the
151    /// chart, across every kernel center `c_k`. `None` for non-radial families.
152    pub radial_r_max: Option<f64>,
153}
154
155impl ChartRegion {
156    pub fn new(center: Array1<f64>, radius: f64) -> Self {
157        Self {
158            center,
159            radius,
160            exclusion_r_min: None,
161            radial_r_max: None,
162        }
163    }
164
165    pub fn with_radial_bounds(mut self, r_min: f64, r_max: f64) -> Self {
166        self.exclusion_r_min = Some(r_min);
167        self.radial_r_max = Some(r_max);
168        self
169    }
170
171    /// A jet-sup certificate is only meaningful over a genuine region. Even
172    /// families whose bounds are manifold-global constants (the sup over any
173    /// chart equals the global sup) must refuse a malformed chart rather than
174    /// certify garbage geometry.
175    pub(crate) fn assert_valid(&self) {
176        assert!(
177            self.radius.is_finite()
178                && self.radius >= 0.0
179                && self.center.iter().all(|c| c.is_finite()),
180            "ChartRegion must have a finite center and a finite non-negative radius"
181        );
182    }
183}
184
185/// Per-column sup-norm bounds on the first three coordinate jets of a basis
186/// family `Φ(t)`, valid over a stated [`ChartRegion`] (issue #1010). These are
187/// the analytic ingredients of the Hessian-Lipschitz constant `L` — see the
188/// module docs for the assembly. `value_sup` bounds `max_m |Φ_m|`,
189/// `jacobian_sup`/`hessian_sup`/`third_sup` bound `max_m ‖∂^g Φ_m‖`.
190pub trait BasisHessianLipschitz {
191    fn value_sup(&self, chart: &ChartRegion) -> f64;
192    fn jacobian_sup(&self, chart: &ChartRegion) -> f64;
193    fn hessian_sup(&self, chart: &ChartRegion) -> f64;
194    fn third_sup(&self, chart: &ChartRegion) -> f64;
195}
196
197/// Sup over the circle of the `g`-th derivative of any single harmonic column
198/// of a `num_basis`-wide Fourier basis `[1, sin(2π h t), cos(2π h t), …]`:
199/// `(2π·H)^g` for the top harmonic `H = (num_basis − 1)/2`. The constant column
200/// contributes `0` for `g ≥ 1`, so the top harmonic dominates; the bound is
201/// global (the trig magnitudes are `≤ 1` everywhere, independent of the chart).
202pub(crate) fn harmonic_jet_sup(num_basis: usize, order: u32) -> f64 {
203    let top_harmonic = num_basis.saturating_sub(1) / 2;
204    let omega = std::f64::consts::TAU * top_harmonic as f64;
205    omega.powi(order as i32)
206}
207
208impl BasisHessianLipschitz for PeriodicHarmonicEvaluator {
209    fn value_sup(&self, chart: &ChartRegion) -> f64 {
210        chart.assert_valid();
211        1.0
212    }
213    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
214        chart.assert_valid();
215        harmonic_jet_sup(self.num_basis, 1)
216    }
217    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
218        chart.assert_valid();
219        harmonic_jet_sup(self.num_basis, 2)
220    }
221    fn third_sup(&self, chart: &ChartRegion) -> f64 {
222        chart.assert_valid();
223        harmonic_jet_sup(self.num_basis, 3)
224    }
225}
226
227impl BasisHessianLipschitz for TorusHarmonicEvaluator {
228    /// Tensor product of per-axis circle harmonics. A torus basis column is a
229    /// product of single-axis harmonics, each bounded as in the circle case.
230    /// The `g`-th coordinate jet routes `g` derivative operators across the
231    /// `latent_dim` factors (Leibniz); each routing contributes a product of
232    /// per-axis derivative magnitudes. A per-column sup is therefore bounded by
233    /// the top single-axis frequency to the `g`-th power times the number of
234    /// such routings (`latent_dim^g`, the count of operator-to-axis maps).
235    fn value_sup(&self, chart: &ChartRegion) -> f64 {
236        chart.assert_valid();
237        1.0
238    }
239    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
240        chart.assert_valid();
241        torus_jet_sup(self.num_harmonics, self.latent_dim, 1)
242    }
243    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
244        chart.assert_valid();
245        torus_jet_sup(self.num_harmonics, self.latent_dim, 2)
246    }
247    fn third_sup(&self, chart: &ChartRegion) -> f64 {
248        chart.assert_valid();
249        torus_jet_sup(self.num_harmonics, self.latent_dim, 3)
250    }
251}
252
253/// Per-column `g`-th jet sup for the torus harmonic basis: `(2π·H)^g ·
254/// latent_dim^g`, where `H = num_harmonics` is the top per-axis frequency and
255/// `latent_dim^g` over-counts the Leibniz routings of `g` operators across the
256/// product factors (a conservative bound — each routing's per-axis magnitude is
257/// `≤ (2π H)^{#ops on that axis}`, and the products telescope to `(2π H)^g`).
258pub(crate) fn torus_jet_sup(num_harmonics: usize, latent_dim: usize, order: u32) -> f64 {
259    let omega = std::f64::consts::TAU * num_harmonics as f64;
260    omega.powi(order as i32) * (latent_dim as f64).powi(order as i32)
261}
262
263impl BasisHessianLipschitz for SphereChartEvaluator {
264    /// The 7-column lat/lon chart `[1, x, y, z, xy, yz, xz]` with
265    /// `x = cos(lat)cos(lon)`, `y = cos(lat)sin(lon)`, `z = sin(lat)`. Each of
266    /// `x, y, z` is a product of two unit-frequency trig factors, so its `g`-th
267    /// coordinate jet is a sum of `2^g` products of `{sin,cos}` (each `≤ 1`):
268    /// magnitude `≤ 2^g` for `g ≥ 1`, `≤ 1` for `g = 0`. The bilinear columns
269    /// `xy, yz, xz` are products of two such coordinates; by Leibniz over the
270    /// product, their `g`-th jet is bounded by `Σ_{i=0}^{g} C(g,i)·(2^i)·(2^{g−i})
271    /// = (2+2)^g = 4^g` (using `‖∂^i u‖ ≤ 2^i`, `|u| ≤ 1`). The bilinear columns
272    /// dominate, so the per-column sup is `4^g` (`g ≥ 1`). Bounds are global
273    /// constants — the chart box `lat ∈ [-π/2, π/2]` does not enlarge them.
274    fn value_sup(&self, chart: &ChartRegion) -> f64 {
275        chart.assert_valid();
276        1.0
277    }
278    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
279        chart.assert_valid();
280        4.0
281    }
282    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
283        chart.assert_valid();
284        16.0
285    }
286    fn third_sup(&self, chart: &ChartRegion) -> f64 {
287        chart.assert_valid();
288        64.0
289    }
290}
291
292impl BasisHessianLipschitz for AffineCoordinateEvaluator {
293    /// The affine basis `[1, t₁, …, t_d]` is degree ≤ 1: its first jet has unit
294    /// columns, and all second and third jets vanish. The value sup is
295    /// `max(1, ‖t‖)` over the chart, bounded by `1 + ‖t_c‖ + radius`.
296    fn value_sup(&self, chart: &ChartRegion) -> f64 {
297        let center_norm = chart.center.dot(&chart.center).sqrt();
298        1.0 + center_norm + chart.radius
299    }
300    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
301        chart.assert_valid();
302        1.0
303    }
304    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
305        chart.assert_valid();
306        0.0
307    }
308    fn third_sup(&self, chart: &ChartRegion) -> f64 {
309        chart.assert_valid();
310        0.0
311    }
312}
313
314impl BasisHessianLipschitz for EuclideanPatchEvaluator {
315    /// Monomials of total degree ≤ `max_degree` in `t ∈ ℝ^d`. Over the ball of
316    /// radius `R` about `t_c`, each coordinate is bounded by `ρ = ‖t_c‖∞ + R`.
317    /// A monomial `t^α` with `|α| = q` has `g`-th partials bounded (crudely) by
318    /// the descending-factorial coefficient `q·(q−1)···(q−g+1) ≤ q^g` times
319    /// `ρ^{max(q−g,0)}`, and there are at most `d^g` partial routings, so the
320    /// per-column `g`-th jet sup is `≤ d^g · D^g · ρ^{max(D−g,0)}` with
321    /// `D = max_degree`. Conservative; D is small for patch evaluators.
322    fn value_sup(&self, chart: &ChartRegion) -> f64 {
323        let rho = patch_rho(chart);
324        let d = self.max_degree as i32;
325        rho.powi(d).max(1.0)
326    }
327    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
328        patch_jet_sup(self.latent_dim, self.max_degree, chart, 1)
329    }
330    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
331        patch_jet_sup(self.latent_dim, self.max_degree, chart, 2)
332    }
333    fn third_sup(&self, chart: &ChartRegion) -> f64 {
334        patch_jet_sup(self.latent_dim, self.max_degree, chart, 3)
335    }
336}
337
338impl BasisHessianLipschitz for CylinderHarmonicEvaluator {
339    /// Cylinder `S¹ × ℝ` product basis `Φ_{c,l} = c(t₀)·l(t₁)`, the circle
340    /// (periodic harmonic) factor on axis 0 crossed with the monomial line
341    /// factor on axis 1. Because the two factors depend on disjoint coordinates,
342    /// the order-`g` coordinate jet in any cell is exactly
343    /// `c^{(k₀)}(t₀)·l^{(k₁)}(t₁)` with `k₀ + k₁ = g`, so the per-column sup is
344    /// the max over the split `k₀ + k₁ = g` of the product of the two per-axis
345    /// per-order sups: the circle factor contributes `1` at order 0 and
346    /// `(2π·H)^{k₀}` at order `k₀ ≥ 1` (trig magnitudes `≤ 1`); the line factor
347    /// contributes the monomial-patch sup `D^{k₁}·ρ^{max(D−k₁,0)}` (`D = line
348    /// degree`, `ρ = ‖t_c‖∞ + radius`). Bounds are global in the periodic axis
349    /// and chart-local in the line axis.
350    fn value_sup(&self, chart: &ChartRegion) -> f64 {
351        cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 0)
352    }
353    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
354        cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 1)
355    }
356    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
357        cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 2)
358    }
359    fn third_sup(&self, chart: &ChartRegion) -> f64 {
360        cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 3)
361    }
362}
363
364/// Per-column order-`g` jet sup of the cylinder product basis: the max over
365/// `k₀ + k₁ = g` of `circle_axis_sup(k₀) · line_axis_sup(k₁)`, where the circle
366/// axis sup is `(2π·H)^{k₀}` (`1` at `k₀ = 0`) and the line axis sup is the
367/// monomial-patch bound `D^{k₁}·ρ^{max(D−k₁,0)}` (`1` at `k₁ = 0`). See the
368/// [`CylinderHarmonicEvaluator`] doc comment for the derivation.
369pub(crate) fn cylinder_jet_sup(
370    circle_harmonics: usize,
371    line_degree: usize,
372    chart: &ChartRegion,
373    order: u32,
374) -> f64 {
375    let omega = std::f64::consts::TAU * circle_harmonics as f64;
376    let big_d = line_degree as f64;
377    let rho = patch_rho(chart);
378    let mut best = 0.0_f64;
379    for k0 in 0..=order {
380        let k1 = order - k0;
381        let circle = if k0 == 0 { 1.0 } else { omega.powi(k0 as i32) };
382        let line = if k1 == 0 {
383            rho.powi(line_degree as i32).max(1.0)
384        } else {
385            let residual = line_degree.saturating_sub(k1 as usize) as i32;
386            // `.max(1.0)` as in `patch_jet_sup`: for ρ < 1 a lower-degree line
387            // monomial dominates the k1-th derivative, so the bare `ρ^residual`
388            // underestimates the line-factor sup. The value case (k1==0) already
389            // clamps; this completes it for the derivative orders.
390            big_d.powi(k1 as i32) * rho.powi(residual).max(1.0)
391        };
392        best = best.max(circle * line);
393    }
394    best
395}
396
397/// Sup-norm radius `ρ = ‖t_c‖∞ + radius` of the chart (the coordinate magnitude
398/// bound used by the monomial-patch jet bounds).
399pub(crate) fn patch_rho(chart: &ChartRegion) -> f64 {
400    let center_inf = chart
401        .center
402        .iter()
403        .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
404    center_inf + chart.radius
405}
406
407/// Per-column `g`-th jet sup for a monomial patch of max degree `D` in `d`
408/// coordinates over the chart: `d^g · D^g · ρ^{max(D−g,0)}` (see the
409/// [`EuclideanPatchEvaluator`] doc comment for the derivation).
410pub(crate) fn patch_jet_sup(
411    latent_dim: usize,
412    max_degree: usize,
413    chart: &ChartRegion,
414    order: u32,
415) -> f64 {
416    let d = latent_dim as f64;
417    let big_d = max_degree as f64;
418    let rho = patch_rho(chart);
419    let residual_degree = max_degree.saturating_sub(order as usize) as i32;
420    // `.max(1.0)`: for ρ < 1 (small charts near the origin) the g-th jet sup is NOT
421    // dominated by the max-degree monomial `t^D` (whose g-th derivative ~ ρ^{D-g}
422    // shrinks with ρ) but by a LOWER-degree monomial whose g-th derivative is a
423    // larger constant — e.g. {1,t,t²}'s jacobian sup is the linear term's constant
424    // `1`, which exceeds `2ρ` when ρ < ½. Without the clamp the bound underestimates
425    // the true sup (numerically: D=3, ρ=0.1, g=1 → formula 0.03 vs true 1.0), which
426    // would make the certificate's Lipschitz `L` too small → a FALSE certificate.
427    // `D^g · max(ρ^{D-g}, 1)` upper-bounds `max_{q∈[g,D]} (q!/(q-g)!)·ρ^{q-g}` for
428    // all ρ (the `q=g` term gives `g! ≤ D^g`, the `q=D` term gives `≤ D^g·ρ^{D-g}`).
429    d.powi(order as i32) * big_d.powi(order as i32) * rho.powi(residual_degree).max(1.0)
430}
431
432impl BasisHessianLipschitz for DuchonCoordinateEvaluator {
433    /// Radial-kernel basis `Φ_m(t) = φ(r_m)`, `r_m = ‖t − c_m‖`, plus a
434    /// polynomial nullspace block. For the cubic Duchon kernel `φ(r) = r³` the
435    /// radial derivatives are `φ' = 3r²`, `φ'' = 6r`, `φ''' = 6`. The chain rule
436    /// to coordinate jets introduces `1/r` factors through the unit radial
437    /// direction `u = (t − c)/r` and the projector `(I − uuᵀ)/r`, so over a
438    /// chart the jets are bounded by combining the radial-derivative magnitudes
439    /// at the worst-case radius with the inverse-radius tail at the chart's
440    /// EXCLUSION radius `r_min` (the closest a chart point gets to any center):
441    ///
442    /// ```text
443    /// ‖∇φ‖    ≤ |φ'|                              ≤ 3 r_max²
444    /// ‖∇²φ‖   ≤ |φ''| + |φ'|/r                    ≤ 6 r_max + 3 r_max²/r_min
445    /// ‖∇³φ‖   ≤ |φ'''| + 3|φ''|/r + 3|φ'|/r²      ≤ 6 + 18 r_max/r_min + 9 r_max²/r_min²
446    /// ```
447    ///
448    /// (the `1/r`, `1/r²` tails are bounded by `1/r_min`, `1/r_min²`). The
449    /// polynomial nullspace block is degree ≤ `order`; its jets are bounded like
450    /// the monomial patch with `D = order`. The per-column sup is the max of the
451    /// kernel and polynomial bounds. The `r³` kernel is itself `C²` (no
452    /// singularity) so these tails are conservative but finite for any
453    /// `r_min > 0`; the atlas refines charts to keep `r_min` bounded away from 0.
454    fn value_sup(&self, chart: &ChartRegion) -> f64 {
455        let r_max = chart.radial_r_max.unwrap_or(chart.radius);
456        let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 0);
457        (r_max.powi(3)).max(poly)
458    }
459    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
460        let r_max = chart.radial_r_max.unwrap_or(chart.radius);
461        let kernel = 3.0 * r_max * r_max;
462        let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 1);
463        kernel.max(poly)
464    }
465    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
466        let r_max = chart.radial_r_max.unwrap_or(chart.radius);
467        let r_min = chart
468            .exclusion_r_min
469            .unwrap_or(chart.radius)
470            .max(f64::MIN_POSITIVE);
471        let kernel = 6.0 * r_max + 3.0 * r_max * r_max / r_min;
472        let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 2);
473        kernel.max(poly)
474    }
475    fn third_sup(&self, chart: &ChartRegion) -> f64 {
476        let r_max = chart.radial_r_max.unwrap_or(chart.radius);
477        let r_min = chart
478            .exclusion_r_min
479            .unwrap_or(chart.radius)
480            .max(f64::MIN_POSITIVE);
481        let kernel = 6.0 + 18.0 * r_max / r_min + 9.0 * r_max * r_max / (r_min * r_min);
482        let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 3);
483        kernel.max(poly)
484    }
485}
486
487/// Polynomial-block degree of a Duchon nullspace order, used to bound the
488/// nullspace columns like a monomial patch.
489trait DuchonOrderDegree {
490    fn order_degree(&self) -> usize;
491}
492
493impl DuchonOrderDegree for DuchonCoordinateEvaluator {
494    fn order_degree(&self) -> usize {
495        match self.order {
496            gam_terms::basis::DuchonNullspaceOrder::Zero => 0,
497            gam_terms::basis::DuchonNullspaceOrder::Linear => 1,
498            gam_terms::basis::DuchonNullspaceOrder::Degree(d) => d,
499        }
500    }
501}
502
503/// Per-column `g`-th jet sup of the Duchon polynomial nullspace block, treated
504/// as a monomial patch of degree `order_degree`.
505pub(crate) fn duchon_poly_jet_sup(
506    latent_dim: usize,
507    order_degree: usize,
508    chart: &ChartRegion,
509    order: u32,
510) -> f64 {
511    if order_degree == 0 {
512        return if order == 0 { 1.0 } else { 0.0 };
513    }
514    patch_jet_sup(latent_dim, order_degree, chart, order)
515}
516
517/// Decoder magnitude `Σ_m ‖B_{m,:}‖₂` of an atom's frozen decoder block: the
518/// factor that converts a per-column `Φ`-jet sup `B_g` into a reconstruction
519/// jet sup `‖∂^g m‖ ≤ |z|·decoder_row_norm_sum·B_g`.
520pub(crate) fn decoder_row_norm_sum(decoder: ArrayView2<'_, f64>) -> f64 {
521    let mut acc = 0.0;
522    for row in decoder.rows() {
523        acc += row.dot(&row).sqrt();
524    }
525    acc
526}
527
528#[derive(Debug, Clone, Copy)]
529pub(crate) struct ReconstructionJetSups {
530    pub(crate) value: f64,
531    pub(crate) jacobian: f64,
532    pub(crate) hessian: f64,
533    pub(crate) third: f64,
534}
535
536pub(crate) fn pair_trig_decoder_sup(
537    sin_row: ArrayView1<'_, f64>,
538    cos_row: ArrayView1<'_, f64>,
539) -> f64 {
540    let aa = sin_row.dot(&sin_row);
541    let bb = cos_row.dot(&cos_row);
542    let ab = sin_row.dot(&cos_row);
543    let trace = aa + bb;
544    let disc = ((aa - bb) * (aa - bb) + 4.0 * ab * ab).sqrt();
545    (0.5 * (trace + disc)).sqrt()
546}
547
548pub(crate) fn periodic_reconstruction_jet_sups(
549    decoder: ArrayView2<'_, f64>,
550) -> ReconstructionJetSups {
551    let mut value = 0.0;
552    let mut jacobian = 0.0;
553    let mut hessian = 0.0;
554    let mut third = 0.0;
555    if decoder.nrows() > 0 {
556        value += decoder.row(0).dot(&decoder.row(0)).sqrt();
557    }
558    let harmonics = decoder.nrows().saturating_sub(1) / 2;
559    for h in 1..=harmonics {
560        let sin_idx = 2 * h - 1;
561        let cos_idx = 2 * h;
562        let amp = pair_trig_decoder_sup(decoder.row(sin_idx), decoder.row(cos_idx));
563        let omega = std::f64::consts::TAU * h as f64;
564        value += amp;
565        jacobian += omega * amp;
566        hessian += omega.powi(2) * amp;
567        third += omega.powi(3) * amp;
568    }
569    for row in (1 + 2 * harmonics)..decoder.nrows() {
570        let amp = decoder.row(row).dot(&decoder.row(row)).sqrt();
571        value += amp;
572        let omega = std::f64::consts::TAU * harmonics.max(1) as f64;
573        jacobian += omega * amp;
574        hessian += omega.powi(2) * amp;
575        third += omega.powi(3) * amp;
576    }
577    ReconstructionJetSups {
578        value,
579        jacobian,
580        hessian,
581        third,
582    }
583}
584
585pub(crate) fn reconstruction_jet_sups(
586    atom: &SaeManifoldAtom,
587    sups: JetSups,
588) -> ReconstructionJetSups {
589    if matches!(
590        atom.basis_kind,
591        crate::manifold::SaeAtomBasisKind::Periodic
592    ) {
593        periodic_reconstruction_jet_sups(atom.decoder_coefficients.view())
594    } else {
595        let decoder_norm_sum = decoder_row_norm_sum(atom.decoder_coefficients.view());
596        ReconstructionJetSups {
597            value: decoder_norm_sum * sups.value,
598            jacobian: decoder_norm_sum * sups.jacobian,
599            hessian: decoder_norm_sum * sups.hessian,
600            third: decoder_norm_sum * sups.third,
601        }
602    }
603}
604
605/// The Hessian-Lipschitz constant `L` of the per-row encode objective `f_k` on
606/// a chart, assembled in closed form from the basis jet sups and the decoder /
607/// amplitude / target magnitudes. See the module docs for the derivation:
608///
609/// ```text
610/// L ≤ 3·‖J_m‖·‖∇²m‖ + ‖r‖·‖∇³m‖ + L_prior,
611/// ‖∂^g m‖ ≤ |z|·S_B·B_g,   S_B = Σ_m ‖B_{m,:}‖,
612/// ‖r‖ ≤ ‖x‖ + |z|·S_B·B_0,
613/// ```
614///
615/// `prior_lipschitz` is the caller-supplied closed-form `L_prior` of the
616/// ARD/von-Mises coordinate prior (`0.0` if no prior is active on the encode).
617pub(crate) fn hessian_lipschitz_constant(
618    recon_sups: ReconstructionJetSups,
619    amplitude: f64,
620    target_norm: f64,
621    prior_lipschitz: f64,
622) -> f64 {
623    let z = amplitude.abs();
624    let m_jac = z * recon_sups.jacobian;
625    let m_hess = z * recon_sups.hessian;
626    let m_third = z * recon_sups.third;
627    let recon_value = z * recon_sups.value;
628    let r_norm = target_norm + recon_value;
629    3.0 * m_jac * m_hess + r_norm * m_third + prior_lipschitz
630}
631
632/// One offline-certified chart: a center, its Kantorovich constants, and the
633/// certified Newton-convergence radius `R_c` solved from `h = β·η·L ≤ ½` at the
634/// worst-case in-chart start.
635#[derive(Debug, Clone)]
636pub struct CertifiedChart {
637    pub region: ChartRegion,
638    /// Closed-form Hessian-Lipschitz constant `L` over the chart.
639    pub lipschitz: f64,
640    /// `β = ‖F'(t_c)⁻¹‖` at the chart center (worst-case in-chart start uses
641    /// the center's curvature; the radius is solved so the certificate holds for
642    /// any start in the ball).
643    pub beta_center: f64,
644    /// Certified Newton radius: starts within `radius` of `t_c` satisfy `h ≤ ½`.
645    pub certified_radius: f64,
646    /// Distilled amortized-encoder Jacobian for this chart (#1026 ladder item 3).
647    ///
648    /// The exact encode map `x ↦ t` solves `F(t; x) = J_m(t)ᵀ(m(t) − x) = 0`. By
649    /// the implicit function theorem its derivative at the converged root is
650    /// `dt/dx = −(∂_t F)⁻¹ (∂_x F) = H⁻¹ J_m` (since `∂_x F = −J_m`), so the
651    /// first-order Taylor expansion of the encode map about this chart's center
652    /// `t_c` is the closed-form AFFINE predictor
653    ///
654    /// ```text
655    /// t(x) ≈ t_c + (1/z) · A₁ · (x − z · m₁(t_c)),   A₁ = (J₁ᵀJ₁ + ridge·I)⁻¹ J₁,
656    /// ```
657    ///
658    /// with `J₁ = Bᵀ J_Φ(t_c)` and `m₁(t_c) = BᵀΦ(t_c)` the AMPLITUDE-1
659    /// reconstruction jets (the amplitude `z` factors out analytically, so the
660    /// stored Jacobian is amplitude-free). This is the DISTILLED amortized
661    /// encoder of the #1026 thread: the per-row Hessian factorization + Newton
662    /// iteration is moved OFFLINE into this `d × p` matrix, leaving a single
663    /// `O(d·p)` mat-vec online — no per-row eigendecomposition, no second-jet
664    /// evaluation. The Kantorovich certificate is still evaluated AT the
665    /// predicted start, so the amortized prediction is trusted iff `h ≤ ½` and an
666    /// uncertified row still routes to the exact multi-start solve (the encoder
667    /// approximates inference, the certificate keeps it honest — the thread's
668    /// "encoder + certificate-gated exact fallback" deployment). `None` when the
669    /// center's Gauss–Newton block is singular (no certifiable amortization).
670    pub amortized_jacobian: Option<Array2<f64>>,
671    /// Amplitude-1 chart-center reconstruction `m₁(t_c) = BᵀΦ(t_c)` (length `p`),
672    /// the anchor the amortized predictor expands the encode map around.
673    pub recon_center: Array1<f64>,
674    /// Precomputed affine-predictor CONSTANT term `base = t_c − A₁·m₁(t_c)` (length
675    /// `d`), so the online amortized encode of a row `x` at amplitude `z` is the
676    /// single mat-vec `t̂ = base + (1/z)·A₁·x` with NO per-row `A₁·m₁` recompute.
677    /// Hoisting this atom-static term offline is what lets the massive-K index-routed
678    /// fast paths run a single allocation-free pass over rows (rather than a per-atom
679    /// GEMM sub-batch that degenerates to one row per group when `K ≫ N`). `None`
680    /// exactly when `amortized_jacobian` is `None` (singular Gauss–Newton block).
681    pub amortized_base: Option<Array1<f64>>,
682}
683
684/// The per-atom encode atlas: a set of certified charts covering the atom's
685/// coordinate domain, plus the decoder/amplitude scaling needed to recompute a
686/// per-row certificate online.
687#[derive(Debug, Clone)]
688pub struct AtomEncodeAtlas {
689    pub atom_index: usize,
690    pub latent_dim: usize,
691    pub decoder_norm_sum: f64,
692    pub charts: Vec<CertifiedChart>,
693}
694
695/// Result of a certified encode over a batch of rows, carrying the honesty
696/// flag: how many rows could NOT be certified and were flagged for the exact
697/// multi-start fallback (issue #1010 — no approximation enters silently).
698#[derive(Debug, Clone)]
699pub struct EncodeResult {
700    /// Per-row encoded latent coordinates (`n_rows × latent_dim`).
701    pub coords: Array2<f64>,
702    /// Per-row certificate: `true` ⇒ the row's start satisfied `h ≤ ½` and the
703    /// 1–2 Newton steps are exact-into-the-certified-ball; `false` ⇒ flagged.
704    pub certified: Vec<bool>,
705    /// Count of rows that could not be certified. These ride the payload so the
706    /// caller routes them to the exact multi-start encode — honesty, never
707    /// silent. Equals `certified.iter().filter(|c| !**c).count()`.
708    pub encode_uncertified_count: usize,
709}
710
711impl EncodeResult {
712    pub(crate) fn from_rows(coords: Array2<f64>, certified: Vec<bool>) -> Self {
713        let encode_uncertified_count = certified.iter().filter(|c| !**c).count();
714        Self {
715            coords,
716            certified,
717            encode_uncertified_count,
718        }
719    }
720}
721
722/// Per-row Kantorovich certificate at a start `t₀` for one atom encode.
723#[derive(Debug, Clone, Copy)]
724pub struct RowCertificate {
725    pub beta: f64,
726    pub eta: f64,
727    pub lipschitz: f64,
728    /// `h = β·η·L`. The row is certified iff `h ≤ ½`.
729    pub h: f64,
730}
731
732impl RowCertificate {
733    pub fn certified(&self) -> bool {
734        self.h.is_finite() && self.h <= KANTOROVICH_THRESHOLD
735    }
736}
737
738#[derive(Debug, Clone)]
739struct CertifiedEncodeProbe {
740    coord: Array1<f64>,
741    initial_cert: RowCertificate,
742    final_cert: RowCertificate,
743}
744
745/// Canonical flat-axis polynomial degree of a cylinder `S¹ × ℝ` atom — the
746/// degree the topology-race builder ([`gam_solve::structure_harvest`]) uses
747/// for the line axis (`CylinderHarmonicEvaluator::new(_, 2)`). The encode atlas
748/// recovers the circle harmonic count from the basis width using this degree, so
749/// the two must agree.
750pub(crate) const SAE_CYLINDER_LINE_DEGREE: usize = 2;
751
752/// Build a basis-family handle for one atom from its [`SaeManifoldAtom`]. The
753/// atlas needs to evaluate the jet sups, which live on the concrete evaluator
754/// types; the atom carries the evaluator as `Arc<dyn SaeBasisEvaluator>`, so we
755/// reconstruct the family bound from the atom's basis kind + width + centers.
756pub(crate) fn family_jet_sups(
757    atom: &SaeManifoldAtom,
758    chart: &ChartRegion,
759) -> Result<JetSups, String> {
760    use crate::manifold::SaeAtomBasisKind::*;
761    let m = atom.basis_size();
762    let d = atom.latent_dim;
763    let sups = match &atom.basis_kind {
764        Periodic => {
765            let ev = PeriodicHarmonicEvaluator::new(m)?;
766            JetSups::from_family(&ev, chart)
767        }
768        Torus => {
769            // Torus basis width is `(2H+1)^d`; recover the per-axis harmonic
770            // count `H` from `axis_m = m^(1/d)` rather than a sum formula.
771            let axis_m = integer_root(m, d.max(1));
772            let num_harmonics = axis_m.saturating_sub(1) / 2;
773            let ev = TorusHarmonicEvaluator::new(d, num_harmonics.max(1))?;
774            JetSups::from_family(&ev, chart)
775        }
776        Sphere => {
777            let ev = SphereChartEvaluator;
778            JetSups::from_family(&ev, chart)
779        }
780        Cylinder => {
781            // Cylinder width is `(2H+1)·(D+1)` with the canonical flat-axis
782            // degree `D = SAE_CYLINDER_LINE_DEGREE` (the harvest convention).
783            // Recover the per-axis circle harmonic count `H` from
784            // `2H+1 = m/(D+1)`.
785            let ml = SAE_CYLINDER_LINE_DEGREE + 1;
786            if d != 2 || ml == 0 || m % ml != 0 {
787                return Err(format!(
788                    "EncodeAtlas: Cylinder atom requires latent_dim == 2 and width divisible by {ml}; got dim={d}, m={m}"
789                ));
790            }
791            let axis_mc = m / ml;
792            let h = axis_mc.saturating_sub(1) / 2;
793            let ev = CylinderHarmonicEvaluator::new(h.max(1), SAE_CYLINDER_LINE_DEGREE)?;
794            JetSups::from_family(&ev, chart)
795        }
796        Linear | EuclideanPatch | Poincare => {
797            // The patch width fixes max_degree implicitly; bound by a degree that
798            // covers the column count (conservative). Degree d-patch column count
799            // grows fast; we recover the smallest degree whose patch is ≥ m.
800            // Poincare atoms use the same tangent-coordinate polynomial decoder;
801            // their intrinsic smoothness differs in the penalty, not in Phi(t).
802            let degree = euclidean_patch_degree(d, m);
803            let ev = EuclideanPatchEvaluator::new(d, degree)?;
804            JetSups::from_family(&ev, chart)
805        }
806        Duchon => {
807            // The atom carries the basis kind but not the nullspace order, and
808            // the certificate needs an UPPER bound on L. The kernel-tail bound
809            // (cubic r³ coefficients vs the chart's r_min/r_max) is independent
810            // of the constructed order; the polynomial-block bound grows with the
811            // order, so we construct with a conservative order whose polynomial
812            // degree upper-bounds any nullspace the atom's basis width can hold.
813            // Constructing with `m = basis_size` maps to `Degree(basis_size − 1)`
814            // — an over-estimate that keeps the Lipschitz bound sound.
815            let centers = duchon_centers_from_atom(atom);
816            let conservative_m = m.max(1);
817            let ev = DuchonCoordinateEvaluator::new(centers, conservative_m)?;
818            JetSups::from_family(&ev, chart)
819        }
820        Precomputed(name) => {
821            return Err(format!(
822                "EncodeAtlas: precomputed basis '{name}' has no closed-form jet sup; route to exact encode"
823            ));
824        }
825    };
826    Ok(sups)
827}
828
829/// Smallest monomial-patch degree whose column count covers `m` basis columns.
830pub(crate) fn euclidean_patch_degree(latent_dim: usize, m: usize) -> usize {
831    // Column count of a degree-D patch in d vars is C(d+D, D). Grow D until it
832    // covers m; cap at m so a degenerate width still terminates.
833    let mut degree = 0usize;
834    while patch_column_count(latent_dim, degree) < m && degree < m {
835        degree += 1;
836    }
837    degree
838}
839
840/// Largest integer `a` with `a^k ≤ n` (the floor of the `k`-th root). Used to
841/// recover the per-axis harmonic width `axis_m` from a torus basis width
842/// `m = axis_m^d`.
843pub(crate) fn integer_root(n: usize, k: usize) -> usize {
844    if k == 0 {
845        return 1;
846    }
847    if k == 1 {
848        return n;
849    }
850    let mut a = 1usize;
851    loop {
852        let next = a + 1;
853        let mut pow: u128 = 1;
854        let mut overflow = false;
855        for _ in 0..k {
856            pow = pow.saturating_mul(next as u128);
857            if pow > n as u128 {
858                overflow = true;
859                break;
860            }
861        }
862        if overflow {
863            return a;
864        }
865        a = next;
866    }
867}
868
869pub(crate) fn patch_column_count(latent_dim: usize, degree: usize) -> usize {
870    // C(d + D, D)
871    let mut num = 1u128;
872    let mut den = 1u128;
873    for i in 1..=degree {
874        num *= (latent_dim + i) as u128;
875        den *= i as u128;
876    }
877    (num / den) as usize
878}
879
880/// Recover Duchon centers from an atom: when the evaluator is unavailable the
881/// atlas falls back to the atom's own latent-coordinate hull as the center set,
882/// which only affects the radial-tail bound conservatively.
883pub(crate) fn duchon_centers_from_atom(atom: &SaeManifoldAtom) -> Array2<f64> {
884    // One center at the origin in latent_dim space is a sound conservative
885    // default: the chart's own r_min / r_max bracket the true radial range.
886    Array2::<f64>::zeros((1, atom.latent_dim.max(1)))
887}
888
889/// The four per-column jet sups of a basis family over a chart.
890#[derive(Debug, Clone, Copy)]
891pub(crate) struct JetSups {
892    pub(crate) value: f64,
893    pub(crate) jacobian: f64,
894    pub(crate) hessian: f64,
895    pub(crate) third: f64,
896}
897
898impl JetSups {
899    pub(crate) fn from_family<B: BasisHessianLipschitz>(family: &B, chart: &ChartRegion) -> Self {
900        Self {
901            value: family.value_sup(chart),
902            jacobian: family.jacobian_sup(chart),
903            hessian: family.hessian_sup(chart),
904            third: family.third_sup(chart),
905        }
906    }
907}
908
909/// Evaluate one atom's encode objective gradient `F(t) = ∇f_k(t)` and the FULL
910/// Hessian `F'(t) = ∇²f_k(t)` at a single coordinate `t`, for a single target
911/// row `x` and fixed amplitude `z`. With `m(t) = z·BᵀΦ(t)`, `r = m − x`,
912/// `J_m = z·Bᵀ J_Φ`:
913///
914/// ```text
915/// g_t[a]   = J_m[a] · r                                  (= ∇f)
916/// H_tt[a,b] = J_m[a] · J_m[b] + r · ∂²m/∂t_a∂t_b         (= ∇²f, FULL Hessian)
917/// ```
918///
919/// The certificate uses the FULL Hessian rather than the Gauss-Newton block
920/// `J_mᵀ J_m`. This is the principled choice for Newton–Kantorovich: the
921/// theorem certifies convergence of Newton on `F = ∇f` to the unique nearby
922/// ROOT of `∇f`, but a root of `∇f` can be a maximum. The full Hessian is
923/// positive-definite exactly on the genuine-minimum basin, so requiring
924/// `λ_min(H) > 0` (finite `β`) is what flags a start that would otherwise let
925/// Gauss-Newton march into the wrong root (e.g. the circle antipode, a local
926/// max where `∇f = 0` but the full curvature is negative). The residual term
927/// needs the basis second jet `∂²Φ/∂t²`; an evaluator without one returns
928/// `None`, and the row is flagged (no silent Gauss-Newton fallback).
929pub(crate) fn encode_grad_hess(
930    atom: &SaeManifoldAtom,
931    evaluator: &dyn SaeBasisEvaluator,
932    t: ArrayView1<'_, f64>,
933    x: ArrayView1<'_, f64>,
934    amplitude: f64,
935    ridge: f64,
936) -> Result<Option<(Array1<f64>, Array2<f64>)>, String> {
937    let d = atom.latent_dim;
938    let p = atom.output_dim();
939    let m = atom.basis_size();
940    let coords = t.to_shape((1, d)).map_err(|e| e.to_string())?.to_owned();
941    let (phi, jet) = evaluator.evaluate(coords.view())?;
942    if phi.dim() != (1, m) {
943        return Err(format!(
944            "encode_grad_hess: evaluator returned phi {:?}, expected (1, {m})",
945            phi.dim()
946        ));
947    }
948    let decoder = &atom.decoder_coefficients;
949    // Reconstruction m(t) = z · Bᵀ Φ(t)  ∈ ℝᵖ.
950    let mut recon = Array1::<f64>::zeros(p);
951    for basis_col in 0..m {
952        let phi_v = phi[[0, basis_col]];
953        if phi_v == 0.0 {
954            continue;
955        }
956        for out in 0..p {
957            recon[out] += amplitude * phi_v * decoder[[basis_col, out]];
958        }
959    }
960    let residual = &recon - &x;
961    // J_m[axis] = z · Bᵀ (∂Φ/∂t_axis)  ∈ ℝᵖ.
962    let mut jm = Array2::<f64>::zeros((d, p));
963    for axis in 0..d {
964        for basis_col in 0..m {
965            let dphi = jet[[0, basis_col, axis]];
966            if dphi == 0.0 {
967                continue;
968            }
969            for out in 0..p {
970                jm[[axis, out]] += amplitude * dphi * decoder[[basis_col, out]];
971            }
972        }
973    }
974    // The full-Hessian residual term needs ∂²Φ/∂t². No second jet ⇒ no
975    // certificate (flag), never a silent Gauss-Newton substitute.
976    let second = match evaluator.second_jet_dyn(coords.view()) {
977        Some(result) => result?,
978        None => return Ok(None),
979    };
980    // Residual · decoder-row `r·B_{basis,:}` is INDEPENDENT of the (a,b) axes, yet
981    // the old code recomputed it `d²` times inside the Hessian double loop. Hoist it
982    // to one O(m·p) pass so the per-axis curvature term is a cheap O(m) dot.
983    let mut rd = vec![0.0_f64; m];
984    for (basis_col, rd_col) in rd.iter_mut().enumerate() {
985        let mut dot = 0.0;
986        for out in 0..p {
987            dot += residual[out] * decoder[[basis_col, out]];
988        }
989        *rd_col = dot;
990    }
991    // g_t[axis] = J_m[axis] · r ;  H_tt[a,b] = J_m[a]·J_m[b] + r·∂²m/∂t_a∂t_b.
992    // The full Hessian is symmetric (Gauss-Newton block + symmetric second jet), so
993    // compute the upper triangle and mirror — half the curvature work.
994    let mut g = Array1::<f64>::zeros(d);
995    let mut h = Array2::<f64>::zeros((d, d));
996    for a in 0..d {
997        let ja = jm.row(a);
998        g[a] = ja.dot(&residual);
999        for b in a..d {
1000            // Gauss-Newton block.
1001            let mut hab = ja.dot(&jm.row(b));
1002            // Residual · second-jet curvature: r · ∂²m_{ab},
1003            // ∂²m_{ab}[out] = z · Σ_basis (∂²Φ/∂t_a∂t_b) · B[basis, out].
1004            let mut curv = 0.0;
1005            for basis_col in 0..m {
1006                let d2phi = second[[0, basis_col, a, b]];
1007                if d2phi == 0.0 {
1008                    continue;
1009                }
1010                curv += amplitude * d2phi * rd[basis_col];
1011            }
1012            hab += curv;
1013            h[[a, b]] = hab;
1014            h[[b, a]] = hab;
1015        }
1016    }
1017    for a in 0..d {
1018        h[[a, a]] += ridge;
1019    }
1020    Ok(Some((g, h)))
1021}
1022
1023/// Operator-norm of `H⁻¹` (i.e. `β = 1/λ_min(H)`) and the Newton step
1024/// `δ = −H⁻¹ g` with `η = ‖δ‖`, from a symmetric PSD `H` and gradient `g`.
1025/// Returns `None` when `H` is numerically singular (λ_min ≤ 0) — an
1026/// uncertifiable start.
1027pub(crate) fn beta_eta_newton(
1028    h: ArrayView2<'_, f64>,
1029    g: ArrayView1<'_, f64>,
1030) -> Result<Option<(f64, f64, Array1<f64>)>, String> {
1031    // Closed-form fast paths for the tiny latent dims that dominate SAE atoms
1032    // (`d = 1, 2`), avoiding a faer eigendecomposition + its heap allocations on
1033    // the hottest per-row Newton inner loop. `β = 1/λ_min(H)`, `δ = −H⁻¹g`, and the
1034    // `λ_min ≤ 0` gate are computed directly; the symmetric-`H` reads mirror
1035    // `eigh(Side::Lower)` (which uses the lower triangle) exactly.
1036    let d = h.nrows();
1037    if d == 1 {
1038        let h00 = h[[0, 0]];
1039        if !(h00.is_finite() && h00 > 0.0) {
1040            return Ok(None);
1041        }
1042        let delta0 = -g[0] / h00;
1043        let mut delta = Array1::<f64>::zeros(1);
1044        delta[0] = delta0;
1045        return Ok(Some((1.0 / h00, delta0.abs(), delta)));
1046    }
1047    if d == 2 {
1048        // Symmetric H = [[a, b], [b, c]] read from the lower triangle.
1049        let a = h[[0, 0]];
1050        let b = h[[1, 0]];
1051        let c = h[[1, 1]];
1052        let tr = a + c;
1053        let det = a * c - b * b;
1054        // λ_min = ½(tr − √((a−c)² + 4b²)); ≥ 0 ⇒ H PSD, > 0 ⇒ PD.
1055        let disc = ((a - c) * (a - c) + 4.0 * b * b).max(0.0).sqrt();
1056        let lambda_min = 0.5 * (tr - disc);
1057        if !(lambda_min.is_finite() && lambda_min > 0.0) {
1058            return Ok(None);
1059        }
1060        // δ = −H⁻¹g with H⁻¹ = [[c, −b], [−b, a]] / det (det = λ_min·λ_max > 0).
1061        let inv_det = 1.0 / det;
1062        let g0 = g[0];
1063        let g1 = g[1];
1064        let d0 = -(c * g0 - b * g1) * inv_det;
1065        let d1 = -(a * g1 - b * g0) * inv_det;
1066        if !(d0.is_finite() && d1.is_finite()) {
1067            return Ok(None);
1068        }
1069        let mut delta = Array1::<f64>::zeros(2);
1070        delta[0] = d0;
1071        delta[1] = d1;
1072        let eta = (d0 * d0 + d1 * d1).sqrt();
1073        return Ok(Some((1.0 / lambda_min, eta, delta)));
1074    }
1075    let (vals, vecs) = h
1076        .eigh(Side::Lower)
1077        .map_err(|e| format!("beta_eta_newton: eigh failed: {e:?}"))?;
1078    let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
1079    if !(lambda_min.is_finite() && lambda_min > 0.0) {
1080        return Ok(None);
1081    }
1082    let beta = 1.0 / lambda_min;
1083    // Newton step δ = −H⁻¹ g via the eigendecomposition: δ = −Σ_i (vᵢᵀg/λᵢ) vᵢ.
1084    let mut delta = Array1::<f64>::zeros(d);
1085    for (col, &lam) in vals.iter().enumerate() {
1086        if lam <= 0.0 {
1087            return Ok(None);
1088        }
1089        let vi = vecs.column(col);
1090        let coeff = vi.dot(&g) / lam;
1091        for row in 0..d {
1092            delta[row] -= coeff * vi[row];
1093        }
1094    }
1095    let eta = delta.dot(&delta).sqrt();
1096    Ok(Some((beta, eta, delta)))
1097}
1098
1099/// Compute the per-row Kantorovich certificate for encoding target row `x`
1100/// against atom `atom` at start coordinate `t₀`, with fixed amplitude `z` and
1101/// the chart's closed-form Lipschitz constant `lipschitz`. Returns the
1102/// certificate AND the Newton step `δ = −H⁻¹ g` so the caller can advance.
1103pub fn row_certificate(
1104    atom: &SaeManifoldAtom,
1105    evaluator: &dyn SaeBasisEvaluator,
1106    t0: ArrayView1<'_, f64>,
1107    x: ArrayView1<'_, f64>,
1108    amplitude: f64,
1109    lipschitz: f64,
1110    ridge: f64,
1111) -> Result<(RowCertificate, Array1<f64>), String> {
1112    let uncertified = || {
1113        (
1114            RowCertificate {
1115                beta: f64::INFINITY,
1116                eta: f64::INFINITY,
1117                lipschitz,
1118                h: f64::INFINITY,
1119            },
1120            Array1::<f64>::zeros(atom.latent_dim),
1121        )
1122    };
1123    // No second jet ⇒ no full Hessian ⇒ uncertifiable (flag).
1124    let Some((g, h)) = encode_grad_hess(atom, evaluator, t0, x, amplitude, ridge)? else {
1125        return Ok(uncertified());
1126    };
1127    match beta_eta_newton(h.view(), g.view())? {
1128        Some((beta, eta, delta)) => {
1129            let cert = RowCertificate {
1130                beta,
1131                eta,
1132                lipschitz,
1133                h: beta * eta * lipschitz,
1134            };
1135            Ok((cert, delta))
1136        }
1137        // Indefinite / negative-curvature full Hessian: the start is at or past
1138        // a basin boundary (a max/saddle of f), not the minimum basin — flag.
1139        None => Ok(uncertified()),
1140    }
1141}
1142
1143fn uncertified_certificate(lipschitz: f64) -> RowCertificate {
1144    RowCertificate {
1145        beta: f64::INFINITY,
1146        eta: f64::INFINITY,
1147        lipschitz,
1148        h: f64::INFINITY,
1149    }
1150}
1151
1152fn refine_certified_start(
1153    atom: &SaeManifoldAtom,
1154    evaluator: &dyn SaeBasisEvaluator,
1155    mut t: Array1<f64>,
1156    x: ArrayView1<'_, f64>,
1157    amplitude: f64,
1158    lipschitz: f64,
1159    ridge: f64,
1160    newton_steps: usize,
1161    initial_cert: RowCertificate,
1162    mut delta: Array1<f64>,
1163) -> Result<Option<CertifiedEncodeProbe>, String> {
1164    assert!(initial_cert.certified());
1165    let mut final_cert = initial_cert;
1166    for _ in 0..newton_steps {
1167        // Convergence early-exit: the pending Newton step is below the coordinate
1168        // ULP scale, so `t + δ == t` to f64 resolution — the certified root is
1169        // reached and the remaining fixed-budget steps would only re-accumulate
1170        // round-off. This is where the well-conditioned quadratic Newton tail's
1171        // redundant `evaluate` + `second_jet` work is eliminated.
1172        if delta.dot(&delta).sqrt()
1173            <= NEWTON_REFINE_CONVERGED_EPS * (1.0 + t.dot(&t).sqrt())
1174        {
1175            break;
1176        }
1177        t = &t + &delta;
1178        let (cert, next_delta) =
1179            row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz, ridge)?;
1180        if !cert.certified() {
1181            return Ok(None);
1182        }
1183        final_cert = cert;
1184        delta = next_delta;
1185    }
1186    Ok(Some(CertifiedEncodeProbe {
1187        coord: t,
1188        initial_cert,
1189        final_cert,
1190    }))
1191}
1192
1193/// Certify an encode probe from `t_start`, navigating into the Kantorovich basin
1194/// first if needed (#1154/#1026). The Kantorovich quantity `h = β·η·L` scales with
1195/// amplitude through `L`, so at unit amplitude a positive-definite chart-center /
1196/// distilled start can sit OUTSIDE the certified ball (`h > ½`). Rather than
1197/// flagging it uncertified immediately — which made the encoder certify ZERO
1198/// held-out rows at amplitude 1.0 and fall back to the exact solve for everything —
1199/// take plain Newton steps toward the root, re-certifying at each iterate, while
1200/// the Kantorovich quantity `h = β·η·L` keeps CONTRACTING toward the ½ bound. The
1201/// certificate at the landing point is a full Kantorovich guarantee from there
1202/// (`h ≤ ½` ⇒ Newton converges to the in-ball root), so this only ever WIDENS the
1203/// certified set; it never certifies a non-convergent start.
1204///
1205/// Termination is the natural Newton stopping rule — there is no arbitrary step
1206/// budget. The warm-up stops and flags for the exact fallback when either the start
1207/// is not steppable (indefinite / non-finite Hessian — at or past a basin boundary)
1208/// or a step fails to reduce `h` (the iterate is not approaching a certifiable
1209/// in-chart root: its root lies outside this chart's valid Lipschitz region, or the
1210/// start was past the basin — empirically the rows that miss *plateau*, so more
1211/// steps cannot help; the lever there is denser charts, not more iterations). On
1212/// success the start is refined `newton_steps` further by [`refine_certified_start`].
1213fn certify_with_basin_warmup(
1214    atom: &SaeManifoldAtom,
1215    evaluator: &dyn SaeBasisEvaluator,
1216    t_start: Array1<f64>,
1217    x: ArrayView1<'_, f64>,
1218    amplitude: f64,
1219    lipschitz: f64,
1220    ridge: f64,
1221    newton_steps: usize,
1222    chart_center: ArrayView1<'_, f64>,
1223    chart_radius: f64,
1224) -> Result<Option<CertifiedEncodeProbe>, String> {
1225    // SOUNDNESS GUARD: `lipschitz` is the chart's Hessian-Lipschitz sup, which is
1226    // only a valid bound over this chart's ball `‖t − center‖ ≤ radius` for the
1227    // chart-local families (`EuclideanPatch`/`Linear`/`Poincare` monomial patches,
1228    // `Cylinder` line axis, `Duchon` radial kernels). If a warm-up iterate leaves
1229    // that ball, `row_certificate` would compute `h = β·η·L` with an `L` that no
1230    // longer bounds the true geometry there, so `h ≤ ½` would NOT imply Kantorovich
1231    // convergence — a false certificate. (The `h`-contraction check does NOT catch
1232    // this: `h` can decrease monotonically toward an out-of-chart root the whole
1233    // way.) So we keep every certified iterate inside the chart; a row whose root is
1234    // outside this chart flags for the exact fallback — its lever is a denser grid,
1235    // not a step using an invalid `L`. Global-`L` families (periodic/torus/sphere)
1236    // route their points to charts whose centers are near the root, so the guard
1237    // rarely trips for them, and where it does the row was out-of-chart anyway.
1238    let in_chart = |t: &Array1<f64>| -> bool {
1239        let r2: f64 = t
1240            .iter()
1241            .zip(chart_center.iter())
1242            .map(|(a, b)| (a - b) * (a - b))
1243            .sum();
1244        r2 <= chart_radius * chart_radius
1245    };
1246    let mut t = t_start;
1247    // The distilled / chart-center start must itself be in-chart for its certificate
1248    // to be valid; a bad IFT prediction landing outside the chart is uncertifiable.
1249    if !in_chart(&t) {
1250        return Ok(None);
1251    }
1252    let (mut cert, mut delta) =
1253        row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz, ridge)?;
1254    while !cert.certified() {
1255        // Not steppable (indefinite / non-finite Hessian): flag.
1256        if !(cert.h.is_finite() && cert.beta.is_finite() && cert.eta.is_finite()) {
1257            return Ok(None);
1258        }
1259        let prev_h = cert.h;
1260        let next = &t + &delta;
1261        // Refuse to step where the chart's `L` is no longer valid (see guard above).
1262        if !in_chart(&next) {
1263            return Ok(None);
1264        }
1265        t = next;
1266        let (next_cert, next_delta) =
1267            row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz, ridge)?;
1268        cert = next_cert;
1269        delta = next_delta;
1270        // The warm-up only helps while h keeps contracting toward ½. Once a step
1271        // fails to reduce it, the iterate is not converging to a certifiable in-chart
1272        // root — flag for the exact fallback (no arbitrary step budget).
1273        if !cert.h.is_finite() || cert.h >= prev_h {
1274            return Ok(None);
1275        }
1276    }
1277    refine_certified_start(
1278        atom,
1279        evaluator,
1280        t,
1281        x,
1282        amplitude,
1283        lipschitz,
1284        ridge,
1285        newton_steps,
1286        cert,
1287        delta,
1288    )
1289}
1290
1291fn kantorovich_root_radius(cert: RowCertificate) -> f64 {
1292    if !cert.certified() || !(cert.eta.is_finite() && cert.eta >= 0.0) {
1293        return f64::INFINITY;
1294    }
1295    if cert.eta == 0.0 {
1296        return 0.0;
1297    }
1298    if !(cert.h.is_finite() && cert.h >= 0.0) {
1299        return f64::INFINITY;
1300    }
1301    let h = cert.h.min(KANTOROVICH_THRESHOLD);
1302    let discriminant = (1.0 - 2.0 * h).max(0.0).sqrt();
1303    let radius = 2.0 * cert.eta / (1.0 + discriminant);
1304    if radius.is_finite() {
1305        radius
1306    } else {
1307        f64::INFINITY
1308    }
1309}
1310
1311fn distilled_probe_tolerance(
1312    amortized: &CertifiedEncodeProbe,
1313    cold: &CertifiedEncodeProbe,
1314    amplitude: f64,
1315    x: ArrayView1<'_, f64>,
1316) -> f64 {
1317    let certified_radius =
1318        kantorovich_root_radius(amortized.final_cert) + kantorovich_root_radius(cold.final_cert);
1319    let coord_scale = amortized.coord.dot(&amortized.coord).sqrt()
1320        + cold.coord.dot(&cold.coord).sqrt()
1321        + x.dot(&x).sqrt()
1322        + amplitude.abs()
1323        + 1.0;
1324    certified_radius + 1024.0 * f64::EPSILON * coord_scale
1325}
1326
1327fn latent_coordinate_distance(
1328    atom: &SaeManifoldAtom,
1329    lhs: ArrayView1<'_, f64>,
1330    rhs: ArrayView1<'_, f64>,
1331) -> f64 {
1332    let mut acc = 0.0;
1333    for axis in 0..lhs.len().min(rhs.len()) {
1334        let mut diff = (lhs[axis] - rhs[axis]).abs();
1335        if let Some(period) = latent_axis_period(atom, axis) {
1336            let wrapped = diff.rem_euclid(period);
1337            diff = wrapped.min(period - wrapped);
1338        }
1339        acc += diff * diff;
1340    }
1341    acc.sqrt()
1342}
1343
1344fn latent_axis_period(atom: &SaeManifoldAtom, axis: usize) -> Option<f64> {
1345    use crate::manifold::SaeAtomBasisKind::*;
1346    match &atom.basis_kind {
1347        Periodic | Torus => Some(1.0),
1348        Cylinder if axis == 0 => Some(1.0),
1349        Sphere if axis == 1 => Some(std::f64::consts::TAU),
1350        _ => None,
1351    }
1352}
1353
1354/// Configuration for [`EncodeAtlas`] construction and online encode. All fields
1355/// are explicit; the atlas never reads global state and adds no CLI flags.
1356#[derive(Debug, Clone, Copy)]
1357pub struct AtlasConfig {
1358    /// Grid resolution per latent axis for offline chart centers (the
1359    /// SHAPE_BAND grid idiom).
1360    pub grid_resolution: usize,
1361    /// Levenberg ridge floor added to the per-row Gauss-Newton Hessian.
1362    pub ridge: f64,
1363    /// Number of online Newton refinement steps after a certified start (1 or 2
1364    /// per issue #1010).
1365    pub newton_steps: usize,
1366}
1367
1368impl Default for AtlasConfig {
1369    fn default() -> Self {
1370        Self {
1371            grid_resolution: 16,
1372            ridge: 1.0e-9,
1373            newton_steps: 2,
1374        }
1375    }
1376}
1377
1378/// The encode atlas: per-atom certified charts plus the online certified-encode
1379/// driver (issue #1010).
1380#[derive(Debug, Clone)]
1381pub struct EncodeAtlas {
1382    pub atoms: Vec<AtomEncodeAtlas>,
1383    pub config: AtlasConfig,
1384}
1385
1386impl EncodeAtlas {
1387    /// Build the offline atlas over a frozen dictionary: for each atom, lay down
1388    /// chart centers on the atom's coordinate grid and certify a Newton radius
1389    /// from the Kantorovich inequality at the worst-case in-chart start.
1390    ///
1391    /// `amplitude_bound[k]` is the per-atom bound on `|z_k|` used to scale the
1392    /// reconstruction jets (the offline `L` must hold for the largest amplitude
1393    /// the encode can produce); `target_norm_bound` bounds `‖x‖` over the data.
1394    pub fn build(
1395        atoms: &[SaeManifoldAtom],
1396        amplitude_bound: &[f64],
1397        target_norm_bound: f64,
1398        config: AtlasConfig,
1399    ) -> Result<Self, String> {
1400        if amplitude_bound.len() != atoms.len() {
1401            return Err(format!(
1402                "EncodeAtlas::build: amplitude_bound length {} != atom count {}",
1403                amplitude_bound.len(),
1404                atoms.len()
1405            ));
1406        }
1407        let mut atom_atlases = Vec::with_capacity(atoms.len());
1408        for (k, atom) in atoms.iter().enumerate() {
1409            let atlas =
1410                Self::build_atom_atlas(k, atom, amplitude_bound[k], target_norm_bound, &config)?;
1411            atom_atlases.push(atlas);
1412        }
1413        Ok(Self {
1414            atoms: atom_atlases,
1415            config,
1416        })
1417    }
1418
1419    pub(crate) fn build_atom_atlas(
1420        atom_index: usize,
1421        atom: &SaeManifoldAtom,
1422        amplitude_bound: f64,
1423        target_norm_bound: f64,
1424        config: &AtlasConfig,
1425    ) -> Result<AtomEncodeAtlas, String> {
1426        let centers = chart_center_grid(atom, config.grid_resolution);
1427        // Half the inter-center spacing is the natural in-chart radius so the
1428        // charts tile the grid without gaps; refined below if the certificate
1429        // fails at that radius. One uniform radius for the regular grid.
1430        let nominal_radius = chart_nominal_radius(atom, config.grid_resolution);
1431        let radii = vec![nominal_radius; centers.nrows()];
1432        Self::build_atom_atlas_from_centers(
1433            atom_index,
1434            atom,
1435            centers.view(),
1436            &radii,
1437            amplitude_bound,
1438            target_norm_bound,
1439            config,
1440        )
1441    }
1442
1443    /// Build a per-atom atlas from EXPLICIT chart centers with a per-center
1444    /// nominal radius — the geometry-agnostic core shared by the regular-grid
1445    /// [`Self::build_atom_atlas`] and the data-driven [`Self::build_data_driven`].
1446    /// Every chart is certified identically (Kantorovich radius from the in-chart
1447    /// curvature at its center); only the center PLACEMENT and per-center radius
1448    /// differ. `radii[c]` is the nominal in-chart radius for `centers[c]`.
1449    pub(crate) fn build_atom_atlas_from_centers(
1450        atom_index: usize,
1451        atom: &SaeManifoldAtom,
1452        centers: ArrayView2<'_, f64>,
1453        radii: &[f64],
1454        amplitude_bound: f64,
1455        target_norm_bound: f64,
1456        config: &AtlasConfig,
1457    ) -> Result<AtomEncodeAtlas, String> {
1458        let d = atom.latent_dim;
1459        if centers.ncols() != d {
1460            return Err(format!(
1461                "build_atom_atlas_from_centers: centers have {} cols but atom latent_dim is {d}",
1462                centers.ncols()
1463            ));
1464        }
1465        if radii.len() != centers.nrows() {
1466            return Err(format!(
1467                "build_atom_atlas_from_centers: {} radii != {} centers",
1468                radii.len(),
1469                centers.nrows()
1470            ));
1471        }
1472        let decoder_norm_sum = decoder_row_norm_sum(atom.decoder_coefficients.view());
1473        let mut charts = Vec::with_capacity(centers.nrows());
1474        for c in 0..centers.nrows() {
1475            let center = centers.row(c).to_owned();
1476            let nominal_radius = radii[c];
1477            let region = chart_region(atom, center.clone(), nominal_radius);
1478            let sups = family_jet_sups(atom, &region)?;
1479            let recon_sups = reconstruction_jet_sups(atom, sups);
1480            let lipschitz =
1481                hessian_lipschitz_constant(recon_sups, amplitude_bound, target_norm_bound, 0.0);
1482            // β at the chart center bounds the worst-case in-chart curvature
1483            // (the Gauss-Newton Hessian is continuous; the certified radius is
1484            // solved so the certificate is robust to the start within the ball).
1485            let beta_center = match center_beta(atom, &center, config.ridge) {
1486                Some(b) => b,
1487                None => {
1488                    // Degenerate center curvature: no certifiable chart here, and
1489                    // no amortized Jacobian (the same singular Gauss–Newton block).
1490                    charts.push(CertifiedChart {
1491                        region,
1492                        lipschitz,
1493                        beta_center: f64::INFINITY,
1494                        certified_radius: 0.0,
1495                        amortized_jacobian: None,
1496                        recon_center: Array1::<f64>::zeros(atom.output_dim()),
1497                        amortized_base: None,
1498                    });
1499                    continue;
1500                }
1501            };
1502            // Distill the amortized-encoder Jacobian at this center (#1026 ladder
1503            // item 3): the IFT derivative of the encode map, precomputed offline
1504            // so the online encode is one mat-vec. A finite `beta_center` (above)
1505            // means the Gauss–Newton block is non-singular, so this succeeds
1506            // alongside it; the pair travels together on the chart.
1507            let (amortized_jacobian, recon_center) =
1508                match center_amortized_jacobian(atom, &center, config.ridge) {
1509                    Some((a1, m1)) => (Some(a1), m1),
1510                    None => (None, Array1::<f64>::zeros(atom.output_dim())),
1511                };
1512            // Certified radius from h = β·η·L ≤ ½ with η ≤ R (Newton step length
1513            // is bounded by the start distance to the root, itself ≤ chart
1514            // radius at worst): R_c = ½ / (β·L), capped at the nominal radius.
1515            let certified_radius = if lipschitz > 0.0 && beta_center.is_finite() {
1516                (0.5 / (beta_center * lipschitz)).min(region.radius)
1517            } else {
1518                region.radius
1519            };
1520            // Precompute the affine-predictor constant `base = t_c − A₁·m₁` (atom-
1521            // static), so the online encode is a single `base + (1/z)·A₁·x` mat-vec.
1522            let amortized_base = amortized_jacobian
1523                .as_ref()
1524                .map(|a1| &center - &a1.dot(&recon_center));
1525            charts.push(CertifiedChart {
1526                region,
1527                lipschitz,
1528                beta_center,
1529                certified_radius,
1530                amortized_jacobian,
1531                recon_center,
1532                amortized_base,
1533            });
1534        }
1535        Ok(AtomEncodeAtlas {
1536            atom_index,
1537            latent_dim: d,
1538            decoder_norm_sum,
1539            charts,
1540        })
1541    }
1542
1543    /// Build the atlas with DATA-DRIVEN chart placement: instead of a dense
1544    /// `resolution^d` product grid (exponential in latent dim `d`, so the regular
1545    /// [`Self::build`] is forced to coarse, poorly-certified charts for `d ≥ 3`),
1546    /// place a bounded number of charts AT the data's own latent coordinates. The
1547    /// chart count is then `O(max_charts)` regardless of `d`, and every chart sits
1548    /// where data actually lands (small in-chart residual → certifies), so
1549    /// higher-dimensional atoms — which reconstruct real activations far better per
1550    /// parameter — become affordable and well-covered.
1551    ///
1552    /// `coords[k]` is atom `k`'s `n × d_k` latent coordinates (the seed coords, or
1553    /// a previous encode's output). Charts are chosen by greedy farthest-point
1554    /// sampling over those coords (deterministic, coverage-maximizing), capped at
1555    /// `max_charts`. Each chart's nominal radius is half the distance to its
1556    /// nearest neighbor center, so the charts tile the local data density. The
1557    /// per-chart Kantorovich certification is IDENTICAL to the regular grid — only
1558    /// the center placement differs.
1559    pub fn build_data_driven(
1560        atoms: &[SaeManifoldAtom],
1561        coords: &[Array2<f64>],
1562        amplitude_bound: &[f64],
1563        target_norm_bound: f64,
1564        max_charts: usize,
1565        config: AtlasConfig,
1566    ) -> Result<Self, String> {
1567        if amplitude_bound.len() != atoms.len() || coords.len() != atoms.len() {
1568            return Err(format!(
1569                "build_data_driven: amplitude_bound {} / coords {} must match atom count {}",
1570                amplitude_bound.len(),
1571                coords.len(),
1572                atoms.len()
1573            ));
1574        }
1575        let mut atom_atlases = Vec::with_capacity(atoms.len());
1576        for (k, atom) in atoms.iter().enumerate() {
1577            let (centers, radii) =
1578                data_driven_chart_centers(atom, coords[k].view(), max_charts.max(1))?;
1579            let atlas = Self::build_atom_atlas_from_centers(
1580                k,
1581                atom,
1582                centers.view(),
1583                &radii,
1584                amplitude_bound[k],
1585                target_norm_bound,
1586                &config,
1587            )?;
1588            atom_atlases.push(atlas);
1589        }
1590        Ok(Self {
1591            atoms: atom_atlases,
1592            config,
1593        })
1594    }
1595
1596    fn refine_certified_encode_start(
1597        &self,
1598        atom: &SaeManifoldAtom,
1599        evaluator: &dyn SaeBasisEvaluator,
1600        chart: &CertifiedChart,
1601        t: Array1<f64>,
1602        x: ArrayView1<'_, f64>,
1603        amplitude: f64,
1604    ) -> Result<(Array1<f64>, RowCertificate), String> {
1605        // Certify from the warm start, navigating into the Kantorovich basin first
1606        // if the unit-amplitude start has h > ½ (see `certify_with_basin_warmup`).
1607        let Some(probe) = certify_with_basin_warmup(
1608            atom,
1609            evaluator,
1610            t,
1611            x,
1612            amplitude,
1613            chart.lipschitz,
1614            self.config.ridge,
1615            self.config.newton_steps,
1616            chart.region.center.view(),
1617            chart.region.radius,
1618        )?
1619        else {
1620            return Ok((
1621                Array1::<f64>::zeros(atom.latent_dim),
1622                uncertified_certificate(chart.lipschitz),
1623            ));
1624        };
1625        Ok((probe.coord, probe.initial_cert))
1626    }
1627
1628    /// Online certified encode of one target row `x` against one atom `k` with
1629    /// fixed amplitude `z`. Routes to the nearest chart, starts from that chart's
1630    /// distilled IFT warm start, runs `config.newton_steps` Newton steps, and
1631    /// returns the encoded coordinate with its certificate. An uncertified start
1632    /// (no chart, no distilled Jacobian, non-positive amplitude, or `h > ½`)
1633    /// flags the row for the exact multi-start caller.
1634    pub fn certified_encode_row(
1635        &self,
1636        atom: &SaeManifoldAtom,
1637        atom_index: usize,
1638        x: ArrayView1<'_, f64>,
1639        amplitude: f64,
1640    ) -> Result<(Array1<f64>, RowCertificate), String> {
1641        let atom_atlas = self
1642            .atoms
1643            .get(atom_index)
1644            .ok_or_else(|| format!("certified_encode_row: atom {atom_index} not in atlas"))?;
1645        let d = atom.latent_dim;
1646        // A missing basis evaluator means the amortized/cold predictor cannot fire
1647        // for this atom (e.g. a frozen-baseline or first-build atom that never
1648        // attached a distilled evaluator). That is exactly the "cannot certify"
1649        // state — flag the row uncertified (zeros coords, ∞ certificate) so the
1650        // upstream exact multi-start solve owns it, never a hard error that aborts
1651        // the whole criterion. Mirrors the no-chart / singular-Jacobian branches.
1652        let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
1653            return Ok((
1654                Array1::<f64>::zeros(d),
1655                RowCertificate {
1656                    beta: f64::INFINITY,
1657                    eta: f64::INFINITY,
1658                    lipschitz: f64::INFINITY,
1659                    h: f64::INFINITY,
1660                },
1661            ));
1662        };
1663
1664        // Route to the nearest chart centers by AMBIENT reconstruction distance.
1665        // A single nearest chart is NOT globally sound on self-approaching atoms:
1666        // where the decoded manifold folds near itself (two distant latent points
1667        // map near the same output), the nearest-center chart can certify into the
1668        // locally-worse basin while another chart holds the GLOBAL minimum (both
1669        // branches' charts reconstruct near the crossing, so both are near in
1670        // ambient distance). The certificate is honest about LOCAL convergence but
1671        // cannot see the better far basin. So we refine in the top-K nearest charts
1672        // and keep the lowest-reconstruction-error CERTIFIED result. For a unimodal
1673        // atom every candidate chart converges to the same root, so this is a no-op
1674        // (first-wins tie → the nearest chart), preserving the existing behavior.
1675        let candidates =
1676            nearest_charts_topk(atom_atlas, x, CERTIFIED_ROUTING_TOPK);
1677        if candidates.is_empty() {
1678            return Ok((
1679                Array1::<f64>::zeros(d),
1680                RowCertificate {
1681                    beta: f64::INFINITY,
1682                    eta: f64::INFINITY,
1683                    lipschitz: f64::INFINITY,
1684                    h: f64::INFINITY,
1685                },
1686            ));
1687        }
1688        // Best CERTIFIED result by reconstruction error, plus the nearest chart's
1689        // result as the uncertified fallback (preserving the prior return when no
1690        // candidate certifies — the nearest chart owns the flagged row).
1691        let mut best: Option<(Array1<f64>, RowCertificate, f64)> = None;
1692        let mut nearest_fallback: Option<(Array1<f64>, RowCertificate)> = None;
1693        for chart_idx in candidates {
1694            let chart = &atom_atlas.charts[chart_idx];
1695            let Some(t) = amortized_warm_start(chart, x, amplitude) else {
1696                if nearest_fallback.is_none() {
1697                    nearest_fallback =
1698                        Some((Array1::<f64>::zeros(d), uncertified_certificate(chart.lipschitz)));
1699                }
1700                continue;
1701            };
1702            let (coord, cert) = self.refine_certified_encode_start(
1703                atom,
1704                evaluator.as_ref(),
1705                chart,
1706                t,
1707                x,
1708                amplitude,
1709            )?;
1710            if nearest_fallback.is_none() {
1711                nearest_fallback = Some((coord.clone(), cert.clone()));
1712            }
1713            if cert.certified() {
1714                let err =
1715                    encode_reconstruction_error(atom, evaluator.as_ref(), coord.view(), x, amplitude);
1716                if best.as_ref().map(|(_, _, e)| err < *e).unwrap_or(true) {
1717                    best = Some((coord, cert, err));
1718                }
1719                // Global-minimum short-circuit: reconstruction error ≥ 0, so a
1720                // certified candidate already at the ambient noise floor is provably
1721                // the global optimum over the remaining charts — stop refining them.
1722                if let Some((_, _, e)) = best.as_ref() {
1723                    if *e <= CERTIFIED_GLOBAL_MIN_RECON_FLOOR * (1.0 + x.dot(&x).sqrt()) {
1724                        break;
1725                    }
1726                }
1727            }
1728        }
1729        match best {
1730            Some((coord, cert, _)) => Ok((coord, cert)),
1731            None => Ok(nearest_fallback.unwrap_or_else(|| {
1732                (
1733                    Array1::<f64>::zeros(d),
1734                    RowCertificate {
1735                        beta: f64::INFINITY,
1736                        eta: f64::INFINITY,
1737                        lipschitz: f64::INFINITY,
1738                        h: f64::INFINITY,
1739                    },
1740                )
1741            })),
1742        }
1743    }
1744
1745    /// Amortized (distilled) encode of one target row `x` against one atom `k`
1746    /// with fixed amplitude `z` (#1026 ladder item 3).
1747    ///
1748    /// Routes to the nearest chart, then predicts the latent coordinate in CLOSED
1749    /// FORM from that chart's precomputed implicit-function-theorem Jacobian:
1750    ///
1751    /// ```text
1752    /// t̂ = t_c + (1/z) · A₁ · (x − z · m₁(t_c)),
1753    /// ```
1754    ///
1755    /// a single `O(d·p)` mat-vec — no per-row Hessian factorization or
1756    /// eigendecomposition, which is the amortization. The Kantorovich
1757    /// certificate is then evaluated AT the predicted start `t̂` with the chart's
1758    /// closed-form Lipschitz constant. A prediction is accepted only when that
1759    /// certificate holds, an independent cold chart-center probe also certifies,
1760    /// and the two refined coordinates agree within the two probes' final
1761    /// Kantorovich root-radius bounds. This keeps the distilled path honest
1762    /// without letting the exact probe reuse the distilled warm start it is
1763    /// auditing. A chart without a distilled Jacobian (singular Gauss–Newton
1764    /// block) flags the row.
1765    pub fn amortized_encode_row(
1766        &self,
1767        atom: &SaeManifoldAtom,
1768        atom_index: usize,
1769        x: ArrayView1<'_, f64>,
1770        amplitude: f64,
1771    ) -> Result<(Array1<f64>, RowCertificate), String> {
1772        let atom_atlas = self
1773            .atoms
1774            .get(atom_index)
1775            .ok_or_else(|| format!("amortized_encode_row: atom {atom_index} not in atlas"))?;
1776        let d = atom.latent_dim;
1777        let uncertified = || {
1778            (
1779                Array1::<f64>::zeros(d),
1780                RowCertificate {
1781                    beta: f64::INFINITY,
1782                    eta: f64::INFINITY,
1783                    lipschitz: f64::INFINITY,
1784                    h: f64::INFINITY,
1785                },
1786            )
1787        };
1788        // A missing basis evaluator means the distilled predictor cannot fire for
1789        // this atom — flag the row uncertified (the exact upstream solve owns it)
1790        // rather than erroring, exactly as the no-chart / singular-Jacobian /
1791        // non-positive-amplitude branches below do. Never a silent wrong encode,
1792        // never a hard abort of the criterion.
1793        let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
1794            return Ok(uncertified());
1795        };
1796        let Some((chart_idx, _)) = nearest_chart(atom_atlas, x) else {
1797            return Ok(uncertified());
1798        };
1799        let chart = &atom_atlas.charts[chart_idx];
1800        // Closed-form predicted start t̂ = t_c + (1/z)·A₁·(x − z·m₁). `None` when
1801        // the chart's Gauss–Newton block was singular (no distilled Jacobian, so
1802        // the amortized predictor cannot fire) or the amplitude is not strictly
1803        // positive and finite (a near-inactive atom, where the amplitude-divided
1804        // map is undefined) — either way flag for the exact fallback, never a
1805        // silent wrong encode.
1806        let Some(t_hat) = amortized_warm_start(chart, x, amplitude) else {
1807            return Ok(uncertified());
1808        };
1809        // Evaluate the SAME Kantorovich certificate at the predicted start. The
1810        // amortized prediction is trusted only if this certificate holds AND an
1811        // independent cold chart-center probe certifies and agrees below the
1812        // two probes' final Kantorovich root-radius bounds. This avoids the
1813        // self-referential gate where the "exact" probe is warm-started by the
1814        // same distilled prediction it is supposed to audit.
1815        let Some(amortized_probe) = certify_with_basin_warmup(
1816            atom,
1817            evaluator.as_ref(),
1818            t_hat,
1819            x,
1820            amplitude,
1821            chart.lipschitz,
1822            self.config.ridge,
1823            self.config.newton_steps,
1824            chart.region.center.view(),
1825            chart.region.radius,
1826        )?
1827        else {
1828            return Ok((
1829                Array1::<f64>::zeros(d),
1830                uncertified_certificate(chart.lipschitz),
1831            ));
1832        };
1833
1834        let cold_start = chart.region.center.clone();
1835        let Some(cold_probe) = certify_with_basin_warmup(
1836            atom,
1837            evaluator.as_ref(),
1838            cold_start,
1839            x,
1840            amplitude,
1841            chart.lipschitz,
1842            self.config.ridge,
1843            self.config.newton_steps,
1844            chart.region.center.view(),
1845            chart.region.radius,
1846        )?
1847        else {
1848            return Ok((
1849                amortized_probe.coord,
1850                uncertified_certificate(chart.lipschitz),
1851            ));
1852        };
1853
1854        let gap =
1855            latent_coordinate_distance(atom, amortized_probe.coord.view(), cold_probe.coord.view());
1856        let tolerance = distilled_probe_tolerance(&amortized_probe, &cold_probe, amplitude, x);
1857        if !(gap.is_finite() && gap <= tolerance) {
1858            return Ok((
1859                amortized_probe.coord,
1860                uncertified_certificate(chart.lipschitz),
1861            ));
1862        }
1863        Ok((amortized_probe.coord, amortized_probe.initial_cert))
1864    }
1865
1866    /// Batched amortized (distilled) encode over many rows against one atom
1867    /// (#1026 ladder item 3, corpus-rate). Each row uses the closed-form
1868    /// per-chart Jacobian predictor and carries its own Kantorovich certificate;
1869    /// uncertified rows are flagged in [`EncodeResult::encode_uncertified_count`]
1870    /// for the exact multi-start fallback. Row-independent against the frozen
1871    /// dictionary, so the batch fans out over rows (deterministic row-order
1872    /// assembly, bit-identical run-to-run), staying sequential inside a rayon
1873    /// worker to avoid nested oversubscription.
1874    pub fn amortized_encode_batch(
1875        &self,
1876        atom: &SaeManifoldAtom,
1877        atom_index: usize,
1878        targets: ArrayView2<'_, f64>,
1879        amplitudes: ArrayView1<'_, f64>,
1880    ) -> Result<EncodeResult, String> {
1881        let n = targets.nrows();
1882        if amplitudes.len() != n {
1883            return Err(format!(
1884                "amortized_encode_batch: amplitudes len {} != rows {n}",
1885                amplitudes.len()
1886            ));
1887        }
1888        let d = atom.latent_dim;
1889        let encode_rows =
1890            |range: std::ops::Range<usize>| -> Result<Vec<(Array1<f64>, bool)>, String> {
1891                range
1892                    .map(|row| {
1893                        let (t, cert) = self.amortized_encode_row(
1894                            atom,
1895                            atom_index,
1896                            targets.row(row),
1897                            amplitudes[row],
1898                        )?;
1899                        Ok((t, cert.certified()))
1900                    })
1901                    .collect()
1902            };
1903        let rows: Vec<(Array1<f64>, bool)> =
1904            if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
1905                use rayon::prelude::*;
1906                const CHUNK: usize = 256;
1907                let n_chunks = n.div_ceil(CHUNK);
1908                let chunked: Vec<Vec<(Array1<f64>, bool)>> = (0..n_chunks)
1909                    .into_par_iter()
1910                    .map(|c| {
1911                        let start = c * CHUNK;
1912                        let end = (start + CHUNK).min(n);
1913                        encode_rows(start..end)
1914                    })
1915                    .collect::<Result<_, _>>()?;
1916                chunked.into_iter().flatten().collect()
1917            } else {
1918                encode_rows(0..n)?
1919            };
1920        let mut coords = Array2::<f64>::zeros((n, d));
1921        let mut certified = Vec::with_capacity(n);
1922        for (row, (t, cert)) in rows.into_iter().enumerate() {
1923            coords.row_mut(row).assign(&t);
1924            certified.push(cert);
1925        }
1926        Ok(EncodeResult::from_rows(coords, certified))
1927    }
1928
1929    /// Batched certified encode over many rows against one atom (the #988
1930    /// throughput consumer). Each row carries its own certificate; uncertified
1931    /// rows are flagged in [`EncodeResult::encode_uncertified_count`] for the
1932    /// exact multi-start fallback.
1933    pub fn certified_encode_batch(
1934        &self,
1935        atom: &SaeManifoldAtom,
1936        atom_index: usize,
1937        targets: ArrayView2<'_, f64>,
1938        amplitudes: ArrayView1<'_, f64>,
1939    ) -> Result<EncodeResult, String> {
1940        let n = targets.nrows();
1941        if amplitudes.len() != n {
1942            return Err(format!(
1943                "certified_encode_batch: amplitudes len {} != rows {n}",
1944                amplitudes.len()
1945            ));
1946        }
1947        let d = atom.latent_dim;
1948        // Per-row encode is independent against a frozen dictionary (#1010), so
1949        // the corpus-rate batch fans out over rows (#1026 amortized-encoder leg /
1950        // #977 Stage-3 corpus encode). Each row produces an owned `(t, certified)`
1951        // pair; results are assembled back in row order so the output is
1952        // bit-identical run-to-run regardless of thread scheduling. Stay
1953        // sequential inside a rayon worker (e.g. when an outer atom-level fan-out
1954        // owns the pool) to avoid nested oversubscription. The first row that
1955        // fails to encode propagates its error deterministically.
1956        let encode_rows =
1957            |range: std::ops::Range<usize>| -> Result<Vec<(Array1<f64>, bool)>, String> {
1958                range
1959                    .map(|row| {
1960                        let (t, cert) = self.certified_encode_row(
1961                            atom,
1962                            atom_index,
1963                            targets.row(row),
1964                            amplitudes[row],
1965                        )?;
1966                        Ok((t, cert.certified()))
1967                    })
1968                    .collect()
1969            };
1970        let rows: Vec<(Array1<f64>, bool)> =
1971            if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
1972                use rayon::prelude::*;
1973                const CHUNK: usize = 256;
1974                let n_chunks = n.div_ceil(CHUNK);
1975                let chunked: Vec<Vec<(Array1<f64>, bool)>> = (0..n_chunks)
1976                    .into_par_iter()
1977                    .map(|c| {
1978                        let start = c * CHUNK;
1979                        let end = (start + CHUNK).min(n);
1980                        encode_rows(start..end)
1981                    })
1982                    .collect::<Result<_, _>>()?;
1983                chunked.into_iter().flatten().collect()
1984            } else {
1985                encode_rows(0..n)?
1986            };
1987        let mut coords = Array2::<f64>::zeros((n, d));
1988        let mut certified = Vec::with_capacity(n);
1989        for (row, (t, cert)) in rows.into_iter().enumerate() {
1990            coords.row_mut(row).assign(&t);
1991            certified.push(cert);
1992        }
1993        Ok(EncodeResult::from_rows(coords, certified))
1994    }
1995
1996    /// Batched GEMM "fast" amortized encode — the traditional-encoder forward
1997    /// pass, WITH manifolds. For every row this applies the SAME closed-form
1998    /// affine predictor as [`amortized_warm_start`]
1999    /// (`t̂ = t_c + (1/z)·A₁·(x − z·m₁)`), but routed and applied as batched
2000    /// matrix products instead of a per-row loop wrapped in the Kantorovich
2001    /// certificate + basin warmup. NO per-row certificate is taken: this is the
2002    /// speed mode (the certified `*_encode_*` paths remain the accuracy mode).
2003    ///
2004    /// Cost is GEMM-bound: one `(n × p)·(p × d)` decode-distance product for
2005    /// nearest-chart routing (skipped for single-chart atoms) plus, per chart,
2006    /// one `(n_c × p)·(p × d)` predictor product — i.e. `≈ X·Wᵀ`, exactly a
2007    /// dense SAE encoder's forward map.
2008    ///
2009    /// Degenerate rows are handled exactly as `amortized_warm_start` flags them
2010    /// (returns `None` ⇒ zeroed coord here): a missing basis evaluator, a chart
2011    /// whose Gauss–Newton block was singular (`amortized_jacobian == None`), or a
2012    /// non-finite / non-positive amplitude. Those rows are zeroed (never a panic,
2013    /// never a silent wrong encode), and their indices are returned in the
2014    /// `valid` mask so the caller can route them to the exact path if desired.
2015    ///
2016    /// Returns `(coords, valid)` where `coords` is `n × d` and `valid[row]` is
2017    /// `true` iff the amortized predictor fired for that row.
2018    pub fn amortized_encode_batch_fast(
2019        &self,
2020        atom: &SaeManifoldAtom,
2021        atom_index: usize,
2022        x: ArrayView2<'_, f64>,
2023        amplitudes: ArrayView1<'_, f64>,
2024    ) -> Result<(Array2<f64>, Vec<bool>), String> {
2025        let n = x.nrows();
2026        let p = atom.output_dim();
2027        let d = atom.latent_dim;
2028        if x.ncols() != p {
2029            return Err(format!(
2030                "amortized_encode_batch_fast: x has {} cols but atom output dim is {p}",
2031                x.ncols()
2032            ));
2033        }
2034        if amplitudes.len() != n {
2035            return Err(format!(
2036                "amortized_encode_batch_fast: amplitudes len {} != rows {n}",
2037                amplitudes.len()
2038            ));
2039        }
2040        let atom_atlas = self.atoms.get(atom_index).ok_or_else(|| {
2041            format!("amortized_encode_batch_fast: atom {atom_index} not in atlas")
2042        })?;
2043        let mut coords = Array2::<f64>::zeros((n, d));
2044        let mut valid = vec![false; n];
2045
2046        // A missing basis evaluator means this atom never had a well-formed atlas
2047        // built here — treat every row as uncertified (zeroed), exactly like the
2048        // per-row `amortized_encode_row` no-evaluator branch. (The predictor below
2049        // uses only cached atlas data, so no evaluator call is made online.)
2050        if atom.basis_evaluator.is_none() {
2051            return Ok((coords, valid));
2052        }
2053
2054        // ── Routing recon-centers (cached, no online basis evaluation). ────────
2055        // Routing sends a row to the chart whose center reconstruction
2056        // `m(t_c) = BᵀΦ(t_c)` is closest in ‖·‖². Those center reconstructions are
2057        // OFFLINE-cached in `chart.recon_center` (bit-identical to re-evaluating the
2058        // basis at the fixed centers — same φ·decoder accumulation). Gather the cache
2059        // instead of calling `evaluator.evaluate` on every invocation: that per-call
2060        // chart-center evaluation was the dominant per-atom-group overhead at massive
2061        // K, where N rows scatter across many atoms into tiny groups so a fixed
2062        // per-call cost is amortized over only a handful of rows. This is what keeps
2063        // the fast index-routed encode near-flat as K grows.
2064        let valid_charts: Vec<usize> = (0..atom_atlas.charts.len())
2065            .filter(|&c| atom_atlas.charts[c].certified_radius > 0.0)
2066            .collect();
2067        if valid_charts.is_empty() {
2068            return Ok((coords, valid));
2069        }
2070        // recon_centers (C × p): the cached m(t_c) for each certifiable chart.
2071        let mut recon_centers = Array2::<f64>::zeros((valid_charts.len(), p));
2072        for (ci, &c) in valid_charts.iter().enumerate() {
2073            recon_centers
2074                .row_mut(ci)
2075                .assign(&atom_atlas.charts[c].recon_center);
2076        }
2077        // Per-chart routing key: route_idx[row] = argmin_c ‖x_row − recon_c‖².
2078        // ‖x − r‖² = ‖x‖² − 2 x·r + ‖r‖²; the ‖x‖² term is row-constant so the
2079        // argmin uses S = X·recon_centersᵀ and the per-chart ‖r‖². First chart
2080        // wins on a tie (strict `<`), matching `nearest_chart`.
2081        let route_idx: Vec<usize> = if valid_charts.len() == 1 {
2082            vec![0usize; n]
2083        } else {
2084            let s = x.dot(&recon_centers.t()); // (n × C)
2085            let r_sq: Vec<f64> = (0..valid_charts.len())
2086                .map(|c| recon_centers.row(c).dot(&recon_centers.row(c)))
2087                .collect();
2088            (0..n)
2089                .map(|row| {
2090                    let mut best_c = 0usize;
2091                    let mut best_d = f64::INFINITY;
2092                    for c in 0..valid_charts.len() {
2093                        let dist = r_sq[c] - 2.0 * s[[row, c]];
2094                        if dist < best_d {
2095                            best_d = dist;
2096                            best_c = c;
2097                        }
2098                    }
2099                    best_c
2100                })
2101                .collect()
2102        };
2103
2104        // ── Per-chart batched affine predictor. ───────────────────────────────
2105        // For rows routed to chart `c` with finite jacobian `A₁` (d × p) and
2106        // center reconstruction `m₁` (= `chart.recon_center`), the predictor is
2107        //   t̂ = t_c − A₁·m₁ + (1/z)·(A₁·x).
2108        // `t_c − A₁·m₁` is a per-chart constant `base`; `A₁·x` is a d-vector of
2109        // per-row dot products. Instead of gathering routed rows into a fresh
2110        // `X_c` (n_c × p) buffer and running a GEMM into a second `U` (n_c × d)
2111        // buffer — two allocations plus a full copy of the routed rows, per chart —
2112        // fuse the gather straight into the multiply: stream each source row of `x`
2113        // once (it is contiguous) and dot it against `A₁`'s rows, writing the
2114        // predicted coord directly. Zero per-chart heap traffic; the inverse
2115        // amplitude is hoisted to one reciprocal per row.
2116        //
2117        // Precompute each valid chart's `(A₁, base)` once (charts with a singular
2118        // Gauss–Newton block carry no `A₁`, so their routed rows stay
2119        // zeroed/uncertified — same as `amortized_warm_start` returning `None`).
2120        struct ChartPredictor<'a> {
2121            a1: &'a Array2<f64>,
2122            base: &'a Array1<f64>,
2123        }
2124        let predictors: Vec<Option<ChartPredictor<'_>>> = valid_charts
2125            .iter()
2126            .map(|&c| {
2127                let chart = &atom_atlas.charts[c];
2128                // `base = t_c − A₁·m₁` is precomputed offline in the atlas; reuse it
2129                // (both are `Some` together — singular G-N block ⇒ both `None`).
2130                match (chart.amortized_jacobian.as_ref(), chart.amortized_base.as_ref()) {
2131                    (Some(a1), Some(base)) => Some(ChartPredictor { a1, base }),
2132                    _ => None,
2133                }
2134            })
2135            .collect();
2136
2137        for row in 0..n {
2138            let Some(pred) = predictors[route_idx[row]].as_ref() else {
2139                continue;
2140            };
2141            let amp = amplitudes[row];
2142            if !(amp.is_finite() && amp.abs() > 0.0) {
2143                continue;
2144            }
2145            let inv_z = 1.0 / amp;
2146            let x_row = x.row(row);
2147            let mut coord_row = coords.row_mut(row);
2148            for axis in 0..d {
2149                // (A₁·x)[axis] = A₁ row `axis` (contiguous, length p) · x_row.
2150                coord_row[axis] = pred.base[axis] + pred.a1.row(axis).dot(&x_row) * inv_z;
2151            }
2152            valid[row] = true;
2153        }
2154        Ok((coords, valid))
2155    }
2156
2157    /// Fast batched FULL forward pass against one atom: encode → decode, the
2158    /// manifold analogue of a traditional SAE's `x̂ = z·D` (decoder `D`, code `z`).
2159    ///
2160    /// A traditional SAE decodes with one GEMM. The manifold SAE's reconstruction
2161    /// is `m(t̂) = z·Φ(t̂)·B` (module header) — the SAME GEMM `Φ·B`, but the code
2162    /// `Φ(t̂)` is the curved chart basis evaluated at the encoded latent coordinate
2163    /// rather than a flat one-hot. So the fast forward is exactly:
2164    ///   1. [`amortized_encode_batch_fast`] → per-row latent coords `t̂` (one
2165    ///      routing GEMM + one affine GEMM per chart — a traditional `W·x+b`);
2166    ///   2. ONE batched basis evaluation `Φ(t̂)` (the manifold-curvature step a
2167    ///      flat SAE doesn't have — `n×m`);
2168    ///   3. ONE GEMM `recon = Φ(t̂)·B` (`(n×m)·(m×p)` — a traditional decoder
2169    ///      `z·D`), then the per-row amplitude scale `z`.
2170    ///
2171    /// Rows the encoder could not certify-predict (no evaluator / singular
2172    /// Gauss–Newton block / non-finite-or-zero amplitude) are returned as a ZERO
2173    /// reconstruction and flagged `false` in the valid-mask — never a silent wrong
2174    /// decode. The reconstruction of a valid row equals, bit-for-bit up to GEMM
2175    /// reassociation, `z·(Φ(t̂_row)·B)` with `t̂` from the per-row predictor.
2176    pub fn amortized_reconstruct_batch_fast(
2177        &self,
2178        atom: &SaeManifoldAtom,
2179        atom_index: usize,
2180        x: ArrayView2<'_, f64>,
2181        amplitudes: ArrayView1<'_, f64>,
2182    ) -> Result<(Array2<f64>, Vec<bool>), String> {
2183        let n = x.nrows();
2184        let p = atom.output_dim();
2185        // Step 1: batched encode → latent coords (reuses the fast routing+affine).
2186        let (coords, valid) = self.amortized_encode_batch_fast(atom, atom_index, x, amplitudes)?;
2187        let mut recon = Array2::<f64>::zeros((n, p));
2188        // A missing evaluator means no row could encode — every row is zeroed and
2189        // already flagged `false` by the encode; nothing to decode.
2190        let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
2191            return Ok((recon, valid));
2192        };
2193        // Step 2: ONE batched basis evaluation Φ(t̂) over all rows (n × m). Invalid
2194        // rows carry coords = 0 (the chart-origin); we still evaluate them in the
2195        // batch for a single GEMM, then zero their reconstruction below — the basis
2196        // is finite at the origin so this cannot poison the valid rows' GEMM.
2197        let (phi, _jet) = evaluator
2198            .evaluate(coords.view())
2199            .map_err(|err| format!("amortized_reconstruct_batch_fast: basis eval: {err}"))?;
2200        // Step 3: ONE GEMM recon = Φ·B (n × p), then per-row amplitude scale z.
2201        // m(t̂) = z·Φ(t̂)·B, matching the module header and `fill_decoded_row`'s
2202        // `Φ·decoder` accumulation (the amplitude is applied once here).
2203        let decoded = phi.dot(&atom.decoder_coefficients); // (n × p), amplitude-1
2204        for row in 0..n {
2205            if !valid[row] {
2206                continue; // stays zeroed — uncertified, like warm_start `None`.
2207            }
2208            let z = amplitudes[row];
2209            for col in 0..p {
2210                recon[[row, col]] = z * decoded[[row, col]];
2211            }
2212        }
2213        Ok((recon, valid))
2214    }
2215
2216    /// LSH-routed certified encode (issue #1010 step 2 + 3): for each target
2217    /// row, the existing [`SaeCandidateIndex`] (#985/#994) proposes the
2218    /// best-aligned atom by frame alignment to the row direction; the row is then
2219    /// encoded against THAT atom's certified chart atlas. This is the production
2220    /// routing path. Atom selection is EXACT (#1777): [`SaeCandidateIndex::route_exact`]
2221    /// returns the global argmax of the routing score (the universal-bound LSH fast
2222    /// path, else a full-scan fallback) — never a silently-missed ungathered atom —
2223    /// and the atlas does the in-atom nearest-chart routing and the per-row
2224    /// Kantorovich certificate.
2225    ///
2226    /// `atoms[id]` must be aligned with the atlas's `atoms[id]` (same dictionary
2227    /// order the atlas was built from and the sketch/index were built over).
2228    /// A row over an empty dictionary, or whose globally-best atom aligns below the
2229    /// fit-quality floor, is flagged uncertified — it routes to the exact
2230    /// multi-start fallback, never a silent wrong encode.
2231    pub fn certified_encode_with_index<S: AtomFrameSketch + Sync>(
2232        &self,
2233        atoms: &[SaeManifoldAtom],
2234        index: &SaeCandidateIndex,
2235        sketch: &S,
2236        targets: ArrayView2<'_, f64>,
2237        amplitudes: ArrayView1<'_, f64>,
2238        latent_dim: usize,
2239    ) -> Result<EncodeResult, String> {
2240        let n = targets.nrows();
2241        if amplitudes.len() != n {
2242            return Err(format!(
2243                "certified_encode_with_index: amplitudes len {} != rows {n}",
2244                amplitudes.len()
2245            ));
2246        }
2247        let budget = auto_candidate_budget(atoms.len().max(1));
2248        // LSH-routed per-row encode is independent across rows (sublinear atom
2249        // selection + frozen-dictionary in-atom Newton), so the corpus-rate batch
2250        // fans out over rows (#1026 amortized-encoder/routing leg / #977 Stage-3).
2251        // `None` coords (no LSH candidate) carry through as a zeroed row flagged
2252        // uncertified — identical to the sequential semantics. Results assemble
2253        // back in row order (bit-identical run-to-run); the first encode error
2254        // propagates deterministically. Stay sequential inside a rayon worker to
2255        // avoid nested oversubscription.
2256        let encode_rows =
2257            |range: std::ops::Range<usize>| -> Result<Vec<Option<(Array1<f64>, bool)>>, String> {
2258                range
2259                    .map(|row| {
2260                        // EXACT routing (#1777): pick the GLOBAL argmax of the
2261                        // routing score over the whole dictionary, not merely the
2262                        // best LSH-gathered candidate. `route_exact` certifies the
2263                        // sublinear gather against the universal `[0,1]` alignment
2264                        // bound and falls back to a full scan otherwise, so the
2265                        // returned atom is guaranteed to be the globally-best — no
2266                        // silently-missed ungathered atom.
2267                        let Some(route) =
2268                            index.route_exact(sketch, targets.row(row), budget, true)
2269                        else {
2270                            // Empty dictionary: flag for the exact fallback.
2271                            return Ok(None);
2272                        };
2273                        let best_atom = route.atom;
2274                        // Fit-quality floor: even the globally-best atom may align
2275                        // only weakly with this row (no atom fits it). A finite
2276                        // alignment below the floor — or a NaN, the zero-norm
2277                        // ‖d‖ = 0 row — flags for the exact multi-start fallback
2278                        // rather than encoding against a poorly-fitting atom. This is
2279                        // a quality gate, not a routing-correctness gate; routing is
2280                        // already exact. See CANDIDATE_ROUTING_MIN_ALIGNMENT.
2281                        if !route.alignment.is_finite()
2282                            || route.alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
2283                        {
2284                            return Ok(None);
2285                        }
2286                        let atom = atoms.get(best_atom).ok_or_else(|| {
2287                            format!(
2288                                "certified_encode_with_index: proposed atom {best_atom} out of range"
2289                            )
2290                        })?;
2291                        let (t, cert) = self.certified_encode_row(
2292                            atom,
2293                            best_atom,
2294                            targets.row(row),
2295                            amplitudes[row],
2296                        )?;
2297                        // Heterogeneous-atom dictionaries with different latent_dim
2298                        // per atom are not supported by the batched API: the caller
2299                        // declares one shared `latent_dim` for the output tensor.
2300                        // Silently zeroing the coord row while recording a
2301                        // certified=true flag would produce corrupted
2302                        // reconstructions downstream — error loudly instead.
2303                        if t.len() != latent_dim {
2304                            return Err(format!(
2305                                "certified_encode_with_index: atom {best_atom} returned t.len()={} \
2306                                 but declared latent_dim={latent_dim}; heterogeneous-dim \
2307                                 dictionaries are not supported by this batched encode path",
2308                                t.len()
2309                            ));
2310                        }
2311                        Ok(Some((t, cert.certified())))
2312                    })
2313                    .collect()
2314            };
2315        let rows: Vec<Option<(Array1<f64>, bool)>> =
2316            if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
2317                use rayon::prelude::*;
2318                const CHUNK: usize = 256;
2319                let n_chunks = n.div_ceil(CHUNK);
2320                let chunked: Vec<Vec<Option<(Array1<f64>, bool)>>> = (0..n_chunks)
2321                    .into_par_iter()
2322                    .map(|c| {
2323                        let start = c * CHUNK;
2324                        let end = (start + CHUNK).min(n);
2325                        encode_rows(start..end)
2326                    })
2327                    .collect::<Result<_, _>>()?;
2328                chunked.into_iter().flatten().collect()
2329            } else {
2330                encode_rows(0..n)?
2331            };
2332        let mut coords = Array2::<f64>::zeros((n, latent_dim));
2333        let mut certified = Vec::with_capacity(n);
2334        for (row, slot) in rows.into_iter().enumerate() {
2335            match slot {
2336                Some((t, cert)) => {
2337                    coords.row_mut(row).assign(&t);
2338                    certified.push(cert);
2339                }
2340                None => certified.push(false),
2341            }
2342        }
2343        Ok(EncodeResult::from_rows(coords, certified))
2344    }
2345
2346    /// LSH-routed AMORTIZED (distilled) encode — the production token-rate
2347    /// encoder of #1026 ladder item 3. Identical routing to
2348    /// [`Self::certified_encode_with_index`] (LSH proposes the best-aligned atom,
2349    /// the atlas routes to the in-atom nearest chart), but the in-atom encode is
2350    /// the closed-form per-chart Jacobian predictor + certificate gate of
2351    /// [`Self::amortized_encode_row`] rather than the certified Newton-refinement
2352    /// path.
2353    /// This is the deployment path: the distilled affine map produces the encode
2354    /// in one mat-vec, the Kantorovich certificate decides trust-or-fallback per
2355    /// row, and uncertified rows (the adversarial tail the thread expects to
2356    /// concentrate on rare tokens) are flagged for the exact multi-start solve —
2357    /// compute goes where the questions are. Row-independent against the frozen
2358    /// dictionary, so the batch fans out over rows with deterministic row-order
2359    /// assembly (bit-identical run-to-run).
2360    pub fn amortized_encode_with_index<S: AtomFrameSketch + Sync>(
2361        &self,
2362        atoms: &[SaeManifoldAtom],
2363        index: &SaeCandidateIndex,
2364        sketch: &S,
2365        targets: ArrayView2<'_, f64>,
2366        amplitudes: ArrayView1<'_, f64>,
2367        latent_dim: usize,
2368    ) -> Result<EncodeResult, String> {
2369        let n = targets.nrows();
2370        if amplitudes.len() != n {
2371            return Err(format!(
2372                "amortized_encode_with_index: amplitudes len {} != rows {n}",
2373                amplitudes.len()
2374            ));
2375        }
2376        let budget = auto_candidate_budget(atoms.len().max(1));
2377        let encode_rows =
2378            |range: std::ops::Range<usize>| -> Result<Vec<Option<(Array1<f64>, bool)>>, String> {
2379                range
2380                    .map(|row| {
2381                        // EXACT routing (#1777): global argmax of the routing score,
2382                        // not just the best LSH-gathered candidate (see
2383                        // certified_encode_with_index for the full rationale).
2384                        let Some(route) =
2385                            index.route_exact(sketch, targets.row(row), budget, true)
2386                        else {
2387                            return Ok(None);
2388                        };
2389                        let best_atom = route.atom;
2390                        // Fit-quality floor (not a routing-correctness gate; routing
2391                        // is exact): even the globally-best atom may fit a row poorly,
2392                        // and a NaN alignment is the zero-norm ‖d‖ = 0 row. Either way
2393                        // flag for the exact multi-start fallback. See
2394                        // CANDIDATE_ROUTING_MIN_ALIGNMENT.
2395                        if !route.alignment.is_finite()
2396                            || route.alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
2397                        {
2398                            return Ok(None);
2399                        }
2400                        let atom = atoms.get(best_atom).ok_or_else(|| {
2401                            format!(
2402                                "amortized_encode_with_index: proposed atom {best_atom} out of range"
2403                            )
2404                        })?;
2405                        let (t, cert) = self.amortized_encode_row(
2406                            atom,
2407                            best_atom,
2408                            targets.row(row),
2409                            amplitudes[row],
2410                        )?;
2411                        if t.len() != latent_dim {
2412                            return Err(format!(
2413                                "amortized_encode_with_index: atom {best_atom} returned t.len()={} \
2414                                 but declared latent_dim={latent_dim}; heterogeneous-dim \
2415                                 dictionaries are not supported by this batched encode path",
2416                                t.len()
2417                            ));
2418                        }
2419                        Ok(Some((t, cert.certified())))
2420                    })
2421                    .collect()
2422            };
2423        let rows: Vec<Option<(Array1<f64>, bool)>> =
2424            if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
2425                use rayon::prelude::*;
2426                const CHUNK: usize = 256;
2427                let n_chunks = n.div_ceil(CHUNK);
2428                let chunked: Vec<Vec<Option<(Array1<f64>, bool)>>> = (0..n_chunks)
2429                    .into_par_iter()
2430                    .map(|c| {
2431                        let start = c * CHUNK;
2432                        let end = (start + CHUNK).min(n);
2433                        encode_rows(start..end)
2434                    })
2435                    .collect::<Result<_, _>>()?;
2436                chunked.into_iter().flatten().collect()
2437            } else {
2438                encode_rows(0..n)?
2439            };
2440        let mut coords = Array2::<f64>::zeros((n, latent_dim));
2441        let mut certified = Vec::with_capacity(n);
2442        for (row, slot) in rows.into_iter().enumerate() {
2443            match slot {
2444                Some((t, cert)) => {
2445                    coords.row_mut(row).assign(&t);
2446                    certified.push(cert);
2447                }
2448                None => certified.push(false),
2449            }
2450        }
2451        Ok(EncodeResult::from_rows(coords, certified))
2452    }
2453
2454    /// LSH-routed FAST amortized encode over the WHOLE dictionary — the
2455    /// multi-atom, corpus-rate analogue of [`Self::amortized_encode_with_index`].
2456    ///
2457    /// `amortized_encode_with_index` routes per row, then runs the per-row
2458    /// closed-form predictor + Kantorovich certificate + cold cross-check on each
2459    /// row independently. This fast variant keeps the SAME per-row EXACT routing
2460    /// (`index.route_exact` + the fit-quality floor), but replaces the per-row
2461    /// predictor with the GEMM-batched [`Self::amortized_encode_batch_fast`]:
2462    /// it GROUPS rows by their global-argmax atom and runs one batched affine-
2463    /// predictor pass per atom-group (a routing GEMM + a predictor GEMM each),
2464    /// reproducing a traditional SAE's whole-dictionary `W·x+b` throughput. No
2465    /// per-row certificate — this is the speed mode validated as accuracy-parity
2466    /// with the certified solve (`fast_forward_is_accuracy_parity_with_certified`).
2467    ///
2468    /// Returns the per-row latent coords and a valid-mask: `false` for a row with
2469    /// an empty dictionary, a sub-threshold/NaN routing alignment, or one the batched
2470    /// predictor could not fire on (no evaluator / singular Gauss–Newton block /
2471    /// non-finite-or-zero amplitude). Each row is written exactly once (disjoint
2472    /// per-atom groups), so the result is independent of group iteration order.
2473    pub fn amortized_encode_with_index_fast<S: AtomFrameSketch + Sync>(
2474        &self,
2475        atoms: &[SaeManifoldAtom],
2476        index: &SaeCandidateIndex,
2477        sketch: &S,
2478        targets: ArrayView2<'_, f64>,
2479        amplitudes: ArrayView1<'_, f64>,
2480        latent_dim: usize,
2481    ) -> Result<(Array2<f64>, Vec<bool>), String> {
2482        let n = targets.nrows();
2483        if amplitudes.len() != n {
2484            return Err(format!(
2485                "amortized_encode_with_index_fast: amplitudes len {} != rows {n}",
2486                amplitudes.len()
2487            ));
2488        }
2489        let budget = auto_candidate_budget(atoms.len().max(1));
2490        let mut coords = Array2::<f64>::zeros((n, latent_dim));
2491        let mut valid = vec![false; n];
2492        // ── Single allocation-free pass: route each row, apply the CACHED predictor.
2493        //
2494        // Routing sublinearity (massive-K, K≈32k): the certified path uses
2495        // `route_exact`, whose universal-bound LSH certificate only fires at the
2496        // alignment ceiling (≈1.0); for any real dictionary (`alignment < 1`) it
2497        // falls back to `brute_force_best_atom` — an O(K) full scan PER ROW, making
2498        // the encode O(N·K). This SPEED path takes the LSH gather's best-aligned atom
2499        // directly (`propose` scores only ~budget candidates → O(log K)); a rare miss
2500        // is caught by the fit-quality floor + the downstream certificate/exact
2501        // fallback (the documented speed/accuracy tradeoff).
2502        //
2503        // Allocation: the predictor uses only OFFLINE-cached atlas data (per-chart
2504        // `recon_center` for routing + `amortized_jacobian`/`amortized_base` for the
2505        // `t̂ = base + (1/z)·A₁·x` mat-vec), so NO per-row or per-atom heap buffer is
2506        // allocated. This replaces the old per-atom-group GEMM sub-batch — which at
2507        // `K ≫ N` degenerated to ONE row per group, so its per-group buffers (x_sub,
2508        // recon-centers, predictors) dominated and made the "fast" path allocation-
2509        // bound. Now the only per-row allocation is inside `index.propose`.
2510        for row in 0..n {
2511            let dir = targets.row(row);
2512            let proposal = index.propose(sketch, dir, budget, true);
2513            let Some(&best_atom) = proposal.proposed.first() else {
2514                continue; // nothing gathered (empty dictionary / probe-dim mismatch)
2515            };
2516            // Fit-quality floor: the best gathered atom still fits this row poorly,
2517            // or the alignment is NaN (zero-norm row) — flag for the exact fallback.
2518            let alignment = sketch.alignment(best_atom, dir);
2519            if !alignment.is_finite() || alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT {
2520                continue;
2521            }
2522            let atom = atoms.get(best_atom).ok_or_else(|| {
2523                format!("amortized_encode_with_index_fast: proposed atom {best_atom} out of range")
2524            })?;
2525            if atom.latent_dim != latent_dim {
2526                return Err(format!(
2527                    "amortized_encode_with_index_fast: atom {best_atom} latent_dim {} != declared \
2528                     {latent_dim}; heterogeneous-dim dictionaries are not supported by this path",
2529                    atom.latent_dim
2530                ));
2531            }
2532            let Some(atom_atlas) = self.atoms.get(best_atom) else {
2533                continue; // no atlas for this atom → predictor cannot fire (zeroed)
2534            };
2535            if amortized_predict_row(atom_atlas, dir, amplitudes[row], latent_dim, coords.row_mut(row))
2536            {
2537                valid[row] = true;
2538            }
2539        }
2540        Ok((coords, valid))
2541    }
2542
2543    /// LSH-routed FAST full forward over the WHOLE dictionary: encode → decode,
2544    /// the multi-atom analogue of [`Self::amortized_reconstruct_batch_fast`]. Same
2545    /// sublinear per-row routing + per-atom grouping as
2546    /// [`Self::amortized_encode_with_index_fast`], but each group is run through
2547    /// the batched reconstruct (`m(t̂) = z·Φ(t̂)·B`) so the result is the per-row
2548    /// reconstruction in the ambient space. Rows that do not route/predict decode
2549    /// to an exact zero reconstruction and are flagged `false`.
2550    pub fn amortized_reconstruct_with_index_fast<S: AtomFrameSketch + Sync>(
2551        &self,
2552        atoms: &[SaeManifoldAtom],
2553        index: &SaeCandidateIndex,
2554        sketch: &S,
2555        targets: ArrayView2<'_, f64>,
2556        amplitudes: ArrayView1<'_, f64>,
2557    ) -> Result<(Array2<f64>, Vec<bool>), String> {
2558        let n = targets.nrows();
2559        let p = targets.ncols();
2560        if amplitudes.len() != n {
2561            return Err(format!(
2562                "amortized_reconstruct_with_index_fast: amplitudes len {} != rows {n}",
2563                amplitudes.len()
2564            ));
2565        }
2566        let budget = auto_candidate_budget(atoms.len().max(1));
2567        // SUBLINEAR routing for the SPEED-mode full forward — mirror
2568        // `amortized_encode_with_index_fast`: take the LSH gather's best-aligned
2569        // atom (O(budget) candidates) instead of route_exact's O(K) full-scan
2570        // certification, keeping the whole fast encode→decode sublinear in K at
2571        // K=32k. The gather's best is the exact argmax on the vast majority of rows;
2572        // rare misses are caught by the fit-quality floor + downstream fallback.
2573        let mut groups: std::collections::HashMap<usize, Vec<usize>> =
2574            std::collections::HashMap::new();
2575        for row in 0..n {
2576            let dir = targets.row(row);
2577            let proposal = index.propose(sketch, dir, budget, true);
2578            let Some(&best_atom) = proposal.proposed.first() else {
2579                continue; // nothing gathered (empty dictionary / probe-dim mismatch)
2580            };
2581            let alignment = sketch.alignment(best_atom, dir);
2582            if !alignment.is_finite() || alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT {
2583                continue;
2584            }
2585            groups.entry(best_atom).or_default().push(row);
2586        }
2587
2588        let mut recon = Array2::<f64>::zeros((n, p));
2589        let mut valid = vec![false; n];
2590        for (atom_idx, rows_here) in groups {
2591            let atom = atoms.get(atom_idx).ok_or_else(|| {
2592                format!(
2593                    "amortized_reconstruct_with_index_fast: proposed atom {atom_idx} out of range"
2594                )
2595            })?;
2596            if atom.output_dim() != p {
2597                return Err(format!(
2598                    "amortized_reconstruct_with_index_fast: atom {atom_idx} output_dim {} != target \
2599                     dim {p}",
2600                    atom.output_dim()
2601                ));
2602            }
2603            let mut x_sub = Array2::<f64>::zeros((rows_here.len(), p));
2604            let mut amp_sub = Array1::<f64>::zeros(rows_here.len());
2605            for (i, &row) in rows_here.iter().enumerate() {
2606                x_sub.row_mut(i).assign(&targets.row(row));
2607                amp_sub[i] = amplitudes[row];
2608            }
2609            let (sub_recon, sub_valid) = self.amortized_reconstruct_batch_fast(
2610                atom,
2611                atom_idx,
2612                x_sub.view(),
2613                amp_sub.view(),
2614            )?;
2615            for (i, &row) in rows_here.iter().enumerate() {
2616                if sub_valid[i] {
2617                    recon.row_mut(row).assign(&sub_recon.row(i));
2618                    valid[row] = true;
2619                }
2620            }
2621        }
2622        Ok((recon, valid))
2623    }
2624}
2625
2626/// Offline `β = 1/λ_min(H_GN)` at a chart center from the Gauss-Newton block
2627/// `H_GN = J_mᵀ J_m` (residual-free). The offline `β` bounds the curvature the
2628/// online certificate sees: charts are placed where the encode lands, so the
2629/// representative residual is small and `H_GN` is the dominant, residual-free
2630/// curvature estimate. (The online per-row certificate still uses the FULL
2631/// Hessian; this is only the offline radius-sizing curvature.) Returns `None`
2632/// for a degenerate center (`λ_min ≤ 0`), which marks an uncertifiable chart.
2633pub(crate) fn center_beta(atom: &SaeManifoldAtom, center: &Array1<f64>, ridge: f64) -> Option<f64> {
2634    let evaluator = atom.basis_evaluator.as_ref()?.clone();
2635    let d = atom.latent_dim;
2636    let p = atom.output_dim();
2637    let m = atom.basis_size();
2638    let coords = center.view().to_shape((1, d)).ok()?.to_owned();
2639    let (_phi, jet) = evaluator.evaluate(coords.view()).ok()?;
2640    let decoder = &atom.decoder_coefficients;
2641    // J_m[axis] = Bᵀ (∂Φ/∂t_axis) ∈ ℝᵖ (amplitude-1; curvature scales with z²
2642    // and is absorbed conservatively by the amplitude-bounded Lipschitz term).
2643    let mut jm = Array2::<f64>::zeros((d, p));
2644    for axis in 0..d {
2645        for basis_col in 0..m {
2646            let dphi = jet[[0, basis_col, axis]];
2647            if dphi == 0.0 {
2648                continue;
2649            }
2650            for out in 0..p {
2651                jm[[axis, out]] += dphi * decoder[[basis_col, out]];
2652            }
2653        }
2654    }
2655    let mut h = Array2::<f64>::zeros((d, d));
2656    for a in 0..d {
2657        for b in 0..d {
2658            h[[a, b]] = jm.row(a).dot(&jm.row(b));
2659        }
2660        h[[a, a]] += ridge;
2661    }
2662    let (vals, _vecs) = h.eigh(Side::Lower).ok()?;
2663    let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
2664    if lambda_min.is_finite() && lambda_min > 0.0 {
2665        Some(1.0 / lambda_min)
2666    } else {
2667        None
2668    }
2669}
2670
2671/// #1154 — the amortized encoder's closed-form warm-start coordinate for one
2672/// row `x` against one chart at amplitude `z`:
2673///
2674/// ```text
2675/// t̂ = t_c + (1/z) · A₁ · (x − z · m₁(t_c)),
2676/// ```
2677///
2678/// a single `O(d·p)` mat-vec from the chart's precomputed IFT Jacobian `A₁` and
2679/// center reconstruction `m₁`. Returns `None` when the chart carries no
2680/// distilled Jacobian (singular Gauss–Newton block) or the amplitude is not
2681/// strictly positive and finite (a near-inactive atom, where the
2682/// amplitude-divided map is undefined) — in those cases the caller starts from
2683/// the chart center instead. Shared by the amortized encode (where `t̂` is the
2684/// prediction) and the exact certified encode (where `t̂` is the Newton
2685/// warm-start that then refines to stationarity, Design A).
2686pub(crate) fn amortized_warm_start(
2687    chart: &CertifiedChart,
2688    x: ArrayView1<'_, f64>,
2689    amplitude: f64,
2690) -> Option<Array1<f64>> {
2691    let a1 = chart.amortized_jacobian.as_ref()?;
2692    if !(amplitude.is_finite() && amplitude.abs() > 0.0) {
2693        return None;
2694    }
2695    let d = a1.nrows();
2696    let mut t_hat = chart.region.center.clone();
2697    for (out_idx, &m1_out) in chart.recon_center.iter().enumerate().take(a1.ncols()) {
2698        let resid = x[out_idx] - amplitude * m1_out;
2699        for axis in 0..d {
2700            t_hat[axis] += a1[[axis, out_idx]] * resid / amplitude;
2701        }
2702    }
2703    Some(t_hat)
2704}
2705
2706/// Single-row amortized predictor against ONE atom's cached atlas, writing the
2707/// encoded latent coordinate DIRECTLY into `out` (length `d`) with NO heap
2708/// allocation. Routes to the nearest certifiable chart by cached center-
2709/// reconstruction distance `‖x − m(t_c)‖²` (matching `amortized_encode_batch_fast`'s
2710/// per-chart routing), then applies that chart's precomputed affine predictor
2711/// `t̂ = base + (1/z)·A₁·x` (`base = t_c − A₁·m₁` is `chart.amortized_base`).
2712///
2713/// Returns `false` — leaving `out` at its incoming (zeroed) value — for exactly the
2714/// rows `amortized_encode_batch_fast` would flag: no certifiable chart, a nearest
2715/// chart with no distilled predictor (singular Gauss–Newton block), or an unusable
2716/// amplitude. This is the per-row core of the allocation-free massive-K fast encode.
2717pub(crate) fn amortized_predict_row(
2718    atom_atlas: &AtomEncodeAtlas,
2719    x: ArrayView1<'_, f64>,
2720    amplitude: f64,
2721    d: usize,
2722    mut out: ndarray::ArrayViewMut1<'_, f64>,
2723) -> bool {
2724    if !(amplitude.is_finite() && amplitude.abs() > 0.0) {
2725        return false;
2726    }
2727    // Nearest certifiable chart by ‖x − recon_center‖² (first-wins on ties, strict
2728    // `<` — same argmin as `amortized_encode_batch_fast`'s route_idx; the ‖x‖² term
2729    // it drops is row-constant and does not change the argmin).
2730    let mut best_ci: Option<usize> = None;
2731    let mut best_dist = f64::INFINITY;
2732    for (ci, chart) in atom_atlas.charts.iter().enumerate() {
2733        if chart.certified_radius <= 0.0 {
2734            continue;
2735        }
2736        let mut dist = 0.0;
2737        for (r, xv) in chart.recon_center.iter().zip(x.iter()) {
2738            let diff = r - xv;
2739            dist += diff * diff;
2740        }
2741        if dist < best_dist {
2742            best_dist = dist;
2743            best_ci = Some(ci);
2744        }
2745    }
2746    let Some(ci) = best_ci else {
2747        return false;
2748    };
2749    let chart = &atom_atlas.charts[ci];
2750    // The nearest chart must carry a distilled predictor; otherwise flag (zeroed),
2751    // exactly as the per-chart `None` branch of `amortized_encode_batch_fast`.
2752    let (Some(a1), Some(base)) = (
2753        chart.amortized_jacobian.as_ref(),
2754        chart.amortized_base.as_ref(),
2755    ) else {
2756        return false;
2757    };
2758    let inv_z = 1.0 / amplitude;
2759    for axis in 0..d {
2760        out[axis] = base[axis] + a1.row(axis).dot(&x) * inv_z;
2761    }
2762    true
2763}
2764
2765/// The amplitude-1 distilled amortized-encoder Jacobian at a chart center
2766/// (#1026 ladder item 3). Returns `(A₁, m₁)` where `m₁ = BᵀΦ(t_c) ∈ ℝᵖ` is the
2767/// amplitude-1 center reconstruction and `A₁ = (J₁ᵀJ₁ + ridge·I)⁻¹ J₁ ∈ ℝ^{d×p}`
2768/// is the implicit-function-theorem derivative of the encode map `x ↦ t`
2769/// (Gauss–Newton block — the residual-free, dominant curvature exactly as the
2770/// offline radius-sizing `β`). With these, the online encode of a row `x` at
2771/// amplitude `z` is the closed-form affine prediction
2772/// `t = t_c + (1/z)·A₁·(x − z·m₁)` — one mat-vec, no per-row factorization.
2773/// `None` when the basis has no jet or the Gauss–Newton block is singular (no
2774/// certifiable amortization), matching `center_beta`'s gate so a chart with a
2775/// finite `β` always carries a Jacobian and vice versa.
2776pub(crate) fn center_amortized_jacobian(
2777    atom: &SaeManifoldAtom,
2778    center: &Array1<f64>,
2779    ridge: f64,
2780) -> Option<(Array2<f64>, Array1<f64>)> {
2781    let evaluator = atom.basis_evaluator.as_ref()?.clone();
2782    let d = atom.latent_dim;
2783    let p = atom.output_dim();
2784    let m = atom.basis_size();
2785    let coords = center.view().to_shape((1, d)).ok()?.to_owned();
2786    let (phi, jet) = evaluator.evaluate(coords.view()).ok()?;
2787    let decoder = &atom.decoder_coefficients;
2788    // m₁(t_c) = BᵀΦ(t_c) ∈ ℝᵖ (amplitude-1 center reconstruction).
2789    let mut recon = Array1::<f64>::zeros(p);
2790    for basis_col in 0..m {
2791        let phi_v = phi[[0, basis_col]];
2792        if phi_v == 0.0 {
2793            continue;
2794        }
2795        for out in 0..p {
2796            recon[out] += phi_v * decoder[[basis_col, out]];
2797        }
2798    }
2799    // J₁[axis] = Bᵀ (∂Φ/∂t_axis) ∈ ℝᵖ (amplitude-1; z factors out analytically).
2800    let mut jm = Array2::<f64>::zeros((d, p));
2801    for axis in 0..d {
2802        for basis_col in 0..m {
2803            let dphi = jet[[0, basis_col, axis]];
2804            if dphi == 0.0 {
2805                continue;
2806            }
2807            for out in 0..p {
2808                jm[[axis, out]] += dphi * decoder[[basis_col, out]];
2809            }
2810        }
2811    }
2812    // H_GN = J₁ J₁ᵀ + ridge·I ∈ ℝ^{d×d}.
2813    let mut h = Array2::<f64>::zeros((d, d));
2814    for a in 0..d {
2815        for b in 0..d {
2816            h[[a, b]] = jm.row(a).dot(&jm.row(b));
2817        }
2818        h[[a, a]] += ridge;
2819    }
2820    let (vals, vecs) = h.eigh(Side::Lower).ok()?;
2821    let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
2822    if !(lambda_min.is_finite() && lambda_min > 0.0) {
2823        return None;
2824    }
2825    // A₁ = H_GN⁻¹ J₁ via the eigendecomposition: H⁻¹ = Σ_i (1/λᵢ) vᵢ vᵢᵀ, so
2826    // A₁[:, out] = Σ_i (vᵢ · J₁[:, out]) / λᵢ · vᵢ. Column-by-column keeps it the
2827    // d×p Jacobian (one SPD solve reused across all p output channels).
2828    let mut a1 = Array2::<f64>::zeros((d, p));
2829    for out in 0..p {
2830        let jcol = jm.column(out);
2831        for (i, &lam) in vals.iter().enumerate() {
2832            if !(lam.is_finite() && lam > 0.0) {
2833                return None;
2834            }
2835            let vi = vecs.column(i);
2836            let coeff = vi.dot(&jcol) / lam;
2837            for row in 0..d {
2838                a1[[row, out]] += coeff * vi[row];
2839            }
2840        }
2841    }
2842    Some((a1, recon))
2843}
2844
2845/// Route a target row to the nearest chart of an atom by reconstruction
2846/// distance: the chart whose center reconstruction `m(t_c)` is closest to `x`.
2847/// Returns the chart index and the distance, or `None` when the atom has no
2848/// charts.
2849pub(crate) fn nearest_chart(
2850    atom_atlas: &AtomEncodeAtlas,
2851    x: ArrayView1<'_, f64>,
2852) -> Option<(usize, f64)> {
2853    if atom_atlas.charts.is_empty() {
2854        return None;
2855    }
2856    let mut best: Option<(usize, f64)> = None;
2857    for (idx, chart) in atom_atlas.charts.iter().enumerate() {
2858        if chart.certified_radius <= 0.0 {
2859            continue;
2860        }
2861        // Reuse the offline-distilled `m(t_c) = B^T Phi(t_c)` (`chart.recon_center`)
2862        // instead of re-evaluating the basis at a fixed center per row — see
2863        // `nearest_charts_topk`. Distance accumulated in place, no temporary array.
2864        let mut dist = 0.0;
2865        for (r, xv) in chart.recon_center.iter().zip(x.iter()) {
2866            let diff = r - xv;
2867            dist += diff * diff;
2868        }
2869        if best.map(|(_, b)| dist < b).unwrap_or(true) {
2870            best = Some((idx, dist));
2871        }
2872    }
2873    best
2874}
2875
2876/// The `k` charts whose CENTER reconstruction `m(t_c)` is nearest to `x` in
2877/// ambient ‖·‖², returned as chart indices sorted by increasing distance (ties
2878/// broken by chart index — deterministic). Only certifiable charts
2879/// (`certified_radius > 0`) are considered, exactly like [`nearest_chart`], whose
2880/// single result is `nearest_charts_topk(.., 1)[0]`. Used by the certified encode
2881/// to refine the global basin on self-approaching atoms (see
2882/// [`CERTIFIED_ROUTING_TOPK`]).
2883pub(crate) fn nearest_charts_topk(
2884    atom_atlas: &AtomEncodeAtlas,
2885    x: ArrayView1<'_, f64>,
2886    k: usize,
2887) -> Vec<usize> {
2888    if atom_atlas.charts.is_empty() || k == 0 {
2889        return Vec::new();
2890    }
2891    let mut scored: Vec<(usize, f64)> = Vec::new();
2892    for (idx, chart) in atom_atlas.charts.iter().enumerate() {
2893        if chart.certified_radius <= 0.0 {
2894            continue;
2895        }
2896        // `m(t_c) = BᵀΦ(t_c)` is an OFFLINE per-chart constant already distilled
2897        // into `chart.recon_center` at build time (bit-for-bit the same φ·decoder
2898        // accumulation this used to recompute). Reuse it instead of re-evaluating
2899        // the basis at a fixed center for every row — that re-eval was the encode's
2900        // dominant per-row cost (charts × rows basis evals, each allocating the φ/jet
2901        // arrays). `‖m(t_c) − x‖²` computed in place (no temporary diff array).
2902        let mut dist = 0.0;
2903        for (r, xv) in chart.recon_center.iter().zip(x.iter()) {
2904            let diff = r - xv;
2905            dist += diff * diff;
2906        }
2907        scored.push((idx, dist));
2908    }
2909    // Sort by distance, then chart index for a deterministic, first-wins order
2910    // consistent with `nearest_chart`'s strict-`<` tie rule.
2911    scored.sort_by(|a, b| {
2912        a.1.partial_cmp(&b.1)
2913            .unwrap_or(std::cmp::Ordering::Equal)
2914            .then(a.0.cmp(&b.0))
2915    });
2916    scored.into_iter().take(k).map(|(idx, _)| idx).collect()
2917}
2918
2919/// Reconstruction error `‖x − z·m(t)‖` of an encoded coordinate `t` — the
2920/// criterion the certified encode minimizes over its candidate charts to pick the
2921/// GLOBAL basin. `m(t) = Bᵀ Φ(t)` is the amplitude-1 reconstruction; `z` is the
2922/// amplitude. A non-finite reconstruction returns `+∞` so it never wins.
2923pub(crate) fn encode_reconstruction_error(
2924    atom: &SaeManifoldAtom,
2925    evaluator: &dyn SaeBasisEvaluator,
2926    coord: ArrayView1<'_, f64>,
2927    x: ArrayView1<'_, f64>,
2928    amplitude: f64,
2929) -> f64 {
2930    let d = atom.latent_dim;
2931    let p = atom.output_dim();
2932    let m = atom.basis_size();
2933    let coords = match coord.to_shape((1, d)) {
2934        Ok(c) => c.to_owned(),
2935        Err(_) => return f64::INFINITY,
2936    };
2937    let Ok((phi, _jet)) = evaluator.evaluate(coords.view()) else {
2938        return f64::INFINITY;
2939    };
2940    let mut err2 = 0.0;
2941    for out in 0..p {
2942        let mut recon = 0.0;
2943        for basis_col in 0..m {
2944            recon += phi[[0, basis_col]] * atom.decoder_coefficients[[basis_col, out]];
2945        }
2946        let r = x[out] - amplitude * recon;
2947        err2 += r * r;
2948    }
2949    if err2.is_finite() { err2.sqrt() } else { f64::INFINITY }
2950}
2951
2952/// Maximum number of chart centers laid down per atom (the SHAPE_BAND grid
2953/// point cap; mirrors `SHAPE_BAND_MAX_POINTS` in the atom band machinery).
2954pub(crate) const SHAPE_BAND_MAX_POINTS: usize = 512;
2955
2956/// Lay down chart centers on an atom's coordinate grid (the SHAPE_BAND grid
2957/// idiom): a regular grid spanning the compact latent domain for periodic /
2958/// sphere / torus atoms, and a strided cover of the latent axes for unbounded
2959/// (Duchon / Euclidean) atoms.
2960///
2961/// Periodic / torus latents are fractions of one period, so the per-axis grid
2962/// spans `[0, 1)`; the sphere chart spans `lat ∈ [−π/2, π/2]`, `lon ∈ [−π, π)`.
2963/// These conventions match the basis evaluators (the fraction-of-period circle
2964/// harmonic and the lat/lon sphere chart).
2965/// Squared coordinate distance between two latent points under the atom's chart
2966/// geometry: per-axis WRAPPED distance `min(|a−b|, period−|a−b|)` on periodic
2967/// (circle) axes — period 1 to match `chart_center_grid`'s `[0,1)` torus tiling
2968/// — and plain difference on line axes. Used to place + size data-driven charts.
2969pub(crate) fn coord_dist_sq(atom: &SaeManifoldAtom, a: ArrayView1<'_, f64>, b: ArrayView1<'_, f64>) -> f64 {
2970    use crate::manifold::SaeAtomBasisKind::*;
2971    let periodic_axis = |axis: usize| -> bool {
2972        match &atom.basis_kind {
2973            Periodic | Torus | Sphere => true,
2974            // Cylinder S¹×ℝ: only axis 0 is the circle.
2975            Cylinder => axis == 0,
2976            Linear | Duchon | EuclideanPatch | Poincare | Precomputed(_) => false,
2977        }
2978    };
2979    let mut acc = 0.0;
2980    for axis in 0..a.len() {
2981        let mut d = (a[axis] - b[axis]).abs();
2982        if periodic_axis(axis) {
2983            // Wrap onto the circle of unit period.
2984            d -= d.floor(); // fractional part in [0,1)
2985            d = d.min(1.0 - d);
2986        }
2987        acc += d * d;
2988    }
2989    acc
2990}
2991
2992/// Greedy farthest-point sampling of up to `max_charts` chart centers from the
2993/// atom's latent `coords` (n × d), with each center's nominal radius set to half
2994/// the distance to its nearest neighbor center (floored, so a singleton/coincident
2995/// cluster still gets a usable ball). Deterministic: seeds from row 0, then
2996/// repeatedly adds the coord maximally far (under [`coord_dist_sq`]) from the
2997/// chosen set — coverage-maximizing and reproducible run-to-run.
2998pub(crate) fn data_driven_chart_centers(
2999    atom: &SaeManifoldAtom,
3000    coords: ArrayView2<'_, f64>,
3001    max_charts: usize,
3002) -> Result<(Array2<f64>, Vec<f64>), String> {
3003    let n = coords.nrows();
3004    let d = coords.ncols();
3005    if d != atom.latent_dim {
3006        return Err(format!(
3007            "data_driven_chart_centers: coords have {d} cols but atom latent_dim is {}",
3008            atom.latent_dim
3009        ));
3010    }
3011    if n == 0 {
3012        return Ok((Array2::<f64>::zeros((0, d)), Vec::new()));
3013    }
3014    let k = max_charts.min(n);
3015    // Farthest-point sampling: maintain each row's distance to the nearest chosen
3016    // center, add the row with the maximum such distance each step.
3017    let mut chosen: Vec<usize> = Vec::with_capacity(k);
3018    chosen.push(0);
3019    let mut nearest_sq: Vec<f64> = (0..n)
3020        .map(|r| coord_dist_sq(atom, coords.row(r), coords.row(0)))
3021        .collect();
3022    while chosen.len() < k {
3023        // Pick the row farthest from the current center set (first-wins tie).
3024        let mut best = 0usize;
3025        let mut best_d = -1.0;
3026        for r in 0..n {
3027            if nearest_sq[r] > best_d {
3028                best_d = nearest_sq[r];
3029                best = r;
3030            }
3031        }
3032        if best_d <= 0.0 {
3033            break; // all remaining rows coincide with a chosen center.
3034        }
3035        chosen.push(best);
3036        for r in 0..n {
3037            let dr = coord_dist_sq(atom, coords.row(r), coords.row(best));
3038            if dr < nearest_sq[r] {
3039                nearest_sq[r] = dr;
3040            }
3041        }
3042    }
3043    let m = chosen.len();
3044    let mut centers = Array2::<f64>::zeros((m, d));
3045    for (i, &row) in chosen.iter().enumerate() {
3046        centers.row_mut(i).assign(&coords.row(row));
3047    }
3048    // Per-center radius = half the nearest-OTHER-center distance, floored so a
3049    // coincident pair still yields a positive ball, capped at 0.5 (the largest
3050    // meaningful half-period on a unit circle).
3051    let mut radii = vec![0.0_f64; m];
3052    for i in 0..m {
3053        let mut nn = f64::INFINITY;
3054        for j in 0..m {
3055            if i == j {
3056                continue;
3057            }
3058            let dsq = coord_dist_sq(atom, centers.row(i), centers.row(j));
3059            if dsq < nn {
3060                nn = dsq;
3061            }
3062        }
3063        let r = if nn.is_finite() { 0.5 * nn.sqrt() } else { 0.5 };
3064        radii[i] = r.max(1.0e-3).min(0.5);
3065    }
3066    Ok((centers, radii))
3067}
3068
3069pub(crate) fn chart_center_grid(atom: &SaeManifoldAtom, resolution: usize) -> Array2<f64> {
3070    use crate::manifold::SaeAtomBasisKind::*;
3071    let d = atom.latent_dim;
3072    match &atom.basis_kind {
3073        Periodic | Torus => regular_product_grid(d, resolution, 0.0, 1.0, false),
3074        // Cylinder `S¹ × ℝ`: axis 0 is the periodic circle `[0, 1)` (no
3075        // endpoint, like the harmonic axes); axis 1 is the unbounded line,
3076        // covered by a strided unit box `[-0.5, 0.5]` about the origin (like the
3077        // Euclidean patch). The certified radius refines each chart; out-of-cover
3078        // line starts route to the exact fallback honestly.
3079        Cylinder if d == 2 => cylinder_chart_center_grid(resolution),
3080        Cylinder => regular_product_grid(d, resolution, -0.5, 0.5, true),
3081        Sphere if d == 2 => sphere_latlon_grid(resolution),
3082        Linear | Sphere | Duchon | EuclideanPatch | Poincare | Precomputed(_) => {
3083            // Unbounded / non-compact latents: a strided cover of a unit box
3084            // about the origin per axis. The certified radius refines each chart;
3085            // out-of-cover starts route to the exact fallback honestly.
3086            regular_product_grid(d, resolution, -0.5, 0.5, true)
3087        }
3088    }
3089}
3090
3091/// A regular `resolution`-per-axis product grid over `[lo, hi]^d`, capped at
3092/// [`SHAPE_BAND_MAX_POINTS`] total points (the per-axis resolution is reduced
3093/// until the product fits). When `include_endpoint` the last grid point sits at
3094/// `hi`; otherwise the axis is treated as periodic and stops one step short.
3095pub(crate) fn regular_product_grid(
3096    d: usize,
3097    resolution: usize,
3098    lo: f64,
3099    hi: f64,
3100    include_endpoint: bool,
3101) -> Array2<f64> {
3102    if d == 0 {
3103        return Array2::<f64>::zeros((1, 0));
3104    }
3105    let mut per_axis = resolution.max(2);
3106    while per_axis.saturating_pow(d as u32) > SHAPE_BAND_MAX_POINTS && per_axis > 2 {
3107        per_axis -= 1;
3108    }
3109    let total = per_axis.saturating_pow(d as u32).max(1);
3110    let denom = if include_endpoint {
3111        (per_axis.max(2) - 1) as f64
3112    } else {
3113        per_axis as f64
3114    };
3115    let mut grid = Array2::<f64>::zeros((total, d));
3116    let mut idx = vec![0usize; d];
3117    for flat in 0..total {
3118        for axis in 0..d {
3119            let frac = idx[axis] as f64 / denom;
3120            grid[[flat, axis]] = lo + (hi - lo) * frac;
3121        }
3122        for axis in (0..d).rev() {
3123            idx[axis] += 1;
3124            if idx[axis] < per_axis {
3125                break;
3126            }
3127            idx[axis] = 0;
3128        }
3129    }
3130    grid
3131}
3132
3133/// Lat/lon sphere chart grid: `lat ∈ [−π/2, π/2]`, `lon ∈ [−π, π)`, matching
3134/// the [`crate::manifold::SphereChartEvaluator`] convention.
3135pub(crate) fn sphere_latlon_grid(resolution: usize) -> Array2<f64> {
3136    use std::f64::consts::PI;
3137    let r = resolution.max(2).min(22); // 22² = 484 ≤ SHAPE_BAND_MAX_POINTS.
3138    let mut grid = Array2::<f64>::zeros((r * r, 2));
3139    for i in 0..r {
3140        let lat = -PI / 2.0 + PI * (i as f64 + 0.5) / r as f64;
3141        for j in 0..r {
3142            let lon = -PI + 2.0 * PI * (j as f64) / r as f64;
3143            grid[[i * r + j, 0]] = lat;
3144            grid[[i * r + j, 1]] = lon;
3145        }
3146    }
3147    grid
3148}
3149
3150/// Cylinder `S¹ × ℝ` chart-center grid: axis 0 sweeps the periodic circle over
3151/// one period `[0, 1)` (no endpoint, matching the harmonic axis), axis 1 strides
3152/// a unit box `[−0.5, 0.5]` about the origin on the unbounded line (with
3153/// endpoint). Capped at [`SHAPE_BAND_MAX_POINTS`] total centers.
3154pub(crate) fn cylinder_chart_center_grid(resolution: usize) -> Array2<f64> {
3155    let mut per_axis = resolution.max(2);
3156    while per_axis * per_axis > SHAPE_BAND_MAX_POINTS && per_axis > 2 {
3157        per_axis -= 1;
3158    }
3159    let total = per_axis * per_axis;
3160    let line_denom = (per_axis.max(2) - 1) as f64;
3161    let mut grid = Array2::<f64>::zeros((total, 2));
3162    for i in 0..per_axis {
3163        // Periodic axis 0: stop one step short of the period.
3164        let circle = i as f64 / per_axis as f64;
3165        for j in 0..per_axis {
3166            // Line axis 1: include the endpoint of the unit box.
3167            let line = -0.5 + (j as f64) / line_denom;
3168            grid[[i * per_axis + j, 0]] = circle;
3169            grid[[i * per_axis + j, 1]] = line;
3170        }
3171    }
3172    grid
3173}
3174
3175/// Nominal in-chart radius: half the inter-center grid spacing, so charts tile
3176/// the domain. For compact latents this is the grid step; for unbounded latents
3177/// a unit default that the certified radius refines.
3178pub(crate) fn chart_nominal_radius(atom: &SaeManifoldAtom, resolution: usize) -> f64 {
3179    use crate::manifold::SaeAtomBasisKind::*;
3180    match &atom.basis_kind {
3181        Periodic | Torus => 0.5 / (resolution.max(2) as f64),
3182        Sphere => std::f64::consts::PI / (resolution.max(2) as f64),
3183        // Cylinder charts tile two heterogeneous axes (a `[0,1)` periodic step
3184        // and a unit-box line step); the chart radius is a single scalar, so we
3185        // take the tighter (periodic) step `0.5/res` to keep every chart valid
3186        // on both axes. The certified Kantorovich radius refines it per chart.
3187        Cylinder => 0.5 / (resolution.max(2) as f64),
3188        Linear | Duchon | EuclideanPatch | Poincare | Precomputed(_) => {
3189            1.0 / (resolution.max(2) as f64)
3190        }
3191    }
3192}
3193
3194/// Build the [`ChartRegion`] for a center, attaching the radial r_min / r_max
3195/// bracket for Duchon atoms (the chart's distance range to the kernel centers).
3196pub(crate) fn chart_region(
3197    atom: &SaeManifoldAtom,
3198    center: Array1<f64>,
3199    radius: f64,
3200) -> ChartRegion {
3201    use crate::manifold::SaeAtomBasisKind::*;
3202    let region = ChartRegion::new(center.clone(), radius);
3203    match &atom.basis_kind {
3204        Duchon => {
3205            // r ranges over [‖t_c‖ − radius, ‖t_c‖ + radius] about the single
3206            // origin-anchored center used by the conservative radial bound.
3207            //
3208            // The lower bound must be `max(0, center_norm − radius)` — NOT floored
3209            // at `radius`. When the chart contains the kernel center
3210            // (`center_norm < radius`, true r_min = 0), flooring at `radius`
3211            // would give a finite, NON-CONSERVATIVE `r_min`, causing the
3212            // hessian_sup / third_sup formulas (which divide by r_min) to
3213            // underestimate the Lipschitz constant and potentially grant a false
3214            // Kantorovich certificate. Flooring at `f64::MIN_POSITIVE` instead
3215            // correctly drives the formulas toward ∞, producing a very large L
3216            // that will NEVER certify (rows route to the exact multi-start
3217            // fallback) — conservative and sound.
3218            let center_norm = center.dot(&center).sqrt();
3219            let r_min = (center_norm - radius).max(f64::MIN_POSITIVE);
3220            let r_max = center_norm + radius;
3221            region.with_radial_bounds(r_min, r_max)
3222        }
3223        // Cylinder has no radial kernel block (it is a harmonic × polynomial
3224        // tensor, not a Duchon radial basis), so it needs no radial r_min/r_max.
3225        Periodic | Sphere | Torus | Cylinder | Linear | EuclideanPatch | Poincare
3226        | Precomputed(_) => region,
3227    }
3228}