Skip to main content

gam_sae/
encode.rs

1//! Kantorovich-certified encode atlas (issue #1010).
2//!
3//! Encoding a row `x ∈ ℝᵖ` against a FROZEN SAE dictionary is, per atom `k`,
4//! the coordinate-only Newton problem
5//!
6//! ```text
7//! min_t  f_k(t) = ½‖x − z_k · B_kᵀ Φ_k(t)‖² + prior_k(t),
8//! ```
9//!
10//! with the amplitude `z_k` and decoder block `B_k` held fixed (the encode
11//! freezes the dictionary; only the latent coordinate `t` moves). Newton on
12//! `F(t) = ∇f_k(t)` converges quadratically from a start `t₀` into the unique
13//! root in a certified ball whenever the **Newton–Kantorovich** quantity
14//!
15//! ```text
16//! h = β · η · L ≤ ½,    β = ‖F'(t₀)⁻¹‖,   η = ‖F'(t₀)⁻¹ F(t₀)‖,
17//! ```
18//!
19//! where `L` is a Lipschitz constant of `F'` (the Hessian of `f_k`) on a region
20//! containing the Newton iterates. `h` is CHECKABLE per row in `O(q³)`
21//! (`q = latent_dim`, tiny), so each fast-path encode carries its own
22//! exactness certificate.
23//!
24//! ## The closed-form Hessian-Lipschitz constant `L`
25//!
26//! Write `m(t) = z·BᵀΦ(t) ∈ ℝᵖ` (the reconstruction) and `r(t) = m(t) − x`.
27//! Then `f = ½‖r‖² + prior` and, differentiating three times,
28//!
29//! ```text
30//! ∇³f = 3·sym(J_mᵀ : ∇²m) + ⟨r, ∇³m⟩ + ∇³prior,
31//! ```
32//!
33//! so an operator-norm bound on the chart is
34//!
35//! ```text
36//! L ≤ 3·‖J_m‖·‖∇²m‖ + ‖r‖·‖∇³m‖ + L_prior,
37//! ```
38//!
39//! with `‖∂^g m‖ ≤ |z|·(Σ_m ‖B_{m,:}‖)·B_g`, where `B_g = sup_chart max_m
40//! ‖∂^g Φ_m‖` is the per-column jet sup of the basis family — closed form per
41//! family ([`BasisHessianLipschitz`]). `‖r‖` is bounded by `‖x‖ +
42//! |z|·(Σ_m‖B_{m,:}‖)·B_0`. The ARD/von-Mises prior `L_prior` is a closed-form
43//! constant from the prior strength. Every bound is conservative (an
44//! over-estimate of `L` only SHRINKS the certified radius — it can never
45//! certify a row that does not converge).
46//!
47//! ## Pipeline
48//!
49//! 1. **Offline, per atom** ([`EncodeAtlas::build`]): chart centers `t_c` on the
50//!    atom's coordinate grid (the SHAPE_BAND grid idiom), each with a certified
51//!    Newton radius `R_c` solved from the Kantorovich inequality at the
52//!    worst-case in-chart start.
53//! 2. **Online, per row** ([`EncodeAtlas::certified_encode_row`]): route to the
54//!    nearest chart, start from its distilled IFT predictor, take one or two
55//!    Newton steps, then the `h ≤ ½` check AT the start point is the per-row
56//!    certificate.
57//! 3. **Uncertified tail**: rows whose start fails `h ≤ ½` are FLAGGED (counted
58//!    in [`EncodeResult::encode_uncertified_count`]) and must be routed by the
59//!    caller to the existing exact multi-start solve. No approximation enters
60//!    silently.
61
62use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
63
64use gam_linalg::faer_ndarray::FaerEigh;
65use crate::candidate_index::{
66    AtomFrameSketch, SaeCandidateIndex, auto_candidate_budget,
67};
68use crate::manifold::{
69    AffineCoordinateEvaluator, CylinderHarmonicEvaluator, DuchonCoordinateEvaluator,
70    EuclideanPatchEvaluator, PeriodicHarmonicEvaluator, SaeBasisEvaluator, SaeManifoldAtom,
71    SphereChartEvaluator, TorusHarmonicEvaluator,
72};
73
74use faer::Side;
75
76/// The Kantorovich convergence threshold `h ≤ ½`. Below this the Newton
77/// iteration is guaranteed to converge quadratically into the unique root in
78/// the certified ball; at or above it the start is uncertified.
79pub const KANTOROVICH_THRESHOLD: f64 = 0.5;
80
81/// Row count at or above which the corpus-rate certified-encode batch
82/// (`certified_encode_batch` / `certified_encode_with_index`) fans its
83/// per-row encodes out over rayon. Below this the per-row Newton + chart
84/// routing is cheap enough that the fan-out overhead does not pay; matched to
85/// the same order as the arrow-Schur `SCHUR_MATVEC_PARALLEL_ROW_MIN` gate so
86/// short batches inside an outer atom-level fan-out stay sequential.
87pub(crate) const ENCODE_BATCH_PARALLEL_ROW_MIN: usize = 256;
88
89/// Minimum frame alignment `‖Uₖᵀd‖/‖d‖ ∈ [0,1]` the best LSH-gathered atom must
90/// have for an index-routed encode to be TRUSTED (#1026). The in-atom Kantorovich
91/// certificate attests Newton convergence *within* the chosen atom; it does NOT
92/// attest that the LSH gather found the globally-correct atom. Because alignment
93/// is capped at 1, a HIGH best-alignment leaves no room for a materially-better
94/// ungathered atom (the gather is trustworthy), whereas a LOW best-alignment means
95/// the true best atom was likely missed by LSH recall. Rows whose best gathered
96/// atom aligns below this are flagged uncertified and routed to the exact
97/// full-scan fallback — closing the silent-mis-routing hole where a non-empty but
98/// wrong gather certified against a suboptimal atom. Erring toward flagging only
99/// ever adds exact-fallback work (correct, slower); it never certifies a mis-route.
100pub(crate) const CANDIDATE_ROUTING_MIN_ALIGNMENT: f64 = 0.5;
101
102/// Number of nearest charts the CERTIFIED encode refines in before returning the
103/// lowest-reconstruction-error certified result. A single nearest chart is not
104/// globally sound where the decoded manifold folds near itself (both competing
105/// basins' charts reconstruct near the fold, so both rank among the nearest by
106/// ambient distance); refining the top few captures the global basin. For a
107/// unimodal atom all candidates converge to the same root, so K>1 is a no-op.
108pub(crate) const CERTIFIED_ROUTING_TOPK: usize = 4;
109
110/// A chart region on an atom's latent coordinate: a center `t_c` plus a
111/// certified in-chart radius. Over the ball `‖t − t_c‖ ≤ radius` the jet sup
112/// bounds returned by [`BasisHessianLipschitz`] hold, so the Kantorovich
113/// constant `L` computed from them is valid for any start in the ball.
114///
115/// For radial (Duchon) families the chart also carries the minimum kernel-center
116/// distance `exclusion_r_min` (a lower bound on `‖t − c_k‖` over the chart) that
117/// bounds the otherwise-singular `1/r` radial tails (issue #1010).
118#[derive(Debug, Clone)]
119pub struct ChartRegion {
120    /// Chart center coordinate `t_c` (length = latent_dim).
121    pub center: Array1<f64>,
122    /// In-chart radius in the coordinate metric.
123    pub radius: f64,
124    /// For radial (Duchon) families: a lower bound on `‖t − c_k‖` over the
125    /// chart, across every kernel center `c_k`. `None` for non-radial families.
126    pub exclusion_r_min: Option<f64>,
127    /// For radial (Duchon) families: an upper bound on `‖t − c_k‖` over the
128    /// chart, across every kernel center `c_k`. `None` for non-radial families.
129    pub radial_r_max: Option<f64>,
130}
131
132impl ChartRegion {
133    pub fn new(center: Array1<f64>, radius: f64) -> Self {
134        Self {
135            center,
136            radius,
137            exclusion_r_min: None,
138            radial_r_max: None,
139        }
140    }
141
142    pub fn with_radial_bounds(mut self, r_min: f64, r_max: f64) -> Self {
143        self.exclusion_r_min = Some(r_min);
144        self.radial_r_max = Some(r_max);
145        self
146    }
147
148    /// A jet-sup certificate is only meaningful over a genuine region. Even
149    /// families whose bounds are manifold-global constants (the sup over any
150    /// chart equals the global sup) must refuse a malformed chart rather than
151    /// certify garbage geometry.
152    pub(crate) fn assert_valid(&self) {
153        assert!(
154            self.radius.is_finite()
155                && self.radius >= 0.0
156                && self.center.iter().all(|c| c.is_finite()),
157            "ChartRegion must have a finite center and a finite non-negative radius"
158        );
159    }
160}
161
162/// Per-column sup-norm bounds on the first three coordinate jets of a basis
163/// family `Φ(t)`, valid over a stated [`ChartRegion`] (issue #1010). These are
164/// the analytic ingredients of the Hessian-Lipschitz constant `L` — see the
165/// module docs for the assembly. `value_sup` bounds `max_m |Φ_m|`,
166/// `jacobian_sup`/`hessian_sup`/`third_sup` bound `max_m ‖∂^g Φ_m‖`.
167pub trait BasisHessianLipschitz {
168    fn value_sup(&self, chart: &ChartRegion) -> f64;
169    fn jacobian_sup(&self, chart: &ChartRegion) -> f64;
170    fn hessian_sup(&self, chart: &ChartRegion) -> f64;
171    fn third_sup(&self, chart: &ChartRegion) -> f64;
172}
173
174/// Sup over the circle of the `g`-th derivative of any single harmonic column
175/// of a `num_basis`-wide Fourier basis `[1, sin(2π h t), cos(2π h t), …]`:
176/// `(2π·H)^g` for the top harmonic `H = (num_basis − 1)/2`. The constant column
177/// contributes `0` for `g ≥ 1`, so the top harmonic dominates; the bound is
178/// global (the trig magnitudes are `≤ 1` everywhere, independent of the chart).
179pub(crate) fn harmonic_jet_sup(num_basis: usize, order: u32) -> f64 {
180    let top_harmonic = num_basis.saturating_sub(1) / 2;
181    let omega = std::f64::consts::TAU * top_harmonic as f64;
182    omega.powi(order as i32)
183}
184
185impl BasisHessianLipschitz for PeriodicHarmonicEvaluator {
186    fn value_sup(&self, chart: &ChartRegion) -> f64 {
187        chart.assert_valid();
188        1.0
189    }
190    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
191        chart.assert_valid();
192        harmonic_jet_sup(self.num_basis, 1)
193    }
194    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
195        chart.assert_valid();
196        harmonic_jet_sup(self.num_basis, 2)
197    }
198    fn third_sup(&self, chart: &ChartRegion) -> f64 {
199        chart.assert_valid();
200        harmonic_jet_sup(self.num_basis, 3)
201    }
202}
203
204impl BasisHessianLipschitz for TorusHarmonicEvaluator {
205    /// Tensor product of per-axis circle harmonics. A torus basis column is a
206    /// product of single-axis harmonics, each bounded as in the circle case.
207    /// The `g`-th coordinate jet routes `g` derivative operators across the
208    /// `latent_dim` factors (Leibniz); each routing contributes a product of
209    /// per-axis derivative magnitudes. A per-column sup is therefore bounded by
210    /// the top single-axis frequency to the `g`-th power times the number of
211    /// such routings (`latent_dim^g`, the count of operator-to-axis maps).
212    fn value_sup(&self, chart: &ChartRegion) -> f64 {
213        chart.assert_valid();
214        1.0
215    }
216    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
217        chart.assert_valid();
218        torus_jet_sup(self.num_harmonics, self.latent_dim, 1)
219    }
220    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
221        chart.assert_valid();
222        torus_jet_sup(self.num_harmonics, self.latent_dim, 2)
223    }
224    fn third_sup(&self, chart: &ChartRegion) -> f64 {
225        chart.assert_valid();
226        torus_jet_sup(self.num_harmonics, self.latent_dim, 3)
227    }
228}
229
230/// Per-column `g`-th jet sup for the torus harmonic basis: `(2π·H)^g ·
231/// latent_dim^g`, where `H = num_harmonics` is the top per-axis frequency and
232/// `latent_dim^g` over-counts the Leibniz routings of `g` operators across the
233/// product factors (a conservative bound — each routing's per-axis magnitude is
234/// `≤ (2π H)^{#ops on that axis}`, and the products telescope to `(2π H)^g`).
235pub(crate) fn torus_jet_sup(num_harmonics: usize, latent_dim: usize, order: u32) -> f64 {
236    let omega = std::f64::consts::TAU * num_harmonics as f64;
237    omega.powi(order as i32) * (latent_dim as f64).powi(order as i32)
238}
239
240impl BasisHessianLipschitz for SphereChartEvaluator {
241    /// The 7-column lat/lon chart `[1, x, y, z, xy, yz, xz]` with
242    /// `x = cos(lat)cos(lon)`, `y = cos(lat)sin(lon)`, `z = sin(lat)`. Each of
243    /// `x, y, z` is a product of two unit-frequency trig factors, so its `g`-th
244    /// coordinate jet is a sum of `2^g` products of `{sin,cos}` (each `≤ 1`):
245    /// magnitude `≤ 2^g` for `g ≥ 1`, `≤ 1` for `g = 0`. The bilinear columns
246    /// `xy, yz, xz` are products of two such coordinates; by Leibniz over the
247    /// product, their `g`-th jet is bounded by `Σ_{i=0}^{g} C(g,i)·(2^i)·(2^{g−i})
248    /// = (2+2)^g = 4^g` (using `‖∂^i u‖ ≤ 2^i`, `|u| ≤ 1`). The bilinear columns
249    /// dominate, so the per-column sup is `4^g` (`g ≥ 1`). Bounds are global
250    /// constants — the chart box `lat ∈ [-π/2, π/2]` does not enlarge them.
251    fn value_sup(&self, chart: &ChartRegion) -> f64 {
252        chart.assert_valid();
253        1.0
254    }
255    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
256        chart.assert_valid();
257        4.0
258    }
259    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
260        chart.assert_valid();
261        16.0
262    }
263    fn third_sup(&self, chart: &ChartRegion) -> f64 {
264        chart.assert_valid();
265        64.0
266    }
267}
268
269impl BasisHessianLipschitz for AffineCoordinateEvaluator {
270    /// The affine basis `[1, t₁, …, t_d]` is degree ≤ 1: its first jet has unit
271    /// columns, and all second and third jets vanish. The value sup is
272    /// `max(1, ‖t‖)` over the chart, bounded by `1 + ‖t_c‖ + radius`.
273    fn value_sup(&self, chart: &ChartRegion) -> f64 {
274        let center_norm = chart.center.dot(&chart.center).sqrt();
275        1.0 + center_norm + chart.radius
276    }
277    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
278        chart.assert_valid();
279        1.0
280    }
281    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
282        chart.assert_valid();
283        0.0
284    }
285    fn third_sup(&self, chart: &ChartRegion) -> f64 {
286        chart.assert_valid();
287        0.0
288    }
289}
290
291impl BasisHessianLipschitz for EuclideanPatchEvaluator {
292    /// Monomials of total degree ≤ `max_degree` in `t ∈ ℝ^d`. Over the ball of
293    /// radius `R` about `t_c`, each coordinate is bounded by `ρ = ‖t_c‖∞ + R`.
294    /// A monomial `t^α` with `|α| = q` has `g`-th partials bounded (crudely) by
295    /// the descending-factorial coefficient `q·(q−1)···(q−g+1) ≤ q^g` times
296    /// `ρ^{max(q−g,0)}`, and there are at most `d^g` partial routings, so the
297    /// per-column `g`-th jet sup is `≤ d^g · D^g · ρ^{max(D−g,0)}` with
298    /// `D = max_degree`. Conservative; D is small for patch evaluators.
299    fn value_sup(&self, chart: &ChartRegion) -> f64 {
300        let rho = patch_rho(chart);
301        let d = self.max_degree as i32;
302        rho.powi(d).max(1.0)
303    }
304    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
305        patch_jet_sup(self.latent_dim, self.max_degree, chart, 1)
306    }
307    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
308        patch_jet_sup(self.latent_dim, self.max_degree, chart, 2)
309    }
310    fn third_sup(&self, chart: &ChartRegion) -> f64 {
311        patch_jet_sup(self.latent_dim, self.max_degree, chart, 3)
312    }
313}
314
315impl BasisHessianLipschitz for CylinderHarmonicEvaluator {
316    /// Cylinder `S¹ × ℝ` product basis `Φ_{c,l} = c(t₀)·l(t₁)`, the circle
317    /// (periodic harmonic) factor on axis 0 crossed with the monomial line
318    /// factor on axis 1. Because the two factors depend on disjoint coordinates,
319    /// the order-`g` coordinate jet in any cell is exactly
320    /// `c^{(k₀)}(t₀)·l^{(k₁)}(t₁)` with `k₀ + k₁ = g`, so the per-column sup is
321    /// the max over the split `k₀ + k₁ = g` of the product of the two per-axis
322    /// per-order sups: the circle factor contributes `1` at order 0 and
323    /// `(2π·H)^{k₀}` at order `k₀ ≥ 1` (trig magnitudes `≤ 1`); the line factor
324    /// contributes the monomial-patch sup `D^{k₁}·ρ^{max(D−k₁,0)}` (`D = line
325    /// degree`, `ρ = ‖t_c‖∞ + radius`). Bounds are global in the periodic axis
326    /// and chart-local in the line axis.
327    fn value_sup(&self, chart: &ChartRegion) -> f64 {
328        cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 0)
329    }
330    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
331        cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 1)
332    }
333    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
334        cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 2)
335    }
336    fn third_sup(&self, chart: &ChartRegion) -> f64 {
337        cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 3)
338    }
339}
340
341/// Per-column order-`g` jet sup of the cylinder product basis: the max over
342/// `k₀ + k₁ = g` of `circle_axis_sup(k₀) · line_axis_sup(k₁)`, where the circle
343/// axis sup is `(2π·H)^{k₀}` (`1` at `k₀ = 0`) and the line axis sup is the
344/// monomial-patch bound `D^{k₁}·ρ^{max(D−k₁,0)}` (`1` at `k₁ = 0`). See the
345/// [`CylinderHarmonicEvaluator`] doc comment for the derivation.
346pub(crate) fn cylinder_jet_sup(
347    circle_harmonics: usize,
348    line_degree: usize,
349    chart: &ChartRegion,
350    order: u32,
351) -> f64 {
352    let omega = std::f64::consts::TAU * circle_harmonics as f64;
353    let big_d = line_degree as f64;
354    let rho = patch_rho(chart);
355    let mut best = 0.0_f64;
356    for k0 in 0..=order {
357        let k1 = order - k0;
358        let circle = if k0 == 0 { 1.0 } else { omega.powi(k0 as i32) };
359        let line = if k1 == 0 {
360            rho.powi(line_degree as i32).max(1.0)
361        } else {
362            let residual = line_degree.saturating_sub(k1 as usize) as i32;
363            // `.max(1.0)` as in `patch_jet_sup`: for ρ < 1 a lower-degree line
364            // monomial dominates the k1-th derivative, so the bare `ρ^residual`
365            // underestimates the line-factor sup. The value case (k1==0) already
366            // clamps; this completes it for the derivative orders.
367            big_d.powi(k1 as i32) * rho.powi(residual).max(1.0)
368        };
369        best = best.max(circle * line);
370    }
371    best
372}
373
374/// Sup-norm radius `ρ = ‖t_c‖∞ + radius` of the chart (the coordinate magnitude
375/// bound used by the monomial-patch jet bounds).
376pub(crate) fn patch_rho(chart: &ChartRegion) -> f64 {
377    let center_inf = chart
378        .center
379        .iter()
380        .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
381    center_inf + chart.radius
382}
383
384/// Per-column `g`-th jet sup for a monomial patch of max degree `D` in `d`
385/// coordinates over the chart: `d^g · D^g · ρ^{max(D−g,0)}` (see the
386/// [`EuclideanPatchEvaluator`] doc comment for the derivation).
387pub(crate) fn patch_jet_sup(
388    latent_dim: usize,
389    max_degree: usize,
390    chart: &ChartRegion,
391    order: u32,
392) -> f64 {
393    let d = latent_dim as f64;
394    let big_d = max_degree as f64;
395    let rho = patch_rho(chart);
396    let residual_degree = max_degree.saturating_sub(order as usize) as i32;
397    // `.max(1.0)`: for ρ < 1 (small charts near the origin) the g-th jet sup is NOT
398    // dominated by the max-degree monomial `t^D` (whose g-th derivative ~ ρ^{D-g}
399    // shrinks with ρ) but by a LOWER-degree monomial whose g-th derivative is a
400    // larger constant — e.g. {1,t,t²}'s jacobian sup is the linear term's constant
401    // `1`, which exceeds `2ρ` when ρ < ½. Without the clamp the bound underestimates
402    // the true sup (numerically: D=3, ρ=0.1, g=1 → formula 0.03 vs true 1.0), which
403    // would make the certificate's Lipschitz `L` too small → a FALSE certificate.
404    // `D^g · max(ρ^{D-g}, 1)` upper-bounds `max_{q∈[g,D]} (q!/(q-g)!)·ρ^{q-g}` for
405    // all ρ (the `q=g` term gives `g! ≤ D^g`, the `q=D` term gives `≤ D^g·ρ^{D-g}`).
406    d.powi(order as i32) * big_d.powi(order as i32) * rho.powi(residual_degree).max(1.0)
407}
408
409impl BasisHessianLipschitz for DuchonCoordinateEvaluator {
410    /// Radial-kernel basis `Φ_m(t) = φ(r_m)`, `r_m = ‖t − c_m‖`, plus a
411    /// polynomial nullspace block. For the cubic Duchon kernel `φ(r) = r³` the
412    /// radial derivatives are `φ' = 3r²`, `φ'' = 6r`, `φ''' = 6`. The chain rule
413    /// to coordinate jets introduces `1/r` factors through the unit radial
414    /// direction `u = (t − c)/r` and the projector `(I − uuᵀ)/r`, so over a
415    /// chart the jets are bounded by combining the radial-derivative magnitudes
416    /// at the worst-case radius with the inverse-radius tail at the chart's
417    /// EXCLUSION radius `r_min` (the closest a chart point gets to any center):
418    ///
419    /// ```text
420    /// ‖∇φ‖    ≤ |φ'|                              ≤ 3 r_max²
421    /// ‖∇²φ‖   ≤ |φ''| + |φ'|/r                    ≤ 6 r_max + 3 r_max²/r_min
422    /// ‖∇³φ‖   ≤ |φ'''| + 3|φ''|/r + 3|φ'|/r²      ≤ 6 + 18 r_max/r_min + 9 r_max²/r_min²
423    /// ```
424    ///
425    /// (the `1/r`, `1/r²` tails are bounded by `1/r_min`, `1/r_min²`). The
426    /// polynomial nullspace block is degree ≤ `order`; its jets are bounded like
427    /// the monomial patch with `D = order`. The per-column sup is the max of the
428    /// kernel and polynomial bounds. The `r³` kernel is itself `C²` (no
429    /// singularity) so these tails are conservative but finite for any
430    /// `r_min > 0`; the atlas refines charts to keep `r_min` bounded away from 0.
431    fn value_sup(&self, chart: &ChartRegion) -> f64 {
432        let r_max = chart.radial_r_max.unwrap_or(chart.radius);
433        let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 0);
434        (r_max.powi(3)).max(poly)
435    }
436    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
437        let r_max = chart.radial_r_max.unwrap_or(chart.radius);
438        let kernel = 3.0 * r_max * r_max;
439        let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 1);
440        kernel.max(poly)
441    }
442    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
443        let r_max = chart.radial_r_max.unwrap_or(chart.radius);
444        let r_min = chart
445            .exclusion_r_min
446            .unwrap_or(chart.radius)
447            .max(f64::MIN_POSITIVE);
448        let kernel = 6.0 * r_max + 3.0 * r_max * r_max / r_min;
449        let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 2);
450        kernel.max(poly)
451    }
452    fn third_sup(&self, chart: &ChartRegion) -> f64 {
453        let r_max = chart.radial_r_max.unwrap_or(chart.radius);
454        let r_min = chart
455            .exclusion_r_min
456            .unwrap_or(chart.radius)
457            .max(f64::MIN_POSITIVE);
458        let kernel = 6.0 + 18.0 * r_max / r_min + 9.0 * r_max * r_max / (r_min * r_min);
459        let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 3);
460        kernel.max(poly)
461    }
462}
463
464/// Polynomial-block degree of a Duchon nullspace order, used to bound the
465/// nullspace columns like a monomial patch.
466trait DuchonOrderDegree {
467    fn order_degree(&self) -> usize;
468}
469
470impl DuchonOrderDegree for DuchonCoordinateEvaluator {
471    fn order_degree(&self) -> usize {
472        match self.order {
473            gam_terms::basis::DuchonNullspaceOrder::Zero => 0,
474            gam_terms::basis::DuchonNullspaceOrder::Linear => 1,
475            gam_terms::basis::DuchonNullspaceOrder::Degree(d) => d,
476        }
477    }
478}
479
480/// Per-column `g`-th jet sup of the Duchon polynomial nullspace block, treated
481/// as a monomial patch of degree `order_degree`.
482pub(crate) fn duchon_poly_jet_sup(
483    latent_dim: usize,
484    order_degree: usize,
485    chart: &ChartRegion,
486    order: u32,
487) -> f64 {
488    if order_degree == 0 {
489        return if order == 0 { 1.0 } else { 0.0 };
490    }
491    patch_jet_sup(latent_dim, order_degree, chart, order)
492}
493
494/// Decoder magnitude `Σ_m ‖B_{m,:}‖₂` of an atom's frozen decoder block: the
495/// factor that converts a per-column `Φ`-jet sup `B_g` into a reconstruction
496/// jet sup `‖∂^g m‖ ≤ |z|·decoder_row_norm_sum·B_g`.
497pub(crate) fn decoder_row_norm_sum(decoder: ArrayView2<'_, f64>) -> f64 {
498    let mut acc = 0.0;
499    for row in decoder.rows() {
500        acc += row.dot(&row).sqrt();
501    }
502    acc
503}
504
505#[derive(Debug, Clone, Copy)]
506pub(crate) struct ReconstructionJetSups {
507    pub(crate) value: f64,
508    pub(crate) jacobian: f64,
509    pub(crate) hessian: f64,
510    pub(crate) third: f64,
511}
512
513pub(crate) fn pair_trig_decoder_sup(
514    sin_row: ArrayView1<'_, f64>,
515    cos_row: ArrayView1<'_, f64>,
516) -> f64 {
517    let aa = sin_row.dot(&sin_row);
518    let bb = cos_row.dot(&cos_row);
519    let ab = sin_row.dot(&cos_row);
520    let trace = aa + bb;
521    let disc = ((aa - bb) * (aa - bb) + 4.0 * ab * ab).sqrt();
522    (0.5 * (trace + disc)).sqrt()
523}
524
525pub(crate) fn periodic_reconstruction_jet_sups(
526    decoder: ArrayView2<'_, f64>,
527) -> ReconstructionJetSups {
528    let mut value = 0.0;
529    let mut jacobian = 0.0;
530    let mut hessian = 0.0;
531    let mut third = 0.0;
532    if decoder.nrows() > 0 {
533        value += decoder.row(0).dot(&decoder.row(0)).sqrt();
534    }
535    let harmonics = decoder.nrows().saturating_sub(1) / 2;
536    for h in 1..=harmonics {
537        let sin_idx = 2 * h - 1;
538        let cos_idx = 2 * h;
539        let amp = pair_trig_decoder_sup(decoder.row(sin_idx), decoder.row(cos_idx));
540        let omega = std::f64::consts::TAU * h as f64;
541        value += amp;
542        jacobian += omega * amp;
543        hessian += omega.powi(2) * amp;
544        third += omega.powi(3) * amp;
545    }
546    for row in (1 + 2 * harmonics)..decoder.nrows() {
547        let amp = decoder.row(row).dot(&decoder.row(row)).sqrt();
548        value += amp;
549        let omega = std::f64::consts::TAU * harmonics.max(1) as f64;
550        jacobian += omega * amp;
551        hessian += omega.powi(2) * amp;
552        third += omega.powi(3) * amp;
553    }
554    ReconstructionJetSups {
555        value,
556        jacobian,
557        hessian,
558        third,
559    }
560}
561
562pub(crate) fn reconstruction_jet_sups(
563    atom: &SaeManifoldAtom,
564    sups: JetSups,
565) -> ReconstructionJetSups {
566    if matches!(
567        atom.basis_kind,
568        crate::manifold::SaeAtomBasisKind::Periodic
569    ) {
570        periodic_reconstruction_jet_sups(atom.decoder_coefficients.view())
571    } else {
572        let decoder_norm_sum = decoder_row_norm_sum(atom.decoder_coefficients.view());
573        ReconstructionJetSups {
574            value: decoder_norm_sum * sups.value,
575            jacobian: decoder_norm_sum * sups.jacobian,
576            hessian: decoder_norm_sum * sups.hessian,
577            third: decoder_norm_sum * sups.third,
578        }
579    }
580}
581
582/// The Hessian-Lipschitz constant `L` of the per-row encode objective `f_k` on
583/// a chart, assembled in closed form from the basis jet sups and the decoder /
584/// amplitude / target magnitudes. See the module docs for the derivation:
585///
586/// ```text
587/// L ≤ 3·‖J_m‖·‖∇²m‖ + ‖r‖·‖∇³m‖ + L_prior,
588/// ‖∂^g m‖ ≤ |z|·S_B·B_g,   S_B = Σ_m ‖B_{m,:}‖,
589/// ‖r‖ ≤ ‖x‖ + |z|·S_B·B_0,
590/// ```
591///
592/// `prior_lipschitz` is the caller-supplied closed-form `L_prior` of the
593/// ARD/von-Mises coordinate prior (`0.0` if no prior is active on the encode).
594pub(crate) fn hessian_lipschitz_constant(
595    recon_sups: ReconstructionJetSups,
596    amplitude: f64,
597    target_norm: f64,
598    prior_lipschitz: f64,
599) -> f64 {
600    let z = amplitude.abs();
601    let m_jac = z * recon_sups.jacobian;
602    let m_hess = z * recon_sups.hessian;
603    let m_third = z * recon_sups.third;
604    let recon_value = z * recon_sups.value;
605    let r_norm = target_norm + recon_value;
606    3.0 * m_jac * m_hess + r_norm * m_third + prior_lipschitz
607}
608
609/// One offline-certified chart: a center, its Kantorovich constants, and the
610/// certified Newton-convergence radius `R_c` solved from `h = β·η·L ≤ ½` at the
611/// worst-case in-chart start.
612#[derive(Debug, Clone)]
613pub struct CertifiedChart {
614    pub region: ChartRegion,
615    /// Closed-form Hessian-Lipschitz constant `L` over the chart.
616    pub lipschitz: f64,
617    /// `β = ‖F'(t_c)⁻¹‖` at the chart center (worst-case in-chart start uses
618    /// the center's curvature; the radius is solved so the certificate holds for
619    /// any start in the ball).
620    pub beta_center: f64,
621    /// Certified Newton radius: starts within `radius` of `t_c` satisfy `h ≤ ½`.
622    pub certified_radius: f64,
623    /// Distilled amortized-encoder Jacobian for this chart (#1026 ladder item 3).
624    ///
625    /// The exact encode map `x ↦ t` solves `F(t; x) = J_m(t)ᵀ(m(t) − x) = 0`. By
626    /// the implicit function theorem its derivative at the converged root is
627    /// `dt/dx = −(∂_t F)⁻¹ (∂_x F) = H⁻¹ J_m` (since `∂_x F = −J_m`), so the
628    /// first-order Taylor expansion of the encode map about this chart's center
629    /// `t_c` is the closed-form AFFINE predictor
630    ///
631    /// ```text
632    /// t(x) ≈ t_c + (1/z) · A₁ · (x − z · m₁(t_c)),   A₁ = (J₁ᵀJ₁ + ridge·I)⁻¹ J₁,
633    /// ```
634    ///
635    /// with `J₁ = Bᵀ J_Φ(t_c)` and `m₁(t_c) = BᵀΦ(t_c)` the AMPLITUDE-1
636    /// reconstruction jets (the amplitude `z` factors out analytically, so the
637    /// stored Jacobian is amplitude-free). This is the DISTILLED amortized
638    /// encoder of the #1026 thread: the per-row Hessian factorization + Newton
639    /// iteration is moved OFFLINE into this `d × p` matrix, leaving a single
640    /// `O(d·p)` mat-vec online — no per-row eigendecomposition, no second-jet
641    /// evaluation. The Kantorovich certificate is still evaluated AT the
642    /// predicted start, so the amortized prediction is trusted iff `h ≤ ½` and an
643    /// uncertified row still routes to the exact multi-start solve (the encoder
644    /// approximates inference, the certificate keeps it honest — the thread's
645    /// "encoder + certificate-gated exact fallback" deployment). `None` when the
646    /// center's Gauss–Newton block is singular (no certifiable amortization).
647    pub amortized_jacobian: Option<Array2<f64>>,
648    /// Amplitude-1 chart-center reconstruction `m₁(t_c) = BᵀΦ(t_c)` (length `p`),
649    /// the anchor the amortized predictor expands the encode map around.
650    pub recon_center: Array1<f64>,
651}
652
653/// The per-atom encode atlas: a set of certified charts covering the atom's
654/// coordinate domain, plus the decoder/amplitude scaling needed to recompute a
655/// per-row certificate online.
656#[derive(Debug, Clone)]
657pub struct AtomEncodeAtlas {
658    pub atom_index: usize,
659    pub latent_dim: usize,
660    pub decoder_norm_sum: f64,
661    pub charts: Vec<CertifiedChart>,
662}
663
664/// Result of a certified encode over a batch of rows, carrying the honesty
665/// flag: how many rows could NOT be certified and were flagged for the exact
666/// multi-start fallback (issue #1010 — no approximation enters silently).
667#[derive(Debug, Clone)]
668pub struct EncodeResult {
669    /// Per-row encoded latent coordinates (`n_rows × latent_dim`).
670    pub coords: Array2<f64>,
671    /// Per-row certificate: `true` ⇒ the row's start satisfied `h ≤ ½` and the
672    /// 1–2 Newton steps are exact-into-the-certified-ball; `false` ⇒ flagged.
673    pub certified: Vec<bool>,
674    /// Count of rows that could not be certified. These ride the payload so the
675    /// caller routes them to the exact multi-start encode — honesty, never
676    /// silent. Equals `certified.iter().filter(|c| !**c).count()`.
677    pub encode_uncertified_count: usize,
678}
679
680impl EncodeResult {
681    pub(crate) fn from_rows(coords: Array2<f64>, certified: Vec<bool>) -> Self {
682        let encode_uncertified_count = certified.iter().filter(|c| !**c).count();
683        Self {
684            coords,
685            certified,
686            encode_uncertified_count,
687        }
688    }
689}
690
691/// Per-row Kantorovich certificate at a start `t₀` for one atom encode.
692#[derive(Debug, Clone, Copy)]
693pub struct RowCertificate {
694    pub beta: f64,
695    pub eta: f64,
696    pub lipschitz: f64,
697    /// `h = β·η·L`. The row is certified iff `h ≤ ½`.
698    pub h: f64,
699}
700
701impl RowCertificate {
702    pub fn certified(&self) -> bool {
703        self.h.is_finite() && self.h <= KANTOROVICH_THRESHOLD
704    }
705}
706
707#[derive(Debug, Clone)]
708struct CertifiedEncodeProbe {
709    coord: Array1<f64>,
710    initial_cert: RowCertificate,
711    final_cert: RowCertificate,
712}
713
714/// Canonical flat-axis polynomial degree of a cylinder `S¹ × ℝ` atom — the
715/// degree the topology-race builder ([`gam_solve::structure_harvest`]) uses
716/// for the line axis (`CylinderHarmonicEvaluator::new(_, 2)`). The encode atlas
717/// recovers the circle harmonic count from the basis width using this degree, so
718/// the two must agree.
719pub(crate) const SAE_CYLINDER_LINE_DEGREE: usize = 2;
720
721/// Build a basis-family handle for one atom from its [`SaeManifoldAtom`]. The
722/// atlas needs to evaluate the jet sups, which live on the concrete evaluator
723/// types; the atom carries the evaluator as `Arc<dyn SaeBasisEvaluator>`, so we
724/// reconstruct the family bound from the atom's basis kind + width + centers.
725pub(crate) fn family_jet_sups(
726    atom: &SaeManifoldAtom,
727    chart: &ChartRegion,
728) -> Result<JetSups, String> {
729    use crate::manifold::SaeAtomBasisKind::*;
730    let m = atom.basis_size();
731    let d = atom.latent_dim;
732    let sups = match &atom.basis_kind {
733        Periodic => {
734            let ev = PeriodicHarmonicEvaluator::new(m)?;
735            JetSups::from_family(&ev, chart)
736        }
737        Torus => {
738            // Torus basis width is `(2H+1)^d`; recover the per-axis harmonic
739            // count `H` from `axis_m = m^(1/d)` rather than a sum formula.
740            let axis_m = integer_root(m, d.max(1));
741            let num_harmonics = axis_m.saturating_sub(1) / 2;
742            let ev = TorusHarmonicEvaluator::new(d, num_harmonics.max(1))?;
743            JetSups::from_family(&ev, chart)
744        }
745        Sphere => {
746            let ev = SphereChartEvaluator;
747            JetSups::from_family(&ev, chart)
748        }
749        Cylinder => {
750            // Cylinder width is `(2H+1)·(D+1)` with the canonical flat-axis
751            // degree `D = SAE_CYLINDER_LINE_DEGREE` (the harvest convention).
752            // Recover the per-axis circle harmonic count `H` from
753            // `2H+1 = m/(D+1)`.
754            let ml = SAE_CYLINDER_LINE_DEGREE + 1;
755            if d != 2 || ml == 0 || m % ml != 0 {
756                return Err(format!(
757                    "EncodeAtlas: Cylinder atom requires latent_dim == 2 and width divisible by {ml}; got dim={d}, m={m}"
758                ));
759            }
760            let axis_mc = m / ml;
761            let h = axis_mc.saturating_sub(1) / 2;
762            let ev = CylinderHarmonicEvaluator::new(h.max(1), SAE_CYLINDER_LINE_DEGREE)?;
763            JetSups::from_family(&ev, chart)
764        }
765        Linear | EuclideanPatch | Poincare => {
766            // The patch width fixes max_degree implicitly; bound by a degree that
767            // covers the column count (conservative). Degree d-patch column count
768            // grows fast; we recover the smallest degree whose patch is ≥ m.
769            // Poincare atoms use the same tangent-coordinate polynomial decoder;
770            // their intrinsic smoothness differs in the penalty, not in Phi(t).
771            let degree = euclidean_patch_degree(d, m);
772            let ev = EuclideanPatchEvaluator::new(d, degree)?;
773            JetSups::from_family(&ev, chart)
774        }
775        Duchon => {
776            // The atom carries the basis kind but not the nullspace order, and
777            // the certificate needs an UPPER bound on L. The kernel-tail bound
778            // (cubic r³ coefficients vs the chart's r_min/r_max) is independent
779            // of the constructed order; the polynomial-block bound grows with the
780            // order, so we construct with a conservative order whose polynomial
781            // degree upper-bounds any nullspace the atom's basis width can hold.
782            // Constructing with `m = basis_size` maps to `Degree(basis_size − 1)`
783            // — an over-estimate that keeps the Lipschitz bound sound.
784            let centers = duchon_centers_from_atom(atom);
785            let conservative_m = m.max(1);
786            let ev = DuchonCoordinateEvaluator::new(centers, conservative_m)?;
787            JetSups::from_family(&ev, chart)
788        }
789        Precomputed(name) => {
790            return Err(format!(
791                "EncodeAtlas: precomputed basis '{name}' has no closed-form jet sup; route to exact encode"
792            ));
793        }
794    };
795    Ok(sups)
796}
797
798/// Smallest monomial-patch degree whose column count covers `m` basis columns.
799pub(crate) fn euclidean_patch_degree(latent_dim: usize, m: usize) -> usize {
800    // Column count of a degree-D patch in d vars is C(d+D, D). Grow D until it
801    // covers m; cap at m so a degenerate width still terminates.
802    let mut degree = 0usize;
803    while patch_column_count(latent_dim, degree) < m && degree < m {
804        degree += 1;
805    }
806    degree
807}
808
809/// Largest integer `a` with `a^k ≤ n` (the floor of the `k`-th root). Used to
810/// recover the per-axis harmonic width `axis_m` from a torus basis width
811/// `m = axis_m^d`.
812pub(crate) fn integer_root(n: usize, k: usize) -> usize {
813    if k == 0 {
814        return 1;
815    }
816    if k == 1 {
817        return n;
818    }
819    let mut a = 1usize;
820    loop {
821        let next = a + 1;
822        let mut pow: u128 = 1;
823        let mut overflow = false;
824        for _ in 0..k {
825            pow = pow.saturating_mul(next as u128);
826            if pow > n as u128 {
827                overflow = true;
828                break;
829            }
830        }
831        if overflow {
832            return a;
833        }
834        a = next;
835    }
836}
837
838pub(crate) fn patch_column_count(latent_dim: usize, degree: usize) -> usize {
839    // C(d + D, D)
840    let mut num = 1u128;
841    let mut den = 1u128;
842    for i in 1..=degree {
843        num *= (latent_dim + i) as u128;
844        den *= i as u128;
845    }
846    (num / den) as usize
847}
848
849/// Recover Duchon centers from an atom: when the evaluator is unavailable the
850/// atlas falls back to the atom's own latent-coordinate hull as the center set,
851/// which only affects the radial-tail bound conservatively.
852pub(crate) fn duchon_centers_from_atom(atom: &SaeManifoldAtom) -> Array2<f64> {
853    // One center at the origin in latent_dim space is a sound conservative
854    // default: the chart's own r_min / r_max bracket the true radial range.
855    Array2::<f64>::zeros((1, atom.latent_dim.max(1)))
856}
857
858/// The four per-column jet sups of a basis family over a chart.
859#[derive(Debug, Clone, Copy)]
860pub(crate) struct JetSups {
861    pub(crate) value: f64,
862    pub(crate) jacobian: f64,
863    pub(crate) hessian: f64,
864    pub(crate) third: f64,
865}
866
867impl JetSups {
868    pub(crate) fn from_family<B: BasisHessianLipschitz>(family: &B, chart: &ChartRegion) -> Self {
869        Self {
870            value: family.value_sup(chart),
871            jacobian: family.jacobian_sup(chart),
872            hessian: family.hessian_sup(chart),
873            third: family.third_sup(chart),
874        }
875    }
876}
877
878/// Evaluate one atom's encode objective gradient `F(t) = ∇f_k(t)` and the FULL
879/// Hessian `F'(t) = ∇²f_k(t)` at a single coordinate `t`, for a single target
880/// row `x` and fixed amplitude `z`. With `m(t) = z·BᵀΦ(t)`, `r = m − x`,
881/// `J_m = z·Bᵀ J_Φ`:
882///
883/// ```text
884/// g_t[a]   = J_m[a] · r                                  (= ∇f)
885/// H_tt[a,b] = J_m[a] · J_m[b] + r · ∂²m/∂t_a∂t_b         (= ∇²f, FULL Hessian)
886/// ```
887///
888/// The certificate uses the FULL Hessian rather than the Gauss-Newton block
889/// `J_mᵀ J_m`. This is the principled choice for Newton–Kantorovich: the
890/// theorem certifies convergence of Newton on `F = ∇f` to the unique nearby
891/// ROOT of `∇f`, but a root of `∇f` can be a maximum. The full Hessian is
892/// positive-definite exactly on the genuine-minimum basin, so requiring
893/// `λ_min(H) > 0` (finite `β`) is what flags a start that would otherwise let
894/// Gauss-Newton march into the wrong root (e.g. the circle antipode, a local
895/// max where `∇f = 0` but the full curvature is negative). The residual term
896/// needs the basis second jet `∂²Φ/∂t²`; an evaluator without one returns
897/// `None`, and the row is flagged (no silent Gauss-Newton fallback).
898pub(crate) fn encode_grad_hess(
899    atom: &SaeManifoldAtom,
900    evaluator: &dyn SaeBasisEvaluator,
901    t: ArrayView1<'_, f64>,
902    x: ArrayView1<'_, f64>,
903    amplitude: f64,
904    ridge: f64,
905) -> Result<Option<(Array1<f64>, Array2<f64>)>, String> {
906    let d = atom.latent_dim;
907    let p = atom.output_dim();
908    let m = atom.basis_size();
909    let coords = t.to_shape((1, d)).map_err(|e| e.to_string())?.to_owned();
910    let (phi, jet) = evaluator.evaluate(coords.view())?;
911    if phi.dim() != (1, m) {
912        return Err(format!(
913            "encode_grad_hess: evaluator returned phi {:?}, expected (1, {m})",
914            phi.dim()
915        ));
916    }
917    let decoder = &atom.decoder_coefficients;
918    // Reconstruction m(t) = z · Bᵀ Φ(t)  ∈ ℝᵖ.
919    let mut recon = Array1::<f64>::zeros(p);
920    for basis_col in 0..m {
921        let phi_v = phi[[0, basis_col]];
922        if phi_v == 0.0 {
923            continue;
924        }
925        for out in 0..p {
926            recon[out] += amplitude * phi_v * decoder[[basis_col, out]];
927        }
928    }
929    let residual = &recon - &x;
930    // J_m[axis] = z · Bᵀ (∂Φ/∂t_axis)  ∈ ℝᵖ.
931    let mut jm = Array2::<f64>::zeros((d, p));
932    for axis in 0..d {
933        for basis_col in 0..m {
934            let dphi = jet[[0, basis_col, axis]];
935            if dphi == 0.0 {
936                continue;
937            }
938            for out in 0..p {
939                jm[[axis, out]] += amplitude * dphi * decoder[[basis_col, out]];
940            }
941        }
942    }
943    // The full-Hessian residual term needs ∂²Φ/∂t². No second jet ⇒ no
944    // certificate (flag), never a silent Gauss-Newton substitute.
945    let second = match evaluator.second_jet_dyn(coords.view()) {
946        Some(result) => result?,
947        None => return Ok(None),
948    };
949    // g_t[axis] = J_m[axis] · r ;  H_tt[a,b] = J_m[a]·J_m[b] + r·∂²m/∂t_a∂t_b.
950    let mut g = Array1::<f64>::zeros(d);
951    let mut h = Array2::<f64>::zeros((d, d));
952    for a in 0..d {
953        let ja = jm.row(a);
954        g[a] = ja.dot(&residual);
955        for b in 0..d {
956            // Gauss-Newton block.
957            let mut hab = ja.dot(&jm.row(b));
958            // Residual · second-jet curvature: r · ∂²m_{ab},
959            // ∂²m_{ab}[out] = z · Σ_basis (∂²Φ/∂t_a∂t_b) · B[basis, out].
960            let mut curv = 0.0;
961            for basis_col in 0..m {
962                let d2phi = second[[0, basis_col, a, b]];
963                if d2phi == 0.0 {
964                    continue;
965                }
966                let mut dot = 0.0;
967                for out in 0..p {
968                    dot += residual[out] * decoder[[basis_col, out]];
969                }
970                curv += amplitude * d2phi * dot;
971            }
972            hab += curv;
973            h[[a, b]] = hab;
974        }
975    }
976    for a in 0..d {
977        h[[a, a]] += ridge;
978    }
979    Ok(Some((g, h)))
980}
981
982/// Operator-norm of `H⁻¹` (i.e. `β = 1/λ_min(H)`) and the Newton step
983/// `δ = −H⁻¹ g` with `η = ‖δ‖`, from a symmetric PSD `H` and gradient `g`.
984/// Returns `None` when `H` is numerically singular (λ_min ≤ 0) — an
985/// uncertifiable start.
986pub(crate) fn beta_eta_newton(
987    h: ArrayView2<'_, f64>,
988    g: ArrayView1<'_, f64>,
989) -> Result<Option<(f64, f64, Array1<f64>)>, String> {
990    let (vals, vecs) = h
991        .eigh(Side::Lower)
992        .map_err(|e| format!("beta_eta_newton: eigh failed: {e:?}"))?;
993    let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
994    if !(lambda_min.is_finite() && lambda_min > 0.0) {
995        return Ok(None);
996    }
997    let beta = 1.0 / lambda_min;
998    // Newton step δ = −H⁻¹ g via the eigendecomposition: δ = −Σ_i (vᵢᵀg/λᵢ) vᵢ.
999    let d = h.nrows();
1000    let mut delta = Array1::<f64>::zeros(d);
1001    for (col, &lam) in vals.iter().enumerate() {
1002        if lam <= 0.0 {
1003            return Ok(None);
1004        }
1005        let vi = vecs.column(col);
1006        let coeff = vi.dot(&g) / lam;
1007        for row in 0..d {
1008            delta[row] -= coeff * vi[row];
1009        }
1010    }
1011    let eta = delta.dot(&delta).sqrt();
1012    Ok(Some((beta, eta, delta)))
1013}
1014
1015/// Compute the per-row Kantorovich certificate for encoding target row `x`
1016/// against atom `atom` at start coordinate `t₀`, with fixed amplitude `z` and
1017/// the chart's closed-form Lipschitz constant `lipschitz`. Returns the
1018/// certificate AND the Newton step `δ = −H⁻¹ g` so the caller can advance.
1019pub fn row_certificate(
1020    atom: &SaeManifoldAtom,
1021    evaluator: &dyn SaeBasisEvaluator,
1022    t0: ArrayView1<'_, f64>,
1023    x: ArrayView1<'_, f64>,
1024    amplitude: f64,
1025    lipschitz: f64,
1026    ridge: f64,
1027) -> Result<(RowCertificate, Array1<f64>), String> {
1028    let uncertified = || {
1029        (
1030            RowCertificate {
1031                beta: f64::INFINITY,
1032                eta: f64::INFINITY,
1033                lipschitz,
1034                h: f64::INFINITY,
1035            },
1036            Array1::<f64>::zeros(atom.latent_dim),
1037        )
1038    };
1039    // No second jet ⇒ no full Hessian ⇒ uncertifiable (flag).
1040    let Some((g, h)) = encode_grad_hess(atom, evaluator, t0, x, amplitude, ridge)? else {
1041        return Ok(uncertified());
1042    };
1043    match beta_eta_newton(h.view(), g.view())? {
1044        Some((beta, eta, delta)) => {
1045            let cert = RowCertificate {
1046                beta,
1047                eta,
1048                lipschitz,
1049                h: beta * eta * lipschitz,
1050            };
1051            Ok((cert, delta))
1052        }
1053        // Indefinite / negative-curvature full Hessian: the start is at or past
1054        // a basin boundary (a max/saddle of f), not the minimum basin — flag.
1055        None => Ok(uncertified()),
1056    }
1057}
1058
1059fn uncertified_certificate(lipschitz: f64) -> RowCertificate {
1060    RowCertificate {
1061        beta: f64::INFINITY,
1062        eta: f64::INFINITY,
1063        lipschitz,
1064        h: f64::INFINITY,
1065    }
1066}
1067
1068fn refine_certified_start(
1069    atom: &SaeManifoldAtom,
1070    evaluator: &dyn SaeBasisEvaluator,
1071    mut t: Array1<f64>,
1072    x: ArrayView1<'_, f64>,
1073    amplitude: f64,
1074    lipschitz: f64,
1075    ridge: f64,
1076    newton_steps: usize,
1077    initial_cert: RowCertificate,
1078    mut delta: Array1<f64>,
1079) -> Result<Option<CertifiedEncodeProbe>, String> {
1080    assert!(initial_cert.certified());
1081    let mut final_cert = initial_cert;
1082    for _ in 0..newton_steps {
1083        t = &t + &delta;
1084        let (cert, next_delta) =
1085            row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz, ridge)?;
1086        if !cert.certified() {
1087            return Ok(None);
1088        }
1089        final_cert = cert;
1090        delta = next_delta;
1091    }
1092    Ok(Some(CertifiedEncodeProbe {
1093        coord: t,
1094        initial_cert,
1095        final_cert,
1096    }))
1097}
1098
1099/// Certify an encode probe from `t_start`, navigating into the Kantorovich basin
1100/// first if needed (#1154/#1026). The Kantorovich quantity `h = β·η·L` scales with
1101/// amplitude through `L`, so at unit amplitude a positive-definite chart-center /
1102/// distilled start can sit OUTSIDE the certified ball (`h > ½`). Rather than
1103/// flagging it uncertified immediately — which made the encoder certify ZERO
1104/// held-out rows at amplitude 1.0 and fall back to the exact solve for everything —
1105/// take plain Newton steps toward the root, re-certifying at each iterate, while
1106/// the Kantorovich quantity `h = β·η·L` keeps CONTRACTING toward the ½ bound. The
1107/// certificate at the landing point is a full Kantorovich guarantee from there
1108/// (`h ≤ ½` ⇒ Newton converges to the in-ball root), so this only ever WIDENS the
1109/// certified set; it never certifies a non-convergent start.
1110///
1111/// Termination is the natural Newton stopping rule — there is no arbitrary step
1112/// budget. The warm-up stops and flags for the exact fallback when either the start
1113/// is not steppable (indefinite / non-finite Hessian — at or past a basin boundary)
1114/// or a step fails to reduce `h` (the iterate is not approaching a certifiable
1115/// in-chart root: its root lies outside this chart's valid Lipschitz region, or the
1116/// start was past the basin — empirically the rows that miss *plateau*, so more
1117/// steps cannot help; the lever there is denser charts, not more iterations). On
1118/// success the start is refined `newton_steps` further by [`refine_certified_start`].
1119fn certify_with_basin_warmup(
1120    atom: &SaeManifoldAtom,
1121    evaluator: &dyn SaeBasisEvaluator,
1122    t_start: Array1<f64>,
1123    x: ArrayView1<'_, f64>,
1124    amplitude: f64,
1125    lipschitz: f64,
1126    ridge: f64,
1127    newton_steps: usize,
1128    chart_center: ArrayView1<'_, f64>,
1129    chart_radius: f64,
1130) -> Result<Option<CertifiedEncodeProbe>, String> {
1131    // SOUNDNESS GUARD: `lipschitz` is the chart's Hessian-Lipschitz sup, which is
1132    // only a valid bound over this chart's ball `‖t − center‖ ≤ radius` for the
1133    // chart-local families (`EuclideanPatch`/`Linear`/`Poincare` monomial patches,
1134    // `Cylinder` line axis, `Duchon` radial kernels). If a warm-up iterate leaves
1135    // that ball, `row_certificate` would compute `h = β·η·L` with an `L` that no
1136    // longer bounds the true geometry there, so `h ≤ ½` would NOT imply Kantorovich
1137    // convergence — a false certificate. (The `h`-contraction check does NOT catch
1138    // this: `h` can decrease monotonically toward an out-of-chart root the whole
1139    // way.) So we keep every certified iterate inside the chart; a row whose root is
1140    // outside this chart flags for the exact fallback — its lever is a denser grid,
1141    // not a step using an invalid `L`. Global-`L` families (periodic/torus/sphere)
1142    // route their points to charts whose centers are near the root, so the guard
1143    // rarely trips for them, and where it does the row was out-of-chart anyway.
1144    let in_chart = |t: &Array1<f64>| -> bool {
1145        let r2: f64 = t
1146            .iter()
1147            .zip(chart_center.iter())
1148            .map(|(a, b)| (a - b) * (a - b))
1149            .sum();
1150        r2 <= chart_radius * chart_radius
1151    };
1152    let mut t = t_start;
1153    // The distilled / chart-center start must itself be in-chart for its certificate
1154    // to be valid; a bad IFT prediction landing outside the chart is uncertifiable.
1155    if !in_chart(&t) {
1156        return Ok(None);
1157    }
1158    let (mut cert, mut delta) =
1159        row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz, ridge)?;
1160    while !cert.certified() {
1161        // Not steppable (indefinite / non-finite Hessian): flag.
1162        if !(cert.h.is_finite() && cert.beta.is_finite() && cert.eta.is_finite()) {
1163            return Ok(None);
1164        }
1165        let prev_h = cert.h;
1166        let next = &t + &delta;
1167        // Refuse to step where the chart's `L` is no longer valid (see guard above).
1168        if !in_chart(&next) {
1169            return Ok(None);
1170        }
1171        t = next;
1172        let (next_cert, next_delta) =
1173            row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz, ridge)?;
1174        cert = next_cert;
1175        delta = next_delta;
1176        // The warm-up only helps while h keeps contracting toward ½. Once a step
1177        // fails to reduce it, the iterate is not converging to a certifiable in-chart
1178        // root — flag for the exact fallback (no arbitrary step budget).
1179        if !cert.h.is_finite() || cert.h >= prev_h {
1180            return Ok(None);
1181        }
1182    }
1183    refine_certified_start(
1184        atom,
1185        evaluator,
1186        t,
1187        x,
1188        amplitude,
1189        lipschitz,
1190        ridge,
1191        newton_steps,
1192        cert,
1193        delta,
1194    )
1195}
1196
1197fn kantorovich_root_radius(cert: RowCertificate) -> f64 {
1198    if !cert.certified() || !(cert.eta.is_finite() && cert.eta >= 0.0) {
1199        return f64::INFINITY;
1200    }
1201    if cert.eta == 0.0 {
1202        return 0.0;
1203    }
1204    if !(cert.h.is_finite() && cert.h >= 0.0) {
1205        return f64::INFINITY;
1206    }
1207    let h = cert.h.min(KANTOROVICH_THRESHOLD);
1208    let discriminant = (1.0 - 2.0 * h).max(0.0).sqrt();
1209    let radius = 2.0 * cert.eta / (1.0 + discriminant);
1210    if radius.is_finite() {
1211        radius
1212    } else {
1213        f64::INFINITY
1214    }
1215}
1216
1217fn distilled_probe_tolerance(
1218    amortized: &CertifiedEncodeProbe,
1219    cold: &CertifiedEncodeProbe,
1220    amplitude: f64,
1221    x: ArrayView1<'_, f64>,
1222) -> f64 {
1223    let certified_radius =
1224        kantorovich_root_radius(amortized.final_cert) + kantorovich_root_radius(cold.final_cert);
1225    let coord_scale = amortized.coord.dot(&amortized.coord).sqrt()
1226        + cold.coord.dot(&cold.coord).sqrt()
1227        + x.dot(&x).sqrt()
1228        + amplitude.abs()
1229        + 1.0;
1230    certified_radius + 1024.0 * f64::EPSILON * coord_scale
1231}
1232
1233fn latent_coordinate_distance(
1234    atom: &SaeManifoldAtom,
1235    lhs: ArrayView1<'_, f64>,
1236    rhs: ArrayView1<'_, f64>,
1237) -> f64 {
1238    let mut acc = 0.0;
1239    for axis in 0..lhs.len().min(rhs.len()) {
1240        let mut diff = (lhs[axis] - rhs[axis]).abs();
1241        if let Some(period) = latent_axis_period(atom, axis) {
1242            let wrapped = diff.rem_euclid(period);
1243            diff = wrapped.min(period - wrapped);
1244        }
1245        acc += diff * diff;
1246    }
1247    acc.sqrt()
1248}
1249
1250fn latent_axis_period(atom: &SaeManifoldAtom, axis: usize) -> Option<f64> {
1251    use crate::manifold::SaeAtomBasisKind::*;
1252    match &atom.basis_kind {
1253        Periodic | Torus => Some(1.0),
1254        Cylinder if axis == 0 => Some(1.0),
1255        Sphere if axis == 1 => Some(std::f64::consts::TAU),
1256        _ => None,
1257    }
1258}
1259
1260/// Configuration for [`EncodeAtlas`] construction and online encode. All fields
1261/// are explicit; the atlas never reads global state and adds no CLI flags.
1262#[derive(Debug, Clone, Copy)]
1263pub struct AtlasConfig {
1264    /// Grid resolution per latent axis for offline chart centers (the
1265    /// SHAPE_BAND grid idiom).
1266    pub grid_resolution: usize,
1267    /// Levenberg ridge floor added to the per-row Gauss-Newton Hessian.
1268    pub ridge: f64,
1269    /// Number of online Newton refinement steps after a certified start (1 or 2
1270    /// per issue #1010).
1271    pub newton_steps: usize,
1272}
1273
1274impl Default for AtlasConfig {
1275    fn default() -> Self {
1276        Self {
1277            grid_resolution: 16,
1278            ridge: 1.0e-9,
1279            newton_steps: 2,
1280        }
1281    }
1282}
1283
1284/// The encode atlas: per-atom certified charts plus the online certified-encode
1285/// driver (issue #1010).
1286#[derive(Debug, Clone)]
1287pub struct EncodeAtlas {
1288    pub atoms: Vec<AtomEncodeAtlas>,
1289    pub config: AtlasConfig,
1290}
1291
1292impl EncodeAtlas {
1293    /// Build the offline atlas over a frozen dictionary: for each atom, lay down
1294    /// chart centers on the atom's coordinate grid and certify a Newton radius
1295    /// from the Kantorovich inequality at the worst-case in-chart start.
1296    ///
1297    /// `amplitude_bound[k]` is the per-atom bound on `|z_k|` used to scale the
1298    /// reconstruction jets (the offline `L` must hold for the largest amplitude
1299    /// the encode can produce); `target_norm_bound` bounds `‖x‖` over the data.
1300    pub fn build(
1301        atoms: &[SaeManifoldAtom],
1302        amplitude_bound: &[f64],
1303        target_norm_bound: f64,
1304        config: AtlasConfig,
1305    ) -> Result<Self, String> {
1306        if amplitude_bound.len() != atoms.len() {
1307            return Err(format!(
1308                "EncodeAtlas::build: amplitude_bound length {} != atom count {}",
1309                amplitude_bound.len(),
1310                atoms.len()
1311            ));
1312        }
1313        let mut atom_atlases = Vec::with_capacity(atoms.len());
1314        for (k, atom) in atoms.iter().enumerate() {
1315            let atlas =
1316                Self::build_atom_atlas(k, atom, amplitude_bound[k], target_norm_bound, &config)?;
1317            atom_atlases.push(atlas);
1318        }
1319        Ok(Self {
1320            atoms: atom_atlases,
1321            config,
1322        })
1323    }
1324
1325    pub(crate) fn build_atom_atlas(
1326        atom_index: usize,
1327        atom: &SaeManifoldAtom,
1328        amplitude_bound: f64,
1329        target_norm_bound: f64,
1330        config: &AtlasConfig,
1331    ) -> Result<AtomEncodeAtlas, String> {
1332        let centers = chart_center_grid(atom, config.grid_resolution);
1333        // Half the inter-center spacing is the natural in-chart radius so the
1334        // charts tile the grid without gaps; refined below if the certificate
1335        // fails at that radius. One uniform radius for the regular grid.
1336        let nominal_radius = chart_nominal_radius(atom, config.grid_resolution);
1337        let radii = vec![nominal_radius; centers.nrows()];
1338        Self::build_atom_atlas_from_centers(
1339            atom_index,
1340            atom,
1341            centers.view(),
1342            &radii,
1343            amplitude_bound,
1344            target_norm_bound,
1345            config,
1346        )
1347    }
1348
1349    /// Build a per-atom atlas from EXPLICIT chart centers with a per-center
1350    /// nominal radius — the geometry-agnostic core shared by the regular-grid
1351    /// [`Self::build_atom_atlas`] and the data-driven [`Self::build_data_driven`].
1352    /// Every chart is certified identically (Kantorovich radius from the in-chart
1353    /// curvature at its center); only the center PLACEMENT and per-center radius
1354    /// differ. `radii[c]` is the nominal in-chart radius for `centers[c]`.
1355    pub(crate) fn build_atom_atlas_from_centers(
1356        atom_index: usize,
1357        atom: &SaeManifoldAtom,
1358        centers: ArrayView2<'_, f64>,
1359        radii: &[f64],
1360        amplitude_bound: f64,
1361        target_norm_bound: f64,
1362        config: &AtlasConfig,
1363    ) -> Result<AtomEncodeAtlas, String> {
1364        let d = atom.latent_dim;
1365        if centers.ncols() != d {
1366            return Err(format!(
1367                "build_atom_atlas_from_centers: centers have {} cols but atom latent_dim is {d}",
1368                centers.ncols()
1369            ));
1370        }
1371        if radii.len() != centers.nrows() {
1372            return Err(format!(
1373                "build_atom_atlas_from_centers: {} radii != {} centers",
1374                radii.len(),
1375                centers.nrows()
1376            ));
1377        }
1378        let decoder_norm_sum = decoder_row_norm_sum(atom.decoder_coefficients.view());
1379        let mut charts = Vec::with_capacity(centers.nrows());
1380        for c in 0..centers.nrows() {
1381            let center = centers.row(c).to_owned();
1382            let nominal_radius = radii[c];
1383            let region = chart_region(atom, center.clone(), nominal_radius);
1384            let sups = family_jet_sups(atom, &region)?;
1385            let recon_sups = reconstruction_jet_sups(atom, sups);
1386            let lipschitz =
1387                hessian_lipschitz_constant(recon_sups, amplitude_bound, target_norm_bound, 0.0);
1388            // β at the chart center bounds the worst-case in-chart curvature
1389            // (the Gauss-Newton Hessian is continuous; the certified radius is
1390            // solved so the certificate is robust to the start within the ball).
1391            let beta_center = match center_beta(atom, &center, config.ridge) {
1392                Some(b) => b,
1393                None => {
1394                    // Degenerate center curvature: no certifiable chart here, and
1395                    // no amortized Jacobian (the same singular Gauss–Newton block).
1396                    charts.push(CertifiedChart {
1397                        region,
1398                        lipschitz,
1399                        beta_center: f64::INFINITY,
1400                        certified_radius: 0.0,
1401                        amortized_jacobian: None,
1402                        recon_center: Array1::<f64>::zeros(atom.output_dim()),
1403                    });
1404                    continue;
1405                }
1406            };
1407            // Distill the amortized-encoder Jacobian at this center (#1026 ladder
1408            // item 3): the IFT derivative of the encode map, precomputed offline
1409            // so the online encode is one mat-vec. A finite `beta_center` (above)
1410            // means the Gauss–Newton block is non-singular, so this succeeds
1411            // alongside it; the pair travels together on the chart.
1412            let (amortized_jacobian, recon_center) =
1413                match center_amortized_jacobian(atom, &center, config.ridge) {
1414                    Some((a1, m1)) => (Some(a1), m1),
1415                    None => (None, Array1::<f64>::zeros(atom.output_dim())),
1416                };
1417            // Certified radius from h = β·η·L ≤ ½ with η ≤ R (Newton step length
1418            // is bounded by the start distance to the root, itself ≤ chart
1419            // radius at worst): R_c = ½ / (β·L), capped at the nominal radius.
1420            let certified_radius = if lipschitz > 0.0 && beta_center.is_finite() {
1421                (0.5 / (beta_center * lipschitz)).min(region.radius)
1422            } else {
1423                region.radius
1424            };
1425            charts.push(CertifiedChart {
1426                region,
1427                lipschitz,
1428                beta_center,
1429                certified_radius,
1430                amortized_jacobian,
1431                recon_center,
1432            });
1433        }
1434        Ok(AtomEncodeAtlas {
1435            atom_index,
1436            latent_dim: d,
1437            decoder_norm_sum,
1438            charts,
1439        })
1440    }
1441
1442    /// Build the atlas with DATA-DRIVEN chart placement: instead of a dense
1443    /// `resolution^d` product grid (exponential in latent dim `d`, so the regular
1444    /// [`Self::build`] is forced to coarse, poorly-certified charts for `d ≥ 3`),
1445    /// place a bounded number of charts AT the data's own latent coordinates. The
1446    /// chart count is then `O(max_charts)` regardless of `d`, and every chart sits
1447    /// where data actually lands (small in-chart residual → certifies), so
1448    /// higher-dimensional atoms — which reconstruct real activations far better per
1449    /// parameter — become affordable and well-covered.
1450    ///
1451    /// `coords[k]` is atom `k`'s `n × d_k` latent coordinates (the seed coords, or
1452    /// a previous encode's output). Charts are chosen by greedy farthest-point
1453    /// sampling over those coords (deterministic, coverage-maximizing), capped at
1454    /// `max_charts`. Each chart's nominal radius is half the distance to its
1455    /// nearest neighbor center, so the charts tile the local data density. The
1456    /// per-chart Kantorovich certification is IDENTICAL to the regular grid — only
1457    /// the center placement differs.
1458    pub fn build_data_driven(
1459        atoms: &[SaeManifoldAtom],
1460        coords: &[Array2<f64>],
1461        amplitude_bound: &[f64],
1462        target_norm_bound: f64,
1463        max_charts: usize,
1464        config: AtlasConfig,
1465    ) -> Result<Self, String> {
1466        if amplitude_bound.len() != atoms.len() || coords.len() != atoms.len() {
1467            return Err(format!(
1468                "build_data_driven: amplitude_bound {} / coords {} must match atom count {}",
1469                amplitude_bound.len(),
1470                coords.len(),
1471                atoms.len()
1472            ));
1473        }
1474        let mut atom_atlases = Vec::with_capacity(atoms.len());
1475        for (k, atom) in atoms.iter().enumerate() {
1476            let (centers, radii) =
1477                data_driven_chart_centers(atom, coords[k].view(), max_charts.max(1))?;
1478            let atlas = Self::build_atom_atlas_from_centers(
1479                k,
1480                atom,
1481                centers.view(),
1482                &radii,
1483                amplitude_bound[k],
1484                target_norm_bound,
1485                &config,
1486            )?;
1487            atom_atlases.push(atlas);
1488        }
1489        Ok(Self {
1490            atoms: atom_atlases,
1491            config,
1492        })
1493    }
1494
1495    fn refine_certified_encode_start(
1496        &self,
1497        atom: &SaeManifoldAtom,
1498        evaluator: &dyn SaeBasisEvaluator,
1499        chart: &CertifiedChart,
1500        t: Array1<f64>,
1501        x: ArrayView1<'_, f64>,
1502        amplitude: f64,
1503    ) -> Result<(Array1<f64>, RowCertificate), String> {
1504        // Certify from the warm start, navigating into the Kantorovich basin first
1505        // if the unit-amplitude start has h > ½ (see `certify_with_basin_warmup`).
1506        let Some(probe) = certify_with_basin_warmup(
1507            atom,
1508            evaluator,
1509            t,
1510            x,
1511            amplitude,
1512            chart.lipschitz,
1513            self.config.ridge,
1514            self.config.newton_steps,
1515            chart.region.center.view(),
1516            chart.region.radius,
1517        )?
1518        else {
1519            return Ok((
1520                Array1::<f64>::zeros(atom.latent_dim),
1521                uncertified_certificate(chart.lipschitz),
1522            ));
1523        };
1524        Ok((probe.coord, probe.initial_cert))
1525    }
1526
1527    /// Online certified encode of one target row `x` against one atom `k` with
1528    /// fixed amplitude `z`. Routes to the nearest chart, starts from that chart's
1529    /// distilled IFT warm start, runs `config.newton_steps` Newton steps, and
1530    /// returns the encoded coordinate with its certificate. An uncertified start
1531    /// (no chart, no distilled Jacobian, non-positive amplitude, or `h > ½`)
1532    /// flags the row for the exact multi-start caller.
1533    pub fn certified_encode_row(
1534        &self,
1535        atom: &SaeManifoldAtom,
1536        atom_index: usize,
1537        x: ArrayView1<'_, f64>,
1538        amplitude: f64,
1539    ) -> Result<(Array1<f64>, RowCertificate), String> {
1540        let atom_atlas = self
1541            .atoms
1542            .get(atom_index)
1543            .ok_or_else(|| format!("certified_encode_row: atom {atom_index} not in atlas"))?;
1544        let d = atom.latent_dim;
1545        // A missing basis evaluator means the amortized/cold predictor cannot fire
1546        // for this atom (e.g. a frozen-baseline or first-build atom that never
1547        // attached a distilled evaluator). That is exactly the "cannot certify"
1548        // state — flag the row uncertified (zeros coords, ∞ certificate) so the
1549        // upstream exact multi-start solve owns it, never a hard error that aborts
1550        // the whole criterion. Mirrors the no-chart / singular-Jacobian branches.
1551        let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
1552            return Ok((
1553                Array1::<f64>::zeros(d),
1554                RowCertificate {
1555                    beta: f64::INFINITY,
1556                    eta: f64::INFINITY,
1557                    lipschitz: f64::INFINITY,
1558                    h: f64::INFINITY,
1559                },
1560            ));
1561        };
1562
1563        // Route to the nearest chart centers by AMBIENT reconstruction distance.
1564        // A single nearest chart is NOT globally sound on self-approaching atoms:
1565        // where the decoded manifold folds near itself (two distant latent points
1566        // map near the same output), the nearest-center chart can certify into the
1567        // locally-worse basin while another chart holds the GLOBAL minimum (both
1568        // branches' charts reconstruct near the crossing, so both are near in
1569        // ambient distance). The certificate is honest about LOCAL convergence but
1570        // cannot see the better far basin. So we refine in the top-K nearest charts
1571        // and keep the lowest-reconstruction-error CERTIFIED result. For a unimodal
1572        // atom every candidate chart converges to the same root, so this is a no-op
1573        // (first-wins tie → the nearest chart), preserving the existing behavior.
1574        let candidates =
1575            nearest_charts_topk(atom_atlas, x, atom, evaluator.as_ref(), CERTIFIED_ROUTING_TOPK);
1576        if candidates.is_empty() {
1577            return Ok((
1578                Array1::<f64>::zeros(d),
1579                RowCertificate {
1580                    beta: f64::INFINITY,
1581                    eta: f64::INFINITY,
1582                    lipschitz: f64::INFINITY,
1583                    h: f64::INFINITY,
1584                },
1585            ));
1586        }
1587        // Best CERTIFIED result by reconstruction error, plus the nearest chart's
1588        // result as the uncertified fallback (preserving the prior return when no
1589        // candidate certifies — the nearest chart owns the flagged row).
1590        let mut best: Option<(Array1<f64>, RowCertificate, f64)> = None;
1591        let mut nearest_fallback: Option<(Array1<f64>, RowCertificate)> = None;
1592        for chart_idx in candidates {
1593            let chart = &atom_atlas.charts[chart_idx];
1594            let Some(t) = amortized_warm_start(chart, x, amplitude) else {
1595                if nearest_fallback.is_none() {
1596                    nearest_fallback =
1597                        Some((Array1::<f64>::zeros(d), uncertified_certificate(chart.lipschitz)));
1598                }
1599                continue;
1600            };
1601            let (coord, cert) = self.refine_certified_encode_start(
1602                atom,
1603                evaluator.as_ref(),
1604                chart,
1605                t,
1606                x,
1607                amplitude,
1608            )?;
1609            if nearest_fallback.is_none() {
1610                nearest_fallback = Some((coord.clone(), cert.clone()));
1611            }
1612            if cert.certified() {
1613                let err =
1614                    encode_reconstruction_error(atom, evaluator.as_ref(), coord.view(), x, amplitude);
1615                if best.as_ref().map(|(_, _, e)| err < *e).unwrap_or(true) {
1616                    best = Some((coord, cert, err));
1617                }
1618            }
1619        }
1620        match best {
1621            Some((coord, cert, _)) => Ok((coord, cert)),
1622            None => Ok(nearest_fallback.unwrap_or_else(|| {
1623                (
1624                    Array1::<f64>::zeros(d),
1625                    RowCertificate {
1626                        beta: f64::INFINITY,
1627                        eta: f64::INFINITY,
1628                        lipschitz: f64::INFINITY,
1629                        h: f64::INFINITY,
1630                    },
1631                )
1632            })),
1633        }
1634    }
1635
1636    /// Amortized (distilled) encode of one target row `x` against one atom `k`
1637    /// with fixed amplitude `z` (#1026 ladder item 3).
1638    ///
1639    /// Routes to the nearest chart, then predicts the latent coordinate in CLOSED
1640    /// FORM from that chart's precomputed implicit-function-theorem Jacobian:
1641    ///
1642    /// ```text
1643    /// t̂ = t_c + (1/z) · A₁ · (x − z · m₁(t_c)),
1644    /// ```
1645    ///
1646    /// a single `O(d·p)` mat-vec — no per-row Hessian factorization or
1647    /// eigendecomposition, which is the amortization. The Kantorovich
1648    /// certificate is then evaluated AT the predicted start `t̂` with the chart's
1649    /// closed-form Lipschitz constant. A prediction is accepted only when that
1650    /// certificate holds, an independent cold chart-center probe also certifies,
1651    /// and the two refined coordinates agree within the two probes' final
1652    /// Kantorovich root-radius bounds. This keeps the distilled path honest
1653    /// without letting the exact probe reuse the distilled warm start it is
1654    /// auditing. A chart without a distilled Jacobian (singular Gauss–Newton
1655    /// block) flags the row.
1656    pub fn amortized_encode_row(
1657        &self,
1658        atom: &SaeManifoldAtom,
1659        atom_index: usize,
1660        x: ArrayView1<'_, f64>,
1661        amplitude: f64,
1662    ) -> Result<(Array1<f64>, RowCertificate), String> {
1663        let atom_atlas = self
1664            .atoms
1665            .get(atom_index)
1666            .ok_or_else(|| format!("amortized_encode_row: atom {atom_index} not in atlas"))?;
1667        let d = atom.latent_dim;
1668        let uncertified = || {
1669            (
1670                Array1::<f64>::zeros(d),
1671                RowCertificate {
1672                    beta: f64::INFINITY,
1673                    eta: f64::INFINITY,
1674                    lipschitz: f64::INFINITY,
1675                    h: f64::INFINITY,
1676                },
1677            )
1678        };
1679        // A missing basis evaluator means the distilled predictor cannot fire for
1680        // this atom — flag the row uncertified (the exact upstream solve owns it)
1681        // rather than erroring, exactly as the no-chart / singular-Jacobian /
1682        // non-positive-amplitude branches below do. Never a silent wrong encode,
1683        // never a hard abort of the criterion.
1684        let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
1685            return Ok(uncertified());
1686        };
1687        let Some((chart_idx, _)) = nearest_chart(atom_atlas, x, atom, evaluator.as_ref()) else {
1688            return Ok(uncertified());
1689        };
1690        let chart = &atom_atlas.charts[chart_idx];
1691        // Closed-form predicted start t̂ = t_c + (1/z)·A₁·(x − z·m₁). `None` when
1692        // the chart's Gauss–Newton block was singular (no distilled Jacobian, so
1693        // the amortized predictor cannot fire) or the amplitude is not strictly
1694        // positive and finite (a near-inactive atom, where the amplitude-divided
1695        // map is undefined) — either way flag for the exact fallback, never a
1696        // silent wrong encode.
1697        let Some(t_hat) = amortized_warm_start(chart, x, amplitude) else {
1698            return Ok(uncertified());
1699        };
1700        // Evaluate the SAME Kantorovich certificate at the predicted start. The
1701        // amortized prediction is trusted only if this certificate holds AND an
1702        // independent cold chart-center probe certifies and agrees below the
1703        // two probes' final Kantorovich root-radius bounds. This avoids the
1704        // self-referential gate where the "exact" probe is warm-started by the
1705        // same distilled prediction it is supposed to audit.
1706        let Some(amortized_probe) = certify_with_basin_warmup(
1707            atom,
1708            evaluator.as_ref(),
1709            t_hat,
1710            x,
1711            amplitude,
1712            chart.lipschitz,
1713            self.config.ridge,
1714            self.config.newton_steps,
1715            chart.region.center.view(),
1716            chart.region.radius,
1717        )?
1718        else {
1719            return Ok((
1720                Array1::<f64>::zeros(d),
1721                uncertified_certificate(chart.lipschitz),
1722            ));
1723        };
1724
1725        let cold_start = chart.region.center.clone();
1726        let Some(cold_probe) = certify_with_basin_warmup(
1727            atom,
1728            evaluator.as_ref(),
1729            cold_start,
1730            x,
1731            amplitude,
1732            chart.lipschitz,
1733            self.config.ridge,
1734            self.config.newton_steps,
1735            chart.region.center.view(),
1736            chart.region.radius,
1737        )?
1738        else {
1739            return Ok((
1740                amortized_probe.coord,
1741                uncertified_certificate(chart.lipschitz),
1742            ));
1743        };
1744
1745        let gap =
1746            latent_coordinate_distance(atom, amortized_probe.coord.view(), cold_probe.coord.view());
1747        let tolerance = distilled_probe_tolerance(&amortized_probe, &cold_probe, amplitude, x);
1748        if !(gap.is_finite() && gap <= tolerance) {
1749            return Ok((
1750                amortized_probe.coord,
1751                uncertified_certificate(chart.lipschitz),
1752            ));
1753        }
1754        Ok((amortized_probe.coord, amortized_probe.initial_cert))
1755    }
1756
1757    /// Batched amortized (distilled) encode over many rows against one atom
1758    /// (#1026 ladder item 3, corpus-rate). Each row uses the closed-form
1759    /// per-chart Jacobian predictor and carries its own Kantorovich certificate;
1760    /// uncertified rows are flagged in [`EncodeResult::encode_uncertified_count`]
1761    /// for the exact multi-start fallback. Row-independent against the frozen
1762    /// dictionary, so the batch fans out over rows (deterministic row-order
1763    /// assembly, bit-identical run-to-run), staying sequential inside a rayon
1764    /// worker to avoid nested oversubscription.
1765    pub fn amortized_encode_batch(
1766        &self,
1767        atom: &SaeManifoldAtom,
1768        atom_index: usize,
1769        targets: ArrayView2<'_, f64>,
1770        amplitudes: ArrayView1<'_, f64>,
1771    ) -> Result<EncodeResult, String> {
1772        let n = targets.nrows();
1773        if amplitudes.len() != n {
1774            return Err(format!(
1775                "amortized_encode_batch: amplitudes len {} != rows {n}",
1776                amplitudes.len()
1777            ));
1778        }
1779        let d = atom.latent_dim;
1780        let encode_rows =
1781            |range: std::ops::Range<usize>| -> Result<Vec<(Array1<f64>, bool)>, String> {
1782                range
1783                    .map(|row| {
1784                        let (t, cert) = self.amortized_encode_row(
1785                            atom,
1786                            atom_index,
1787                            targets.row(row),
1788                            amplitudes[row],
1789                        )?;
1790                        Ok((t, cert.certified()))
1791                    })
1792                    .collect()
1793            };
1794        let rows: Vec<(Array1<f64>, bool)> =
1795            if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
1796                use rayon::prelude::*;
1797                const CHUNK: usize = 256;
1798                let n_chunks = n.div_ceil(CHUNK);
1799                let chunked: Vec<Vec<(Array1<f64>, bool)>> = (0..n_chunks)
1800                    .into_par_iter()
1801                    .map(|c| {
1802                        let start = c * CHUNK;
1803                        let end = (start + CHUNK).min(n);
1804                        encode_rows(start..end)
1805                    })
1806                    .collect::<Result<_, _>>()?;
1807                chunked.into_iter().flatten().collect()
1808            } else {
1809                encode_rows(0..n)?
1810            };
1811        let mut coords = Array2::<f64>::zeros((n, d));
1812        let mut certified = Vec::with_capacity(n);
1813        for (row, (t, cert)) in rows.into_iter().enumerate() {
1814            coords.row_mut(row).assign(&t);
1815            certified.push(cert);
1816        }
1817        Ok(EncodeResult::from_rows(coords, certified))
1818    }
1819
1820    /// Batched certified encode over many rows against one atom (the #988
1821    /// throughput consumer). Each row carries its own certificate; uncertified
1822    /// rows are flagged in [`EncodeResult::encode_uncertified_count`] for the
1823    /// exact multi-start fallback.
1824    pub fn certified_encode_batch(
1825        &self,
1826        atom: &SaeManifoldAtom,
1827        atom_index: usize,
1828        targets: ArrayView2<'_, f64>,
1829        amplitudes: ArrayView1<'_, f64>,
1830    ) -> Result<EncodeResult, String> {
1831        let n = targets.nrows();
1832        if amplitudes.len() != n {
1833            return Err(format!(
1834                "certified_encode_batch: amplitudes len {} != rows {n}",
1835                amplitudes.len()
1836            ));
1837        }
1838        let d = atom.latent_dim;
1839        // Per-row encode is independent against a frozen dictionary (#1010), so
1840        // the corpus-rate batch fans out over rows (#1026 amortized-encoder leg /
1841        // #977 Stage-3 corpus encode). Each row produces an owned `(t, certified)`
1842        // pair; results are assembled back in row order so the output is
1843        // bit-identical run-to-run regardless of thread scheduling. Stay
1844        // sequential inside a rayon worker (e.g. when an outer atom-level fan-out
1845        // owns the pool) to avoid nested oversubscription. The first row that
1846        // fails to encode propagates its error deterministically.
1847        let encode_rows =
1848            |range: std::ops::Range<usize>| -> Result<Vec<(Array1<f64>, bool)>, String> {
1849                range
1850                    .map(|row| {
1851                        let (t, cert) = self.certified_encode_row(
1852                            atom,
1853                            atom_index,
1854                            targets.row(row),
1855                            amplitudes[row],
1856                        )?;
1857                        Ok((t, cert.certified()))
1858                    })
1859                    .collect()
1860            };
1861        let rows: Vec<(Array1<f64>, bool)> =
1862            if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
1863                use rayon::prelude::*;
1864                const CHUNK: usize = 256;
1865                let n_chunks = n.div_ceil(CHUNK);
1866                let chunked: Vec<Vec<(Array1<f64>, bool)>> = (0..n_chunks)
1867                    .into_par_iter()
1868                    .map(|c| {
1869                        let start = c * CHUNK;
1870                        let end = (start + CHUNK).min(n);
1871                        encode_rows(start..end)
1872                    })
1873                    .collect::<Result<_, _>>()?;
1874                chunked.into_iter().flatten().collect()
1875            } else {
1876                encode_rows(0..n)?
1877            };
1878        let mut coords = Array2::<f64>::zeros((n, d));
1879        let mut certified = Vec::with_capacity(n);
1880        for (row, (t, cert)) in rows.into_iter().enumerate() {
1881            coords.row_mut(row).assign(&t);
1882            certified.push(cert);
1883        }
1884        Ok(EncodeResult::from_rows(coords, certified))
1885    }
1886
1887    /// Batched GEMM "fast" amortized encode — the traditional-encoder forward
1888    /// pass, WITH manifolds. For every row this applies the SAME closed-form
1889    /// affine predictor as [`amortized_warm_start`]
1890    /// (`t̂ = t_c + (1/z)·A₁·(x − z·m₁)`), but routed and applied as batched
1891    /// matrix products instead of a per-row loop wrapped in the Kantorovich
1892    /// certificate + basin warmup. NO per-row certificate is taken: this is the
1893    /// speed mode (the certified `*_encode_*` paths remain the accuracy mode).
1894    ///
1895    /// Cost is GEMM-bound: one `(n × p)·(p × d)` decode-distance product for
1896    /// nearest-chart routing (skipped for single-chart atoms) plus, per chart,
1897    /// one `(n_c × p)·(p × d)` predictor product — i.e. `≈ X·Wᵀ`, exactly a
1898    /// dense SAE encoder's forward map.
1899    ///
1900    /// Degenerate rows are handled exactly as `amortized_warm_start` flags them
1901    /// (returns `None` ⇒ zeroed coord here): a missing basis evaluator, a chart
1902    /// whose Gauss–Newton block was singular (`amortized_jacobian == None`), or a
1903    /// non-finite / non-positive amplitude. Those rows are zeroed (never a panic,
1904    /// never a silent wrong encode), and their indices are returned in the
1905    /// `valid` mask so the caller can route them to the exact path if desired.
1906    ///
1907    /// Returns `(coords, valid)` where `coords` is `n × d` and `valid[row]` is
1908    /// `true` iff the amortized predictor fired for that row.
1909    pub fn amortized_encode_batch_fast(
1910        &self,
1911        atom: &SaeManifoldAtom,
1912        atom_index: usize,
1913        x: ArrayView2<'_, f64>,
1914        amplitudes: ArrayView1<'_, f64>,
1915    ) -> Result<(Array2<f64>, Vec<bool>), String> {
1916        let n = x.nrows();
1917        let p = atom.output_dim();
1918        let d = atom.latent_dim;
1919        if x.ncols() != p {
1920            return Err(format!(
1921                "amortized_encode_batch_fast: x has {} cols but atom output dim is {p}",
1922                x.ncols()
1923            ));
1924        }
1925        if amplitudes.len() != n {
1926            return Err(format!(
1927                "amortized_encode_batch_fast: amplitudes len {} != rows {n}",
1928                amplitudes.len()
1929            ));
1930        }
1931        let atom_atlas = self.atoms.get(atom_index).ok_or_else(|| {
1932            format!("amortized_encode_batch_fast: atom {atom_index} not in atlas")
1933        })?;
1934        let mut coords = Array2::<f64>::zeros((n, d));
1935        let mut valid = vec![false; n];
1936
1937        // A missing basis evaluator means the distilled predictor cannot fire for
1938        // this atom — every row is uncertified (zeroed), exactly like the per-row
1939        // `amortized_encode_row` no-evaluator branch.
1940        let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
1941            return Ok((coords, valid));
1942        };
1943
1944        // ── Routing recon-centers (one evaluation per chart, batched). ────────
1945        // `nearest_chart` routes a row to the chart whose center reconstruction
1946        // `m(t_c) = BᵀΦ(t_c)` is closest in ‖·‖², skipping charts with
1947        // `certified_radius <= 0`. Evaluate every candidate center ONCE here
1948        // (chart count ≪ n) and GEMM the recon, reproducing `nearest_chart`'s
1949        // per-chart recon bit-for-bit (same φ·decoder accumulation order).
1950        let valid_charts: Vec<usize> = (0..atom_atlas.charts.len())
1951            .filter(|&c| atom_atlas.charts[c].certified_radius > 0.0)
1952            .collect();
1953        if valid_charts.is_empty() {
1954            return Ok((coords, valid));
1955        }
1956        // Stack candidate centers (C × d) and evaluate the basis in one call.
1957        let mut centers = Array2::<f64>::zeros((valid_charts.len(), d));
1958        for (ci, &c) in valid_charts.iter().enumerate() {
1959            centers
1960                .row_mut(ci)
1961                .assign(&atom_atlas.charts[c].region.center);
1962        }
1963        let (phi_centers, _jet) = evaluator
1964            .evaluate(centers.view())
1965            .map_err(|err| format!("amortized_encode_batch_fast: center eval: {err}"))?;
1966        // recon_centers = Φ_centers · decoder  (C × p), the routing targets.
1967        let recon_centers = phi_centers.dot(&atom.decoder_coefficients);
1968        // Per-chart routing key: route_idx[row] = argmin_c ‖x_row − recon_c‖².
1969        // ‖x − r‖² = ‖x‖² − 2 x·r + ‖r‖²; the ‖x‖² term is row-constant so the
1970        // argmin uses S = X·recon_centersᵀ and the per-chart ‖r‖². First chart
1971        // wins on a tie (strict `<`), matching `nearest_chart`.
1972        let route_idx: Vec<usize> = if valid_charts.len() == 1 {
1973            vec![0usize; n]
1974        } else {
1975            let s = x.dot(&recon_centers.t()); // (n × C)
1976            let r_sq: Vec<f64> = (0..valid_charts.len())
1977                .map(|c| recon_centers.row(c).dot(&recon_centers.row(c)))
1978                .collect();
1979            (0..n)
1980                .map(|row| {
1981                    let mut best_c = 0usize;
1982                    let mut best_d = f64::INFINITY;
1983                    for c in 0..valid_charts.len() {
1984                        let dist = r_sq[c] - 2.0 * s[[row, c]];
1985                        if dist < best_d {
1986                            best_d = dist;
1987                            best_c = c;
1988                        }
1989                    }
1990                    best_c
1991                })
1992                .collect()
1993        };
1994
1995        // ── Per-chart batched affine predictor. ───────────────────────────────
1996        // For rows routed to chart `c` with finite jacobian `A₁` (d × p) and
1997        // center reconstruction `m₁` (= `chart.recon_center`), the predictor is
1998        //   t̂ = t_c − A₁·m₁ + (1/z)·(A₁·x).
1999        // The `A₁·x` term is one `(n_c × p)·(p × d)` GEMM; the `1/z` is a per-row
2000        // scalar; `t_c − A₁·m₁` is a per-chart constant. This reproduces
2001        // `amortized_warm_start` up to FP reassociation of the per-output sum.
2002        for (ci, &c) in valid_charts.iter().enumerate() {
2003            let chart = &atom_atlas.charts[c];
2004            let Some(a1) = chart.amortized_jacobian.as_ref() else {
2005                // Singular Gauss–Newton block: predictor cannot fire — the rows
2006                // routed here stay zeroed/uncertified (same as warm_start `None`).
2007                continue;
2008            };
2009            // Gather rows routed to this chart with a usable amplitude.
2010            let rows_here: Vec<usize> = (0..n)
2011                .filter(|&row| {
2012                    route_idx[row] == ci
2013                        && amplitudes[row].is_finite()
2014                        && amplitudes[row].abs() > 0.0
2015                })
2016                .collect();
2017            if rows_here.is_empty() {
2018                continue;
2019            }
2020            // X_c (n_c × p).
2021            let mut x_c = Array2::<f64>::zeros((rows_here.len(), p));
2022            for (i, &row) in rows_here.iter().enumerate() {
2023                x_c.row_mut(i).assign(&x.row(row));
2024            }
2025            // U = X_c · A₁ᵀ  (n_c × d) — the GEMM.
2026            let u = x_c.dot(&a1.t());
2027            // Per-chart constant base = t_c − A₁·m₁ (length d).
2028            let m1 = &chart.recon_center;
2029            let a1_m1 = a1.dot(m1); // (d)
2030            let base = &chart.region.center - &a1_m1; // (d)
2031            for (i, &row) in rows_here.iter().enumerate() {
2032                let inv_z = 1.0 / amplitudes[row];
2033                for axis in 0..d {
2034                    coords[[row, axis]] = base[axis] + u[[i, axis]] * inv_z;
2035                }
2036                valid[row] = true;
2037            }
2038        }
2039        Ok((coords, valid))
2040    }
2041
2042    /// Fast batched FULL forward pass against one atom: encode → decode, the
2043    /// manifold analogue of a traditional SAE's `x̂ = z·D` (decoder `D`, code `z`).
2044    ///
2045    /// A traditional SAE decodes with one GEMM. The manifold SAE's reconstruction
2046    /// is `m(t̂) = z·Φ(t̂)·B` (module header) — the SAME GEMM `Φ·B`, but the code
2047    /// `Φ(t̂)` is the curved chart basis evaluated at the encoded latent coordinate
2048    /// rather than a flat one-hot. So the fast forward is exactly:
2049    ///   1. [`amortized_encode_batch_fast`] → per-row latent coords `t̂` (one
2050    ///      routing GEMM + one affine GEMM per chart — a traditional `W·x+b`);
2051    ///   2. ONE batched basis evaluation `Φ(t̂)` (the manifold-curvature step a
2052    ///      flat SAE doesn't have — `n×m`);
2053    ///   3. ONE GEMM `recon = Φ(t̂)·B` (`(n×m)·(m×p)` — a traditional decoder
2054    ///      `z·D`), then the per-row amplitude scale `z`.
2055    ///
2056    /// Rows the encoder could not certify-predict (no evaluator / singular
2057    /// Gauss–Newton block / non-finite-or-zero amplitude) are returned as a ZERO
2058    /// reconstruction and flagged `false` in the valid-mask — never a silent wrong
2059    /// decode. The reconstruction of a valid row equals, bit-for-bit up to GEMM
2060    /// reassociation, `z·(Φ(t̂_row)·B)` with `t̂` from the per-row predictor.
2061    pub fn amortized_reconstruct_batch_fast(
2062        &self,
2063        atom: &SaeManifoldAtom,
2064        atom_index: usize,
2065        x: ArrayView2<'_, f64>,
2066        amplitudes: ArrayView1<'_, f64>,
2067    ) -> Result<(Array2<f64>, Vec<bool>), String> {
2068        let n = x.nrows();
2069        let p = atom.output_dim();
2070        // Step 1: batched encode → latent coords (reuses the fast routing+affine).
2071        let (coords, valid) = self.amortized_encode_batch_fast(atom, atom_index, x, amplitudes)?;
2072        let mut recon = Array2::<f64>::zeros((n, p));
2073        // A missing evaluator means no row could encode — every row is zeroed and
2074        // already flagged `false` by the encode; nothing to decode.
2075        let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
2076            return Ok((recon, valid));
2077        };
2078        // Step 2: ONE batched basis evaluation Φ(t̂) over all rows (n × m). Invalid
2079        // rows carry coords = 0 (the chart-origin); we still evaluate them in the
2080        // batch for a single GEMM, then zero their reconstruction below — the basis
2081        // is finite at the origin so this cannot poison the valid rows' GEMM.
2082        let (phi, _jet) = evaluator
2083            .evaluate(coords.view())
2084            .map_err(|err| format!("amortized_reconstruct_batch_fast: basis eval: {err}"))?;
2085        // Step 3: ONE GEMM recon = Φ·B (n × p), then per-row amplitude scale z.
2086        // m(t̂) = z·Φ(t̂)·B, matching the module header and `fill_decoded_row`'s
2087        // `Φ·decoder` accumulation (the amplitude is applied once here).
2088        let decoded = phi.dot(&atom.decoder_coefficients); // (n × p), amplitude-1
2089        for row in 0..n {
2090            if !valid[row] {
2091                continue; // stays zeroed — uncertified, like warm_start `None`.
2092            }
2093            let z = amplitudes[row];
2094            for col in 0..p {
2095                recon[[row, col]] = z * decoded[[row, col]];
2096            }
2097        }
2098        Ok((recon, valid))
2099    }
2100
2101    /// LSH-routed certified encode (issue #1010 step 2 + 3): for each target
2102    /// row, the existing [`SaeCandidateIndex`] (#985/#994) proposes the
2103    /// best-aligned atom by frame alignment to the row direction; the row is then
2104    /// encoded against THAT atom's certified chart atlas. This is the production
2105    /// routing path — the LSH does sublinear atom selection, the atlas does the
2106    /// in-atom nearest-chart routing and the per-row Kantorovich certificate.
2107    ///
2108    /// `atoms[id]` must be aligned with the atlas's `atoms[id]` (same dictionary
2109    /// order the atlas was built from and the sketch/index were built over).
2110    /// A row with no LSH proposal (empty bucket) is flagged uncertified — it
2111    /// routes to the exact multi-start fallback, never a silent wrong encode.
2112    pub fn certified_encode_with_index<S: AtomFrameSketch + Sync>(
2113        &self,
2114        atoms: &[SaeManifoldAtom],
2115        index: &SaeCandidateIndex,
2116        sketch: &S,
2117        targets: ArrayView2<'_, f64>,
2118        amplitudes: ArrayView1<'_, f64>,
2119        latent_dim: usize,
2120    ) -> Result<EncodeResult, String> {
2121        let n = targets.nrows();
2122        if amplitudes.len() != n {
2123            return Err(format!(
2124                "certified_encode_with_index: amplitudes len {} != rows {n}",
2125                amplitudes.len()
2126            ));
2127        }
2128        let budget = auto_candidate_budget(atoms.len().max(1));
2129        // LSH-routed per-row encode is independent across rows (sublinear atom
2130        // selection + frozen-dictionary in-atom Newton), so the corpus-rate batch
2131        // fans out over rows (#1026 amortized-encoder/routing leg / #977 Stage-3).
2132        // `None` coords (no LSH candidate) carry through as a zeroed row flagged
2133        // uncertified — identical to the sequential semantics. Results assemble
2134        // back in row order (bit-identical run-to-run); the first encode error
2135        // propagates deterministically. Stay sequential inside a rayon worker to
2136        // avoid nested oversubscription.
2137        let encode_rows =
2138            |range: std::ops::Range<usize>| -> Result<Vec<Option<(Array1<f64>, bool)>>, String> {
2139                range
2140                    .map(|row| {
2141                        // The row direction is the (unit-tolerant) target; the LSH
2142                        // ranks atoms by how much of that direction lies in each
2143                        // atom's column space. `propose` returns the top-`budget`
2144                        // atom ids by exact frame alignment.
2145                        let proposal = index.propose(sketch, targets.row(row), budget, true);
2146                        let Some(&best_atom) = proposal.proposed.first() else {
2147                            // No LSH candidate: flag for the exact fallback.
2148                            return Ok(None);
2149                        };
2150                        // Routing-confidence gate (#1026): the in-atom certificate
2151                        // attests convergence WITHIN best_atom, not that best_atom is
2152                        // the globally-correct atom. A non-empty but low-alignment
2153                        // gather likely missed the true best atom (LSH recall < 1) —
2154                        // flag it for the exact fallback rather than silently
2155                        // certifying a mis-route. See CANDIDATE_ROUTING_MIN_ALIGNMENT.
2156                        // A NaN alignment — a zero-norm target row, ‖d‖ = 0 — must
2157                        // also flag for the exact fallback, not slip through `<`.
2158                        let routing_alignment = sketch.alignment(best_atom, targets.row(row));
2159                        if !routing_alignment.is_finite()
2160                            || routing_alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
2161                        {
2162                            return Ok(None);
2163                        }
2164                        let atom = atoms.get(best_atom).ok_or_else(|| {
2165                            format!(
2166                                "certified_encode_with_index: proposed atom {best_atom} out of range"
2167                            )
2168                        })?;
2169                        let (t, cert) = self.certified_encode_row(
2170                            atom,
2171                            best_atom,
2172                            targets.row(row),
2173                            amplitudes[row],
2174                        )?;
2175                        // Heterogeneous-atom dictionaries with different latent_dim
2176                        // per atom are not supported by the batched API: the caller
2177                        // declares one shared `latent_dim` for the output tensor.
2178                        // Silently zeroing the coord row while recording a
2179                        // certified=true flag would produce corrupted
2180                        // reconstructions downstream — error loudly instead.
2181                        if t.len() != latent_dim {
2182                            return Err(format!(
2183                                "certified_encode_with_index: atom {best_atom} returned t.len()={} \
2184                                 but declared latent_dim={latent_dim}; heterogeneous-dim \
2185                                 dictionaries are not supported by this batched encode path",
2186                                t.len()
2187                            ));
2188                        }
2189                        Ok(Some((t, cert.certified())))
2190                    })
2191                    .collect()
2192            };
2193        let rows: Vec<Option<(Array1<f64>, bool)>> =
2194            if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
2195                use rayon::prelude::*;
2196                const CHUNK: usize = 256;
2197                let n_chunks = n.div_ceil(CHUNK);
2198                let chunked: Vec<Vec<Option<(Array1<f64>, bool)>>> = (0..n_chunks)
2199                    .into_par_iter()
2200                    .map(|c| {
2201                        let start = c * CHUNK;
2202                        let end = (start + CHUNK).min(n);
2203                        encode_rows(start..end)
2204                    })
2205                    .collect::<Result<_, _>>()?;
2206                chunked.into_iter().flatten().collect()
2207            } else {
2208                encode_rows(0..n)?
2209            };
2210        let mut coords = Array2::<f64>::zeros((n, latent_dim));
2211        let mut certified = Vec::with_capacity(n);
2212        for (row, slot) in rows.into_iter().enumerate() {
2213            match slot {
2214                Some((t, cert)) => {
2215                    coords.row_mut(row).assign(&t);
2216                    certified.push(cert);
2217                }
2218                None => certified.push(false),
2219            }
2220        }
2221        Ok(EncodeResult::from_rows(coords, certified))
2222    }
2223
2224    /// LSH-routed AMORTIZED (distilled) encode — the production token-rate
2225    /// encoder of #1026 ladder item 3. Identical routing to
2226    /// [`Self::certified_encode_with_index`] (LSH proposes the best-aligned atom,
2227    /// the atlas routes to the in-atom nearest chart), but the in-atom encode is
2228    /// the closed-form per-chart Jacobian predictor + certificate gate of
2229    /// [`Self::amortized_encode_row`] rather than the certified Newton-refinement
2230    /// path.
2231    /// This is the deployment path: the distilled affine map produces the encode
2232    /// in one mat-vec, the Kantorovich certificate decides trust-or-fallback per
2233    /// row, and uncertified rows (the adversarial tail the thread expects to
2234    /// concentrate on rare tokens) are flagged for the exact multi-start solve —
2235    /// compute goes where the questions are. Row-independent against the frozen
2236    /// dictionary, so the batch fans out over rows with deterministic row-order
2237    /// assembly (bit-identical run-to-run).
2238    pub fn amortized_encode_with_index<S: AtomFrameSketch + Sync>(
2239        &self,
2240        atoms: &[SaeManifoldAtom],
2241        index: &SaeCandidateIndex,
2242        sketch: &S,
2243        targets: ArrayView2<'_, f64>,
2244        amplitudes: ArrayView1<'_, f64>,
2245        latent_dim: usize,
2246    ) -> Result<EncodeResult, String> {
2247        let n = targets.nrows();
2248        if amplitudes.len() != n {
2249            return Err(format!(
2250                "amortized_encode_with_index: amplitudes len {} != rows {n}",
2251                amplitudes.len()
2252            ));
2253        }
2254        let budget = auto_candidate_budget(atoms.len().max(1));
2255        let encode_rows =
2256            |range: std::ops::Range<usize>| -> Result<Vec<Option<(Array1<f64>, bool)>>, String> {
2257                range
2258                    .map(|row| {
2259                        let proposal = index.propose(sketch, targets.row(row), budget, true);
2260                        let Some(&best_atom) = proposal.proposed.first() else {
2261                            return Ok(None);
2262                        };
2263                        // Routing-confidence gate (#1026): flag low-alignment gathers
2264                        // for the exact fallback so a missed-true-best LSH route is
2265                        // never silently certified. See CANDIDATE_ROUTING_MIN_ALIGNMENT
2266                        // and certified_encode_with_index for the full rationale.
2267                        // A NaN alignment — a zero-norm target row, ‖d‖ = 0 — must
2268                        // also flag for the exact fallback, not slip through `<`.
2269                        let routing_alignment = sketch.alignment(best_atom, targets.row(row));
2270                        if !routing_alignment.is_finite()
2271                            || routing_alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
2272                        {
2273                            return Ok(None);
2274                        }
2275                        let atom = atoms.get(best_atom).ok_or_else(|| {
2276                            format!(
2277                                "amortized_encode_with_index: proposed atom {best_atom} out of range"
2278                            )
2279                        })?;
2280                        let (t, cert) = self.amortized_encode_row(
2281                            atom,
2282                            best_atom,
2283                            targets.row(row),
2284                            amplitudes[row],
2285                        )?;
2286                        if t.len() != latent_dim {
2287                            return Err(format!(
2288                                "amortized_encode_with_index: atom {best_atom} returned t.len()={} \
2289                                 but declared latent_dim={latent_dim}; heterogeneous-dim \
2290                                 dictionaries are not supported by this batched encode path",
2291                                t.len()
2292                            ));
2293                        }
2294                        Ok(Some((t, cert.certified())))
2295                    })
2296                    .collect()
2297            };
2298        let rows: Vec<Option<(Array1<f64>, bool)>> =
2299            if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
2300                use rayon::prelude::*;
2301                const CHUNK: usize = 256;
2302                let n_chunks = n.div_ceil(CHUNK);
2303                let chunked: Vec<Vec<Option<(Array1<f64>, bool)>>> = (0..n_chunks)
2304                    .into_par_iter()
2305                    .map(|c| {
2306                        let start = c * CHUNK;
2307                        let end = (start + CHUNK).min(n);
2308                        encode_rows(start..end)
2309                    })
2310                    .collect::<Result<_, _>>()?;
2311                chunked.into_iter().flatten().collect()
2312            } else {
2313                encode_rows(0..n)?
2314            };
2315        let mut coords = Array2::<f64>::zeros((n, latent_dim));
2316        let mut certified = Vec::with_capacity(n);
2317        for (row, slot) in rows.into_iter().enumerate() {
2318            match slot {
2319                Some((t, cert)) => {
2320                    coords.row_mut(row).assign(&t);
2321                    certified.push(cert);
2322                }
2323                None => certified.push(false),
2324            }
2325        }
2326        Ok(EncodeResult::from_rows(coords, certified))
2327    }
2328
2329    /// LSH-routed FAST amortized encode over the WHOLE dictionary — the
2330    /// multi-atom, corpus-rate analogue of [`Self::amortized_encode_with_index`].
2331    ///
2332    /// `amortized_encode_with_index` routes per row, then runs the per-row
2333    /// closed-form predictor + Kantorovich certificate + cold cross-check on each
2334    /// row independently. This fast variant keeps the SAME sublinear per-row LSH
2335    /// routing (cheap — `index.propose` + the alignment gate), but replaces the
2336    /// per-row predictor with the GEMM-batched [`Self::amortized_encode_batch_fast`]:
2337    /// it GROUPS rows by their proposed atom and runs one batched affine-predictor
2338    /// pass per atom-group (a routing GEMM + a predictor GEMM each), reproducing a
2339    /// traditional SAE's whole-dictionary `W·x+b` throughput. No per-row
2340    /// certificate — this is the speed mode validated as accuracy-parity with the
2341    /// certified solve (`fast_forward_is_accuracy_parity_with_certified`).
2342    ///
2343    /// Returns the per-row latent coords and a valid-mask: `false` for a row with
2344    /// no LSH proposal, a sub-threshold/NaN routing alignment, or one the batched
2345    /// predictor could not fire on (no evaluator / singular Gauss–Newton block /
2346    /// non-finite-or-zero amplitude). Each row is written exactly once (disjoint
2347    /// per-atom groups), so the result is independent of group iteration order.
2348    pub fn amortized_encode_with_index_fast<S: AtomFrameSketch + Sync>(
2349        &self,
2350        atoms: &[SaeManifoldAtom],
2351        index: &SaeCandidateIndex,
2352        sketch: &S,
2353        targets: ArrayView2<'_, f64>,
2354        amplitudes: ArrayView1<'_, f64>,
2355        latent_dim: usize,
2356    ) -> Result<(Array2<f64>, Vec<bool>), String> {
2357        let n = targets.nrows();
2358        if amplitudes.len() != n {
2359            return Err(format!(
2360                "amortized_encode_with_index_fast: amplitudes len {} != rows {n}",
2361                amplitudes.len()
2362            ));
2363        }
2364        let budget = auto_candidate_budget(atoms.len().max(1));
2365        // ── Per-row LSH routing (sublinear), grouped by proposed atom. ──────────
2366        let mut groups: std::collections::HashMap<usize, Vec<usize>> =
2367            std::collections::HashMap::new();
2368        for row in 0..n {
2369            let proposal = index.propose(sketch, targets.row(row), budget, true);
2370            let Some(&best_atom) = proposal.proposed.first() else {
2371                continue;
2372            };
2373            // Same routing-confidence gate as the per-row path: a low-alignment or
2374            // NaN (zero-norm row) gather flags the row for the exact fallback, so a
2375            // missed-true-best LSH route is never silently encoded.
2376            let routing_alignment = sketch.alignment(best_atom, targets.row(row));
2377            if !routing_alignment.is_finite()
2378                || routing_alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
2379            {
2380                continue;
2381            }
2382            groups.entry(best_atom).or_default().push(row);
2383        }
2384
2385        let mut coords = Array2::<f64>::zeros((n, latent_dim));
2386        let mut valid = vec![false; n];
2387        // ── Per-atom batched predictor over each group's rows. ──────────────────
2388        for (atom_idx, rows_here) in groups {
2389            let atom = atoms.get(atom_idx).ok_or_else(|| {
2390                format!("amortized_encode_with_index_fast: proposed atom {atom_idx} out of range")
2391            })?;
2392            if atom.latent_dim != latent_dim {
2393                return Err(format!(
2394                    "amortized_encode_with_index_fast: atom {atom_idx} latent_dim {} != declared \
2395                     {latent_dim}; heterogeneous-dim dictionaries are not supported by this path",
2396                    atom.latent_dim
2397                ));
2398            }
2399            // Gather this group's target rows and amplitudes (contiguous sub-batch).
2400            let p = atom.output_dim();
2401            let mut x_sub = Array2::<f64>::zeros((rows_here.len(), p));
2402            let mut amp_sub = Array1::<f64>::zeros(rows_here.len());
2403            for (i, &row) in rows_here.iter().enumerate() {
2404                x_sub.row_mut(i).assign(&targets.row(row));
2405                amp_sub[i] = amplitudes[row];
2406            }
2407            let (sub_coords, sub_valid) =
2408                self.amortized_encode_batch_fast(atom, atom_idx, x_sub.view(), amp_sub.view())?;
2409            for (i, &row) in rows_here.iter().enumerate() {
2410                if sub_valid[i] {
2411                    coords.row_mut(row).assign(&sub_coords.row(i));
2412                    valid[row] = true;
2413                }
2414            }
2415        }
2416        Ok((coords, valid))
2417    }
2418
2419    /// LSH-routed FAST full forward over the WHOLE dictionary: encode → decode,
2420    /// the multi-atom analogue of [`Self::amortized_reconstruct_batch_fast`]. Same
2421    /// sublinear per-row routing + per-atom grouping as
2422    /// [`Self::amortized_encode_with_index_fast`], but each group is run through
2423    /// the batched reconstruct (`m(t̂) = z·Φ(t̂)·B`) so the result is the per-row
2424    /// reconstruction in the ambient space. Rows that do not route/predict decode
2425    /// to an exact zero reconstruction and are flagged `false`.
2426    pub fn amortized_reconstruct_with_index_fast<S: AtomFrameSketch + Sync>(
2427        &self,
2428        atoms: &[SaeManifoldAtom],
2429        index: &SaeCandidateIndex,
2430        sketch: &S,
2431        targets: ArrayView2<'_, f64>,
2432        amplitudes: ArrayView1<'_, f64>,
2433    ) -> Result<(Array2<f64>, Vec<bool>), String> {
2434        let n = targets.nrows();
2435        let p = targets.ncols();
2436        if amplitudes.len() != n {
2437            return Err(format!(
2438                "amortized_reconstruct_with_index_fast: amplitudes len {} != rows {n}",
2439                amplitudes.len()
2440            ));
2441        }
2442        let budget = auto_candidate_budget(atoms.len().max(1));
2443        let mut groups: std::collections::HashMap<usize, Vec<usize>> =
2444            std::collections::HashMap::new();
2445        for row in 0..n {
2446            let proposal = index.propose(sketch, targets.row(row), budget, true);
2447            let Some(&best_atom) = proposal.proposed.first() else {
2448                continue;
2449            };
2450            let routing_alignment = sketch.alignment(best_atom, targets.row(row));
2451            if !routing_alignment.is_finite()
2452                || routing_alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
2453            {
2454                continue;
2455            }
2456            groups.entry(best_atom).or_default().push(row);
2457        }
2458
2459        let mut recon = Array2::<f64>::zeros((n, p));
2460        let mut valid = vec![false; n];
2461        for (atom_idx, rows_here) in groups {
2462            let atom = atoms.get(atom_idx).ok_or_else(|| {
2463                format!(
2464                    "amortized_reconstruct_with_index_fast: proposed atom {atom_idx} out of range"
2465                )
2466            })?;
2467            if atom.output_dim() != p {
2468                return Err(format!(
2469                    "amortized_reconstruct_with_index_fast: atom {atom_idx} output_dim {} != target \
2470                     dim {p}",
2471                    atom.output_dim()
2472                ));
2473            }
2474            let mut x_sub = Array2::<f64>::zeros((rows_here.len(), p));
2475            let mut amp_sub = Array1::<f64>::zeros(rows_here.len());
2476            for (i, &row) in rows_here.iter().enumerate() {
2477                x_sub.row_mut(i).assign(&targets.row(row));
2478                amp_sub[i] = amplitudes[row];
2479            }
2480            let (sub_recon, sub_valid) = self.amortized_reconstruct_batch_fast(
2481                atom,
2482                atom_idx,
2483                x_sub.view(),
2484                amp_sub.view(),
2485            )?;
2486            for (i, &row) in rows_here.iter().enumerate() {
2487                if sub_valid[i] {
2488                    recon.row_mut(row).assign(&sub_recon.row(i));
2489                    valid[row] = true;
2490                }
2491            }
2492        }
2493        Ok((recon, valid))
2494    }
2495}
2496
2497/// Offline `β = 1/λ_min(H_GN)` at a chart center from the Gauss-Newton block
2498/// `H_GN = J_mᵀ J_m` (residual-free). The offline `β` bounds the curvature the
2499/// online certificate sees: charts are placed where the encode lands, so the
2500/// representative residual is small and `H_GN` is the dominant, residual-free
2501/// curvature estimate. (The online per-row certificate still uses the FULL
2502/// Hessian; this is only the offline radius-sizing curvature.) Returns `None`
2503/// for a degenerate center (`λ_min ≤ 0`), which marks an uncertifiable chart.
2504pub(crate) fn center_beta(atom: &SaeManifoldAtom, center: &Array1<f64>, ridge: f64) -> Option<f64> {
2505    let evaluator = atom.basis_evaluator.as_ref()?.clone();
2506    let d = atom.latent_dim;
2507    let p = atom.output_dim();
2508    let m = atom.basis_size();
2509    let coords = center.view().to_shape((1, d)).ok()?.to_owned();
2510    let (_phi, jet) = evaluator.evaluate(coords.view()).ok()?;
2511    let decoder = &atom.decoder_coefficients;
2512    // J_m[axis] = Bᵀ (∂Φ/∂t_axis) ∈ ℝᵖ (amplitude-1; curvature scales with z²
2513    // and is absorbed conservatively by the amplitude-bounded Lipschitz term).
2514    let mut jm = Array2::<f64>::zeros((d, p));
2515    for axis in 0..d {
2516        for basis_col in 0..m {
2517            let dphi = jet[[0, basis_col, axis]];
2518            if dphi == 0.0 {
2519                continue;
2520            }
2521            for out in 0..p {
2522                jm[[axis, out]] += dphi * decoder[[basis_col, out]];
2523            }
2524        }
2525    }
2526    let mut h = Array2::<f64>::zeros((d, d));
2527    for a in 0..d {
2528        for b in 0..d {
2529            h[[a, b]] = jm.row(a).dot(&jm.row(b));
2530        }
2531        h[[a, a]] += ridge;
2532    }
2533    let (vals, _vecs) = h.eigh(Side::Lower).ok()?;
2534    let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
2535    if lambda_min.is_finite() && lambda_min > 0.0 {
2536        Some(1.0 / lambda_min)
2537    } else {
2538        None
2539    }
2540}
2541
2542/// #1154 — the amortized encoder's closed-form warm-start coordinate for one
2543/// row `x` against one chart at amplitude `z`:
2544///
2545/// ```text
2546/// t̂ = t_c + (1/z) · A₁ · (x − z · m₁(t_c)),
2547/// ```
2548///
2549/// a single `O(d·p)` mat-vec from the chart's precomputed IFT Jacobian `A₁` and
2550/// center reconstruction `m₁`. Returns `None` when the chart carries no
2551/// distilled Jacobian (singular Gauss–Newton block) or the amplitude is not
2552/// strictly positive and finite (a near-inactive atom, where the
2553/// amplitude-divided map is undefined) — in those cases the caller starts from
2554/// the chart center instead. Shared by the amortized encode (where `t̂` is the
2555/// prediction) and the exact certified encode (where `t̂` is the Newton
2556/// warm-start that then refines to stationarity, Design A).
2557pub(crate) fn amortized_warm_start(
2558    chart: &CertifiedChart,
2559    x: ArrayView1<'_, f64>,
2560    amplitude: f64,
2561) -> Option<Array1<f64>> {
2562    let a1 = chart.amortized_jacobian.as_ref()?;
2563    if !(amplitude.is_finite() && amplitude.abs() > 0.0) {
2564        return None;
2565    }
2566    let d = a1.nrows();
2567    let mut t_hat = chart.region.center.clone();
2568    for (out_idx, &m1_out) in chart.recon_center.iter().enumerate().take(a1.ncols()) {
2569        let resid = x[out_idx] - amplitude * m1_out;
2570        for axis in 0..d {
2571            t_hat[axis] += a1[[axis, out_idx]] * resid / amplitude;
2572        }
2573    }
2574    Some(t_hat)
2575}
2576
2577/// The amplitude-1 distilled amortized-encoder Jacobian at a chart center
2578/// (#1026 ladder item 3). Returns `(A₁, m₁)` where `m₁ = BᵀΦ(t_c) ∈ ℝᵖ` is the
2579/// amplitude-1 center reconstruction and `A₁ = (J₁ᵀJ₁ + ridge·I)⁻¹ J₁ ∈ ℝ^{d×p}`
2580/// is the implicit-function-theorem derivative of the encode map `x ↦ t`
2581/// (Gauss–Newton block — the residual-free, dominant curvature exactly as the
2582/// offline radius-sizing `β`). With these, the online encode of a row `x` at
2583/// amplitude `z` is the closed-form affine prediction
2584/// `t = t_c + (1/z)·A₁·(x − z·m₁)` — one mat-vec, no per-row factorization.
2585/// `None` when the basis has no jet or the Gauss–Newton block is singular (no
2586/// certifiable amortization), matching `center_beta`'s gate so a chart with a
2587/// finite `β` always carries a Jacobian and vice versa.
2588pub(crate) fn center_amortized_jacobian(
2589    atom: &SaeManifoldAtom,
2590    center: &Array1<f64>,
2591    ridge: f64,
2592) -> Option<(Array2<f64>, Array1<f64>)> {
2593    let evaluator = atom.basis_evaluator.as_ref()?.clone();
2594    let d = atom.latent_dim;
2595    let p = atom.output_dim();
2596    let m = atom.basis_size();
2597    let coords = center.view().to_shape((1, d)).ok()?.to_owned();
2598    let (phi, jet) = evaluator.evaluate(coords.view()).ok()?;
2599    let decoder = &atom.decoder_coefficients;
2600    // m₁(t_c) = BᵀΦ(t_c) ∈ ℝᵖ (amplitude-1 center reconstruction).
2601    let mut recon = Array1::<f64>::zeros(p);
2602    for basis_col in 0..m {
2603        let phi_v = phi[[0, basis_col]];
2604        if phi_v == 0.0 {
2605            continue;
2606        }
2607        for out in 0..p {
2608            recon[out] += phi_v * decoder[[basis_col, out]];
2609        }
2610    }
2611    // J₁[axis] = Bᵀ (∂Φ/∂t_axis) ∈ ℝᵖ (amplitude-1; z factors out analytically).
2612    let mut jm = Array2::<f64>::zeros((d, p));
2613    for axis in 0..d {
2614        for basis_col in 0..m {
2615            let dphi = jet[[0, basis_col, axis]];
2616            if dphi == 0.0 {
2617                continue;
2618            }
2619            for out in 0..p {
2620                jm[[axis, out]] += dphi * decoder[[basis_col, out]];
2621            }
2622        }
2623    }
2624    // H_GN = J₁ J₁ᵀ + ridge·I ∈ ℝ^{d×d}.
2625    let mut h = Array2::<f64>::zeros((d, d));
2626    for a in 0..d {
2627        for b in 0..d {
2628            h[[a, b]] = jm.row(a).dot(&jm.row(b));
2629        }
2630        h[[a, a]] += ridge;
2631    }
2632    let (vals, vecs) = h.eigh(Side::Lower).ok()?;
2633    let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
2634    if !(lambda_min.is_finite() && lambda_min > 0.0) {
2635        return None;
2636    }
2637    // A₁ = H_GN⁻¹ J₁ via the eigendecomposition: H⁻¹ = Σ_i (1/λᵢ) vᵢ vᵢᵀ, so
2638    // A₁[:, out] = Σ_i (vᵢ · J₁[:, out]) / λᵢ · vᵢ. Column-by-column keeps it the
2639    // d×p Jacobian (one SPD solve reused across all p output channels).
2640    let mut a1 = Array2::<f64>::zeros((d, p));
2641    for out in 0..p {
2642        let jcol = jm.column(out);
2643        for (i, &lam) in vals.iter().enumerate() {
2644            if !(lam.is_finite() && lam > 0.0) {
2645                return None;
2646            }
2647            let vi = vecs.column(i);
2648            let coeff = vi.dot(&jcol) / lam;
2649            for row in 0..d {
2650                a1[[row, out]] += coeff * vi[row];
2651            }
2652        }
2653    }
2654    Some((a1, recon))
2655}
2656
2657/// Route a target row to the nearest chart of an atom by reconstruction
2658/// distance: the chart whose center reconstruction `m(t_c)` is closest to `x`.
2659/// Returns the chart index and the distance, or `None` when the atom has no
2660/// charts.
2661pub(crate) fn nearest_chart(
2662    atom_atlas: &AtomEncodeAtlas,
2663    x: ArrayView1<'_, f64>,
2664    atom: &SaeManifoldAtom,
2665    evaluator: &dyn SaeBasisEvaluator,
2666) -> Option<(usize, f64)> {
2667    if atom_atlas.charts.is_empty() {
2668        return None;
2669    }
2670    let d = atom.latent_dim;
2671    let p = atom.output_dim();
2672    let m = atom.basis_size();
2673    let mut best: Option<(usize, f64)> = None;
2674    for (idx, chart) in atom_atlas.charts.iter().enumerate() {
2675        if chart.certified_radius <= 0.0 {
2676            continue;
2677        }
2678        let coords = match chart.region.center.view().to_shape((1, d)) {
2679            Ok(c) => c.to_owned(),
2680            Err(_) => continue,
2681        };
2682        let Ok((phi, _jet)) = evaluator.evaluate(coords.view()) else {
2683            continue;
2684        };
2685        // m(t_c) = Bᵀ Φ(t_c) (amplitude-1; routing is scale-tolerant).
2686        let mut recon = Array1::<f64>::zeros(p);
2687        for basis_col in 0..m {
2688            let phi_v = phi[[0, basis_col]];
2689            if phi_v == 0.0 {
2690                continue;
2691            }
2692            for out in 0..p {
2693                recon[out] += phi_v * atom.decoder_coefficients[[basis_col, out]];
2694            }
2695        }
2696        let diff = &recon - &x;
2697        let dist = diff.dot(&diff);
2698        if best.map(|(_, b)| dist < b).unwrap_or(true) {
2699            best = Some((idx, dist));
2700        }
2701    }
2702    best
2703}
2704
2705/// The `k` charts whose CENTER reconstruction `m(t_c)` is nearest to `x` in
2706/// ambient ‖·‖², returned as chart indices sorted by increasing distance (ties
2707/// broken by chart index — deterministic). Only certifiable charts
2708/// (`certified_radius > 0`) are considered, exactly like [`nearest_chart`], whose
2709/// single result is `nearest_charts_topk(.., 1)[0]`. Used by the certified encode
2710/// to refine the global basin on self-approaching atoms (see
2711/// [`CERTIFIED_ROUTING_TOPK`]).
2712pub(crate) fn nearest_charts_topk(
2713    atom_atlas: &AtomEncodeAtlas,
2714    x: ArrayView1<'_, f64>,
2715    atom: &SaeManifoldAtom,
2716    evaluator: &dyn SaeBasisEvaluator,
2717    k: usize,
2718) -> Vec<usize> {
2719    if atom_atlas.charts.is_empty() || k == 0 {
2720        return Vec::new();
2721    }
2722    let d = atom.latent_dim;
2723    let p = atom.output_dim();
2724    let m = atom.basis_size();
2725    let mut scored: Vec<(usize, f64)> = Vec::new();
2726    for (idx, chart) in atom_atlas.charts.iter().enumerate() {
2727        if chart.certified_radius <= 0.0 {
2728            continue;
2729        }
2730        let coords = match chart.region.center.view().to_shape((1, d)) {
2731            Ok(c) => c.to_owned(),
2732            Err(_) => continue,
2733        };
2734        let Ok((phi, _jet)) = evaluator.evaluate(coords.view()) else {
2735            continue;
2736        };
2737        let mut recon = Array1::<f64>::zeros(p);
2738        for basis_col in 0..m {
2739            let phi_v = phi[[0, basis_col]];
2740            if phi_v == 0.0 {
2741                continue;
2742            }
2743            for out in 0..p {
2744                recon[out] += phi_v * atom.decoder_coefficients[[basis_col, out]];
2745            }
2746        }
2747        let diff = &recon - &x;
2748        scored.push((idx, diff.dot(&diff)));
2749    }
2750    // Sort by distance, then chart index for a deterministic, first-wins order
2751    // consistent with `nearest_chart`'s strict-`<` tie rule.
2752    scored.sort_by(|a, b| {
2753        a.1.partial_cmp(&b.1)
2754            .unwrap_or(std::cmp::Ordering::Equal)
2755            .then(a.0.cmp(&b.0))
2756    });
2757    scored.into_iter().take(k).map(|(idx, _)| idx).collect()
2758}
2759
2760/// Reconstruction error `‖x − z·m(t)‖` of an encoded coordinate `t` — the
2761/// criterion the certified encode minimizes over its candidate charts to pick the
2762/// GLOBAL basin. `m(t) = Bᵀ Φ(t)` is the amplitude-1 reconstruction; `z` is the
2763/// amplitude. A non-finite reconstruction returns `+∞` so it never wins.
2764pub(crate) fn encode_reconstruction_error(
2765    atom: &SaeManifoldAtom,
2766    evaluator: &dyn SaeBasisEvaluator,
2767    coord: ArrayView1<'_, f64>,
2768    x: ArrayView1<'_, f64>,
2769    amplitude: f64,
2770) -> f64 {
2771    let d = atom.latent_dim;
2772    let p = atom.output_dim();
2773    let m = atom.basis_size();
2774    let coords = match coord.to_shape((1, d)) {
2775        Ok(c) => c.to_owned(),
2776        Err(_) => return f64::INFINITY,
2777    };
2778    let Ok((phi, _jet)) = evaluator.evaluate(coords.view()) else {
2779        return f64::INFINITY;
2780    };
2781    let mut err2 = 0.0;
2782    for out in 0..p {
2783        let mut recon = 0.0;
2784        for basis_col in 0..m {
2785            recon += phi[[0, basis_col]] * atom.decoder_coefficients[[basis_col, out]];
2786        }
2787        let r = x[out] - amplitude * recon;
2788        err2 += r * r;
2789    }
2790    if err2.is_finite() { err2.sqrt() } else { f64::INFINITY }
2791}
2792
2793/// Maximum number of chart centers laid down per atom (the SHAPE_BAND grid
2794/// point cap; mirrors `SHAPE_BAND_MAX_POINTS` in the atom band machinery).
2795pub(crate) const SHAPE_BAND_MAX_POINTS: usize = 512;
2796
2797/// Lay down chart centers on an atom's coordinate grid (the SHAPE_BAND grid
2798/// idiom): a regular grid spanning the compact latent domain for periodic /
2799/// sphere / torus atoms, and a strided cover of the latent axes for unbounded
2800/// (Duchon / Euclidean) atoms.
2801///
2802/// Periodic / torus latents are fractions of one period, so the per-axis grid
2803/// spans `[0, 1)`; the sphere chart spans `lat ∈ [−π/2, π/2]`, `lon ∈ [−π, π)`.
2804/// These conventions match the basis evaluators (the fraction-of-period circle
2805/// harmonic and the lat/lon sphere chart).
2806/// Squared coordinate distance between two latent points under the atom's chart
2807/// geometry: per-axis WRAPPED distance `min(|a−b|, period−|a−b|)` on periodic
2808/// (circle) axes — period 1 to match `chart_center_grid`'s `[0,1)` torus tiling
2809/// — and plain difference on line axes. Used to place + size data-driven charts.
2810pub(crate) fn coord_dist_sq(atom: &SaeManifoldAtom, a: ArrayView1<'_, f64>, b: ArrayView1<'_, f64>) -> f64 {
2811    use crate::manifold::SaeAtomBasisKind::*;
2812    let periodic_axis = |axis: usize| -> bool {
2813        match &atom.basis_kind {
2814            Periodic | Torus | Sphere => true,
2815            // Cylinder S¹×ℝ: only axis 0 is the circle.
2816            Cylinder => axis == 0,
2817            Linear | Duchon | EuclideanPatch | Poincare | Precomputed(_) => false,
2818        }
2819    };
2820    let mut acc = 0.0;
2821    for axis in 0..a.len() {
2822        let mut d = (a[axis] - b[axis]).abs();
2823        if periodic_axis(axis) {
2824            // Wrap onto the circle of unit period.
2825            d -= d.floor(); // fractional part in [0,1)
2826            d = d.min(1.0 - d);
2827        }
2828        acc += d * d;
2829    }
2830    acc
2831}
2832
2833/// Greedy farthest-point sampling of up to `max_charts` chart centers from the
2834/// atom's latent `coords` (n × d), with each center's nominal radius set to half
2835/// the distance to its nearest neighbor center (floored, so a singleton/coincident
2836/// cluster still gets a usable ball). Deterministic: seeds from row 0, then
2837/// repeatedly adds the coord maximally far (under [`coord_dist_sq`]) from the
2838/// chosen set — coverage-maximizing and reproducible run-to-run.
2839pub(crate) fn data_driven_chart_centers(
2840    atom: &SaeManifoldAtom,
2841    coords: ArrayView2<'_, f64>,
2842    max_charts: usize,
2843) -> Result<(Array2<f64>, Vec<f64>), String> {
2844    let n = coords.nrows();
2845    let d = coords.ncols();
2846    if d != atom.latent_dim {
2847        return Err(format!(
2848            "data_driven_chart_centers: coords have {d} cols but atom latent_dim is {}",
2849            atom.latent_dim
2850        ));
2851    }
2852    if n == 0 {
2853        return Ok((Array2::<f64>::zeros((0, d)), Vec::new()));
2854    }
2855    let k = max_charts.min(n);
2856    // Farthest-point sampling: maintain each row's distance to the nearest chosen
2857    // center, add the row with the maximum such distance each step.
2858    let mut chosen: Vec<usize> = Vec::with_capacity(k);
2859    chosen.push(0);
2860    let mut nearest_sq: Vec<f64> = (0..n)
2861        .map(|r| coord_dist_sq(atom, coords.row(r), coords.row(0)))
2862        .collect();
2863    while chosen.len() < k {
2864        // Pick the row farthest from the current center set (first-wins tie).
2865        let mut best = 0usize;
2866        let mut best_d = -1.0;
2867        for r in 0..n {
2868            if nearest_sq[r] > best_d {
2869                best_d = nearest_sq[r];
2870                best = r;
2871            }
2872        }
2873        if best_d <= 0.0 {
2874            break; // all remaining rows coincide with a chosen center.
2875        }
2876        chosen.push(best);
2877        for r in 0..n {
2878            let dr = coord_dist_sq(atom, coords.row(r), coords.row(best));
2879            if dr < nearest_sq[r] {
2880                nearest_sq[r] = dr;
2881            }
2882        }
2883    }
2884    let m = chosen.len();
2885    let mut centers = Array2::<f64>::zeros((m, d));
2886    for (i, &row) in chosen.iter().enumerate() {
2887        centers.row_mut(i).assign(&coords.row(row));
2888    }
2889    // Per-center radius = half the nearest-OTHER-center distance, floored so a
2890    // coincident pair still yields a positive ball, capped at 0.5 (the largest
2891    // meaningful half-period on a unit circle).
2892    let mut radii = vec![0.0_f64; m];
2893    for i in 0..m {
2894        let mut nn = f64::INFINITY;
2895        for j in 0..m {
2896            if i == j {
2897                continue;
2898            }
2899            let dsq = coord_dist_sq(atom, centers.row(i), centers.row(j));
2900            if dsq < nn {
2901                nn = dsq;
2902            }
2903        }
2904        let r = if nn.is_finite() { 0.5 * nn.sqrt() } else { 0.5 };
2905        radii[i] = r.max(1.0e-3).min(0.5);
2906    }
2907    Ok((centers, radii))
2908}
2909
2910pub(crate) fn chart_center_grid(atom: &SaeManifoldAtom, resolution: usize) -> Array2<f64> {
2911    use crate::manifold::SaeAtomBasisKind::*;
2912    let d = atom.latent_dim;
2913    match &atom.basis_kind {
2914        Periodic | Torus => regular_product_grid(d, resolution, 0.0, 1.0, false),
2915        // Cylinder `S¹ × ℝ`: axis 0 is the periodic circle `[0, 1)` (no
2916        // endpoint, like the harmonic axes); axis 1 is the unbounded line,
2917        // covered by a strided unit box `[-0.5, 0.5]` about the origin (like the
2918        // Euclidean patch). The certified radius refines each chart; out-of-cover
2919        // line starts route to the exact fallback honestly.
2920        Cylinder if d == 2 => cylinder_chart_center_grid(resolution),
2921        Cylinder => regular_product_grid(d, resolution, -0.5, 0.5, true),
2922        Sphere if d == 2 => sphere_latlon_grid(resolution),
2923        Linear | Sphere | Duchon | EuclideanPatch | Poincare | Precomputed(_) => {
2924            // Unbounded / non-compact latents: a strided cover of a unit box
2925            // about the origin per axis. The certified radius refines each chart;
2926            // out-of-cover starts route to the exact fallback honestly.
2927            regular_product_grid(d, resolution, -0.5, 0.5, true)
2928        }
2929    }
2930}
2931
2932/// A regular `resolution`-per-axis product grid over `[lo, hi]^d`, capped at
2933/// [`SHAPE_BAND_MAX_POINTS`] total points (the per-axis resolution is reduced
2934/// until the product fits). When `include_endpoint` the last grid point sits at
2935/// `hi`; otherwise the axis is treated as periodic and stops one step short.
2936pub(crate) fn regular_product_grid(
2937    d: usize,
2938    resolution: usize,
2939    lo: f64,
2940    hi: f64,
2941    include_endpoint: bool,
2942) -> Array2<f64> {
2943    if d == 0 {
2944        return Array2::<f64>::zeros((1, 0));
2945    }
2946    let mut per_axis = resolution.max(2);
2947    while per_axis.saturating_pow(d as u32) > SHAPE_BAND_MAX_POINTS && per_axis > 2 {
2948        per_axis -= 1;
2949    }
2950    let total = per_axis.saturating_pow(d as u32).max(1);
2951    let denom = if include_endpoint {
2952        (per_axis.max(2) - 1) as f64
2953    } else {
2954        per_axis as f64
2955    };
2956    let mut grid = Array2::<f64>::zeros((total, d));
2957    let mut idx = vec![0usize; d];
2958    for flat in 0..total {
2959        for axis in 0..d {
2960            let frac = idx[axis] as f64 / denom;
2961            grid[[flat, axis]] = lo + (hi - lo) * frac;
2962        }
2963        for axis in (0..d).rev() {
2964            idx[axis] += 1;
2965            if idx[axis] < per_axis {
2966                break;
2967            }
2968            idx[axis] = 0;
2969        }
2970    }
2971    grid
2972}
2973
2974/// Lat/lon sphere chart grid: `lat ∈ [−π/2, π/2]`, `lon ∈ [−π, π)`, matching
2975/// the [`crate::manifold::SphereChartEvaluator`] convention.
2976pub(crate) fn sphere_latlon_grid(resolution: usize) -> Array2<f64> {
2977    use std::f64::consts::PI;
2978    let r = resolution.max(2).min(22); // 22² = 484 ≤ SHAPE_BAND_MAX_POINTS.
2979    let mut grid = Array2::<f64>::zeros((r * r, 2));
2980    for i in 0..r {
2981        let lat = -PI / 2.0 + PI * (i as f64 + 0.5) / r as f64;
2982        for j in 0..r {
2983            let lon = -PI + 2.0 * PI * (j as f64) / r as f64;
2984            grid[[i * r + j, 0]] = lat;
2985            grid[[i * r + j, 1]] = lon;
2986        }
2987    }
2988    grid
2989}
2990
2991/// Cylinder `S¹ × ℝ` chart-center grid: axis 0 sweeps the periodic circle over
2992/// one period `[0, 1)` (no endpoint, matching the harmonic axis), axis 1 strides
2993/// a unit box `[−0.5, 0.5]` about the origin on the unbounded line (with
2994/// endpoint). Capped at [`SHAPE_BAND_MAX_POINTS`] total centers.
2995pub(crate) fn cylinder_chart_center_grid(resolution: usize) -> Array2<f64> {
2996    let mut per_axis = resolution.max(2);
2997    while per_axis * per_axis > SHAPE_BAND_MAX_POINTS && per_axis > 2 {
2998        per_axis -= 1;
2999    }
3000    let total = per_axis * per_axis;
3001    let line_denom = (per_axis.max(2) - 1) as f64;
3002    let mut grid = Array2::<f64>::zeros((total, 2));
3003    for i in 0..per_axis {
3004        // Periodic axis 0: stop one step short of the period.
3005        let circle = i as f64 / per_axis as f64;
3006        for j in 0..per_axis {
3007            // Line axis 1: include the endpoint of the unit box.
3008            let line = -0.5 + (j as f64) / line_denom;
3009            grid[[i * per_axis + j, 0]] = circle;
3010            grid[[i * per_axis + j, 1]] = line;
3011        }
3012    }
3013    grid
3014}
3015
3016/// Nominal in-chart radius: half the inter-center grid spacing, so charts tile
3017/// the domain. For compact latents this is the grid step; for unbounded latents
3018/// a unit default that the certified radius refines.
3019pub(crate) fn chart_nominal_radius(atom: &SaeManifoldAtom, resolution: usize) -> f64 {
3020    use crate::manifold::SaeAtomBasisKind::*;
3021    match &atom.basis_kind {
3022        Periodic | Torus => 0.5 / (resolution.max(2) as f64),
3023        Sphere => std::f64::consts::PI / (resolution.max(2) as f64),
3024        // Cylinder charts tile two heterogeneous axes (a `[0,1)` periodic step
3025        // and a unit-box line step); the chart radius is a single scalar, so we
3026        // take the tighter (periodic) step `0.5/res` to keep every chart valid
3027        // on both axes. The certified Kantorovich radius refines it per chart.
3028        Cylinder => 0.5 / (resolution.max(2) as f64),
3029        Linear | Duchon | EuclideanPatch | Poincare | Precomputed(_) => {
3030            1.0 / (resolution.max(2) as f64)
3031        }
3032    }
3033}
3034
3035/// Build the [`ChartRegion`] for a center, attaching the radial r_min / r_max
3036/// bracket for Duchon atoms (the chart's distance range to the kernel centers).
3037pub(crate) fn chart_region(
3038    atom: &SaeManifoldAtom,
3039    center: Array1<f64>,
3040    radius: f64,
3041) -> ChartRegion {
3042    use crate::manifold::SaeAtomBasisKind::*;
3043    let region = ChartRegion::new(center.clone(), radius);
3044    match &atom.basis_kind {
3045        Duchon => {
3046            // r ranges over [‖t_c‖ − radius, ‖t_c‖ + radius] about the single
3047            // origin-anchored center used by the conservative radial bound.
3048            //
3049            // The lower bound must be `max(0, center_norm − radius)` — NOT floored
3050            // at `radius`. When the chart contains the kernel center
3051            // (`center_norm < radius`, true r_min = 0), flooring at `radius`
3052            // would give a finite, NON-CONSERVATIVE `r_min`, causing the
3053            // hessian_sup / third_sup formulas (which divide by r_min) to
3054            // underestimate the Lipschitz constant and potentially grant a false
3055            // Kantorovich certificate. Flooring at `f64::MIN_POSITIVE` instead
3056            // correctly drives the formulas toward ∞, producing a very large L
3057            // that will NEVER certify (rows route to the exact multi-start
3058            // fallback) — conservative and sound.
3059            let center_norm = center.dot(&center).sqrt();
3060            let r_min = (center_norm - radius).max(f64::MIN_POSITIVE);
3061            let r_max = center_norm + radius;
3062            region.with_radial_bounds(r_min, r_max)
3063        }
3064        // Cylinder has no radial kernel block (it is a harmonic × polynomial
3065        // tensor, not a Duchon radial basis), so it needs no radial r_min/r_max.
3066        Periodic | Sphere | Torus | Cylinder | Linear | EuclideanPatch | Poincare
3067        | Precomputed(_) => region,
3068    }
3069}