Skip to main content

gam_sae/
encode.rs

1//! Kantorovich-certified encode atlas (issue #1010).
2//!
3//! Encoding a row `x ∈ ℝᵖ` against a FROZEN SAE dictionary is, per atom `k`,
4//! the coordinate-only Newton problem
5//!
6//! ```text
7//! min_t  f_k(t) = ½‖x − z_k · B_kᵀ Φ_k(t)‖² + prior_k(t),
8//! ```
9//!
10//! with the amplitude `z_k` and decoder block `B_k` held fixed (the encode
11//! freezes the dictionary; only the latent coordinate `t` moves). Newton on
12//! `F(t) = ∇f_k(t)` converges quadratically from a start `t₀` into the unique
13//! root in a certified ball whenever the **Newton–Kantorovich** quantity
14//!
15//! ```text
16//! h = β · η · L ≤ ½,    β = ‖F'(t₀)⁻¹‖,   η = ‖F'(t₀)⁻¹ F(t₀)‖,
17//! ```
18//!
19//! where `L` is a Lipschitz constant of `F'` (the Hessian of `f_k`) on a region
20//! containing the Newton iterates. `h` is CHECKABLE per row in `O(q³)`
21//! (`q = latent_dim`, tiny), so each fast-path encode carries its own
22//! exactness certificate.
23//!
24//! ## The closed-form Hessian-Lipschitz constant `L`
25//!
26//! Write `m(t) = z·BᵀΦ(t) ∈ ℝᵖ` (the reconstruction) and `r(t) = m(t) − x`.
27//! Then `f = ½‖r‖² + prior` and, differentiating three times,
28//!
29//! ```text
30//! ∇³f = 3·sym(J_mᵀ : ∇²m) + ⟨r, ∇³m⟩ + ∇³prior,
31//! ```
32//!
33//! so an operator-norm bound on the chart is
34//!
35//! ```text
36//! L ≤ 3·‖J_m‖·‖∇²m‖ + ‖r‖·‖∇³m‖ + L_prior,
37//! ```
38//!
39//! with `‖∂^g m‖ ≤ |z|·(Σ_m ‖B_{m,:}‖)·B_g`, where `B_g = sup_chart max_m
40//! ‖∂^g Φ_m‖` is the per-column jet sup of the basis family — closed form per
41//! family ([`BasisHessianLipschitz`]). `‖r‖` is bounded by `‖x‖ +
42//! |z|·(Σ_m‖B_{m,:}‖)·B_0`. The ARD/von-Mises prior `L_prior` is a closed-form
43//! constant from the prior strength. Every bound is conservative (an
44//! over-estimate of `L` only SHRINKS the certified radius — it can never
45//! certify a row that does not converge).
46//!
47//! ## Pipeline
48//!
49//! 1. **Offline, per atom** ([`EncodeAtlas::build`]): chart centers `t_c` on the
50//!    atom's coordinate grid (the SHAPE_BAND grid idiom), each with a certified
51//!    Newton radius `R_c` solved from the Kantorovich inequality at the
52//!    worst-case in-chart start.
53//! 2. **Online, per row** ([`EncodeAtlas::certified_encode_row`]): route to the
54//!    nearest chart, start from its distilled IFT predictor, take one or two
55//!    Newton steps, then the `h ≤ ½` check AT the start point is the per-row
56//!    certificate.
57//! 3. **Uncertified tail**: rows whose start fails `h ≤ ½` are FLAGGED (counted
58//!    in [`EncodeResult::encode_uncertified_count`]) and must be routed by the
59//!    caller to the existing exact multi-start solve. No approximation enters
60//!    silently.
61
62use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
63
64use crate::candidate_index::{AtomFrameSketch, SaeCandidateIndex, auto_candidate_budget};
65use crate::manifold::{
66    AffineCoordinateEvaluator, CylinderHarmonicEvaluator, DuchonCoordinateEvaluator,
67    EuclideanPatchEvaluator, PeriodicHarmonicEvaluator, SaeBasisEvaluator, SaeManifoldAtom,
68    SphereChartEvaluator, TorusHarmonicEvaluator,
69};
70use gam_linalg::faer_ndarray::FaerEigh;
71
72use faer::Side;
73
74/// The Kantorovich convergence threshold `h ≤ ½`. Below this the Newton
75/// iteration is guaranteed to converge quadratically into the unique root in
76/// the certified ball; at or above it the start is uncertified.
77pub const KANTOROVICH_THRESHOLD: f64 = 0.5;
78
79/// Row count at or above which the corpus-rate certified-encode batch
80/// (`certified_encode_batch` / `certified_encode_with_index`) fans its
81/// per-row encodes out over rayon. Below this the per-row Newton + chart
82/// routing is cheap enough that the fan-out overhead does not pay; matched to
83/// the same order as the arrow-Schur `SCHUR_MATVEC_PARALLEL_ROW_MIN` gate so
84/// short batches inside an outer atom-level fan-out stay sequential.
85pub(crate) const ENCODE_BATCH_PARALLEL_ROW_MIN: usize = 256;
86
87/// Minimum frame alignment `‖Uₖᵀd‖/‖d‖ ∈ [0,1]` the routed atom must have for an
88/// index-routed encode to be attempted at all — a FIT-QUALITY floor, NOT a
89/// routing-correctness gate (#1026, corrected by the #1777 exact-routing path).
90///
91/// Routing itself is now EXACT: the index-routed encode picks the atom via
92/// [`SaeCandidateIndex::route_exact`], which returns the GLOBAL argmax of the
93/// routing score (the universal-bound LSH fast path, else a full-scan fallback) —
94/// so there is no "missed-better-ungathered-atom" hole left for this constant to
95/// patch. What remains is a different, honest question: even the globally-best
96/// atom may align only weakly with a row (no atom in the dictionary fits it). A
97/// finite alignment below this floor means the best available atom is a poor fit,
98/// so the row is flagged and routed to the exact multi-start fallback rather than
99/// encoded against an atom it barely belongs to. (Previously this same comparison
100/// double-served as a recall proxy; that role is gone — `route_exact` guarantees
101/// recall — leaving only the fit-quality role described here.)
102pub(crate) const CANDIDATE_ROUTING_MIN_ALIGNMENT: f64 = 0.5;
103
104/// Number of nearest charts the CERTIFIED encode refines in before returning the
105/// lowest-reconstruction-error certified result. A single nearest chart is not
106/// globally sound where the decoded manifold folds near itself (both competing
107/// basins' charts reconstruct near the fold, so both rank among the nearest by
108/// ambient distance); refining the top few captures the global basin. For a
109/// unimodal atom all candidates converge to the same root, so K>1 is a no-op.
110pub(crate) const CERTIFIED_ROUTING_TOPK: usize = 4;
111
112/// Newton refinement convergence floor. Once a refinement step's length `‖δ‖`
113/// falls below this (relative to the coordinate scale `1 + ‖t‖`), the iterate has
114/// reached the certified root to f64 resolution: applying the step cannot move `t`
115/// meaningfully, and the remaining fixed-budget steps only re-accumulate round-off.
116/// Stopping there is STRICTLY more accurate than draining a fixed step budget on a
117/// well-conditioned quadratic Newton tail, and it removes that tail's per-step
118/// `evaluate` + `second_jet` cost (the dominant per-row encode work). The batched
119/// and per-row encodes share this rule, so they stay bit-identical.
120pub(crate) const NEWTON_REFINE_CONVERGED_EPS: f64 = 1.0e-12;
121
122/// Global-minimum short-circuit floor for top-K certified routing. The
123/// reconstruction error `‖x − z·m(t)‖` is bounded below by 0, so a certified
124/// candidate whose residual already sits at the ambient noise floor
125/// (`≤ this · (1 + ‖x‖)`) is provably the global optimum over the charts — no
126/// competing chart can reach a strictly lower residual. The remaining candidates'
127/// refinement is then skipped. Conservative (a genuine second basin of the same
128/// target reconstructs the SAME point, so returning the first is a valid encode).
129pub(crate) const CERTIFIED_GLOBAL_MIN_RECON_FLOOR: f64 = 1.0e-11;
130
131/// A chart region on an atom's latent coordinate: a center `t_c` plus a
132/// certified in-chart radius. Over the ball `‖t − t_c‖ ≤ radius` the jet sup
133/// bounds returned by [`BasisHessianLipschitz`] hold, so the Kantorovich
134/// constant `L` computed from them is valid for any start in the ball.
135///
136/// For radial (Duchon) families the chart also carries the minimum kernel-center
137/// distance `exclusion_r_min` (a lower bound on `‖t − c_k‖` over the chart) that
138/// bounds the otherwise-singular `1/r` radial tails (issue #1010).
139#[derive(Debug, Clone)]
140pub struct ChartRegion {
141    /// Chart center coordinate `t_c` (length = latent_dim).
142    pub center: Array1<f64>,
143    /// In-chart radius in the coordinate metric.
144    pub radius: f64,
145    /// For radial (Duchon) families: a lower bound on `‖t − c_k‖` over the
146    /// chart, across every kernel center `c_k`. `None` for non-radial families.
147    pub exclusion_r_min: Option<f64>,
148    /// For radial (Duchon) families: an upper bound on `‖t − c_k‖` over the
149    /// chart, across every kernel center `c_k`. `None` for non-radial families.
150    pub radial_r_max: Option<f64>,
151}
152
153impl ChartRegion {
154    pub fn new(center: Array1<f64>, radius: f64) -> Self {
155        Self {
156            center,
157            radius,
158            exclusion_r_min: None,
159            radial_r_max: None,
160        }
161    }
162
163    pub fn with_radial_bounds(mut self, r_min: f64, r_max: f64) -> Self {
164        self.exclusion_r_min = Some(r_min);
165        self.radial_r_max = Some(r_max);
166        self
167    }
168
169    /// A jet-sup certificate is only meaningful over a genuine region. Even
170    /// families whose bounds are manifold-global constants (the sup over any
171    /// chart equals the global sup) must refuse a malformed chart rather than
172    /// certify garbage geometry.
173    pub(crate) fn assert_valid(&self) {
174        assert!(
175            self.radius.is_finite()
176                && self.radius >= 0.0
177                && self.center.iter().all(|c| c.is_finite()),
178            "ChartRegion must have a finite center and a finite non-negative radius"
179        );
180    }
181}
182
183/// Per-column sup-norm bounds on the first three coordinate jets of a basis
184/// family `Φ(t)`, valid over a stated [`ChartRegion`] (issue #1010). These are
185/// the analytic ingredients of the Hessian-Lipschitz constant `L` — see the
186/// module docs for the assembly. `value_sup` bounds `max_m |Φ_m|`,
187/// `jacobian_sup`/`hessian_sup`/`third_sup` bound `max_m ‖∂^g Φ_m‖`.
188pub trait BasisHessianLipschitz {
189    fn value_sup(&self, chart: &ChartRegion) -> f64;
190    fn jacobian_sup(&self, chart: &ChartRegion) -> f64;
191    fn hessian_sup(&self, chart: &ChartRegion) -> f64;
192    fn third_sup(&self, chart: &ChartRegion) -> f64;
193}
194
195/// Sup over the circle of the `g`-th derivative of any single harmonic column
196/// of a `num_basis`-wide Fourier basis `[1, sin(2π h t), cos(2π h t), …]`:
197/// `(2π·H)^g` for the top harmonic `H = (num_basis − 1)/2`. The constant column
198/// contributes `0` for `g ≥ 1`, so the top harmonic dominates; the bound is
199/// global (the trig magnitudes are `≤ 1` everywhere, independent of the chart).
200pub(crate) fn harmonic_jet_sup(num_basis: usize, order: u32) -> f64 {
201    let top_harmonic = num_basis.saturating_sub(1) / 2;
202    let omega = std::f64::consts::TAU * top_harmonic as f64;
203    omega.powi(order as i32)
204}
205
206impl BasisHessianLipschitz for PeriodicHarmonicEvaluator {
207    fn value_sup(&self, chart: &ChartRegion) -> f64 {
208        chart.assert_valid();
209        1.0
210    }
211    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
212        chart.assert_valid();
213        harmonic_jet_sup(self.num_basis, 1)
214    }
215    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
216        chart.assert_valid();
217        harmonic_jet_sup(self.num_basis, 2)
218    }
219    fn third_sup(&self, chart: &ChartRegion) -> f64 {
220        chart.assert_valid();
221        harmonic_jet_sup(self.num_basis, 3)
222    }
223}
224
225impl BasisHessianLipschitz for TorusHarmonicEvaluator {
226    /// Tensor product of per-axis circle harmonics. A torus basis column is a
227    /// product of single-axis harmonics, each bounded as in the circle case.
228    /// The `g`-th coordinate jet routes `g` derivative operators across the
229    /// `latent_dim` factors (Leibniz); each routing contributes a product of
230    /// per-axis derivative magnitudes. A per-column sup is therefore bounded by
231    /// the top single-axis frequency to the `g`-th power times the number of
232    /// such routings (`latent_dim^g`, the count of operator-to-axis maps).
233    fn value_sup(&self, chart: &ChartRegion) -> f64 {
234        chart.assert_valid();
235        1.0
236    }
237    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
238        chart.assert_valid();
239        torus_jet_sup(self.num_harmonics, self.latent_dim, 1)
240    }
241    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
242        chart.assert_valid();
243        torus_jet_sup(self.num_harmonics, self.latent_dim, 2)
244    }
245    fn third_sup(&self, chart: &ChartRegion) -> f64 {
246        chart.assert_valid();
247        torus_jet_sup(self.num_harmonics, self.latent_dim, 3)
248    }
249}
250
251/// Per-column `g`-th jet sup for the torus harmonic basis: `(2π·H)^g ·
252/// latent_dim^g`, where `H = num_harmonics` is the top per-axis frequency and
253/// `latent_dim^g` over-counts the Leibniz routings of `g` operators across the
254/// product factors (a conservative bound — each routing's per-axis magnitude is
255/// `≤ (2π H)^{#ops on that axis}`, and the products telescope to `(2π H)^g`).
256pub(crate) fn torus_jet_sup(num_harmonics: usize, latent_dim: usize, order: u32) -> f64 {
257    let omega = std::f64::consts::TAU * num_harmonics as f64;
258    omega.powi(order as i32) * (latent_dim as f64).powi(order as i32)
259}
260
261impl BasisHessianLipschitz for SphereChartEvaluator {
262    /// The 7-column lat/lon chart `[1, x, y, z, xy, yz, xz]` with
263    /// `x = cos(lat)cos(lon)`, `y = cos(lat)sin(lon)`, `z = sin(lat)`. Each of
264    /// `x, y, z` is a product of two unit-frequency trig factors, so its `g`-th
265    /// coordinate jet is a sum of `2^g` products of `{sin,cos}` (each `≤ 1`):
266    /// magnitude `≤ 2^g` for `g ≥ 1`, `≤ 1` for `g = 0`. The bilinear columns
267    /// `xy, yz, xz` are products of two such coordinates; by Leibniz over the
268    /// product, their `g`-th jet is bounded by `Σ_{i=0}^{g} C(g,i)·(2^i)·(2^{g−i})
269    /// = (2+2)^g = 4^g` (using `‖∂^i u‖ ≤ 2^i`, `|u| ≤ 1`). The bilinear columns
270    /// dominate, so the per-column sup is `4^g` (`g ≥ 1`). Bounds are global
271    /// constants — the chart box `lat ∈ [-π/2, π/2]` does not enlarge them.
272    fn value_sup(&self, chart: &ChartRegion) -> f64 {
273        chart.assert_valid();
274        1.0
275    }
276    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
277        chart.assert_valid();
278        4.0
279    }
280    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
281        chart.assert_valid();
282        16.0
283    }
284    fn third_sup(&self, chart: &ChartRegion) -> f64 {
285        chart.assert_valid();
286        64.0
287    }
288}
289
290impl BasisHessianLipschitz for AffineCoordinateEvaluator {
291    /// The affine basis `[1, t₁, …, t_d]` is degree ≤ 1: its first jet has unit
292    /// columns, and all second and third jets vanish. The value sup is
293    /// `max(1, ‖t‖)` over the chart, bounded by `1 + ‖t_c‖ + radius`.
294    fn value_sup(&self, chart: &ChartRegion) -> f64 {
295        let center_norm = chart.center.dot(&chart.center).sqrt();
296        1.0 + center_norm + chart.radius
297    }
298    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
299        chart.assert_valid();
300        1.0
301    }
302    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
303        chart.assert_valid();
304        0.0
305    }
306    fn third_sup(&self, chart: &ChartRegion) -> f64 {
307        chart.assert_valid();
308        0.0
309    }
310}
311
312impl BasisHessianLipschitz for EuclideanPatchEvaluator {
313    /// Monomials of total degree ≤ `max_degree` in `t ∈ ℝ^d`. Over the ball of
314    /// radius `R` about `t_c`, each coordinate is bounded by `ρ = ‖t_c‖∞ + R`.
315    /// A monomial `t^α` with `|α| = q` has `g`-th partials bounded (crudely) by
316    /// the descending-factorial coefficient `q·(q−1)···(q−g+1) ≤ q^g` times
317    /// `ρ^{max(q−g,0)}`, and there are at most `d^g` partial routings, so the
318    /// per-column `g`-th jet sup is `≤ d^g · D^g · ρ^{max(D−g,0)}` with
319    /// `D = max_degree`. Conservative; D is small for patch evaluators.
320    fn value_sup(&self, chart: &ChartRegion) -> f64 {
321        let rho = patch_rho(chart);
322        let d = self.max_degree as i32;
323        rho.powi(d).max(1.0)
324    }
325    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
326        patch_jet_sup(self.latent_dim, self.max_degree, chart, 1)
327    }
328    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
329        patch_jet_sup(self.latent_dim, self.max_degree, chart, 2)
330    }
331    fn third_sup(&self, chart: &ChartRegion) -> f64 {
332        patch_jet_sup(self.latent_dim, self.max_degree, chart, 3)
333    }
334}
335
336impl BasisHessianLipschitz for CylinderHarmonicEvaluator {
337    /// Cylinder `S¹ × ℝ` product basis `Φ_{c,l} = c(t₀)·l(t₁)`, the circle
338    /// (periodic harmonic) factor on axis 0 crossed with the monomial line
339    /// factor on axis 1. Because the two factors depend on disjoint coordinates,
340    /// the order-`g` coordinate jet in any cell is exactly
341    /// `c^{(k₀)}(t₀)·l^{(k₁)}(t₁)` with `k₀ + k₁ = g`, so the per-column sup is
342    /// the max over the split `k₀ + k₁ = g` of the product of the two per-axis
343    /// per-order sups: the circle factor contributes `1` at order 0 and
344    /// `(2π·H)^{k₀}` at order `k₀ ≥ 1` (trig magnitudes `≤ 1`); the line factor
345    /// contributes the monomial-patch sup `D^{k₁}·ρ^{max(D−k₁,0)}` (`D = line
346    /// degree`, `ρ = ‖t_c‖∞ + radius`). Bounds are global in the periodic axis
347    /// and chart-local in the line axis.
348    fn value_sup(&self, chart: &ChartRegion) -> f64 {
349        cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 0)
350    }
351    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
352        cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 1)
353    }
354    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
355        cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 2)
356    }
357    fn third_sup(&self, chart: &ChartRegion) -> f64 {
358        cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 3)
359    }
360}
361
362/// Per-column order-`g` jet sup of the cylinder product basis: the max over
363/// `k₀ + k₁ = g` of `circle_axis_sup(k₀) · line_axis_sup(k₁)`, where the circle
364/// axis sup is `(2π·H)^{k₀}` (`1` at `k₀ = 0`) and the line axis sup is the
365/// monomial-patch bound `D^{k₁}·ρ^{max(D−k₁,0)}` (`1` at `k₁ = 0`). See the
366/// [`CylinderHarmonicEvaluator`] doc comment for the derivation.
367pub(crate) fn cylinder_jet_sup(
368    circle_harmonics: usize,
369    line_degree: usize,
370    chart: &ChartRegion,
371    order: u32,
372) -> f64 {
373    let omega = std::f64::consts::TAU * circle_harmonics as f64;
374    let big_d = line_degree as f64;
375    let rho = patch_rho(chart);
376    let mut best = 0.0_f64;
377    for k0 in 0..=order {
378        let k1 = order - k0;
379        let circle = if k0 == 0 { 1.0 } else { omega.powi(k0 as i32) };
380        let line = if k1 == 0 {
381            rho.powi(line_degree as i32).max(1.0)
382        } else {
383            let residual = line_degree.saturating_sub(k1 as usize) as i32;
384            // `.max(1.0)` as in `patch_jet_sup`: for ρ < 1 a lower-degree line
385            // monomial dominates the k1-th derivative, so the bare `ρ^residual`
386            // underestimates the line-factor sup. The value case (k1==0) already
387            // clamps; this completes it for the derivative orders.
388            big_d.powi(k1 as i32) * rho.powi(residual).max(1.0)
389        };
390        best = best.max(circle * line);
391    }
392    best
393}
394
395/// Sup-norm radius `ρ = ‖t_c‖∞ + radius` of the chart (the coordinate magnitude
396/// bound used by the monomial-patch jet bounds).
397pub(crate) fn patch_rho(chart: &ChartRegion) -> f64 {
398    let center_inf = chart
399        .center
400        .iter()
401        .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
402    center_inf + chart.radius
403}
404
405/// Per-column `g`-th jet sup for a monomial patch of max degree `D` in `d`
406/// coordinates over the chart: `d^g · D^g · ρ^{max(D−g,0)}` (see the
407/// [`EuclideanPatchEvaluator`] doc comment for the derivation).
408pub(crate) fn patch_jet_sup(
409    latent_dim: usize,
410    max_degree: usize,
411    chart: &ChartRegion,
412    order: u32,
413) -> f64 {
414    let d = latent_dim as f64;
415    let big_d = max_degree as f64;
416    let rho = patch_rho(chart);
417    let residual_degree = max_degree.saturating_sub(order as usize) as i32;
418    // `.max(1.0)`: for ρ < 1 (small charts near the origin) the g-th jet sup is NOT
419    // dominated by the max-degree monomial `t^D` (whose g-th derivative ~ ρ^{D-g}
420    // shrinks with ρ) but by a LOWER-degree monomial whose g-th derivative is a
421    // larger constant — e.g. {1,t,t²}'s jacobian sup is the linear term's constant
422    // `1`, which exceeds `2ρ` when ρ < ½. Without the clamp the bound underestimates
423    // the true sup (numerically: D=3, ρ=0.1, g=1 → formula 0.03 vs true 1.0), which
424    // would make the certificate's Lipschitz `L` too small → a FALSE certificate.
425    // `D^g · max(ρ^{D-g}, 1)` upper-bounds `max_{q∈[g,D]} (q!/(q-g)!)·ρ^{q-g}` for
426    // all ρ (the `q=g` term gives `g! ≤ D^g`, the `q=D` term gives `≤ D^g·ρ^{D-g}`).
427    d.powi(order as i32) * big_d.powi(order as i32) * rho.powi(residual_degree).max(1.0)
428}
429
430impl BasisHessianLipschitz for DuchonCoordinateEvaluator {
431    /// Radial-kernel basis `Φ_m(t) = φ(r_m)`, `r_m = ‖t − c_m‖`, plus a
432    /// polynomial nullspace block. For the cubic Duchon kernel `φ(r) = r³` the
433    /// radial derivatives are `φ' = 3r²`, `φ'' = 6r`, `φ''' = 6`. The chain rule
434    /// to coordinate jets introduces `1/r` factors through the unit radial
435    /// direction `u = (t − c)/r` and the projector `(I − uuᵀ)/r`, so over a
436    /// chart the jets are bounded by combining the radial-derivative magnitudes
437    /// at the worst-case radius with the inverse-radius tail at the chart's
438    /// EXCLUSION radius `r_min` (the closest a chart point gets to any center):
439    ///
440    /// ```text
441    /// ‖∇φ‖    ≤ |φ'|                              ≤ 3 r_max²
442    /// ‖∇²φ‖   ≤ |φ''| + |φ'|/r                    ≤ 6 r_max + 3 r_max²/r_min
443    /// ‖∇³φ‖   ≤ |φ'''| + 3|φ''|/r + 3|φ'|/r²      ≤ 6 + 18 r_max/r_min + 9 r_max²/r_min²
444    /// ```
445    ///
446    /// (the `1/r`, `1/r²` tails are bounded by `1/r_min`, `1/r_min²`). The
447    /// polynomial nullspace block is degree ≤ `order`; its jets are bounded like
448    /// the monomial patch with `D = order`. The per-column sup is the max of the
449    /// kernel and polynomial bounds. The `r³` kernel is itself `C²` (no
450    /// singularity) so these tails are conservative but finite for any
451    /// `r_min > 0`; the atlas refines charts to keep `r_min` bounded away from 0.
452    fn value_sup(&self, chart: &ChartRegion) -> f64 {
453        let r_max = chart.radial_r_max.unwrap_or(chart.radius);
454        let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 0);
455        (r_max.powi(3)).max(poly)
456    }
457    fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
458        let r_max = chart.radial_r_max.unwrap_or(chart.radius);
459        let kernel = 3.0 * r_max * r_max;
460        let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 1);
461        kernel.max(poly)
462    }
463    fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
464        let r_max = chart.radial_r_max.unwrap_or(chart.radius);
465        let r_min = chart
466            .exclusion_r_min
467            .unwrap_or(chart.radius)
468            .max(f64::MIN_POSITIVE);
469        let kernel = 6.0 * r_max + 3.0 * r_max * r_max / r_min;
470        let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 2);
471        kernel.max(poly)
472    }
473    fn third_sup(&self, chart: &ChartRegion) -> f64 {
474        let r_max = chart.radial_r_max.unwrap_or(chart.radius);
475        let r_min = chart
476            .exclusion_r_min
477            .unwrap_or(chart.radius)
478            .max(f64::MIN_POSITIVE);
479        let kernel = 6.0 + 18.0 * r_max / r_min + 9.0 * r_max * r_max / (r_min * r_min);
480        let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 3);
481        kernel.max(poly)
482    }
483}
484
485/// Polynomial-block degree of a Duchon nullspace order, used to bound the
486/// nullspace columns like a monomial patch.
487trait DuchonOrderDegree {
488    fn order_degree(&self) -> usize;
489}
490
491impl DuchonOrderDegree for DuchonCoordinateEvaluator {
492    fn order_degree(&self) -> usize {
493        match self.order {
494            gam_terms::basis::DuchonNullspaceOrder::Zero => 0,
495            gam_terms::basis::DuchonNullspaceOrder::Linear => 1,
496            gam_terms::basis::DuchonNullspaceOrder::Degree(d) => d,
497        }
498    }
499}
500
501/// Per-column `g`-th jet sup of the Duchon polynomial nullspace block, treated
502/// as a monomial patch of degree `order_degree`.
503pub(crate) fn duchon_poly_jet_sup(
504    latent_dim: usize,
505    order_degree: usize,
506    chart: &ChartRegion,
507    order: u32,
508) -> f64 {
509    if order_degree == 0 {
510        return if order == 0 { 1.0 } else { 0.0 };
511    }
512    patch_jet_sup(latent_dim, order_degree, chart, order)
513}
514
515/// Decoder magnitude `Σ_m ‖B_{m,:}‖₂` of an atom's frozen decoder block: the
516/// factor that converts a per-column `Φ`-jet sup `B_g` into a reconstruction
517/// jet sup `‖∂^g m‖ ≤ |z|·decoder_row_norm_sum·B_g`.
518pub(crate) fn decoder_row_norm_sum(decoder: ArrayView2<'_, f64>) -> f64 {
519    let mut acc = 0.0;
520    for row in decoder.rows() {
521        acc += row.dot(&row).sqrt();
522    }
523    acc
524}
525
526#[derive(Debug, Clone, Copy)]
527pub(crate) struct ReconstructionJetSups {
528    pub(crate) value: f64,
529    pub(crate) jacobian: f64,
530    pub(crate) hessian: f64,
531    pub(crate) third: f64,
532}
533
534pub(crate) fn pair_trig_decoder_sup(
535    sin_row: ArrayView1<'_, f64>,
536    cos_row: ArrayView1<'_, f64>,
537) -> f64 {
538    let aa = sin_row.dot(&sin_row);
539    let bb = cos_row.dot(&cos_row);
540    let ab = sin_row.dot(&cos_row);
541    let trace = aa + bb;
542    let disc = ((aa - bb) * (aa - bb) + 4.0 * ab * ab).sqrt();
543    (0.5 * (trace + disc)).sqrt()
544}
545
546pub(crate) fn periodic_reconstruction_jet_sups(
547    decoder: ArrayView2<'_, f64>,
548) -> ReconstructionJetSups {
549    let mut value = 0.0;
550    let mut jacobian = 0.0;
551    let mut hessian = 0.0;
552    let mut third = 0.0;
553    if decoder.nrows() > 0 {
554        value += decoder.row(0).dot(&decoder.row(0)).sqrt();
555    }
556    let harmonics = decoder.nrows().saturating_sub(1) / 2;
557    for h in 1..=harmonics {
558        let sin_idx = 2 * h - 1;
559        let cos_idx = 2 * h;
560        let amp = pair_trig_decoder_sup(decoder.row(sin_idx), decoder.row(cos_idx));
561        let omega = std::f64::consts::TAU * h as f64;
562        value += amp;
563        jacobian += omega * amp;
564        hessian += omega.powi(2) * amp;
565        third += omega.powi(3) * amp;
566    }
567    for row in (1 + 2 * harmonics)..decoder.nrows() {
568        let amp = decoder.row(row).dot(&decoder.row(row)).sqrt();
569        value += amp;
570        let omega = std::f64::consts::TAU * harmonics.max(1) as f64;
571        jacobian += omega * amp;
572        hessian += omega.powi(2) * amp;
573        third += omega.powi(3) * amp;
574    }
575    ReconstructionJetSups {
576        value,
577        jacobian,
578        hessian,
579        third,
580    }
581}
582
583pub(crate) fn reconstruction_jet_sups(
584    atom: &SaeManifoldAtom,
585    sups: JetSups,
586) -> ReconstructionJetSups {
587    if matches!(atom.basis_kind, crate::manifold::SaeAtomBasisKind::Periodic) {
588        periodic_reconstruction_jet_sups(atom.decoder_coefficients.view())
589    } else {
590        let decoder_norm_sum = decoder_row_norm_sum(atom.decoder_coefficients.view());
591        ReconstructionJetSups {
592            value: decoder_norm_sum * sups.value,
593            jacobian: decoder_norm_sum * sups.jacobian,
594            hessian: decoder_norm_sum * sups.hessian,
595            third: decoder_norm_sum * sups.third,
596        }
597    }
598}
599
600/// The Hessian-Lipschitz constant `L` of the per-row encode objective `f_k` on
601/// a chart, assembled in closed form from the basis jet sups and the decoder /
602/// amplitude / target magnitudes. See the module docs for the derivation:
603///
604/// ```text
605/// L ≤ 3·‖J_m‖·‖∇²m‖ + ‖r‖·‖∇³m‖ + L_prior,
606/// ‖∂^g m‖ ≤ |z|·S_B·B_g,   S_B = Σ_m ‖B_{m,:}‖,
607/// ‖r‖ ≤ ‖x‖ + |z|·S_B·B_0,
608/// ```
609///
610/// `prior_lipschitz` is the caller-supplied closed-form `L_prior` of the
611/// ARD/von-Mises coordinate prior (`0.0` if no prior is active on the encode).
612pub(crate) fn hessian_lipschitz_constant(
613    recon_sups: ReconstructionJetSups,
614    amplitude: f64,
615    target_norm: f64,
616    prior_lipschitz: f64,
617) -> f64 {
618    let z = amplitude.abs();
619    let m_jac = z * recon_sups.jacobian;
620    let m_hess = z * recon_sups.hessian;
621    let m_third = z * recon_sups.third;
622    let recon_value = z * recon_sups.value;
623    let r_norm = target_norm + recon_value;
624    3.0 * m_jac * m_hess + r_norm * m_third + prior_lipschitz
625}
626
627/// One offline-certified chart: a center, its Kantorovich constants, and the
628/// certified Newton-convergence radius `R_c` solved from `h = β·η·L ≤ ½` at the
629/// worst-case in-chart start.
630#[derive(Debug, Clone)]
631pub struct CertifiedChart {
632    pub region: ChartRegion,
633    /// Closed-form Hessian-Lipschitz constant `L` over the chart.
634    pub lipschitz: f64,
635    /// `β = ‖F'(t_c)⁻¹‖` at the chart center (worst-case in-chart start uses
636    /// the center's curvature; the radius is solved so the certificate holds for
637    /// any start in the ball).
638    pub beta_center: f64,
639    /// Certified Newton radius: starts within `radius` of `t_c` satisfy `h ≤ ½`.
640    pub certified_radius: f64,
641    /// Distilled amortized-encoder Jacobian for this chart (#1026 ladder item 3).
642    ///
643    /// The exact encode map `x ↦ t` solves `F(t; x) = J_m(t)ᵀ(m(t) − x) = 0`. By
644    /// the implicit function theorem its derivative at the converged root is
645    /// `dt/dx = −(∂_t F)⁻¹ (∂_x F) = H⁻¹ J_m` (since `∂_x F = −J_m`), so the
646    /// first-order Taylor expansion of the encode map about this chart's center
647    /// `t_c` is the closed-form AFFINE predictor
648    ///
649    /// ```text
650    /// t(x) ≈ t_c + (1/z) · A₁ · (x − z · m₁(t_c)),   A₁ = (J₁ᵀJ₁ + ridge·I)⁻¹ J₁,
651    /// ```
652    ///
653    /// with `J₁ = Bᵀ J_Φ(t_c)` and `m₁(t_c) = BᵀΦ(t_c)` the AMPLITUDE-1
654    /// reconstruction jets (the amplitude `z` factors out analytically, so the
655    /// stored Jacobian is amplitude-free). This is the DISTILLED amortized
656    /// encoder of the #1026 thread: the per-row Hessian factorization + Newton
657    /// iteration is moved OFFLINE into this `d × p` matrix, leaving a single
658    /// `O(d·p)` mat-vec online — no per-row eigendecomposition, no second-jet
659    /// evaluation. The Kantorovich certificate is still evaluated AT the
660    /// predicted start, so the amortized prediction is trusted iff `h ≤ ½` and an
661    /// uncertified row still routes to the exact multi-start solve (the encoder
662    /// approximates inference, the certificate keeps it honest — the thread's
663    /// "encoder + certificate-gated exact fallback" deployment). `None` when the
664    /// center's Gauss–Newton block is singular (no certifiable amortization).
665    pub amortized_jacobian: Option<Array2<f64>>,
666    /// Amplitude-1 chart-center reconstruction `m₁(t_c) = BᵀΦ(t_c)` (length `p`),
667    /// the anchor the amortized predictor expands the encode map around.
668    pub recon_center: Array1<f64>,
669    /// Precomputed affine-predictor CONSTANT term `base = t_c − A₁·m₁(t_c)` (length
670    /// `d`), so the online amortized encode of a row `x` at amplitude `z` is the
671    /// single mat-vec `t̂ = base + (1/z)·A₁·x` with NO per-row `A₁·m₁` recompute.
672    /// Hoisting this atom-static term offline is what lets the massive-K index-routed
673    /// fast paths run a single allocation-free pass over rows (rather than a per-atom
674    /// GEMM sub-batch that degenerates to one row per group when `K ≫ N`). `None`
675    /// exactly when `amortized_jacobian` is `None` (singular Gauss–Newton block).
676    pub amortized_base: Option<Array1<f64>>,
677}
678
679/// The per-atom encode atlas: a set of certified charts covering the atom's
680/// coordinate domain, plus the decoder/amplitude scaling needed to recompute a
681/// per-row certificate online.
682#[derive(Debug, Clone)]
683pub struct AtomEncodeAtlas {
684    pub atom_index: usize,
685    pub latent_dim: usize,
686    pub decoder_norm_sum: f64,
687    pub charts: Vec<CertifiedChart>,
688}
689
690/// Result of a certified encode over a batch of rows, carrying the honesty
691/// flag: how many rows could NOT be certified and were flagged for the exact
692/// multi-start fallback (issue #1010 — no approximation enters silently).
693#[derive(Debug, Clone)]
694pub struct EncodeResult {
695    /// Per-row encoded latent coordinates (`n_rows × latent_dim`).
696    pub coords: Array2<f64>,
697    /// Per-row certificate: `true` ⇒ the row's start satisfied `h ≤ ½` and the
698    /// 1–2 Newton steps are exact-into-the-certified-ball; `false` ⇒ flagged.
699    pub certified: Vec<bool>,
700    /// Count of rows that could not be certified. These ride the payload so the
701    /// caller routes them to the exact multi-start encode — honesty, never
702    /// silent. Equals `certified.iter().filter(|c| !**c).count()`.
703    pub encode_uncertified_count: usize,
704}
705
706impl EncodeResult {
707    pub(crate) fn from_rows(coords: Array2<f64>, certified: Vec<bool>) -> Self {
708        let encode_uncertified_count = certified.iter().filter(|c| !**c).count();
709        Self {
710            coords,
711            certified,
712            encode_uncertified_count,
713        }
714    }
715}
716
717/// Per-row Kantorovich certificate at a start `t₀` for one atom encode.
718#[derive(Debug, Clone, Copy)]
719pub struct RowCertificate {
720    pub beta: f64,
721    pub eta: f64,
722    pub lipschitz: f64,
723    /// `h = β·η·L`. The row is certified iff `h ≤ ½`.
724    pub h: f64,
725}
726
727impl RowCertificate {
728    pub fn certified(&self) -> bool {
729        self.h.is_finite() && self.h <= KANTOROVICH_THRESHOLD
730    }
731}
732
733#[derive(Debug, Clone)]
734struct CertifiedEncodeProbe {
735    coord: Array1<f64>,
736    final_cert: RowCertificate,
737}
738
739/// Canonical flat-axis polynomial degree of a cylinder `S¹ × ℝ` atom — the
740/// degree the topology-race builder ([`gam_solve::structure_harvest`]) uses
741/// for the line axis (`CylinderHarmonicEvaluator::new(_, 2)`). The encode atlas
742/// recovers the circle harmonic count from the basis width using this degree, so
743/// the two must agree.
744pub(crate) const SAE_CYLINDER_LINE_DEGREE: usize = 2;
745
746/// Build a basis-family handle for one atom from its [`SaeManifoldAtom`]. The
747/// atlas needs to evaluate the jet sups, which live on the concrete evaluator
748/// types; the atom carries the evaluator as `Arc<dyn SaeBasisEvaluator>`, so we
749/// reconstruct the family bound from the atom's basis kind + width + centers.
750pub(crate) fn family_jet_sups(
751    atom: &SaeManifoldAtom,
752    chart: &ChartRegion,
753) -> Result<JetSups, String> {
754    use crate::manifold::SaeAtomBasisKind::*;
755    let m = atom.basis_size();
756    let d = atom.latent_dim;
757    let sups = match &atom.basis_kind {
758        Periodic => {
759            let ev = PeriodicHarmonicEvaluator::new(m)?;
760            JetSups::from_family(&ev, chart)
761        }
762        Torus => {
763            // Torus basis width is `(2H+1)^d`; recover the per-axis harmonic
764            // count `H` from `axis_m = m^(1/d)` rather than a sum formula.
765            let axis_m = integer_root(m, d.max(1));
766            let num_harmonics = axis_m.saturating_sub(1) / 2;
767            let ev = TorusHarmonicEvaluator::new(d, num_harmonics.max(1))?;
768            JetSups::from_family(&ev, chart)
769        }
770        Sphere => {
771            let ev = SphereChartEvaluator;
772            JetSups::from_family(&ev, chart)
773        }
774        Cylinder => {
775            // Cylinder width is `(2H+1)·(D+1)` with the canonical flat-axis
776            // degree `D = SAE_CYLINDER_LINE_DEGREE` (the harvest convention).
777            // Recover the per-axis circle harmonic count `H` from
778            // `2H+1 = m/(D+1)`.
779            let ml = SAE_CYLINDER_LINE_DEGREE + 1;
780            if d != 2 || ml == 0 || m % ml != 0 {
781                return Err(format!(
782                    "EncodeAtlas: Cylinder atom requires latent_dim == 2 and width divisible by {ml}; got dim={d}, m={m}"
783                ));
784            }
785            let axis_mc = m / ml;
786            let h = axis_mc.saturating_sub(1) / 2;
787            let ev = CylinderHarmonicEvaluator::new(h.max(1), SAE_CYLINDER_LINE_DEGREE)?;
788            JetSups::from_family(&ev, chart)
789        }
790        Linear | EuclideanPatch | Poincare => {
791            // The patch width fixes max_degree implicitly; bound by a degree that
792            // covers the column count (conservative). Degree d-patch column count
793            // grows fast; we recover the smallest degree whose patch is ≥ m.
794            // Poincare atoms use the same tangent-coordinate polynomial decoder;
795            // their intrinsic smoothness differs in the penalty, not in Phi(t).
796            let degree = euclidean_patch_degree(d, m);
797            let ev = EuclideanPatchEvaluator::new(d, degree)?;
798            JetSups::from_family(&ev, chart)
799        }
800        Duchon => {
801            // UNSOUND — DO NOT TRUST for a certificate (F2/F3). This bound
802            // hard-codes cubic `φ(r) = r³` radial jets and a single origin center
803            // (`duchon_centers_from_atom`), but the real Duchon kernel is the
804            // polyharmonic `c·r^(2m−d)` (with log variants) over the atom's
805            // data-placed centers — so it can UNDER-estimate L and issue a FALSE
806            // certificate (the module's own warning). The atom does not expose its
807            // real order / center matrix / scaling to this crate, so no sound bound
808            // is available. `build_atom_atlas_from_centers` therefore REFUSES to
809            // certify Duchon atoms (emits uncertified charts → exact-encode
810            // fallback) and never reaches this arm; it is retained only so the
811            // family dispatch is total. If a future change threads the real
812            // order/centers here, replace this with the true `φ_{m,d}` jet bounds.
813            let centers = duchon_centers_from_atom(atom);
814            let conservative_m = m.max(1);
815            let ev = DuchonCoordinateEvaluator::new(centers, conservative_m)?;
816            JetSups::from_family(&ev, chart)
817        }
818        Precomputed(name) => {
819            return Err(format!(
820                "EncodeAtlas: precomputed basis '{name}' has no closed-form jet sup; route to exact encode"
821            ));
822        }
823    };
824    Ok(sups)
825}
826
827/// Smallest monomial-patch degree whose column count covers `m` basis columns.
828pub(crate) fn euclidean_patch_degree(latent_dim: usize, m: usize) -> usize {
829    // Column count of a degree-D patch in d vars is C(d+D, D). Grow D until it
830    // covers m; cap at m so a degenerate width still terminates.
831    let mut degree = 0usize;
832    while patch_column_count(latent_dim, degree) < m && degree < m {
833        degree += 1;
834    }
835    degree
836}
837
838/// Largest integer `a` with `a^k ≤ n` (the floor of the `k`-th root). Used to
839/// recover the per-axis harmonic width `axis_m` from a torus basis width
840/// `m = axis_m^d`.
841pub(crate) fn integer_root(n: usize, k: usize) -> usize {
842    if k == 0 {
843        return 1;
844    }
845    if k == 1 {
846        return n;
847    }
848    let mut a = 1usize;
849    loop {
850        let next = a + 1;
851        let mut pow: u128 = 1;
852        let mut overflow = false;
853        for _ in 0..k {
854            pow = pow.saturating_mul(next as u128);
855            if pow > n as u128 {
856                overflow = true;
857                break;
858            }
859        }
860        if overflow {
861            return a;
862        }
863        a = next;
864    }
865}
866
867pub(crate) fn patch_column_count(latent_dim: usize, degree: usize) -> usize {
868    // C(d + D, D)
869    let mut num = 1u128;
870    let mut den = 1u128;
871    for i in 1..=degree {
872        num *= (latent_dim + i) as u128;
873        den *= i as u128;
874    }
875    (num / den) as usize
876}
877
878/// Recover Duchon centers from an atom: when the evaluator is unavailable the
879/// atlas falls back to the atom's own latent-coordinate hull as the center set,
880/// which only affects the radial-tail bound conservatively.
881pub(crate) fn duchon_centers_from_atom(atom: &SaeManifoldAtom) -> Array2<f64> {
882    // One center at the origin in latent_dim space is a sound conservative
883    // default: the chart's own r_min / r_max bracket the true radial range.
884    Array2::<f64>::zeros((1, atom.latent_dim.max(1)))
885}
886
887/// The four per-column jet sups of a basis family over a chart.
888#[derive(Debug, Clone, Copy)]
889pub(crate) struct JetSups {
890    pub(crate) value: f64,
891    pub(crate) jacobian: f64,
892    pub(crate) hessian: f64,
893    pub(crate) third: f64,
894}
895
896impl JetSups {
897    pub(crate) fn from_family<B: BasisHessianLipschitz>(family: &B, chart: &ChartRegion) -> Self {
898        Self {
899            value: family.value_sup(chart),
900            jacobian: family.jacobian_sup(chart),
901            hessian: family.hessian_sup(chart),
902            third: family.third_sup(chart),
903        }
904    }
905}
906
907/// Evaluate one atom's encode objective gradient `F(t) = ∇f_k(t)` and the FULL
908/// Hessian `F'(t) = ∇²f_k(t)` at a single coordinate `t`, for a single target
909/// row `x` and fixed amplitude `z`. With `m(t) = z·BᵀΦ(t)`, `r = m − x`,
910/// `J_m = z·Bᵀ J_Φ`:
911///
912/// ```text
913/// g_t[a]   = J_m[a] · r                                  (= ∇f)
914/// H_tt[a,b] = J_m[a] · J_m[b] + r · ∂²m/∂t_a∂t_b         (= ∇²f, FULL Hessian)
915/// ```
916///
917/// The certificate uses the FULL Hessian rather than the Gauss-Newton block
918/// `J_mᵀ J_m`. This is the principled choice for Newton–Kantorovich: the
919/// theorem certifies convergence of Newton on `F = ∇f` to the unique nearby
920/// ROOT of `∇f`, but a root of `∇f` can be a maximum. The full Hessian is
921/// positive-definite exactly on the genuine-minimum basin, so requiring
922/// `λ_min(H) > 0` (finite `β`) is what flags a start that would otherwise let
923/// Gauss-Newton march into the wrong root (e.g. the circle antipode, a local
924/// max where `∇f = 0` but the full curvature is negative). The residual term
925/// needs the basis second jet `∂²Φ/∂t²`; an evaluator without one returns
926/// `None`, and the row is flagged (no silent Gauss-Newton fallback).
927///
928/// The Hessian returned is the TRUE `∇²f_k` — no Levenberg ridge is added
929/// (F2). The Kantorovich certificate (`row_certificate`) and its `λ_min(H) > 0`
930/// saddle gate must see the genuine field: a ridged `H + λI` certifies neither
931/// the original objective nor a consistently regularized one (a
932/// locally-constant reconstruction has `H = 0`, whose ridged `λI` would falsely
933/// certify a non-isolated, non-unique root). Ridge stays only in the
934/// UNCERTIFIED amortized predictor (`center_amortized_jacobian`/`center_beta`).
935pub(crate) fn encode_grad_hess(
936    atom: &SaeManifoldAtom,
937    evaluator: &dyn SaeBasisEvaluator,
938    t: ArrayView1<'_, f64>,
939    x: ArrayView1<'_, f64>,
940    amplitude: f64,
941) -> Result<Option<(Array1<f64>, Array2<f64>)>, String> {
942    let d = atom.latent_dim;
943    let p = atom.output_dim();
944    let m = atom.basis_size();
945    let coords = t.to_shape((1, d)).map_err(|e| e.to_string())?.to_owned();
946    let (phi, jet) = evaluator.evaluate(coords.view())?;
947    if phi.dim() != (1, m) {
948        return Err(format!(
949            "encode_grad_hess: evaluator returned phi {:?}, expected (1, {m})",
950            phi.dim()
951        ));
952    }
953    let decoder = &atom.decoder_coefficients;
954    // Reconstruction m(t) = z · Bᵀ Φ(t)  ∈ ℝᵖ.
955    let mut recon = Array1::<f64>::zeros(p);
956    for basis_col in 0..m {
957        let phi_v = phi[[0, basis_col]];
958        if phi_v == 0.0 {
959            continue;
960        }
961        for out in 0..p {
962            recon[out] += amplitude * phi_v * decoder[[basis_col, out]];
963        }
964    }
965    let residual = &recon - &x;
966    // J_m[axis] = z · Bᵀ (∂Φ/∂t_axis)  ∈ ℝᵖ.
967    let mut jm = Array2::<f64>::zeros((d, p));
968    for axis in 0..d {
969        for basis_col in 0..m {
970            let dphi = jet[[0, basis_col, axis]];
971            if dphi == 0.0 {
972                continue;
973            }
974            for out in 0..p {
975                jm[[axis, out]] += amplitude * dphi * decoder[[basis_col, out]];
976            }
977        }
978    }
979    // The full-Hessian residual term needs ∂²Φ/∂t². No second jet ⇒ no
980    // certificate (flag), never a silent Gauss-Newton substitute.
981    let second = match evaluator.second_jet_dyn(coords.view()) {
982        Some(result) => result?,
983        None => return Ok(None),
984    };
985    // Residual · decoder-row `r·B_{basis,:}` is INDEPENDENT of the (a,b) axes, yet
986    // the old code recomputed it `d²` times inside the Hessian double loop. Hoist it
987    // to one O(m·p) pass so the per-axis curvature term is a cheap O(m) dot.
988    let mut rd = vec![0.0_f64; m];
989    for (basis_col, rd_col) in rd.iter_mut().enumerate() {
990        let mut dot = 0.0;
991        for out in 0..p {
992            dot += residual[out] * decoder[[basis_col, out]];
993        }
994        *rd_col = dot;
995    }
996    // g_t[axis] = J_m[axis] · r ;  H_tt[a,b] = J_m[a]·J_m[b] + r·∂²m/∂t_a∂t_b.
997    // The full Hessian is symmetric (Gauss-Newton block + symmetric second jet), so
998    // compute the upper triangle and mirror — half the curvature work.
999    let mut g = Array1::<f64>::zeros(d);
1000    let mut h = Array2::<f64>::zeros((d, d));
1001    for a in 0..d {
1002        let ja = jm.row(a);
1003        g[a] = ja.dot(&residual);
1004        for b in a..d {
1005            // Gauss-Newton block.
1006            let mut hab = ja.dot(&jm.row(b));
1007            // Residual · second-jet curvature: r · ∂²m_{ab},
1008            // ∂²m_{ab}[out] = z · Σ_basis (∂²Φ/∂t_a∂t_b) · B[basis, out].
1009            let mut curv = 0.0;
1010            for basis_col in 0..m {
1011                let d2phi = second[[0, basis_col, a, b]];
1012                if d2phi == 0.0 {
1013                    continue;
1014                }
1015                curv += amplitude * d2phi * rd[basis_col];
1016            }
1017            hab += curv;
1018            h[[a, b]] = hab;
1019            h[[b, a]] = hab;
1020        }
1021    }
1022    // NO ridge: the certificate must use the TRUE Hessian (F2). See the doc above.
1023    Ok(Some((g, h)))
1024}
1025
1026/// Operator-norm of `H⁻¹` (i.e. `β = 1/λ_min(H)`) and the Newton step
1027/// `δ = −H⁻¹ g` with `η = ‖δ‖`, from a symmetric PSD `H` and gradient `g`.
1028/// Returns `None` when `H` is numerically singular (λ_min ≤ 0) — an
1029/// uncertifiable start.
1030pub(crate) fn beta_eta_newton(
1031    h: ArrayView2<'_, f64>,
1032    g: ArrayView1<'_, f64>,
1033) -> Result<Option<(f64, f64, Array1<f64>)>, String> {
1034    // Closed-form fast paths for the tiny latent dims that dominate SAE atoms
1035    // (`d = 1, 2`), avoiding a faer eigendecomposition + its heap allocations on
1036    // the hottest per-row Newton inner loop. `β = 1/λ_min(H)`, `δ = −H⁻¹g`, and the
1037    // `λ_min ≤ 0` gate are computed directly; the symmetric-`H` reads mirror
1038    // `eigh(Side::Lower)` (which uses the lower triangle) exactly.
1039    let d = h.nrows();
1040    if d == 1 {
1041        let h00 = h[[0, 0]];
1042        if !(h00.is_finite() && h00 > 0.0) {
1043            return Ok(None);
1044        }
1045        let delta0 = -g[0] / h00;
1046        let mut delta = Array1::<f64>::zeros(1);
1047        delta[0] = delta0;
1048        return Ok(Some((1.0 / h00, delta0.abs(), delta)));
1049    }
1050    if d == 2 {
1051        // Symmetric H = [[a, b], [b, c]] read from the lower triangle.
1052        let a = h[[0, 0]];
1053        let b = h[[1, 0]];
1054        let c = h[[1, 1]];
1055        let tr = a + c;
1056        let det = a * c - b * b;
1057        // λ_min = ½(tr − √((a−c)² + 4b²)); ≥ 0 ⇒ H PSD, > 0 ⇒ PD.
1058        let disc = ((a - c) * (a - c) + 4.0 * b * b).max(0.0).sqrt();
1059        let lambda_min = 0.5 * (tr - disc);
1060        if !(lambda_min.is_finite() && lambda_min > 0.0) {
1061            return Ok(None);
1062        }
1063        // δ = −H⁻¹g with H⁻¹ = [[c, −b], [−b, a]] / det (det = λ_min·λ_max > 0).
1064        let inv_det = 1.0 / det;
1065        let g0 = g[0];
1066        let g1 = g[1];
1067        let d0 = -(c * g0 - b * g1) * inv_det;
1068        let d1 = -(a * g1 - b * g0) * inv_det;
1069        if !(d0.is_finite() && d1.is_finite()) {
1070            return Ok(None);
1071        }
1072        let mut delta = Array1::<f64>::zeros(2);
1073        delta[0] = d0;
1074        delta[1] = d1;
1075        let eta = (d0 * d0 + d1 * d1).sqrt();
1076        return Ok(Some((1.0 / lambda_min, eta, delta)));
1077    }
1078    let (vals, vecs) = h
1079        .eigh(Side::Lower)
1080        .map_err(|e| format!("beta_eta_newton: eigh failed: {e:?}"))?;
1081    let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
1082    if !(lambda_min.is_finite() && lambda_min > 0.0) {
1083        return Ok(None);
1084    }
1085    let beta = 1.0 / lambda_min;
1086    // Newton step δ = −H⁻¹ g via the eigendecomposition: δ = −Σ_i (vᵢᵀg/λᵢ) vᵢ.
1087    let mut delta = Array1::<f64>::zeros(d);
1088    for (col, &lam) in vals.iter().enumerate() {
1089        if lam <= 0.0 {
1090            return Ok(None);
1091        }
1092        let vi = vecs.column(col);
1093        let coeff = vi.dot(&g) / lam;
1094        for row in 0..d {
1095            delta[row] -= coeff * vi[row];
1096        }
1097    }
1098    let eta = delta.dot(&delta).sqrt();
1099    Ok(Some((beta, eta, delta)))
1100}
1101
1102/// Compute the per-row Kantorovich certificate for encoding target row `x`
1103/// against atom `atom` at start coordinate `t₀`, with fixed amplitude `z` and
1104/// the chart's closed-form Lipschitz constant `lipschitz`. Returns the
1105/// certificate AND the Newton step `δ = −H⁻¹ g` so the caller can advance.
1106pub fn row_certificate(
1107    atom: &SaeManifoldAtom,
1108    evaluator: &dyn SaeBasisEvaluator,
1109    t0: ArrayView1<'_, f64>,
1110    x: ArrayView1<'_, f64>,
1111    amplitude: f64,
1112    lipschitz: f64,
1113) -> Result<(RowCertificate, Array1<f64>), String> {
1114    let uncertified = || {
1115        (
1116            RowCertificate {
1117                beta: f64::INFINITY,
1118                eta: f64::INFINITY,
1119                lipschitz,
1120                h: f64::INFINITY,
1121            },
1122            Array1::<f64>::zeros(atom.latent_dim),
1123        )
1124    };
1125    // No second jet ⇒ no full Hessian ⇒ uncertifiable (flag).
1126    let Some((g, h)) = encode_grad_hess(atom, evaluator, t0, x, amplitude)? else {
1127        return Ok(uncertified());
1128    };
1129    match beta_eta_newton(h.view(), g.view())? {
1130        Some((beta, eta, delta)) => {
1131            let cert = RowCertificate {
1132                beta,
1133                eta,
1134                lipschitz,
1135                h: beta * eta * lipschitz,
1136            };
1137            Ok((cert, delta))
1138        }
1139        // Indefinite / negative-curvature full Hessian: the start is at or past
1140        // a basin boundary (a max/saddle of f), not the minimum basin — flag.
1141        None => Ok(uncertified()),
1142    }
1143}
1144
1145fn uncertified_certificate(lipschitz: f64) -> RowCertificate {
1146    RowCertificate {
1147        beta: f64::INFINITY,
1148        eta: f64::INFINITY,
1149        lipschitz,
1150        h: f64::INFINITY,
1151    }
1152}
1153
1154fn refine_certified_start(
1155    atom: &SaeManifoldAtom,
1156    evaluator: &dyn SaeBasisEvaluator,
1157    mut t: Array1<f64>,
1158    x: ArrayView1<'_, f64>,
1159    amplitude: f64,
1160    lipschitz: f64,
1161    newton_steps: usize,
1162    initial_cert: RowCertificate,
1163    mut delta: Array1<f64>,
1164) -> Result<Option<CertifiedEncodeProbe>, String> {
1165    assert!(initial_cert.certified());
1166    let mut final_cert = initial_cert;
1167    for _ in 0..newton_steps {
1168        // Convergence early-exit: the pending Newton step is below the coordinate
1169        // ULP scale, so `t + δ == t` to f64 resolution — the certified root is
1170        // reached and the remaining fixed-budget steps would only re-accumulate
1171        // round-off. This is where the well-conditioned quadratic Newton tail's
1172        // redundant `evaluate` + `second_jet` work is eliminated.
1173        if delta.dot(&delta).sqrt() <= NEWTON_REFINE_CONVERGED_EPS * (1.0 + t.dot(&t).sqrt()) {
1174            break;
1175        }
1176        t = &t + &delta;
1177        let (cert, next_delta) =
1178            row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz)?;
1179        if !cert.certified() {
1180            return Ok(None);
1181        }
1182        final_cert = cert;
1183        delta = next_delta;
1184    }
1185    Ok(Some(CertifiedEncodeProbe {
1186        coord: t,
1187        final_cert,
1188    }))
1189}
1190
1191/// Certify an encode probe from `t_start`, navigating into the Kantorovich basin
1192/// first if needed (#1154/#1026). The Kantorovich quantity `h = β·η·L` scales with
1193/// amplitude through `L`, so at unit amplitude a positive-definite chart-center /
1194/// distilled start can sit OUTSIDE the certified ball (`h > ½`). Rather than
1195/// flagging it uncertified immediately — which made the encoder certify ZERO
1196/// held-out rows at amplitude 1.0 and fall back to the exact solve for everything —
1197/// take plain Newton steps toward the root, re-certifying at each iterate, while
1198/// the Kantorovich quantity `h = β·η·L` keeps CONTRACTING toward the ½ bound. The
1199/// certificate at the landing point is a full Kantorovich guarantee from there
1200/// (`h ≤ ½` ⇒ Newton converges to the in-ball root), so this only ever WIDENS the
1201/// certified set; it never certifies a non-convergent start.
1202///
1203/// Termination is the natural Newton stopping rule — there is no arbitrary step
1204/// budget. The warm-up stops and flags for the exact fallback when either the start
1205/// is not steppable (indefinite / non-finite Hessian — at or past a basin boundary)
1206/// or a step fails to reduce `h` (the iterate is not approaching a certifiable
1207/// in-chart root: its root lies outside this chart's valid Lipschitz region, or the
1208/// start was past the basin — empirically the rows that miss *plateau*, so more
1209/// steps cannot help; the lever there is denser charts, not more iterations). On
1210/// success the start is refined `newton_steps` further by [`refine_certified_start`].
1211/// Minimum per-step *multiplicative* decrease of the Kantorovich `h` the basin
1212/// warm-up requires to keep stepping (FIX #4). A tiny geometric floor: a
1213/// continued step must contract `h` by at least this fraction. Chosen small so
1214/// it never bites a genuinely converging row (Newton in the Kantorovich regime
1215/// is at least geometric and quadratic once `h < 1`, contracting `h` far faster
1216/// than `1/64` per step) while still forcing termination on a plateau.
1217const WARMUP_MIN_MULTIPLICATIVE_DECREASE: f64 = 1.0 / 64.0;
1218
1219/// Kantorovich-quadratic acceptance coefficient for the basin warm-up (FIX #4).
1220/// Once `h < 1` a converging Newton step contracts quadratically (`h_new ≲ κ·h²`);
1221/// accepting that path in addition to the geometric floor makes the
1222/// "genuinely-converging rows are untouched" guarantee explicit. Kept `< 1` so
1223/// the quadratic path is itself a strict contraction (for `h < 1`,
1224/// `κ·h² < κ·h < h`), which preserves the termination bound below.
1225const WARMUP_QUADRATIC_KAPPA: f64 = 0.5;
1226
1227/// Sufficient-decrease test for the basin-warmup loop (FIX #4).
1228///
1229/// Returns `true` while the warm-up should keep stepping. The previous exit rule
1230/// (`h_new >= h_prev` — strict decrease or quit) never fires for an `h`-sequence
1231/// that decreases *monotonically toward a limit above ½*: the increments fall
1232/// below one ulp long before `h` crosses the certifiable ½ bound, so a single
1233/// pathological row could spin ~1e15 full `row_certificate` solves (Hessian
1234/// build + solve) on the encode hot path. We instead require genuine
1235/// *multiplicative* progress each step, which matches Newton's actual behavior:
1236/// a healthy contraction clears the geometric floor by a wide margin (and the
1237/// quadratic path once `h < 1`), while a plateau (`h_new/h_prev → 1`) satisfies
1238/// neither branch and flags to the exact fallback.
1239///
1240/// Termination bound: the warm-up loop runs only while the row is uncertified
1241/// (`h_prev > ½`), and every continued step contracts `h` by at least the factor
1242/// `(1 − WARMUP_MIN_MULTIPLICATIVE_DECREASE)` (the quadratic branch, for `h < 1`,
1243/// contracts by `κ·h_prev < κ < 1`, i.e. even harder). Hence after `N` continued
1244/// steps `h ≤ (1 − c)^N · h₀`, and since the loop stops once `h ≤ ½` we get
1245/// `N < ln(2·h₀) / −ln(1 − c)` — a few hundred iterations at most, versus
1246/// unbounded before. No arbitrary fixed step cap is imposed; the bound is a
1247/// consequence of the contraction requirement, in keeping with the function's
1248/// no-magic-budget design.
1249fn warmup_progress_sufficient(h_new: f64, h_prev: f64) -> bool {
1250    if !(h_new.is_finite() && h_prev.is_finite()) {
1251        return false;
1252    }
1253    // Geometric floor: a real Newton contraction easily clears this.
1254    if h_new <= (1.0 - WARMUP_MIN_MULTIPLICATIVE_DECREASE) * h_prev {
1255        return true;
1256    }
1257    // Kantorovich-quadratic path (only once `h < 1`, where it is a strict
1258    // contraction): makes the no-regression guarantee for converging rows
1259    // explicit without ever admitting a non-contracting (plateau) step.
1260    h_prev < 1.0 && h_new <= WARMUP_QUADRATIC_KAPPA * h_prev * h_prev
1261}
1262
1263/// Whether the basin warm-up must REJECT the just-taken step (flag to the exact
1264/// fallback). FIX (F6): a step that just crossed into the certified region
1265/// (`h ≤ ½`) is ALWAYS accepted, even if its multiplicative decrease was below
1266/// the progress floor (e.g. `h: 0.501 → 0.499`, a legitimate but tiny cross of
1267/// the ½ bound). Only a step that is STILL uncertified AND failed to make
1268/// sufficient multiplicative progress is a plateau that must flag. The old code
1269/// ran the progress test unconditionally, so it could reject an in-hand
1270/// certificate and push the row to the exact fallback for nothing (a false
1271/// negative).
1272fn warmup_should_reject(next_certified: bool, h_new: f64, h_prev: f64) -> bool {
1273    !next_certified && !warmup_progress_sufficient(h_new, h_prev)
1274}
1275
1276fn certify_with_basin_warmup(
1277    atom: &SaeManifoldAtom,
1278    evaluator: &dyn SaeBasisEvaluator,
1279    t_start: Array1<f64>,
1280    x: ArrayView1<'_, f64>,
1281    amplitude: f64,
1282    lipschitz: f64,
1283    newton_steps: usize,
1284    chart_center: ArrayView1<'_, f64>,
1285    chart_radius: f64,
1286) -> Result<Option<CertifiedEncodeProbe>, String> {
1287    // SOUNDNESS GUARD: `lipschitz` is the chart's Hessian-Lipschitz sup, which is
1288    // only a valid bound over this chart's ball `‖t − center‖ ≤ radius` for the
1289    // chart-local families (`EuclideanPatch`/`Linear`/`Poincare` monomial patches,
1290    // `Cylinder` line axis, `Duchon` radial kernels). If a warm-up iterate leaves
1291    // that ball, `row_certificate` would compute `h = β·η·L` with an `L` that no
1292    // longer bounds the true geometry there, so `h ≤ ½` would NOT imply Kantorovich
1293    // convergence — a false certificate. (The `h`-contraction check does NOT catch
1294    // this: `h` can decrease monotonically toward an out-of-chart root the whole
1295    // way.) So we keep every certified iterate inside the chart; a row whose root is
1296    // outside this chart flags for the exact fallback — its lever is a denser grid,
1297    // not a step using an invalid `L`. Global-`L` families (periodic/torus/sphere)
1298    // route their points to charts whose centers are near the root, so the guard
1299    // rarely trips for them, and where it does the row was out-of-chart anyway.
1300    let in_chart = |t: &Array1<f64>| -> bool {
1301        // Wrap-aware chart containment. A raw Euclidean latent distance
1302        // `Σ(tᵢ − centerᵢ)²` mis-measures separation across the wrap seam of a
1303        // periodic axis: an iterate at `t = 0.99` against a chart centered at
1304        // `0.01` reads distance `0.98` instead of the true circle distance
1305        // `0.02`, so it is wrongly rejected at the start check and at every step
1306        // check — silently disabling the amortized encoder for a phase-localized
1307        // band of rows on every periodic / torus / cylinder-angle /
1308        // sphere-longitude chart and pushing that band to the multi-start
1309        // fallback. `latent_coordinate_distance` folds each axis onto its
1310        // `latent_axis_period`, so the containment ball is measured in the true
1311        // manifold geometry — exactly the geometry the chart's Lipschitz bound
1312        // `L` is valid over. This only ever moves points that are genuinely
1313        // near the center (small wrapped distance) back INTO the chart where
1314        // `L` holds, so it widens acceptance without weakening the soundness
1315        // guard (an iterate truly far from the center on the circle still fails).
1316        latent_coordinate_distance(atom, t.view(), chart_center) <= chart_radius
1317    };
1318    let mut t = t_start;
1319    // The distilled / chart-center start must itself be in-chart for its certificate
1320    // to be valid; a bad IFT prediction landing outside the chart is uncertifiable.
1321    if !in_chart(&t) {
1322        return Ok(None);
1323    }
1324    let (mut cert, mut delta) =
1325        row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz)?;
1326    while !cert.certified() {
1327        // Not steppable (indefinite / non-finite Hessian): flag.
1328        if !(cert.h.is_finite() && cert.beta.is_finite() && cert.eta.is_finite()) {
1329            return Ok(None);
1330        }
1331        let prev_h = cert.h;
1332        let next = &t + &delta;
1333        // Refuse to step where the chart's `L` is no longer valid (see guard above).
1334        if !in_chart(&next) {
1335            return Ok(None);
1336        }
1337        t = next;
1338        let (next_cert, next_delta) =
1339            row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz)?;
1340        cert = next_cert;
1341        delta = next_delta;
1342        // The warm-up only helps while h keeps *multiplicatively* contracting
1343        // toward ½. A plain strict-decrease test (`h >= prev_h`) never fires for
1344        // a sequence that decreases monotonically toward a limit above ½, so it
1345        // could spin ~1e15 `row_certificate` solves for one row; require genuine
1346        // multiplicative progress instead (bounded to a few hundred steps, see
1347        // `warmup_progress_sufficient`). Once a step fails that bar the iterate is
1348        // not converging to a certifiable in-chart root — flag for the exact
1349        // fallback (no arbitrary step budget; the bound falls out of the
1350        // contraction requirement).
1351        //
1352        // F6: only enforce the progress bar while STILL uncertified. If this step
1353        // just crossed `h ≤ ½` (e.g. 0.501 → 0.499, a decrease too small to clear
1354        // the multiplicative floor) the point is already certified — the loop
1355        // guard `while !cert.certified()` will exit and refine it. Running the
1356        // progress test unconditionally would falsely reject an in-hand certificate
1357        // and push the row to the exact fallback for nothing. See
1358        // [`warmup_should_reject`].
1359        if warmup_should_reject(cert.certified(), cert.h, prev_h) {
1360            return Ok(None);
1361        }
1362    }
1363    refine_certified_start(
1364        atom,
1365        evaluator,
1366        t,
1367        x,
1368        amplitude,
1369        lipschitz,
1370        newton_steps,
1371        cert,
1372        delta,
1373    )
1374}
1375
1376fn kantorovich_root_radius(cert: RowCertificate) -> f64 {
1377    if !cert.certified() || !(cert.eta.is_finite() && cert.eta >= 0.0) {
1378        return f64::INFINITY;
1379    }
1380    if cert.eta == 0.0 {
1381        return 0.0;
1382    }
1383    if !(cert.h.is_finite() && cert.h >= 0.0) {
1384        return f64::INFINITY;
1385    }
1386    let h = cert.h.min(KANTOROVICH_THRESHOLD);
1387    let discriminant = (1.0 - 2.0 * h).max(0.0).sqrt();
1388    let radius = 2.0 * cert.eta / (1.0 + discriminant);
1389    if radius.is_finite() {
1390        radius
1391    } else {
1392        f64::INFINITY
1393    }
1394}
1395
1396fn distilled_probe_tolerance(
1397    amortized: &CertifiedEncodeProbe,
1398    cold: &CertifiedEncodeProbe,
1399    amplitude: f64,
1400    x: ArrayView1<'_, f64>,
1401) -> f64 {
1402    let certified_radius =
1403        kantorovich_root_radius(amortized.final_cert) + kantorovich_root_radius(cold.final_cert);
1404    let coord_scale = amortized.coord.dot(&amortized.coord).sqrt()
1405        + cold.coord.dot(&cold.coord).sqrt()
1406        + x.dot(&x).sqrt()
1407        + amplitude.abs()
1408        + 1.0;
1409    certified_radius + 1024.0 * f64::EPSILON * coord_scale
1410}
1411
1412fn latent_coordinate_distance(
1413    atom: &SaeManifoldAtom,
1414    lhs: ArrayView1<'_, f64>,
1415    rhs: ArrayView1<'_, f64>,
1416) -> f64 {
1417    let mut acc = 0.0;
1418    for axis in 0..lhs.len().min(rhs.len()) {
1419        let mut diff = (lhs[axis] - rhs[axis]).abs();
1420        if let Some(period) = latent_axis_period(atom, axis) {
1421            let wrapped = diff.rem_euclid(period);
1422            diff = wrapped.min(period - wrapped);
1423        }
1424        acc += diff * diff;
1425    }
1426    acc.sqrt()
1427}
1428
1429fn latent_axis_period(atom: &SaeManifoldAtom, axis: usize) -> Option<f64> {
1430    use crate::manifold::SaeAtomBasisKind::*;
1431    match &atom.basis_kind {
1432        Periodic | Torus => Some(1.0),
1433        Cylinder if axis == 0 => Some(1.0),
1434        Sphere if axis == 1 => Some(std::f64::consts::TAU),
1435        _ => None,
1436    }
1437}
1438
1439/// Configuration for [`EncodeAtlas`] construction and online encode. All fields
1440/// are explicit; the atlas never reads global state and adds no CLI flags.
1441#[derive(Debug, Clone, Copy)]
1442pub struct AtlasConfig {
1443    /// Grid resolution per latent axis for offline chart centers (the
1444    /// SHAPE_BAND grid idiom).
1445    pub grid_resolution: usize,
1446    /// Levenberg ridge floor added to the per-row Gauss-Newton Hessian.
1447    pub ridge: f64,
1448    /// Number of online Newton refinement steps after a certified start (1 or 2
1449    /// per issue #1010).
1450    pub newton_steps: usize,
1451}
1452
1453impl Default for AtlasConfig {
1454    fn default() -> Self {
1455        Self {
1456            grid_resolution: 16,
1457            ridge: 1.0e-9,
1458            newton_steps: 2,
1459        }
1460    }
1461}
1462
1463/// The encode atlas: per-atom certified charts plus the online certified-encode
1464/// driver (issue #1010).
1465#[derive(Debug, Clone)]
1466pub struct EncodeAtlas {
1467    pub atoms: Vec<AtomEncodeAtlas>,
1468    pub config: AtlasConfig,
1469}
1470
1471impl EncodeAtlas {
1472    /// Build the offline atlas over a frozen dictionary: for each atom, lay down
1473    /// chart centers on the atom's coordinate grid and certify a Newton radius
1474    /// from the Kantorovich inequality at the worst-case in-chart start.
1475    ///
1476    /// `amplitude_bound[k]` is the per-atom bound on `|z_k|` used to scale the
1477    /// reconstruction jets (the offline `L` must hold for the largest amplitude
1478    /// the encode can produce); `target_norm_bound` bounds `‖x‖` over the data.
1479    pub fn build(
1480        atoms: &[SaeManifoldAtom],
1481        amplitude_bound: &[f64],
1482        target_norm_bound: f64,
1483        config: AtlasConfig,
1484    ) -> Result<Self, String> {
1485        if amplitude_bound.len() != atoms.len() {
1486            return Err(format!(
1487                "EncodeAtlas::build: amplitude_bound length {} != atom count {}",
1488                amplitude_bound.len(),
1489                atoms.len()
1490            ));
1491        }
1492        let mut atom_atlases = Vec::with_capacity(atoms.len());
1493        for (k, atom) in atoms.iter().enumerate() {
1494            let atlas =
1495                Self::build_atom_atlas(k, atom, amplitude_bound[k], target_norm_bound, &config)?;
1496            atom_atlases.push(atlas);
1497        }
1498        Ok(Self {
1499            atoms: atom_atlases,
1500            config,
1501        })
1502    }
1503
1504    pub(crate) fn build_atom_atlas(
1505        atom_index: usize,
1506        atom: &SaeManifoldAtom,
1507        amplitude_bound: f64,
1508        target_norm_bound: f64,
1509        config: &AtlasConfig,
1510    ) -> Result<AtomEncodeAtlas, String> {
1511        let centers = chart_center_grid(atom, config.grid_resolution);
1512        // Half the inter-center spacing is the natural in-chart radius so the
1513        // charts tile the grid without gaps; refined below if the certificate
1514        // fails at that radius. One uniform radius for the regular grid.
1515        let nominal_radius = chart_nominal_radius(atom, config.grid_resolution);
1516        let radii = vec![nominal_radius; centers.nrows()];
1517        Self::build_atom_atlas_from_centers(
1518            atom_index,
1519            atom,
1520            centers.view(),
1521            &radii,
1522            amplitude_bound,
1523            target_norm_bound,
1524            config,
1525        )
1526    }
1527
1528    /// Build a per-atom atlas from EXPLICIT chart centers with a per-center
1529    /// nominal radius — the geometry-agnostic core shared by the regular-grid
1530    /// [`Self::build_atom_atlas`] and the data-driven [`Self::build_data_driven`].
1531    /// Every chart is certified identically (Kantorovich radius from the in-chart
1532    /// curvature at its center); only the center PLACEMENT and per-center radius
1533    /// differ. `radii[c]` is the nominal in-chart radius for `centers[c]`.
1534    pub(crate) fn build_atom_atlas_from_centers(
1535        atom_index: usize,
1536        atom: &SaeManifoldAtom,
1537        centers: ArrayView2<'_, f64>,
1538        radii: &[f64],
1539        amplitude_bound: f64,
1540        target_norm_bound: f64,
1541        config: &AtlasConfig,
1542    ) -> Result<AtomEncodeAtlas, String> {
1543        let d = atom.latent_dim;
1544        if centers.ncols() != d {
1545            return Err(format!(
1546                "build_atom_atlas_from_centers: centers have {} cols but atom latent_dim is {d}",
1547                centers.ncols()
1548            ));
1549        }
1550        if radii.len() != centers.nrows() {
1551            return Err(format!(
1552                "build_atom_atlas_from_centers: {} radii != {} centers",
1553                radii.len(),
1554                centers.nrows()
1555            ));
1556        }
1557        let decoder_norm_sum = decoder_row_norm_sum(atom.decoder_coefficients.view());
1558        let mut charts = Vec::with_capacity(centers.nrows());
1559        // HONEST REFUSAL for Duchon atoms (F2/F3): the closed-form Hessian-Lipschitz
1560        // bound available here (`family_jet_sups` Duchon arm) hard-codes cubic-r³
1561        // jets and a single origin center, but the real Duchon kernel is the
1562        // polyharmonic `c·r^(2m−d)` (with log variants) over data-placed centers.
1563        // That bound can UNDER-estimate L → a FALSE certificate (the module's own
1564        // warning; underestimating L is the dangerous direction). The atom does not
1565        // expose its real order/centers/scaling to this crate, so no sound bound is
1566        // available — refuse rather than fabricate. Every Duchon chart is emitted
1567        // UNCERTIFIED (`certified_radius = 0`, no amortized predictor), so routing
1568        // skips it and every Duchon row flags for the exact multi-start encode.
1569        let duchon_uncertifiable =
1570            matches!(atom.basis_kind, crate::manifold::SaeAtomBasisKind::Duchon);
1571        for c in 0..centers.nrows() {
1572            let center = centers.row(c).to_owned();
1573            let nominal_radius = radii[c];
1574            let region = chart_region(atom, center.clone(), nominal_radius);
1575            if duchon_uncertifiable {
1576                charts.push(CertifiedChart {
1577                    region,
1578                    lipschitz: f64::INFINITY,
1579                    beta_center: f64::INFINITY,
1580                    certified_radius: 0.0,
1581                    amortized_jacobian: None,
1582                    recon_center: Array1::<f64>::zeros(atom.output_dim()),
1583                    amortized_base: None,
1584                });
1585                continue;
1586            }
1587            let sups = family_jet_sups(atom, &region)?;
1588            let recon_sups = reconstruction_jet_sups(atom, sups);
1589            let lipschitz =
1590                hessian_lipschitz_constant(recon_sups, amplitude_bound, target_norm_bound, 0.0);
1591            // β at the chart center bounds the worst-case in-chart curvature
1592            // (the Gauss-Newton Hessian is continuous; the certified radius is
1593            // solved so the certificate is robust to the start within the ball).
1594            let beta_center = match center_beta(atom, &center, config.ridge) {
1595                Some(b) => b,
1596                None => {
1597                    // Degenerate center curvature: no certifiable chart here, and
1598                    // no amortized Jacobian (the same singular Gauss–Newton block).
1599                    charts.push(CertifiedChart {
1600                        region,
1601                        lipschitz,
1602                        beta_center: f64::INFINITY,
1603                        certified_radius: 0.0,
1604                        amortized_jacobian: None,
1605                        recon_center: Array1::<f64>::zeros(atom.output_dim()),
1606                        amortized_base: None,
1607                    });
1608                    continue;
1609                }
1610            };
1611            // Distill the amortized-encoder Jacobian at this center (#1026 ladder
1612            // item 3): the IFT derivative of the encode map, precomputed offline
1613            // so the online encode is one mat-vec. A finite `beta_center` (above)
1614            // means the Gauss–Newton block is non-singular, so this succeeds
1615            // alongside it; the pair travels together on the chart.
1616            let (amortized_jacobian, recon_center) =
1617                match center_amortized_jacobian(atom, &center, config.ridge) {
1618                    Some((a1, m1)) => (Some(a1), m1),
1619                    None => (None, Array1::<f64>::zeros(atom.output_dim())),
1620                };
1621            // Certified radius from h = β·η·L ≤ ½ with η ≤ R (Newton step length
1622            // is bounded by the start distance to the root, itself ≤ chart
1623            // radius at worst): R_c = ½ / (β·L), capped at the nominal radius.
1624            let certified_radius = if lipschitz > 0.0 && beta_center.is_finite() {
1625                (0.5 / (beta_center * lipschitz)).min(region.radius)
1626            } else {
1627                region.radius
1628            };
1629            // Precompute the affine-predictor constant `base = t_c − A₁·m₁` (atom-
1630            // static), so the online encode is a single `base + (1/z)·A₁·x` mat-vec.
1631            let amortized_base = amortized_jacobian
1632                .as_ref()
1633                .map(|a1| &center - &a1.dot(&recon_center));
1634            charts.push(CertifiedChart {
1635                region,
1636                lipschitz,
1637                beta_center,
1638                certified_radius,
1639                amortized_jacobian,
1640                recon_center,
1641                amortized_base,
1642            });
1643        }
1644        Ok(AtomEncodeAtlas {
1645            atom_index,
1646            latent_dim: d,
1647            decoder_norm_sum,
1648            charts,
1649        })
1650    }
1651
1652    /// Build the atlas with DATA-DRIVEN chart placement: instead of a dense
1653    /// `resolution^d` product grid (exponential in latent dim `d`, so the regular
1654    /// [`Self::build`] is forced to coarse, poorly-certified charts for `d ≥ 3`),
1655    /// place a bounded number of charts AT the data's own latent coordinates. The
1656    /// chart count is then `O(max_charts)` regardless of `d`, and every chart sits
1657    /// where data actually lands (small in-chart residual → certifies), so
1658    /// higher-dimensional atoms — which reconstruct real activations far better per
1659    /// parameter — become affordable and well-covered.
1660    ///
1661    /// `coords[k]` is atom `k`'s `n × d_k` latent coordinates (the seed coords, or
1662    /// a previous encode's output). Charts are chosen by greedy farthest-point
1663    /// sampling over those coords (deterministic, coverage-maximizing), capped at
1664    /// `max_charts`. Each chart's nominal radius is half the distance to its
1665    /// nearest neighbor center, so the charts tile the local data density. The
1666    /// per-chart Kantorovich certification is IDENTICAL to the regular grid — only
1667    /// the center placement differs.
1668    pub fn build_data_driven(
1669        atoms: &[SaeManifoldAtom],
1670        coords: &[Array2<f64>],
1671        amplitude_bound: &[f64],
1672        target_norm_bound: f64,
1673        max_charts: usize,
1674        config: AtlasConfig,
1675    ) -> Result<Self, String> {
1676        if amplitude_bound.len() != atoms.len() || coords.len() != atoms.len() {
1677            return Err(format!(
1678                "build_data_driven: amplitude_bound {} / coords {} must match atom count {}",
1679                amplitude_bound.len(),
1680                coords.len(),
1681                atoms.len()
1682            ));
1683        }
1684        let mut atom_atlases = Vec::with_capacity(atoms.len());
1685        for (k, atom) in atoms.iter().enumerate() {
1686            let (centers, radii) =
1687                data_driven_chart_centers(atom, coords[k].view(), max_charts.max(1))?;
1688            let atlas = Self::build_atom_atlas_from_centers(
1689                k,
1690                atom,
1691                centers.view(),
1692                &radii,
1693                amplitude_bound[k],
1694                target_norm_bound,
1695                &config,
1696            )?;
1697            atom_atlases.push(atlas);
1698        }
1699        Ok(Self {
1700            atoms: atom_atlases,
1701            config,
1702        })
1703    }
1704
1705    fn refine_certified_encode_start(
1706        &self,
1707        atom: &SaeManifoldAtom,
1708        evaluator: &dyn SaeBasisEvaluator,
1709        chart: &CertifiedChart,
1710        t: Array1<f64>,
1711        x: ArrayView1<'_, f64>,
1712        amplitude: f64,
1713    ) -> Result<(Array1<f64>, RowCertificate), String> {
1714        // Certify from the warm start, navigating into the Kantorovich basin first
1715        // if the unit-amplitude start has h > ½ (see `certify_with_basin_warmup`).
1716        let Some(probe) = certify_with_basin_warmup(
1717            atom,
1718            evaluator,
1719            t,
1720            x,
1721            amplitude,
1722            chart.lipschitz,
1723            self.config.newton_steps,
1724            chart.region.center.view(),
1725            chart.region.radius,
1726        )?
1727        else {
1728            return Ok((
1729                Array1::<f64>::zeros(atom.latent_dim),
1730                uncertified_certificate(chart.lipschitz),
1731            ));
1732        };
1733        // F5: pair the REFINED coordinate with the certificate evaluated AT that
1734        // refined landing point (`final_cert`), not the pre-refinement
1735        // `initial_cert`. Returning `initial_cert` describes a different (earlier)
1736        // iterate's β/η/h — its Kantorovich root-radius overstates the refined
1737        // point's distance to the root.
1738        Ok((probe.coord, probe.final_cert))
1739    }
1740
1741    /// Online certified encode of one target row `x` against one atom `k` with
1742    /// fixed amplitude `z`. Routes to the nearest chart, starts from that chart's
1743    /// distilled IFT warm start, runs `config.newton_steps` Newton steps, and
1744    /// returns the encoded coordinate with its certificate. An uncertified start
1745    /// (no chart, no distilled Jacobian, non-positive amplitude, or `h > ½`)
1746    /// flags the row for the exact multi-start caller.
1747    pub fn certified_encode_row(
1748        &self,
1749        atom: &SaeManifoldAtom,
1750        atom_index: usize,
1751        x: ArrayView1<'_, f64>,
1752        amplitude: f64,
1753    ) -> Result<(Array1<f64>, RowCertificate), String> {
1754        let atom_atlas = self
1755            .atoms
1756            .get(atom_index)
1757            .ok_or_else(|| format!("certified_encode_row: atom {atom_index} not in atlas"))?;
1758        let d = atom.latent_dim;
1759        // A missing basis evaluator means the amortized/cold predictor cannot fire
1760        // for this atom (e.g. a frozen-baseline or first-build atom that never
1761        // attached a distilled evaluator). That is exactly the "cannot certify"
1762        // state — flag the row uncertified (zeros coords, ∞ certificate) so the
1763        // upstream exact multi-start solve owns it, never a hard error that aborts
1764        // the whole criterion. Mirrors the no-chart / singular-Jacobian branches.
1765        let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
1766            return Ok((
1767                Array1::<f64>::zeros(d),
1768                RowCertificate {
1769                    beta: f64::INFINITY,
1770                    eta: f64::INFINITY,
1771                    lipschitz: f64::INFINITY,
1772                    h: f64::INFINITY,
1773                },
1774            ));
1775        };
1776
1777        // Route to the nearest chart centers by AMBIENT reconstruction distance.
1778        // A single nearest chart is NOT globally sound on self-approaching atoms:
1779        // where the decoded manifold folds near itself (two distant latent points
1780        // map near the same output), the nearest-center chart can certify into the
1781        // locally-worse basin while another chart holds the GLOBAL minimum (both
1782        // branches' charts reconstruct near the crossing, so both are near in
1783        // ambient distance). The certificate is honest about LOCAL convergence but
1784        // cannot see the better far basin. So we refine in the top-K nearest charts
1785        // and keep the lowest-reconstruction-error CERTIFIED result. For a unimodal
1786        // atom every candidate chart converges to the same root, so this is a no-op
1787        // (first-wins tie → the nearest chart), preserving the existing behavior.
1788        let candidates = nearest_charts_topk(atom_atlas, x, amplitude, CERTIFIED_ROUTING_TOPK);
1789        if candidates.is_empty() {
1790            return Ok((
1791                Array1::<f64>::zeros(d),
1792                RowCertificate {
1793                    beta: f64::INFINITY,
1794                    eta: f64::INFINITY,
1795                    lipschitz: f64::INFINITY,
1796                    h: f64::INFINITY,
1797                },
1798            ));
1799        }
1800        // Best CERTIFIED result by reconstruction error, plus the nearest chart's
1801        // result as the uncertified fallback (preserving the prior return when no
1802        // candidate certifies — the nearest chart owns the flagged row).
1803        let mut best: Option<(Array1<f64>, RowCertificate, f64)> = None;
1804        let mut nearest_fallback: Option<(Array1<f64>, RowCertificate)> = None;
1805        for chart_idx in candidates {
1806            let chart = &atom_atlas.charts[chart_idx];
1807            let Some(t) = amortized_warm_start(chart, x, amplitude) else {
1808                if nearest_fallback.is_none() {
1809                    nearest_fallback = Some((
1810                        Array1::<f64>::zeros(d),
1811                        uncertified_certificate(chart.lipschitz),
1812                    ));
1813                }
1814                continue;
1815            };
1816            let (coord, cert) = self.refine_certified_encode_start(
1817                atom,
1818                evaluator.as_ref(),
1819                chart,
1820                t,
1821                x,
1822                amplitude,
1823            )?;
1824            if nearest_fallback.is_none() {
1825                nearest_fallback = Some((coord.clone(), cert.clone()));
1826            }
1827            if cert.certified() {
1828                let err = encode_reconstruction_error(
1829                    atom,
1830                    evaluator.as_ref(),
1831                    coord.view(),
1832                    x,
1833                    amplitude,
1834                );
1835                if best.as_ref().map(|(_, _, e)| err < *e).unwrap_or(true) {
1836                    best = Some((coord, cert, err));
1837                }
1838                // Global-minimum short-circuit: reconstruction error ≥ 0, so a
1839                // certified candidate already at the ambient noise floor is provably
1840                // the global optimum over the remaining charts — stop refining them.
1841                if let Some((_, _, e)) = best.as_ref() {
1842                    if *e <= CERTIFIED_GLOBAL_MIN_RECON_FLOOR * (1.0 + x.dot(&x).sqrt()) {
1843                        break;
1844                    }
1845                }
1846            }
1847        }
1848        match best {
1849            Some((coord, cert, _)) => Ok((coord, cert)),
1850            None => Ok(nearest_fallback.unwrap_or_else(|| {
1851                (
1852                    Array1::<f64>::zeros(d),
1853                    RowCertificate {
1854                        beta: f64::INFINITY,
1855                        eta: f64::INFINITY,
1856                        lipschitz: f64::INFINITY,
1857                        h: f64::INFINITY,
1858                    },
1859                )
1860            })),
1861        }
1862    }
1863
1864    /// Amortized (distilled) encode of one target row `x` against one atom `k`
1865    /// with fixed amplitude `z` (#1026 ladder item 3).
1866    ///
1867    /// Routes to the nearest chart, then predicts the latent coordinate in CLOSED
1868    /// FORM from that chart's precomputed implicit-function-theorem Jacobian:
1869    ///
1870    /// ```text
1871    /// t̂ = t_c + (1/z) · A₁ · (x − z · m₁(t_c)),
1872    /// ```
1873    ///
1874    /// a single `O(d·p)` mat-vec — no per-row Hessian factorization or
1875    /// eigendecomposition, which is the amortization. The Kantorovich
1876    /// certificate is then evaluated AT the predicted start `t̂` with the chart's
1877    /// closed-form Lipschitz constant. A prediction is accepted only when that
1878    /// certificate holds, an independent cold chart-center probe also certifies,
1879    /// and the two refined coordinates agree within the two probes' final
1880    /// Kantorovich root-radius bounds. This keeps the distilled path honest
1881    /// without letting the exact probe reuse the distilled warm start it is
1882    /// auditing. A chart without a distilled Jacobian (singular Gauss–Newton
1883    /// block) flags the row.
1884    pub fn amortized_encode_row(
1885        &self,
1886        atom: &SaeManifoldAtom,
1887        atom_index: usize,
1888        x: ArrayView1<'_, f64>,
1889        amplitude: f64,
1890    ) -> Result<(Array1<f64>, RowCertificate), String> {
1891        let atom_atlas = self
1892            .atoms
1893            .get(atom_index)
1894            .ok_or_else(|| format!("amortized_encode_row: atom {atom_index} not in atlas"))?;
1895        let d = atom.latent_dim;
1896        let uncertified = || {
1897            (
1898                Array1::<f64>::zeros(d),
1899                RowCertificate {
1900                    beta: f64::INFINITY,
1901                    eta: f64::INFINITY,
1902                    lipschitz: f64::INFINITY,
1903                    h: f64::INFINITY,
1904                },
1905            )
1906        };
1907        // A missing basis evaluator means the distilled predictor cannot fire for
1908        // this atom — flag the row uncertified (the exact upstream solve owns it)
1909        // rather than erroring, exactly as the no-chart / singular-Jacobian /
1910        // non-positive-amplitude branches below do. Never a silent wrong encode,
1911        // never a hard abort of the criterion.
1912        let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
1913            return Ok(uncertified());
1914        };
1915        let Some((chart_idx, _)) = nearest_chart(atom_atlas, x, amplitude) else {
1916            return Ok(uncertified());
1917        };
1918        let chart = &atom_atlas.charts[chart_idx];
1919        // Closed-form predicted start t̂ = t_c + (1/z)·A₁·(x − z·m₁). `None` when
1920        // the chart's Gauss–Newton block was singular (no distilled Jacobian, so
1921        // the amortized predictor cannot fire) or the amplitude is not strictly
1922        // positive and finite (a near-inactive atom, where the amplitude-divided
1923        // map is undefined) — either way flag for the exact fallback, never a
1924        // silent wrong encode.
1925        let Some(t_hat) = amortized_warm_start(chart, x, amplitude) else {
1926            return Ok(uncertified());
1927        };
1928        // Evaluate the SAME Kantorovich certificate at the predicted start. The
1929        // amortized prediction is trusted only if this certificate holds AND an
1930        // independent cold chart-center probe certifies and agrees below the
1931        // two probes' final Kantorovich root-radius bounds. This avoids the
1932        // self-referential gate where the "exact" probe is warm-started by the
1933        // same distilled prediction it is supposed to audit.
1934        let Some(amortized_probe) = certify_with_basin_warmup(
1935            atom,
1936            evaluator.as_ref(),
1937            t_hat,
1938            x,
1939            amplitude,
1940            chart.lipschitz,
1941            self.config.newton_steps,
1942            chart.region.center.view(),
1943            chart.region.radius,
1944        )?
1945        else {
1946            return Ok((
1947                Array1::<f64>::zeros(d),
1948                uncertified_certificate(chart.lipschitz),
1949            ));
1950        };
1951
1952        let cold_start = chart.region.center.clone();
1953        let Some(cold_probe) = certify_with_basin_warmup(
1954            atom,
1955            evaluator.as_ref(),
1956            cold_start,
1957            x,
1958            amplitude,
1959            chart.lipschitz,
1960            self.config.newton_steps,
1961            chart.region.center.view(),
1962            chart.region.radius,
1963        )?
1964        else {
1965            return Ok((
1966                amortized_probe.coord,
1967                uncertified_certificate(chart.lipschitz),
1968            ));
1969        };
1970
1971        let gap =
1972            latent_coordinate_distance(atom, amortized_probe.coord.view(), cold_probe.coord.view());
1973        let tolerance = distilled_probe_tolerance(&amortized_probe, &cold_probe, amplitude, x);
1974        if !(gap.is_finite() && gap <= tolerance) {
1975            return Ok((
1976                amortized_probe.coord,
1977                uncertified_certificate(chart.lipschitz),
1978            ));
1979        }
1980        // F5: return the certificate at the refined landing coordinate
1981        // (`final_cert`), consistent with the coord actually returned and with the
1982        // `distilled_probe_tolerance` gate above (which already reads `final_cert`
1983        // via `kantorovich_root_radius`). `initial_cert` is a stale earlier iterate.
1984        Ok((amortized_probe.coord, amortized_probe.final_cert))
1985    }
1986
1987    /// Batched amortized (distilled) encode over many rows against one atom
1988    /// (#1026 ladder item 3, corpus-rate). Each row uses the closed-form
1989    /// per-chart Jacobian predictor and carries its own Kantorovich certificate;
1990    /// uncertified rows are flagged in [`EncodeResult::encode_uncertified_count`]
1991    /// for the exact multi-start fallback. Row-independent against the frozen
1992    /// dictionary, so the batch fans out over rows (deterministic row-order
1993    /// assembly, bit-identical run-to-run), staying sequential inside a rayon
1994    /// worker to avoid nested oversubscription.
1995    pub fn amortized_encode_batch(
1996        &self,
1997        atom: &SaeManifoldAtom,
1998        atom_index: usize,
1999        targets: ArrayView2<'_, f64>,
2000        amplitudes: ArrayView1<'_, f64>,
2001    ) -> Result<EncodeResult, String> {
2002        let n = targets.nrows();
2003        if amplitudes.len() != n {
2004            return Err(format!(
2005                "amortized_encode_batch: amplitudes len {} != rows {n}",
2006                amplitudes.len()
2007            ));
2008        }
2009        let d = atom.latent_dim;
2010        let encode_rows =
2011            |range: std::ops::Range<usize>| -> Result<Vec<(Array1<f64>, bool)>, String> {
2012                range
2013                    .map(|row| {
2014                        let (t, cert) = self.amortized_encode_row(
2015                            atom,
2016                            atom_index,
2017                            targets.row(row),
2018                            amplitudes[row],
2019                        )?;
2020                        Ok((t, cert.certified()))
2021                    })
2022                    .collect()
2023            };
2024        let rows: Vec<(Array1<f64>, bool)> =
2025            if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
2026                use rayon::prelude::*;
2027                const CHUNK: usize = 256;
2028                let n_chunks = n.div_ceil(CHUNK);
2029                let chunked: Vec<Vec<(Array1<f64>, bool)>> = (0..n_chunks)
2030                    .into_par_iter()
2031                    .map(|c| {
2032                        let start = c * CHUNK;
2033                        let end = (start + CHUNK).min(n);
2034                        encode_rows(start..end)
2035                    })
2036                    .collect::<Result<_, _>>()?;
2037                chunked.into_iter().flatten().collect()
2038            } else {
2039                encode_rows(0..n)?
2040            };
2041        let mut coords = Array2::<f64>::zeros((n, d));
2042        let mut certified = Vec::with_capacity(n);
2043        for (row, (t, cert)) in rows.into_iter().enumerate() {
2044            coords.row_mut(row).assign(&t);
2045            certified.push(cert);
2046        }
2047        Ok(EncodeResult::from_rows(coords, certified))
2048    }
2049
2050    /// Batched certified encode over many rows against one atom (the #988
2051    /// throughput consumer). Each row carries its own certificate; uncertified
2052    /// rows are flagged in [`EncodeResult::encode_uncertified_count`] for the
2053    /// exact multi-start fallback.
2054    pub fn certified_encode_batch(
2055        &self,
2056        atom: &SaeManifoldAtom,
2057        atom_index: usize,
2058        targets: ArrayView2<'_, f64>,
2059        amplitudes: ArrayView1<'_, f64>,
2060    ) -> Result<EncodeResult, String> {
2061        let n = targets.nrows();
2062        if amplitudes.len() != n {
2063            return Err(format!(
2064                "certified_encode_batch: amplitudes len {} != rows {n}",
2065                amplitudes.len()
2066            ));
2067        }
2068        let d = atom.latent_dim;
2069        // Per-row encode is independent against a frozen dictionary (#1010), so
2070        // the corpus-rate batch fans out over rows (#1026 amortized-encoder leg /
2071        // #977 Stage-3 corpus encode). Each row produces an owned `(t, certified)`
2072        // pair; results are assembled back in row order so the output is
2073        // bit-identical run-to-run regardless of thread scheduling. Stay
2074        // sequential inside a rayon worker (e.g. when an outer atom-level fan-out
2075        // owns the pool) to avoid nested oversubscription. The first row that
2076        // fails to encode propagates its error deterministically.
2077        let encode_rows =
2078            |range: std::ops::Range<usize>| -> Result<Vec<(Array1<f64>, bool)>, String> {
2079                range
2080                    .map(|row| {
2081                        let (t, cert) = self.certified_encode_row(
2082                            atom,
2083                            atom_index,
2084                            targets.row(row),
2085                            amplitudes[row],
2086                        )?;
2087                        Ok((t, cert.certified()))
2088                    })
2089                    .collect()
2090            };
2091        let rows: Vec<(Array1<f64>, bool)> =
2092            if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
2093                use rayon::prelude::*;
2094                const CHUNK: usize = 256;
2095                let n_chunks = n.div_ceil(CHUNK);
2096                let chunked: Vec<Vec<(Array1<f64>, bool)>> = (0..n_chunks)
2097                    .into_par_iter()
2098                    .map(|c| {
2099                        let start = c * CHUNK;
2100                        let end = (start + CHUNK).min(n);
2101                        encode_rows(start..end)
2102                    })
2103                    .collect::<Result<_, _>>()?;
2104                chunked.into_iter().flatten().collect()
2105            } else {
2106                encode_rows(0..n)?
2107            };
2108        let mut coords = Array2::<f64>::zeros((n, d));
2109        let mut certified = Vec::with_capacity(n);
2110        for (row, (t, cert)) in rows.into_iter().enumerate() {
2111            coords.row_mut(row).assign(&t);
2112            certified.push(cert);
2113        }
2114        Ok(EncodeResult::from_rows(coords, certified))
2115    }
2116
2117    /// Batched GEMM "fast" amortized encode — the traditional-encoder forward
2118    /// pass, WITH manifolds. For every row this applies the SAME closed-form
2119    /// affine predictor as [`amortized_warm_start`]
2120    /// (`t̂ = t_c + (1/z)·A₁·(x − z·m₁)`), but routed and applied as batched
2121    /// matrix products instead of a per-row loop wrapped in the Kantorovich
2122    /// certificate + basin warmup. NO per-row certificate is taken: this is the
2123    /// speed mode (the certified `*_encode_*` paths remain the accuracy mode).
2124    ///
2125    /// Cost is GEMM-bound: one `(n × p)·(p × d)` decode-distance product for
2126    /// nearest-chart routing (skipped for single-chart atoms) plus, per chart,
2127    /// one `(n_c × p)·(p × d)` predictor product — i.e. `≈ X·Wᵀ`, exactly a
2128    /// dense SAE encoder's forward map.
2129    ///
2130    /// Degenerate rows are handled exactly as `amortized_warm_start` flags them
2131    /// (returns `None` ⇒ zeroed coord here): a missing basis evaluator, a chart
2132    /// whose Gauss–Newton block was singular (`amortized_jacobian == None`), or a
2133    /// non-finite / non-positive amplitude. Those rows are zeroed (never a panic,
2134    /// never a silent wrong encode), and their indices are returned in the
2135    /// `valid` mask so the caller can route them to the exact path if desired.
2136    ///
2137    /// Returns `(coords, valid)` where `coords` is `n × d` and `valid[row]` is
2138    /// `true` iff the amortized predictor fired for that row.
2139    pub fn amortized_encode_batch_fast(
2140        &self,
2141        atom: &SaeManifoldAtom,
2142        atom_index: usize,
2143        x: ArrayView2<'_, f64>,
2144        amplitudes: ArrayView1<'_, f64>,
2145    ) -> Result<(Array2<f64>, Vec<bool>), String> {
2146        let n = x.nrows();
2147        let p = atom.output_dim();
2148        let d = atom.latent_dim;
2149        if x.ncols() != p {
2150            return Err(format!(
2151                "amortized_encode_batch_fast: x has {} cols but atom output dim is {p}",
2152                x.ncols()
2153            ));
2154        }
2155        if amplitudes.len() != n {
2156            return Err(format!(
2157                "amortized_encode_batch_fast: amplitudes len {} != rows {n}",
2158                amplitudes.len()
2159            ));
2160        }
2161        let atom_atlas = self.atoms.get(atom_index).ok_or_else(|| {
2162            format!("amortized_encode_batch_fast: atom {atom_index} not in atlas")
2163        })?;
2164        let mut coords = Array2::<f64>::zeros((n, d));
2165        let mut valid = vec![false; n];
2166
2167        // A missing basis evaluator means this atom never had a well-formed atlas
2168        // built here — treat every row as uncertified (zeroed), exactly like the
2169        // per-row `amortized_encode_row` no-evaluator branch. (The predictor below
2170        // uses only cached atlas data, so no evaluator call is made online.)
2171        if atom.basis_evaluator.is_none() {
2172            return Ok((coords, valid));
2173        }
2174
2175        // ── Routing recon-centers (cached, no online basis evaluation). ────────
2176        // Routing sends a row to the chart whose center reconstruction
2177        // `m(t_c) = BᵀΦ(t_c)` is closest in ‖·‖². Those center reconstructions are
2178        // OFFLINE-cached in `chart.recon_center` (bit-identical to re-evaluating the
2179        // basis at the fixed centers — same φ·decoder accumulation). Gather the cache
2180        // instead of calling `evaluator.evaluate` on every invocation: that per-call
2181        // chart-center evaluation was the dominant per-atom-group overhead at massive
2182        // K, where N rows scatter across many atoms into tiny groups so a fixed
2183        // per-call cost is amortized over only a handful of rows. This is what keeps
2184        // the fast index-routed encode near-flat as K grows.
2185        let valid_charts: Vec<usize> = (0..atom_atlas.charts.len())
2186            .filter(|&c| atom_atlas.charts[c].certified_radius > 0.0)
2187            .collect();
2188        if valid_charts.is_empty() {
2189            return Ok((coords, valid));
2190        }
2191        // recon_centers (C × p): the cached m(t_c) for each certifiable chart.
2192        let mut recon_centers = Array2::<f64>::zeros((valid_charts.len(), p));
2193        for (ci, &c) in valid_charts.iter().enumerate() {
2194            recon_centers
2195                .row_mut(ci)
2196                .assign(&atom_atlas.charts[c].recon_center);
2197        }
2198        // Per-chart routing key: route_idx[row] = argmin_c ‖x_row − z_row·r_c‖²
2199        // (F1: the TRUE objective at the row's amplitude `z_row`, where `r_c` is
2200        // the amplitude-1 center reconstruction). First chart wins on a tie (strict
2201        // `<`), matching `nearest_chart`. A non-finite amplitude routes to chart 0
2202        // (moot — the predictor below skips the row anyway).
2203        //
2204        // Score the DIRECT squared distance `Σ_j (z·r_cj − x_j)²`, accumulated in
2205        // the same element order as the per-row `nearest_chart`, NOT the algebraic
2206        // expansion `z²‖r_c‖² − 2z·(x·r_c)`. The two are equal in exact arithmetic,
2207        // but the expansion subtracts two O(‖x‖²) quantities and loses ~‖x‖²·ε of
2208        // precision to cancellation — enough to FLIP the argmin between two charts
2209        // whose reconstructions coincide to rounding (e.g. period-wrapped torus-seam
2210        // charts, whose latent centers differ by a full period yet reconstruct to
2211        // within 1e-14). On such a near-tie the expansion and `nearest_chart`
2212        // disagree, and the two seam charts predict coords a full period apart — so
2213        // the batched fast-encode would diverge from the per-row warm-start it must
2214        // reproduce bit-for-bit (the `fast_encode_matches_per_row_warm_start`
2215        // contract). Computing the same direct distance in the same order keeps this
2216        // path byte-identical to `nearest_chart`, so routing (and the tie-break)
2217        // agrees on every row. `recon_centers` is still the cached offline
2218        // `m₁(t_c)` — no basis re-evaluation, which was the fast path's real cost.
2219        let route_idx: Vec<usize> = if valid_charts.len() == 1 {
2220            vec![0usize; n]
2221        } else {
2222            (0..n)
2223                .map(|row| {
2224                    let z = amplitudes[row];
2225                    if !z.is_finite() {
2226                        return 0usize;
2227                    }
2228                    let x_row = x.row(row);
2229                    let mut best_c = 0usize;
2230                    let mut best_d = f64::INFINITY;
2231                    for c in 0..valid_charts.len() {
2232                        let mut dist = 0.0;
2233                        for (r, xv) in recon_centers.row(c).iter().zip(x_row.iter()) {
2234                            let diff = z * r - xv;
2235                            dist += diff * diff;
2236                        }
2237                        if dist < best_d {
2238                            best_d = dist;
2239                            best_c = c;
2240                        }
2241                    }
2242                    best_c
2243                })
2244                .collect()
2245        };
2246
2247        // ── Per-chart batched affine predictor. ───────────────────────────────
2248        // For rows routed to chart `c` with finite jacobian `A₁` (d × p) and
2249        // center reconstruction `m₁` (= `chart.recon_center`), the predictor is
2250        //   t̂ = t_c − A₁·m₁ + (1/z)·(A₁·x).
2251        // `t_c − A₁·m₁` is a per-chart constant `base`; `A₁·x` is a d-vector of
2252        // per-row dot products. Instead of gathering routed rows into a fresh
2253        // `X_c` (n_c × p) buffer and running a GEMM into a second `U` (n_c × d)
2254        // buffer — two allocations plus a full copy of the routed rows, per chart —
2255        // fuse the gather straight into the multiply: stream each source row of `x`
2256        // once (it is contiguous) and dot it against `A₁`'s rows, writing the
2257        // predicted coord directly. Zero per-chart heap traffic; the inverse
2258        // amplitude is hoisted to one reciprocal per row.
2259        //
2260        // Precompute each valid chart's `(A₁, base)` once (charts with a singular
2261        // Gauss–Newton block carry no `A₁`, so their routed rows stay
2262        // zeroed/uncertified — same as `amortized_warm_start` returning `None`).
2263        struct ChartPredictor<'a> {
2264            a1: &'a Array2<f64>,
2265            base: &'a Array1<f64>,
2266        }
2267        let predictors: Vec<Option<ChartPredictor<'_>>> = valid_charts
2268            .iter()
2269            .map(|&c| {
2270                let chart = &atom_atlas.charts[c];
2271                // `base = t_c − A₁·m₁` is precomputed offline in the atlas; reuse it
2272                // (both are `Some` together — singular G-N block ⇒ both `None`).
2273                match (
2274                    chart.amortized_jacobian.as_ref(),
2275                    chart.amortized_base.as_ref(),
2276                ) {
2277                    (Some(a1), Some(base)) => Some(ChartPredictor { a1, base }),
2278                    _ => None,
2279                }
2280            })
2281            .collect();
2282
2283        for row in 0..n {
2284            let Some(pred) = predictors[route_idx[row]].as_ref() else {
2285                continue;
2286            };
2287            let amp = amplitudes[row];
2288            if !(amp.is_finite() && amp.abs() > 0.0) {
2289                continue;
2290            }
2291            let inv_z = 1.0 / amp;
2292            let x_row = x.row(row);
2293            let mut coord_row = coords.row_mut(row);
2294            for axis in 0..d {
2295                // (A₁·x)[axis] = A₁ row `axis` (contiguous, length p) · x_row.
2296                coord_row[axis] = pred.base[axis] + pred.a1.row(axis).dot(&x_row) * inv_z;
2297            }
2298            valid[row] = true;
2299        }
2300        Ok((coords, valid))
2301    }
2302
2303    /// Fast batched FULL forward pass against one atom: encode → decode, the
2304    /// manifold analogue of a traditional SAE's `x̂ = z·D` (decoder `D`, code `z`).
2305    ///
2306    /// A traditional SAE decodes with one GEMM. The manifold SAE's reconstruction
2307    /// is `m(t̂) = z·Φ(t̂)·B` (module header) — the SAME GEMM `Φ·B`, but the code
2308    /// `Φ(t̂)` is the curved chart basis evaluated at the encoded latent coordinate
2309    /// rather than a flat one-hot. So the fast forward is exactly:
2310    ///   1. [`amortized_encode_batch_fast`] → per-row latent coords `t̂` (one
2311    ///      routing GEMM + one affine GEMM per chart — a traditional `W·x+b`);
2312    ///   2. ONE batched basis evaluation `Φ(t̂)` (the manifold-curvature step a
2313    ///      flat SAE doesn't have — `n×m`);
2314    ///   3. ONE GEMM `recon = Φ(t̂)·B` (`(n×m)·(m×p)` — a traditional decoder
2315    ///      `z·D`), then the per-row amplitude scale `z`.
2316    ///
2317    /// Rows the encoder could not certify-predict (no evaluator / singular
2318    /// Gauss–Newton block / non-finite-or-zero amplitude) are returned as a ZERO
2319    /// reconstruction and flagged `false` in the valid-mask — never a silent wrong
2320    /// decode. The reconstruction of a valid row equals, bit-for-bit up to GEMM
2321    /// reassociation, `z·(Φ(t̂_row)·B)` with `t̂` from the per-row predictor.
2322    pub fn amortized_reconstruct_batch_fast(
2323        &self,
2324        atom: &SaeManifoldAtom,
2325        atom_index: usize,
2326        x: ArrayView2<'_, f64>,
2327        amplitudes: ArrayView1<'_, f64>,
2328    ) -> Result<(Array2<f64>, Vec<bool>), String> {
2329        let n = x.nrows();
2330        let p = atom.output_dim();
2331        // Step 1: batched encode → latent coords (reuses the fast routing+affine).
2332        let (coords, valid) = self.amortized_encode_batch_fast(atom, atom_index, x, amplitudes)?;
2333        let mut recon = Array2::<f64>::zeros((n, p));
2334        // A missing evaluator means no row could encode — every row is zeroed and
2335        // already flagged `false` by the encode; nothing to decode.
2336        let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
2337            return Ok((recon, valid));
2338        };
2339        // Step 2: ONE batched basis evaluation Φ(t̂) over all rows (n × m). Invalid
2340        // rows carry coords = 0 (the chart-origin); we still evaluate them in the
2341        // batch for a single GEMM, then zero their reconstruction below — the basis
2342        // is finite at the origin so this cannot poison the valid rows' GEMM.
2343        let (phi, _jet) = evaluator
2344            .evaluate(coords.view())
2345            .map_err(|err| format!("amortized_reconstruct_batch_fast: basis eval: {err}"))?;
2346        // Step 3: ONE GEMM recon = Φ·B (n × p), then per-row amplitude scale z.
2347        // m(t̂) = z·Φ(t̂)·B, matching the module header and `fill_decoded_row`'s
2348        // `Φ·decoder` accumulation (the amplitude is applied once here).
2349        let decoded = phi.dot(&atom.decoder_coefficients); // (n × p), amplitude-1
2350        for row in 0..n {
2351            if !valid[row] {
2352                continue; // stays zeroed — uncertified, like warm_start `None`.
2353            }
2354            let z = amplitudes[row];
2355            for col in 0..p {
2356                recon[[row, col]] = z * decoded[[row, col]];
2357            }
2358        }
2359        Ok((recon, valid))
2360    }
2361
2362    /// LSH-routed certified encode (issue #1010 step 2 + 3): for each target
2363    /// row, the existing [`SaeCandidateIndex`] (#985/#994) proposes the
2364    /// best-aligned atom by frame alignment to the row direction; the row is then
2365    /// encoded against THAT atom's certified chart atlas. This is the production
2366    /// routing path. Atom selection is EXACT (#1777): [`SaeCandidateIndex::route_exact`]
2367    /// returns the global argmax of the routing score (the universal-bound LSH fast
2368    /// path, else a full-scan fallback) — never a silently-missed ungathered atom —
2369    /// and the atlas does the in-atom nearest-chart routing and the per-row
2370    /// Kantorovich certificate.
2371    ///
2372    /// `atoms[id]` must be aligned with the atlas's `atoms[id]` (same dictionary
2373    /// order the atlas was built from and the sketch/index were built over).
2374    /// A row over an empty dictionary, or whose globally-best atom aligns below the
2375    /// fit-quality floor, is flagged uncertified — it routes to the exact
2376    /// multi-start fallback, never a silent wrong encode.
2377    pub fn certified_encode_with_index<S: AtomFrameSketch + Sync>(
2378        &self,
2379        atoms: &[SaeManifoldAtom],
2380        index: &SaeCandidateIndex,
2381        sketch: &S,
2382        targets: ArrayView2<'_, f64>,
2383        amplitudes: ArrayView1<'_, f64>,
2384        latent_dim: usize,
2385    ) -> Result<EncodeResult, String> {
2386        let n = targets.nrows();
2387        if amplitudes.len() != n {
2388            return Err(format!(
2389                "certified_encode_with_index: amplitudes len {} != rows {n}",
2390                amplitudes.len()
2391            ));
2392        }
2393        let budget = auto_candidate_budget(atoms.len().max(1));
2394        // LSH-routed per-row encode is independent across rows (sublinear atom
2395        // selection + frozen-dictionary in-atom Newton), so the corpus-rate batch
2396        // fans out over rows (#1026 amortized-encoder/routing leg / #977 Stage-3).
2397        // `None` coords (no LSH candidate) carry through as a zeroed row flagged
2398        // uncertified — identical to the sequential semantics. Results assemble
2399        // back in row order (bit-identical run-to-run); the first encode error
2400        // propagates deterministically. Stay sequential inside a rayon worker to
2401        // avoid nested oversubscription.
2402        let encode_rows =
2403            |range: std::ops::Range<usize>| -> Result<Vec<Option<(Array1<f64>, bool)>>, String> {
2404                range
2405                    .map(|row| {
2406                        // EXACT routing (#1777): pick the GLOBAL argmax of the
2407                        // routing score over the whole dictionary, not merely the
2408                        // best LSH-gathered candidate. `route_exact` certifies the
2409                        // sublinear gather against the universal `[0,1]` alignment
2410                        // bound and falls back to a full scan otherwise, so the
2411                        // returned atom is guaranteed to be the globally-best — no
2412                        // silently-missed ungathered atom.
2413                        let Some(route) =
2414                            index.route_exact(sketch, targets.row(row), budget, true)
2415                        else {
2416                            // Empty dictionary: flag for the exact fallback.
2417                            return Ok(None);
2418                        };
2419                        let best_atom = route.atom;
2420                        // Fit-quality floor: even the globally-best atom may align
2421                        // only weakly with this row (no atom fits it). A finite
2422                        // alignment below the floor — or a NaN, the zero-norm
2423                        // ‖d‖ = 0 row — flags for the exact multi-start fallback
2424                        // rather than encoding against a poorly-fitting atom. This is
2425                        // a quality gate, not a routing-correctness gate; routing is
2426                        // already exact. See CANDIDATE_ROUTING_MIN_ALIGNMENT.
2427                        if !route.alignment.is_finite()
2428                            || route.alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
2429                        {
2430                            return Ok(None);
2431                        }
2432                        let atom = atoms.get(best_atom).ok_or_else(|| {
2433                            format!(
2434                                "certified_encode_with_index: proposed atom {best_atom} out of range"
2435                            )
2436                        })?;
2437                        let (t, cert) = self.certified_encode_row(
2438                            atom,
2439                            best_atom,
2440                            targets.row(row),
2441                            amplitudes[row],
2442                        )?;
2443                        // Heterogeneous-atom dictionaries with different latent_dim
2444                        // per atom are not supported by the batched API: the caller
2445                        // declares one shared `latent_dim` for the output tensor.
2446                        // Silently zeroing the coord row while recording a
2447                        // certified=true flag would produce corrupted
2448                        // reconstructions downstream — error loudly instead.
2449                        if t.len() != latent_dim {
2450                            return Err(format!(
2451                                "certified_encode_with_index: atom {best_atom} returned t.len()={} \
2452                                 but declared latent_dim={latent_dim}; heterogeneous-dim \
2453                                 dictionaries are not supported by this batched encode path",
2454                                t.len()
2455                            ));
2456                        }
2457                        Ok(Some((t, cert.certified())))
2458                    })
2459                    .collect()
2460            };
2461        let rows: Vec<Option<(Array1<f64>, bool)>> =
2462            if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
2463                use rayon::prelude::*;
2464                const CHUNK: usize = 256;
2465                let n_chunks = n.div_ceil(CHUNK);
2466                let chunked: Vec<Vec<Option<(Array1<f64>, bool)>>> = (0..n_chunks)
2467                    .into_par_iter()
2468                    .map(|c| {
2469                        let start = c * CHUNK;
2470                        let end = (start + CHUNK).min(n);
2471                        encode_rows(start..end)
2472                    })
2473                    .collect::<Result<_, _>>()?;
2474                chunked.into_iter().flatten().collect()
2475            } else {
2476                encode_rows(0..n)?
2477            };
2478        let mut coords = Array2::<f64>::zeros((n, latent_dim));
2479        let mut certified = Vec::with_capacity(n);
2480        for (row, slot) in rows.into_iter().enumerate() {
2481            match slot {
2482                Some((t, cert)) => {
2483                    coords.row_mut(row).assign(&t);
2484                    certified.push(cert);
2485                }
2486                None => certified.push(false),
2487            }
2488        }
2489        Ok(EncodeResult::from_rows(coords, certified))
2490    }
2491
2492    /// LSH-routed AMORTIZED (distilled) encode — the production token-rate
2493    /// encoder of #1026 ladder item 3. Identical routing to
2494    /// [`Self::certified_encode_with_index`] (LSH proposes the best-aligned atom,
2495    /// the atlas routes to the in-atom nearest chart), but the in-atom encode is
2496    /// the closed-form per-chart Jacobian predictor + certificate gate of
2497    /// [`Self::amortized_encode_row`] rather than the certified Newton-refinement
2498    /// path.
2499    /// This is the deployment path: the distilled affine map produces the encode
2500    /// in one mat-vec, the Kantorovich certificate decides trust-or-fallback per
2501    /// row, and uncertified rows (the adversarial tail the thread expects to
2502    /// concentrate on rare tokens) are flagged for the exact multi-start solve —
2503    /// compute goes where the questions are. Row-independent against the frozen
2504    /// dictionary, so the batch fans out over rows with deterministic row-order
2505    /// assembly (bit-identical run-to-run).
2506    pub fn amortized_encode_with_index<S: AtomFrameSketch + Sync>(
2507        &self,
2508        atoms: &[SaeManifoldAtom],
2509        index: &SaeCandidateIndex,
2510        sketch: &S,
2511        targets: ArrayView2<'_, f64>,
2512        amplitudes: ArrayView1<'_, f64>,
2513        latent_dim: usize,
2514    ) -> Result<EncodeResult, String> {
2515        let n = targets.nrows();
2516        if amplitudes.len() != n {
2517            return Err(format!(
2518                "amortized_encode_with_index: amplitudes len {} != rows {n}",
2519                amplitudes.len()
2520            ));
2521        }
2522        let budget = auto_candidate_budget(atoms.len().max(1));
2523        let encode_rows =
2524            |range: std::ops::Range<usize>| -> Result<Vec<Option<(Array1<f64>, bool)>>, String> {
2525                range
2526                    .map(|row| {
2527                        // EXACT routing (#1777): global argmax of the routing score,
2528                        // not just the best LSH-gathered candidate (see
2529                        // certified_encode_with_index for the full rationale).
2530                        let Some(route) =
2531                            index.route_exact(sketch, targets.row(row), budget, true)
2532                        else {
2533                            return Ok(None);
2534                        };
2535                        let best_atom = route.atom;
2536                        // Fit-quality floor (not a routing-correctness gate; routing
2537                        // is exact): even the globally-best atom may fit a row poorly,
2538                        // and a NaN alignment is the zero-norm ‖d‖ = 0 row. Either way
2539                        // flag for the exact multi-start fallback. See
2540                        // CANDIDATE_ROUTING_MIN_ALIGNMENT.
2541                        if !route.alignment.is_finite()
2542                            || route.alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
2543                        {
2544                            return Ok(None);
2545                        }
2546                        let atom = atoms.get(best_atom).ok_or_else(|| {
2547                            format!(
2548                                "amortized_encode_with_index: proposed atom {best_atom} out of range"
2549                            )
2550                        })?;
2551                        let (t, cert) = self.amortized_encode_row(
2552                            atom,
2553                            best_atom,
2554                            targets.row(row),
2555                            amplitudes[row],
2556                        )?;
2557                        if t.len() != latent_dim {
2558                            return Err(format!(
2559                                "amortized_encode_with_index: atom {best_atom} returned t.len()={} \
2560                                 but declared latent_dim={latent_dim}; heterogeneous-dim \
2561                                 dictionaries are not supported by this batched encode path",
2562                                t.len()
2563                            ));
2564                        }
2565                        Ok(Some((t, cert.certified())))
2566                    })
2567                    .collect()
2568            };
2569        let rows: Vec<Option<(Array1<f64>, bool)>> =
2570            if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
2571                use rayon::prelude::*;
2572                const CHUNK: usize = 256;
2573                let n_chunks = n.div_ceil(CHUNK);
2574                let chunked: Vec<Vec<Option<(Array1<f64>, bool)>>> = (0..n_chunks)
2575                    .into_par_iter()
2576                    .map(|c| {
2577                        let start = c * CHUNK;
2578                        let end = (start + CHUNK).min(n);
2579                        encode_rows(start..end)
2580                    })
2581                    .collect::<Result<_, _>>()?;
2582                chunked.into_iter().flatten().collect()
2583            } else {
2584                encode_rows(0..n)?
2585            };
2586        let mut coords = Array2::<f64>::zeros((n, latent_dim));
2587        let mut certified = Vec::with_capacity(n);
2588        for (row, slot) in rows.into_iter().enumerate() {
2589            match slot {
2590                Some((t, cert)) => {
2591                    coords.row_mut(row).assign(&t);
2592                    certified.push(cert);
2593                }
2594                None => certified.push(false),
2595            }
2596        }
2597        Ok(EncodeResult::from_rows(coords, certified))
2598    }
2599
2600    /// LSH-routed FAST amortized encode over the WHOLE dictionary — the
2601    /// multi-atom, corpus-rate analogue of [`Self::amortized_encode_with_index`].
2602    ///
2603    /// `amortized_encode_with_index` routes per row, then runs the per-row
2604    /// closed-form predictor + Kantorovich certificate + cold cross-check on each
2605    /// row independently. This fast variant keeps the SAME per-row EXACT routing
2606    /// (`index.route_exact` + the fit-quality floor), but replaces the per-row
2607    /// predictor with the GEMM-batched [`Self::amortized_encode_batch_fast`]:
2608    /// it GROUPS rows by their global-argmax atom and runs one batched affine-
2609    /// predictor pass per atom-group (a routing GEMM + a predictor GEMM each),
2610    /// reproducing a traditional SAE's whole-dictionary `W·x+b` throughput. No
2611    /// per-row certificate — this is the speed mode validated as accuracy-parity
2612    /// with the certified solve (`fast_forward_is_accuracy_parity_with_certified`).
2613    ///
2614    /// Returns the per-row latent coords and a valid-mask: `false` for a row with
2615    /// an empty dictionary, a sub-threshold/NaN routing alignment, or one the batched
2616    /// predictor could not fire on (no evaluator / singular Gauss–Newton block /
2617    /// non-finite-or-zero amplitude). Each row is written exactly once (disjoint
2618    /// per-atom groups), so the result is independent of group iteration order.
2619    pub fn amortized_encode_with_index_fast<S: AtomFrameSketch + Sync>(
2620        &self,
2621        atoms: &[SaeManifoldAtom],
2622        index: &SaeCandidateIndex,
2623        sketch: &S,
2624        targets: ArrayView2<'_, f64>,
2625        amplitudes: ArrayView1<'_, f64>,
2626        latent_dim: usize,
2627    ) -> Result<(Array2<f64>, Vec<bool>), String> {
2628        let n = targets.nrows();
2629        if amplitudes.len() != n {
2630            return Err(format!(
2631                "amortized_encode_with_index_fast: amplitudes len {} != rows {n}",
2632                amplitudes.len()
2633            ));
2634        }
2635        let budget = auto_candidate_budget(atoms.len().max(1));
2636        let mut coords = Array2::<f64>::zeros((n, latent_dim));
2637        let mut valid = vec![false; n];
2638        // ── Single allocation-free pass: route each row, apply the CACHED predictor.
2639        //
2640        // Routing sublinearity (massive-K, K≈32k): the certified path uses
2641        // `route_exact`, whose universal-bound LSH certificate only fires at the
2642        // alignment ceiling (≈1.0); for any real dictionary (`alignment < 1`) it
2643        // falls back to `brute_force_best_atom` — an O(K) full scan PER ROW, making
2644        // the encode O(N·K). This SPEED path takes the LSH gather's best-aligned atom
2645        // directly (`propose` scores only ~budget candidates → O(log K)); a rare miss
2646        // is caught by the fit-quality floor + the downstream certificate/exact
2647        // fallback (the documented speed/accuracy tradeoff).
2648        //
2649        // Allocation: the predictor uses only OFFLINE-cached atlas data (per-chart
2650        // `recon_center` for routing + `amortized_jacobian`/`amortized_base` for the
2651        // `t̂ = base + (1/z)·A₁·x` mat-vec), so NO per-row or per-atom heap buffer is
2652        // allocated. This replaces the old per-atom-group GEMM sub-batch — which at
2653        // `K ≫ N` degenerated to ONE row per group, so its per-group buffers (x_sub,
2654        // recon-centers, predictors) dominated and made the "fast" path allocation-
2655        // bound. Now the only per-row allocation is inside `index.propose`.
2656        for row in 0..n {
2657            let dir = targets.row(row);
2658            let proposal = index.propose(sketch, dir, budget, true);
2659            let Some(&best_atom) = proposal.proposed.first() else {
2660                continue; // nothing gathered (empty dictionary / probe-dim mismatch)
2661            };
2662            // Fit-quality floor: the best gathered atom still fits this row poorly,
2663            // or the alignment is NaN (zero-norm row) — flag for the exact fallback.
2664            let alignment = sketch.alignment(best_atom, dir);
2665            if !alignment.is_finite() || alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT {
2666                continue;
2667            }
2668            let atom = atoms.get(best_atom).ok_or_else(|| {
2669                format!("amortized_encode_with_index_fast: proposed atom {best_atom} out of range")
2670            })?;
2671            if atom.latent_dim != latent_dim {
2672                return Err(format!(
2673                    "amortized_encode_with_index_fast: atom {best_atom} latent_dim {} != declared \
2674                     {latent_dim}; heterogeneous-dim dictionaries are not supported by this path",
2675                    atom.latent_dim
2676                ));
2677            }
2678            let Some(atom_atlas) = self.atoms.get(best_atom) else {
2679                continue; // no atlas for this atom → predictor cannot fire (zeroed)
2680            };
2681            if amortized_predict_row(
2682                atom_atlas,
2683                dir,
2684                amplitudes[row],
2685                latent_dim,
2686                coords.row_mut(row),
2687            ) {
2688                valid[row] = true;
2689            }
2690        }
2691        Ok((coords, valid))
2692    }
2693
2694    /// LSH-routed FAST full forward over the WHOLE dictionary: encode → decode,
2695    /// the multi-atom analogue of [`Self::amortized_reconstruct_batch_fast`]. Same
2696    /// sublinear per-row routing + per-atom grouping as
2697    /// [`Self::amortized_encode_with_index_fast`], but each group is run through
2698    /// the batched reconstruct (`m(t̂) = z·Φ(t̂)·B`) so the result is the per-row
2699    /// reconstruction in the ambient space. Rows that do not route/predict decode
2700    /// to an exact zero reconstruction and are flagged `false`.
2701    pub fn amortized_reconstruct_with_index_fast<S: AtomFrameSketch + Sync>(
2702        &self,
2703        atoms: &[SaeManifoldAtom],
2704        index: &SaeCandidateIndex,
2705        sketch: &S,
2706        targets: ArrayView2<'_, f64>,
2707        amplitudes: ArrayView1<'_, f64>,
2708    ) -> Result<(Array2<f64>, Vec<bool>), String> {
2709        let n = targets.nrows();
2710        let p = targets.ncols();
2711        if amplitudes.len() != n {
2712            return Err(format!(
2713                "amortized_reconstruct_with_index_fast: amplitudes len {} != rows {n}",
2714                amplitudes.len()
2715            ));
2716        }
2717        let budget = auto_candidate_budget(atoms.len().max(1));
2718        // SUBLINEAR routing for the SPEED-mode full forward — mirror
2719        // `amortized_encode_with_index_fast`: take the LSH gather's best-aligned
2720        // atom (O(budget) candidates) instead of route_exact's O(K) full-scan
2721        // certification, keeping the whole fast encode→decode sublinear in K at
2722        // K=32k. The gather's best is the exact argmax on the vast majority of rows;
2723        // rare misses are caught by the fit-quality floor + downstream fallback.
2724        let mut groups: std::collections::HashMap<usize, Vec<usize>> =
2725            std::collections::HashMap::new();
2726        for row in 0..n {
2727            let dir = targets.row(row);
2728            let proposal = index.propose(sketch, dir, budget, true);
2729            let Some(&best_atom) = proposal.proposed.first() else {
2730                continue; // nothing gathered (empty dictionary / probe-dim mismatch)
2731            };
2732            let alignment = sketch.alignment(best_atom, dir);
2733            if !alignment.is_finite() || alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT {
2734                continue;
2735            }
2736            groups.entry(best_atom).or_default().push(row);
2737        }
2738
2739        let mut recon = Array2::<f64>::zeros((n, p));
2740        let mut valid = vec![false; n];
2741        for (atom_idx, rows_here) in groups {
2742            let atom = atoms.get(atom_idx).ok_or_else(|| {
2743                format!(
2744                    "amortized_reconstruct_with_index_fast: proposed atom {atom_idx} out of range"
2745                )
2746            })?;
2747            if atom.output_dim() != p {
2748                return Err(format!(
2749                    "amortized_reconstruct_with_index_fast: atom {atom_idx} output_dim {} != target \
2750                     dim {p}",
2751                    atom.output_dim()
2752                ));
2753            }
2754            let mut x_sub = Array2::<f64>::zeros((rows_here.len(), p));
2755            let mut amp_sub = Array1::<f64>::zeros(rows_here.len());
2756            for (i, &row) in rows_here.iter().enumerate() {
2757                x_sub.row_mut(i).assign(&targets.row(row));
2758                amp_sub[i] = amplitudes[row];
2759            }
2760            let (sub_recon, sub_valid) = self.amortized_reconstruct_batch_fast(
2761                atom,
2762                atom_idx,
2763                x_sub.view(),
2764                amp_sub.view(),
2765            )?;
2766            for (i, &row) in rows_here.iter().enumerate() {
2767                if sub_valid[i] {
2768                    recon.row_mut(row).assign(&sub_recon.row(i));
2769                    valid[row] = true;
2770                }
2771            }
2772        }
2773        Ok((recon, valid))
2774    }
2775}
2776
2777/// Offline `β = 1/λ_min(H_GN)` at a chart center from the Gauss-Newton block
2778/// `H_GN = J_mᵀ J_m` (residual-free). The offline `β` bounds the curvature the
2779/// online certificate sees: charts are placed where the encode lands, so the
2780/// representative residual is small and `H_GN` is the dominant, residual-free
2781/// curvature estimate. (The online per-row certificate still uses the FULL
2782/// Hessian; this is only the offline radius-sizing curvature.) Returns `None`
2783/// for a degenerate center (`λ_min ≤ 0`), which marks an uncertifiable chart.
2784pub(crate) fn center_beta(atom: &SaeManifoldAtom, center: &Array1<f64>, ridge: f64) -> Option<f64> {
2785    let evaluator = atom.basis_evaluator.as_ref()?.clone();
2786    let d = atom.latent_dim;
2787    let p = atom.output_dim();
2788    let m = atom.basis_size();
2789    let coords = center.view().to_shape((1, d)).ok()?.to_owned();
2790    let (_phi, jet) = evaluator.evaluate(coords.view()).ok()?;
2791    let decoder = &atom.decoder_coefficients;
2792    // J_m[axis] = Bᵀ (∂Φ/∂t_axis) ∈ ℝᵖ (amplitude-1; curvature scales with z²
2793    // and is absorbed conservatively by the amplitude-bounded Lipschitz term).
2794    let mut jm = Array2::<f64>::zeros((d, p));
2795    for axis in 0..d {
2796        for basis_col in 0..m {
2797            let dphi = jet[[0, basis_col, axis]];
2798            if dphi == 0.0 {
2799                continue;
2800            }
2801            for out in 0..p {
2802                jm[[axis, out]] += dphi * decoder[[basis_col, out]];
2803            }
2804        }
2805    }
2806    let mut h = Array2::<f64>::zeros((d, d));
2807    for a in 0..d {
2808        for b in 0..d {
2809            h[[a, b]] = jm.row(a).dot(&jm.row(b));
2810        }
2811        h[[a, a]] += ridge;
2812    }
2813    let (vals, _vecs) = h.eigh(Side::Lower).ok()?;
2814    let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
2815    if lambda_min.is_finite() && lambda_min > 0.0 {
2816        Some(1.0 / lambda_min)
2817    } else {
2818        None
2819    }
2820}
2821
2822/// #1154 — the amortized encoder's closed-form warm-start coordinate for one
2823/// row `x` against one chart at amplitude `z`:
2824///
2825/// ```text
2826/// t̂ = t_c + (1/z) · A₁ · (x − z · m₁(t_c)),
2827/// ```
2828///
2829/// a single `O(d·p)` mat-vec from the chart's precomputed IFT Jacobian `A₁` and
2830/// center reconstruction `m₁`. Returns `None` when the chart carries no
2831/// distilled Jacobian (singular Gauss–Newton block) or the amplitude is not
2832/// strictly positive and finite (a near-inactive atom, where the
2833/// amplitude-divided map is undefined) — in those cases the caller starts from
2834/// the chart center instead. Shared by the amortized encode (where `t̂` is the
2835/// prediction) and the exact certified encode (where `t̂` is the Newton
2836/// warm-start that then refines to stationarity, Design A).
2837pub(crate) fn amortized_warm_start(
2838    chart: &CertifiedChart,
2839    x: ArrayView1<'_, f64>,
2840    amplitude: f64,
2841) -> Option<Array1<f64>> {
2842    let a1 = chart.amortized_jacobian.as_ref()?;
2843    if !(amplitude.is_finite() && amplitude.abs() > 0.0) {
2844        return None;
2845    }
2846    let d = a1.nrows();
2847    let mut t_hat = chart.region.center.clone();
2848    for (out_idx, &m1_out) in chart.recon_center.iter().enumerate().take(a1.ncols()) {
2849        let resid = x[out_idx] - amplitude * m1_out;
2850        for axis in 0..d {
2851            t_hat[axis] += a1[[axis, out_idx]] * resid / amplitude;
2852        }
2853    }
2854    Some(t_hat)
2855}
2856
2857/// Single-row amortized predictor against ONE atom's cached atlas, writing the
2858/// encoded latent coordinate DIRECTLY into `out` (length `d`) with NO heap
2859/// allocation. Routes to the nearest certifiable chart by cached center-
2860/// reconstruction distance `‖x − m(t_c)‖²` (matching `amortized_encode_batch_fast`'s
2861/// per-chart routing), then applies that chart's precomputed affine predictor
2862/// `t̂ = base + (1/z)·A₁·x` (`base = t_c − A₁·m₁` is `chart.amortized_base`).
2863///
2864/// Returns `false` — leaving `out` at its incoming (zeroed) value — for exactly the
2865/// rows `amortized_encode_batch_fast` would flag: no certifiable chart, a nearest
2866/// chart with no distilled predictor (singular Gauss–Newton block), or an unusable
2867/// amplitude. This is the per-row core of the allocation-free massive-K fast encode.
2868pub(crate) fn amortized_predict_row(
2869    atom_atlas: &AtomEncodeAtlas,
2870    x: ArrayView1<'_, f64>,
2871    amplitude: f64,
2872    d: usize,
2873    mut out: ndarray::ArrayViewMut1<'_, f64>,
2874) -> bool {
2875    if !(amplitude.is_finite() && amplitude.abs() > 0.0) {
2876        return false;
2877    }
2878    // Nearest certifiable chart by the TRUE objective ‖x − z·recon_center‖²
2879    // (F1; `z = amplitude`, finite and non-zero by the guard above). First-wins on
2880    // ties (strict `<`) — same amplitude-scaled argmin as
2881    // `amortized_encode_batch_fast`'s route_idx.
2882    let mut best_ci: Option<usize> = None;
2883    let mut best_dist = f64::INFINITY;
2884    for (ci, chart) in atom_atlas.charts.iter().enumerate() {
2885        if chart.certified_radius <= 0.0 {
2886            continue;
2887        }
2888        let mut dist = 0.0;
2889        for (r, xv) in chart.recon_center.iter().zip(x.iter()) {
2890            let diff = amplitude * r - xv;
2891            dist += diff * diff;
2892        }
2893        if dist < best_dist {
2894            best_dist = dist;
2895            best_ci = Some(ci);
2896        }
2897    }
2898    let Some(ci) = best_ci else {
2899        return false;
2900    };
2901    let chart = &atom_atlas.charts[ci];
2902    // The nearest chart must carry a distilled predictor; otherwise flag (zeroed),
2903    // exactly as the per-chart `None` branch of `amortized_encode_batch_fast`.
2904    let (Some(a1), Some(base)) = (
2905        chart.amortized_jacobian.as_ref(),
2906        chart.amortized_base.as_ref(),
2907    ) else {
2908        return false;
2909    };
2910    let inv_z = 1.0 / amplitude;
2911    for axis in 0..d {
2912        out[axis] = base[axis] + a1.row(axis).dot(&x) * inv_z;
2913    }
2914    true
2915}
2916
2917/// The amplitude-1 distilled amortized-encoder Jacobian at a chart center
2918/// (#1026 ladder item 3). Returns `(A₁, m₁)` where `m₁ = BᵀΦ(t_c) ∈ ℝᵖ` is the
2919/// amplitude-1 center reconstruction and `A₁ = (J₁ᵀJ₁ + ridge·I)⁻¹ J₁ ∈ ℝ^{d×p}`
2920/// is the implicit-function-theorem derivative of the encode map `x ↦ t`
2921/// (Gauss–Newton block — the residual-free, dominant curvature exactly as the
2922/// offline radius-sizing `β`). With these, the online encode of a row `x` at
2923/// amplitude `z` is the closed-form affine prediction
2924/// `t = t_c + (1/z)·A₁·(x − z·m₁)` — one mat-vec, no per-row factorization.
2925/// `None` when the basis has no jet or the Gauss–Newton block is singular (no
2926/// certifiable amortization), matching `center_beta`'s gate so a chart with a
2927/// finite `β` always carries a Jacobian and vice versa.
2928pub(crate) fn center_amortized_jacobian(
2929    atom: &SaeManifoldAtom,
2930    center: &Array1<f64>,
2931    ridge: f64,
2932) -> Option<(Array2<f64>, Array1<f64>)> {
2933    let evaluator = atom.basis_evaluator.as_ref()?.clone();
2934    let d = atom.latent_dim;
2935    let p = atom.output_dim();
2936    let m = atom.basis_size();
2937    let coords = center.view().to_shape((1, d)).ok()?.to_owned();
2938    let (phi, jet) = evaluator.evaluate(coords.view()).ok()?;
2939    let decoder = &atom.decoder_coefficients;
2940    // m₁(t_c) = BᵀΦ(t_c) ∈ ℝᵖ (amplitude-1 center reconstruction).
2941    let mut recon = Array1::<f64>::zeros(p);
2942    for basis_col in 0..m {
2943        let phi_v = phi[[0, basis_col]];
2944        if phi_v == 0.0 {
2945            continue;
2946        }
2947        for out in 0..p {
2948            recon[out] += phi_v * decoder[[basis_col, out]];
2949        }
2950    }
2951    // J₁[axis] = Bᵀ (∂Φ/∂t_axis) ∈ ℝᵖ (amplitude-1; z factors out analytically).
2952    let mut jm = Array2::<f64>::zeros((d, p));
2953    for axis in 0..d {
2954        for basis_col in 0..m {
2955            let dphi = jet[[0, basis_col, axis]];
2956            if dphi == 0.0 {
2957                continue;
2958            }
2959            for out in 0..p {
2960                jm[[axis, out]] += dphi * decoder[[basis_col, out]];
2961            }
2962        }
2963    }
2964    // H_GN = J₁ J₁ᵀ + ridge·I ∈ ℝ^{d×d}.
2965    let mut h = Array2::<f64>::zeros((d, d));
2966    for a in 0..d {
2967        for b in 0..d {
2968            h[[a, b]] = jm.row(a).dot(&jm.row(b));
2969        }
2970        h[[a, a]] += ridge;
2971    }
2972    let (vals, vecs) = h.eigh(Side::Lower).ok()?;
2973    let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
2974    if !(lambda_min.is_finite() && lambda_min > 0.0) {
2975        return None;
2976    }
2977    // A₁ = H_GN⁻¹ J₁ via the eigendecomposition: H⁻¹ = Σ_i (1/λᵢ) vᵢ vᵢᵀ, so
2978    // A₁[:, out] = Σ_i (vᵢ · J₁[:, out]) / λᵢ · vᵢ. Column-by-column keeps it the
2979    // d×p Jacobian (one SPD solve reused across all p output channels).
2980    let mut a1 = Array2::<f64>::zeros((d, p));
2981    for out in 0..p {
2982        let jcol = jm.column(out);
2983        for (i, &lam) in vals.iter().enumerate() {
2984            if !(lam.is_finite() && lam > 0.0) {
2985                return None;
2986            }
2987            let vi = vecs.column(i);
2988            let coeff = vi.dot(&jcol) / lam;
2989            for row in 0..d {
2990                a1[[row, out]] += coeff * vi[row];
2991            }
2992        }
2993    }
2994    Some((a1, recon))
2995}
2996
2997/// Route a target row to the nearest chart of an atom by reconstruction
2998/// distance: the chart whose center reconstruction `m(t_c)` is closest to `x`.
2999/// Returns the chart index and the distance, or `None` when the atom has no
3000/// charts.
3001pub(crate) fn nearest_chart(
3002    atom_atlas: &AtomEncodeAtlas,
3003    x: ArrayView1<'_, f64>,
3004    amplitude: f64,
3005) -> Option<(usize, f64)> {
3006    if atom_atlas.charts.is_empty() {
3007        return None;
3008    }
3009    let mut best: Option<(usize, f64)> = None;
3010    for (idx, chart) in atom_atlas.charts.iter().enumerate() {
3011        if chart.certified_radius <= 0.0 {
3012            continue;
3013        }
3014        // F1: score the TRUE encode objective `‖x − z·m(t_c)‖²` at the row's
3015        // amplitude `z`, NOT the amplitude-1 `‖x − m(t_c)‖²`. The distilled
3016        // `chart.recon_center` is the amplitude-1 reconstruction `m₁(t_c)`; the
3017        // reconstruction actually compared against `x` is `z·m₁(t_c)`. Routing on
3018        // the amplitude-1 centers picks the wrong chart whenever `z ≠ 1` (e.g. a
3019        // small `z` makes a large-norm center reconstruct near a small `x`).
3020        // Distance accumulated in place, no temporary array.
3021        let mut dist = 0.0;
3022        for (r, xv) in chart.recon_center.iter().zip(x.iter()) {
3023            let diff = amplitude * r - xv;
3024            dist += diff * diff;
3025        }
3026        if best.map(|(_, b)| dist < b).unwrap_or(true) {
3027            best = Some((idx, dist));
3028        }
3029    }
3030    best
3031}
3032
3033/// The `k` charts whose CENTER reconstruction `m(t_c)` is nearest to `x` in
3034/// ambient ‖·‖², returned as chart indices sorted by increasing distance (ties
3035/// broken by chart index — deterministic). Only certifiable charts
3036/// (`certified_radius > 0`) are considered, exactly like [`nearest_chart`], whose
3037/// single result is `nearest_charts_topk(.., 1)[0]`. Used by the certified encode
3038/// to refine the global basin on self-approaching atoms (see
3039/// [`CERTIFIED_ROUTING_TOPK`]).
3040pub(crate) fn nearest_charts_topk(
3041    atom_atlas: &AtomEncodeAtlas,
3042    x: ArrayView1<'_, f64>,
3043    amplitude: f64,
3044    k: usize,
3045) -> Vec<usize> {
3046    if atom_atlas.charts.is_empty() || k == 0 {
3047        return Vec::new();
3048    }
3049    let mut scored: Vec<(usize, f64)> = Vec::new();
3050    for (idx, chart) in atom_atlas.charts.iter().enumerate() {
3051        if chart.certified_radius <= 0.0 {
3052            continue;
3053        }
3054        // `m₁(t_c) = BᵀΦ(t_c)` is an OFFLINE per-chart constant already distilled
3055        // into `chart.recon_center` at build time (bit-for-bit the same φ·decoder
3056        // accumulation this used to recompute). Reuse it instead of re-evaluating
3057        // the basis at a fixed center for every row — that re-eval was the encode's
3058        // dominant per-row cost (charts × rows basis evals, each allocating the φ/jet
3059        // arrays). F1: score the amplitude-scaled `‖x − z·m₁(t_c)‖²` (the true
3060        // objective), not the amplitude-1 `‖m₁(t_c) − x‖²`. Computed in place.
3061        let mut dist = 0.0;
3062        for (r, xv) in chart.recon_center.iter().zip(x.iter()) {
3063            let diff = amplitude * r - xv;
3064            dist += diff * diff;
3065        }
3066        scored.push((idx, dist));
3067    }
3068    // Sort by distance, then chart index for a deterministic, first-wins order
3069    // consistent with `nearest_chart`'s strict-`<` tie rule.
3070    scored.sort_by(|a, b| {
3071        a.1.partial_cmp(&b.1)
3072            .unwrap_or(std::cmp::Ordering::Equal)
3073            .then(a.0.cmp(&b.0))
3074    });
3075    scored.into_iter().take(k).map(|(idx, _)| idx).collect()
3076}
3077
3078/// Reconstruction error `‖x − z·m(t)‖` of an encoded coordinate `t` — the
3079/// criterion the certified encode minimizes over its candidate charts to pick the
3080/// GLOBAL basin. `m(t) = Bᵀ Φ(t)` is the amplitude-1 reconstruction; `z` is the
3081/// amplitude. A non-finite reconstruction returns `+∞` so it never wins.
3082pub(crate) fn encode_reconstruction_error(
3083    atom: &SaeManifoldAtom,
3084    evaluator: &dyn SaeBasisEvaluator,
3085    coord: ArrayView1<'_, f64>,
3086    x: ArrayView1<'_, f64>,
3087    amplitude: f64,
3088) -> f64 {
3089    let d = atom.latent_dim;
3090    let p = atom.output_dim();
3091    let m = atom.basis_size();
3092    let coords = match coord.to_shape((1, d)) {
3093        Ok(c) => c.to_owned(),
3094        Err(_) => return f64::INFINITY,
3095    };
3096    let Ok((phi, _jet)) = evaluator.evaluate(coords.view()) else {
3097        return f64::INFINITY;
3098    };
3099    let mut err2 = 0.0;
3100    for out in 0..p {
3101        let mut recon = 0.0;
3102        for basis_col in 0..m {
3103            recon += phi[[0, basis_col]] * atom.decoder_coefficients[[basis_col, out]];
3104        }
3105        let r = x[out] - amplitude * recon;
3106        err2 += r * r;
3107    }
3108    if err2.is_finite() {
3109        err2.sqrt()
3110    } else {
3111        f64::INFINITY
3112    }
3113}
3114
3115/// Maximum number of chart centers laid down per atom (the SHAPE_BAND grid
3116/// point cap; mirrors `SHAPE_BAND_MAX_POINTS` in the atom band machinery).
3117pub(crate) const SHAPE_BAND_MAX_POINTS: usize = 512;
3118
3119/// Lay down chart centers on an atom's coordinate grid (the SHAPE_BAND grid
3120/// idiom): a regular grid spanning the compact latent domain for periodic /
3121/// sphere / torus atoms, and a strided cover of the latent axes for unbounded
3122/// (Duchon / Euclidean) atoms.
3123///
3124/// Periodic / torus latents are fractions of one period, so the per-axis grid
3125/// spans `[0, 1)`; the sphere chart spans `lat ∈ [−π/2, π/2]`, `lon ∈ [−π, π)`.
3126/// These conventions match the basis evaluators (the fraction-of-period circle
3127/// harmonic and the lat/lon sphere chart).
3128/// Squared coordinate distance between two latent points under the atom's chart
3129/// geometry: per-axis WRAPPED distance `min(|a−b|, period−|a−b|)` on periodic
3130/// (circle) axes — period 1 to match `chart_center_grid`'s `[0,1)` torus tiling
3131/// — and plain difference on line axes. Used to place + size data-driven charts.
3132pub(crate) fn coord_dist_sq(
3133    atom: &SaeManifoldAtom,
3134    a: ArrayView1<'_, f64>,
3135    b: ArrayView1<'_, f64>,
3136) -> f64 {
3137    // Per-axis period comes from the ONE canonical source `latent_axis_period` —
3138    // the SAME convention the certified-encode `in_chart` guard uses (via
3139    // `latent_coordinate_distance`): period-1 fraction axes for periodic/torus and
3140    // the cylinder angle, period-2π on the sphere LONGITUDE (axis 1), and
3141    // NON-periodic otherwise — including the sphere LATITUDE (axis 0), which ranges
3142    // over [−π/2, π/2] and must NOT wrap. A former local `periodic_axis` closure
3143    // wrapped BOTH sphere axes at period 1, so two genuinely-far radian longitudes
3144    // (e.g. 0 and 3 rad) collapsed to distance 0 — corrupting the farthest-point
3145    // placement AND the nearest-neighbour radii in `data_driven_chart_centers` for
3146    // sphere atoms, and disagreeing with the metric the encode's `in_chart`
3147    // soundness guard enforces. Delegating keeps the two metrics identical.
3148    let mut acc = 0.0;
3149    for axis in 0..a.len().min(b.len()) {
3150        let mut d = (a[axis] - b[axis]).abs();
3151        if let Some(period) = latent_axis_period(atom, axis) {
3152            let wrapped = d.rem_euclid(period);
3153            d = wrapped.min(period - wrapped);
3154        }
3155        acc += d * d;
3156    }
3157    acc
3158}
3159
3160/// Greedy farthest-point sampling of up to `max_charts` chart centers from the
3161/// atom's latent `coords` (n × d), with each center's nominal radius set to half
3162/// the distance to its nearest neighbor center (floored, so a singleton/coincident
3163/// cluster still gets a usable ball). Deterministic: seeds from row 0, then
3164/// repeatedly adds the coord maximally far (under [`coord_dist_sq`]) from the
3165/// chosen set — coverage-maximizing and reproducible run-to-run.
3166pub(crate) fn data_driven_chart_centers(
3167    atom: &SaeManifoldAtom,
3168    coords: ArrayView2<'_, f64>,
3169    max_charts: usize,
3170) -> Result<(Array2<f64>, Vec<f64>), String> {
3171    let n = coords.nrows();
3172    let d = coords.ncols();
3173    if d != atom.latent_dim {
3174        return Err(format!(
3175            "data_driven_chart_centers: coords have {d} cols but atom latent_dim is {}",
3176            atom.latent_dim
3177        ));
3178    }
3179    if n == 0 {
3180        return Ok((Array2::<f64>::zeros((0, d)), Vec::new()));
3181    }
3182    let k = max_charts.min(n);
3183    // Farthest-point sampling: maintain each row's distance to the nearest chosen
3184    // center, add the row with the maximum such distance each step.
3185    let mut chosen: Vec<usize> = Vec::with_capacity(k);
3186    chosen.push(0);
3187    let mut nearest_sq: Vec<f64> = (0..n)
3188        .map(|r| coord_dist_sq(atom, coords.row(r), coords.row(0)))
3189        .collect();
3190    while chosen.len() < k {
3191        // Pick the row farthest from the current center set (first-wins tie).
3192        let mut best = 0usize;
3193        let mut best_d = -1.0;
3194        for r in 0..n {
3195            if nearest_sq[r] > best_d {
3196                best_d = nearest_sq[r];
3197                best = r;
3198            }
3199        }
3200        if best_d <= 0.0 {
3201            break; // all remaining rows coincide with a chosen center.
3202        }
3203        chosen.push(best);
3204        for r in 0..n {
3205            let dr = coord_dist_sq(atom, coords.row(r), coords.row(best));
3206            if dr < nearest_sq[r] {
3207                nearest_sq[r] = dr;
3208            }
3209        }
3210    }
3211    let m = chosen.len();
3212    let mut centers = Array2::<f64>::zeros((m, d));
3213    for (i, &row) in chosen.iter().enumerate() {
3214        centers.row_mut(i).assign(&coords.row(row));
3215    }
3216    // Per-center radius = half the nearest-OTHER-center distance, floored so a
3217    // coincident pair still yields a positive ball, capped at 0.5 (the largest
3218    // meaningful half-period on a unit circle).
3219    let mut radii = vec![0.0_f64; m];
3220    for i in 0..m {
3221        let mut nn = f64::INFINITY;
3222        for j in 0..m {
3223            if i == j {
3224                continue;
3225            }
3226            let dsq = coord_dist_sq(atom, centers.row(i), centers.row(j));
3227            if dsq < nn {
3228                nn = dsq;
3229            }
3230        }
3231        let r = if nn.is_finite() { 0.5 * nn.sqrt() } else { 0.5 };
3232        radii[i] = r.max(1.0e-3).min(0.5);
3233    }
3234    Ok((centers, radii))
3235}
3236
3237pub(crate) fn chart_center_grid(atom: &SaeManifoldAtom, resolution: usize) -> Array2<f64> {
3238    use crate::manifold::SaeAtomBasisKind::*;
3239    let d = atom.latent_dim;
3240    match &atom.basis_kind {
3241        Periodic | Torus => regular_product_grid(d, resolution, 0.0, 1.0, false),
3242        // Cylinder `S¹ × ℝ`: axis 0 is the periodic circle `[0, 1)` (no
3243        // endpoint, like the harmonic axes); axis 1 is the unbounded line,
3244        // covered by a strided unit box `[-0.5, 0.5]` about the origin (like the
3245        // Euclidean patch). The certified radius refines each chart; out-of-cover
3246        // line starts route to the exact fallback honestly.
3247        Cylinder if d == 2 => cylinder_chart_center_grid(resolution),
3248        Cylinder => regular_product_grid(d, resolution, -0.5, 0.5, true),
3249        Sphere if d == 2 => sphere_latlon_grid(resolution),
3250        Linear | Sphere | Duchon | EuclideanPatch | Poincare | Precomputed(_) => {
3251            // Unbounded / non-compact latents: a strided cover of a unit box
3252            // about the origin per axis. The certified radius refines each chart;
3253            // out-of-cover starts route to the exact fallback honestly.
3254            regular_product_grid(d, resolution, -0.5, 0.5, true)
3255        }
3256    }
3257}
3258
3259/// A regular `resolution`-per-axis product grid over `[lo, hi]^d`, capped at
3260/// [`SHAPE_BAND_MAX_POINTS`] total points (the per-axis resolution is reduced
3261/// until the product fits). When `include_endpoint` the last grid point sits at
3262/// `hi`; otherwise the axis is treated as periodic and stops one step short.
3263/// Per-axis resolution actually used by [`regular_product_grid`] after the
3264/// [`SHAPE_BAND_MAX_POINTS`] product cap. Chart radii must be derived from THIS
3265/// (not the raw `resolution`), otherwise for `resolution^d > SHAPE_BAND_MAX_POINTS`
3266/// the grid spacing is coarser than the radius and the charts leave gaps.
3267pub(crate) fn capped_per_axis(d: usize, resolution: usize) -> usize {
3268    let mut per_axis = resolution.max(2);
3269    while per_axis.saturating_pow(d as u32) > SHAPE_BAND_MAX_POINTS && per_axis > 2 {
3270        per_axis -= 1;
3271    }
3272    per_axis
3273}
3274
3275pub(crate) fn regular_product_grid(
3276    d: usize,
3277    resolution: usize,
3278    lo: f64,
3279    hi: f64,
3280    include_endpoint: bool,
3281) -> Array2<f64> {
3282    if d == 0 {
3283        return Array2::<f64>::zeros((1, 0));
3284    }
3285    let per_axis = capped_per_axis(d, resolution);
3286    let total = per_axis.saturating_pow(d as u32).max(1);
3287    let denom = if include_endpoint {
3288        (per_axis.max(2) - 1) as f64
3289    } else {
3290        per_axis as f64
3291    };
3292    let mut grid = Array2::<f64>::zeros((total, d));
3293    let mut idx = vec![0usize; d];
3294    for flat in 0..total {
3295        for axis in 0..d {
3296            let frac = idx[axis] as f64 / denom;
3297            grid[[flat, axis]] = lo + (hi - lo) * frac;
3298        }
3299        for axis in (0..d).rev() {
3300            idx[axis] += 1;
3301            if idx[axis] < per_axis {
3302                break;
3303            }
3304            idx[axis] = 0;
3305        }
3306    }
3307    grid
3308}
3309
3310/// Lat/lon sphere chart grid: `lat ∈ [−π/2, π/2]`, `lon ∈ [−π, π)`, matching
3311/// the [`crate::manifold::SphereChartEvaluator`] convention.
3312pub(crate) fn sphere_latlon_grid(resolution: usize) -> Array2<f64> {
3313    use std::f64::consts::PI;
3314    // Per-axis cap derived from the shared point budget rather than a hardcoded
3315    // literal (#2071): the largest r with r² ≤ SHAPE_BAND_MAX_POINTS. Equals 22
3316    // at the current budget (22²=484≤512), but now tracks the budget if it
3317    // changes instead of silently desyncing from the sibling grid arms.
3318    let r_cap = SHAPE_BAND_MAX_POINTS.isqrt();
3319    let r = resolution.max(2).min(r_cap);
3320    let mut grid = Array2::<f64>::zeros((r * r, 2));
3321    for i in 0..r {
3322        let lat = -PI / 2.0 + PI * (i as f64 + 0.5) / r as f64;
3323        for j in 0..r {
3324            let lon = -PI + 2.0 * PI * (j as f64) / r as f64;
3325            grid[[i * r + j, 0]] = lat;
3326            grid[[i * r + j, 1]] = lon;
3327        }
3328    }
3329    grid
3330}
3331
3332/// Cylinder `S¹ × ℝ` chart-center grid: axis 0 sweeps the periodic circle over
3333/// one period `[0, 1)` (no endpoint, matching the harmonic axis), axis 1 strides
3334/// a unit box `[−0.5, 0.5]` about the origin on the unbounded line (with
3335/// endpoint). Capped at [`SHAPE_BAND_MAX_POINTS`] total centers.
3336pub(crate) fn cylinder_chart_center_grid(resolution: usize) -> Array2<f64> {
3337    let mut per_axis = resolution.max(2);
3338    while per_axis * per_axis > SHAPE_BAND_MAX_POINTS && per_axis > 2 {
3339        per_axis -= 1;
3340    }
3341    let total = per_axis * per_axis;
3342    let line_denom = (per_axis.max(2) - 1) as f64;
3343    let mut grid = Array2::<f64>::zeros((total, 2));
3344    for i in 0..per_axis {
3345        // Periodic axis 0: stop one step short of the period.
3346        let circle = i as f64 / per_axis as f64;
3347        for j in 0..per_axis {
3348            // Line axis 1: include the endpoint of the unit box.
3349            let line = -0.5 + (j as f64) / line_denom;
3350            grid[[i * per_axis + j, 0]] = circle;
3351            grid[[i * per_axis + j, 1]] = line;
3352        }
3353    }
3354    grid
3355}
3356
3357/// Nominal in-chart radius: half the inter-center grid spacing, so charts tile
3358/// the domain. For compact latents this is the grid step; for unbounded latents
3359/// a unit default that the certified radius refines.
3360pub(crate) fn chart_nominal_radius(atom: &SaeManifoldAtom, resolution: usize) -> f64 {
3361    use crate::manifold::SaeAtomBasisKind::*;
3362    match &atom.basis_kind {
3363        Periodic | Torus => 0.5 / (capped_per_axis(atom.latent_dim, resolution) as f64),
3364        // Must use the SAME capped per-axis count `min(resolution, 22)` that
3365        // `sphere_latlon_grid` lays the centers on: the coarsest tiling step is
3366        // the longitude half-spacing `π/r` with `r = min(resolution, 22)`. Deriving
3367        // the radius from the RAW resolution makes it smaller than the grid spacing
3368        // once `resolution > 22`, leaving gaps between charts so rows in the gaps
3369        // spuriously route to the exact fallback (the exact hazard `capped_per_axis`
3370        // documents for the regular grid).
3371        Sphere => std::f64::consts::PI / (resolution.max(2).min(22) as f64),
3372        // Cylinder charts tile two heterogeneous axes (a `[0,1)` periodic step
3373        // and a unit-box line step); the chart radius is a single scalar, so we
3374        // take the tighter (periodic) step `0.5/res` to keep every chart valid
3375        // on both axes. The certified Kantorovich radius refines it per chart.
3376        Cylinder => 0.5 / (capped_per_axis(atom.latent_dim, resolution) as f64),
3377        Linear | Duchon | EuclideanPatch | Poincare | Precomputed(_) => {
3378            1.0 / (resolution.max(2) as f64)
3379        }
3380    }
3381}
3382
3383/// Build the [`ChartRegion`] for a center, attaching the radial r_min / r_max
3384/// bracket for Duchon atoms (the chart's distance range to the kernel centers).
3385pub(crate) fn chart_region(
3386    atom: &SaeManifoldAtom,
3387    center: Array1<f64>,
3388    radius: f64,
3389) -> ChartRegion {
3390    use crate::manifold::SaeAtomBasisKind::*;
3391    let region = ChartRegion::new(center.clone(), radius);
3392    match &atom.basis_kind {
3393        Duchon => {
3394            // r ranges over [‖t_c‖ − radius, ‖t_c‖ + radius] about the single
3395            // origin-anchored center used by the conservative radial bound.
3396            //
3397            // The lower bound must be `max(0, center_norm − radius)` — NOT floored
3398            // at `radius`. When the chart contains the kernel center
3399            // (`center_norm < radius`, true r_min = 0), flooring at `radius`
3400            // would give a finite, NON-CONSERVATIVE `r_min`, causing the
3401            // hessian_sup / third_sup formulas (which divide by r_min) to
3402            // underestimate the Lipschitz constant and potentially grant a false
3403            // Kantorovich certificate. Flooring at `f64::MIN_POSITIVE` instead
3404            // correctly drives the formulas toward ∞, producing a very large L
3405            // that will NEVER certify (rows route to the exact multi-start
3406            // fallback) — conservative and sound.
3407            let center_norm = center.dot(&center).sqrt();
3408            let r_min = (center_norm - radius).max(f64::MIN_POSITIVE);
3409            let r_max = center_norm + radius;
3410            region.with_radial_bounds(r_min, r_max)
3411        }
3412        // Cylinder has no radial kernel block (it is a harmonic × polynomial
3413        // tensor, not a Duchon radial basis), so it needs no radial r_min/r_max.
3414        Periodic | Sphere | Torus | Cylinder | Linear | EuclideanPatch | Poincare
3415        | Precomputed(_) => region,
3416    }
3417}
3418
3419#[cfg(test)]
3420mod encode_fix_tests {
3421    //! Unit tests for the two `certify_with_basin_warmup` fixes:
3422    //!   FIX #3 — wrap-aware chart containment (`in_chart` now measures latent
3423    //!            distance on the atom's periodic geometry via
3424    //!            [`latent_coordinate_distance`]).
3425    //!   FIX #4 — multiplicative sufficient-decrease bound on the warm-up loop
3426    //!            ([`warmup_progress_sufficient`]).
3427    use super::*;
3428    use crate::manifold::SaeAtomBasisKind;
3429    use ndarray::{Array1, Array2, Array3};
3430
3431    /// Minimal atom carrying only a `basis_kind`; the wrap-aware metric reads no
3432    /// other field, so a tiny well-formed basis/decoder/penalty suffices.
3433    fn tiny_atom(kind: SaeAtomBasisKind, latent_dim: usize) -> SaeManifoldAtom {
3434        let m = 2usize;
3435        let phi = Array2::<f64>::eye(m);
3436        let jet = Array3::<f64>::zeros((m, m, latent_dim));
3437        let dec = Array2::<f64>::from_elem((m, 1), 0.5);
3438        let smooth = Array2::<f64>::eye(m);
3439        SaeManifoldAtom::new("tiny", kind, latent_dim, phi, jet, dec, smooth)
3440            .expect("tiny atom builds")
3441    }
3442
3443    /// FIX #3: a point on the far side of the wrap seam of a periodic axis is now
3444    /// IN-CHART under the wrap-aware metric, where the old raw-Euclidean sum used
3445    /// by `in_chart` rejected it. This is the exact predicate `in_chart` evaluates
3446    /// (`latent_coordinate_distance(atom, t, center) <= radius`).
3447    #[test]
3448    fn wrap_aware_containment_accepts_seam_point() {
3449        let atom = tiny_atom(SaeAtomBasisKind::Periodic, 1);
3450        let t = Array1::from(vec![0.99f64]);
3451        let center = Array1::from(vec![0.01f64]);
3452        let radius = 0.05;
3453
3454        // Precondition: the OLD raw-Euclidean latent distance is 0.98 — far
3455        // outside the 0.05 ball, so the old `in_chart` REJECTED this iterate.
3456        let raw = (t[0] - center[0]).abs();
3457        assert!(
3458            raw > radius,
3459            "precondition: raw Euclidean distance {raw} must exceed the chart radius {radius}"
3460        );
3461
3462        // The wrap-aware metric sees the true circle distance 0.02 (period 1.0),
3463        // so the seam point is now correctly IN-CHART.
3464        let d = latent_coordinate_distance(&atom, t.view(), center.view());
3465        assert!(
3466            (d - 0.02).abs() < 1e-12,
3467            "wrap-aware distance across the seam must be 0.02, got {d}"
3468        );
3469        assert!(
3470            d <= radius,
3471            "wrap-aware seam point must now be in-chart (d={d} <= r={radius})"
3472        );
3473    }
3474
3475    /// Soundness guard for FIX #3: on a NON-periodic axis the metric must NOT
3476    /// wrap — the same coordinates stay at their full Euclidean separation and
3477    /// (correctly) remain OUT of the small ball. This preserves the
3478    /// never-issue-a-false-certificate invariant for flat-patch families.
3479    #[test]
3480    fn wrap_aware_containment_does_not_wrap_flat_axis() {
3481        let atom = tiny_atom(SaeAtomBasisKind::EuclideanPatch, 1);
3482        let t = Array1::from(vec![0.99f64]);
3483        let center = Array1::from(vec![0.01f64]);
3484        let d = latent_coordinate_distance(&atom, t.view(), center.view());
3485        assert!(
3486            (d - 0.98).abs() < 1e-12,
3487            "a flat (non-periodic) axis must keep the full 0.98 distance, got {d}"
3488        );
3489        assert!(
3490            d > 0.05,
3491            "flat-axis point correctly stays out of a 0.05 ball"
3492        );
3493    }
3494
3495    /// FIX #4: a genuinely converging (Kantorovich-quadratic) `h`-sequence is
3496    /// untouched — every step clears the sufficient-decrease bar, so no
3497    /// converging row is regressed to the fallback.
3498    #[test]
3499    fn converging_h_sequence_is_never_flagged() {
3500        // Quadratic Newton contraction h_{k+1} = h_k^2 from an uncertified start.
3501        let mut h = 0.9f64;
3502        while h > KANTOROVICH_THRESHOLD {
3503            let h_next = h * h;
3504            assert!(
3505                warmup_progress_sufficient(h_next, h),
3506                "converging step {h} -> {h_next} must be accepted (no regression)"
3507            );
3508            h = h_next;
3509        }
3510    }
3511
3512    /// FIX #4: a monotone `h`-sequence decreasing toward a limit ABOVE ½ (the
3513    /// pathological plateau that the old strict-decrease rule spun on until the
3514    /// increments fell below one ulp) now terminates in a BOUNDED, small number
3515    /// of `warmup_progress_sufficient` steps — flag-to-fallback rather than loop.
3516    #[test]
3517    fn plateau_h_sequence_terminates_bounded() {
3518        let limit = 0.65f64; // plateau limit strictly above ½: never certifies
3519        let ratio = 0.8f64; // geometric approach; ratio -> 1 as h -> limit
3520        let mut h = 2.0f64; // uncertified start
3521        let start = h;
3522        let mut accepted = 0usize;
3523        let mut iters = 0usize;
3524        loop {
3525            iters += 1;
3526            // Hard ceiling: proves boundedness. The OLD strict-decrease rule
3527            // would run ~1e15 steps here (h strictly decreases every step toward
3528            // 0.65 and only stalls below one ulp), so tripping this assert would
3529            // signal a regression to the unbounded behavior.
3530            assert!(
3531                iters < 10_000,
3532                "warm-up must terminate in a bounded number of steps, not loop"
3533            );
3534            let h_next = limit + (h - limit) * ratio; // monotone decreasing > limit
3535            if !warmup_progress_sufficient(h_next, h) {
3536                break; // flag to the exact fallback
3537            }
3538            accepted += 1;
3539            h = h_next;
3540        }
3541        // It made real progress before flagging (a warm-up, not an instant bail)…
3542        assert!(
3543            accepted >= 1,
3544            "expected the warm-up to accept at least one contracting step first"
3545        );
3546        // …and it flagged while still strictly above the plateau limit — exactly
3547        // where the OLD rule would have kept spinning (h is still shrinking).
3548        assert!(
3549            h > limit && h < start,
3550            "flagged mid-descent (limit < h={h} < start={start}); old rule would not stop here"
3551        );
3552        // Termination bound was tiny, not astronomical.
3553        assert!(iters < 200, "plateau flagged in {iters} steps (bounded)");
3554    }
3555
3556    /// FIX #4: non-finite `h` (indefinite / blown-up Hessian) is treated as
3557    /// insufficient progress — the row flags rather than being accepted.
3558    #[test]
3559    fn non_finite_h_flags() {
3560        assert!(!warmup_progress_sufficient(f64::NAN, 0.9));
3561        assert!(!warmup_progress_sufficient(f64::INFINITY, 0.9));
3562        assert!(!warmup_progress_sufficient(0.5, f64::NAN));
3563    }
3564
3565    /// BUG (sphere data-driven placement): `coord_dist_sq` must measure sphere
3566    /// coordinates in RADIANS via the canonical `latent_axis_period` — longitude
3567    /// (axis 1) wraps at 2π, latitude (axis 0) does NOT wrap — not at unit period
3568    /// on both axes. The old period-1 wrap collapsed genuinely-far longitudes to
3569    /// distance 0, corrupting `data_driven_chart_centers` for sphere atoms.
3570    #[test]
3571    fn coord_dist_sq_sphere_uses_radian_metric_not_unit_period() {
3572        let atom = tiny_atom(SaeAtomBasisKind::Sphere, 2);
3573        // Longitude differs by 3.0 rad: circle distance min(3, 2π−3)=3
3574        // (2π−3 ≈ 3.283) ⇒ dist² = 9. The old period-1 wrap read 3 mod 1 = 0.
3575        let a = Array1::from(vec![0.0f64, 0.0]);
3576        let b = Array1::from(vec![0.0f64, 3.0]);
3577        let dsq = coord_dist_sq(&atom, a.view(), b.view());
3578        assert!(
3579            (dsq - 9.0).abs() < 1e-9,
3580            "sphere longitude dist² must be 3²=9 (2π radian period), got {dsq}"
3581        );
3582        // Latitude (axis 0) must NOT wrap: 1.2 rad apart ⇒ dist² = 1.44.
3583        let c = Array1::from(vec![0.0f64, 0.0]);
3584        let e = Array1::from(vec![1.2f64, 0.0]);
3585        let dsq_lat = coord_dist_sq(&atom, c.view(), e.view());
3586        assert!(
3587            (dsq_lat - 1.44).abs() < 1e-9,
3588            "sphere latitude must not wrap; 1.2²=1.44, got {dsq_lat}"
3589        );
3590        // Must agree with the certified-encode metric (single source of truth).
3591        let dc = latent_coordinate_distance(&atom, a.view(), b.view());
3592        assert!(
3593            (dc * dc - dsq).abs() < 1e-9,
3594            "coord_dist_sq must equal latent_coordinate_distance² ({} vs {dsq})",
3595            dc * dc
3596        );
3597    }
3598
3599    /// No-regression: periodic / torus axes still wrap at unit period.
3600    #[test]
3601    fn coord_dist_sq_torus_still_wraps_unit_period() {
3602        let atom = tiny_atom(SaeAtomBasisKind::Torus, 1);
3603        let a = Array1::from(vec![0.02f64]);
3604        let b = Array1::from(vec![0.98f64]);
3605        // Wrapped circle distance min(0.96, 0.04) = 0.04 ⇒ dist² = 0.0016.
3606        let dsq = coord_dist_sq(&atom, a.view(), b.view());
3607        assert!(
3608            (dsq - 0.0016).abs() < 1e-12,
3609            "torus unit-period wrap; got {dsq}"
3610        );
3611    }
3612
3613    /// BUG (sphere chart tiling): `chart_nominal_radius` for the sphere must use
3614    /// the SAME capped per-axis count `min(resolution, 22)` as `sphere_latlon_grid`,
3615    /// so the radius covers the (capped) longitude half-spacing π/22 and the charts
3616    /// tile without gaps for resolution > 22 (raw π/40 would leave gaps).
3617    #[test]
3618    fn chart_nominal_radius_sphere_covers_capped_grid_spacing() {
3619        let atom = tiny_atom(SaeAtomBasisKind::Sphere, 2);
3620        let lon_half_spacing = std::f64::consts::PI / 22.0; // r capped at 22
3621        let r = chart_nominal_radius(&atom, 40);
3622        assert!(
3623            r >= lon_half_spacing - 1e-12,
3624            "sphere radius {r} must cover the capped lon half-spacing {lon_half_spacing} \
3625             (no gaps for resolution>22); raw π/40={} would gap",
3626            std::f64::consts::PI / 40.0
3627        );
3628    }
3629
3630    // ---------------------------------------------------------------------
3631    // F1: amplitude-aware chart routing.
3632    // ---------------------------------------------------------------------
3633
3634    /// A single certifiable chart with the given amplitude-1 center reconstruction.
3635    fn chart_with_recon(recon: Vec<f64>) -> CertifiedChart {
3636        CertifiedChart {
3637            region: ChartRegion::new(Array1::zeros(1), 1.0),
3638            lipschitz: 1.0,
3639            beta_center: 1.0,
3640            certified_radius: 1.0,
3641            amortized_jacobian: None,
3642            recon_center: Array1::from(recon),
3643            amortized_base: None,
3644        }
3645    }
3646
3647    fn atlas_two_charts(m1: f64, m2: f64) -> AtomEncodeAtlas {
3648        AtomEncodeAtlas {
3649            atom_index: 0,
3650            latent_dim: 1,
3651            decoder_norm_sum: 1.0,
3652            charts: vec![chart_with_recon(vec![m1]), chart_with_recon(vec![m2])],
3653        }
3654    }
3655
3656    /// F1 counterexample (module review): chart centers reconstruct (amplitude 1)
3657    /// to `m₁ = 1` and `m₂ = 10`. A row `x = 1` at amplitude `z = 0.1`
3658    /// reconstructs as `z·m`, so chart 2 is EXACT (`0.1·10 = 1 = x`) and chart 1 is
3659    /// wrong (`0.1·1 = 0.1`, error `0.9`). Amplitude-blind routing on the
3660    /// amplitude-1 centers picks chart 1 (`|1−1| = 0 < |10−1| = 9`); the
3661    /// amplitude-aware fix picks chart 2.
3662    #[test]
3663    fn f1_routing_scores_amplitude_scaled_reconstruction() {
3664        let atlas = atlas_two_charts(1.0, 10.0);
3665        let x = Array1::from(vec![1.0]);
3666
3667        let (idx, _) = nearest_chart(&atlas, x.view(), 0.1).expect("routes");
3668        assert_eq!(
3669            idx, 1,
3670            "z=0.1: must route to the m=10 chart (z·m=1 exact), not m=1"
3671        );
3672
3673        let ranked = nearest_charts_topk(&atlas, x.view(), 0.1, 2);
3674        assert_eq!(ranked[0], 1, "z=0.1: nearest chart is the m=10 chart");
3675
3676        // Negative amplitude: z=-0.1 makes z·m₂ = -1 (error 2) and z·m₁ = -0.1
3677        // (error 1.1), so chart 1 is now nearer — the sign is respected.
3678        let (idxn, _) = nearest_chart(&atlas, x.view(), -0.1).expect("routes");
3679        assert_eq!(idxn, 0, "z=-0.1: sign-aware routing prefers the m=1 chart");
3680
3681        // No-regression at z=1: recovers the amplitude-1 argmin (m=1 chart exact).
3682        let (idx1, _) = nearest_chart(&atlas, x.view(), 1.0).expect("routes");
3683        assert_eq!(idx1, 0, "z=1 recovers the amplitude-1 nearest (m=1) chart");
3684        assert_eq!(nearest_charts_topk(&atlas, x.view(), 1.0, 1)[0], 0);
3685    }
3686
3687    // ---------------------------------------------------------------------
3688    // F2: the certificate uses the TRUE (un-ridged) Hessian.
3689    // ---------------------------------------------------------------------
3690
3691    /// A basis with constant `Φ`, zero Jacobian, and zero second jet: the
3692    /// reconstruction is locally CONSTANT, so the true encode Hessian is `0`
3693    /// (singular) and the coordinate is non-unique.
3694    #[derive(Debug)]
3695    struct ConstantPhi {
3696        m: usize,
3697        d: usize,
3698    }
3699    impl SaeBasisEvaluator for ConstantPhi {
3700        fn evaluate(
3701            &self,
3702            coords: ndarray::ArrayView2<'_, f64>,
3703        ) -> Result<(Array2<f64>, Array3<f64>), String> {
3704            let n = coords.nrows();
3705            Ok((
3706                Array2::ones((n, self.m)),
3707                Array3::zeros((n, self.m, self.d)),
3708            ))
3709        }
3710        fn second_jet_dyn(
3711            &self,
3712            coords: ndarray::ArrayView2<'_, f64>,
3713        ) -> Option<Result<ndarray::Array4<f64>, String>> {
3714            let n = coords.nrows();
3715            Some(Ok(ndarray::Array4::zeros((n, self.m, self.d, self.d))))
3716        }
3717        fn third_jet_dyn(
3718            &self,
3719            coords: ndarray::ArrayView2<'_, f64>,
3720        ) -> Option<Result<ndarray::Array5<f64>, String>> {
3721            if coords.ncols() != self.d {
3722                return Some(Err(format!(
3723                    "ConstantPhi::third_jet_dyn: expected d = {}, got {} coords",
3724                    self.d,
3725                    coords.ncols()
3726                )));
3727            }
3728            None
3729        }
3730    }
3731
3732    /// F2: a locally-constant reconstruction has a SINGULAR true Hessian (`H = 0`),
3733    /// so the point is NOT a genuine isolated minimum and must NOT be certified.
3734    /// The old code ridged the certified Hessian (`H → ridge·I`, PD), faking
3735    /// `β = 1/ridge`, `η = 0`, `h = 0 ≤ ½` — a FALSE certificate. With the true
3736    /// Hessian (`encode_grad_hess` no longer adds ridge) `beta_eta_newton` sees
3737    /// `λ_min = 0` and refuses.
3738    #[test]
3739    fn f2_certificate_uses_true_hessian_refuses_singular_field() {
3740        let atom = tiny_atom(SaeAtomBasisKind::EuclideanPatch, 1);
3741        let eval = ConstantPhi {
3742            m: atom.basis_size(),
3743            d: 1,
3744        };
3745        let t0 = Array1::from(vec![0.0]);
3746        let x = Array1::from(vec![0.5]);
3747
3748        let (g, h) = encode_grad_hess(&atom, &eval, t0.view(), x.view(), 1.0)
3749            .expect("encode_grad_hess runs")
3750            .expect("second jet present ⇒ Some");
3751        assert!(
3752            h.iter().all(|&v| v == 0.0),
3753            "the TRUE Hessian of a constant reconstruction is 0 — no ridge is added \
3754             to the certified field; got {h:?}"
3755        );
3756        assert!(
3757            g.iter().all(|&v| v == 0.0),
3758            "gradient is 0 at a flat reconstruction"
3759        );
3760
3761        let (cert, _) = row_certificate(&atom, &eval, t0.view(), x.view(), 1.0, 1.0)
3762            .expect("row_certificate runs");
3763        assert!(
3764            !cert.certified(),
3765            "a singular true Hessian must NOT be certified (the old ridged H falsely did)"
3766        );
3767        assert!(
3768            !cert.beta.is_finite(),
3769            "β must be ∞ (uncertifiable), never the ridge-faked 1/ridge; got {}",
3770            cert.beta
3771        );
3772    }
3773
3774    // ---------------------------------------------------------------------
3775    // F3: Duchon atoms are refused (honest refusal, never a fabricated bound).
3776    // ---------------------------------------------------------------------
3777
3778    /// F3: `build_atom_atlas_from_centers` must emit ONLY uncertified charts for a
3779    /// Duchon atom — the closed-form bound available here would fabricate cubic-r³
3780    /// jets and an origin center for a polyharmonic `c·r^(2m−d)` kernel over
3781    /// data-placed centers, risking an under-estimate of `L` (false certificate).
3782    /// The refusal routes every Duchon row to the exact multi-start encode. A
3783    /// non-Duchon control atom is NOT auto-zeroed by this guard.
3784    #[test]
3785    fn f3_duchon_atoms_are_uncertifiable() {
3786        let atom = tiny_atom(SaeAtomBasisKind::Duchon, 1);
3787        let centers = ndarray::array![[0.0_f64], [0.3], [0.7]];
3788        let radii = vec![0.1_f64, 0.1, 0.1];
3789        let atlas = EncodeAtlas::build_atom_atlas_from_centers(
3790            0,
3791            &atom,
3792            centers.view(),
3793            &radii,
3794            1.0,
3795            1.0,
3796            &AtlasConfig::default(),
3797        )
3798        .expect("duchon atlas builds (uncertified)");
3799        assert_eq!(atlas.charts.len(), 3, "one chart per center");
3800        for (i, chart) in atlas.charts.iter().enumerate() {
3801            assert_eq!(
3802                chart.certified_radius, 0.0,
3803                "duchon chart {i} must be uncertified (refused), got r={}",
3804                chart.certified_radius
3805            );
3806            assert!(
3807                chart.amortized_jacobian.is_none(),
3808                "duchon chart {i} must carry no amortized predictor"
3809            );
3810        }
3811    }
3812
3813    // ---------------------------------------------------------------------
3814    // F6: an already-certified warm-up step is never rejected by the progress rule.
3815    // ---------------------------------------------------------------------
3816
3817    /// F6: a step that just crossed into the certified region (`h ≤ ½`) is accepted
3818    /// even when its multiplicative decrease was below the progress floor
3819    /// (`0.501 → 0.499`). Only a STILL-uncertified plateau step is rejected. The
3820    /// old unconditional progress test rejected the `0.501 → 0.499` cross — a false
3821    /// negative that forced an unnecessary exact-solve fallback.
3822    #[test]
3823    fn f6_certified_step_not_rejected_by_progress_rule() {
3824        // 0.501 → 0.499: below the 1/64 multiplicative floor, and NOT the quadratic
3825        // path, so `warmup_progress_sufficient` is false…
3826        assert!(
3827            !warmup_progress_sufficient(0.499, 0.501),
3828            "precondition: this tiny decrease fails the progress bar"
3829        );
3830        // …but because the step is now CERTIFIED it must NOT be rejected (F6).
3831        assert!(
3832            !warmup_should_reject(true, 0.499, 0.501),
3833            "a certified step must be accepted regardless of the progress floor"
3834        );
3835        // A still-uncertified plateau step IS rejected (the guard still bites).
3836        assert!(
3837            warmup_should_reject(false, 0.499, 0.501),
3838            "an uncertified sub-floor step is a plateau and must flag to fallback"
3839        );
3840        // A certified step that ALSO made big progress is accepted (sanity).
3841        assert!(!warmup_should_reject(true, 0.1, 0.9));
3842        // A genuinely-contracting uncertified step is accepted (no regression).
3843        assert!(!warmup_should_reject(false, 0.4, 0.9));
3844    }
3845}