Skip to main content

gam_sae/
encode.rs

1//! Kantorovich-certified encode atlas (issue #1010).
2//!
3//! Encoding a row `x ∈ ℝᵖ` against a FROZEN SAE dictionary is, per atom `k`,
4//! the coordinate-only Newton problem
5//!
6//! ```text
7//! min_t  f_k(t) = ½‖x − z_k · B_kᵀ Φ_k(t)‖² + prior_k(t),
8//! ```
9//!
10//! with the amplitude `z_k` and decoder block `B_k` held fixed (the encode
11//! freezes the dictionary; only the latent coordinate `t` moves). Newton on
12//! `F(t) = ∇f_k(t)` converges quadratically from a start `t₀` into the unique
13//! root in a certified ball whenever the **Newton–Kantorovich** quantity
14//!
15//! ```text
16//! h = β · η · L ≤ ½,    β = ‖F'(t₀)⁻¹‖,   η = ‖F'(t₀)⁻¹ F(t₀)‖,
17//! ```
18//!
19//! where `L` is a Lipschitz constant of `F'` (the Hessian of `f_k`) on a region
20//! containing the Newton iterates. `h` is CHECKABLE per row in `O(q³)`
21//! (`q = latent_dim`, tiny), so each fast-path encode carries its own
22//! exactness certificate.
23//!
24//! ## The closed-form Hessian-Lipschitz constant `L`
25//!
26//! Write `m(t) = z·BᵀΦ(t) ∈ ℝᵖ` (the reconstruction) and `r(t) = m(t) − x`.
27//! Then `f = ½‖r‖² + prior` and, differentiating three times,
28//!
29//! ```text
30//! ∇³f = 3·sym(J_mᵀ : ∇²m) + ⟨r, ∇³m⟩ + ∇³prior,
31//! ```
32//!
33//! so an operator-norm bound on the chart is
34//!
35//! ```text
36//! L ≤ 3·‖J_m‖·‖∇²m‖ + ‖r‖·‖∇³m‖ + L_prior,
37//! ```
38//!
39//! with `‖∂^g m‖ ≤ |z|·(Σ_m ‖B_{m,:}‖)·B_g`, where `B_g = sup_chart max_m
40//! ‖∂^g Φ_m‖` is the per-column jet sup of the basis family — closed form per
41//! family ([`BasisHessianLipschitz`]). `‖r‖` is bounded by `‖x‖ +
42//! |z|·(Σ_m‖B_{m,:}‖)·B_0`. The ARD/von-Mises prior `L_prior` is a closed-form
43//! constant from the prior strength. Every bound is conservative (an
44//! over-estimate of `L` only SHRINKS the certified radius — it can never
45//! certify a row that does not converge).
46//!
47//! ## Pipeline
48//!
49//! 1. **Offline, per atom** ([`EncodeAtlas::build`]): chart centers `t_c` on the
50//!    atom's coordinate grid (the SHAPE_BAND grid idiom), each with a certified
51//!    Newton radius `R_c` solved from the Kantorovich inequality at the
52//!    worst-case in-chart start.
53//! 2. **Online, per row** ([`EncodeAtlas::certified_encode_row`]): route to the
54//!    nearest chart, start from its distilled IFT predictor, take one or two
55//!    Newton steps, then the `h ≤ ½` check AT the start point is the per-row
56//!    certificate.
57//! 3. **Uncertified tail**: rows whose start fails `h ≤ ½` are FLAGGED (counted
58//!    in [`EncodeResult::encode_uncertified_count`]) and must be routed by the
59//!    caller to the existing exact multi-start solve. No approximation enters
60//!    silently.
61
62use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
63
64use gam_linalg::faer_ndarray::FaerEigh;
65use crate::candidate_index::{
66    AtomFrameSketch, SaeCandidateIndex, auto_candidate_budget,
67};
68use crate::manifold::{
69    AffineCoordinateEvaluator, CylinderHarmonicEvaluator, DuchonCoordinateEvaluator,
70    EuclideanPatchEvaluator, PeriodicHarmonicEvaluator, SaeBasisEvaluator, SaeManifoldAtom,
71    SphereChartEvaluator, TorusHarmonicEvaluator,
72};
73
74use faer::Side;
75
76/// The Kantorovich convergence threshold `h ≤ ½`. Below this the Newton
77/// iteration is guaranteed to converge quadratically into the unique root in
78/// the certified ball; at or above it the start is uncertified.
79pub const KANTOROVICH_THRESHOLD: f64 = 0.5;
80
81/// Row count at or above which the corpus-rate certified-encode batch
82/// (`certified_encode_batch` / `certified_encode_with_index`) fans its
83/// per-row encodes out over rayon. Below this the per-row Newton + chart
84/// routing is cheap enough that the fan-out overhead does not pay; matched to
85/// the same order as the arrow-Schur `SCHUR_MATVEC_PARALLEL_ROW_MIN` gate so
86/// short batches inside an outer atom-level fan-out stay sequential.
87pub(crate) const ENCODE_BATCH_PARALLEL_ROW_MIN: usize = 256;
88
89/// Minimum frame alignment `‖Uₖᵀd‖/‖d‖ ∈ [0,1]` the best LSH-gathered atom must
90/// have for an index-routed encode to be TRUSTED (#1026). The in-atom Kantorovich
91/// certificate attests Newton convergence *within* the chosen atom; it does NOT
92/// attest that the LSH gather found the globally-correct atom. Because alignment
93/// is capped at 1, a HIGH best-alignment leaves no room for a materially-better
94/// ungathered atom (the gather is trustworthy), whereas a LOW best-alignment means
95/// the true best atom was likely missed by LSH recall. Rows whose best gathered
96/// atom aligns below this are flagged uncertified and routed to the exact
97/// full-scan fallback — closing the silent-mis-routing hole where a non-empty but
98/// wrong gather certified against a suboptimal atom. Erring toward flagging only
99/// ever adds exact-fallback work (correct, slower); it never certifies a mis-route.
100pub(crate) const CANDIDATE_ROUTING_MIN_ALIGNMENT: f64 = 0.5;
101
102/// A chart region on an atom's latent coordinate: a center `t_c` plus a
103/// certified in-chart radius. Over the ball `‖t − t_c‖ ≤ radius` the jet sup
104/// bounds returned by [`BasisHessianLipschitz`] hold, so the Kantorovich
105/// constant `L` computed from them is valid for any start in the ball.
106///
107/// For radial (Duchon) families the chart also carries the minimum kernel-center
108/// distance `exclusion_r_min` (a lower bound on `‖t − c_k‖` over the chart) that
109/// bounds the otherwise-singular `1/r` radial tails (issue #1010).
110#[derive(Debug, Clone)]
111pub struct ChartRegion {
112    /// Chart center coordinate `t_c` (length = latent_dim).
113    pub center: Array1<f64>,
114    /// In-chart radius in the coordinate metric.
115    pub radius: f64,
116    /// For radial (Duchon) families: a lower bound on `‖t − c_k‖` over the
117    /// chart, across every kernel center `c_k`. `None` for non-radial families.
118    pub exclusion_r_min: Option<f64>,
119    /// For radial (Duchon) families: an upper bound on `‖t − c_k‖` over the
120    /// chart, across every kernel center `c_k`. `None` for non-radial families.
121    pub radial_r_max: Option<f64>,
122}
123
124impl ChartRegion {
125    pub fn new(center: Array1<f64>, radius: f64) -> Self {
126        Self {
127            center,
128            radius,
129            exclusion_r_min: None,
130            radial_r_max: None,
131        }
132    }
133
134    pub fn with_radial_bounds(mut self, r_min: f64, r_max: f64) -> Self {
135        self.exclusion_r_min = Some(r_min);
136        self.radial_r_max = Some(r_max);
137        self
138    }
139
140    /// A jet-sup certificate is only meaningful over a genuine region. Even
141    /// families whose bounds are manifold-global constants (the sup over any
142    /// chart equals the global sup) must refuse a malformed chart rather than
143    /// certify garbage geometry.
144    pub(crate) fn assert_valid(&self) {
145        assert!(
146            self.radius.is_finite()
147                && self.radius >= 0.0
148                && self.center.iter().all(|c| c.is_finite()),
149            "ChartRegion must have a finite center and a finite non-negative radius"
150        );
151    }
152}
153
154/// Per-column sup-norm bounds on the first three coordinate jets of a basis
155/// family `Φ(t)`, valid over a stated [`ChartRegion`] (issue #1010). These are
156/// the analytic ingredients of the Hessian-Lipschitz constant `L` — see the
157/// module docs for the assembly. `value_sup` bounds `max_m |Φ_m|`,
158/// `jacobian_sup`/`hessian_sup`/`third_sup` bound `max_m ‖∂^g Φ_m‖`.
159pub trait BasisHessianLipschitz {
160    fn value_sup(&self, chart: &ChartRegion) -> f64;
161    fn jacobian_sup(&self, chart: &ChartRegion) -> f64;
162    fn hessian_sup(&self, chart: &ChartRegion) -> f64;
163    fn third_sup(&self, chart: &ChartRegion) -> f64;
164}
165
166/// Sup over the circle of the `g`-th derivative of any single harmonic column
167/// of a `num_basis`-wide Fourier basis `[1, sin(2π h t), cos(2π h t), …]`:
168/// `(2π·H)^g` for the top harmonic `H = (num_basis − 1)/2`. The constant column
169/// contributes `0` for `g ≥ 1`, so the top harmonic dominates; the bound is
170/// global (the trig magnitudes are `≤ 1` everywhere, independent of the chart).
171pub(crate) fn harmonic_jet_sup(num_basis: usize, order: u32) -> f64 {
172    let top_harmonic = num_basis.saturating_sub(1) / 2;
173    let omega = std::f64::consts::TAU * top_harmonic as f64;
174    omega.powi(order as i32)
175}
176
177impl BasisHessianLipschitz for PeriodicHarmonicEvaluator {
178    fn value_sup(&self, chart: &ChartRegion) -> f64 {
179        chart.assert_valid();
180        1.0
181    }
182    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
183        chart.assert_valid();
184        harmonic_jet_sup(self.num_basis, 1)
185    }
186    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
187        chart.assert_valid();
188        harmonic_jet_sup(self.num_basis, 2)
189    }
190    fn third_sup(&self, chart: &ChartRegion) -> f64 {
191        chart.assert_valid();
192        harmonic_jet_sup(self.num_basis, 3)
193    }
194}
195
196impl BasisHessianLipschitz for TorusHarmonicEvaluator {
197    /// Tensor product of per-axis circle harmonics. A torus basis column is a
198    /// product of single-axis harmonics, each bounded as in the circle case.
199    /// The `g`-th coordinate jet routes `g` derivative operators across the
200    /// `latent_dim` factors (Leibniz); each routing contributes a product of
201    /// per-axis derivative magnitudes. A per-column sup is therefore bounded by
202    /// the top single-axis frequency to the `g`-th power times the number of
203    /// such routings (`latent_dim^g`, the count of operator-to-axis maps).
204    fn value_sup(&self, chart: &ChartRegion) -> f64 {
205        chart.assert_valid();
206        1.0
207    }
208    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
209        chart.assert_valid();
210        torus_jet_sup(self.num_harmonics, self.latent_dim, 1)
211    }
212    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
213        chart.assert_valid();
214        torus_jet_sup(self.num_harmonics, self.latent_dim, 2)
215    }
216    fn third_sup(&self, chart: &ChartRegion) -> f64 {
217        chart.assert_valid();
218        torus_jet_sup(self.num_harmonics, self.latent_dim, 3)
219    }
220}
221
222/// Per-column `g`-th jet sup for the torus harmonic basis: `(2π·H)^g ·
223/// latent_dim^g`, where `H = num_harmonics` is the top per-axis frequency and
224/// `latent_dim^g` over-counts the Leibniz routings of `g` operators across the
225/// product factors (a conservative bound — each routing's per-axis magnitude is
226/// `≤ (2π H)^{#ops on that axis}`, and the products telescope to `(2π H)^g`).
227pub(crate) fn torus_jet_sup(num_harmonics: usize, latent_dim: usize, order: u32) -> f64 {
228    let omega = std::f64::consts::TAU * num_harmonics as f64;
229    omega.powi(order as i32) * (latent_dim as f64).powi(order as i32)
230}
231
232impl BasisHessianLipschitz for SphereChartEvaluator {
233    /// The 7-column lat/lon chart `[1, x, y, z, xy, yz, xz]` with
234    /// `x = cos(lat)cos(lon)`, `y = cos(lat)sin(lon)`, `z = sin(lat)`. Each of
235    /// `x, y, z` is a product of two unit-frequency trig factors, so its `g`-th
236    /// coordinate jet is a sum of `2^g` products of `{sin,cos}` (each `≤ 1`):
237    /// magnitude `≤ 2^g` for `g ≥ 1`, `≤ 1` for `g = 0`. The bilinear columns
238    /// `xy, yz, xz` are products of two such coordinates; by Leibniz over the
239    /// product, their `g`-th jet is bounded by `Σ_{i=0}^{g} C(g,i)·(2^i)·(2^{g−i})
240    /// = (2+2)^g = 4^g` (using `‖∂^i u‖ ≤ 2^i`, `|u| ≤ 1`). The bilinear columns
241    /// dominate, so the per-column sup is `4^g` (`g ≥ 1`). Bounds are global
242    /// constants — the chart box `lat ∈ [-π/2, π/2]` does not enlarge them.
243    fn value_sup(&self, chart: &ChartRegion) -> f64 {
244        chart.assert_valid();
245        1.0
246    }
247    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
248        chart.assert_valid();
249        4.0
250    }
251    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
252        chart.assert_valid();
253        16.0
254    }
255    fn third_sup(&self, chart: &ChartRegion) -> f64 {
256        chart.assert_valid();
257        64.0
258    }
259}
260
261impl BasisHessianLipschitz for AffineCoordinateEvaluator {
262    /// The affine basis `[1, t₁, …, t_d]` is degree ≤ 1: its first jet has unit
263    /// columns, and all second and third jets vanish. The value sup is
264    /// `max(1, ‖t‖)` over the chart, bounded by `1 + ‖t_c‖ + radius`.
265    fn value_sup(&self, chart: &ChartRegion) -> f64 {
266        let center_norm = chart.center.dot(&chart.center).sqrt();
267        1.0 + center_norm + chart.radius
268    }
269    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
270        chart.assert_valid();
271        1.0
272    }
273    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
274        chart.assert_valid();
275        0.0
276    }
277    fn third_sup(&self, chart: &ChartRegion) -> f64 {
278        chart.assert_valid();
279        0.0
280    }
281}
282
283impl BasisHessianLipschitz for EuclideanPatchEvaluator {
284    /// Monomials of total degree ≤ `max_degree` in `t ∈ ℝ^d`. Over the ball of
285    /// radius `R` about `t_c`, each coordinate is bounded by `ρ = ‖t_c‖∞ + R`.
286    /// A monomial `t^α` with `|α| = q` has `g`-th partials bounded (crudely) by
287    /// the descending-factorial coefficient `q·(q−1)···(q−g+1) ≤ q^g` times
288    /// `ρ^{max(q−g,0)}`, and there are at most `d^g` partial routings, so the
289    /// per-column `g`-th jet sup is `≤ d^g · D^g · ρ^{max(D−g,0)}` with
290    /// `D = max_degree`. Conservative; D is small for patch evaluators.
291    fn value_sup(&self, chart: &ChartRegion) -> f64 {
292        let rho = patch_rho(chart);
293        let d = self.max_degree as i32;
294        rho.powi(d).max(1.0)
295    }
296    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
297        patch_jet_sup(self.latent_dim, self.max_degree, chart, 1)
298    }
299    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
300        patch_jet_sup(self.latent_dim, self.max_degree, chart, 2)
301    }
302    fn third_sup(&self, chart: &ChartRegion) -> f64 {
303        patch_jet_sup(self.latent_dim, self.max_degree, chart, 3)
304    }
305}
306
307impl BasisHessianLipschitz for CylinderHarmonicEvaluator {
308    /// Cylinder `S¹ × ℝ` product basis `Φ_{c,l} = c(t₀)·l(t₁)`, the circle
309    /// (periodic harmonic) factor on axis 0 crossed with the monomial line
310    /// factor on axis 1. Because the two factors depend on disjoint coordinates,
311    /// the order-`g` coordinate jet in any cell is exactly
312    /// `c^{(k₀)}(t₀)·l^{(k₁)}(t₁)` with `k₀ + k₁ = g`, so the per-column sup is
313    /// the max over the split `k₀ + k₁ = g` of the product of the two per-axis
314    /// per-order sups: the circle factor contributes `1` at order 0 and
315    /// `(2π·H)^{k₀}` at order `k₀ ≥ 1` (trig magnitudes `≤ 1`); the line factor
316    /// contributes the monomial-patch sup `D^{k₁}·ρ^{max(D−k₁,0)}` (`D = line
317    /// degree`, `ρ = ‖t_c‖∞ + radius`). Bounds are global in the periodic axis
318    /// and chart-local in the line axis.
319    fn value_sup(&self, chart: &ChartRegion) -> f64 {
320        cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 0)
321    }
322    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
323        cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 1)
324    }
325    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
326        cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 2)
327    }
328    fn third_sup(&self, chart: &ChartRegion) -> f64 {
329        cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 3)
330    }
331}
332
333/// Per-column order-`g` jet sup of the cylinder product basis: the max over
334/// `k₀ + k₁ = g` of `circle_axis_sup(k₀) · line_axis_sup(k₁)`, where the circle
335/// axis sup is `(2π·H)^{k₀}` (`1` at `k₀ = 0`) and the line axis sup is the
336/// monomial-patch bound `D^{k₁}·ρ^{max(D−k₁,0)}` (`1` at `k₁ = 0`). See the
337/// [`CylinderHarmonicEvaluator`] doc comment for the derivation.
338pub(crate) fn cylinder_jet_sup(
339    circle_harmonics: usize,
340    line_degree: usize,
341    chart: &ChartRegion,
342    order: u32,
343) -> f64 {
344    let omega = std::f64::consts::TAU * circle_harmonics as f64;
345    let big_d = line_degree as f64;
346    let rho = patch_rho(chart);
347    let mut best = 0.0_f64;
348    for k0 in 0..=order {
349        let k1 = order - k0;
350        let circle = if k0 == 0 { 1.0 } else { omega.powi(k0 as i32) };
351        let line = if k1 == 0 {
352            rho.powi(line_degree as i32).max(1.0)
353        } else {
354            let residual = line_degree.saturating_sub(k1 as usize) as i32;
355            // `.max(1.0)` as in `patch_jet_sup`: for ρ < 1 a lower-degree line
356            // monomial dominates the k1-th derivative, so the bare `ρ^residual`
357            // underestimates the line-factor sup. The value case (k1==0) already
358            // clamps; this completes it for the derivative orders.
359            big_d.powi(k1 as i32) * rho.powi(residual).max(1.0)
360        };
361        best = best.max(circle * line);
362    }
363    best
364}
365
366/// Sup-norm radius `ρ = ‖t_c‖∞ + radius` of the chart (the coordinate magnitude
367/// bound used by the monomial-patch jet bounds).
368pub(crate) fn patch_rho(chart: &ChartRegion) -> f64 {
369    let center_inf = chart
370        .center
371        .iter()
372        .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
373    center_inf + chart.radius
374}
375
376/// Per-column `g`-th jet sup for a monomial patch of max degree `D` in `d`
377/// coordinates over the chart: `d^g · D^g · ρ^{max(D−g,0)}` (see the
378/// [`EuclideanPatchEvaluator`] doc comment for the derivation).
379pub(crate) fn patch_jet_sup(
380    latent_dim: usize,
381    max_degree: usize,
382    chart: &ChartRegion,
383    order: u32,
384) -> f64 {
385    let d = latent_dim as f64;
386    let big_d = max_degree as f64;
387    let rho = patch_rho(chart);
388    let residual_degree = max_degree.saturating_sub(order as usize) as i32;
389    // `.max(1.0)`: for ρ < 1 (small charts near the origin) the g-th jet sup is NOT
390    // dominated by the max-degree monomial `t^D` (whose g-th derivative ~ ρ^{D-g}
391    // shrinks with ρ) but by a LOWER-degree monomial whose g-th derivative is a
392    // larger constant — e.g. {1,t,t²}'s jacobian sup is the linear term's constant
393    // `1`, which exceeds `2ρ` when ρ < ½. Without the clamp the bound underestimates
394    // the true sup (numerically: D=3, ρ=0.1, g=1 → formula 0.03 vs true 1.0), which
395    // would make the certificate's Lipschitz `L` too small → a FALSE certificate.
396    // `D^g · max(ρ^{D-g}, 1)` upper-bounds `max_{q∈[g,D]} (q!/(q-g)!)·ρ^{q-g}` for
397    // all ρ (the `q=g` term gives `g! ≤ D^g`, the `q=D` term gives `≤ D^g·ρ^{D-g}`).
398    d.powi(order as i32) * big_d.powi(order as i32) * rho.powi(residual_degree).max(1.0)
399}
400
401impl BasisHessianLipschitz for DuchonCoordinateEvaluator {
402    /// Radial-kernel basis `Φ_m(t) = φ(r_m)`, `r_m = ‖t − c_m‖`, plus a
403    /// polynomial nullspace block. For the cubic Duchon kernel `φ(r) = r³` the
404    /// radial derivatives are `φ' = 3r²`, `φ'' = 6r`, `φ''' = 6`. The chain rule
405    /// to coordinate jets introduces `1/r` factors through the unit radial
406    /// direction `u = (t − c)/r` and the projector `(I − uuᵀ)/r`, so over a
407    /// chart the jets are bounded by combining the radial-derivative magnitudes
408    /// at the worst-case radius with the inverse-radius tail at the chart's
409    /// EXCLUSION radius `r_min` (the closest a chart point gets to any center):
410    ///
411    /// ```text
412    /// ‖∇φ‖    ≤ |φ'|                              ≤ 3 r_max²
413    /// ‖∇²φ‖   ≤ |φ''| + |φ'|/r                    ≤ 6 r_max + 3 r_max²/r_min
414    /// ‖∇³φ‖   ≤ |φ'''| + 3|φ''|/r + 3|φ'|/r²      ≤ 6 + 18 r_max/r_min + 9 r_max²/r_min²
415    /// ```
416    ///
417    /// (the `1/r`, `1/r²` tails are bounded by `1/r_min`, `1/r_min²`). The
418    /// polynomial nullspace block is degree ≤ `order`; its jets are bounded like
419    /// the monomial patch with `D = order`. The per-column sup is the max of the
420    /// kernel and polynomial bounds. The `r³` kernel is itself `C²` (no
421    /// singularity) so these tails are conservative but finite for any
422    /// `r_min > 0`; the atlas refines charts to keep `r_min` bounded away from 0.
423    fn value_sup(&self, chart: &ChartRegion) -> f64 {
424        let r_max = chart.radial_r_max.unwrap_or(chart.radius);
425        let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 0);
426        (r_max.powi(3)).max(poly)
427    }
428    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
429        let r_max = chart.radial_r_max.unwrap_or(chart.radius);
430        let kernel = 3.0 * r_max * r_max;
431        let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 1);
432        kernel.max(poly)
433    }
434    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
435        let r_max = chart.radial_r_max.unwrap_or(chart.radius);
436        let r_min = chart
437            .exclusion_r_min
438            .unwrap_or(chart.radius)
439            .max(f64::MIN_POSITIVE);
440        let kernel = 6.0 * r_max + 3.0 * r_max * r_max / r_min;
441        let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 2);
442        kernel.max(poly)
443    }
444    fn third_sup(&self, chart: &ChartRegion) -> f64 {
445        let r_max = chart.radial_r_max.unwrap_or(chart.radius);
446        let r_min = chart
447            .exclusion_r_min
448            .unwrap_or(chart.radius)
449            .max(f64::MIN_POSITIVE);
450        let kernel = 6.0 + 18.0 * r_max / r_min + 9.0 * r_max * r_max / (r_min * r_min);
451        let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 3);
452        kernel.max(poly)
453    }
454}
455
456/// Polynomial-block degree of a Duchon nullspace order, used to bound the
457/// nullspace columns like a monomial patch.
458trait DuchonOrderDegree {
459    fn order_degree(&self) -> usize;
460}
461
462impl DuchonOrderDegree for DuchonCoordinateEvaluator {
463    fn order_degree(&self) -> usize {
464        match self.order {
465            gam_terms::basis::DuchonNullspaceOrder::Zero => 0,
466            gam_terms::basis::DuchonNullspaceOrder::Linear => 1,
467            gam_terms::basis::DuchonNullspaceOrder::Degree(d) => d,
468        }
469    }
470}
471
472/// Per-column `g`-th jet sup of the Duchon polynomial nullspace block, treated
473/// as a monomial patch of degree `order_degree`.
474pub(crate) fn duchon_poly_jet_sup(
475    latent_dim: usize,
476    order_degree: usize,
477    chart: &ChartRegion,
478    order: u32,
479) -> f64 {
480    if order_degree == 0 {
481        return if order == 0 { 1.0 } else { 0.0 };
482    }
483    patch_jet_sup(latent_dim, order_degree, chart, order)
484}
485
486/// Decoder magnitude `Σ_m ‖B_{m,:}‖₂` of an atom's frozen decoder block: the
487/// factor that converts a per-column `Φ`-jet sup `B_g` into a reconstruction
488/// jet sup `‖∂^g m‖ ≤ |z|·decoder_row_norm_sum·B_g`.
489pub(crate) fn decoder_row_norm_sum(decoder: ArrayView2<'_, f64>) -> f64 {
490    let mut acc = 0.0;
491    for row in decoder.rows() {
492        acc += row.dot(&row).sqrt();
493    }
494    acc
495}
496
497#[derive(Debug, Clone, Copy)]
498pub(crate) struct ReconstructionJetSups {
499    pub(crate) value: f64,
500    pub(crate) jacobian: f64,
501    pub(crate) hessian: f64,
502    pub(crate) third: f64,
503}
504
505pub(crate) fn pair_trig_decoder_sup(
506    sin_row: ArrayView1<'_, f64>,
507    cos_row: ArrayView1<'_, f64>,
508) -> f64 {
509    let aa = sin_row.dot(&sin_row);
510    let bb = cos_row.dot(&cos_row);
511    let ab = sin_row.dot(&cos_row);
512    let trace = aa + bb;
513    let disc = ((aa - bb) * (aa - bb) + 4.0 * ab * ab).sqrt();
514    (0.5 * (trace + disc)).sqrt()
515}
516
517pub(crate) fn periodic_reconstruction_jet_sups(
518    decoder: ArrayView2<'_, f64>,
519) -> ReconstructionJetSups {
520    let mut value = 0.0;
521    let mut jacobian = 0.0;
522    let mut hessian = 0.0;
523    let mut third = 0.0;
524    if decoder.nrows() > 0 {
525        value += decoder.row(0).dot(&decoder.row(0)).sqrt();
526    }
527    let harmonics = decoder.nrows().saturating_sub(1) / 2;
528    for h in 1..=harmonics {
529        let sin_idx = 2 * h - 1;
530        let cos_idx = 2 * h;
531        let amp = pair_trig_decoder_sup(decoder.row(sin_idx), decoder.row(cos_idx));
532        let omega = std::f64::consts::TAU * h as f64;
533        value += amp;
534        jacobian += omega * amp;
535        hessian += omega.powi(2) * amp;
536        third += omega.powi(3) * amp;
537    }
538    for row in (1 + 2 * harmonics)..decoder.nrows() {
539        let amp = decoder.row(row).dot(&decoder.row(row)).sqrt();
540        value += amp;
541        let omega = std::f64::consts::TAU * harmonics.max(1) as f64;
542        jacobian += omega * amp;
543        hessian += omega.powi(2) * amp;
544        third += omega.powi(3) * amp;
545    }
546    ReconstructionJetSups {
547        value,
548        jacobian,
549        hessian,
550        third,
551    }
552}
553
554pub(crate) fn reconstruction_jet_sups(
555    atom: &SaeManifoldAtom,
556    sups: JetSups,
557) -> ReconstructionJetSups {
558    if matches!(
559        atom.basis_kind,
560        crate::manifold::SaeAtomBasisKind::Periodic
561    ) {
562        periodic_reconstruction_jet_sups(atom.decoder_coefficients.view())
563    } else {
564        let decoder_norm_sum = decoder_row_norm_sum(atom.decoder_coefficients.view());
565        ReconstructionJetSups {
566            value: decoder_norm_sum * sups.value,
567            jacobian: decoder_norm_sum * sups.jacobian,
568            hessian: decoder_norm_sum * sups.hessian,
569            third: decoder_norm_sum * sups.third,
570        }
571    }
572}
573
574/// The Hessian-Lipschitz constant `L` of the per-row encode objective `f_k` on
575/// a chart, assembled in closed form from the basis jet sups and the decoder /
576/// amplitude / target magnitudes. See the module docs for the derivation:
577///
578/// ```text
579/// L ≤ 3·‖J_m‖·‖∇²m‖ + ‖r‖·‖∇³m‖ + L_prior,
580/// ‖∂^g m‖ ≤ |z|·S_B·B_g,   S_B = Σ_m ‖B_{m,:}‖,
581/// ‖r‖ ≤ ‖x‖ + |z|·S_B·B_0,
582/// ```
583///
584/// `prior_lipschitz` is the caller-supplied closed-form `L_prior` of the
585/// ARD/von-Mises coordinate prior (`0.0` if no prior is active on the encode).
586pub(crate) fn hessian_lipschitz_constant(
587    recon_sups: ReconstructionJetSups,
588    amplitude: f64,
589    target_norm: f64,
590    prior_lipschitz: f64,
591) -> f64 {
592    let z = amplitude.abs();
593    let m_jac = z * recon_sups.jacobian;
594    let m_hess = z * recon_sups.hessian;
595    let m_third = z * recon_sups.third;
596    let recon_value = z * recon_sups.value;
597    let r_norm = target_norm + recon_value;
598    3.0 * m_jac * m_hess + r_norm * m_third + prior_lipschitz
599}
600
601/// One offline-certified chart: a center, its Kantorovich constants, and the
602/// certified Newton-convergence radius `R_c` solved from `h = β·η·L ≤ ½` at the
603/// worst-case in-chart start.
604#[derive(Debug, Clone)]
605pub struct CertifiedChart {
606    pub region: ChartRegion,
607    /// Closed-form Hessian-Lipschitz constant `L` over the chart.
608    pub lipschitz: f64,
609    /// `β = ‖F'(t_c)⁻¹‖` at the chart center (worst-case in-chart start uses
610    /// the center's curvature; the radius is solved so the certificate holds for
611    /// any start in the ball).
612    pub beta_center: f64,
613    /// Certified Newton radius: starts within `radius` of `t_c` satisfy `h ≤ ½`.
614    pub certified_radius: f64,
615    /// Distilled amortized-encoder Jacobian for this chart (#1026 ladder item 3).
616    ///
617    /// The exact encode map `x ↦ t` solves `F(t; x) = J_m(t)ᵀ(m(t) − x) = 0`. By
618    /// the implicit function theorem its derivative at the converged root is
619    /// `dt/dx = −(∂_t F)⁻¹ (∂_x F) = H⁻¹ J_m` (since `∂_x F = −J_m`), so the
620    /// first-order Taylor expansion of the encode map about this chart's center
621    /// `t_c` is the closed-form AFFINE predictor
622    ///
623    /// ```text
624    /// t(x) ≈ t_c + (1/z) · A₁ · (x − z · m₁(t_c)),   A₁ = (J₁ᵀJ₁ + ridge·I)⁻¹ J₁,
625    /// ```
626    ///
627    /// with `J₁ = Bᵀ J_Φ(t_c)` and `m₁(t_c) = BᵀΦ(t_c)` the AMPLITUDE-1
628    /// reconstruction jets (the amplitude `z` factors out analytically, so the
629    /// stored Jacobian is amplitude-free). This is the DISTILLED amortized
630    /// encoder of the #1026 thread: the per-row Hessian factorization + Newton
631    /// iteration is moved OFFLINE into this `d × p` matrix, leaving a single
632    /// `O(d·p)` mat-vec online — no per-row eigendecomposition, no second-jet
633    /// evaluation. The Kantorovich certificate is still evaluated AT the
634    /// predicted start, so the amortized prediction is trusted iff `h ≤ ½` and an
635    /// uncertified row still routes to the exact multi-start solve (the encoder
636    /// approximates inference, the certificate keeps it honest — the thread's
637    /// "encoder + certificate-gated exact fallback" deployment). `None` when the
638    /// center's Gauss–Newton block is singular (no certifiable amortization).
639    pub amortized_jacobian: Option<Array2<f64>>,
640    /// Amplitude-1 chart-center reconstruction `m₁(t_c) = BᵀΦ(t_c)` (length `p`),
641    /// the anchor the amortized predictor expands the encode map around.
642    pub recon_center: Array1<f64>,
643}
644
645/// The per-atom encode atlas: a set of certified charts covering the atom's
646/// coordinate domain, plus the decoder/amplitude scaling needed to recompute a
647/// per-row certificate online.
648#[derive(Debug, Clone)]
649pub struct AtomEncodeAtlas {
650    pub atom_index: usize,
651    pub latent_dim: usize,
652    pub decoder_norm_sum: f64,
653    pub charts: Vec<CertifiedChart>,
654}
655
656/// Result of a certified encode over a batch of rows, carrying the honesty
657/// flag: how many rows could NOT be certified and were flagged for the exact
658/// multi-start fallback (issue #1010 — no approximation enters silently).
659#[derive(Debug, Clone)]
660pub struct EncodeResult {
661    /// Per-row encoded latent coordinates (`n_rows × latent_dim`).
662    pub coords: Array2<f64>,
663    /// Per-row certificate: `true` ⇒ the row's start satisfied `h ≤ ½` and the
664    /// 1–2 Newton steps are exact-into-the-certified-ball; `false` ⇒ flagged.
665    pub certified: Vec<bool>,
666    /// Count of rows that could not be certified. These ride the payload so the
667    /// caller routes them to the exact multi-start encode — honesty, never
668    /// silent. Equals `certified.iter().filter(|c| !**c).count()`.
669    pub encode_uncertified_count: usize,
670}
671
672impl EncodeResult {
673    pub(crate) fn from_rows(coords: Array2<f64>, certified: Vec<bool>) -> Self {
674        let encode_uncertified_count = certified.iter().filter(|c| !**c).count();
675        Self {
676            coords,
677            certified,
678            encode_uncertified_count,
679        }
680    }
681}
682
683/// Per-row Kantorovich certificate at a start `t₀` for one atom encode.
684#[derive(Debug, Clone, Copy)]
685pub struct RowCertificate {
686    pub beta: f64,
687    pub eta: f64,
688    pub lipschitz: f64,
689    /// `h = β·η·L`. The row is certified iff `h ≤ ½`.
690    pub h: f64,
691}
692
693impl RowCertificate {
694    pub fn certified(&self) -> bool {
695        self.h.is_finite() && self.h <= KANTOROVICH_THRESHOLD
696    }
697}
698
699#[derive(Debug, Clone)]
700struct CertifiedEncodeProbe {
701    coord: Array1<f64>,
702    initial_cert: RowCertificate,
703    final_cert: RowCertificate,
704}
705
706/// Canonical flat-axis polynomial degree of a cylinder `S¹ × ℝ` atom — the
707/// degree the topology-race builder ([`gam_solve::structure_harvest`]) uses
708/// for the line axis (`CylinderHarmonicEvaluator::new(_, 2)`). The encode atlas
709/// recovers the circle harmonic count from the basis width using this degree, so
710/// the two must agree.
711pub(crate) const SAE_CYLINDER_LINE_DEGREE: usize = 2;
712
713/// Build a basis-family handle for one atom from its [`SaeManifoldAtom`]. The
714/// atlas needs to evaluate the jet sups, which live on the concrete evaluator
715/// types; the atom carries the evaluator as `Arc<dyn SaeBasisEvaluator>`, so we
716/// reconstruct the family bound from the atom's basis kind + width + centers.
717pub(crate) fn family_jet_sups(
718    atom: &SaeManifoldAtom,
719    chart: &ChartRegion,
720) -> Result<JetSups, String> {
721    use crate::manifold::SaeAtomBasisKind::*;
722    let m = atom.basis_size();
723    let d = atom.latent_dim;
724    let sups = match &atom.basis_kind {
725        Periodic => {
726            let ev = PeriodicHarmonicEvaluator::new(m)?;
727            JetSups::from_family(&ev, chart)
728        }
729        Torus => {
730            // Torus basis width is `(2H+1)^d`; recover the per-axis harmonic
731            // count `H` from `axis_m = m^(1/d)` rather than a sum formula.
732            let axis_m = integer_root(m, d.max(1));
733            let num_harmonics = axis_m.saturating_sub(1) / 2;
734            let ev = TorusHarmonicEvaluator::new(d, num_harmonics.max(1))?;
735            JetSups::from_family(&ev, chart)
736        }
737        Sphere => {
738            let ev = SphereChartEvaluator;
739            JetSups::from_family(&ev, chart)
740        }
741        Cylinder => {
742            // Cylinder width is `(2H+1)·(D+1)` with the canonical flat-axis
743            // degree `D = SAE_CYLINDER_LINE_DEGREE` (the harvest convention).
744            // Recover the per-axis circle harmonic count `H` from
745            // `2H+1 = m/(D+1)`.
746            let ml = SAE_CYLINDER_LINE_DEGREE + 1;
747            if d != 2 || ml == 0 || m % ml != 0 {
748                return Err(format!(
749                    "EncodeAtlas: Cylinder atom requires latent_dim == 2 and width divisible by {ml}; got dim={d}, m={m}"
750                ));
751            }
752            let axis_mc = m / ml;
753            let h = axis_mc.saturating_sub(1) / 2;
754            let ev = CylinderHarmonicEvaluator::new(h.max(1), SAE_CYLINDER_LINE_DEGREE)?;
755            JetSups::from_family(&ev, chart)
756        }
757        Linear | EuclideanPatch | Poincare => {
758            // The patch width fixes max_degree implicitly; bound by a degree that
759            // covers the column count (conservative). Degree d-patch column count
760            // grows fast; we recover the smallest degree whose patch is ≥ m.
761            // Poincare atoms use the same tangent-coordinate polynomial decoder;
762            // their intrinsic smoothness differs in the penalty, not in Phi(t).
763            let degree = euclidean_patch_degree(d, m);
764            let ev = EuclideanPatchEvaluator::new(d, degree)?;
765            JetSups::from_family(&ev, chart)
766        }
767        Duchon => {
768            // The atom carries the basis kind but not the nullspace order, and
769            // the certificate needs an UPPER bound on L. The kernel-tail bound
770            // (cubic r³ coefficients vs the chart's r_min/r_max) is independent
771            // of the constructed order; the polynomial-block bound grows with the
772            // order, so we construct with a conservative order whose polynomial
773            // degree upper-bounds any nullspace the atom's basis width can hold.
774            // Constructing with `m = basis_size` maps to `Degree(basis_size − 1)`
775            // — an over-estimate that keeps the Lipschitz bound sound.
776            let centers = duchon_centers_from_atom(atom);
777            let conservative_m = m.max(1);
778            let ev = DuchonCoordinateEvaluator::new(centers, conservative_m)?;
779            JetSups::from_family(&ev, chart)
780        }
781        Precomputed(name) => {
782            return Err(format!(
783                "EncodeAtlas: precomputed basis '{name}' has no closed-form jet sup; route to exact encode"
784            ));
785        }
786    };
787    Ok(sups)
788}
789
790/// Smallest monomial-patch degree whose column count covers `m` basis columns.
791pub(crate) fn euclidean_patch_degree(latent_dim: usize, m: usize) -> usize {
792    // Column count of a degree-D patch in d vars is C(d+D, D). Grow D until it
793    // covers m; cap at m so a degenerate width still terminates.
794    let mut degree = 0usize;
795    while patch_column_count(latent_dim, degree) < m && degree < m {
796        degree += 1;
797    }
798    degree
799}
800
801/// Largest integer `a` with `a^k ≤ n` (the floor of the `k`-th root). Used to
802/// recover the per-axis harmonic width `axis_m` from a torus basis width
803/// `m = axis_m^d`.
804pub(crate) fn integer_root(n: usize, k: usize) -> usize {
805    if k == 0 {
806        return 1;
807    }
808    if k == 1 {
809        return n;
810    }
811    let mut a = 1usize;
812    loop {
813        let next = a + 1;
814        let mut pow: u128 = 1;
815        let mut overflow = false;
816        for _ in 0..k {
817            pow = pow.saturating_mul(next as u128);
818            if pow > n as u128 {
819                overflow = true;
820                break;
821            }
822        }
823        if overflow {
824            return a;
825        }
826        a = next;
827    }
828}
829
830pub(crate) fn patch_column_count(latent_dim: usize, degree: usize) -> usize {
831    // C(d + D, D)
832    let mut num = 1u128;
833    let mut den = 1u128;
834    for i in 1..=degree {
835        num *= (latent_dim + i) as u128;
836        den *= i as u128;
837    }
838    (num / den) as usize
839}
840
841/// Recover Duchon centers from an atom: when the evaluator is unavailable the
842/// atlas falls back to the atom's own latent-coordinate hull as the center set,
843/// which only affects the radial-tail bound conservatively.
844pub(crate) fn duchon_centers_from_atom(atom: &SaeManifoldAtom) -> Array2<f64> {
845    // One center at the origin in latent_dim space is a sound conservative
846    // default: the chart's own r_min / r_max bracket the true radial range.
847    Array2::<f64>::zeros((1, atom.latent_dim.max(1)))
848}
849
850/// The four per-column jet sups of a basis family over a chart.
851#[derive(Debug, Clone, Copy)]
852pub(crate) struct JetSups {
853    pub(crate) value: f64,
854    pub(crate) jacobian: f64,
855    pub(crate) hessian: f64,
856    pub(crate) third: f64,
857}
858
859impl JetSups {
860    pub(crate) fn from_family<B: BasisHessianLipschitz>(family: &B, chart: &ChartRegion) -> Self {
861        Self {
862            value: family.value_sup(chart),
863            jacobian: family.jacobian_sup(chart),
864            hessian: family.hessian_sup(chart),
865            third: family.third_sup(chart),
866        }
867    }
868}
869
870/// Evaluate one atom's encode objective gradient `F(t) = ∇f_k(t)` and the FULL
871/// Hessian `F'(t) = ∇²f_k(t)` at a single coordinate `t`, for a single target
872/// row `x` and fixed amplitude `z`. With `m(t) = z·BᵀΦ(t)`, `r = m − x`,
873/// `J_m = z·Bᵀ J_Φ`:
874///
875/// ```text
876/// g_t[a]   = J_m[a] · r                                  (= ∇f)
877/// H_tt[a,b] = J_m[a] · J_m[b] + r · ∂²m/∂t_a∂t_b         (= ∇²f, FULL Hessian)
878/// ```
879///
880/// The certificate uses the FULL Hessian rather than the Gauss-Newton block
881/// `J_mᵀ J_m`. This is the principled choice for Newton–Kantorovich: the
882/// theorem certifies convergence of Newton on `F = ∇f` to the unique nearby
883/// ROOT of `∇f`, but a root of `∇f` can be a maximum. The full Hessian is
884/// positive-definite exactly on the genuine-minimum basin, so requiring
885/// `λ_min(H) > 0` (finite `β`) is what flags a start that would otherwise let
886/// Gauss-Newton march into the wrong root (e.g. the circle antipode, a local
887/// max where `∇f = 0` but the full curvature is negative). The residual term
888/// needs the basis second jet `∂²Φ/∂t²`; an evaluator without one returns
889/// `None`, and the row is flagged (no silent Gauss-Newton fallback).
890pub(crate) fn encode_grad_hess(
891    atom: &SaeManifoldAtom,
892    evaluator: &dyn SaeBasisEvaluator,
893    t: ArrayView1<'_, f64>,
894    x: ArrayView1<'_, f64>,
895    amplitude: f64,
896    ridge: f64,
897) -> Result<Option<(Array1<f64>, Array2<f64>)>, String> {
898    let d = atom.latent_dim;
899    let p = atom.output_dim();
900    let m = atom.basis_size();
901    let coords = t.to_shape((1, d)).map_err(|e| e.to_string())?.to_owned();
902    let (phi, jet) = evaluator.evaluate(coords.view())?;
903    if phi.dim() != (1, m) {
904        return Err(format!(
905            "encode_grad_hess: evaluator returned phi {:?}, expected (1, {m})",
906            phi.dim()
907        ));
908    }
909    let decoder = &atom.decoder_coefficients;
910    // Reconstruction m(t) = z · Bᵀ Φ(t)  ∈ ℝᵖ.
911    let mut recon = Array1::<f64>::zeros(p);
912    for basis_col in 0..m {
913        let phi_v = phi[[0, basis_col]];
914        if phi_v == 0.0 {
915            continue;
916        }
917        for out in 0..p {
918            recon[out] += amplitude * phi_v * decoder[[basis_col, out]];
919        }
920    }
921    let residual = &recon - &x;
922    // J_m[axis] = z · Bᵀ (∂Φ/∂t_axis)  ∈ ℝᵖ.
923    let mut jm = Array2::<f64>::zeros((d, p));
924    for axis in 0..d {
925        for basis_col in 0..m {
926            let dphi = jet[[0, basis_col, axis]];
927            if dphi == 0.0 {
928                continue;
929            }
930            for out in 0..p {
931                jm[[axis, out]] += amplitude * dphi * decoder[[basis_col, out]];
932            }
933        }
934    }
935    // The full-Hessian residual term needs ∂²Φ/∂t². No second jet ⇒ no
936    // certificate (flag), never a silent Gauss-Newton substitute.
937    let second = match evaluator.second_jet_dyn(coords.view()) {
938        Some(result) => result?,
939        None => return Ok(None),
940    };
941    // g_t[axis] = J_m[axis] · r ;  H_tt[a,b] = J_m[a]·J_m[b] + r·∂²m/∂t_a∂t_b.
942    let mut g = Array1::<f64>::zeros(d);
943    let mut h = Array2::<f64>::zeros((d, d));
944    for a in 0..d {
945        let ja = jm.row(a);
946        g[a] = ja.dot(&residual);
947        for b in 0..d {
948            // Gauss-Newton block.
949            let mut hab = ja.dot(&jm.row(b));
950            // Residual · second-jet curvature: r · ∂²m_{ab},
951            // ∂²m_{ab}[out] = z · Σ_basis (∂²Φ/∂t_a∂t_b) · B[basis, out].
952            let mut curv = 0.0;
953            for basis_col in 0..m {
954                let d2phi = second[[0, basis_col, a, b]];
955                if d2phi == 0.0 {
956                    continue;
957                }
958                let mut dot = 0.0;
959                for out in 0..p {
960                    dot += residual[out] * decoder[[basis_col, out]];
961                }
962                curv += amplitude * d2phi * dot;
963            }
964            hab += curv;
965            h[[a, b]] = hab;
966        }
967    }
968    for a in 0..d {
969        h[[a, a]] += ridge;
970    }
971    Ok(Some((g, h)))
972}
973
974/// Operator-norm of `H⁻¹` (i.e. `β = 1/λ_min(H)`) and the Newton step
975/// `δ = −H⁻¹ g` with `η = ‖δ‖`, from a symmetric PSD `H` and gradient `g`.
976/// Returns `None` when `H` is numerically singular (λ_min ≤ 0) — an
977/// uncertifiable start.
978pub(crate) fn beta_eta_newton(
979    h: ArrayView2<'_, f64>,
980    g: ArrayView1<'_, f64>,
981) -> Result<Option<(f64, f64, Array1<f64>)>, String> {
982    let (vals, vecs) = h
983        .eigh(Side::Lower)
984        .map_err(|e| format!("beta_eta_newton: eigh failed: {e:?}"))?;
985    let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
986    if !(lambda_min.is_finite() && lambda_min > 0.0) {
987        return Ok(None);
988    }
989    let beta = 1.0 / lambda_min;
990    // Newton step δ = −H⁻¹ g via the eigendecomposition: δ = −Σ_i (vᵢᵀg/λᵢ) vᵢ.
991    let d = h.nrows();
992    let mut delta = Array1::<f64>::zeros(d);
993    for (col, &lam) in vals.iter().enumerate() {
994        if lam <= 0.0 {
995            return Ok(None);
996        }
997        let vi = vecs.column(col);
998        let coeff = vi.dot(&g) / lam;
999        for row in 0..d {
1000            delta[row] -= coeff * vi[row];
1001        }
1002    }
1003    let eta = delta.dot(&delta).sqrt();
1004    Ok(Some((beta, eta, delta)))
1005}
1006
1007/// Compute the per-row Kantorovich certificate for encoding target row `x`
1008/// against atom `atom` at start coordinate `t₀`, with fixed amplitude `z` and
1009/// the chart's closed-form Lipschitz constant `lipschitz`. Returns the
1010/// certificate AND the Newton step `δ = −H⁻¹ g` so the caller can advance.
1011pub fn row_certificate(
1012    atom: &SaeManifoldAtom,
1013    evaluator: &dyn SaeBasisEvaluator,
1014    t0: ArrayView1<'_, f64>,
1015    x: ArrayView1<'_, f64>,
1016    amplitude: f64,
1017    lipschitz: f64,
1018    ridge: f64,
1019) -> Result<(RowCertificate, Array1<f64>), String> {
1020    let uncertified = || {
1021        (
1022            RowCertificate {
1023                beta: f64::INFINITY,
1024                eta: f64::INFINITY,
1025                lipschitz,
1026                h: f64::INFINITY,
1027            },
1028            Array1::<f64>::zeros(atom.latent_dim),
1029        )
1030    };
1031    // No second jet ⇒ no full Hessian ⇒ uncertifiable (flag).
1032    let Some((g, h)) = encode_grad_hess(atom, evaluator, t0, x, amplitude, ridge)? else {
1033        return Ok(uncertified());
1034    };
1035    match beta_eta_newton(h.view(), g.view())? {
1036        Some((beta, eta, delta)) => {
1037            let cert = RowCertificate {
1038                beta,
1039                eta,
1040                lipschitz,
1041                h: beta * eta * lipschitz,
1042            };
1043            Ok((cert, delta))
1044        }
1045        // Indefinite / negative-curvature full Hessian: the start is at or past
1046        // a basin boundary (a max/saddle of f), not the minimum basin — flag.
1047        None => Ok(uncertified()),
1048    }
1049}
1050
1051fn uncertified_certificate(lipschitz: f64) -> RowCertificate {
1052    RowCertificate {
1053        beta: f64::INFINITY,
1054        eta: f64::INFINITY,
1055        lipschitz,
1056        h: f64::INFINITY,
1057    }
1058}
1059
1060fn refine_certified_start(
1061    atom: &SaeManifoldAtom,
1062    evaluator: &dyn SaeBasisEvaluator,
1063    mut t: Array1<f64>,
1064    x: ArrayView1<'_, f64>,
1065    amplitude: f64,
1066    lipschitz: f64,
1067    ridge: f64,
1068    newton_steps: usize,
1069    initial_cert: RowCertificate,
1070    mut delta: Array1<f64>,
1071) -> Result<Option<CertifiedEncodeProbe>, String> {
1072    assert!(initial_cert.certified());
1073    let mut final_cert = initial_cert;
1074    for _ in 0..newton_steps {
1075        t = &t + &delta;
1076        let (cert, next_delta) =
1077            row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz, ridge)?;
1078        if !cert.certified() {
1079            return Ok(None);
1080        }
1081        final_cert = cert;
1082        delta = next_delta;
1083    }
1084    Ok(Some(CertifiedEncodeProbe {
1085        coord: t,
1086        initial_cert,
1087        final_cert,
1088    }))
1089}
1090
1091/// Certify an encode probe from `t_start`, navigating into the Kantorovich basin
1092/// first if needed (#1154/#1026). The Kantorovich quantity `h = β·η·L` scales with
1093/// amplitude through `L`, so at unit amplitude a positive-definite chart-center /
1094/// distilled start can sit OUTSIDE the certified ball (`h > ½`). Rather than
1095/// flagging it uncertified immediately — which made the encoder certify ZERO
1096/// held-out rows at amplitude 1.0 and fall back to the exact solve for everything —
1097/// take plain Newton steps toward the root, re-certifying at each iterate, while
1098/// the Kantorovich quantity `h = β·η·L` keeps CONTRACTING toward the ½ bound. The
1099/// certificate at the landing point is a full Kantorovich guarantee from there
1100/// (`h ≤ ½` ⇒ Newton converges to the in-ball root), so this only ever WIDENS the
1101/// certified set; it never certifies a non-convergent start.
1102///
1103/// Termination is the natural Newton stopping rule — there is no arbitrary step
1104/// budget. The warm-up stops and flags for the exact fallback when either the start
1105/// is not steppable (indefinite / non-finite Hessian — at or past a basin boundary)
1106/// or a step fails to reduce `h` (the iterate is not approaching a certifiable
1107/// in-chart root: its root lies outside this chart's valid Lipschitz region, or the
1108/// start was past the basin — empirically the rows that miss *plateau*, so more
1109/// steps cannot help; the lever there is denser charts, not more iterations). On
1110/// success the start is refined `newton_steps` further by [`refine_certified_start`].
1111fn certify_with_basin_warmup(
1112    atom: &SaeManifoldAtom,
1113    evaluator: &dyn SaeBasisEvaluator,
1114    t_start: Array1<f64>,
1115    x: ArrayView1<'_, f64>,
1116    amplitude: f64,
1117    lipschitz: f64,
1118    ridge: f64,
1119    newton_steps: usize,
1120    chart_center: ArrayView1<'_, f64>,
1121    chart_radius: f64,
1122) -> Result<Option<CertifiedEncodeProbe>, String> {
1123    // SOUNDNESS GUARD: `lipschitz` is the chart's Hessian-Lipschitz sup, which is
1124    // only a valid bound over this chart's ball `‖t − center‖ ≤ radius` for the
1125    // chart-local families (`EuclideanPatch`/`Linear`/`Poincare` monomial patches,
1126    // `Cylinder` line axis, `Duchon` radial kernels). If a warm-up iterate leaves
1127    // that ball, `row_certificate` would compute `h = β·η·L` with an `L` that no
1128    // longer bounds the true geometry there, so `h ≤ ½` would NOT imply Kantorovich
1129    // convergence — a false certificate. (The `h`-contraction check does NOT catch
1130    // this: `h` can decrease monotonically toward an out-of-chart root the whole
1131    // way.) So we keep every certified iterate inside the chart; a row whose root is
1132    // outside this chart flags for the exact fallback — its lever is a denser grid,
1133    // not a step using an invalid `L`. Global-`L` families (periodic/torus/sphere)
1134    // route their points to charts whose centers are near the root, so the guard
1135    // rarely trips for them, and where it does the row was out-of-chart anyway.
1136    let in_chart = |t: &Array1<f64>| -> bool {
1137        let r2: f64 = t
1138            .iter()
1139            .zip(chart_center.iter())
1140            .map(|(a, b)| (a - b) * (a - b))
1141            .sum();
1142        r2 <= chart_radius * chart_radius
1143    };
1144    let mut t = t_start;
1145    // The distilled / chart-center start must itself be in-chart for its certificate
1146    // to be valid; a bad IFT prediction landing outside the chart is uncertifiable.
1147    if !in_chart(&t) {
1148        return Ok(None);
1149    }
1150    let (mut cert, mut delta) =
1151        row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz, ridge)?;
1152    while !cert.certified() {
1153        // Not steppable (indefinite / non-finite Hessian): flag.
1154        if !(cert.h.is_finite() && cert.beta.is_finite() && cert.eta.is_finite()) {
1155            return Ok(None);
1156        }
1157        let prev_h = cert.h;
1158        let next = &t + &delta;
1159        // Refuse to step where the chart's `L` is no longer valid (see guard above).
1160        if !in_chart(&next) {
1161            return Ok(None);
1162        }
1163        t = next;
1164        let (next_cert, next_delta) =
1165            row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz, ridge)?;
1166        cert = next_cert;
1167        delta = next_delta;
1168        // The warm-up only helps while h keeps contracting toward ½. Once a step
1169        // fails to reduce it, the iterate is not converging to a certifiable in-chart
1170        // root — flag for the exact fallback (no arbitrary step budget).
1171        if !cert.h.is_finite() || cert.h >= prev_h {
1172            return Ok(None);
1173        }
1174    }
1175    refine_certified_start(
1176        atom,
1177        evaluator,
1178        t,
1179        x,
1180        amplitude,
1181        lipschitz,
1182        ridge,
1183        newton_steps,
1184        cert,
1185        delta,
1186    )
1187}
1188
1189fn kantorovich_root_radius(cert: RowCertificate) -> f64 {
1190    if !cert.certified() || !(cert.eta.is_finite() && cert.eta >= 0.0) {
1191        return f64::INFINITY;
1192    }
1193    if cert.eta == 0.0 {
1194        return 0.0;
1195    }
1196    if !(cert.h.is_finite() && cert.h >= 0.0) {
1197        return f64::INFINITY;
1198    }
1199    let h = cert.h.min(KANTOROVICH_THRESHOLD);
1200    let discriminant = (1.0 - 2.0 * h).max(0.0).sqrt();
1201    let radius = 2.0 * cert.eta / (1.0 + discriminant);
1202    if radius.is_finite() {
1203        radius
1204    } else {
1205        f64::INFINITY
1206    }
1207}
1208
1209fn distilled_probe_tolerance(
1210    amortized: &CertifiedEncodeProbe,
1211    cold: &CertifiedEncodeProbe,
1212    amplitude: f64,
1213    x: ArrayView1<'_, f64>,
1214) -> f64 {
1215    let certified_radius =
1216        kantorovich_root_radius(amortized.final_cert) + kantorovich_root_radius(cold.final_cert);
1217    let coord_scale = amortized.coord.dot(&amortized.coord).sqrt()
1218        + cold.coord.dot(&cold.coord).sqrt()
1219        + x.dot(&x).sqrt()
1220        + amplitude.abs()
1221        + 1.0;
1222    certified_radius + 1024.0 * f64::EPSILON * coord_scale
1223}
1224
1225fn latent_coordinate_distance(
1226    atom: &SaeManifoldAtom,
1227    lhs: ArrayView1<'_, f64>,
1228    rhs: ArrayView1<'_, f64>,
1229) -> f64 {
1230    let mut acc = 0.0;
1231    for axis in 0..lhs.len().min(rhs.len()) {
1232        let mut diff = (lhs[axis] - rhs[axis]).abs();
1233        if let Some(period) = latent_axis_period(atom, axis) {
1234            let wrapped = diff.rem_euclid(period);
1235            diff = wrapped.min(period - wrapped);
1236        }
1237        acc += diff * diff;
1238    }
1239    acc.sqrt()
1240}
1241
1242fn latent_axis_period(atom: &SaeManifoldAtom, axis: usize) -> Option<f64> {
1243    use crate::manifold::SaeAtomBasisKind::*;
1244    match &atom.basis_kind {
1245        Periodic | Torus => Some(1.0),
1246        Cylinder if axis == 0 => Some(1.0),
1247        Sphere if axis == 1 => Some(std::f64::consts::TAU),
1248        _ => None,
1249    }
1250}
1251
1252/// Configuration for [`EncodeAtlas`] construction and online encode. All fields
1253/// are explicit; the atlas never reads global state and adds no CLI flags.
1254#[derive(Debug, Clone, Copy)]
1255pub struct AtlasConfig {
1256    /// Grid resolution per latent axis for offline chart centers (the
1257    /// SHAPE_BAND grid idiom).
1258    pub grid_resolution: usize,
1259    /// Levenberg ridge floor added to the per-row Gauss-Newton Hessian.
1260    pub ridge: f64,
1261    /// Number of online Newton refinement steps after a certified start (1 or 2
1262    /// per issue #1010).
1263    pub newton_steps: usize,
1264}
1265
1266impl Default for AtlasConfig {
1267    fn default() -> Self {
1268        Self {
1269            grid_resolution: 16,
1270            ridge: 1.0e-9,
1271            newton_steps: 2,
1272        }
1273    }
1274}
1275
1276/// The encode atlas: per-atom certified charts plus the online certified-encode
1277/// driver (issue #1010).
1278#[derive(Debug, Clone)]
1279pub struct EncodeAtlas {
1280    pub atoms: Vec<AtomEncodeAtlas>,
1281    pub config: AtlasConfig,
1282}
1283
1284impl EncodeAtlas {
1285    /// Build the offline atlas over a frozen dictionary: for each atom, lay down
1286    /// chart centers on the atom's coordinate grid and certify a Newton radius
1287    /// from the Kantorovich inequality at the worst-case in-chart start.
1288    ///
1289    /// `amplitude_bound[k]` is the per-atom bound on `|z_k|` used to scale the
1290    /// reconstruction jets (the offline `L` must hold for the largest amplitude
1291    /// the encode can produce); `target_norm_bound` bounds `‖x‖` over the data.
1292    pub fn build(
1293        atoms: &[SaeManifoldAtom],
1294        amplitude_bound: &[f64],
1295        target_norm_bound: f64,
1296        config: AtlasConfig,
1297    ) -> Result<Self, String> {
1298        if amplitude_bound.len() != atoms.len() {
1299            return Err(format!(
1300                "EncodeAtlas::build: amplitude_bound length {} != atom count {}",
1301                amplitude_bound.len(),
1302                atoms.len()
1303            ));
1304        }
1305        let mut atom_atlases = Vec::with_capacity(atoms.len());
1306        for (k, atom) in atoms.iter().enumerate() {
1307            let atlas =
1308                Self::build_atom_atlas(k, atom, amplitude_bound[k], target_norm_bound, &config)?;
1309            atom_atlases.push(atlas);
1310        }
1311        Ok(Self {
1312            atoms: atom_atlases,
1313            config,
1314        })
1315    }
1316
1317    pub(crate) fn build_atom_atlas(
1318        atom_index: usize,
1319        atom: &SaeManifoldAtom,
1320        amplitude_bound: f64,
1321        target_norm_bound: f64,
1322        config: &AtlasConfig,
1323    ) -> Result<AtomEncodeAtlas, String> {
1324        let d = atom.latent_dim;
1325        let decoder_norm_sum = decoder_row_norm_sum(atom.decoder_coefficients.view());
1326        let centers = chart_center_grid(atom, config.grid_resolution);
1327        // Half the inter-center spacing is the natural in-chart radius so the
1328        // charts tile the grid without gaps; refined below if the certificate
1329        // fails at that radius.
1330        let nominal_radius = chart_nominal_radius(atom, config.grid_resolution);
1331        let mut charts = Vec::with_capacity(centers.nrows());
1332        for c in 0..centers.nrows() {
1333            let center = centers.row(c).to_owned();
1334            let region = chart_region(atom, center.clone(), nominal_radius);
1335            let sups = family_jet_sups(atom, &region)?;
1336            let recon_sups = reconstruction_jet_sups(atom, sups);
1337            let lipschitz =
1338                hessian_lipschitz_constant(recon_sups, amplitude_bound, target_norm_bound, 0.0);
1339            // β at the chart center bounds the worst-case in-chart curvature
1340            // (the Gauss-Newton Hessian is continuous; the certified radius is
1341            // solved so the certificate is robust to the start within the ball).
1342            let beta_center = match center_beta(atom, &center, config.ridge) {
1343                Some(b) => b,
1344                None => {
1345                    // Degenerate center curvature: no certifiable chart here, and
1346                    // no amortized Jacobian (the same singular Gauss–Newton block).
1347                    charts.push(CertifiedChart {
1348                        region,
1349                        lipschitz,
1350                        beta_center: f64::INFINITY,
1351                        certified_radius: 0.0,
1352                        amortized_jacobian: None,
1353                        recon_center: Array1::<f64>::zeros(atom.output_dim()),
1354                    });
1355                    continue;
1356                }
1357            };
1358            // Distill the amortized-encoder Jacobian at this center (#1026 ladder
1359            // item 3): the IFT derivative of the encode map, precomputed offline
1360            // so the online encode is one mat-vec. A finite `beta_center` (above)
1361            // means the Gauss–Newton block is non-singular, so this succeeds
1362            // alongside it; the pair travels together on the chart.
1363            let (amortized_jacobian, recon_center) =
1364                match center_amortized_jacobian(atom, &center, config.ridge) {
1365                    Some((a1, m1)) => (Some(a1), m1),
1366                    None => (None, Array1::<f64>::zeros(atom.output_dim())),
1367                };
1368            // Certified radius from h = β·η·L ≤ ½ with η ≤ R (Newton step length
1369            // is bounded by the start distance to the root, itself ≤ chart
1370            // radius at worst): R_c = ½ / (β·L), capped at the nominal radius.
1371            let certified_radius = if lipschitz > 0.0 && beta_center.is_finite() {
1372                (0.5 / (beta_center * lipschitz)).min(region.radius)
1373            } else {
1374                region.radius
1375            };
1376            charts.push(CertifiedChart {
1377                region,
1378                lipschitz,
1379                beta_center,
1380                certified_radius,
1381                amortized_jacobian,
1382                recon_center,
1383            });
1384        }
1385        Ok(AtomEncodeAtlas {
1386            atom_index,
1387            latent_dim: d,
1388            decoder_norm_sum,
1389            charts,
1390        })
1391    }
1392
1393    fn refine_certified_encode_start(
1394        &self,
1395        atom: &SaeManifoldAtom,
1396        evaluator: &dyn SaeBasisEvaluator,
1397        chart: &CertifiedChart,
1398        t: Array1<f64>,
1399        x: ArrayView1<'_, f64>,
1400        amplitude: f64,
1401    ) -> Result<(Array1<f64>, RowCertificate), String> {
1402        // Certify from the warm start, navigating into the Kantorovich basin first
1403        // if the unit-amplitude start has h > ½ (see `certify_with_basin_warmup`).
1404        let Some(probe) = certify_with_basin_warmup(
1405            atom,
1406            evaluator,
1407            t,
1408            x,
1409            amplitude,
1410            chart.lipschitz,
1411            self.config.ridge,
1412            self.config.newton_steps,
1413            chart.region.center.view(),
1414            chart.region.radius,
1415        )?
1416        else {
1417            return Ok((
1418                Array1::<f64>::zeros(atom.latent_dim),
1419                uncertified_certificate(chart.lipschitz),
1420            ));
1421        };
1422        Ok((probe.coord, probe.initial_cert))
1423    }
1424
1425    /// Online certified encode of one target row `x` against one atom `k` with
1426    /// fixed amplitude `z`. Routes to the nearest chart, starts from that chart's
1427    /// distilled IFT warm start, runs `config.newton_steps` Newton steps, and
1428    /// returns the encoded coordinate with its certificate. An uncertified start
1429    /// (no chart, no distilled Jacobian, non-positive amplitude, or `h > ½`)
1430    /// flags the row for the exact multi-start caller.
1431    pub fn certified_encode_row(
1432        &self,
1433        atom: &SaeManifoldAtom,
1434        atom_index: usize,
1435        x: ArrayView1<'_, f64>,
1436        amplitude: f64,
1437    ) -> Result<(Array1<f64>, RowCertificate), String> {
1438        let atom_atlas = self
1439            .atoms
1440            .get(atom_index)
1441            .ok_or_else(|| format!("certified_encode_row: atom {atom_index} not in atlas"))?;
1442        let d = atom.latent_dim;
1443        // A missing basis evaluator means the amortized/cold predictor cannot fire
1444        // for this atom (e.g. a frozen-baseline or first-build atom that never
1445        // attached a distilled evaluator). That is exactly the "cannot certify"
1446        // state — flag the row uncertified (zeros coords, ∞ certificate) so the
1447        // upstream exact multi-start solve owns it, never a hard error that aborts
1448        // the whole criterion. Mirrors the no-chart / singular-Jacobian branches.
1449        let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
1450            return Ok((
1451                Array1::<f64>::zeros(d),
1452                RowCertificate {
1453                    beta: f64::INFINITY,
1454                    eta: f64::INFINITY,
1455                    lipschitz: f64::INFINITY,
1456                    h: f64::INFINITY,
1457                },
1458            ));
1459        };
1460
1461        // Route to the nearest chart center (ambient routing happens upstream via
1462        // sae_candidate_index; here we pick the in-atom chart). Without a chart we
1463        // cannot certify — flag immediately.
1464        let Some((chart_idx, _)) = nearest_chart(atom_atlas, x, atom, evaluator.as_ref()) else {
1465            return Ok((
1466                Array1::<f64>::zeros(d),
1467                RowCertificate {
1468                    beta: f64::INFINITY,
1469                    eta: f64::INFINITY,
1470                    lipschitz: f64::INFINITY,
1471                    h: f64::INFINITY,
1472                },
1473            ));
1474        };
1475        let chart = &atom_atlas.charts[chart_idx];
1476        let Some(t) = amortized_warm_start(chart, x, amplitude) else {
1477            return Ok((
1478                Array1::<f64>::zeros(d),
1479                uncertified_certificate(chart.lipschitz),
1480            ));
1481        };
1482        self.refine_certified_encode_start(atom, evaluator.as_ref(), chart, t, x, amplitude)
1483    }
1484
1485    /// Amortized (distilled) encode of one target row `x` against one atom `k`
1486    /// with fixed amplitude `z` (#1026 ladder item 3).
1487    ///
1488    /// Routes to the nearest chart, then predicts the latent coordinate in CLOSED
1489    /// FORM from that chart's precomputed implicit-function-theorem Jacobian:
1490    ///
1491    /// ```text
1492    /// t̂ = t_c + (1/z) · A₁ · (x − z · m₁(t_c)),
1493    /// ```
1494    ///
1495    /// a single `O(d·p)` mat-vec — no per-row Hessian factorization or
1496    /// eigendecomposition, which is the amortization. The Kantorovich
1497    /// certificate is then evaluated AT the predicted start `t̂` with the chart's
1498    /// closed-form Lipschitz constant. A prediction is accepted only when that
1499    /// certificate holds, an independent cold chart-center probe also certifies,
1500    /// and the two refined coordinates agree within the two probes' final
1501    /// Kantorovich root-radius bounds. This keeps the distilled path honest
1502    /// without letting the exact probe reuse the distilled warm start it is
1503    /// auditing. A chart without a distilled Jacobian (singular Gauss–Newton
1504    /// block) flags the row.
1505    pub fn amortized_encode_row(
1506        &self,
1507        atom: &SaeManifoldAtom,
1508        atom_index: usize,
1509        x: ArrayView1<'_, f64>,
1510        amplitude: f64,
1511    ) -> Result<(Array1<f64>, RowCertificate), String> {
1512        let atom_atlas = self
1513            .atoms
1514            .get(atom_index)
1515            .ok_or_else(|| format!("amortized_encode_row: atom {atom_index} not in atlas"))?;
1516        let d = atom.latent_dim;
1517        let uncertified = || {
1518            (
1519                Array1::<f64>::zeros(d),
1520                RowCertificate {
1521                    beta: f64::INFINITY,
1522                    eta: f64::INFINITY,
1523                    lipschitz: f64::INFINITY,
1524                    h: f64::INFINITY,
1525                },
1526            )
1527        };
1528        // A missing basis evaluator means the distilled predictor cannot fire for
1529        // this atom — flag the row uncertified (the exact upstream solve owns it)
1530        // rather than erroring, exactly as the no-chart / singular-Jacobian /
1531        // non-positive-amplitude branches below do. Never a silent wrong encode,
1532        // never a hard abort of the criterion.
1533        let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
1534            return Ok(uncertified());
1535        };
1536        let Some((chart_idx, _)) = nearest_chart(atom_atlas, x, atom, evaluator.as_ref()) else {
1537            return Ok(uncertified());
1538        };
1539        let chart = &atom_atlas.charts[chart_idx];
1540        // Closed-form predicted start t̂ = t_c + (1/z)·A₁·(x − z·m₁). `None` when
1541        // the chart's Gauss–Newton block was singular (no distilled Jacobian, so
1542        // the amortized predictor cannot fire) or the amplitude is not strictly
1543        // positive and finite (a near-inactive atom, where the amplitude-divided
1544        // map is undefined) — either way flag for the exact fallback, never a
1545        // silent wrong encode.
1546        let Some(t_hat) = amortized_warm_start(chart, x, amplitude) else {
1547            return Ok(uncertified());
1548        };
1549        // Evaluate the SAME Kantorovich certificate at the predicted start. The
1550        // amortized prediction is trusted only if this certificate holds AND an
1551        // independent cold chart-center probe certifies and agrees below the
1552        // two probes' final Kantorovich root-radius bounds. This avoids the
1553        // self-referential gate where the "exact" probe is warm-started by the
1554        // same distilled prediction it is supposed to audit.
1555        let Some(amortized_probe) = certify_with_basin_warmup(
1556            atom,
1557            evaluator.as_ref(),
1558            t_hat,
1559            x,
1560            amplitude,
1561            chart.lipschitz,
1562            self.config.ridge,
1563            self.config.newton_steps,
1564            chart.region.center.view(),
1565            chart.region.radius,
1566        )?
1567        else {
1568            return Ok((
1569                Array1::<f64>::zeros(d),
1570                uncertified_certificate(chart.lipschitz),
1571            ));
1572        };
1573
1574        let cold_start = chart.region.center.clone();
1575        let Some(cold_probe) = certify_with_basin_warmup(
1576            atom,
1577            evaluator.as_ref(),
1578            cold_start,
1579            x,
1580            amplitude,
1581            chart.lipschitz,
1582            self.config.ridge,
1583            self.config.newton_steps,
1584            chart.region.center.view(),
1585            chart.region.radius,
1586        )?
1587        else {
1588            return Ok((
1589                amortized_probe.coord,
1590                uncertified_certificate(chart.lipschitz),
1591            ));
1592        };
1593
1594        let gap =
1595            latent_coordinate_distance(atom, amortized_probe.coord.view(), cold_probe.coord.view());
1596        let tolerance = distilled_probe_tolerance(&amortized_probe, &cold_probe, amplitude, x);
1597        if !(gap.is_finite() && gap <= tolerance) {
1598            return Ok((
1599                amortized_probe.coord,
1600                uncertified_certificate(chart.lipschitz),
1601            ));
1602        }
1603        Ok((amortized_probe.coord, amortized_probe.initial_cert))
1604    }
1605
1606    /// Batched amortized (distilled) encode over many rows against one atom
1607    /// (#1026 ladder item 3, corpus-rate). Each row uses the closed-form
1608    /// per-chart Jacobian predictor and carries its own Kantorovich certificate;
1609    /// uncertified rows are flagged in [`EncodeResult::encode_uncertified_count`]
1610    /// for the exact multi-start fallback. Row-independent against the frozen
1611    /// dictionary, so the batch fans out over rows (deterministic row-order
1612    /// assembly, bit-identical run-to-run), staying sequential inside a rayon
1613    /// worker to avoid nested oversubscription.
1614    pub fn amortized_encode_batch(
1615        &self,
1616        atom: &SaeManifoldAtom,
1617        atom_index: usize,
1618        targets: ArrayView2<'_, f64>,
1619        amplitudes: ArrayView1<'_, f64>,
1620    ) -> Result<EncodeResult, String> {
1621        let n = targets.nrows();
1622        if amplitudes.len() != n {
1623            return Err(format!(
1624                "amortized_encode_batch: amplitudes len {} != rows {n}",
1625                amplitudes.len()
1626            ));
1627        }
1628        let d = atom.latent_dim;
1629        let encode_rows =
1630            |range: std::ops::Range<usize>| -> Result<Vec<(Array1<f64>, bool)>, String> {
1631                range
1632                    .map(|row| {
1633                        let (t, cert) = self.amortized_encode_row(
1634                            atom,
1635                            atom_index,
1636                            targets.row(row),
1637                            amplitudes[row],
1638                        )?;
1639                        Ok((t, cert.certified()))
1640                    })
1641                    .collect()
1642            };
1643        let rows: Vec<(Array1<f64>, bool)> =
1644            if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
1645                use rayon::prelude::*;
1646                const CHUNK: usize = 256;
1647                let n_chunks = n.div_ceil(CHUNK);
1648                let chunked: Vec<Vec<(Array1<f64>, bool)>> = (0..n_chunks)
1649                    .into_par_iter()
1650                    .map(|c| {
1651                        let start = c * CHUNK;
1652                        let end = (start + CHUNK).min(n);
1653                        encode_rows(start..end)
1654                    })
1655                    .collect::<Result<_, _>>()?;
1656                chunked.into_iter().flatten().collect()
1657            } else {
1658                encode_rows(0..n)?
1659            };
1660        let mut coords = Array2::<f64>::zeros((n, d));
1661        let mut certified = Vec::with_capacity(n);
1662        for (row, (t, cert)) in rows.into_iter().enumerate() {
1663            coords.row_mut(row).assign(&t);
1664            certified.push(cert);
1665        }
1666        Ok(EncodeResult::from_rows(coords, certified))
1667    }
1668
1669    /// Batched certified encode over many rows against one atom (the #988
1670    /// throughput consumer). Each row carries its own certificate; uncertified
1671    /// rows are flagged in [`EncodeResult::encode_uncertified_count`] for the
1672    /// exact multi-start fallback.
1673    pub fn certified_encode_batch(
1674        &self,
1675        atom: &SaeManifoldAtom,
1676        atom_index: usize,
1677        targets: ArrayView2<'_, f64>,
1678        amplitudes: ArrayView1<'_, f64>,
1679    ) -> Result<EncodeResult, String> {
1680        let n = targets.nrows();
1681        if amplitudes.len() != n {
1682            return Err(format!(
1683                "certified_encode_batch: amplitudes len {} != rows {n}",
1684                amplitudes.len()
1685            ));
1686        }
1687        let d = atom.latent_dim;
1688        // Per-row encode is independent against a frozen dictionary (#1010), so
1689        // the corpus-rate batch fans out over rows (#1026 amortized-encoder leg /
1690        // #977 Stage-3 corpus encode). Each row produces an owned `(t, certified)`
1691        // pair; results are assembled back in row order so the output is
1692        // bit-identical run-to-run regardless of thread scheduling. Stay
1693        // sequential inside a rayon worker (e.g. when an outer atom-level fan-out
1694        // owns the pool) to avoid nested oversubscription. The first row that
1695        // fails to encode propagates its error deterministically.
1696        let encode_rows =
1697            |range: std::ops::Range<usize>| -> Result<Vec<(Array1<f64>, bool)>, String> {
1698                range
1699                    .map(|row| {
1700                        let (t, cert) = self.certified_encode_row(
1701                            atom,
1702                            atom_index,
1703                            targets.row(row),
1704                            amplitudes[row],
1705                        )?;
1706                        Ok((t, cert.certified()))
1707                    })
1708                    .collect()
1709            };
1710        let rows: Vec<(Array1<f64>, bool)> =
1711            if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
1712                use rayon::prelude::*;
1713                const CHUNK: usize = 256;
1714                let n_chunks = n.div_ceil(CHUNK);
1715                let chunked: Vec<Vec<(Array1<f64>, bool)>> = (0..n_chunks)
1716                    .into_par_iter()
1717                    .map(|c| {
1718                        let start = c * CHUNK;
1719                        let end = (start + CHUNK).min(n);
1720                        encode_rows(start..end)
1721                    })
1722                    .collect::<Result<_, _>>()?;
1723                chunked.into_iter().flatten().collect()
1724            } else {
1725                encode_rows(0..n)?
1726            };
1727        let mut coords = Array2::<f64>::zeros((n, d));
1728        let mut certified = Vec::with_capacity(n);
1729        for (row, (t, cert)) in rows.into_iter().enumerate() {
1730            coords.row_mut(row).assign(&t);
1731            certified.push(cert);
1732        }
1733        Ok(EncodeResult::from_rows(coords, certified))
1734    }
1735
1736    /// Batched GEMM "fast" amortized encode — the traditional-encoder forward
1737    /// pass, WITH manifolds. For every row this applies the SAME closed-form
1738    /// affine predictor as [`amortized_warm_start`]
1739    /// (`t̂ = t_c + (1/z)·A₁·(x − z·m₁)`), but routed and applied as batched
1740    /// matrix products instead of a per-row loop wrapped in the Kantorovich
1741    /// certificate + basin warmup. NO per-row certificate is taken: this is the
1742    /// speed mode (the certified `*_encode_*` paths remain the accuracy mode).
1743    ///
1744    /// Cost is GEMM-bound: one `(n × p)·(p × d)` decode-distance product for
1745    /// nearest-chart routing (skipped for single-chart atoms) plus, per chart,
1746    /// one `(n_c × p)·(p × d)` predictor product — i.e. `≈ X·Wᵀ`, exactly a
1747    /// dense SAE encoder's forward map.
1748    ///
1749    /// Degenerate rows are handled exactly as `amortized_warm_start` flags them
1750    /// (returns `None` ⇒ zeroed coord here): a missing basis evaluator, a chart
1751    /// whose Gauss–Newton block was singular (`amortized_jacobian == None`), or a
1752    /// non-finite / non-positive amplitude. Those rows are zeroed (never a panic,
1753    /// never a silent wrong encode), and their indices are returned in the
1754    /// `valid` mask so the caller can route them to the exact path if desired.
1755    ///
1756    /// Returns `(coords, valid)` where `coords` is `n × d` and `valid[row]` is
1757    /// `true` iff the amortized predictor fired for that row.
1758    pub fn amortized_encode_batch_fast(
1759        &self,
1760        atom: &SaeManifoldAtom,
1761        atom_index: usize,
1762        x: ArrayView2<'_, f64>,
1763        amplitudes: ArrayView1<'_, f64>,
1764    ) -> Result<(Array2<f64>, Vec<bool>), String> {
1765        let n = x.nrows();
1766        let p = atom.output_dim();
1767        let d = atom.latent_dim;
1768        if x.ncols() != p {
1769            return Err(format!(
1770                "amortized_encode_batch_fast: x has {} cols but atom output dim is {p}",
1771                x.ncols()
1772            ));
1773        }
1774        if amplitudes.len() != n {
1775            return Err(format!(
1776                "amortized_encode_batch_fast: amplitudes len {} != rows {n}",
1777                amplitudes.len()
1778            ));
1779        }
1780        let atom_atlas = self.atoms.get(atom_index).ok_or_else(|| {
1781            format!("amortized_encode_batch_fast: atom {atom_index} not in atlas")
1782        })?;
1783        let mut coords = Array2::<f64>::zeros((n, d));
1784        let mut valid = vec![false; n];
1785
1786        // A missing basis evaluator means the distilled predictor cannot fire for
1787        // this atom — every row is uncertified (zeroed), exactly like the per-row
1788        // `amortized_encode_row` no-evaluator branch.
1789        let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
1790            return Ok((coords, valid));
1791        };
1792
1793        // ── Routing recon-centers (one evaluation per chart, batched). ────────
1794        // `nearest_chart` routes a row to the chart whose center reconstruction
1795        // `m(t_c) = BᵀΦ(t_c)` is closest in ‖·‖², skipping charts with
1796        // `certified_radius <= 0`. Evaluate every candidate center ONCE here
1797        // (chart count ≪ n) and GEMM the recon, reproducing `nearest_chart`'s
1798        // per-chart recon bit-for-bit (same φ·decoder accumulation order).
1799        let valid_charts: Vec<usize> = (0..atom_atlas.charts.len())
1800            .filter(|&c| atom_atlas.charts[c].certified_radius > 0.0)
1801            .collect();
1802        if valid_charts.is_empty() {
1803            return Ok((coords, valid));
1804        }
1805        // Stack candidate centers (C × d) and evaluate the basis in one call.
1806        let mut centers = Array2::<f64>::zeros((valid_charts.len(), d));
1807        for (ci, &c) in valid_charts.iter().enumerate() {
1808            centers
1809                .row_mut(ci)
1810                .assign(&atom_atlas.charts[c].region.center);
1811        }
1812        let (phi_centers, _jet) = evaluator
1813            .evaluate(centers.view())
1814            .map_err(|err| format!("amortized_encode_batch_fast: center eval: {err}"))?;
1815        // recon_centers = Φ_centers · decoder  (C × p), the routing targets.
1816        let recon_centers = phi_centers.dot(&atom.decoder_coefficients);
1817        // Per-chart routing key: route_idx[row] = argmin_c ‖x_row − recon_c‖².
1818        // ‖x − r‖² = ‖x‖² − 2 x·r + ‖r‖²; the ‖x‖² term is row-constant so the
1819        // argmin uses S = X·recon_centersᵀ and the per-chart ‖r‖². First chart
1820        // wins on a tie (strict `<`), matching `nearest_chart`.
1821        let route_idx: Vec<usize> = if valid_charts.len() == 1 {
1822            vec![0usize; n]
1823        } else {
1824            let s = x.dot(&recon_centers.t()); // (n × C)
1825            let r_sq: Vec<f64> = (0..valid_charts.len())
1826                .map(|c| recon_centers.row(c).dot(&recon_centers.row(c)))
1827                .collect();
1828            (0..n)
1829                .map(|row| {
1830                    let mut best_c = 0usize;
1831                    let mut best_d = f64::INFINITY;
1832                    for c in 0..valid_charts.len() {
1833                        let dist = r_sq[c] - 2.0 * s[[row, c]];
1834                        if dist < best_d {
1835                            best_d = dist;
1836                            best_c = c;
1837                        }
1838                    }
1839                    best_c
1840                })
1841                .collect()
1842        };
1843
1844        // ── Per-chart batched affine predictor. ───────────────────────────────
1845        // For rows routed to chart `c` with finite jacobian `A₁` (d × p) and
1846        // center reconstruction `m₁` (= `chart.recon_center`), the predictor is
1847        //   t̂ = t_c − A₁·m₁ + (1/z)·(A₁·x).
1848        // The `A₁·x` term is one `(n_c × p)·(p × d)` GEMM; the `1/z` is a per-row
1849        // scalar; `t_c − A₁·m₁` is a per-chart constant. This reproduces
1850        // `amortized_warm_start` up to FP reassociation of the per-output sum.
1851        for (ci, &c) in valid_charts.iter().enumerate() {
1852            let chart = &atom_atlas.charts[c];
1853            let Some(a1) = chart.amortized_jacobian.as_ref() else {
1854                // Singular Gauss–Newton block: predictor cannot fire — the rows
1855                // routed here stay zeroed/uncertified (same as warm_start `None`).
1856                continue;
1857            };
1858            // Gather rows routed to this chart with a usable amplitude.
1859            let rows_here: Vec<usize> = (0..n)
1860                .filter(|&row| {
1861                    route_idx[row] == ci
1862                        && amplitudes[row].is_finite()
1863                        && amplitudes[row].abs() > 0.0
1864                })
1865                .collect();
1866            if rows_here.is_empty() {
1867                continue;
1868            }
1869            // X_c (n_c × p).
1870            let mut x_c = Array2::<f64>::zeros((rows_here.len(), p));
1871            for (i, &row) in rows_here.iter().enumerate() {
1872                x_c.row_mut(i).assign(&x.row(row));
1873            }
1874            // U = X_c · A₁ᵀ  (n_c × d) — the GEMM.
1875            let u = x_c.dot(&a1.t());
1876            // Per-chart constant base = t_c − A₁·m₁ (length d).
1877            let m1 = &chart.recon_center;
1878            let a1_m1 = a1.dot(m1); // (d)
1879            let base = &chart.region.center - &a1_m1; // (d)
1880            for (i, &row) in rows_here.iter().enumerate() {
1881                let inv_z = 1.0 / amplitudes[row];
1882                for axis in 0..d {
1883                    coords[[row, axis]] = base[axis] + u[[i, axis]] * inv_z;
1884                }
1885                valid[row] = true;
1886            }
1887        }
1888        Ok((coords, valid))
1889    }
1890
1891    /// Fast batched FULL forward pass against one atom: encode → decode, the
1892    /// manifold analogue of a traditional SAE's `x̂ = z·D` (decoder `D`, code `z`).
1893    ///
1894    /// A traditional SAE decodes with one GEMM. The manifold SAE's reconstruction
1895    /// is `m(t̂) = z·Φ(t̂)·B` (module header) — the SAME GEMM `Φ·B`, but the code
1896    /// `Φ(t̂)` is the curved chart basis evaluated at the encoded latent coordinate
1897    /// rather than a flat one-hot. So the fast forward is exactly:
1898    ///   1. [`amortized_encode_batch_fast`] → per-row latent coords `t̂` (one
1899    ///      routing GEMM + one affine GEMM per chart — a traditional `W·x+b`);
1900    ///   2. ONE batched basis evaluation `Φ(t̂)` (the manifold-curvature step a
1901    ///      flat SAE doesn't have — `n×m`);
1902    ///   3. ONE GEMM `recon = Φ(t̂)·B` (`(n×m)·(m×p)` — a traditional decoder
1903    ///      `z·D`), then the per-row amplitude scale `z`.
1904    ///
1905    /// Rows the encoder could not certify-predict (no evaluator / singular
1906    /// Gauss–Newton block / non-finite-or-zero amplitude) are returned as a ZERO
1907    /// reconstruction and flagged `false` in the valid-mask — never a silent wrong
1908    /// decode. The reconstruction of a valid row equals, bit-for-bit up to GEMM
1909    /// reassociation, `z·(Φ(t̂_row)·B)` with `t̂` from the per-row predictor.
1910    pub fn amortized_reconstruct_batch_fast(
1911        &self,
1912        atom: &SaeManifoldAtom,
1913        atom_index: usize,
1914        x: ArrayView2<'_, f64>,
1915        amplitudes: ArrayView1<'_, f64>,
1916    ) -> Result<(Array2<f64>, Vec<bool>), String> {
1917        let n = x.nrows();
1918        let p = atom.output_dim();
1919        // Step 1: batched encode → latent coords (reuses the fast routing+affine).
1920        let (coords, valid) = self.amortized_encode_batch_fast(atom, atom_index, x, amplitudes)?;
1921        let mut recon = Array2::<f64>::zeros((n, p));
1922        // A missing evaluator means no row could encode — every row is zeroed and
1923        // already flagged `false` by the encode; nothing to decode.
1924        let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
1925            return Ok((recon, valid));
1926        };
1927        // Step 2: ONE batched basis evaluation Φ(t̂) over all rows (n × m). Invalid
1928        // rows carry coords = 0 (the chart-origin); we still evaluate them in the
1929        // batch for a single GEMM, then zero their reconstruction below — the basis
1930        // is finite at the origin so this cannot poison the valid rows' GEMM.
1931        let (phi, _jet) = evaluator
1932            .evaluate(coords.view())
1933            .map_err(|err| format!("amortized_reconstruct_batch_fast: basis eval: {err}"))?;
1934        // Step 3: ONE GEMM recon = Φ·B (n × p), then per-row amplitude scale z.
1935        // m(t̂) = z·Φ(t̂)·B, matching the module header and `fill_decoded_row`'s
1936        // `Φ·decoder` accumulation (the amplitude is applied once here).
1937        let decoded = phi.dot(&atom.decoder_coefficients); // (n × p), amplitude-1
1938        for row in 0..n {
1939            if !valid[row] {
1940                continue; // stays zeroed — uncertified, like warm_start `None`.
1941            }
1942            let z = amplitudes[row];
1943            for col in 0..p {
1944                recon[[row, col]] = z * decoded[[row, col]];
1945            }
1946        }
1947        Ok((recon, valid))
1948    }
1949
1950    /// LSH-routed certified encode (issue #1010 step 2 + 3): for each target
1951    /// row, the existing [`SaeCandidateIndex`] (#985/#994) proposes the
1952    /// best-aligned atom by frame alignment to the row direction; the row is then
1953    /// encoded against THAT atom's certified chart atlas. This is the production
1954    /// routing path — the LSH does sublinear atom selection, the atlas does the
1955    /// in-atom nearest-chart routing and the per-row Kantorovich certificate.
1956    ///
1957    /// `atoms[id]` must be aligned with the atlas's `atoms[id]` (same dictionary
1958    /// order the atlas was built from and the sketch/index were built over).
1959    /// A row with no LSH proposal (empty bucket) is flagged uncertified — it
1960    /// routes to the exact multi-start fallback, never a silent wrong encode.
1961    pub fn certified_encode_with_index<S: AtomFrameSketch + Sync>(
1962        &self,
1963        atoms: &[SaeManifoldAtom],
1964        index: &SaeCandidateIndex,
1965        sketch: &S,
1966        targets: ArrayView2<'_, f64>,
1967        amplitudes: ArrayView1<'_, f64>,
1968        latent_dim: usize,
1969    ) -> Result<EncodeResult, String> {
1970        let n = targets.nrows();
1971        if amplitudes.len() != n {
1972            return Err(format!(
1973                "certified_encode_with_index: amplitudes len {} != rows {n}",
1974                amplitudes.len()
1975            ));
1976        }
1977        let budget = auto_candidate_budget(atoms.len().max(1));
1978        // LSH-routed per-row encode is independent across rows (sublinear atom
1979        // selection + frozen-dictionary in-atom Newton), so the corpus-rate batch
1980        // fans out over rows (#1026 amortized-encoder/routing leg / #977 Stage-3).
1981        // `None` coords (no LSH candidate) carry through as a zeroed row flagged
1982        // uncertified — identical to the sequential semantics. Results assemble
1983        // back in row order (bit-identical run-to-run); the first encode error
1984        // propagates deterministically. Stay sequential inside a rayon worker to
1985        // avoid nested oversubscription.
1986        let encode_rows =
1987            |range: std::ops::Range<usize>| -> Result<Vec<Option<(Array1<f64>, bool)>>, String> {
1988                range
1989                    .map(|row| {
1990                        // The row direction is the (unit-tolerant) target; the LSH
1991                        // ranks atoms by how much of that direction lies in each
1992                        // atom's column space. `propose` returns the top-`budget`
1993                        // atom ids by exact frame alignment.
1994                        let proposal = index.propose(sketch, targets.row(row), budget, true);
1995                        let Some(&best_atom) = proposal.proposed.first() else {
1996                            // No LSH candidate: flag for the exact fallback.
1997                            return Ok(None);
1998                        };
1999                        // Routing-confidence gate (#1026): the in-atom certificate
2000                        // attests convergence WITHIN best_atom, not that best_atom is
2001                        // the globally-correct atom. A non-empty but low-alignment
2002                        // gather likely missed the true best atom (LSH recall < 1) —
2003                        // flag it for the exact fallback rather than silently
2004                        // certifying a mis-route. See CANDIDATE_ROUTING_MIN_ALIGNMENT.
2005                        // A NaN alignment — a zero-norm target row, ‖d‖ = 0 — must
2006                        // also flag for the exact fallback, not slip through `<`.
2007                        let routing_alignment = sketch.alignment(best_atom, targets.row(row));
2008                        if !routing_alignment.is_finite()
2009                            || routing_alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
2010                        {
2011                            return Ok(None);
2012                        }
2013                        let atom = atoms.get(best_atom).ok_or_else(|| {
2014                            format!(
2015                                "certified_encode_with_index: proposed atom {best_atom} out of range"
2016                            )
2017                        })?;
2018                        let (t, cert) = self.certified_encode_row(
2019                            atom,
2020                            best_atom,
2021                            targets.row(row),
2022                            amplitudes[row],
2023                        )?;
2024                        // Heterogeneous-atom dictionaries with different latent_dim
2025                        // per atom are not supported by the batched API: the caller
2026                        // declares one shared `latent_dim` for the output tensor.
2027                        // Silently zeroing the coord row while recording a
2028                        // certified=true flag would produce corrupted
2029                        // reconstructions downstream — error loudly instead.
2030                        if t.len() != latent_dim {
2031                            return Err(format!(
2032                                "certified_encode_with_index: atom {best_atom} returned t.len()={} \
2033                                 but declared latent_dim={latent_dim}; heterogeneous-dim \
2034                                 dictionaries are not supported by this batched encode path",
2035                                t.len()
2036                            ));
2037                        }
2038                        Ok(Some((t, cert.certified())))
2039                    })
2040                    .collect()
2041            };
2042        let rows: Vec<Option<(Array1<f64>, bool)>> =
2043            if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
2044                use rayon::prelude::*;
2045                const CHUNK: usize = 256;
2046                let n_chunks = n.div_ceil(CHUNK);
2047                let chunked: Vec<Vec<Option<(Array1<f64>, bool)>>> = (0..n_chunks)
2048                    .into_par_iter()
2049                    .map(|c| {
2050                        let start = c * CHUNK;
2051                        let end = (start + CHUNK).min(n);
2052                        encode_rows(start..end)
2053                    })
2054                    .collect::<Result<_, _>>()?;
2055                chunked.into_iter().flatten().collect()
2056            } else {
2057                encode_rows(0..n)?
2058            };
2059        let mut coords = Array2::<f64>::zeros((n, latent_dim));
2060        let mut certified = Vec::with_capacity(n);
2061        for (row, slot) in rows.into_iter().enumerate() {
2062            match slot {
2063                Some((t, cert)) => {
2064                    coords.row_mut(row).assign(&t);
2065                    certified.push(cert);
2066                }
2067                None => certified.push(false),
2068            }
2069        }
2070        Ok(EncodeResult::from_rows(coords, certified))
2071    }
2072
2073    /// LSH-routed AMORTIZED (distilled) encode — the production token-rate
2074    /// encoder of #1026 ladder item 3. Identical routing to
2075    /// [`Self::certified_encode_with_index`] (LSH proposes the best-aligned atom,
2076    /// the atlas routes to the in-atom nearest chart), but the in-atom encode is
2077    /// the closed-form per-chart Jacobian predictor + certificate gate of
2078    /// [`Self::amortized_encode_row`] rather than the certified Newton-refinement
2079    /// path.
2080    /// This is the deployment path: the distilled affine map produces the encode
2081    /// in one mat-vec, the Kantorovich certificate decides trust-or-fallback per
2082    /// row, and uncertified rows (the adversarial tail the thread expects to
2083    /// concentrate on rare tokens) are flagged for the exact multi-start solve —
2084    /// compute goes where the questions are. Row-independent against the frozen
2085    /// dictionary, so the batch fans out over rows with deterministic row-order
2086    /// assembly (bit-identical run-to-run).
2087    pub fn amortized_encode_with_index<S: AtomFrameSketch + Sync>(
2088        &self,
2089        atoms: &[SaeManifoldAtom],
2090        index: &SaeCandidateIndex,
2091        sketch: &S,
2092        targets: ArrayView2<'_, f64>,
2093        amplitudes: ArrayView1<'_, f64>,
2094        latent_dim: usize,
2095    ) -> Result<EncodeResult, String> {
2096        let n = targets.nrows();
2097        if amplitudes.len() != n {
2098            return Err(format!(
2099                "amortized_encode_with_index: amplitudes len {} != rows {n}",
2100                amplitudes.len()
2101            ));
2102        }
2103        let budget = auto_candidate_budget(atoms.len().max(1));
2104        let encode_rows =
2105            |range: std::ops::Range<usize>| -> Result<Vec<Option<(Array1<f64>, bool)>>, String> {
2106                range
2107                    .map(|row| {
2108                        let proposal = index.propose(sketch, targets.row(row), budget, true);
2109                        let Some(&best_atom) = proposal.proposed.first() else {
2110                            return Ok(None);
2111                        };
2112                        // Routing-confidence gate (#1026): flag low-alignment gathers
2113                        // for the exact fallback so a missed-true-best LSH route is
2114                        // never silently certified. See CANDIDATE_ROUTING_MIN_ALIGNMENT
2115                        // and certified_encode_with_index for the full rationale.
2116                        // A NaN alignment — a zero-norm target row, ‖d‖ = 0 — must
2117                        // also flag for the exact fallback, not slip through `<`.
2118                        let routing_alignment = sketch.alignment(best_atom, targets.row(row));
2119                        if !routing_alignment.is_finite()
2120                            || routing_alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
2121                        {
2122                            return Ok(None);
2123                        }
2124                        let atom = atoms.get(best_atom).ok_or_else(|| {
2125                            format!(
2126                                "amortized_encode_with_index: proposed atom {best_atom} out of range"
2127                            )
2128                        })?;
2129                        let (t, cert) = self.amortized_encode_row(
2130                            atom,
2131                            best_atom,
2132                            targets.row(row),
2133                            amplitudes[row],
2134                        )?;
2135                        if t.len() != latent_dim {
2136                            return Err(format!(
2137                                "amortized_encode_with_index: atom {best_atom} returned t.len()={} \
2138                                 but declared latent_dim={latent_dim}; heterogeneous-dim \
2139                                 dictionaries are not supported by this batched encode path",
2140                                t.len()
2141                            ));
2142                        }
2143                        Ok(Some((t, cert.certified())))
2144                    })
2145                    .collect()
2146            };
2147        let rows: Vec<Option<(Array1<f64>, bool)>> =
2148            if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
2149                use rayon::prelude::*;
2150                const CHUNK: usize = 256;
2151                let n_chunks = n.div_ceil(CHUNK);
2152                let chunked: Vec<Vec<Option<(Array1<f64>, bool)>>> = (0..n_chunks)
2153                    .into_par_iter()
2154                    .map(|c| {
2155                        let start = c * CHUNK;
2156                        let end = (start + CHUNK).min(n);
2157                        encode_rows(start..end)
2158                    })
2159                    .collect::<Result<_, _>>()?;
2160                chunked.into_iter().flatten().collect()
2161            } else {
2162                encode_rows(0..n)?
2163            };
2164        let mut coords = Array2::<f64>::zeros((n, latent_dim));
2165        let mut certified = Vec::with_capacity(n);
2166        for (row, slot) in rows.into_iter().enumerate() {
2167            match slot {
2168                Some((t, cert)) => {
2169                    coords.row_mut(row).assign(&t);
2170                    certified.push(cert);
2171                }
2172                None => certified.push(false),
2173            }
2174        }
2175        Ok(EncodeResult::from_rows(coords, certified))
2176    }
2177
2178    /// LSH-routed FAST amortized encode over the WHOLE dictionary — the
2179    /// multi-atom, corpus-rate analogue of [`Self::amortized_encode_with_index`].
2180    ///
2181    /// `amortized_encode_with_index` routes per row, then runs the per-row
2182    /// closed-form predictor + Kantorovich certificate + cold cross-check on each
2183    /// row independently. This fast variant keeps the SAME sublinear per-row LSH
2184    /// routing (cheap — `index.propose` + the alignment gate), but replaces the
2185    /// per-row predictor with the GEMM-batched [`Self::amortized_encode_batch_fast`]:
2186    /// it GROUPS rows by their proposed atom and runs one batched affine-predictor
2187    /// pass per atom-group (a routing GEMM + a predictor GEMM each), reproducing a
2188    /// traditional SAE's whole-dictionary `W·x+b` throughput. No per-row
2189    /// certificate — this is the speed mode validated as accuracy-parity with the
2190    /// certified solve (`fast_forward_is_accuracy_parity_with_certified`).
2191    ///
2192    /// Returns the per-row latent coords and a valid-mask: `false` for a row with
2193    /// no LSH proposal, a sub-threshold/NaN routing alignment, or one the batched
2194    /// predictor could not fire on (no evaluator / singular Gauss–Newton block /
2195    /// non-finite-or-zero amplitude). Each row is written exactly once (disjoint
2196    /// per-atom groups), so the result is independent of group iteration order.
2197    pub fn amortized_encode_with_index_fast<S: AtomFrameSketch + Sync>(
2198        &self,
2199        atoms: &[SaeManifoldAtom],
2200        index: &SaeCandidateIndex,
2201        sketch: &S,
2202        targets: ArrayView2<'_, f64>,
2203        amplitudes: ArrayView1<'_, f64>,
2204        latent_dim: usize,
2205    ) -> Result<(Array2<f64>, Vec<bool>), String> {
2206        let n = targets.nrows();
2207        if amplitudes.len() != n {
2208            return Err(format!(
2209                "amortized_encode_with_index_fast: amplitudes len {} != rows {n}",
2210                amplitudes.len()
2211            ));
2212        }
2213        let budget = auto_candidate_budget(atoms.len().max(1));
2214        // ── Per-row LSH routing (sublinear), grouped by proposed atom. ──────────
2215        let mut groups: std::collections::HashMap<usize, Vec<usize>> =
2216            std::collections::HashMap::new();
2217        for row in 0..n {
2218            let proposal = index.propose(sketch, targets.row(row), budget, true);
2219            let Some(&best_atom) = proposal.proposed.first() else {
2220                continue;
2221            };
2222            // Same routing-confidence gate as the per-row path: a low-alignment or
2223            // NaN (zero-norm row) gather flags the row for the exact fallback, so a
2224            // missed-true-best LSH route is never silently encoded.
2225            let routing_alignment = sketch.alignment(best_atom, targets.row(row));
2226            if !routing_alignment.is_finite()
2227                || routing_alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
2228            {
2229                continue;
2230            }
2231            groups.entry(best_atom).or_default().push(row);
2232        }
2233
2234        let mut coords = Array2::<f64>::zeros((n, latent_dim));
2235        let mut valid = vec![false; n];
2236        // ── Per-atom batched predictor over each group's rows. ──────────────────
2237        for (atom_idx, rows_here) in groups {
2238            let atom = atoms.get(atom_idx).ok_or_else(|| {
2239                format!("amortized_encode_with_index_fast: proposed atom {atom_idx} out of range")
2240            })?;
2241            if atom.latent_dim != latent_dim {
2242                return Err(format!(
2243                    "amortized_encode_with_index_fast: atom {atom_idx} latent_dim {} != declared \
2244                     {latent_dim}; heterogeneous-dim dictionaries are not supported by this path",
2245                    atom.latent_dim
2246                ));
2247            }
2248            // Gather this group's target rows and amplitudes (contiguous sub-batch).
2249            let p = atom.output_dim();
2250            let mut x_sub = Array2::<f64>::zeros((rows_here.len(), p));
2251            let mut amp_sub = Array1::<f64>::zeros(rows_here.len());
2252            for (i, &row) in rows_here.iter().enumerate() {
2253                x_sub.row_mut(i).assign(&targets.row(row));
2254                amp_sub[i] = amplitudes[row];
2255            }
2256            let (sub_coords, sub_valid) =
2257                self.amortized_encode_batch_fast(atom, atom_idx, x_sub.view(), amp_sub.view())?;
2258            for (i, &row) in rows_here.iter().enumerate() {
2259                if sub_valid[i] {
2260                    coords.row_mut(row).assign(&sub_coords.row(i));
2261                    valid[row] = true;
2262                }
2263            }
2264        }
2265        Ok((coords, valid))
2266    }
2267
2268    /// LSH-routed FAST full forward over the WHOLE dictionary: encode → decode,
2269    /// the multi-atom analogue of [`Self::amortized_reconstruct_batch_fast`]. Same
2270    /// sublinear per-row routing + per-atom grouping as
2271    /// [`Self::amortized_encode_with_index_fast`], but each group is run through
2272    /// the batched reconstruct (`m(t̂) = z·Φ(t̂)·B`) so the result is the per-row
2273    /// reconstruction in the ambient space. Rows that do not route/predict decode
2274    /// to an exact zero reconstruction and are flagged `false`.
2275    pub fn amortized_reconstruct_with_index_fast<S: AtomFrameSketch + Sync>(
2276        &self,
2277        atoms: &[SaeManifoldAtom],
2278        index: &SaeCandidateIndex,
2279        sketch: &S,
2280        targets: ArrayView2<'_, f64>,
2281        amplitudes: ArrayView1<'_, f64>,
2282    ) -> Result<(Array2<f64>, Vec<bool>), String> {
2283        let n = targets.nrows();
2284        let p = targets.ncols();
2285        if amplitudes.len() != n {
2286            return Err(format!(
2287                "amortized_reconstruct_with_index_fast: amplitudes len {} != rows {n}",
2288                amplitudes.len()
2289            ));
2290        }
2291        let budget = auto_candidate_budget(atoms.len().max(1));
2292        let mut groups: std::collections::HashMap<usize, Vec<usize>> =
2293            std::collections::HashMap::new();
2294        for row in 0..n {
2295            let proposal = index.propose(sketch, targets.row(row), budget, true);
2296            let Some(&best_atom) = proposal.proposed.first() else {
2297                continue;
2298            };
2299            let routing_alignment = sketch.alignment(best_atom, targets.row(row));
2300            if !routing_alignment.is_finite()
2301                || routing_alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
2302            {
2303                continue;
2304            }
2305            groups.entry(best_atom).or_default().push(row);
2306        }
2307
2308        let mut recon = Array2::<f64>::zeros((n, p));
2309        let mut valid = vec![false; n];
2310        for (atom_idx, rows_here) in groups {
2311            let atom = atoms.get(atom_idx).ok_or_else(|| {
2312                format!(
2313                    "amortized_reconstruct_with_index_fast: proposed atom {atom_idx} out of range"
2314                )
2315            })?;
2316            if atom.output_dim() != p {
2317                return Err(format!(
2318                    "amortized_reconstruct_with_index_fast: atom {atom_idx} output_dim {} != target \
2319                     dim {p}",
2320                    atom.output_dim()
2321                ));
2322            }
2323            let mut x_sub = Array2::<f64>::zeros((rows_here.len(), p));
2324            let mut amp_sub = Array1::<f64>::zeros(rows_here.len());
2325            for (i, &row) in rows_here.iter().enumerate() {
2326                x_sub.row_mut(i).assign(&targets.row(row));
2327                amp_sub[i] = amplitudes[row];
2328            }
2329            let (sub_recon, sub_valid) = self.amortized_reconstruct_batch_fast(
2330                atom,
2331                atom_idx,
2332                x_sub.view(),
2333                amp_sub.view(),
2334            )?;
2335            for (i, &row) in rows_here.iter().enumerate() {
2336                if sub_valid[i] {
2337                    recon.row_mut(row).assign(&sub_recon.row(i));
2338                    valid[row] = true;
2339                }
2340            }
2341        }
2342        Ok((recon, valid))
2343    }
2344}
2345
2346/// Offline `β = 1/λ_min(H_GN)` at a chart center from the Gauss-Newton block
2347/// `H_GN = J_mᵀ J_m` (residual-free). The offline `β` bounds the curvature the
2348/// online certificate sees: charts are placed where the encode lands, so the
2349/// representative residual is small and `H_GN` is the dominant, residual-free
2350/// curvature estimate. (The online per-row certificate still uses the FULL
2351/// Hessian; this is only the offline radius-sizing curvature.) Returns `None`
2352/// for a degenerate center (`λ_min ≤ 0`), which marks an uncertifiable chart.
2353pub(crate) fn center_beta(atom: &SaeManifoldAtom, center: &Array1<f64>, ridge: f64) -> Option<f64> {
2354    let evaluator = atom.basis_evaluator.as_ref()?.clone();
2355    let d = atom.latent_dim;
2356    let p = atom.output_dim();
2357    let m = atom.basis_size();
2358    let coords = center.view().to_shape((1, d)).ok()?.to_owned();
2359    let (_phi, jet) = evaluator.evaluate(coords.view()).ok()?;
2360    let decoder = &atom.decoder_coefficients;
2361    // J_m[axis] = Bᵀ (∂Φ/∂t_axis) ∈ ℝᵖ (amplitude-1; curvature scales with z²
2362    // and is absorbed conservatively by the amplitude-bounded Lipschitz term).
2363    let mut jm = Array2::<f64>::zeros((d, p));
2364    for axis in 0..d {
2365        for basis_col in 0..m {
2366            let dphi = jet[[0, basis_col, axis]];
2367            if dphi == 0.0 {
2368                continue;
2369            }
2370            for out in 0..p {
2371                jm[[axis, out]] += dphi * decoder[[basis_col, out]];
2372            }
2373        }
2374    }
2375    let mut h = Array2::<f64>::zeros((d, d));
2376    for a in 0..d {
2377        for b in 0..d {
2378            h[[a, b]] = jm.row(a).dot(&jm.row(b));
2379        }
2380        h[[a, a]] += ridge;
2381    }
2382    let (vals, _vecs) = h.eigh(Side::Lower).ok()?;
2383    let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
2384    if lambda_min.is_finite() && lambda_min > 0.0 {
2385        Some(1.0 / lambda_min)
2386    } else {
2387        None
2388    }
2389}
2390
2391/// #1154 — the amortized encoder's closed-form warm-start coordinate for one
2392/// row `x` against one chart at amplitude `z`:
2393///
2394/// ```text
2395/// t̂ = t_c + (1/z) · A₁ · (x − z · m₁(t_c)),
2396/// ```
2397///
2398/// a single `O(d·p)` mat-vec from the chart's precomputed IFT Jacobian `A₁` and
2399/// center reconstruction `m₁`. Returns `None` when the chart carries no
2400/// distilled Jacobian (singular Gauss–Newton block) or the amplitude is not
2401/// strictly positive and finite (a near-inactive atom, where the
2402/// amplitude-divided map is undefined) — in those cases the caller starts from
2403/// the chart center instead. Shared by the amortized encode (where `t̂` is the
2404/// prediction) and the exact certified encode (where `t̂` is the Newton
2405/// warm-start that then refines to stationarity, Design A).
2406pub(crate) fn amortized_warm_start(
2407    chart: &CertifiedChart,
2408    x: ArrayView1<'_, f64>,
2409    amplitude: f64,
2410) -> Option<Array1<f64>> {
2411    let a1 = chart.amortized_jacobian.as_ref()?;
2412    if !(amplitude.is_finite() && amplitude.abs() > 0.0) {
2413        return None;
2414    }
2415    let d = a1.nrows();
2416    let mut t_hat = chart.region.center.clone();
2417    for (out_idx, &m1_out) in chart.recon_center.iter().enumerate().take(a1.ncols()) {
2418        let resid = x[out_idx] - amplitude * m1_out;
2419        for axis in 0..d {
2420            t_hat[axis] += a1[[axis, out_idx]] * resid / amplitude;
2421        }
2422    }
2423    Some(t_hat)
2424}
2425
2426/// The amplitude-1 distilled amortized-encoder Jacobian at a chart center
2427/// (#1026 ladder item 3). Returns `(A₁, m₁)` where `m₁ = BᵀΦ(t_c) ∈ ℝᵖ` is the
2428/// amplitude-1 center reconstruction and `A₁ = (J₁ᵀJ₁ + ridge·I)⁻¹ J₁ ∈ ℝ^{d×p}`
2429/// is the implicit-function-theorem derivative of the encode map `x ↦ t`
2430/// (Gauss–Newton block — the residual-free, dominant curvature exactly as the
2431/// offline radius-sizing `β`). With these, the online encode of a row `x` at
2432/// amplitude `z` is the closed-form affine prediction
2433/// `t = t_c + (1/z)·A₁·(x − z·m₁)` — one mat-vec, no per-row factorization.
2434/// `None` when the basis has no jet or the Gauss–Newton block is singular (no
2435/// certifiable amortization), matching `center_beta`'s gate so a chart with a
2436/// finite `β` always carries a Jacobian and vice versa.
2437pub(crate) fn center_amortized_jacobian(
2438    atom: &SaeManifoldAtom,
2439    center: &Array1<f64>,
2440    ridge: f64,
2441) -> Option<(Array2<f64>, Array1<f64>)> {
2442    let evaluator = atom.basis_evaluator.as_ref()?.clone();
2443    let d = atom.latent_dim;
2444    let p = atom.output_dim();
2445    let m = atom.basis_size();
2446    let coords = center.view().to_shape((1, d)).ok()?.to_owned();
2447    let (phi, jet) = evaluator.evaluate(coords.view()).ok()?;
2448    let decoder = &atom.decoder_coefficients;
2449    // m₁(t_c) = BᵀΦ(t_c) ∈ ℝᵖ (amplitude-1 center reconstruction).
2450    let mut recon = Array1::<f64>::zeros(p);
2451    for basis_col in 0..m {
2452        let phi_v = phi[[0, basis_col]];
2453        if phi_v == 0.0 {
2454            continue;
2455        }
2456        for out in 0..p {
2457            recon[out] += phi_v * decoder[[basis_col, out]];
2458        }
2459    }
2460    // J₁[axis] = Bᵀ (∂Φ/∂t_axis) ∈ ℝᵖ (amplitude-1; z factors out analytically).
2461    let mut jm = Array2::<f64>::zeros((d, p));
2462    for axis in 0..d {
2463        for basis_col in 0..m {
2464            let dphi = jet[[0, basis_col, axis]];
2465            if dphi == 0.0 {
2466                continue;
2467            }
2468            for out in 0..p {
2469                jm[[axis, out]] += dphi * decoder[[basis_col, out]];
2470            }
2471        }
2472    }
2473    // H_GN = J₁ J₁ᵀ + ridge·I ∈ ℝ^{d×d}.
2474    let mut h = Array2::<f64>::zeros((d, d));
2475    for a in 0..d {
2476        for b in 0..d {
2477            h[[a, b]] = jm.row(a).dot(&jm.row(b));
2478        }
2479        h[[a, a]] += ridge;
2480    }
2481    let (vals, vecs) = h.eigh(Side::Lower).ok()?;
2482    let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
2483    if !(lambda_min.is_finite() && lambda_min > 0.0) {
2484        return None;
2485    }
2486    // A₁ = H_GN⁻¹ J₁ via the eigendecomposition: H⁻¹ = Σ_i (1/λᵢ) vᵢ vᵢᵀ, so
2487    // A₁[:, out] = Σ_i (vᵢ · J₁[:, out]) / λᵢ · vᵢ. Column-by-column keeps it the
2488    // d×p Jacobian (one SPD solve reused across all p output channels).
2489    let mut a1 = Array2::<f64>::zeros((d, p));
2490    for out in 0..p {
2491        let jcol = jm.column(out);
2492        for (i, &lam) in vals.iter().enumerate() {
2493            if !(lam.is_finite() && lam > 0.0) {
2494                return None;
2495            }
2496            let vi = vecs.column(i);
2497            let coeff = vi.dot(&jcol) / lam;
2498            for row in 0..d {
2499                a1[[row, out]] += coeff * vi[row];
2500            }
2501        }
2502    }
2503    Some((a1, recon))
2504}
2505
2506/// Route a target row to the nearest chart of an atom by reconstruction
2507/// distance: the chart whose center reconstruction `m(t_c)` is closest to `x`.
2508/// Returns the chart index and the distance, or `None` when the atom has no
2509/// charts.
2510pub(crate) fn nearest_chart(
2511    atom_atlas: &AtomEncodeAtlas,
2512    x: ArrayView1<'_, f64>,
2513    atom: &SaeManifoldAtom,
2514    evaluator: &dyn SaeBasisEvaluator,
2515) -> Option<(usize, f64)> {
2516    if atom_atlas.charts.is_empty() {
2517        return None;
2518    }
2519    let d = atom.latent_dim;
2520    let p = atom.output_dim();
2521    let m = atom.basis_size();
2522    let mut best: Option<(usize, f64)> = None;
2523    for (idx, chart) in atom_atlas.charts.iter().enumerate() {
2524        if chart.certified_radius <= 0.0 {
2525            continue;
2526        }
2527        let coords = match chart.region.center.view().to_shape((1, d)) {
2528            Ok(c) => c.to_owned(),
2529            Err(_) => continue,
2530        };
2531        let Ok((phi, _jet)) = evaluator.evaluate(coords.view()) else {
2532            continue;
2533        };
2534        // m(t_c) = Bᵀ Φ(t_c) (amplitude-1; routing is scale-tolerant).
2535        let mut recon = Array1::<f64>::zeros(p);
2536        for basis_col in 0..m {
2537            let phi_v = phi[[0, basis_col]];
2538            if phi_v == 0.0 {
2539                continue;
2540            }
2541            for out in 0..p {
2542                recon[out] += phi_v * atom.decoder_coefficients[[basis_col, out]];
2543            }
2544        }
2545        let diff = &recon - &x;
2546        let dist = diff.dot(&diff);
2547        if best.map(|(_, b)| dist < b).unwrap_or(true) {
2548            best = Some((idx, dist));
2549        }
2550    }
2551    best
2552}
2553
2554/// Maximum number of chart centers laid down per atom (the SHAPE_BAND grid
2555/// point cap; mirrors `SHAPE_BAND_MAX_POINTS` in the atom band machinery).
2556pub(crate) const SHAPE_BAND_MAX_POINTS: usize = 512;
2557
2558/// Lay down chart centers on an atom's coordinate grid (the SHAPE_BAND grid
2559/// idiom): a regular grid spanning the compact latent domain for periodic /
2560/// sphere / torus atoms, and a strided cover of the latent axes for unbounded
2561/// (Duchon / Euclidean) atoms.
2562///
2563/// Periodic / torus latents are fractions of one period, so the per-axis grid
2564/// spans `[0, 1)`; the sphere chart spans `lat ∈ [−π/2, π/2]`, `lon ∈ [−π, π)`.
2565/// These conventions match the basis evaluators (the fraction-of-period circle
2566/// harmonic and the lat/lon sphere chart).
2567pub(crate) fn chart_center_grid(atom: &SaeManifoldAtom, resolution: usize) -> Array2<f64> {
2568    use crate::manifold::SaeAtomBasisKind::*;
2569    let d = atom.latent_dim;
2570    match &atom.basis_kind {
2571        Periodic | Torus => regular_product_grid(d, resolution, 0.0, 1.0, false),
2572        // Cylinder `S¹ × ℝ`: axis 0 is the periodic circle `[0, 1)` (no
2573        // endpoint, like the harmonic axes); axis 1 is the unbounded line,
2574        // covered by a strided unit box `[-0.5, 0.5]` about the origin (like the
2575        // Euclidean patch). The certified radius refines each chart; out-of-cover
2576        // line starts route to the exact fallback honestly.
2577        Cylinder if d == 2 => cylinder_chart_center_grid(resolution),
2578        Cylinder => regular_product_grid(d, resolution, -0.5, 0.5, true),
2579        Sphere if d == 2 => sphere_latlon_grid(resolution),
2580        Linear | Sphere | Duchon | EuclideanPatch | Poincare | Precomputed(_) => {
2581            // Unbounded / non-compact latents: a strided cover of a unit box
2582            // about the origin per axis. The certified radius refines each chart;
2583            // out-of-cover starts route to the exact fallback honestly.
2584            regular_product_grid(d, resolution, -0.5, 0.5, true)
2585        }
2586    }
2587}
2588
2589/// A regular `resolution`-per-axis product grid over `[lo, hi]^d`, capped at
2590/// [`SHAPE_BAND_MAX_POINTS`] total points (the per-axis resolution is reduced
2591/// until the product fits). When `include_endpoint` the last grid point sits at
2592/// `hi`; otherwise the axis is treated as periodic and stops one step short.
2593pub(crate) fn regular_product_grid(
2594    d: usize,
2595    resolution: usize,
2596    lo: f64,
2597    hi: f64,
2598    include_endpoint: bool,
2599) -> Array2<f64> {
2600    if d == 0 {
2601        return Array2::<f64>::zeros((1, 0));
2602    }
2603    let mut per_axis = resolution.max(2);
2604    while per_axis.saturating_pow(d as u32) > SHAPE_BAND_MAX_POINTS && per_axis > 2 {
2605        per_axis -= 1;
2606    }
2607    let total = per_axis.saturating_pow(d as u32).max(1);
2608    let denom = if include_endpoint {
2609        (per_axis.max(2) - 1) as f64
2610    } else {
2611        per_axis as f64
2612    };
2613    let mut grid = Array2::<f64>::zeros((total, d));
2614    let mut idx = vec![0usize; d];
2615    for flat in 0..total {
2616        for axis in 0..d {
2617            let frac = idx[axis] as f64 / denom;
2618            grid[[flat, axis]] = lo + (hi - lo) * frac;
2619        }
2620        for axis in (0..d).rev() {
2621            idx[axis] += 1;
2622            if idx[axis] < per_axis {
2623                break;
2624            }
2625            idx[axis] = 0;
2626        }
2627    }
2628    grid
2629}
2630
2631/// Lat/lon sphere chart grid: `lat ∈ [−π/2, π/2]`, `lon ∈ [−π, π)`, matching
2632/// the [`crate::manifold::SphereChartEvaluator`] convention.
2633pub(crate) fn sphere_latlon_grid(resolution: usize) -> Array2<f64> {
2634    use std::f64::consts::PI;
2635    let r = resolution.max(2).min(22); // 22² = 484 ≤ SHAPE_BAND_MAX_POINTS.
2636    let mut grid = Array2::<f64>::zeros((r * r, 2));
2637    for i in 0..r {
2638        let lat = -PI / 2.0 + PI * (i as f64 + 0.5) / r as f64;
2639        for j in 0..r {
2640            let lon = -PI + 2.0 * PI * (j as f64) / r as f64;
2641            grid[[i * r + j, 0]] = lat;
2642            grid[[i * r + j, 1]] = lon;
2643        }
2644    }
2645    grid
2646}
2647
2648/// Cylinder `S¹ × ℝ` chart-center grid: axis 0 sweeps the periodic circle over
2649/// one period `[0, 1)` (no endpoint, matching the harmonic axis), axis 1 strides
2650/// a unit box `[−0.5, 0.5]` about the origin on the unbounded line (with
2651/// endpoint). Capped at [`SHAPE_BAND_MAX_POINTS`] total centers.
2652pub(crate) fn cylinder_chart_center_grid(resolution: usize) -> Array2<f64> {
2653    let mut per_axis = resolution.max(2);
2654    while per_axis * per_axis > SHAPE_BAND_MAX_POINTS && per_axis > 2 {
2655        per_axis -= 1;
2656    }
2657    let total = per_axis * per_axis;
2658    let line_denom = (per_axis.max(2) - 1) as f64;
2659    let mut grid = Array2::<f64>::zeros((total, 2));
2660    for i in 0..per_axis {
2661        // Periodic axis 0: stop one step short of the period.
2662        let circle = i as f64 / per_axis as f64;
2663        for j in 0..per_axis {
2664            // Line axis 1: include the endpoint of the unit box.
2665            let line = -0.5 + (j as f64) / line_denom;
2666            grid[[i * per_axis + j, 0]] = circle;
2667            grid[[i * per_axis + j, 1]] = line;
2668        }
2669    }
2670    grid
2671}
2672
2673/// Nominal in-chart radius: half the inter-center grid spacing, so charts tile
2674/// the domain. For compact latents this is the grid step; for unbounded latents
2675/// a unit default that the certified radius refines.
2676pub(crate) fn chart_nominal_radius(atom: &SaeManifoldAtom, resolution: usize) -> f64 {
2677    use crate::manifold::SaeAtomBasisKind::*;
2678    match &atom.basis_kind {
2679        Periodic | Torus => 0.5 / (resolution.max(2) as f64),
2680        Sphere => std::f64::consts::PI / (resolution.max(2) as f64),
2681        // Cylinder charts tile two heterogeneous axes (a `[0,1)` periodic step
2682        // and a unit-box line step); the chart radius is a single scalar, so we
2683        // take the tighter (periodic) step `0.5/res` to keep every chart valid
2684        // on both axes. The certified Kantorovich radius refines it per chart.
2685        Cylinder => 0.5 / (resolution.max(2) as f64),
2686        Linear | Duchon | EuclideanPatch | Poincare | Precomputed(_) => {
2687            1.0 / (resolution.max(2) as f64)
2688        }
2689    }
2690}
2691
2692/// Build the [`ChartRegion`] for a center, attaching the radial r_min / r_max
2693/// bracket for Duchon atoms (the chart's distance range to the kernel centers).
2694pub(crate) fn chart_region(
2695    atom: &SaeManifoldAtom,
2696    center: Array1<f64>,
2697    radius: f64,
2698) -> ChartRegion {
2699    use crate::manifold::SaeAtomBasisKind::*;
2700    let region = ChartRegion::new(center.clone(), radius);
2701    match &atom.basis_kind {
2702        Duchon => {
2703            // r ranges over [‖t_c‖ − radius, ‖t_c‖ + radius] about the single
2704            // origin-anchored center used by the conservative radial bound.
2705            //
2706            // The lower bound must be `max(0, center_norm − radius)` — NOT floored
2707            // at `radius`. When the chart contains the kernel center
2708            // (`center_norm < radius`, true r_min = 0), flooring at `radius`
2709            // would give a finite, NON-CONSERVATIVE `r_min`, causing the
2710            // hessian_sup / third_sup formulas (which divide by r_min) to
2711            // underestimate the Lipschitz constant and potentially grant a false
2712            // Kantorovich certificate. Flooring at `f64::MIN_POSITIVE` instead
2713            // correctly drives the formulas toward ∞, producing a very large L
2714            // that will NEVER certify (rows route to the exact multi-start
2715            // fallback) — conservative and sound.
2716            let center_norm = center.dot(&center).sqrt();
2717            let r_min = (center_norm - radius).max(f64::MIN_POSITIVE);
2718            let r_max = center_norm + radius;
2719            region.with_radial_bounds(r_min, r_max)
2720        }
2721        // Cylinder has no radial kernel block (it is a harmonic × polynomial
2722        // tensor, not a Duchon radial basis), so it needs no radial r_min/r_max.
2723        Periodic | Sphere | Torus | Cylinder | Linear | EuclideanPatch | Poincare
2724        | Precomputed(_) => region,
2725    }
2726}