gam_sae/encode.rs
1//! Kantorovich-certified encode atlas (issue #1010).
2//!
3//! Encoding a row `x ∈ ℝᵖ` against a FROZEN SAE dictionary is, per atom `k`,
4//! the coordinate-only Newton problem
5//!
6//! ```text
7//! min_t f_k(t) = ½‖x − z_k · B_kᵀ Φ_k(t)‖² + prior_k(t),
8//! ```
9//!
10//! with the amplitude `z_k` and decoder block `B_k` held fixed (the encode
11//! freezes the dictionary; only the latent coordinate `t` moves). Newton on
12//! `F(t) = ∇f_k(t)` converges quadratically from a start `t₀` into the unique
13//! root in a certified ball whenever the **Newton–Kantorovich** quantity
14//!
15//! ```text
16//! h = β · η · L ≤ ½, β = ‖F'(t₀)⁻¹‖, η = ‖F'(t₀)⁻¹ F(t₀)‖,
17//! ```
18//!
19//! where `L` is a Lipschitz constant of `F'` (the Hessian of `f_k`) on a region
20//! containing the Newton iterates. `h` is CHECKABLE per row in `O(q³)`
21//! (`q = latent_dim`, tiny), so each fast-path encode carries its own
22//! exactness certificate.
23//!
24//! ## The closed-form Hessian-Lipschitz constant `L`
25//!
26//! Write `m(t) = z·BᵀΦ(t) ∈ ℝᵖ` (the reconstruction) and `r(t) = m(t) − x`.
27//! Then `f = ½‖r‖² + prior` and, differentiating three times,
28//!
29//! ```text
30//! ∇³f = 3·sym(J_mᵀ : ∇²m) + ⟨r, ∇³m⟩ + ∇³prior,
31//! ```
32//!
33//! so an operator-norm bound on the chart is
34//!
35//! ```text
36//! L ≤ 3·‖J_m‖·‖∇²m‖ + ‖r‖·‖∇³m‖ + L_prior,
37//! ```
38//!
39//! with `‖∂^g m‖ ≤ |z|·(Σ_m ‖B_{m,:}‖)·B_g`, where `B_g = sup_chart max_m
40//! ‖∂^g Φ_m‖` is the per-column jet sup of the basis family — closed form per
41//! family ([`BasisHessianLipschitz`]). `‖r‖` is bounded by `‖x‖ +
42//! |z|·(Σ_m‖B_{m,:}‖)·B_0`. The ARD/von-Mises prior `L_prior` is a closed-form
43//! constant from the prior strength. Every bound is conservative (an
44//! over-estimate of `L` only SHRINKS the certified radius — it can never
45//! certify a row that does not converge).
46//!
47//! ## Pipeline
48//!
49//! 1. **Offline, per atom** ([`EncodeAtlas::build`]): chart centers `t_c` on the
50//! atom's coordinate grid (the SHAPE_BAND grid idiom), each with a certified
51//! Newton radius `R_c` solved from the Kantorovich inequality at the
52//! worst-case in-chart start.
53//! 2. **Online, per row** ([`EncodeAtlas::certified_encode_row`]): route to the
54//! nearest chart, start from its distilled IFT predictor, take one or two
55//! Newton steps, then the `h ≤ ½` check AT the start point is the per-row
56//! certificate.
57//! 3. **Uncertified tail**: rows whose start fails `h ≤ ½` are FLAGGED (counted
58//! in [`EncodeResult::encode_uncertified_count`]) and must be routed by the
59//! caller to the existing exact multi-start solve. No approximation enters
60//! silently.
61
62use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
63
64use crate::candidate_index::{AtomFrameSketch, SaeCandidateIndex, auto_candidate_budget};
65use crate::manifold::{
66 AffineCoordinateEvaluator, CylinderHarmonicEvaluator, DuchonCoordinateEvaluator,
67 EuclideanPatchEvaluator, PeriodicHarmonicEvaluator, SaeBasisEvaluator, SaeManifoldAtom,
68 SphereChartEvaluator, TorusHarmonicEvaluator,
69};
70use gam_linalg::faer_ndarray::FaerEigh;
71
72use faer::Side;
73
74/// The Kantorovich convergence threshold `h ≤ ½`. Below this the Newton
75/// iteration is guaranteed to converge quadratically into the unique root in
76/// the certified ball; at or above it the start is uncertified.
77pub const KANTOROVICH_THRESHOLD: f64 = 0.5;
78
79/// Row count at or above which the corpus-rate certified-encode batch
80/// (`certified_encode_batch` / `certified_encode_with_index`) fans its
81/// per-row encodes out over rayon. Below this the per-row Newton + chart
82/// routing is cheap enough that the fan-out overhead does not pay; matched to
83/// the same order as the arrow-Schur `SCHUR_MATVEC_PARALLEL_ROW_MIN` gate so
84/// short batches inside an outer atom-level fan-out stay sequential.
85pub(crate) const ENCODE_BATCH_PARALLEL_ROW_MIN: usize = 256;
86
87/// Minimum frame alignment `‖Uₖᵀd‖/‖d‖ ∈ [0,1]` the routed atom must have for an
88/// index-routed encode to be attempted at all — a FIT-QUALITY floor, NOT a
89/// routing-correctness gate (#1026, corrected by the #1777 exact-routing path).
90///
91/// Routing itself is now EXACT: the index-routed encode picks the atom via
92/// [`SaeCandidateIndex::route_exact`], which returns the GLOBAL argmax of the
93/// routing score (the universal-bound LSH fast path, else a full-scan fallback) —
94/// so there is no "missed-better-ungathered-atom" hole left for this constant to
95/// patch. What remains is a different, honest question: even the globally-best
96/// atom may align only weakly with a row (no atom in the dictionary fits it). A
97/// finite alignment below this floor means the best available atom is a poor fit,
98/// so the row is flagged and routed to the exact multi-start fallback rather than
99/// encoded against an atom it barely belongs to. (Previously this same comparison
100/// double-served as a recall proxy; that role is gone — `route_exact` guarantees
101/// recall — leaving only the fit-quality role described here.)
102pub(crate) const CANDIDATE_ROUTING_MIN_ALIGNMENT: f64 = 0.5;
103
104/// Number of nearest charts the CERTIFIED encode refines in before returning the
105/// lowest-reconstruction-error certified result. A single nearest chart is not
106/// globally sound where the decoded manifold folds near itself (both competing
107/// basins' charts reconstruct near the fold, so both rank among the nearest by
108/// ambient distance); refining the top few captures the global basin. For a
109/// unimodal atom all candidates converge to the same root, so K>1 is a no-op.
110pub(crate) const CERTIFIED_ROUTING_TOPK: usize = 4;
111
112/// Newton refinement convergence floor. Once a refinement step's length `‖δ‖`
113/// falls below this (relative to the coordinate scale `1 + ‖t‖`), the iterate has
114/// reached the certified root to f64 resolution: applying the step cannot move `t`
115/// meaningfully, and the remaining fixed-budget steps only re-accumulate round-off.
116/// Stopping there is STRICTLY more accurate than draining a fixed step budget on a
117/// well-conditioned quadratic Newton tail, and it removes that tail's per-step
118/// `evaluate` + `second_jet` cost (the dominant per-row encode work). The batched
119/// and per-row encodes share this rule, so they stay bit-identical.
120pub(crate) const NEWTON_REFINE_CONVERGED_EPS: f64 = 1.0e-12;
121
122/// Global-minimum short-circuit floor for top-K certified routing. The
123/// reconstruction error `‖x − z·m(t)‖` is bounded below by 0, so a certified
124/// candidate whose residual already sits at the ambient noise floor
125/// (`≤ this · (1 + ‖x‖)`) is provably the global optimum over the charts — no
126/// competing chart can reach a strictly lower residual. The remaining candidates'
127/// refinement is then skipped. Conservative (a genuine second basin of the same
128/// target reconstructs the SAME point, so returning the first is a valid encode).
129pub(crate) const CERTIFIED_GLOBAL_MIN_RECON_FLOOR: f64 = 1.0e-11;
130
131/// A chart region on an atom's latent coordinate: a center `t_c` plus a
132/// certified in-chart radius. Over the ball `‖t − t_c‖ ≤ radius` the jet sup
133/// bounds returned by [`BasisHessianLipschitz`] hold, so the Kantorovich
134/// constant `L` computed from them is valid for any start in the ball.
135///
136/// For radial (Duchon) families the chart also carries the minimum kernel-center
137/// distance `exclusion_r_min` (a lower bound on `‖t − c_k‖` over the chart) that
138/// bounds the otherwise-singular `1/r` radial tails (issue #1010).
139#[derive(Debug, Clone)]
140pub struct ChartRegion {
141 /// Chart center coordinate `t_c` (length = latent_dim).
142 pub center: Array1<f64>,
143 /// In-chart radius in the coordinate metric.
144 pub radius: f64,
145 /// For radial (Duchon) families: a lower bound on `‖t − c_k‖` over the
146 /// chart, across every kernel center `c_k`. `None` for non-radial families.
147 pub exclusion_r_min: Option<f64>,
148 /// For radial (Duchon) families: an upper bound on `‖t − c_k‖` over the
149 /// chart, across every kernel center `c_k`. `None` for non-radial families.
150 pub radial_r_max: Option<f64>,
151}
152
153impl ChartRegion {
154 pub fn new(center: Array1<f64>, radius: f64) -> Self {
155 Self {
156 center,
157 radius,
158 exclusion_r_min: None,
159 radial_r_max: None,
160 }
161 }
162
163 pub fn with_radial_bounds(mut self, r_min: f64, r_max: f64) -> Self {
164 self.exclusion_r_min = Some(r_min);
165 self.radial_r_max = Some(r_max);
166 self
167 }
168
169 /// A jet-sup certificate is only meaningful over a genuine region. Even
170 /// families whose bounds are manifold-global constants (the sup over any
171 /// chart equals the global sup) must refuse a malformed chart rather than
172 /// certify garbage geometry.
173 pub(crate) fn assert_valid(&self) {
174 assert!(
175 self.radius.is_finite()
176 && self.radius >= 0.0
177 && self.center.iter().all(|c| c.is_finite()),
178 "ChartRegion must have a finite center and a finite non-negative radius"
179 );
180 }
181}
182
183/// Per-column sup-norm bounds on the first three coordinate jets of a basis
184/// family `Φ(t)`, valid over a stated [`ChartRegion`] (issue #1010). These are
185/// the analytic ingredients of the Hessian-Lipschitz constant `L` — see the
186/// module docs for the assembly. `value_sup` bounds `max_m |Φ_m|`,
187/// `jacobian_sup`/`hessian_sup`/`third_sup` bound `max_m ‖∂^g Φ_m‖`.
188pub trait BasisHessianLipschitz {
189 fn value_sup(&self, chart: &ChartRegion) -> f64;
190 fn jacobian_sup(&self, chart: &ChartRegion) -> f64;
191 fn hessian_sup(&self, chart: &ChartRegion) -> f64;
192 fn third_sup(&self, chart: &ChartRegion) -> f64;
193}
194
195/// Sup over the circle of the `g`-th derivative of any single harmonic column
196/// of a `num_basis`-wide Fourier basis `[1, sin(2π h t), cos(2π h t), …]`:
197/// `(2π·H)^g` for the top harmonic `H = (num_basis − 1)/2`. The constant column
198/// contributes `0` for `g ≥ 1`, so the top harmonic dominates; the bound is
199/// global (the trig magnitudes are `≤ 1` everywhere, independent of the chart).
200pub(crate) fn harmonic_jet_sup(num_basis: usize, order: u32) -> f64 {
201 let top_harmonic = num_basis.saturating_sub(1) / 2;
202 let omega = std::f64::consts::TAU * top_harmonic as f64;
203 omega.powi(order as i32)
204}
205
206impl BasisHessianLipschitz for PeriodicHarmonicEvaluator {
207 fn value_sup(&self, chart: &ChartRegion) -> f64 {
208 chart.assert_valid();
209 1.0
210 }
211 fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
212 chart.assert_valid();
213 harmonic_jet_sup(self.num_basis, 1)
214 }
215 fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
216 chart.assert_valid();
217 harmonic_jet_sup(self.num_basis, 2)
218 }
219 fn third_sup(&self, chart: &ChartRegion) -> f64 {
220 chart.assert_valid();
221 harmonic_jet_sup(self.num_basis, 3)
222 }
223}
224
225impl BasisHessianLipschitz for TorusHarmonicEvaluator {
226 /// Tensor product of per-axis circle harmonics. A torus basis column is a
227 /// product of single-axis harmonics, each bounded as in the circle case.
228 /// The `g`-th coordinate jet routes `g` derivative operators across the
229 /// `latent_dim` factors (Leibniz); each routing contributes a product of
230 /// per-axis derivative magnitudes. A per-column sup is therefore bounded by
231 /// the top single-axis frequency to the `g`-th power times the number of
232 /// such routings (`latent_dim^g`, the count of operator-to-axis maps).
233 fn value_sup(&self, chart: &ChartRegion) -> f64 {
234 chart.assert_valid();
235 1.0
236 }
237 fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
238 chart.assert_valid();
239 torus_jet_sup(self.num_harmonics, self.latent_dim, 1)
240 }
241 fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
242 chart.assert_valid();
243 torus_jet_sup(self.num_harmonics, self.latent_dim, 2)
244 }
245 fn third_sup(&self, chart: &ChartRegion) -> f64 {
246 chart.assert_valid();
247 torus_jet_sup(self.num_harmonics, self.latent_dim, 3)
248 }
249}
250
251/// Per-column `g`-th jet sup for the torus harmonic basis: `(2π·H)^g ·
252/// latent_dim^g`, where `H = num_harmonics` is the top per-axis frequency and
253/// `latent_dim^g` over-counts the Leibniz routings of `g` operators across the
254/// product factors (a conservative bound — each routing's per-axis magnitude is
255/// `≤ (2π H)^{#ops on that axis}`, and the products telescope to `(2π H)^g`).
256pub(crate) fn torus_jet_sup(num_harmonics: usize, latent_dim: usize, order: u32) -> f64 {
257 let omega = std::f64::consts::TAU * num_harmonics as f64;
258 omega.powi(order as i32) * (latent_dim as f64).powi(order as i32)
259}
260
261impl BasisHessianLipschitz for SphereChartEvaluator {
262 /// The 7-column lat/lon chart `[1, x, y, z, xy, yz, xz]` with
263 /// `x = cos(lat)cos(lon)`, `y = cos(lat)sin(lon)`, `z = sin(lat)`. Each of
264 /// `x, y, z` is a product of two unit-frequency trig factors, so its `g`-th
265 /// coordinate jet is a sum of `2^g` products of `{sin,cos}` (each `≤ 1`):
266 /// magnitude `≤ 2^g` for `g ≥ 1`, `≤ 1` for `g = 0`. The bilinear columns
267 /// `xy, yz, xz` are products of two such coordinates; by Leibniz over the
268 /// product, their `g`-th jet is bounded by `Σ_{i=0}^{g} C(g,i)·(2^i)·(2^{g−i})
269 /// = (2+2)^g = 4^g` (using `‖∂^i u‖ ≤ 2^i`, `|u| ≤ 1`). The bilinear columns
270 /// dominate, so the per-column sup is `4^g` (`g ≥ 1`). Bounds are global
271 /// constants — the chart box `lat ∈ [-π/2, π/2]` does not enlarge them.
272 fn value_sup(&self, chart: &ChartRegion) -> f64 {
273 chart.assert_valid();
274 1.0
275 }
276 fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
277 chart.assert_valid();
278 4.0
279 }
280 fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
281 chart.assert_valid();
282 16.0
283 }
284 fn third_sup(&self, chart: &ChartRegion) -> f64 {
285 chart.assert_valid();
286 64.0
287 }
288}
289
290impl BasisHessianLipschitz for AffineCoordinateEvaluator {
291 /// The affine basis `[1, t₁, …, t_d]` is degree ≤ 1: its first jet has unit
292 /// columns, and all second and third jets vanish. The value sup is
293 /// `max(1, ‖t‖)` over the chart, bounded by `1 + ‖t_c‖ + radius`.
294 fn value_sup(&self, chart: &ChartRegion) -> f64 {
295 let center_norm = chart.center.dot(&chart.center).sqrt();
296 1.0 + center_norm + chart.radius
297 }
298 fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
299 chart.assert_valid();
300 1.0
301 }
302 fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
303 chart.assert_valid();
304 0.0
305 }
306 fn third_sup(&self, chart: &ChartRegion) -> f64 {
307 chart.assert_valid();
308 0.0
309 }
310}
311
312impl BasisHessianLipschitz for EuclideanPatchEvaluator {
313 /// Monomials of total degree ≤ `max_degree` in `t ∈ ℝ^d`. Over the ball of
314 /// radius `R` about `t_c`, each coordinate is bounded by `ρ = ‖t_c‖∞ + R`.
315 /// A monomial `t^α` with `|α| = q` has `g`-th partials bounded (crudely) by
316 /// the descending-factorial coefficient `q·(q−1)···(q−g+1) ≤ q^g` times
317 /// `ρ^{max(q−g,0)}`, and there are at most `d^g` partial routings, so the
318 /// per-column `g`-th jet sup is `≤ d^g · D^g · ρ^{max(D−g,0)}` with
319 /// `D = max_degree`. Conservative; D is small for patch evaluators.
320 fn value_sup(&self, chart: &ChartRegion) -> f64 {
321 let rho = patch_rho(chart);
322 let d = self.max_degree as i32;
323 rho.powi(d).max(1.0)
324 }
325 fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
326 patch_jet_sup(self.latent_dim, self.max_degree, chart, 1)
327 }
328 fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
329 patch_jet_sup(self.latent_dim, self.max_degree, chart, 2)
330 }
331 fn third_sup(&self, chart: &ChartRegion) -> f64 {
332 patch_jet_sup(self.latent_dim, self.max_degree, chart, 3)
333 }
334}
335
336impl BasisHessianLipschitz for CylinderHarmonicEvaluator {
337 /// Cylinder `S¹ × ℝ` product basis `Φ_{c,l} = c(t₀)·l(t₁)`, the circle
338 /// (periodic harmonic) factor on axis 0 crossed with the monomial line
339 /// factor on axis 1. Because the two factors depend on disjoint coordinates,
340 /// the order-`g` coordinate jet in any cell is exactly
341 /// `c^{(k₀)}(t₀)·l^{(k₁)}(t₁)` with `k₀ + k₁ = g`, so the per-column sup is
342 /// the max over the split `k₀ + k₁ = g` of the product of the two per-axis
343 /// per-order sups: the circle factor contributes `1` at order 0 and
344 /// `(2π·H)^{k₀}` at order `k₀ ≥ 1` (trig magnitudes `≤ 1`); the line factor
345 /// contributes the monomial-patch sup `D^{k₁}·ρ^{max(D−k₁,0)}` (`D = line
346 /// degree`, `ρ = ‖t_c‖∞ + radius`). Bounds are global in the periodic axis
347 /// and chart-local in the line axis.
348 fn value_sup(&self, chart: &ChartRegion) -> f64 {
349 cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 0)
350 }
351 fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
352 cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 1)
353 }
354 fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
355 cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 2)
356 }
357 fn third_sup(&self, chart: &ChartRegion) -> f64 {
358 cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 3)
359 }
360}
361
362/// Per-column order-`g` jet sup of the cylinder product basis: the max over
363/// `k₀ + k₁ = g` of `circle_axis_sup(k₀) · line_axis_sup(k₁)`, where the circle
364/// axis sup is `(2π·H)^{k₀}` (`1` at `k₀ = 0`) and the line axis sup is the
365/// monomial-patch bound `D^{k₁}·ρ^{max(D−k₁,0)}` (`1` at `k₁ = 0`). See the
366/// [`CylinderHarmonicEvaluator`] doc comment for the derivation.
367pub(crate) fn cylinder_jet_sup(
368 circle_harmonics: usize,
369 line_degree: usize,
370 chart: &ChartRegion,
371 order: u32,
372) -> f64 {
373 let omega = std::f64::consts::TAU * circle_harmonics as f64;
374 let big_d = line_degree as f64;
375 let rho = patch_rho(chart);
376 let mut best = 0.0_f64;
377 for k0 in 0..=order {
378 let k1 = order - k0;
379 let circle = if k0 == 0 { 1.0 } else { omega.powi(k0 as i32) };
380 let line = if k1 == 0 {
381 rho.powi(line_degree as i32).max(1.0)
382 } else {
383 let residual = line_degree.saturating_sub(k1 as usize) as i32;
384 // `.max(1.0)` as in `patch_jet_sup`: for ρ < 1 a lower-degree line
385 // monomial dominates the k1-th derivative, so the bare `ρ^residual`
386 // underestimates the line-factor sup. The value case (k1==0) already
387 // clamps; this completes it for the derivative orders.
388 big_d.powi(k1 as i32) * rho.powi(residual).max(1.0)
389 };
390 best = best.max(circle * line);
391 }
392 best
393}
394
395/// Sup-norm radius `ρ = ‖t_c‖∞ + radius` of the chart (the coordinate magnitude
396/// bound used by the monomial-patch jet bounds).
397pub(crate) fn patch_rho(chart: &ChartRegion) -> f64 {
398 let center_inf = chart
399 .center
400 .iter()
401 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
402 center_inf + chart.radius
403}
404
405/// Per-column `g`-th jet sup for a monomial patch of max degree `D` in `d`
406/// coordinates over the chart: `d^g · D^g · ρ^{max(D−g,0)}` (see the
407/// [`EuclideanPatchEvaluator`] doc comment for the derivation).
408pub(crate) fn patch_jet_sup(
409 latent_dim: usize,
410 max_degree: usize,
411 chart: &ChartRegion,
412 order: u32,
413) -> f64 {
414 let d = latent_dim as f64;
415 let big_d = max_degree as f64;
416 let rho = patch_rho(chart);
417 let residual_degree = max_degree.saturating_sub(order as usize) as i32;
418 // `.max(1.0)`: for ρ < 1 (small charts near the origin) the g-th jet sup is NOT
419 // dominated by the max-degree monomial `t^D` (whose g-th derivative ~ ρ^{D-g}
420 // shrinks with ρ) but by a LOWER-degree monomial whose g-th derivative is a
421 // larger constant — e.g. {1,t,t²}'s jacobian sup is the linear term's constant
422 // `1`, which exceeds `2ρ` when ρ < ½. Without the clamp the bound underestimates
423 // the true sup (numerically: D=3, ρ=0.1, g=1 → formula 0.03 vs true 1.0), which
424 // would make the certificate's Lipschitz `L` too small → a FALSE certificate.
425 // `D^g · max(ρ^{D-g}, 1)` upper-bounds `max_{q∈[g,D]} (q!/(q-g)!)·ρ^{q-g}` for
426 // all ρ (the `q=g` term gives `g! ≤ D^g`, the `q=D` term gives `≤ D^g·ρ^{D-g}`).
427 d.powi(order as i32) * big_d.powi(order as i32) * rho.powi(residual_degree).max(1.0)
428}
429
430impl BasisHessianLipschitz for DuchonCoordinateEvaluator {
431 /// Radial-kernel basis `Φ_m(t) = φ(r_m)`, `r_m = ‖t − c_m‖`, plus a
432 /// polynomial nullspace block. For the cubic Duchon kernel `φ(r) = r³` the
433 /// radial derivatives are `φ' = 3r²`, `φ'' = 6r`, `φ''' = 6`. The chain rule
434 /// to coordinate jets introduces `1/r` factors through the unit radial
435 /// direction `u = (t − c)/r` and the projector `(I − uuᵀ)/r`, so over a
436 /// chart the jets are bounded by combining the radial-derivative magnitudes
437 /// at the worst-case radius with the inverse-radius tail at the chart's
438 /// EXCLUSION radius `r_min` (the closest a chart point gets to any center):
439 ///
440 /// ```text
441 /// ‖∇φ‖ ≤ |φ'| ≤ 3 r_max²
442 /// ‖∇²φ‖ ≤ |φ''| + |φ'|/r ≤ 6 r_max + 3 r_max²/r_min
443 /// ‖∇³φ‖ ≤ |φ'''| + 3|φ''|/r + 3|φ'|/r² ≤ 6 + 18 r_max/r_min + 9 r_max²/r_min²
444 /// ```
445 ///
446 /// (the `1/r`, `1/r²` tails are bounded by `1/r_min`, `1/r_min²`). The
447 /// polynomial nullspace block is degree ≤ `order`; its jets are bounded like
448 /// the monomial patch with `D = order`. The per-column sup is the max of the
449 /// kernel and polynomial bounds. The `r³` kernel is itself `C²` (no
450 /// singularity) so these tails are conservative but finite for any
451 /// `r_min > 0`; the atlas refines charts to keep `r_min` bounded away from 0.
452 fn value_sup(&self, chart: &ChartRegion) -> f64 {
453 let r_max = chart.radial_r_max.unwrap_or(chart.radius);
454 let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 0);
455 (r_max.powi(3)).max(poly)
456 }
457 fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
458 let r_max = chart.radial_r_max.unwrap_or(chart.radius);
459 let kernel = 3.0 * r_max * r_max;
460 let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 1);
461 kernel.max(poly)
462 }
463 fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
464 let r_max = chart.radial_r_max.unwrap_or(chart.radius);
465 let r_min = chart
466 .exclusion_r_min
467 .unwrap_or(chart.radius)
468 .max(f64::MIN_POSITIVE);
469 let kernel = 6.0 * r_max + 3.0 * r_max * r_max / r_min;
470 let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 2);
471 kernel.max(poly)
472 }
473 fn third_sup(&self, chart: &ChartRegion) -> f64 {
474 let r_max = chart.radial_r_max.unwrap_or(chart.radius);
475 let r_min = chart
476 .exclusion_r_min
477 .unwrap_or(chart.radius)
478 .max(f64::MIN_POSITIVE);
479 let kernel = 6.0 + 18.0 * r_max / r_min + 9.0 * r_max * r_max / (r_min * r_min);
480 let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 3);
481 kernel.max(poly)
482 }
483}
484
485/// Polynomial-block degree of a Duchon nullspace order, used to bound the
486/// nullspace columns like a monomial patch.
487trait DuchonOrderDegree {
488 fn order_degree(&self) -> usize;
489}
490
491impl DuchonOrderDegree for DuchonCoordinateEvaluator {
492 fn order_degree(&self) -> usize {
493 match self.order {
494 gam_terms::basis::DuchonNullspaceOrder::Zero => 0,
495 gam_terms::basis::DuchonNullspaceOrder::Linear => 1,
496 gam_terms::basis::DuchonNullspaceOrder::Degree(d) => d,
497 }
498 }
499}
500
501/// Per-column `g`-th jet sup of the Duchon polynomial nullspace block, treated
502/// as a monomial patch of degree `order_degree`.
503pub(crate) fn duchon_poly_jet_sup(
504 latent_dim: usize,
505 order_degree: usize,
506 chart: &ChartRegion,
507 order: u32,
508) -> f64 {
509 if order_degree == 0 {
510 return if order == 0 { 1.0 } else { 0.0 };
511 }
512 patch_jet_sup(latent_dim, order_degree, chart, order)
513}
514
515/// Decoder magnitude `Σ_m ‖B_{m,:}‖₂` of an atom's frozen decoder block: the
516/// factor that converts a per-column `Φ`-jet sup `B_g` into a reconstruction
517/// jet sup `‖∂^g m‖ ≤ |z|·decoder_row_norm_sum·B_g`.
518pub(crate) fn decoder_row_norm_sum(decoder: ArrayView2<'_, f64>) -> f64 {
519 let mut acc = 0.0;
520 for row in decoder.rows() {
521 acc += row.dot(&row).sqrt();
522 }
523 acc
524}
525
526#[derive(Debug, Clone, Copy)]
527pub(crate) struct ReconstructionJetSups {
528 pub(crate) value: f64,
529 pub(crate) jacobian: f64,
530 pub(crate) hessian: f64,
531 pub(crate) third: f64,
532}
533
534pub(crate) fn pair_trig_decoder_sup(
535 sin_row: ArrayView1<'_, f64>,
536 cos_row: ArrayView1<'_, f64>,
537) -> f64 {
538 let aa = sin_row.dot(&sin_row);
539 let bb = cos_row.dot(&cos_row);
540 let ab = sin_row.dot(&cos_row);
541 let trace = aa + bb;
542 let disc = ((aa - bb) * (aa - bb) + 4.0 * ab * ab).sqrt();
543 (0.5 * (trace + disc)).sqrt()
544}
545
546pub(crate) fn periodic_reconstruction_jet_sups(
547 decoder: ArrayView2<'_, f64>,
548) -> ReconstructionJetSups {
549 let mut value = 0.0;
550 let mut jacobian = 0.0;
551 let mut hessian = 0.0;
552 let mut third = 0.0;
553 if decoder.nrows() > 0 {
554 value += decoder.row(0).dot(&decoder.row(0)).sqrt();
555 }
556 let harmonics = decoder.nrows().saturating_sub(1) / 2;
557 for h in 1..=harmonics {
558 let sin_idx = 2 * h - 1;
559 let cos_idx = 2 * h;
560 let amp = pair_trig_decoder_sup(decoder.row(sin_idx), decoder.row(cos_idx));
561 let omega = std::f64::consts::TAU * h as f64;
562 value += amp;
563 jacobian += omega * amp;
564 hessian += omega.powi(2) * amp;
565 third += omega.powi(3) * amp;
566 }
567 for row in (1 + 2 * harmonics)..decoder.nrows() {
568 let amp = decoder.row(row).dot(&decoder.row(row)).sqrt();
569 value += amp;
570 let omega = std::f64::consts::TAU * harmonics.max(1) as f64;
571 jacobian += omega * amp;
572 hessian += omega.powi(2) * amp;
573 third += omega.powi(3) * amp;
574 }
575 ReconstructionJetSups {
576 value,
577 jacobian,
578 hessian,
579 third,
580 }
581}
582
583pub(crate) fn reconstruction_jet_sups(
584 atom: &SaeManifoldAtom,
585 sups: JetSups,
586) -> ReconstructionJetSups {
587 if matches!(atom.basis_kind, crate::manifold::SaeAtomBasisKind::Periodic) {
588 periodic_reconstruction_jet_sups(atom.decoder_coefficients.view())
589 } else {
590 let decoder_norm_sum = decoder_row_norm_sum(atom.decoder_coefficients.view());
591 ReconstructionJetSups {
592 value: decoder_norm_sum * sups.value,
593 jacobian: decoder_norm_sum * sups.jacobian,
594 hessian: decoder_norm_sum * sups.hessian,
595 third: decoder_norm_sum * sups.third,
596 }
597 }
598}
599
600/// The Hessian-Lipschitz constant `L` of the per-row encode objective `f_k` on
601/// a chart, assembled in closed form from the basis jet sups and the decoder /
602/// amplitude / target magnitudes. See the module docs for the derivation:
603///
604/// ```text
605/// L ≤ 3·‖J_m‖·‖∇²m‖ + ‖r‖·‖∇³m‖ + L_prior,
606/// ‖∂^g m‖ ≤ |z|·S_B·B_g, S_B = Σ_m ‖B_{m,:}‖,
607/// ‖r‖ ≤ ‖x‖ + |z|·S_B·B_0,
608/// ```
609///
610/// `prior_lipschitz` is the caller-supplied closed-form `L_prior` of the
611/// ARD/von-Mises coordinate prior (`0.0` if no prior is active on the encode).
612pub(crate) fn hessian_lipschitz_constant(
613 recon_sups: ReconstructionJetSups,
614 amplitude: f64,
615 target_norm: f64,
616 prior_lipschitz: f64,
617) -> f64 {
618 let z = amplitude.abs();
619 let m_jac = z * recon_sups.jacobian;
620 let m_hess = z * recon_sups.hessian;
621 let m_third = z * recon_sups.third;
622 let recon_value = z * recon_sups.value;
623 let r_norm = target_norm + recon_value;
624 3.0 * m_jac * m_hess + r_norm * m_third + prior_lipschitz
625}
626
627/// One offline-certified chart: a center, its Kantorovich constants, and the
628/// certified Newton-convergence radius `R_c` solved from `h = β·η·L ≤ ½` at the
629/// worst-case in-chart start.
630#[derive(Debug, Clone)]
631pub struct CertifiedChart {
632 pub region: ChartRegion,
633 /// Closed-form Hessian-Lipschitz constant `L` over the chart.
634 pub lipschitz: f64,
635 /// `β = ‖F'(t_c)⁻¹‖` at the chart center (worst-case in-chart start uses
636 /// the center's curvature; the radius is solved so the certificate holds for
637 /// any start in the ball).
638 pub beta_center: f64,
639 /// Certified Newton radius: starts within `radius` of `t_c` satisfy `h ≤ ½`.
640 pub certified_radius: f64,
641 /// Distilled amortized-encoder Jacobian for this chart (#1026 ladder item 3).
642 ///
643 /// The exact encode map `x ↦ t` solves `F(t; x) = J_m(t)ᵀ(m(t) − x) = 0`. By
644 /// the implicit function theorem its derivative at the converged root is
645 /// `dt/dx = −(∂_t F)⁻¹ (∂_x F) = H⁻¹ J_m` (since `∂_x F = −J_m`), so the
646 /// first-order Taylor expansion of the encode map about this chart's center
647 /// `t_c` is the closed-form AFFINE predictor
648 ///
649 /// ```text
650 /// t(x) ≈ t_c + (1/z) · A₁ · (x − z · m₁(t_c)), A₁ = (J₁ᵀJ₁ + ridge·I)⁻¹ J₁,
651 /// ```
652 ///
653 /// with `J₁ = Bᵀ J_Φ(t_c)` and `m₁(t_c) = BᵀΦ(t_c)` the AMPLITUDE-1
654 /// reconstruction jets (the amplitude `z` factors out analytically, so the
655 /// stored Jacobian is amplitude-free). This is the DISTILLED amortized
656 /// encoder of the #1026 thread: the per-row Hessian factorization + Newton
657 /// iteration is moved OFFLINE into this `d × p` matrix, leaving a single
658 /// `O(d·p)` mat-vec online — no per-row eigendecomposition, no second-jet
659 /// evaluation. The Kantorovich certificate is still evaluated AT the
660 /// predicted start, so the amortized prediction is trusted iff `h ≤ ½` and an
661 /// uncertified row still routes to the exact multi-start solve (the encoder
662 /// approximates inference, the certificate keeps it honest — the thread's
663 /// "encoder + certificate-gated exact fallback" deployment). `None` when the
664 /// center's Gauss–Newton block is singular (no certifiable amortization).
665 pub amortized_jacobian: Option<Array2<f64>>,
666 /// Amplitude-1 chart-center reconstruction `m₁(t_c) = BᵀΦ(t_c)` (length `p`),
667 /// the anchor the amortized predictor expands the encode map around.
668 pub recon_center: Array1<f64>,
669 /// Precomputed affine-predictor CONSTANT term `base = t_c − A₁·m₁(t_c)` (length
670 /// `d`), so the online amortized encode of a row `x` at amplitude `z` is the
671 /// single mat-vec `t̂ = base + (1/z)·A₁·x` with NO per-row `A₁·m₁` recompute.
672 /// Hoisting this atom-static term offline is what lets the massive-K index-routed
673 /// fast paths run a single allocation-free pass over rows (rather than a per-atom
674 /// GEMM sub-batch that degenerates to one row per group when `K ≫ N`). `None`
675 /// exactly when `amortized_jacobian` is `None` (singular Gauss–Newton block).
676 pub amortized_base: Option<Array1<f64>>,
677}
678
679/// The per-atom encode atlas: a set of certified charts covering the atom's
680/// coordinate domain, plus the decoder/amplitude scaling needed to recompute a
681/// per-row certificate online.
682#[derive(Debug, Clone)]
683pub struct AtomEncodeAtlas {
684 pub atom_index: usize,
685 pub latent_dim: usize,
686 pub decoder_norm_sum: f64,
687 pub charts: Vec<CertifiedChart>,
688}
689
690/// Result of a certified encode over a batch of rows, carrying the honesty
691/// flag: how many rows could NOT be certified and were flagged for the exact
692/// multi-start fallback (issue #1010 — no approximation enters silently).
693#[derive(Debug, Clone)]
694pub struct EncodeResult {
695 /// Per-row encoded latent coordinates (`n_rows × latent_dim`).
696 pub coords: Array2<f64>,
697 /// Per-row certificate: `true` ⇒ the row's start satisfied `h ≤ ½` and the
698 /// 1–2 Newton steps are exact-into-the-certified-ball; `false` ⇒ flagged.
699 pub certified: Vec<bool>,
700 /// Count of rows that could not be certified. These ride the payload so the
701 /// caller routes them to the exact multi-start encode — honesty, never
702 /// silent. Equals `certified.iter().filter(|c| !**c).count()`.
703 pub encode_uncertified_count: usize,
704}
705
706impl EncodeResult {
707 pub(crate) fn from_rows(coords: Array2<f64>, certified: Vec<bool>) -> Self {
708 let encode_uncertified_count = certified.iter().filter(|c| !**c).count();
709 Self {
710 coords,
711 certified,
712 encode_uncertified_count,
713 }
714 }
715}
716
717/// Per-row Kantorovich certificate at a start `t₀` for one atom encode.
718#[derive(Debug, Clone, Copy)]
719pub struct RowCertificate {
720 pub beta: f64,
721 pub eta: f64,
722 pub lipschitz: f64,
723 /// `h = β·η·L`. The row is certified iff `h ≤ ½`.
724 pub h: f64,
725}
726
727impl RowCertificate {
728 pub fn certified(&self) -> bool {
729 self.h.is_finite() && self.h <= KANTOROVICH_THRESHOLD
730 }
731}
732
733#[derive(Debug, Clone)]
734struct CertifiedEncodeProbe {
735 coord: Array1<f64>,
736 final_cert: RowCertificate,
737}
738
739/// Canonical flat-axis polynomial degree of a cylinder `S¹ × ℝ` atom — the
740/// degree the topology-race builder ([`gam_solve::structure_harvest`]) uses
741/// for the line axis (`CylinderHarmonicEvaluator::new(_, 2)`). The encode atlas
742/// recovers the circle harmonic count from the basis width using this degree, so
743/// the two must agree.
744pub(crate) const SAE_CYLINDER_LINE_DEGREE: usize = 2;
745
746/// Build a basis-family handle for one atom from its [`SaeManifoldAtom`]. The
747/// atlas needs to evaluate the jet sups, which live on the concrete evaluator
748/// types; the atom carries the evaluator as `Arc<dyn SaeBasisEvaluator>`, so we
749/// reconstruct the family bound from the atom's basis kind + width + centers.
750pub(crate) fn family_jet_sups(
751 atom: &SaeManifoldAtom,
752 chart: &ChartRegion,
753) -> Result<JetSups, String> {
754 use crate::manifold::SaeAtomBasisKind::*;
755 let m = atom.basis_size();
756 let d = atom.latent_dim;
757 let sups = match &atom.basis_kind {
758 Periodic => {
759 let ev = PeriodicHarmonicEvaluator::new(m)?;
760 JetSups::from_family(&ev, chart)
761 }
762 Torus => {
763 // Torus basis width is `(2H+1)^d`; recover the per-axis harmonic
764 // count `H` from `axis_m = m^(1/d)` rather than a sum formula.
765 let axis_m = integer_root(m, d.max(1));
766 let num_harmonics = axis_m.saturating_sub(1) / 2;
767 let ev = TorusHarmonicEvaluator::new(d, num_harmonics.max(1))?;
768 JetSups::from_family(&ev, chart)
769 }
770 Sphere => {
771 let ev = SphereChartEvaluator;
772 JetSups::from_family(&ev, chart)
773 }
774 Cylinder => {
775 // Cylinder width is `(2H+1)·(D+1)` with the canonical flat-axis
776 // degree `D = SAE_CYLINDER_LINE_DEGREE` (the harvest convention).
777 // Recover the per-axis circle harmonic count `H` from
778 // `2H+1 = m/(D+1)`.
779 let ml = SAE_CYLINDER_LINE_DEGREE + 1;
780 if d != 2 || ml == 0 || m % ml != 0 {
781 return Err(format!(
782 "EncodeAtlas: Cylinder atom requires latent_dim == 2 and width divisible by {ml}; got dim={d}, m={m}"
783 ));
784 }
785 let axis_mc = m / ml;
786 let h = axis_mc.saturating_sub(1) / 2;
787 let ev = CylinderHarmonicEvaluator::new(h.max(1), SAE_CYLINDER_LINE_DEGREE)?;
788 JetSups::from_family(&ev, chart)
789 }
790 Linear | EuclideanPatch | Poincare => {
791 // The patch width fixes max_degree implicitly; bound by a degree that
792 // covers the column count (conservative). Degree d-patch column count
793 // grows fast; we recover the smallest degree whose patch is ≥ m.
794 // Poincare atoms use the same tangent-coordinate polynomial decoder;
795 // their intrinsic smoothness differs in the penalty, not in Phi(t).
796 let degree = euclidean_patch_degree(d, m);
797 let ev = EuclideanPatchEvaluator::new(d, degree)?;
798 JetSups::from_family(&ev, chart)
799 }
800 Duchon => {
801 // UNSOUND — DO NOT TRUST for a certificate (F2/F3). This bound
802 // hard-codes cubic `φ(r) = r³` radial jets and a single origin center
803 // (`duchon_centers_from_atom`), but the real Duchon kernel is the
804 // polyharmonic `c·r^(2m−d)` (with log variants) over the atom's
805 // data-placed centers — so it can UNDER-estimate L and issue a FALSE
806 // certificate (the module's own warning). The atom does not expose its
807 // real order / center matrix / scaling to this crate, so no sound bound
808 // is available. `build_atom_atlas_from_centers` therefore REFUSES to
809 // certify Duchon atoms (emits uncertified charts → exact-encode
810 // fallback) and never reaches this arm; it is retained only so the
811 // family dispatch is total. If a future change threads the real
812 // order/centers here, replace this with the true `φ_{m,d}` jet bounds.
813 let centers = duchon_centers_from_atom(atom);
814 let conservative_m = m.max(1);
815 let ev = DuchonCoordinateEvaluator::new(centers, conservative_m)?;
816 JetSups::from_family(&ev, chart)
817 }
818 Precomputed(name) => {
819 return Err(format!(
820 "EncodeAtlas: precomputed basis '{name}' has no closed-form jet sup; route to exact encode"
821 ));
822 }
823 };
824 Ok(sups)
825}
826
827/// Smallest monomial-patch degree whose column count covers `m` basis columns.
828pub(crate) fn euclidean_patch_degree(latent_dim: usize, m: usize) -> usize {
829 // Column count of a degree-D patch in d vars is C(d+D, D). Grow D until it
830 // covers m; cap at m so a degenerate width still terminates.
831 let mut degree = 0usize;
832 while patch_column_count(latent_dim, degree) < m && degree < m {
833 degree += 1;
834 }
835 degree
836}
837
838/// Largest integer `a` with `a^k ≤ n` (the floor of the `k`-th root). Used to
839/// recover the per-axis harmonic width `axis_m` from a torus basis width
840/// `m = axis_m^d`.
841pub(crate) fn integer_root(n: usize, k: usize) -> usize {
842 if k == 0 {
843 return 1;
844 }
845 if k == 1 {
846 return n;
847 }
848 let mut a = 1usize;
849 loop {
850 let next = a + 1;
851 let mut pow: u128 = 1;
852 let mut overflow = false;
853 for _ in 0..k {
854 pow = pow.saturating_mul(next as u128);
855 if pow > n as u128 {
856 overflow = true;
857 break;
858 }
859 }
860 if overflow {
861 return a;
862 }
863 a = next;
864 }
865}
866
867pub(crate) fn patch_column_count(latent_dim: usize, degree: usize) -> usize {
868 // C(d + D, D)
869 let mut num = 1u128;
870 let mut den = 1u128;
871 for i in 1..=degree {
872 num *= (latent_dim + i) as u128;
873 den *= i as u128;
874 }
875 (num / den) as usize
876}
877
878/// Recover Duchon centers from an atom: when the evaluator is unavailable the
879/// atlas falls back to the atom's own latent-coordinate hull as the center set,
880/// which only affects the radial-tail bound conservatively.
881pub(crate) fn duchon_centers_from_atom(atom: &SaeManifoldAtom) -> Array2<f64> {
882 // One center at the origin in latent_dim space is a sound conservative
883 // default: the chart's own r_min / r_max bracket the true radial range.
884 Array2::<f64>::zeros((1, atom.latent_dim.max(1)))
885}
886
887/// The four per-column jet sups of a basis family over a chart.
888#[derive(Debug, Clone, Copy)]
889pub(crate) struct JetSups {
890 pub(crate) value: f64,
891 pub(crate) jacobian: f64,
892 pub(crate) hessian: f64,
893 pub(crate) third: f64,
894}
895
896impl JetSups {
897 pub(crate) fn from_family<B: BasisHessianLipschitz>(family: &B, chart: &ChartRegion) -> Self {
898 Self {
899 value: family.value_sup(chart),
900 jacobian: family.jacobian_sup(chart),
901 hessian: family.hessian_sup(chart),
902 third: family.third_sup(chart),
903 }
904 }
905}
906
907/// Evaluate one atom's encode objective gradient `F(t) = ∇f_k(t)` and the FULL
908/// Hessian `F'(t) = ∇²f_k(t)` at a single coordinate `t`, for a single target
909/// row `x` and fixed amplitude `z`. With `m(t) = z·BᵀΦ(t)`, `r = m − x`,
910/// `J_m = z·Bᵀ J_Φ`:
911///
912/// ```text
913/// g_t[a] = J_m[a] · r (= ∇f)
914/// H_tt[a,b] = J_m[a] · J_m[b] + r · ∂²m/∂t_a∂t_b (= ∇²f, FULL Hessian)
915/// ```
916///
917/// The certificate uses the FULL Hessian rather than the Gauss-Newton block
918/// `J_mᵀ J_m`. This is the principled choice for Newton–Kantorovich: the
919/// theorem certifies convergence of Newton on `F = ∇f` to the unique nearby
920/// ROOT of `∇f`, but a root of `∇f` can be a maximum. The full Hessian is
921/// positive-definite exactly on the genuine-minimum basin, so requiring
922/// `λ_min(H) > 0` (finite `β`) is what flags a start that would otherwise let
923/// Gauss-Newton march into the wrong root (e.g. the circle antipode, a local
924/// max where `∇f = 0` but the full curvature is negative). The residual term
925/// needs the basis second jet `∂²Φ/∂t²`; an evaluator without one returns
926/// `None`, and the row is flagged (no silent Gauss-Newton fallback).
927///
928/// The Hessian returned is the TRUE `∇²f_k` — no Levenberg ridge is added
929/// (F2). The Kantorovich certificate (`row_certificate`) and its `λ_min(H) > 0`
930/// saddle gate must see the genuine field: a ridged `H + λI` certifies neither
931/// the original objective nor a consistently regularized one (a
932/// locally-constant reconstruction has `H = 0`, whose ridged `λI` would falsely
933/// certify a non-isolated, non-unique root). Ridge stays only in the
934/// UNCERTIFIED amortized predictor (`center_amortized_jacobian`/`center_beta`).
935pub(crate) fn encode_grad_hess(
936 atom: &SaeManifoldAtom,
937 evaluator: &dyn SaeBasisEvaluator,
938 t: ArrayView1<'_, f64>,
939 x: ArrayView1<'_, f64>,
940 amplitude: f64,
941) -> Result<Option<(Array1<f64>, Array2<f64>)>, String> {
942 let d = atom.latent_dim;
943 let p = atom.output_dim();
944 let m = atom.basis_size();
945 let coords = t.to_shape((1, d)).map_err(|e| e.to_string())?.to_owned();
946 let (phi, jet) = evaluator.evaluate(coords.view())?;
947 if phi.dim() != (1, m) {
948 return Err(format!(
949 "encode_grad_hess: evaluator returned phi {:?}, expected (1, {m})",
950 phi.dim()
951 ));
952 }
953 let decoder = &atom.decoder_coefficients;
954 // Reconstruction m(t) = z · Bᵀ Φ(t) ∈ ℝᵖ.
955 let mut recon = Array1::<f64>::zeros(p);
956 for basis_col in 0..m {
957 let phi_v = phi[[0, basis_col]];
958 if phi_v == 0.0 {
959 continue;
960 }
961 for out in 0..p {
962 recon[out] += amplitude * phi_v * decoder[[basis_col, out]];
963 }
964 }
965 let residual = &recon - &x;
966 // J_m[axis] = z · Bᵀ (∂Φ/∂t_axis) ∈ ℝᵖ.
967 let mut jm = Array2::<f64>::zeros((d, p));
968 for axis in 0..d {
969 for basis_col in 0..m {
970 let dphi = jet[[0, basis_col, axis]];
971 if dphi == 0.0 {
972 continue;
973 }
974 for out in 0..p {
975 jm[[axis, out]] += amplitude * dphi * decoder[[basis_col, out]];
976 }
977 }
978 }
979 // The full-Hessian residual term needs ∂²Φ/∂t². No second jet ⇒ no
980 // certificate (flag), never a silent Gauss-Newton substitute.
981 let second = match evaluator.second_jet_dyn(coords.view()) {
982 Some(result) => result?,
983 None => return Ok(None),
984 };
985 // Residual · decoder-row `r·B_{basis,:}` is INDEPENDENT of the (a,b) axes, yet
986 // the old code recomputed it `d²` times inside the Hessian double loop. Hoist it
987 // to one O(m·p) pass so the per-axis curvature term is a cheap O(m) dot.
988 let mut rd = vec![0.0_f64; m];
989 for (basis_col, rd_col) in rd.iter_mut().enumerate() {
990 let mut dot = 0.0;
991 for out in 0..p {
992 dot += residual[out] * decoder[[basis_col, out]];
993 }
994 *rd_col = dot;
995 }
996 // g_t[axis] = J_m[axis] · r ; H_tt[a,b] = J_m[a]·J_m[b] + r·∂²m/∂t_a∂t_b.
997 // The full Hessian is symmetric (Gauss-Newton block + symmetric second jet), so
998 // compute the upper triangle and mirror — half the curvature work.
999 let mut g = Array1::<f64>::zeros(d);
1000 let mut h = Array2::<f64>::zeros((d, d));
1001 for a in 0..d {
1002 let ja = jm.row(a);
1003 g[a] = ja.dot(&residual);
1004 for b in a..d {
1005 // Gauss-Newton block.
1006 let mut hab = ja.dot(&jm.row(b));
1007 // Residual · second-jet curvature: r · ∂²m_{ab},
1008 // ∂²m_{ab}[out] = z · Σ_basis (∂²Φ/∂t_a∂t_b) · B[basis, out].
1009 let mut curv = 0.0;
1010 for basis_col in 0..m {
1011 let d2phi = second[[0, basis_col, a, b]];
1012 if d2phi == 0.0 {
1013 continue;
1014 }
1015 curv += amplitude * d2phi * rd[basis_col];
1016 }
1017 hab += curv;
1018 h[[a, b]] = hab;
1019 h[[b, a]] = hab;
1020 }
1021 }
1022 // NO ridge: the certificate must use the TRUE Hessian (F2). See the doc above.
1023 Ok(Some((g, h)))
1024}
1025
1026/// Operator-norm of `H⁻¹` (i.e. `β = 1/λ_min(H)`) and the Newton step
1027/// `δ = −H⁻¹ g` with `η = ‖δ‖`, from a symmetric PSD `H` and gradient `g`.
1028/// Returns `None` when `H` is numerically singular (λ_min ≤ 0) — an
1029/// uncertifiable start.
1030pub(crate) fn beta_eta_newton(
1031 h: ArrayView2<'_, f64>,
1032 g: ArrayView1<'_, f64>,
1033) -> Result<Option<(f64, f64, Array1<f64>)>, String> {
1034 // Closed-form fast paths for the tiny latent dims that dominate SAE atoms
1035 // (`d = 1, 2`), avoiding a faer eigendecomposition + its heap allocations on
1036 // the hottest per-row Newton inner loop. `β = 1/λ_min(H)`, `δ = −H⁻¹g`, and the
1037 // `λ_min ≤ 0` gate are computed directly; the symmetric-`H` reads mirror
1038 // `eigh(Side::Lower)` (which uses the lower triangle) exactly.
1039 let d = h.nrows();
1040 if d == 1 {
1041 let h00 = h[[0, 0]];
1042 if !(h00.is_finite() && h00 > 0.0) {
1043 return Ok(None);
1044 }
1045 let delta0 = -g[0] / h00;
1046 let mut delta = Array1::<f64>::zeros(1);
1047 delta[0] = delta0;
1048 return Ok(Some((1.0 / h00, delta0.abs(), delta)));
1049 }
1050 if d == 2 {
1051 // Symmetric H = [[a, b], [b, c]] read from the lower triangle.
1052 let a = h[[0, 0]];
1053 let b = h[[1, 0]];
1054 let c = h[[1, 1]];
1055 let tr = a + c;
1056 let det = a * c - b * b;
1057 // λ_min = ½(tr − √((a−c)² + 4b²)); ≥ 0 ⇒ H PSD, > 0 ⇒ PD.
1058 let disc = ((a - c) * (a - c) + 4.0 * b * b).max(0.0).sqrt();
1059 let lambda_min = 0.5 * (tr - disc);
1060 if !(lambda_min.is_finite() && lambda_min > 0.0) {
1061 return Ok(None);
1062 }
1063 // δ = −H⁻¹g with H⁻¹ = [[c, −b], [−b, a]] / det (det = λ_min·λ_max > 0).
1064 let inv_det = 1.0 / det;
1065 let g0 = g[0];
1066 let g1 = g[1];
1067 let d0 = -(c * g0 - b * g1) * inv_det;
1068 let d1 = -(a * g1 - b * g0) * inv_det;
1069 if !(d0.is_finite() && d1.is_finite()) {
1070 return Ok(None);
1071 }
1072 let mut delta = Array1::<f64>::zeros(2);
1073 delta[0] = d0;
1074 delta[1] = d1;
1075 let eta = (d0 * d0 + d1 * d1).sqrt();
1076 return Ok(Some((1.0 / lambda_min, eta, delta)));
1077 }
1078 let (vals, vecs) = h
1079 .eigh(Side::Lower)
1080 .map_err(|e| format!("beta_eta_newton: eigh failed: {e:?}"))?;
1081 let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
1082 if !(lambda_min.is_finite() && lambda_min > 0.0) {
1083 return Ok(None);
1084 }
1085 let beta = 1.0 / lambda_min;
1086 // Newton step δ = −H⁻¹ g via the eigendecomposition: δ = −Σ_i (vᵢᵀg/λᵢ) vᵢ.
1087 let mut delta = Array1::<f64>::zeros(d);
1088 for (col, &lam) in vals.iter().enumerate() {
1089 if lam <= 0.0 {
1090 return Ok(None);
1091 }
1092 let vi = vecs.column(col);
1093 let coeff = vi.dot(&g) / lam;
1094 for row in 0..d {
1095 delta[row] -= coeff * vi[row];
1096 }
1097 }
1098 let eta = delta.dot(&delta).sqrt();
1099 Ok(Some((beta, eta, delta)))
1100}
1101
1102/// Compute the per-row Kantorovich certificate for encoding target row `x`
1103/// against atom `atom` at start coordinate `t₀`, with fixed amplitude `z` and
1104/// the chart's closed-form Lipschitz constant `lipschitz`. Returns the
1105/// certificate AND the Newton step `δ = −H⁻¹ g` so the caller can advance.
1106pub fn row_certificate(
1107 atom: &SaeManifoldAtom,
1108 evaluator: &dyn SaeBasisEvaluator,
1109 t0: ArrayView1<'_, f64>,
1110 x: ArrayView1<'_, f64>,
1111 amplitude: f64,
1112 lipschitz: f64,
1113) -> Result<(RowCertificate, Array1<f64>), String> {
1114 let uncertified = || {
1115 (
1116 RowCertificate {
1117 beta: f64::INFINITY,
1118 eta: f64::INFINITY,
1119 lipschitz,
1120 h: f64::INFINITY,
1121 },
1122 Array1::<f64>::zeros(atom.latent_dim),
1123 )
1124 };
1125 // No second jet ⇒ no full Hessian ⇒ uncertifiable (flag).
1126 let Some((g, h)) = encode_grad_hess(atom, evaluator, t0, x, amplitude)? else {
1127 return Ok(uncertified());
1128 };
1129 match beta_eta_newton(h.view(), g.view())? {
1130 Some((beta, eta, delta)) => {
1131 let cert = RowCertificate {
1132 beta,
1133 eta,
1134 lipschitz,
1135 h: beta * eta * lipschitz,
1136 };
1137 Ok((cert, delta))
1138 }
1139 // Indefinite / negative-curvature full Hessian: the start is at or past
1140 // a basin boundary (a max/saddle of f), not the minimum basin — flag.
1141 None => Ok(uncertified()),
1142 }
1143}
1144
1145fn uncertified_certificate(lipschitz: f64) -> RowCertificate {
1146 RowCertificate {
1147 beta: f64::INFINITY,
1148 eta: f64::INFINITY,
1149 lipschitz,
1150 h: f64::INFINITY,
1151 }
1152}
1153
1154fn refine_certified_start(
1155 atom: &SaeManifoldAtom,
1156 evaluator: &dyn SaeBasisEvaluator,
1157 mut t: Array1<f64>,
1158 x: ArrayView1<'_, f64>,
1159 amplitude: f64,
1160 lipschitz: f64,
1161 newton_steps: usize,
1162 initial_cert: RowCertificate,
1163 mut delta: Array1<f64>,
1164) -> Result<Option<CertifiedEncodeProbe>, String> {
1165 assert!(initial_cert.certified());
1166 let mut final_cert = initial_cert;
1167 for _ in 0..newton_steps {
1168 // Convergence early-exit: the pending Newton step is below the coordinate
1169 // ULP scale, so `t + δ == t` to f64 resolution — the certified root is
1170 // reached and the remaining fixed-budget steps would only re-accumulate
1171 // round-off. This is where the well-conditioned quadratic Newton tail's
1172 // redundant `evaluate` + `second_jet` work is eliminated.
1173 if delta.dot(&delta).sqrt() <= NEWTON_REFINE_CONVERGED_EPS * (1.0 + t.dot(&t).sqrt()) {
1174 break;
1175 }
1176 t = &t + δ
1177 let (cert, next_delta) =
1178 row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz)?;
1179 if !cert.certified() {
1180 return Ok(None);
1181 }
1182 final_cert = cert;
1183 delta = next_delta;
1184 }
1185 Ok(Some(CertifiedEncodeProbe {
1186 coord: t,
1187 final_cert,
1188 }))
1189}
1190
1191/// Certify an encode probe from `t_start`, navigating into the Kantorovich basin
1192/// first if needed (#1154/#1026). The Kantorovich quantity `h = β·η·L` scales with
1193/// amplitude through `L`, so at unit amplitude a positive-definite chart-center /
1194/// distilled start can sit OUTSIDE the certified ball (`h > ½`). Rather than
1195/// flagging it uncertified immediately — which made the encoder certify ZERO
1196/// held-out rows at amplitude 1.0 and fall back to the exact solve for everything —
1197/// take plain Newton steps toward the root, re-certifying at each iterate, while
1198/// the Kantorovich quantity `h = β·η·L` keeps CONTRACTING toward the ½ bound. The
1199/// certificate at the landing point is a full Kantorovich guarantee from there
1200/// (`h ≤ ½` ⇒ Newton converges to the in-ball root), so this only ever WIDENS the
1201/// certified set; it never certifies a non-convergent start.
1202///
1203/// Termination is the natural Newton stopping rule — there is no arbitrary step
1204/// budget. The warm-up stops and flags for the exact fallback when either the start
1205/// is not steppable (indefinite / non-finite Hessian — at or past a basin boundary)
1206/// or a step fails to reduce `h` (the iterate is not approaching a certifiable
1207/// in-chart root: its root lies outside this chart's valid Lipschitz region, or the
1208/// start was past the basin — empirically the rows that miss *plateau*, so more
1209/// steps cannot help; the lever there is denser charts, not more iterations). On
1210/// success the start is refined `newton_steps` further by [`refine_certified_start`].
1211/// Minimum per-step *multiplicative* decrease of the Kantorovich `h` the basin
1212/// warm-up requires to keep stepping (FIX #4). A tiny geometric floor: a
1213/// continued step must contract `h` by at least this fraction. Chosen small so
1214/// it never bites a genuinely converging row (Newton in the Kantorovich regime
1215/// is at least geometric and quadratic once `h < 1`, contracting `h` far faster
1216/// than `1/64` per step) while still forcing termination on a plateau.
1217const WARMUP_MIN_MULTIPLICATIVE_DECREASE: f64 = 1.0 / 64.0;
1218
1219/// Kantorovich-quadratic acceptance coefficient for the basin warm-up (FIX #4).
1220/// Once `h < 1` a converging Newton step contracts quadratically (`h_new ≲ κ·h²`);
1221/// accepting that path in addition to the geometric floor makes the
1222/// "genuinely-converging rows are untouched" guarantee explicit. Kept `< 1` so
1223/// the quadratic path is itself a strict contraction (for `h < 1`,
1224/// `κ·h² < κ·h < h`), which preserves the termination bound below.
1225const WARMUP_QUADRATIC_KAPPA: f64 = 0.5;
1226
1227/// Sufficient-decrease test for the basin-warmup loop (FIX #4).
1228///
1229/// Returns `true` while the warm-up should keep stepping. The previous exit rule
1230/// (`h_new >= h_prev` — strict decrease or quit) never fires for an `h`-sequence
1231/// that decreases *monotonically toward a limit above ½*: the increments fall
1232/// below one ulp long before `h` crosses the certifiable ½ bound, so a single
1233/// pathological row could spin ~1e15 full `row_certificate` solves (Hessian
1234/// build + solve) on the encode hot path. We instead require genuine
1235/// *multiplicative* progress each step, which matches Newton's actual behavior:
1236/// a healthy contraction clears the geometric floor by a wide margin (and the
1237/// quadratic path once `h < 1`), while a plateau (`h_new/h_prev → 1`) satisfies
1238/// neither branch and flags to the exact fallback.
1239///
1240/// Termination bound: the warm-up loop runs only while the row is uncertified
1241/// (`h_prev > ½`), and every continued step contracts `h` by at least the factor
1242/// `(1 − WARMUP_MIN_MULTIPLICATIVE_DECREASE)` (the quadratic branch, for `h < 1`,
1243/// contracts by `κ·h_prev < κ < 1`, i.e. even harder). Hence after `N` continued
1244/// steps `h ≤ (1 − c)^N · h₀`, and since the loop stops once `h ≤ ½` we get
1245/// `N < ln(2·h₀) / −ln(1 − c)` — a few hundred iterations at most, versus
1246/// unbounded before. No arbitrary fixed step cap is imposed; the bound is a
1247/// consequence of the contraction requirement, in keeping with the function's
1248/// no-magic-budget design.
1249fn warmup_progress_sufficient(h_new: f64, h_prev: f64) -> bool {
1250 if !(h_new.is_finite() && h_prev.is_finite()) {
1251 return false;
1252 }
1253 // Geometric floor: a real Newton contraction easily clears this.
1254 if h_new <= (1.0 - WARMUP_MIN_MULTIPLICATIVE_DECREASE) * h_prev {
1255 return true;
1256 }
1257 // Kantorovich-quadratic path (only once `h < 1`, where it is a strict
1258 // contraction): makes the no-regression guarantee for converging rows
1259 // explicit without ever admitting a non-contracting (plateau) step.
1260 h_prev < 1.0 && h_new <= WARMUP_QUADRATIC_KAPPA * h_prev * h_prev
1261}
1262
1263/// Whether the basin warm-up must REJECT the just-taken step (flag to the exact
1264/// fallback). FIX (F6): a step that just crossed into the certified region
1265/// (`h ≤ ½`) is ALWAYS accepted, even if its multiplicative decrease was below
1266/// the progress floor (e.g. `h: 0.501 → 0.499`, a legitimate but tiny cross of
1267/// the ½ bound). Only a step that is STILL uncertified AND failed to make
1268/// sufficient multiplicative progress is a plateau that must flag. The old code
1269/// ran the progress test unconditionally, so it could reject an in-hand
1270/// certificate and push the row to the exact fallback for nothing (a false
1271/// negative).
1272fn warmup_should_reject(next_certified: bool, h_new: f64, h_prev: f64) -> bool {
1273 !next_certified && !warmup_progress_sufficient(h_new, h_prev)
1274}
1275
1276fn certify_with_basin_warmup(
1277 atom: &SaeManifoldAtom,
1278 evaluator: &dyn SaeBasisEvaluator,
1279 t_start: Array1<f64>,
1280 x: ArrayView1<'_, f64>,
1281 amplitude: f64,
1282 lipschitz: f64,
1283 newton_steps: usize,
1284 chart_center: ArrayView1<'_, f64>,
1285 chart_radius: f64,
1286) -> Result<Option<CertifiedEncodeProbe>, String> {
1287 // SOUNDNESS GUARD: `lipschitz` is the chart's Hessian-Lipschitz sup, which is
1288 // only a valid bound over this chart's ball `‖t − center‖ ≤ radius` for the
1289 // chart-local families (`EuclideanPatch`/`Linear`/`Poincare` monomial patches,
1290 // `Cylinder` line axis, `Duchon` radial kernels). If a warm-up iterate leaves
1291 // that ball, `row_certificate` would compute `h = β·η·L` with an `L` that no
1292 // longer bounds the true geometry there, so `h ≤ ½` would NOT imply Kantorovich
1293 // convergence — a false certificate. (The `h`-contraction check does NOT catch
1294 // this: `h` can decrease monotonically toward an out-of-chart root the whole
1295 // way.) So we keep every certified iterate inside the chart; a row whose root is
1296 // outside this chart flags for the exact fallback — its lever is a denser grid,
1297 // not a step using an invalid `L`. Global-`L` families (periodic/torus/sphere)
1298 // route their points to charts whose centers are near the root, so the guard
1299 // rarely trips for them, and where it does the row was out-of-chart anyway.
1300 let in_chart = |t: &Array1<f64>| -> bool {
1301 // Wrap-aware chart containment. A raw Euclidean latent distance
1302 // `Σ(tᵢ − centerᵢ)²` mis-measures separation across the wrap seam of a
1303 // periodic axis: an iterate at `t = 0.99` against a chart centered at
1304 // `0.01` reads distance `0.98` instead of the true circle distance
1305 // `0.02`, so it is wrongly rejected at the start check and at every step
1306 // check — silently disabling the amortized encoder for a phase-localized
1307 // band of rows on every periodic / torus / cylinder-angle /
1308 // sphere-longitude chart and pushing that band to the multi-start
1309 // fallback. `latent_coordinate_distance` folds each axis onto its
1310 // `latent_axis_period`, so the containment ball is measured in the true
1311 // manifold geometry — exactly the geometry the chart's Lipschitz bound
1312 // `L` is valid over. This only ever moves points that are genuinely
1313 // near the center (small wrapped distance) back INTO the chart where
1314 // `L` holds, so it widens acceptance without weakening the soundness
1315 // guard (an iterate truly far from the center on the circle still fails).
1316 latent_coordinate_distance(atom, t.view(), chart_center) <= chart_radius
1317 };
1318 let mut t = t_start;
1319 // The distilled / chart-center start must itself be in-chart for its certificate
1320 // to be valid; a bad IFT prediction landing outside the chart is uncertifiable.
1321 if !in_chart(&t) {
1322 return Ok(None);
1323 }
1324 let (mut cert, mut delta) =
1325 row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz)?;
1326 while !cert.certified() {
1327 // Not steppable (indefinite / non-finite Hessian): flag.
1328 if !(cert.h.is_finite() && cert.beta.is_finite() && cert.eta.is_finite()) {
1329 return Ok(None);
1330 }
1331 let prev_h = cert.h;
1332 let next = &t + δ
1333 // Refuse to step where the chart's `L` is no longer valid (see guard above).
1334 if !in_chart(&next) {
1335 return Ok(None);
1336 }
1337 t = next;
1338 let (next_cert, next_delta) =
1339 row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz)?;
1340 cert = next_cert;
1341 delta = next_delta;
1342 // The warm-up only helps while h keeps *multiplicatively* contracting
1343 // toward ½. A plain strict-decrease test (`h >= prev_h`) never fires for
1344 // a sequence that decreases monotonically toward a limit above ½, so it
1345 // could spin ~1e15 `row_certificate` solves for one row; require genuine
1346 // multiplicative progress instead (bounded to a few hundred steps, see
1347 // `warmup_progress_sufficient`). Once a step fails that bar the iterate is
1348 // not converging to a certifiable in-chart root — flag for the exact
1349 // fallback (no arbitrary step budget; the bound falls out of the
1350 // contraction requirement).
1351 //
1352 // F6: only enforce the progress bar while STILL uncertified. If this step
1353 // just crossed `h ≤ ½` (e.g. 0.501 → 0.499, a decrease too small to clear
1354 // the multiplicative floor) the point is already certified — the loop
1355 // guard `while !cert.certified()` will exit and refine it. Running the
1356 // progress test unconditionally would falsely reject an in-hand certificate
1357 // and push the row to the exact fallback for nothing. See
1358 // [`warmup_should_reject`].
1359 if warmup_should_reject(cert.certified(), cert.h, prev_h) {
1360 return Ok(None);
1361 }
1362 }
1363 refine_certified_start(
1364 atom,
1365 evaluator,
1366 t,
1367 x,
1368 amplitude,
1369 lipschitz,
1370 newton_steps,
1371 cert,
1372 delta,
1373 )
1374}
1375
1376fn kantorovich_root_radius(cert: RowCertificate) -> f64 {
1377 if !cert.certified() || !(cert.eta.is_finite() && cert.eta >= 0.0) {
1378 return f64::INFINITY;
1379 }
1380 if cert.eta == 0.0 {
1381 return 0.0;
1382 }
1383 if !(cert.h.is_finite() && cert.h >= 0.0) {
1384 return f64::INFINITY;
1385 }
1386 let h = cert.h.min(KANTOROVICH_THRESHOLD);
1387 let discriminant = (1.0 - 2.0 * h).max(0.0).sqrt();
1388 let radius = 2.0 * cert.eta / (1.0 + discriminant);
1389 if radius.is_finite() {
1390 radius
1391 } else {
1392 f64::INFINITY
1393 }
1394}
1395
1396fn distilled_probe_tolerance(
1397 amortized: &CertifiedEncodeProbe,
1398 cold: &CertifiedEncodeProbe,
1399 amplitude: f64,
1400 x: ArrayView1<'_, f64>,
1401) -> f64 {
1402 let certified_radius =
1403 kantorovich_root_radius(amortized.final_cert) + kantorovich_root_radius(cold.final_cert);
1404 let coord_scale = amortized.coord.dot(&amortized.coord).sqrt()
1405 + cold.coord.dot(&cold.coord).sqrt()
1406 + x.dot(&x).sqrt()
1407 + amplitude.abs()
1408 + 1.0;
1409 certified_radius + 1024.0 * f64::EPSILON * coord_scale
1410}
1411
1412fn latent_coordinate_distance(
1413 atom: &SaeManifoldAtom,
1414 lhs: ArrayView1<'_, f64>,
1415 rhs: ArrayView1<'_, f64>,
1416) -> f64 {
1417 let mut acc = 0.0;
1418 for axis in 0..lhs.len().min(rhs.len()) {
1419 let mut diff = (lhs[axis] - rhs[axis]).abs();
1420 if let Some(period) = latent_axis_period(atom, axis) {
1421 let wrapped = diff.rem_euclid(period);
1422 diff = wrapped.min(period - wrapped);
1423 }
1424 acc += diff * diff;
1425 }
1426 acc.sqrt()
1427}
1428
1429fn latent_axis_period(atom: &SaeManifoldAtom, axis: usize) -> Option<f64> {
1430 use crate::manifold::SaeAtomBasisKind::*;
1431 match &atom.basis_kind {
1432 Periodic | Torus => Some(1.0),
1433 Cylinder if axis == 0 => Some(1.0),
1434 Sphere if axis == 1 => Some(std::f64::consts::TAU),
1435 _ => None,
1436 }
1437}
1438
1439/// Configuration for [`EncodeAtlas`] construction and online encode. All fields
1440/// are explicit; the atlas never reads global state and adds no CLI flags.
1441#[derive(Debug, Clone, Copy)]
1442pub struct AtlasConfig {
1443 /// Grid resolution per latent axis for offline chart centers (the
1444 /// SHAPE_BAND grid idiom).
1445 pub grid_resolution: usize,
1446 /// Levenberg ridge floor added to the per-row Gauss-Newton Hessian.
1447 pub ridge: f64,
1448 /// Number of online Newton refinement steps after a certified start (1 or 2
1449 /// per issue #1010).
1450 pub newton_steps: usize,
1451}
1452
1453impl Default for AtlasConfig {
1454 fn default() -> Self {
1455 Self {
1456 grid_resolution: 16,
1457 ridge: 1.0e-9,
1458 newton_steps: 2,
1459 }
1460 }
1461}
1462
1463/// The encode atlas: per-atom certified charts plus the online certified-encode
1464/// driver (issue #1010).
1465#[derive(Debug, Clone)]
1466pub struct EncodeAtlas {
1467 pub atoms: Vec<AtomEncodeAtlas>,
1468 pub config: AtlasConfig,
1469}
1470
1471impl EncodeAtlas {
1472 /// Build the offline atlas over a frozen dictionary: for each atom, lay down
1473 /// chart centers on the atom's coordinate grid and certify a Newton radius
1474 /// from the Kantorovich inequality at the worst-case in-chart start.
1475 ///
1476 /// `amplitude_bound[k]` is the per-atom bound on `|z_k|` used to scale the
1477 /// reconstruction jets (the offline `L` must hold for the largest amplitude
1478 /// the encode can produce); `target_norm_bound` bounds `‖x‖` over the data.
1479 pub fn build(
1480 atoms: &[SaeManifoldAtom],
1481 amplitude_bound: &[f64],
1482 target_norm_bound: f64,
1483 config: AtlasConfig,
1484 ) -> Result<Self, String> {
1485 if amplitude_bound.len() != atoms.len() {
1486 return Err(format!(
1487 "EncodeAtlas::build: amplitude_bound length {} != atom count {}",
1488 amplitude_bound.len(),
1489 atoms.len()
1490 ));
1491 }
1492 let mut atom_atlases = Vec::with_capacity(atoms.len());
1493 for (k, atom) in atoms.iter().enumerate() {
1494 let atlas =
1495 Self::build_atom_atlas(k, atom, amplitude_bound[k], target_norm_bound, &config)?;
1496 atom_atlases.push(atlas);
1497 }
1498 Ok(Self {
1499 atoms: atom_atlases,
1500 config,
1501 })
1502 }
1503
1504 pub(crate) fn build_atom_atlas(
1505 atom_index: usize,
1506 atom: &SaeManifoldAtom,
1507 amplitude_bound: f64,
1508 target_norm_bound: f64,
1509 config: &AtlasConfig,
1510 ) -> Result<AtomEncodeAtlas, String> {
1511 let centers = chart_center_grid(atom, config.grid_resolution);
1512 // Half the inter-center spacing is the natural in-chart radius so the
1513 // charts tile the grid without gaps; refined below if the certificate
1514 // fails at that radius. One uniform radius for the regular grid.
1515 let nominal_radius = chart_nominal_radius(atom, config.grid_resolution);
1516 let radii = vec![nominal_radius; centers.nrows()];
1517 Self::build_atom_atlas_from_centers(
1518 atom_index,
1519 atom,
1520 centers.view(),
1521 &radii,
1522 amplitude_bound,
1523 target_norm_bound,
1524 config,
1525 )
1526 }
1527
1528 /// Build a per-atom atlas from EXPLICIT chart centers with a per-center
1529 /// nominal radius — the geometry-agnostic core shared by the regular-grid
1530 /// [`Self::build_atom_atlas`] and the data-driven [`Self::build_data_driven`].
1531 /// Every chart is certified identically (Kantorovich radius from the in-chart
1532 /// curvature at its center); only the center PLACEMENT and per-center radius
1533 /// differ. `radii[c]` is the nominal in-chart radius for `centers[c]`.
1534 pub(crate) fn build_atom_atlas_from_centers(
1535 atom_index: usize,
1536 atom: &SaeManifoldAtom,
1537 centers: ArrayView2<'_, f64>,
1538 radii: &[f64],
1539 amplitude_bound: f64,
1540 target_norm_bound: f64,
1541 config: &AtlasConfig,
1542 ) -> Result<AtomEncodeAtlas, String> {
1543 let d = atom.latent_dim;
1544 if centers.ncols() != d {
1545 return Err(format!(
1546 "build_atom_atlas_from_centers: centers have {} cols but atom latent_dim is {d}",
1547 centers.ncols()
1548 ));
1549 }
1550 if radii.len() != centers.nrows() {
1551 return Err(format!(
1552 "build_atom_atlas_from_centers: {} radii != {} centers",
1553 radii.len(),
1554 centers.nrows()
1555 ));
1556 }
1557 let decoder_norm_sum = decoder_row_norm_sum(atom.decoder_coefficients.view());
1558 let mut charts = Vec::with_capacity(centers.nrows());
1559 // HONEST REFUSAL for Duchon atoms (F2/F3): the closed-form Hessian-Lipschitz
1560 // bound available here (`family_jet_sups` Duchon arm) hard-codes cubic-r³
1561 // jets and a single origin center, but the real Duchon kernel is the
1562 // polyharmonic `c·r^(2m−d)` (with log variants) over data-placed centers.
1563 // That bound can UNDER-estimate L → a FALSE certificate (the module's own
1564 // warning; underestimating L is the dangerous direction). The atom does not
1565 // expose its real order/centers/scaling to this crate, so no sound bound is
1566 // available — refuse rather than fabricate. Every Duchon chart is emitted
1567 // UNCERTIFIED (`certified_radius = 0`, no amortized predictor), so routing
1568 // skips it and every Duchon row flags for the exact multi-start encode.
1569 let duchon_uncertifiable =
1570 matches!(atom.basis_kind, crate::manifold::SaeAtomBasisKind::Duchon);
1571 for c in 0..centers.nrows() {
1572 let center = centers.row(c).to_owned();
1573 let nominal_radius = radii[c];
1574 let region = chart_region(atom, center.clone(), nominal_radius);
1575 if duchon_uncertifiable {
1576 charts.push(CertifiedChart {
1577 region,
1578 lipschitz: f64::INFINITY,
1579 beta_center: f64::INFINITY,
1580 certified_radius: 0.0,
1581 amortized_jacobian: None,
1582 recon_center: Array1::<f64>::zeros(atom.output_dim()),
1583 amortized_base: None,
1584 });
1585 continue;
1586 }
1587 let sups = family_jet_sups(atom, ®ion)?;
1588 let recon_sups = reconstruction_jet_sups(atom, sups);
1589 let lipschitz =
1590 hessian_lipschitz_constant(recon_sups, amplitude_bound, target_norm_bound, 0.0);
1591 // β at the chart center bounds the worst-case in-chart curvature
1592 // (the Gauss-Newton Hessian is continuous; the certified radius is
1593 // solved so the certificate is robust to the start within the ball).
1594 let beta_center = match center_beta(atom, ¢er, config.ridge) {
1595 Some(b) => b,
1596 None => {
1597 // Degenerate center curvature: no certifiable chart here, and
1598 // no amortized Jacobian (the same singular Gauss–Newton block).
1599 charts.push(CertifiedChart {
1600 region,
1601 lipschitz,
1602 beta_center: f64::INFINITY,
1603 certified_radius: 0.0,
1604 amortized_jacobian: None,
1605 recon_center: Array1::<f64>::zeros(atom.output_dim()),
1606 amortized_base: None,
1607 });
1608 continue;
1609 }
1610 };
1611 // Distill the amortized-encoder Jacobian at this center (#1026 ladder
1612 // item 3): the IFT derivative of the encode map, precomputed offline
1613 // so the online encode is one mat-vec. A finite `beta_center` (above)
1614 // means the Gauss–Newton block is non-singular, so this succeeds
1615 // alongside it; the pair travels together on the chart.
1616 let (amortized_jacobian, recon_center) =
1617 match center_amortized_jacobian(atom, ¢er, config.ridge) {
1618 Some((a1, m1)) => (Some(a1), m1),
1619 None => (None, Array1::<f64>::zeros(atom.output_dim())),
1620 };
1621 // Certified radius from h = β·η·L ≤ ½ with η ≤ R (Newton step length
1622 // is bounded by the start distance to the root, itself ≤ chart
1623 // radius at worst): R_c = ½ / (β·L), capped at the nominal radius.
1624 let certified_radius = if lipschitz > 0.0 && beta_center.is_finite() {
1625 (0.5 / (beta_center * lipschitz)).min(region.radius)
1626 } else {
1627 region.radius
1628 };
1629 // Precompute the affine-predictor constant `base = t_c − A₁·m₁` (atom-
1630 // static), so the online encode is a single `base + (1/z)·A₁·x` mat-vec.
1631 let amortized_base = amortized_jacobian
1632 .as_ref()
1633 .map(|a1| ¢er - &a1.dot(&recon_center));
1634 charts.push(CertifiedChart {
1635 region,
1636 lipschitz,
1637 beta_center,
1638 certified_radius,
1639 amortized_jacobian,
1640 recon_center,
1641 amortized_base,
1642 });
1643 }
1644 Ok(AtomEncodeAtlas {
1645 atom_index,
1646 latent_dim: d,
1647 decoder_norm_sum,
1648 charts,
1649 })
1650 }
1651
1652 /// Build the atlas with DATA-DRIVEN chart placement: instead of a dense
1653 /// `resolution^d` product grid (exponential in latent dim `d`, so the regular
1654 /// [`Self::build`] is forced to coarse, poorly-certified charts for `d ≥ 3`),
1655 /// place a bounded number of charts AT the data's own latent coordinates. The
1656 /// chart count is then `O(max_charts)` regardless of `d`, and every chart sits
1657 /// where data actually lands (small in-chart residual → certifies), so
1658 /// higher-dimensional atoms — which reconstruct real activations far better per
1659 /// parameter — become affordable and well-covered.
1660 ///
1661 /// `coords[k]` is atom `k`'s `n × d_k` latent coordinates (the seed coords, or
1662 /// a previous encode's output). Charts are chosen by greedy farthest-point
1663 /// sampling over those coords (deterministic, coverage-maximizing), capped at
1664 /// `max_charts`. Each chart's nominal radius is half the distance to its
1665 /// nearest neighbor center, so the charts tile the local data density. The
1666 /// per-chart Kantorovich certification is IDENTICAL to the regular grid — only
1667 /// the center placement differs.
1668 pub fn build_data_driven(
1669 atoms: &[SaeManifoldAtom],
1670 coords: &[Array2<f64>],
1671 amplitude_bound: &[f64],
1672 target_norm_bound: f64,
1673 max_charts: usize,
1674 config: AtlasConfig,
1675 ) -> Result<Self, String> {
1676 if amplitude_bound.len() != atoms.len() || coords.len() != atoms.len() {
1677 return Err(format!(
1678 "build_data_driven: amplitude_bound {} / coords {} must match atom count {}",
1679 amplitude_bound.len(),
1680 coords.len(),
1681 atoms.len()
1682 ));
1683 }
1684 let mut atom_atlases = Vec::with_capacity(atoms.len());
1685 for (k, atom) in atoms.iter().enumerate() {
1686 let (centers, radii) =
1687 data_driven_chart_centers(atom, coords[k].view(), max_charts.max(1))?;
1688 let atlas = Self::build_atom_atlas_from_centers(
1689 k,
1690 atom,
1691 centers.view(),
1692 &radii,
1693 amplitude_bound[k],
1694 target_norm_bound,
1695 &config,
1696 )?;
1697 atom_atlases.push(atlas);
1698 }
1699 Ok(Self {
1700 atoms: atom_atlases,
1701 config,
1702 })
1703 }
1704
1705 fn refine_certified_encode_start(
1706 &self,
1707 atom: &SaeManifoldAtom,
1708 evaluator: &dyn SaeBasisEvaluator,
1709 chart: &CertifiedChart,
1710 t: Array1<f64>,
1711 x: ArrayView1<'_, f64>,
1712 amplitude: f64,
1713 ) -> Result<(Array1<f64>, RowCertificate), String> {
1714 // Certify from the warm start, navigating into the Kantorovich basin first
1715 // if the unit-amplitude start has h > ½ (see `certify_with_basin_warmup`).
1716 let Some(probe) = certify_with_basin_warmup(
1717 atom,
1718 evaluator,
1719 t,
1720 x,
1721 amplitude,
1722 chart.lipschitz,
1723 self.config.newton_steps,
1724 chart.region.center.view(),
1725 chart.region.radius,
1726 )?
1727 else {
1728 return Ok((
1729 Array1::<f64>::zeros(atom.latent_dim),
1730 uncertified_certificate(chart.lipschitz),
1731 ));
1732 };
1733 // F5: pair the REFINED coordinate with the certificate evaluated AT that
1734 // refined landing point (`final_cert`), not the pre-refinement
1735 // `initial_cert`. Returning `initial_cert` describes a different (earlier)
1736 // iterate's β/η/h — its Kantorovich root-radius overstates the refined
1737 // point's distance to the root.
1738 Ok((probe.coord, probe.final_cert))
1739 }
1740
1741 /// Online certified encode of one target row `x` against one atom `k` with
1742 /// fixed amplitude `z`. Routes to the nearest chart, starts from that chart's
1743 /// distilled IFT warm start, runs `config.newton_steps` Newton steps, and
1744 /// returns the encoded coordinate with its certificate. An uncertified start
1745 /// (no chart, no distilled Jacobian, non-positive amplitude, or `h > ½`)
1746 /// flags the row for the exact multi-start caller.
1747 pub fn certified_encode_row(
1748 &self,
1749 atom: &SaeManifoldAtom,
1750 atom_index: usize,
1751 x: ArrayView1<'_, f64>,
1752 amplitude: f64,
1753 ) -> Result<(Array1<f64>, RowCertificate), String> {
1754 let atom_atlas = self
1755 .atoms
1756 .get(atom_index)
1757 .ok_or_else(|| format!("certified_encode_row: atom {atom_index} not in atlas"))?;
1758 let d = atom.latent_dim;
1759 // A missing basis evaluator means the amortized/cold predictor cannot fire
1760 // for this atom (e.g. a frozen-baseline or first-build atom that never
1761 // attached a distilled evaluator). That is exactly the "cannot certify"
1762 // state — flag the row uncertified (zeros coords, ∞ certificate) so the
1763 // upstream exact multi-start solve owns it, never a hard error that aborts
1764 // the whole criterion. Mirrors the no-chart / singular-Jacobian branches.
1765 let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
1766 return Ok((
1767 Array1::<f64>::zeros(d),
1768 RowCertificate {
1769 beta: f64::INFINITY,
1770 eta: f64::INFINITY,
1771 lipschitz: f64::INFINITY,
1772 h: f64::INFINITY,
1773 },
1774 ));
1775 };
1776
1777 // Route to the nearest chart centers by AMBIENT reconstruction distance.
1778 // A single nearest chart is NOT globally sound on self-approaching atoms:
1779 // where the decoded manifold folds near itself (two distant latent points
1780 // map near the same output), the nearest-center chart can certify into the
1781 // locally-worse basin while another chart holds the GLOBAL minimum (both
1782 // branches' charts reconstruct near the crossing, so both are near in
1783 // ambient distance). The certificate is honest about LOCAL convergence but
1784 // cannot see the better far basin. So we refine in the top-K nearest charts
1785 // and keep the lowest-reconstruction-error CERTIFIED result. For a unimodal
1786 // atom every candidate chart converges to the same root, so this is a no-op
1787 // (first-wins tie → the nearest chart), preserving the existing behavior.
1788 let candidates = nearest_charts_topk(atom_atlas, x, amplitude, CERTIFIED_ROUTING_TOPK);
1789 if candidates.is_empty() {
1790 return Ok((
1791 Array1::<f64>::zeros(d),
1792 RowCertificate {
1793 beta: f64::INFINITY,
1794 eta: f64::INFINITY,
1795 lipschitz: f64::INFINITY,
1796 h: f64::INFINITY,
1797 },
1798 ));
1799 }
1800 // Best CERTIFIED result by reconstruction error, plus the nearest chart's
1801 // result as the uncertified fallback (preserving the prior return when no
1802 // candidate certifies — the nearest chart owns the flagged row).
1803 let mut best: Option<(Array1<f64>, RowCertificate, f64)> = None;
1804 let mut nearest_fallback: Option<(Array1<f64>, RowCertificate)> = None;
1805 for chart_idx in candidates {
1806 let chart = &atom_atlas.charts[chart_idx];
1807 let Some(t) = amortized_warm_start(chart, x, amplitude) else {
1808 if nearest_fallback.is_none() {
1809 nearest_fallback = Some((
1810 Array1::<f64>::zeros(d),
1811 uncertified_certificate(chart.lipschitz),
1812 ));
1813 }
1814 continue;
1815 };
1816 let (coord, cert) = self.refine_certified_encode_start(
1817 atom,
1818 evaluator.as_ref(),
1819 chart,
1820 t,
1821 x,
1822 amplitude,
1823 )?;
1824 if nearest_fallback.is_none() {
1825 nearest_fallback = Some((coord.clone(), cert.clone()));
1826 }
1827 if cert.certified() {
1828 let err = encode_reconstruction_error(
1829 atom,
1830 evaluator.as_ref(),
1831 coord.view(),
1832 x,
1833 amplitude,
1834 );
1835 if best.as_ref().map(|(_, _, e)| err < *e).unwrap_or(true) {
1836 best = Some((coord, cert, err));
1837 }
1838 // Global-minimum short-circuit: reconstruction error ≥ 0, so a
1839 // certified candidate already at the ambient noise floor is provably
1840 // the global optimum over the remaining charts — stop refining them.
1841 if let Some((_, _, e)) = best.as_ref() {
1842 if *e <= CERTIFIED_GLOBAL_MIN_RECON_FLOOR * (1.0 + x.dot(&x).sqrt()) {
1843 break;
1844 }
1845 }
1846 }
1847 }
1848 match best {
1849 Some((coord, cert, _)) => Ok((coord, cert)),
1850 None => Ok(nearest_fallback.unwrap_or_else(|| {
1851 (
1852 Array1::<f64>::zeros(d),
1853 RowCertificate {
1854 beta: f64::INFINITY,
1855 eta: f64::INFINITY,
1856 lipschitz: f64::INFINITY,
1857 h: f64::INFINITY,
1858 },
1859 )
1860 })),
1861 }
1862 }
1863
1864 /// Amortized (distilled) encode of one target row `x` against one atom `k`
1865 /// with fixed amplitude `z` (#1026 ladder item 3).
1866 ///
1867 /// Routes to the nearest chart, then predicts the latent coordinate in CLOSED
1868 /// FORM from that chart's precomputed implicit-function-theorem Jacobian:
1869 ///
1870 /// ```text
1871 /// t̂ = t_c + (1/z) · A₁ · (x − z · m₁(t_c)),
1872 /// ```
1873 ///
1874 /// a single `O(d·p)` mat-vec — no per-row Hessian factorization or
1875 /// eigendecomposition, which is the amortization. The Kantorovich
1876 /// certificate is then evaluated AT the predicted start `t̂` with the chart's
1877 /// closed-form Lipschitz constant. A prediction is accepted only when that
1878 /// certificate holds, an independent cold chart-center probe also certifies,
1879 /// and the two refined coordinates agree within the two probes' final
1880 /// Kantorovich root-radius bounds. This keeps the distilled path honest
1881 /// without letting the exact probe reuse the distilled warm start it is
1882 /// auditing. A chart without a distilled Jacobian (singular Gauss–Newton
1883 /// block) flags the row.
1884 pub fn amortized_encode_row(
1885 &self,
1886 atom: &SaeManifoldAtom,
1887 atom_index: usize,
1888 x: ArrayView1<'_, f64>,
1889 amplitude: f64,
1890 ) -> Result<(Array1<f64>, RowCertificate), String> {
1891 let atom_atlas = self
1892 .atoms
1893 .get(atom_index)
1894 .ok_or_else(|| format!("amortized_encode_row: atom {atom_index} not in atlas"))?;
1895 let d = atom.latent_dim;
1896 let uncertified = || {
1897 (
1898 Array1::<f64>::zeros(d),
1899 RowCertificate {
1900 beta: f64::INFINITY,
1901 eta: f64::INFINITY,
1902 lipschitz: f64::INFINITY,
1903 h: f64::INFINITY,
1904 },
1905 )
1906 };
1907 // A missing basis evaluator means the distilled predictor cannot fire for
1908 // this atom — flag the row uncertified (the exact upstream solve owns it)
1909 // rather than erroring, exactly as the no-chart / singular-Jacobian /
1910 // non-positive-amplitude branches below do. Never a silent wrong encode,
1911 // never a hard abort of the criterion.
1912 let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
1913 return Ok(uncertified());
1914 };
1915 let Some((chart_idx, _)) = nearest_chart(atom_atlas, x, amplitude) else {
1916 return Ok(uncertified());
1917 };
1918 let chart = &atom_atlas.charts[chart_idx];
1919 // Closed-form predicted start t̂ = t_c + (1/z)·A₁·(x − z·m₁). `None` when
1920 // the chart's Gauss–Newton block was singular (no distilled Jacobian, so
1921 // the amortized predictor cannot fire) or the amplitude is not strictly
1922 // positive and finite (a near-inactive atom, where the amplitude-divided
1923 // map is undefined) — either way flag for the exact fallback, never a
1924 // silent wrong encode.
1925 let Some(t_hat) = amortized_warm_start(chart, x, amplitude) else {
1926 return Ok(uncertified());
1927 };
1928 // Evaluate the SAME Kantorovich certificate at the predicted start. The
1929 // amortized prediction is trusted only if this certificate holds AND an
1930 // independent cold chart-center probe certifies and agrees below the
1931 // two probes' final Kantorovich root-radius bounds. This avoids the
1932 // self-referential gate where the "exact" probe is warm-started by the
1933 // same distilled prediction it is supposed to audit.
1934 let Some(amortized_probe) = certify_with_basin_warmup(
1935 atom,
1936 evaluator.as_ref(),
1937 t_hat,
1938 x,
1939 amplitude,
1940 chart.lipschitz,
1941 self.config.newton_steps,
1942 chart.region.center.view(),
1943 chart.region.radius,
1944 )?
1945 else {
1946 return Ok((
1947 Array1::<f64>::zeros(d),
1948 uncertified_certificate(chart.lipschitz),
1949 ));
1950 };
1951
1952 let cold_start = chart.region.center.clone();
1953 let Some(cold_probe) = certify_with_basin_warmup(
1954 atom,
1955 evaluator.as_ref(),
1956 cold_start,
1957 x,
1958 amplitude,
1959 chart.lipschitz,
1960 self.config.newton_steps,
1961 chart.region.center.view(),
1962 chart.region.radius,
1963 )?
1964 else {
1965 return Ok((
1966 amortized_probe.coord,
1967 uncertified_certificate(chart.lipschitz),
1968 ));
1969 };
1970
1971 let gap =
1972 latent_coordinate_distance(atom, amortized_probe.coord.view(), cold_probe.coord.view());
1973 let tolerance = distilled_probe_tolerance(&amortized_probe, &cold_probe, amplitude, x);
1974 if !(gap.is_finite() && gap <= tolerance) {
1975 return Ok((
1976 amortized_probe.coord,
1977 uncertified_certificate(chart.lipschitz),
1978 ));
1979 }
1980 // F5: return the certificate at the refined landing coordinate
1981 // (`final_cert`), consistent with the coord actually returned and with the
1982 // `distilled_probe_tolerance` gate above (which already reads `final_cert`
1983 // via `kantorovich_root_radius`). `initial_cert` is a stale earlier iterate.
1984 Ok((amortized_probe.coord, amortized_probe.final_cert))
1985 }
1986
1987 /// Batched amortized (distilled) encode over many rows against one atom
1988 /// (#1026 ladder item 3, corpus-rate). Each row uses the closed-form
1989 /// per-chart Jacobian predictor and carries its own Kantorovich certificate;
1990 /// uncertified rows are flagged in [`EncodeResult::encode_uncertified_count`]
1991 /// for the exact multi-start fallback. Row-independent against the frozen
1992 /// dictionary, so the batch fans out over rows (deterministic row-order
1993 /// assembly, bit-identical run-to-run), staying sequential inside a rayon
1994 /// worker to avoid nested oversubscription.
1995 pub fn amortized_encode_batch(
1996 &self,
1997 atom: &SaeManifoldAtom,
1998 atom_index: usize,
1999 targets: ArrayView2<'_, f64>,
2000 amplitudes: ArrayView1<'_, f64>,
2001 ) -> Result<EncodeResult, String> {
2002 let n = targets.nrows();
2003 if amplitudes.len() != n {
2004 return Err(format!(
2005 "amortized_encode_batch: amplitudes len {} != rows {n}",
2006 amplitudes.len()
2007 ));
2008 }
2009 let d = atom.latent_dim;
2010 let encode_rows =
2011 |range: std::ops::Range<usize>| -> Result<Vec<(Array1<f64>, bool)>, String> {
2012 range
2013 .map(|row| {
2014 let (t, cert) = self.amortized_encode_row(
2015 atom,
2016 atom_index,
2017 targets.row(row),
2018 amplitudes[row],
2019 )?;
2020 Ok((t, cert.certified()))
2021 })
2022 .collect()
2023 };
2024 let rows: Vec<(Array1<f64>, bool)> =
2025 if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
2026 use rayon::prelude::*;
2027 const CHUNK: usize = 256;
2028 let n_chunks = n.div_ceil(CHUNK);
2029 let chunked: Vec<Vec<(Array1<f64>, bool)>> = (0..n_chunks)
2030 .into_par_iter()
2031 .map(|c| {
2032 let start = c * CHUNK;
2033 let end = (start + CHUNK).min(n);
2034 encode_rows(start..end)
2035 })
2036 .collect::<Result<_, _>>()?;
2037 chunked.into_iter().flatten().collect()
2038 } else {
2039 encode_rows(0..n)?
2040 };
2041 let mut coords = Array2::<f64>::zeros((n, d));
2042 let mut certified = Vec::with_capacity(n);
2043 for (row, (t, cert)) in rows.into_iter().enumerate() {
2044 coords.row_mut(row).assign(&t);
2045 certified.push(cert);
2046 }
2047 Ok(EncodeResult::from_rows(coords, certified))
2048 }
2049
2050 /// Batched certified encode over many rows against one atom (the #988
2051 /// throughput consumer). Each row carries its own certificate; uncertified
2052 /// rows are flagged in [`EncodeResult::encode_uncertified_count`] for the
2053 /// exact multi-start fallback.
2054 pub fn certified_encode_batch(
2055 &self,
2056 atom: &SaeManifoldAtom,
2057 atom_index: usize,
2058 targets: ArrayView2<'_, f64>,
2059 amplitudes: ArrayView1<'_, f64>,
2060 ) -> Result<EncodeResult, String> {
2061 let n = targets.nrows();
2062 if amplitudes.len() != n {
2063 return Err(format!(
2064 "certified_encode_batch: amplitudes len {} != rows {n}",
2065 amplitudes.len()
2066 ));
2067 }
2068 let d = atom.latent_dim;
2069 // Per-row encode is independent against a frozen dictionary (#1010), so
2070 // the corpus-rate batch fans out over rows (#1026 amortized-encoder leg /
2071 // #977 Stage-3 corpus encode). Each row produces an owned `(t, certified)`
2072 // pair; results are assembled back in row order so the output is
2073 // bit-identical run-to-run regardless of thread scheduling. Stay
2074 // sequential inside a rayon worker (e.g. when an outer atom-level fan-out
2075 // owns the pool) to avoid nested oversubscription. The first row that
2076 // fails to encode propagates its error deterministically.
2077 let encode_rows =
2078 |range: std::ops::Range<usize>| -> Result<Vec<(Array1<f64>, bool)>, String> {
2079 range
2080 .map(|row| {
2081 let (t, cert) = self.certified_encode_row(
2082 atom,
2083 atom_index,
2084 targets.row(row),
2085 amplitudes[row],
2086 )?;
2087 Ok((t, cert.certified()))
2088 })
2089 .collect()
2090 };
2091 let rows: Vec<(Array1<f64>, bool)> =
2092 if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
2093 use rayon::prelude::*;
2094 const CHUNK: usize = 256;
2095 let n_chunks = n.div_ceil(CHUNK);
2096 let chunked: Vec<Vec<(Array1<f64>, bool)>> = (0..n_chunks)
2097 .into_par_iter()
2098 .map(|c| {
2099 let start = c * CHUNK;
2100 let end = (start + CHUNK).min(n);
2101 encode_rows(start..end)
2102 })
2103 .collect::<Result<_, _>>()?;
2104 chunked.into_iter().flatten().collect()
2105 } else {
2106 encode_rows(0..n)?
2107 };
2108 let mut coords = Array2::<f64>::zeros((n, d));
2109 let mut certified = Vec::with_capacity(n);
2110 for (row, (t, cert)) in rows.into_iter().enumerate() {
2111 coords.row_mut(row).assign(&t);
2112 certified.push(cert);
2113 }
2114 Ok(EncodeResult::from_rows(coords, certified))
2115 }
2116
2117 /// Batched GEMM "fast" amortized encode — the traditional-encoder forward
2118 /// pass, WITH manifolds. For every row this applies the SAME closed-form
2119 /// affine predictor as [`amortized_warm_start`]
2120 /// (`t̂ = t_c + (1/z)·A₁·(x − z·m₁)`), but routed and applied as batched
2121 /// matrix products instead of a per-row loop wrapped in the Kantorovich
2122 /// certificate + basin warmup. NO per-row certificate is taken: this is the
2123 /// speed mode (the certified `*_encode_*` paths remain the accuracy mode).
2124 ///
2125 /// Cost is GEMM-bound: one `(n × p)·(p × d)` decode-distance product for
2126 /// nearest-chart routing (skipped for single-chart atoms) plus, per chart,
2127 /// one `(n_c × p)·(p × d)` predictor product — i.e. `≈ X·Wᵀ`, exactly a
2128 /// dense SAE encoder's forward map.
2129 ///
2130 /// Degenerate rows are handled exactly as `amortized_warm_start` flags them
2131 /// (returns `None` ⇒ zeroed coord here): a missing basis evaluator, a chart
2132 /// whose Gauss–Newton block was singular (`amortized_jacobian == None`), or a
2133 /// non-finite / non-positive amplitude. Those rows are zeroed (never a panic,
2134 /// never a silent wrong encode), and their indices are returned in the
2135 /// `valid` mask so the caller can route them to the exact path if desired.
2136 ///
2137 /// Returns `(coords, valid)` where `coords` is `n × d` and `valid[row]` is
2138 /// `true` iff the amortized predictor fired for that row.
2139 pub fn amortized_encode_batch_fast(
2140 &self,
2141 atom: &SaeManifoldAtom,
2142 atom_index: usize,
2143 x: ArrayView2<'_, f64>,
2144 amplitudes: ArrayView1<'_, f64>,
2145 ) -> Result<(Array2<f64>, Vec<bool>), String> {
2146 let n = x.nrows();
2147 let p = atom.output_dim();
2148 let d = atom.latent_dim;
2149 if x.ncols() != p {
2150 return Err(format!(
2151 "amortized_encode_batch_fast: x has {} cols but atom output dim is {p}",
2152 x.ncols()
2153 ));
2154 }
2155 if amplitudes.len() != n {
2156 return Err(format!(
2157 "amortized_encode_batch_fast: amplitudes len {} != rows {n}",
2158 amplitudes.len()
2159 ));
2160 }
2161 let atom_atlas = self.atoms.get(atom_index).ok_or_else(|| {
2162 format!("amortized_encode_batch_fast: atom {atom_index} not in atlas")
2163 })?;
2164 let mut coords = Array2::<f64>::zeros((n, d));
2165 let mut valid = vec![false; n];
2166
2167 // A missing basis evaluator means this atom never had a well-formed atlas
2168 // built here — treat every row as uncertified (zeroed), exactly like the
2169 // per-row `amortized_encode_row` no-evaluator branch. (The predictor below
2170 // uses only cached atlas data, so no evaluator call is made online.)
2171 if atom.basis_evaluator.is_none() {
2172 return Ok((coords, valid));
2173 }
2174
2175 // ── Routing recon-centers (cached, no online basis evaluation). ────────
2176 // Routing sends a row to the chart whose center reconstruction
2177 // `m(t_c) = BᵀΦ(t_c)` is closest in ‖·‖². Those center reconstructions are
2178 // OFFLINE-cached in `chart.recon_center` (bit-identical to re-evaluating the
2179 // basis at the fixed centers — same φ·decoder accumulation). Gather the cache
2180 // instead of calling `evaluator.evaluate` on every invocation: that per-call
2181 // chart-center evaluation was the dominant per-atom-group overhead at massive
2182 // K, where N rows scatter across many atoms into tiny groups so a fixed
2183 // per-call cost is amortized over only a handful of rows. This is what keeps
2184 // the fast index-routed encode near-flat as K grows.
2185 let valid_charts: Vec<usize> = (0..atom_atlas.charts.len())
2186 .filter(|&c| atom_atlas.charts[c].certified_radius > 0.0)
2187 .collect();
2188 if valid_charts.is_empty() {
2189 return Ok((coords, valid));
2190 }
2191 // recon_centers (C × p): the cached m(t_c) for each certifiable chart.
2192 let mut recon_centers = Array2::<f64>::zeros((valid_charts.len(), p));
2193 for (ci, &c) in valid_charts.iter().enumerate() {
2194 recon_centers
2195 .row_mut(ci)
2196 .assign(&atom_atlas.charts[c].recon_center);
2197 }
2198 // Per-chart routing key: route_idx[row] = argmin_c ‖x_row − z_row·r_c‖²
2199 // (F1: the TRUE objective at the row's amplitude `z_row`, where `r_c` is
2200 // the amplitude-1 center reconstruction). First chart wins on a tie (strict
2201 // `<`), matching `nearest_chart`. A non-finite amplitude routes to chart 0
2202 // (moot — the predictor below skips the row anyway).
2203 //
2204 // Score the DIRECT squared distance `Σ_j (z·r_cj − x_j)²`, accumulated in
2205 // the same element order as the per-row `nearest_chart`, NOT the algebraic
2206 // expansion `z²‖r_c‖² − 2z·(x·r_c)`. The two are equal in exact arithmetic,
2207 // but the expansion subtracts two O(‖x‖²) quantities and loses ~‖x‖²·ε of
2208 // precision to cancellation — enough to FLIP the argmin between two charts
2209 // whose reconstructions coincide to rounding (e.g. period-wrapped torus-seam
2210 // charts, whose latent centers differ by a full period yet reconstruct to
2211 // within 1e-14). On such a near-tie the expansion and `nearest_chart`
2212 // disagree, and the two seam charts predict coords a full period apart — so
2213 // the batched fast-encode would diverge from the per-row warm-start it must
2214 // reproduce bit-for-bit (the `fast_encode_matches_per_row_warm_start`
2215 // contract). Computing the same direct distance in the same order keeps this
2216 // path byte-identical to `nearest_chart`, so routing (and the tie-break)
2217 // agrees on every row. `recon_centers` is still the cached offline
2218 // `m₁(t_c)` — no basis re-evaluation, which was the fast path's real cost.
2219 let route_idx: Vec<usize> = if valid_charts.len() == 1 {
2220 vec![0usize; n]
2221 } else {
2222 (0..n)
2223 .map(|row| {
2224 let z = amplitudes[row];
2225 if !z.is_finite() {
2226 return 0usize;
2227 }
2228 let x_row = x.row(row);
2229 let mut best_c = 0usize;
2230 let mut best_d = f64::INFINITY;
2231 for c in 0..valid_charts.len() {
2232 let mut dist = 0.0;
2233 for (r, xv) in recon_centers.row(c).iter().zip(x_row.iter()) {
2234 let diff = z * r - xv;
2235 dist += diff * diff;
2236 }
2237 if dist < best_d {
2238 best_d = dist;
2239 best_c = c;
2240 }
2241 }
2242 best_c
2243 })
2244 .collect()
2245 };
2246
2247 // ── Per-chart batched affine predictor. ───────────────────────────────
2248 // For rows routed to chart `c` with finite jacobian `A₁` (d × p) and
2249 // center reconstruction `m₁` (= `chart.recon_center`), the predictor is
2250 // t̂ = t_c − A₁·m₁ + (1/z)·(A₁·x).
2251 // `t_c − A₁·m₁` is a per-chart constant `base`; `A₁·x` is a d-vector of
2252 // per-row dot products. Instead of gathering routed rows into a fresh
2253 // `X_c` (n_c × p) buffer and running a GEMM into a second `U` (n_c × d)
2254 // buffer — two allocations plus a full copy of the routed rows, per chart —
2255 // fuse the gather straight into the multiply: stream each source row of `x`
2256 // once (it is contiguous) and dot it against `A₁`'s rows, writing the
2257 // predicted coord directly. Zero per-chart heap traffic; the inverse
2258 // amplitude is hoisted to one reciprocal per row.
2259 //
2260 // Precompute each valid chart's `(A₁, base)` once (charts with a singular
2261 // Gauss–Newton block carry no `A₁`, so their routed rows stay
2262 // zeroed/uncertified — same as `amortized_warm_start` returning `None`).
2263 struct ChartPredictor<'a> {
2264 a1: &'a Array2<f64>,
2265 base: &'a Array1<f64>,
2266 }
2267 let predictors: Vec<Option<ChartPredictor<'_>>> = valid_charts
2268 .iter()
2269 .map(|&c| {
2270 let chart = &atom_atlas.charts[c];
2271 // `base = t_c − A₁·m₁` is precomputed offline in the atlas; reuse it
2272 // (both are `Some` together — singular G-N block ⇒ both `None`).
2273 match (
2274 chart.amortized_jacobian.as_ref(),
2275 chart.amortized_base.as_ref(),
2276 ) {
2277 (Some(a1), Some(base)) => Some(ChartPredictor { a1, base }),
2278 _ => None,
2279 }
2280 })
2281 .collect();
2282
2283 for row in 0..n {
2284 let Some(pred) = predictors[route_idx[row]].as_ref() else {
2285 continue;
2286 };
2287 let amp = amplitudes[row];
2288 if !(amp.is_finite() && amp.abs() > 0.0) {
2289 continue;
2290 }
2291 let inv_z = 1.0 / amp;
2292 let x_row = x.row(row);
2293 let mut coord_row = coords.row_mut(row);
2294 for axis in 0..d {
2295 // (A₁·x)[axis] = A₁ row `axis` (contiguous, length p) · x_row.
2296 coord_row[axis] = pred.base[axis] + pred.a1.row(axis).dot(&x_row) * inv_z;
2297 }
2298 valid[row] = true;
2299 }
2300 Ok((coords, valid))
2301 }
2302
2303 /// Fast batched FULL forward pass against one atom: encode → decode, the
2304 /// manifold analogue of a traditional SAE's `x̂ = z·D` (decoder `D`, code `z`).
2305 ///
2306 /// A traditional SAE decodes with one GEMM. The manifold SAE's reconstruction
2307 /// is `m(t̂) = z·Φ(t̂)·B` (module header) — the SAME GEMM `Φ·B`, but the code
2308 /// `Φ(t̂)` is the curved chart basis evaluated at the encoded latent coordinate
2309 /// rather than a flat one-hot. So the fast forward is exactly:
2310 /// 1. [`amortized_encode_batch_fast`] → per-row latent coords `t̂` (one
2311 /// routing GEMM + one affine GEMM per chart — a traditional `W·x+b`);
2312 /// 2. ONE batched basis evaluation `Φ(t̂)` (the manifold-curvature step a
2313 /// flat SAE doesn't have — `n×m`);
2314 /// 3. ONE GEMM `recon = Φ(t̂)·B` (`(n×m)·(m×p)` — a traditional decoder
2315 /// `z·D`), then the per-row amplitude scale `z`.
2316 ///
2317 /// Rows the encoder could not certify-predict (no evaluator / singular
2318 /// Gauss–Newton block / non-finite-or-zero amplitude) are returned as a ZERO
2319 /// reconstruction and flagged `false` in the valid-mask — never a silent wrong
2320 /// decode. The reconstruction of a valid row equals, bit-for-bit up to GEMM
2321 /// reassociation, `z·(Φ(t̂_row)·B)` with `t̂` from the per-row predictor.
2322 pub fn amortized_reconstruct_batch_fast(
2323 &self,
2324 atom: &SaeManifoldAtom,
2325 atom_index: usize,
2326 x: ArrayView2<'_, f64>,
2327 amplitudes: ArrayView1<'_, f64>,
2328 ) -> Result<(Array2<f64>, Vec<bool>), String> {
2329 let n = x.nrows();
2330 let p = atom.output_dim();
2331 // Step 1: batched encode → latent coords (reuses the fast routing+affine).
2332 let (coords, valid) = self.amortized_encode_batch_fast(atom, atom_index, x, amplitudes)?;
2333 let mut recon = Array2::<f64>::zeros((n, p));
2334 // A missing evaluator means no row could encode — every row is zeroed and
2335 // already flagged `false` by the encode; nothing to decode.
2336 let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
2337 return Ok((recon, valid));
2338 };
2339 // Step 2: ONE batched basis evaluation Φ(t̂) over all rows (n × m). Invalid
2340 // rows carry coords = 0 (the chart-origin); we still evaluate them in the
2341 // batch for a single GEMM, then zero their reconstruction below — the basis
2342 // is finite at the origin so this cannot poison the valid rows' GEMM.
2343 let (phi, _jet) = evaluator
2344 .evaluate(coords.view())
2345 .map_err(|err| format!("amortized_reconstruct_batch_fast: basis eval: {err}"))?;
2346 // Step 3: ONE GEMM recon = Φ·B (n × p), then per-row amplitude scale z.
2347 // m(t̂) = z·Φ(t̂)·B, matching the module header and `fill_decoded_row`'s
2348 // `Φ·decoder` accumulation (the amplitude is applied once here).
2349 let decoded = phi.dot(&atom.decoder_coefficients); // (n × p), amplitude-1
2350 for row in 0..n {
2351 if !valid[row] {
2352 continue; // stays zeroed — uncertified, like warm_start `None`.
2353 }
2354 let z = amplitudes[row];
2355 for col in 0..p {
2356 recon[[row, col]] = z * decoded[[row, col]];
2357 }
2358 }
2359 Ok((recon, valid))
2360 }
2361
2362 /// LSH-routed certified encode (issue #1010 step 2 + 3): for each target
2363 /// row, the existing [`SaeCandidateIndex`] (#985/#994) proposes the
2364 /// best-aligned atom by frame alignment to the row direction; the row is then
2365 /// encoded against THAT atom's certified chart atlas. This is the production
2366 /// routing path. Atom selection is EXACT (#1777): [`SaeCandidateIndex::route_exact`]
2367 /// returns the global argmax of the routing score (the universal-bound LSH fast
2368 /// path, else a full-scan fallback) — never a silently-missed ungathered atom —
2369 /// and the atlas does the in-atom nearest-chart routing and the per-row
2370 /// Kantorovich certificate.
2371 ///
2372 /// `atoms[id]` must be aligned with the atlas's `atoms[id]` (same dictionary
2373 /// order the atlas was built from and the sketch/index were built over).
2374 /// A row over an empty dictionary, or whose globally-best atom aligns below the
2375 /// fit-quality floor, is flagged uncertified — it routes to the exact
2376 /// multi-start fallback, never a silent wrong encode.
2377 pub fn certified_encode_with_index<S: AtomFrameSketch + Sync>(
2378 &self,
2379 atoms: &[SaeManifoldAtom],
2380 index: &SaeCandidateIndex,
2381 sketch: &S,
2382 targets: ArrayView2<'_, f64>,
2383 amplitudes: ArrayView1<'_, f64>,
2384 latent_dim: usize,
2385 ) -> Result<EncodeResult, String> {
2386 let n = targets.nrows();
2387 if amplitudes.len() != n {
2388 return Err(format!(
2389 "certified_encode_with_index: amplitudes len {} != rows {n}",
2390 amplitudes.len()
2391 ));
2392 }
2393 let budget = auto_candidate_budget(atoms.len().max(1));
2394 // LSH-routed per-row encode is independent across rows (sublinear atom
2395 // selection + frozen-dictionary in-atom Newton), so the corpus-rate batch
2396 // fans out over rows (#1026 amortized-encoder/routing leg / #977 Stage-3).
2397 // `None` coords (no LSH candidate) carry through as a zeroed row flagged
2398 // uncertified — identical to the sequential semantics. Results assemble
2399 // back in row order (bit-identical run-to-run); the first encode error
2400 // propagates deterministically. Stay sequential inside a rayon worker to
2401 // avoid nested oversubscription.
2402 let encode_rows =
2403 |range: std::ops::Range<usize>| -> Result<Vec<Option<(Array1<f64>, bool)>>, String> {
2404 range
2405 .map(|row| {
2406 // EXACT routing (#1777): pick the GLOBAL argmax of the
2407 // routing score over the whole dictionary, not merely the
2408 // best LSH-gathered candidate. `route_exact` certifies the
2409 // sublinear gather against the universal `[0,1]` alignment
2410 // bound and falls back to a full scan otherwise, so the
2411 // returned atom is guaranteed to be the globally-best — no
2412 // silently-missed ungathered atom.
2413 let Some(route) =
2414 index.route_exact(sketch, targets.row(row), budget, true)
2415 else {
2416 // Empty dictionary: flag for the exact fallback.
2417 return Ok(None);
2418 };
2419 let best_atom = route.atom;
2420 // Fit-quality floor: even the globally-best atom may align
2421 // only weakly with this row (no atom fits it). A finite
2422 // alignment below the floor — or a NaN, the zero-norm
2423 // ‖d‖ = 0 row — flags for the exact multi-start fallback
2424 // rather than encoding against a poorly-fitting atom. This is
2425 // a quality gate, not a routing-correctness gate; routing is
2426 // already exact. See CANDIDATE_ROUTING_MIN_ALIGNMENT.
2427 if !route.alignment.is_finite()
2428 || route.alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
2429 {
2430 return Ok(None);
2431 }
2432 let atom = atoms.get(best_atom).ok_or_else(|| {
2433 format!(
2434 "certified_encode_with_index: proposed atom {best_atom} out of range"
2435 )
2436 })?;
2437 let (t, cert) = self.certified_encode_row(
2438 atom,
2439 best_atom,
2440 targets.row(row),
2441 amplitudes[row],
2442 )?;
2443 // Heterogeneous-atom dictionaries with different latent_dim
2444 // per atom are not supported by the batched API: the caller
2445 // declares one shared `latent_dim` for the output tensor.
2446 // Silently zeroing the coord row while recording a
2447 // certified=true flag would produce corrupted
2448 // reconstructions downstream — error loudly instead.
2449 if t.len() != latent_dim {
2450 return Err(format!(
2451 "certified_encode_with_index: atom {best_atom} returned t.len()={} \
2452 but declared latent_dim={latent_dim}; heterogeneous-dim \
2453 dictionaries are not supported by this batched encode path",
2454 t.len()
2455 ));
2456 }
2457 Ok(Some((t, cert.certified())))
2458 })
2459 .collect()
2460 };
2461 let rows: Vec<Option<(Array1<f64>, bool)>> =
2462 if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
2463 use rayon::prelude::*;
2464 const CHUNK: usize = 256;
2465 let n_chunks = n.div_ceil(CHUNK);
2466 let chunked: Vec<Vec<Option<(Array1<f64>, bool)>>> = (0..n_chunks)
2467 .into_par_iter()
2468 .map(|c| {
2469 let start = c * CHUNK;
2470 let end = (start + CHUNK).min(n);
2471 encode_rows(start..end)
2472 })
2473 .collect::<Result<_, _>>()?;
2474 chunked.into_iter().flatten().collect()
2475 } else {
2476 encode_rows(0..n)?
2477 };
2478 let mut coords = Array2::<f64>::zeros((n, latent_dim));
2479 let mut certified = Vec::with_capacity(n);
2480 for (row, slot) in rows.into_iter().enumerate() {
2481 match slot {
2482 Some((t, cert)) => {
2483 coords.row_mut(row).assign(&t);
2484 certified.push(cert);
2485 }
2486 None => certified.push(false),
2487 }
2488 }
2489 Ok(EncodeResult::from_rows(coords, certified))
2490 }
2491
2492 /// LSH-routed AMORTIZED (distilled) encode — the production token-rate
2493 /// encoder of #1026 ladder item 3. Identical routing to
2494 /// [`Self::certified_encode_with_index`] (LSH proposes the best-aligned atom,
2495 /// the atlas routes to the in-atom nearest chart), but the in-atom encode is
2496 /// the closed-form per-chart Jacobian predictor + certificate gate of
2497 /// [`Self::amortized_encode_row`] rather than the certified Newton-refinement
2498 /// path.
2499 /// This is the deployment path: the distilled affine map produces the encode
2500 /// in one mat-vec, the Kantorovich certificate decides trust-or-fallback per
2501 /// row, and uncertified rows (the adversarial tail the thread expects to
2502 /// concentrate on rare tokens) are flagged for the exact multi-start solve —
2503 /// compute goes where the questions are. Row-independent against the frozen
2504 /// dictionary, so the batch fans out over rows with deterministic row-order
2505 /// assembly (bit-identical run-to-run).
2506 pub fn amortized_encode_with_index<S: AtomFrameSketch + Sync>(
2507 &self,
2508 atoms: &[SaeManifoldAtom],
2509 index: &SaeCandidateIndex,
2510 sketch: &S,
2511 targets: ArrayView2<'_, f64>,
2512 amplitudes: ArrayView1<'_, f64>,
2513 latent_dim: usize,
2514 ) -> Result<EncodeResult, String> {
2515 let n = targets.nrows();
2516 if amplitudes.len() != n {
2517 return Err(format!(
2518 "amortized_encode_with_index: amplitudes len {} != rows {n}",
2519 amplitudes.len()
2520 ));
2521 }
2522 let budget = auto_candidate_budget(atoms.len().max(1));
2523 let encode_rows =
2524 |range: std::ops::Range<usize>| -> Result<Vec<Option<(Array1<f64>, bool)>>, String> {
2525 range
2526 .map(|row| {
2527 // EXACT routing (#1777): global argmax of the routing score,
2528 // not just the best LSH-gathered candidate (see
2529 // certified_encode_with_index for the full rationale).
2530 let Some(route) =
2531 index.route_exact(sketch, targets.row(row), budget, true)
2532 else {
2533 return Ok(None);
2534 };
2535 let best_atom = route.atom;
2536 // Fit-quality floor (not a routing-correctness gate; routing
2537 // is exact): even the globally-best atom may fit a row poorly,
2538 // and a NaN alignment is the zero-norm ‖d‖ = 0 row. Either way
2539 // flag for the exact multi-start fallback. See
2540 // CANDIDATE_ROUTING_MIN_ALIGNMENT.
2541 if !route.alignment.is_finite()
2542 || route.alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
2543 {
2544 return Ok(None);
2545 }
2546 let atom = atoms.get(best_atom).ok_or_else(|| {
2547 format!(
2548 "amortized_encode_with_index: proposed atom {best_atom} out of range"
2549 )
2550 })?;
2551 let (t, cert) = self.amortized_encode_row(
2552 atom,
2553 best_atom,
2554 targets.row(row),
2555 amplitudes[row],
2556 )?;
2557 if t.len() != latent_dim {
2558 return Err(format!(
2559 "amortized_encode_with_index: atom {best_atom} returned t.len()={} \
2560 but declared latent_dim={latent_dim}; heterogeneous-dim \
2561 dictionaries are not supported by this batched encode path",
2562 t.len()
2563 ));
2564 }
2565 Ok(Some((t, cert.certified())))
2566 })
2567 .collect()
2568 };
2569 let rows: Vec<Option<(Array1<f64>, bool)>> =
2570 if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
2571 use rayon::prelude::*;
2572 const CHUNK: usize = 256;
2573 let n_chunks = n.div_ceil(CHUNK);
2574 let chunked: Vec<Vec<Option<(Array1<f64>, bool)>>> = (0..n_chunks)
2575 .into_par_iter()
2576 .map(|c| {
2577 let start = c * CHUNK;
2578 let end = (start + CHUNK).min(n);
2579 encode_rows(start..end)
2580 })
2581 .collect::<Result<_, _>>()?;
2582 chunked.into_iter().flatten().collect()
2583 } else {
2584 encode_rows(0..n)?
2585 };
2586 let mut coords = Array2::<f64>::zeros((n, latent_dim));
2587 let mut certified = Vec::with_capacity(n);
2588 for (row, slot) in rows.into_iter().enumerate() {
2589 match slot {
2590 Some((t, cert)) => {
2591 coords.row_mut(row).assign(&t);
2592 certified.push(cert);
2593 }
2594 None => certified.push(false),
2595 }
2596 }
2597 Ok(EncodeResult::from_rows(coords, certified))
2598 }
2599
2600 /// LSH-routed FAST amortized encode over the WHOLE dictionary — the
2601 /// multi-atom, corpus-rate analogue of [`Self::amortized_encode_with_index`].
2602 ///
2603 /// `amortized_encode_with_index` routes per row, then runs the per-row
2604 /// closed-form predictor + Kantorovich certificate + cold cross-check on each
2605 /// row independently. This fast variant keeps the SAME per-row EXACT routing
2606 /// (`index.route_exact` + the fit-quality floor), but replaces the per-row
2607 /// predictor with the GEMM-batched [`Self::amortized_encode_batch_fast`]:
2608 /// it GROUPS rows by their global-argmax atom and runs one batched affine-
2609 /// predictor pass per atom-group (a routing GEMM + a predictor GEMM each),
2610 /// reproducing a traditional SAE's whole-dictionary `W·x+b` throughput. No
2611 /// per-row certificate — this is the speed mode validated as accuracy-parity
2612 /// with the certified solve (`fast_forward_is_accuracy_parity_with_certified`).
2613 ///
2614 /// Returns the per-row latent coords and a valid-mask: `false` for a row with
2615 /// an empty dictionary, a sub-threshold/NaN routing alignment, or one the batched
2616 /// predictor could not fire on (no evaluator / singular Gauss–Newton block /
2617 /// non-finite-or-zero amplitude). Each row is written exactly once (disjoint
2618 /// per-atom groups), so the result is independent of group iteration order.
2619 pub fn amortized_encode_with_index_fast<S: AtomFrameSketch + Sync>(
2620 &self,
2621 atoms: &[SaeManifoldAtom],
2622 index: &SaeCandidateIndex,
2623 sketch: &S,
2624 targets: ArrayView2<'_, f64>,
2625 amplitudes: ArrayView1<'_, f64>,
2626 latent_dim: usize,
2627 ) -> Result<(Array2<f64>, Vec<bool>), String> {
2628 let n = targets.nrows();
2629 if amplitudes.len() != n {
2630 return Err(format!(
2631 "amortized_encode_with_index_fast: amplitudes len {} != rows {n}",
2632 amplitudes.len()
2633 ));
2634 }
2635 let budget = auto_candidate_budget(atoms.len().max(1));
2636 let mut coords = Array2::<f64>::zeros((n, latent_dim));
2637 let mut valid = vec![false; n];
2638 // ── Single allocation-free pass: route each row, apply the CACHED predictor.
2639 //
2640 // Routing sublinearity (massive-K, K≈32k): the certified path uses
2641 // `route_exact`, whose universal-bound LSH certificate only fires at the
2642 // alignment ceiling (≈1.0); for any real dictionary (`alignment < 1`) it
2643 // falls back to `brute_force_best_atom` — an O(K) full scan PER ROW, making
2644 // the encode O(N·K). This SPEED path takes the LSH gather's best-aligned atom
2645 // directly (`propose` scores only ~budget candidates → O(log K)); a rare miss
2646 // is caught by the fit-quality floor + the downstream certificate/exact
2647 // fallback (the documented speed/accuracy tradeoff).
2648 //
2649 // Allocation: the predictor uses only OFFLINE-cached atlas data (per-chart
2650 // `recon_center` for routing + `amortized_jacobian`/`amortized_base` for the
2651 // `t̂ = base + (1/z)·A₁·x` mat-vec), so NO per-row or per-atom heap buffer is
2652 // allocated. This replaces the old per-atom-group GEMM sub-batch — which at
2653 // `K ≫ N` degenerated to ONE row per group, so its per-group buffers (x_sub,
2654 // recon-centers, predictors) dominated and made the "fast" path allocation-
2655 // bound. Now the only per-row allocation is inside `index.propose`.
2656 for row in 0..n {
2657 let dir = targets.row(row);
2658 let proposal = index.propose(sketch, dir, budget, true);
2659 let Some(&best_atom) = proposal.proposed.first() else {
2660 continue; // nothing gathered (empty dictionary / probe-dim mismatch)
2661 };
2662 // Fit-quality floor: the best gathered atom still fits this row poorly,
2663 // or the alignment is NaN (zero-norm row) — flag for the exact fallback.
2664 let alignment = sketch.alignment(best_atom, dir);
2665 if !alignment.is_finite() || alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT {
2666 continue;
2667 }
2668 let atom = atoms.get(best_atom).ok_or_else(|| {
2669 format!("amortized_encode_with_index_fast: proposed atom {best_atom} out of range")
2670 })?;
2671 if atom.latent_dim != latent_dim {
2672 return Err(format!(
2673 "amortized_encode_with_index_fast: atom {best_atom} latent_dim {} != declared \
2674 {latent_dim}; heterogeneous-dim dictionaries are not supported by this path",
2675 atom.latent_dim
2676 ));
2677 }
2678 let Some(atom_atlas) = self.atoms.get(best_atom) else {
2679 continue; // no atlas for this atom → predictor cannot fire (zeroed)
2680 };
2681 if amortized_predict_row(
2682 atom_atlas,
2683 dir,
2684 amplitudes[row],
2685 latent_dim,
2686 coords.row_mut(row),
2687 ) {
2688 valid[row] = true;
2689 }
2690 }
2691 Ok((coords, valid))
2692 }
2693
2694 /// LSH-routed FAST full forward over the WHOLE dictionary: encode → decode,
2695 /// the multi-atom analogue of [`Self::amortized_reconstruct_batch_fast`]. Same
2696 /// sublinear per-row routing + per-atom grouping as
2697 /// [`Self::amortized_encode_with_index_fast`], but each group is run through
2698 /// the batched reconstruct (`m(t̂) = z·Φ(t̂)·B`) so the result is the per-row
2699 /// reconstruction in the ambient space. Rows that do not route/predict decode
2700 /// to an exact zero reconstruction and are flagged `false`.
2701 pub fn amortized_reconstruct_with_index_fast<S: AtomFrameSketch + Sync>(
2702 &self,
2703 atoms: &[SaeManifoldAtom],
2704 index: &SaeCandidateIndex,
2705 sketch: &S,
2706 targets: ArrayView2<'_, f64>,
2707 amplitudes: ArrayView1<'_, f64>,
2708 ) -> Result<(Array2<f64>, Vec<bool>), String> {
2709 let n = targets.nrows();
2710 let p = targets.ncols();
2711 if amplitudes.len() != n {
2712 return Err(format!(
2713 "amortized_reconstruct_with_index_fast: amplitudes len {} != rows {n}",
2714 amplitudes.len()
2715 ));
2716 }
2717 let budget = auto_candidate_budget(atoms.len().max(1));
2718 // SUBLINEAR routing for the SPEED-mode full forward — mirror
2719 // `amortized_encode_with_index_fast`: take the LSH gather's best-aligned
2720 // atom (O(budget) candidates) instead of route_exact's O(K) full-scan
2721 // certification, keeping the whole fast encode→decode sublinear in K at
2722 // K=32k. The gather's best is the exact argmax on the vast majority of rows;
2723 // rare misses are caught by the fit-quality floor + downstream fallback.
2724 let mut groups: std::collections::HashMap<usize, Vec<usize>> =
2725 std::collections::HashMap::new();
2726 for row in 0..n {
2727 let dir = targets.row(row);
2728 let proposal = index.propose(sketch, dir, budget, true);
2729 let Some(&best_atom) = proposal.proposed.first() else {
2730 continue; // nothing gathered (empty dictionary / probe-dim mismatch)
2731 };
2732 let alignment = sketch.alignment(best_atom, dir);
2733 if !alignment.is_finite() || alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT {
2734 continue;
2735 }
2736 groups.entry(best_atom).or_default().push(row);
2737 }
2738
2739 let mut recon = Array2::<f64>::zeros((n, p));
2740 let mut valid = vec![false; n];
2741 for (atom_idx, rows_here) in groups {
2742 let atom = atoms.get(atom_idx).ok_or_else(|| {
2743 format!(
2744 "amortized_reconstruct_with_index_fast: proposed atom {atom_idx} out of range"
2745 )
2746 })?;
2747 if atom.output_dim() != p {
2748 return Err(format!(
2749 "amortized_reconstruct_with_index_fast: atom {atom_idx} output_dim {} != target \
2750 dim {p}",
2751 atom.output_dim()
2752 ));
2753 }
2754 let mut x_sub = Array2::<f64>::zeros((rows_here.len(), p));
2755 let mut amp_sub = Array1::<f64>::zeros(rows_here.len());
2756 for (i, &row) in rows_here.iter().enumerate() {
2757 x_sub.row_mut(i).assign(&targets.row(row));
2758 amp_sub[i] = amplitudes[row];
2759 }
2760 let (sub_recon, sub_valid) = self.amortized_reconstruct_batch_fast(
2761 atom,
2762 atom_idx,
2763 x_sub.view(),
2764 amp_sub.view(),
2765 )?;
2766 for (i, &row) in rows_here.iter().enumerate() {
2767 if sub_valid[i] {
2768 recon.row_mut(row).assign(&sub_recon.row(i));
2769 valid[row] = true;
2770 }
2771 }
2772 }
2773 Ok((recon, valid))
2774 }
2775}
2776
2777/// Offline `β = 1/λ_min(H_GN)` at a chart center from the Gauss-Newton block
2778/// `H_GN = J_mᵀ J_m` (residual-free). The offline `β` bounds the curvature the
2779/// online certificate sees: charts are placed where the encode lands, so the
2780/// representative residual is small and `H_GN` is the dominant, residual-free
2781/// curvature estimate. (The online per-row certificate still uses the FULL
2782/// Hessian; this is only the offline radius-sizing curvature.) Returns `None`
2783/// for a degenerate center (`λ_min ≤ 0`), which marks an uncertifiable chart.
2784pub(crate) fn center_beta(atom: &SaeManifoldAtom, center: &Array1<f64>, ridge: f64) -> Option<f64> {
2785 let evaluator = atom.basis_evaluator.as_ref()?.clone();
2786 let d = atom.latent_dim;
2787 let p = atom.output_dim();
2788 let m = atom.basis_size();
2789 let coords = center.view().to_shape((1, d)).ok()?.to_owned();
2790 let (_phi, jet) = evaluator.evaluate(coords.view()).ok()?;
2791 let decoder = &atom.decoder_coefficients;
2792 // J_m[axis] = Bᵀ (∂Φ/∂t_axis) ∈ ℝᵖ (amplitude-1; curvature scales with z²
2793 // and is absorbed conservatively by the amplitude-bounded Lipschitz term).
2794 let mut jm = Array2::<f64>::zeros((d, p));
2795 for axis in 0..d {
2796 for basis_col in 0..m {
2797 let dphi = jet[[0, basis_col, axis]];
2798 if dphi == 0.0 {
2799 continue;
2800 }
2801 for out in 0..p {
2802 jm[[axis, out]] += dphi * decoder[[basis_col, out]];
2803 }
2804 }
2805 }
2806 let mut h = Array2::<f64>::zeros((d, d));
2807 for a in 0..d {
2808 for b in 0..d {
2809 h[[a, b]] = jm.row(a).dot(&jm.row(b));
2810 }
2811 h[[a, a]] += ridge;
2812 }
2813 let (vals, _vecs) = h.eigh(Side::Lower).ok()?;
2814 let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
2815 if lambda_min.is_finite() && lambda_min > 0.0 {
2816 Some(1.0 / lambda_min)
2817 } else {
2818 None
2819 }
2820}
2821
2822/// #1154 — the amortized encoder's closed-form warm-start coordinate for one
2823/// row `x` against one chart at amplitude `z`:
2824///
2825/// ```text
2826/// t̂ = t_c + (1/z) · A₁ · (x − z · m₁(t_c)),
2827/// ```
2828///
2829/// a single `O(d·p)` mat-vec from the chart's precomputed IFT Jacobian `A₁` and
2830/// center reconstruction `m₁`. Returns `None` when the chart carries no
2831/// distilled Jacobian (singular Gauss–Newton block) or the amplitude is not
2832/// strictly positive and finite (a near-inactive atom, where the
2833/// amplitude-divided map is undefined) — in those cases the caller starts from
2834/// the chart center instead. Shared by the amortized encode (where `t̂` is the
2835/// prediction) and the exact certified encode (where `t̂` is the Newton
2836/// warm-start that then refines to stationarity, Design A).
2837pub(crate) fn amortized_warm_start(
2838 chart: &CertifiedChart,
2839 x: ArrayView1<'_, f64>,
2840 amplitude: f64,
2841) -> Option<Array1<f64>> {
2842 let a1 = chart.amortized_jacobian.as_ref()?;
2843 if !(amplitude.is_finite() && amplitude.abs() > 0.0) {
2844 return None;
2845 }
2846 let d = a1.nrows();
2847 let mut t_hat = chart.region.center.clone();
2848 for (out_idx, &m1_out) in chart.recon_center.iter().enumerate().take(a1.ncols()) {
2849 let resid = x[out_idx] - amplitude * m1_out;
2850 for axis in 0..d {
2851 t_hat[axis] += a1[[axis, out_idx]] * resid / amplitude;
2852 }
2853 }
2854 Some(t_hat)
2855}
2856
2857/// Single-row amortized predictor against ONE atom's cached atlas, writing the
2858/// encoded latent coordinate DIRECTLY into `out` (length `d`) with NO heap
2859/// allocation. Routes to the nearest certifiable chart by cached center-
2860/// reconstruction distance `‖x − m(t_c)‖²` (matching `amortized_encode_batch_fast`'s
2861/// per-chart routing), then applies that chart's precomputed affine predictor
2862/// `t̂ = base + (1/z)·A₁·x` (`base = t_c − A₁·m₁` is `chart.amortized_base`).
2863///
2864/// Returns `false` — leaving `out` at its incoming (zeroed) value — for exactly the
2865/// rows `amortized_encode_batch_fast` would flag: no certifiable chart, a nearest
2866/// chart with no distilled predictor (singular Gauss–Newton block), or an unusable
2867/// amplitude. This is the per-row core of the allocation-free massive-K fast encode.
2868pub(crate) fn amortized_predict_row(
2869 atom_atlas: &AtomEncodeAtlas,
2870 x: ArrayView1<'_, f64>,
2871 amplitude: f64,
2872 d: usize,
2873 mut out: ndarray::ArrayViewMut1<'_, f64>,
2874) -> bool {
2875 if !(amplitude.is_finite() && amplitude.abs() > 0.0) {
2876 return false;
2877 }
2878 // Nearest certifiable chart by the TRUE objective ‖x − z·recon_center‖²
2879 // (F1; `z = amplitude`, finite and non-zero by the guard above). First-wins on
2880 // ties (strict `<`) — same amplitude-scaled argmin as
2881 // `amortized_encode_batch_fast`'s route_idx.
2882 let mut best_ci: Option<usize> = None;
2883 let mut best_dist = f64::INFINITY;
2884 for (ci, chart) in atom_atlas.charts.iter().enumerate() {
2885 if chart.certified_radius <= 0.0 {
2886 continue;
2887 }
2888 let mut dist = 0.0;
2889 for (r, xv) in chart.recon_center.iter().zip(x.iter()) {
2890 let diff = amplitude * r - xv;
2891 dist += diff * diff;
2892 }
2893 if dist < best_dist {
2894 best_dist = dist;
2895 best_ci = Some(ci);
2896 }
2897 }
2898 let Some(ci) = best_ci else {
2899 return false;
2900 };
2901 let chart = &atom_atlas.charts[ci];
2902 // The nearest chart must carry a distilled predictor; otherwise flag (zeroed),
2903 // exactly as the per-chart `None` branch of `amortized_encode_batch_fast`.
2904 let (Some(a1), Some(base)) = (
2905 chart.amortized_jacobian.as_ref(),
2906 chart.amortized_base.as_ref(),
2907 ) else {
2908 return false;
2909 };
2910 let inv_z = 1.0 / amplitude;
2911 for axis in 0..d {
2912 out[axis] = base[axis] + a1.row(axis).dot(&x) * inv_z;
2913 }
2914 true
2915}
2916
2917/// The amplitude-1 distilled amortized-encoder Jacobian at a chart center
2918/// (#1026 ladder item 3). Returns `(A₁, m₁)` where `m₁ = BᵀΦ(t_c) ∈ ℝᵖ` is the
2919/// amplitude-1 center reconstruction and `A₁ = (J₁ᵀJ₁ + ridge·I)⁻¹ J₁ ∈ ℝ^{d×p}`
2920/// is the implicit-function-theorem derivative of the encode map `x ↦ t`
2921/// (Gauss–Newton block — the residual-free, dominant curvature exactly as the
2922/// offline radius-sizing `β`). With these, the online encode of a row `x` at
2923/// amplitude `z` is the closed-form affine prediction
2924/// `t = t_c + (1/z)·A₁·(x − z·m₁)` — one mat-vec, no per-row factorization.
2925/// `None` when the basis has no jet or the Gauss–Newton block is singular (no
2926/// certifiable amortization), matching `center_beta`'s gate so a chart with a
2927/// finite `β` always carries a Jacobian and vice versa.
2928pub(crate) fn center_amortized_jacobian(
2929 atom: &SaeManifoldAtom,
2930 center: &Array1<f64>,
2931 ridge: f64,
2932) -> Option<(Array2<f64>, Array1<f64>)> {
2933 let evaluator = atom.basis_evaluator.as_ref()?.clone();
2934 let d = atom.latent_dim;
2935 let p = atom.output_dim();
2936 let m = atom.basis_size();
2937 let coords = center.view().to_shape((1, d)).ok()?.to_owned();
2938 let (phi, jet) = evaluator.evaluate(coords.view()).ok()?;
2939 let decoder = &atom.decoder_coefficients;
2940 // m₁(t_c) = BᵀΦ(t_c) ∈ ℝᵖ (amplitude-1 center reconstruction).
2941 let mut recon = Array1::<f64>::zeros(p);
2942 for basis_col in 0..m {
2943 let phi_v = phi[[0, basis_col]];
2944 if phi_v == 0.0 {
2945 continue;
2946 }
2947 for out in 0..p {
2948 recon[out] += phi_v * decoder[[basis_col, out]];
2949 }
2950 }
2951 // J₁[axis] = Bᵀ (∂Φ/∂t_axis) ∈ ℝᵖ (amplitude-1; z factors out analytically).
2952 let mut jm = Array2::<f64>::zeros((d, p));
2953 for axis in 0..d {
2954 for basis_col in 0..m {
2955 let dphi = jet[[0, basis_col, axis]];
2956 if dphi == 0.0 {
2957 continue;
2958 }
2959 for out in 0..p {
2960 jm[[axis, out]] += dphi * decoder[[basis_col, out]];
2961 }
2962 }
2963 }
2964 // H_GN = J₁ J₁ᵀ + ridge·I ∈ ℝ^{d×d}.
2965 let mut h = Array2::<f64>::zeros((d, d));
2966 for a in 0..d {
2967 for b in 0..d {
2968 h[[a, b]] = jm.row(a).dot(&jm.row(b));
2969 }
2970 h[[a, a]] += ridge;
2971 }
2972 let (vals, vecs) = h.eigh(Side::Lower).ok()?;
2973 let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
2974 if !(lambda_min.is_finite() && lambda_min > 0.0) {
2975 return None;
2976 }
2977 // A₁ = H_GN⁻¹ J₁ via the eigendecomposition: H⁻¹ = Σ_i (1/λᵢ) vᵢ vᵢᵀ, so
2978 // A₁[:, out] = Σ_i (vᵢ · J₁[:, out]) / λᵢ · vᵢ. Column-by-column keeps it the
2979 // d×p Jacobian (one SPD solve reused across all p output channels).
2980 let mut a1 = Array2::<f64>::zeros((d, p));
2981 for out in 0..p {
2982 let jcol = jm.column(out);
2983 for (i, &lam) in vals.iter().enumerate() {
2984 if !(lam.is_finite() && lam > 0.0) {
2985 return None;
2986 }
2987 let vi = vecs.column(i);
2988 let coeff = vi.dot(&jcol) / lam;
2989 for row in 0..d {
2990 a1[[row, out]] += coeff * vi[row];
2991 }
2992 }
2993 }
2994 Some((a1, recon))
2995}
2996
2997/// Route a target row to the nearest chart of an atom by reconstruction
2998/// distance: the chart whose center reconstruction `m(t_c)` is closest to `x`.
2999/// Returns the chart index and the distance, or `None` when the atom has no
3000/// charts.
3001pub(crate) fn nearest_chart(
3002 atom_atlas: &AtomEncodeAtlas,
3003 x: ArrayView1<'_, f64>,
3004 amplitude: f64,
3005) -> Option<(usize, f64)> {
3006 if atom_atlas.charts.is_empty() {
3007 return None;
3008 }
3009 let mut best: Option<(usize, f64)> = None;
3010 for (idx, chart) in atom_atlas.charts.iter().enumerate() {
3011 if chart.certified_radius <= 0.0 {
3012 continue;
3013 }
3014 // F1: score the TRUE encode objective `‖x − z·m(t_c)‖²` at the row's
3015 // amplitude `z`, NOT the amplitude-1 `‖x − m(t_c)‖²`. The distilled
3016 // `chart.recon_center` is the amplitude-1 reconstruction `m₁(t_c)`; the
3017 // reconstruction actually compared against `x` is `z·m₁(t_c)`. Routing on
3018 // the amplitude-1 centers picks the wrong chart whenever `z ≠ 1` (e.g. a
3019 // small `z` makes a large-norm center reconstruct near a small `x`).
3020 // Distance accumulated in place, no temporary array.
3021 let mut dist = 0.0;
3022 for (r, xv) in chart.recon_center.iter().zip(x.iter()) {
3023 let diff = amplitude * r - xv;
3024 dist += diff * diff;
3025 }
3026 if best.map(|(_, b)| dist < b).unwrap_or(true) {
3027 best = Some((idx, dist));
3028 }
3029 }
3030 best
3031}
3032
3033/// The `k` charts whose CENTER reconstruction `m(t_c)` is nearest to `x` in
3034/// ambient ‖·‖², returned as chart indices sorted by increasing distance (ties
3035/// broken by chart index — deterministic). Only certifiable charts
3036/// (`certified_radius > 0`) are considered, exactly like [`nearest_chart`], whose
3037/// single result is `nearest_charts_topk(.., 1)[0]`. Used by the certified encode
3038/// to refine the global basin on self-approaching atoms (see
3039/// [`CERTIFIED_ROUTING_TOPK`]).
3040pub(crate) fn nearest_charts_topk(
3041 atom_atlas: &AtomEncodeAtlas,
3042 x: ArrayView1<'_, f64>,
3043 amplitude: f64,
3044 k: usize,
3045) -> Vec<usize> {
3046 if atom_atlas.charts.is_empty() || k == 0 {
3047 return Vec::new();
3048 }
3049 let mut scored: Vec<(usize, f64)> = Vec::new();
3050 for (idx, chart) in atom_atlas.charts.iter().enumerate() {
3051 if chart.certified_radius <= 0.0 {
3052 continue;
3053 }
3054 // `m₁(t_c) = BᵀΦ(t_c)` is an OFFLINE per-chart constant already distilled
3055 // into `chart.recon_center` at build time (bit-for-bit the same φ·decoder
3056 // accumulation this used to recompute). Reuse it instead of re-evaluating
3057 // the basis at a fixed center for every row — that re-eval was the encode's
3058 // dominant per-row cost (charts × rows basis evals, each allocating the φ/jet
3059 // arrays). F1: score the amplitude-scaled `‖x − z·m₁(t_c)‖²` (the true
3060 // objective), not the amplitude-1 `‖m₁(t_c) − x‖²`. Computed in place.
3061 let mut dist = 0.0;
3062 for (r, xv) in chart.recon_center.iter().zip(x.iter()) {
3063 let diff = amplitude * r - xv;
3064 dist += diff * diff;
3065 }
3066 scored.push((idx, dist));
3067 }
3068 // Sort by distance, then chart index for a deterministic, first-wins order
3069 // consistent with `nearest_chart`'s strict-`<` tie rule.
3070 scored.sort_by(|a, b| {
3071 a.1.partial_cmp(&b.1)
3072 .unwrap_or(std::cmp::Ordering::Equal)
3073 .then(a.0.cmp(&b.0))
3074 });
3075 scored.into_iter().take(k).map(|(idx, _)| idx).collect()
3076}
3077
3078/// Reconstruction error `‖x − z·m(t)‖` of an encoded coordinate `t` — the
3079/// criterion the certified encode minimizes over its candidate charts to pick the
3080/// GLOBAL basin. `m(t) = Bᵀ Φ(t)` is the amplitude-1 reconstruction; `z` is the
3081/// amplitude. A non-finite reconstruction returns `+∞` so it never wins.
3082pub(crate) fn encode_reconstruction_error(
3083 atom: &SaeManifoldAtom,
3084 evaluator: &dyn SaeBasisEvaluator,
3085 coord: ArrayView1<'_, f64>,
3086 x: ArrayView1<'_, f64>,
3087 amplitude: f64,
3088) -> f64 {
3089 let d = atom.latent_dim;
3090 let p = atom.output_dim();
3091 let m = atom.basis_size();
3092 let coords = match coord.to_shape((1, d)) {
3093 Ok(c) => c.to_owned(),
3094 Err(_) => return f64::INFINITY,
3095 };
3096 let Ok((phi, _jet)) = evaluator.evaluate(coords.view()) else {
3097 return f64::INFINITY;
3098 };
3099 let mut err2 = 0.0;
3100 for out in 0..p {
3101 let mut recon = 0.0;
3102 for basis_col in 0..m {
3103 recon += phi[[0, basis_col]] * atom.decoder_coefficients[[basis_col, out]];
3104 }
3105 let r = x[out] - amplitude * recon;
3106 err2 += r * r;
3107 }
3108 if err2.is_finite() {
3109 err2.sqrt()
3110 } else {
3111 f64::INFINITY
3112 }
3113}
3114
3115/// Maximum number of chart centers laid down per atom (the SHAPE_BAND grid
3116/// point cap; mirrors `SHAPE_BAND_MAX_POINTS` in the atom band machinery).
3117pub(crate) const SHAPE_BAND_MAX_POINTS: usize = 512;
3118
3119/// Lay down chart centers on an atom's coordinate grid (the SHAPE_BAND grid
3120/// idiom): a regular grid spanning the compact latent domain for periodic /
3121/// sphere / torus atoms, and a strided cover of the latent axes for unbounded
3122/// (Duchon / Euclidean) atoms.
3123///
3124/// Periodic / torus latents are fractions of one period, so the per-axis grid
3125/// spans `[0, 1)`; the sphere chart spans `lat ∈ [−π/2, π/2]`, `lon ∈ [−π, π)`.
3126/// These conventions match the basis evaluators (the fraction-of-period circle
3127/// harmonic and the lat/lon sphere chart).
3128/// Squared coordinate distance between two latent points under the atom's chart
3129/// geometry: per-axis WRAPPED distance `min(|a−b|, period−|a−b|)` on periodic
3130/// (circle) axes — period 1 to match `chart_center_grid`'s `[0,1)` torus tiling
3131/// — and plain difference on line axes. Used to place + size data-driven charts.
3132pub(crate) fn coord_dist_sq(
3133 atom: &SaeManifoldAtom,
3134 a: ArrayView1<'_, f64>,
3135 b: ArrayView1<'_, f64>,
3136) -> f64 {
3137 // Per-axis period comes from the ONE canonical source `latent_axis_period` —
3138 // the SAME convention the certified-encode `in_chart` guard uses (via
3139 // `latent_coordinate_distance`): period-1 fraction axes for periodic/torus and
3140 // the cylinder angle, period-2π on the sphere LONGITUDE (axis 1), and
3141 // NON-periodic otherwise — including the sphere LATITUDE (axis 0), which ranges
3142 // over [−π/2, π/2] and must NOT wrap. A former local `periodic_axis` closure
3143 // wrapped BOTH sphere axes at period 1, so two genuinely-far radian longitudes
3144 // (e.g. 0 and 3 rad) collapsed to distance 0 — corrupting the farthest-point
3145 // placement AND the nearest-neighbour radii in `data_driven_chart_centers` for
3146 // sphere atoms, and disagreeing with the metric the encode's `in_chart`
3147 // soundness guard enforces. Delegating keeps the two metrics identical.
3148 let mut acc = 0.0;
3149 for axis in 0..a.len().min(b.len()) {
3150 let mut d = (a[axis] - b[axis]).abs();
3151 if let Some(period) = latent_axis_period(atom, axis) {
3152 let wrapped = d.rem_euclid(period);
3153 d = wrapped.min(period - wrapped);
3154 }
3155 acc += d * d;
3156 }
3157 acc
3158}
3159
3160/// Greedy farthest-point sampling of up to `max_charts` chart centers from the
3161/// atom's latent `coords` (n × d), with each center's nominal radius set to half
3162/// the distance to its nearest neighbor center (floored, so a singleton/coincident
3163/// cluster still gets a usable ball). Deterministic: seeds from row 0, then
3164/// repeatedly adds the coord maximally far (under [`coord_dist_sq`]) from the
3165/// chosen set — coverage-maximizing and reproducible run-to-run.
3166pub(crate) fn data_driven_chart_centers(
3167 atom: &SaeManifoldAtom,
3168 coords: ArrayView2<'_, f64>,
3169 max_charts: usize,
3170) -> Result<(Array2<f64>, Vec<f64>), String> {
3171 let n = coords.nrows();
3172 let d = coords.ncols();
3173 if d != atom.latent_dim {
3174 return Err(format!(
3175 "data_driven_chart_centers: coords have {d} cols but atom latent_dim is {}",
3176 atom.latent_dim
3177 ));
3178 }
3179 if n == 0 {
3180 return Ok((Array2::<f64>::zeros((0, d)), Vec::new()));
3181 }
3182 let k = max_charts.min(n);
3183 // Farthest-point sampling: maintain each row's distance to the nearest chosen
3184 // center, add the row with the maximum such distance each step.
3185 let mut chosen: Vec<usize> = Vec::with_capacity(k);
3186 chosen.push(0);
3187 let mut nearest_sq: Vec<f64> = (0..n)
3188 .map(|r| coord_dist_sq(atom, coords.row(r), coords.row(0)))
3189 .collect();
3190 while chosen.len() < k {
3191 // Pick the row farthest from the current center set (first-wins tie).
3192 let mut best = 0usize;
3193 let mut best_d = -1.0;
3194 for r in 0..n {
3195 if nearest_sq[r] > best_d {
3196 best_d = nearest_sq[r];
3197 best = r;
3198 }
3199 }
3200 if best_d <= 0.0 {
3201 break; // all remaining rows coincide with a chosen center.
3202 }
3203 chosen.push(best);
3204 for r in 0..n {
3205 let dr = coord_dist_sq(atom, coords.row(r), coords.row(best));
3206 if dr < nearest_sq[r] {
3207 nearest_sq[r] = dr;
3208 }
3209 }
3210 }
3211 let m = chosen.len();
3212 let mut centers = Array2::<f64>::zeros((m, d));
3213 for (i, &row) in chosen.iter().enumerate() {
3214 centers.row_mut(i).assign(&coords.row(row));
3215 }
3216 // Per-center radius = half the nearest-OTHER-center distance, floored so a
3217 // coincident pair still yields a positive ball, capped at 0.5 (the largest
3218 // meaningful half-period on a unit circle).
3219 let mut radii = vec![0.0_f64; m];
3220 for i in 0..m {
3221 let mut nn = f64::INFINITY;
3222 for j in 0..m {
3223 if i == j {
3224 continue;
3225 }
3226 let dsq = coord_dist_sq(atom, centers.row(i), centers.row(j));
3227 if dsq < nn {
3228 nn = dsq;
3229 }
3230 }
3231 let r = if nn.is_finite() { 0.5 * nn.sqrt() } else { 0.5 };
3232 radii[i] = r.max(1.0e-3).min(0.5);
3233 }
3234 Ok((centers, radii))
3235}
3236
3237pub(crate) fn chart_center_grid(atom: &SaeManifoldAtom, resolution: usize) -> Array2<f64> {
3238 use crate::manifold::SaeAtomBasisKind::*;
3239 let d = atom.latent_dim;
3240 match &atom.basis_kind {
3241 Periodic | Torus => regular_product_grid(d, resolution, 0.0, 1.0, false),
3242 // Cylinder `S¹ × ℝ`: axis 0 is the periodic circle `[0, 1)` (no
3243 // endpoint, like the harmonic axes); axis 1 is the unbounded line,
3244 // covered by a strided unit box `[-0.5, 0.5]` about the origin (like the
3245 // Euclidean patch). The certified radius refines each chart; out-of-cover
3246 // line starts route to the exact fallback honestly.
3247 Cylinder if d == 2 => cylinder_chart_center_grid(resolution),
3248 Cylinder => regular_product_grid(d, resolution, -0.5, 0.5, true),
3249 Sphere if d == 2 => sphere_latlon_grid(resolution),
3250 Linear | Sphere | Duchon | EuclideanPatch | Poincare | Precomputed(_) => {
3251 // Unbounded / non-compact latents: a strided cover of a unit box
3252 // about the origin per axis. The certified radius refines each chart;
3253 // out-of-cover starts route to the exact fallback honestly.
3254 regular_product_grid(d, resolution, -0.5, 0.5, true)
3255 }
3256 }
3257}
3258
3259/// A regular `resolution`-per-axis product grid over `[lo, hi]^d`, capped at
3260/// [`SHAPE_BAND_MAX_POINTS`] total points (the per-axis resolution is reduced
3261/// until the product fits). When `include_endpoint` the last grid point sits at
3262/// `hi`; otherwise the axis is treated as periodic and stops one step short.
3263/// Per-axis resolution actually used by [`regular_product_grid`] after the
3264/// [`SHAPE_BAND_MAX_POINTS`] product cap. Chart radii must be derived from THIS
3265/// (not the raw `resolution`), otherwise for `resolution^d > SHAPE_BAND_MAX_POINTS`
3266/// the grid spacing is coarser than the radius and the charts leave gaps.
3267pub(crate) fn capped_per_axis(d: usize, resolution: usize) -> usize {
3268 let mut per_axis = resolution.max(2);
3269 while per_axis.saturating_pow(d as u32) > SHAPE_BAND_MAX_POINTS && per_axis > 2 {
3270 per_axis -= 1;
3271 }
3272 per_axis
3273}
3274
3275pub(crate) fn regular_product_grid(
3276 d: usize,
3277 resolution: usize,
3278 lo: f64,
3279 hi: f64,
3280 include_endpoint: bool,
3281) -> Array2<f64> {
3282 if d == 0 {
3283 return Array2::<f64>::zeros((1, 0));
3284 }
3285 let per_axis = capped_per_axis(d, resolution);
3286 let total = per_axis.saturating_pow(d as u32).max(1);
3287 let denom = if include_endpoint {
3288 (per_axis.max(2) - 1) as f64
3289 } else {
3290 per_axis as f64
3291 };
3292 let mut grid = Array2::<f64>::zeros((total, d));
3293 let mut idx = vec![0usize; d];
3294 for flat in 0..total {
3295 for axis in 0..d {
3296 let frac = idx[axis] as f64 / denom;
3297 grid[[flat, axis]] = lo + (hi - lo) * frac;
3298 }
3299 for axis in (0..d).rev() {
3300 idx[axis] += 1;
3301 if idx[axis] < per_axis {
3302 break;
3303 }
3304 idx[axis] = 0;
3305 }
3306 }
3307 grid
3308}
3309
3310/// Lat/lon sphere chart grid: `lat ∈ [−π/2, π/2]`, `lon ∈ [−π, π)`, matching
3311/// the [`crate::manifold::SphereChartEvaluator`] convention.
3312pub(crate) fn sphere_latlon_grid(resolution: usize) -> Array2<f64> {
3313 use std::f64::consts::PI;
3314 // Per-axis cap derived from the shared point budget rather than a hardcoded
3315 // literal (#2071): the largest r with r² ≤ SHAPE_BAND_MAX_POINTS. Equals 22
3316 // at the current budget (22²=484≤512), but now tracks the budget if it
3317 // changes instead of silently desyncing from the sibling grid arms.
3318 let r_cap = SHAPE_BAND_MAX_POINTS.isqrt();
3319 let r = resolution.max(2).min(r_cap);
3320 let mut grid = Array2::<f64>::zeros((r * r, 2));
3321 for i in 0..r {
3322 let lat = -PI / 2.0 + PI * (i as f64 + 0.5) / r as f64;
3323 for j in 0..r {
3324 let lon = -PI + 2.0 * PI * (j as f64) / r as f64;
3325 grid[[i * r + j, 0]] = lat;
3326 grid[[i * r + j, 1]] = lon;
3327 }
3328 }
3329 grid
3330}
3331
3332/// Cylinder `S¹ × ℝ` chart-center grid: axis 0 sweeps the periodic circle over
3333/// one period `[0, 1)` (no endpoint, matching the harmonic axis), axis 1 strides
3334/// a unit box `[−0.5, 0.5]` about the origin on the unbounded line (with
3335/// endpoint). Capped at [`SHAPE_BAND_MAX_POINTS`] total centers.
3336pub(crate) fn cylinder_chart_center_grid(resolution: usize) -> Array2<f64> {
3337 let mut per_axis = resolution.max(2);
3338 while per_axis * per_axis > SHAPE_BAND_MAX_POINTS && per_axis > 2 {
3339 per_axis -= 1;
3340 }
3341 let total = per_axis * per_axis;
3342 let line_denom = (per_axis.max(2) - 1) as f64;
3343 let mut grid = Array2::<f64>::zeros((total, 2));
3344 for i in 0..per_axis {
3345 // Periodic axis 0: stop one step short of the period.
3346 let circle = i as f64 / per_axis as f64;
3347 for j in 0..per_axis {
3348 // Line axis 1: include the endpoint of the unit box.
3349 let line = -0.5 + (j as f64) / line_denom;
3350 grid[[i * per_axis + j, 0]] = circle;
3351 grid[[i * per_axis + j, 1]] = line;
3352 }
3353 }
3354 grid
3355}
3356
3357/// Nominal in-chart radius: half the inter-center grid spacing, so charts tile
3358/// the domain. For compact latents this is the grid step; for unbounded latents
3359/// a unit default that the certified radius refines.
3360pub(crate) fn chart_nominal_radius(atom: &SaeManifoldAtom, resolution: usize) -> f64 {
3361 use crate::manifold::SaeAtomBasisKind::*;
3362 match &atom.basis_kind {
3363 Periodic | Torus => 0.5 / (capped_per_axis(atom.latent_dim, resolution) as f64),
3364 // Must use the SAME capped per-axis count `min(resolution, 22)` that
3365 // `sphere_latlon_grid` lays the centers on: the coarsest tiling step is
3366 // the longitude half-spacing `π/r` with `r = min(resolution, 22)`. Deriving
3367 // the radius from the RAW resolution makes it smaller than the grid spacing
3368 // once `resolution > 22`, leaving gaps between charts so rows in the gaps
3369 // spuriously route to the exact fallback (the exact hazard `capped_per_axis`
3370 // documents for the regular grid).
3371 Sphere => std::f64::consts::PI / (resolution.max(2).min(22) as f64),
3372 // Cylinder charts tile two heterogeneous axes (a `[0,1)` periodic step
3373 // and a unit-box line step); the chart radius is a single scalar, so we
3374 // take the tighter (periodic) step `0.5/res` to keep every chart valid
3375 // on both axes. The certified Kantorovich radius refines it per chart.
3376 Cylinder => 0.5 / (capped_per_axis(atom.latent_dim, resolution) as f64),
3377 Linear | Duchon | EuclideanPatch | Poincare | Precomputed(_) => {
3378 1.0 / (resolution.max(2) as f64)
3379 }
3380 }
3381}
3382
3383/// Build the [`ChartRegion`] for a center, attaching the radial r_min / r_max
3384/// bracket for Duchon atoms (the chart's distance range to the kernel centers).
3385pub(crate) fn chart_region(
3386 atom: &SaeManifoldAtom,
3387 center: Array1<f64>,
3388 radius: f64,
3389) -> ChartRegion {
3390 use crate::manifold::SaeAtomBasisKind::*;
3391 let region = ChartRegion::new(center.clone(), radius);
3392 match &atom.basis_kind {
3393 Duchon => {
3394 // r ranges over [‖t_c‖ − radius, ‖t_c‖ + radius] about the single
3395 // origin-anchored center used by the conservative radial bound.
3396 //
3397 // The lower bound must be `max(0, center_norm − radius)` — NOT floored
3398 // at `radius`. When the chart contains the kernel center
3399 // (`center_norm < radius`, true r_min = 0), flooring at `radius`
3400 // would give a finite, NON-CONSERVATIVE `r_min`, causing the
3401 // hessian_sup / third_sup formulas (which divide by r_min) to
3402 // underestimate the Lipschitz constant and potentially grant a false
3403 // Kantorovich certificate. Flooring at `f64::MIN_POSITIVE` instead
3404 // correctly drives the formulas toward ∞, producing a very large L
3405 // that will NEVER certify (rows route to the exact multi-start
3406 // fallback) — conservative and sound.
3407 let center_norm = center.dot(¢er).sqrt();
3408 let r_min = (center_norm - radius).max(f64::MIN_POSITIVE);
3409 let r_max = center_norm + radius;
3410 region.with_radial_bounds(r_min, r_max)
3411 }
3412 // Cylinder has no radial kernel block (it is a harmonic × polynomial
3413 // tensor, not a Duchon radial basis), so it needs no radial r_min/r_max.
3414 Periodic | Sphere | Torus | Cylinder | Linear | EuclideanPatch | Poincare
3415 | Precomputed(_) => region,
3416 }
3417}
3418
3419#[cfg(test)]
3420mod encode_fix_tests {
3421 //! Unit tests for the two `certify_with_basin_warmup` fixes:
3422 //! FIX #3 — wrap-aware chart containment (`in_chart` now measures latent
3423 //! distance on the atom's periodic geometry via
3424 //! [`latent_coordinate_distance`]).
3425 //! FIX #4 — multiplicative sufficient-decrease bound on the warm-up loop
3426 //! ([`warmup_progress_sufficient`]).
3427 use super::*;
3428 use crate::manifold::SaeAtomBasisKind;
3429 use ndarray::{Array1, Array2, Array3};
3430
3431 /// Minimal atom carrying only a `basis_kind`; the wrap-aware metric reads no
3432 /// other field, so a tiny well-formed basis/decoder/penalty suffices.
3433 fn tiny_atom(kind: SaeAtomBasisKind, latent_dim: usize) -> SaeManifoldAtom {
3434 let m = 2usize;
3435 let phi = Array2::<f64>::eye(m);
3436 let jet = Array3::<f64>::zeros((m, m, latent_dim));
3437 let dec = Array2::<f64>::from_elem((m, 1), 0.5);
3438 let smooth = Array2::<f64>::eye(m);
3439 SaeManifoldAtom::new("tiny", kind, latent_dim, phi, jet, dec, smooth)
3440 .expect("tiny atom builds")
3441 }
3442
3443 /// FIX #3: a point on the far side of the wrap seam of a periodic axis is now
3444 /// IN-CHART under the wrap-aware metric, where the old raw-Euclidean sum used
3445 /// by `in_chart` rejected it. This is the exact predicate `in_chart` evaluates
3446 /// (`latent_coordinate_distance(atom, t, center) <= radius`).
3447 #[test]
3448 fn wrap_aware_containment_accepts_seam_point() {
3449 let atom = tiny_atom(SaeAtomBasisKind::Periodic, 1);
3450 let t = Array1::from(vec![0.99f64]);
3451 let center = Array1::from(vec![0.01f64]);
3452 let radius = 0.05;
3453
3454 // Precondition: the OLD raw-Euclidean latent distance is 0.98 — far
3455 // outside the 0.05 ball, so the old `in_chart` REJECTED this iterate.
3456 let raw = (t[0] - center[0]).abs();
3457 assert!(
3458 raw > radius,
3459 "precondition: raw Euclidean distance {raw} must exceed the chart radius {radius}"
3460 );
3461
3462 // The wrap-aware metric sees the true circle distance 0.02 (period 1.0),
3463 // so the seam point is now correctly IN-CHART.
3464 let d = latent_coordinate_distance(&atom, t.view(), center.view());
3465 assert!(
3466 (d - 0.02).abs() < 1e-12,
3467 "wrap-aware distance across the seam must be 0.02, got {d}"
3468 );
3469 assert!(
3470 d <= radius,
3471 "wrap-aware seam point must now be in-chart (d={d} <= r={radius})"
3472 );
3473 }
3474
3475 /// Soundness guard for FIX #3: on a NON-periodic axis the metric must NOT
3476 /// wrap — the same coordinates stay at their full Euclidean separation and
3477 /// (correctly) remain OUT of the small ball. This preserves the
3478 /// never-issue-a-false-certificate invariant for flat-patch families.
3479 #[test]
3480 fn wrap_aware_containment_does_not_wrap_flat_axis() {
3481 let atom = tiny_atom(SaeAtomBasisKind::EuclideanPatch, 1);
3482 let t = Array1::from(vec![0.99f64]);
3483 let center = Array1::from(vec![0.01f64]);
3484 let d = latent_coordinate_distance(&atom, t.view(), center.view());
3485 assert!(
3486 (d - 0.98).abs() < 1e-12,
3487 "a flat (non-periodic) axis must keep the full 0.98 distance, got {d}"
3488 );
3489 assert!(
3490 d > 0.05,
3491 "flat-axis point correctly stays out of a 0.05 ball"
3492 );
3493 }
3494
3495 /// FIX #4: a genuinely converging (Kantorovich-quadratic) `h`-sequence is
3496 /// untouched — every step clears the sufficient-decrease bar, so no
3497 /// converging row is regressed to the fallback.
3498 #[test]
3499 fn converging_h_sequence_is_never_flagged() {
3500 // Quadratic Newton contraction h_{k+1} = h_k^2 from an uncertified start.
3501 let mut h = 0.9f64;
3502 while h > KANTOROVICH_THRESHOLD {
3503 let h_next = h * h;
3504 assert!(
3505 warmup_progress_sufficient(h_next, h),
3506 "converging step {h} -> {h_next} must be accepted (no regression)"
3507 );
3508 h = h_next;
3509 }
3510 }
3511
3512 /// FIX #4: a monotone `h`-sequence decreasing toward a limit ABOVE ½ (the
3513 /// pathological plateau that the old strict-decrease rule spun on until the
3514 /// increments fell below one ulp) now terminates in a BOUNDED, small number
3515 /// of `warmup_progress_sufficient` steps — flag-to-fallback rather than loop.
3516 #[test]
3517 fn plateau_h_sequence_terminates_bounded() {
3518 let limit = 0.65f64; // plateau limit strictly above ½: never certifies
3519 let ratio = 0.8f64; // geometric approach; ratio -> 1 as h -> limit
3520 let mut h = 2.0f64; // uncertified start
3521 let start = h;
3522 let mut accepted = 0usize;
3523 let mut iters = 0usize;
3524 loop {
3525 iters += 1;
3526 // Hard ceiling: proves boundedness. The OLD strict-decrease rule
3527 // would run ~1e15 steps here (h strictly decreases every step toward
3528 // 0.65 and only stalls below one ulp), so tripping this assert would
3529 // signal a regression to the unbounded behavior.
3530 assert!(
3531 iters < 10_000,
3532 "warm-up must terminate in a bounded number of steps, not loop"
3533 );
3534 let h_next = limit + (h - limit) * ratio; // monotone decreasing > limit
3535 if !warmup_progress_sufficient(h_next, h) {
3536 break; // flag to the exact fallback
3537 }
3538 accepted += 1;
3539 h = h_next;
3540 }
3541 // It made real progress before flagging (a warm-up, not an instant bail)…
3542 assert!(
3543 accepted >= 1,
3544 "expected the warm-up to accept at least one contracting step first"
3545 );
3546 // …and it flagged while still strictly above the plateau limit — exactly
3547 // where the OLD rule would have kept spinning (h is still shrinking).
3548 assert!(
3549 h > limit && h < start,
3550 "flagged mid-descent (limit < h={h} < start={start}); old rule would not stop here"
3551 );
3552 // Termination bound was tiny, not astronomical.
3553 assert!(iters < 200, "plateau flagged in {iters} steps (bounded)");
3554 }
3555
3556 /// FIX #4: non-finite `h` (indefinite / blown-up Hessian) is treated as
3557 /// insufficient progress — the row flags rather than being accepted.
3558 #[test]
3559 fn non_finite_h_flags() {
3560 assert!(!warmup_progress_sufficient(f64::NAN, 0.9));
3561 assert!(!warmup_progress_sufficient(f64::INFINITY, 0.9));
3562 assert!(!warmup_progress_sufficient(0.5, f64::NAN));
3563 }
3564
3565 /// BUG (sphere data-driven placement): `coord_dist_sq` must measure sphere
3566 /// coordinates in RADIANS via the canonical `latent_axis_period` — longitude
3567 /// (axis 1) wraps at 2π, latitude (axis 0) does NOT wrap — not at unit period
3568 /// on both axes. The old period-1 wrap collapsed genuinely-far longitudes to
3569 /// distance 0, corrupting `data_driven_chart_centers` for sphere atoms.
3570 #[test]
3571 fn coord_dist_sq_sphere_uses_radian_metric_not_unit_period() {
3572 let atom = tiny_atom(SaeAtomBasisKind::Sphere, 2);
3573 // Longitude differs by 3.0 rad: circle distance min(3, 2π−3)=3
3574 // (2π−3 ≈ 3.283) ⇒ dist² = 9. The old period-1 wrap read 3 mod 1 = 0.
3575 let a = Array1::from(vec![0.0f64, 0.0]);
3576 let b = Array1::from(vec![0.0f64, 3.0]);
3577 let dsq = coord_dist_sq(&atom, a.view(), b.view());
3578 assert!(
3579 (dsq - 9.0).abs() < 1e-9,
3580 "sphere longitude dist² must be 3²=9 (2π radian period), got {dsq}"
3581 );
3582 // Latitude (axis 0) must NOT wrap: 1.2 rad apart ⇒ dist² = 1.44.
3583 let c = Array1::from(vec![0.0f64, 0.0]);
3584 let e = Array1::from(vec![1.2f64, 0.0]);
3585 let dsq_lat = coord_dist_sq(&atom, c.view(), e.view());
3586 assert!(
3587 (dsq_lat - 1.44).abs() < 1e-9,
3588 "sphere latitude must not wrap; 1.2²=1.44, got {dsq_lat}"
3589 );
3590 // Must agree with the certified-encode metric (single source of truth).
3591 let dc = latent_coordinate_distance(&atom, a.view(), b.view());
3592 assert!(
3593 (dc * dc - dsq).abs() < 1e-9,
3594 "coord_dist_sq must equal latent_coordinate_distance² ({} vs {dsq})",
3595 dc * dc
3596 );
3597 }
3598
3599 /// No-regression: periodic / torus axes still wrap at unit period.
3600 #[test]
3601 fn coord_dist_sq_torus_still_wraps_unit_period() {
3602 let atom = tiny_atom(SaeAtomBasisKind::Torus, 1);
3603 let a = Array1::from(vec![0.02f64]);
3604 let b = Array1::from(vec![0.98f64]);
3605 // Wrapped circle distance min(0.96, 0.04) = 0.04 ⇒ dist² = 0.0016.
3606 let dsq = coord_dist_sq(&atom, a.view(), b.view());
3607 assert!(
3608 (dsq - 0.0016).abs() < 1e-12,
3609 "torus unit-period wrap; got {dsq}"
3610 );
3611 }
3612
3613 /// BUG (sphere chart tiling): `chart_nominal_radius` for the sphere must use
3614 /// the SAME capped per-axis count `min(resolution, 22)` as `sphere_latlon_grid`,
3615 /// so the radius covers the (capped) longitude half-spacing π/22 and the charts
3616 /// tile without gaps for resolution > 22 (raw π/40 would leave gaps).
3617 #[test]
3618 fn chart_nominal_radius_sphere_covers_capped_grid_spacing() {
3619 let atom = tiny_atom(SaeAtomBasisKind::Sphere, 2);
3620 let lon_half_spacing = std::f64::consts::PI / 22.0; // r capped at 22
3621 let r = chart_nominal_radius(&atom, 40);
3622 assert!(
3623 r >= lon_half_spacing - 1e-12,
3624 "sphere radius {r} must cover the capped lon half-spacing {lon_half_spacing} \
3625 (no gaps for resolution>22); raw π/40={} would gap",
3626 std::f64::consts::PI / 40.0
3627 );
3628 }
3629
3630 // ---------------------------------------------------------------------
3631 // F1: amplitude-aware chart routing.
3632 // ---------------------------------------------------------------------
3633
3634 /// A single certifiable chart with the given amplitude-1 center reconstruction.
3635 fn chart_with_recon(recon: Vec<f64>) -> CertifiedChart {
3636 CertifiedChart {
3637 region: ChartRegion::new(Array1::zeros(1), 1.0),
3638 lipschitz: 1.0,
3639 beta_center: 1.0,
3640 certified_radius: 1.0,
3641 amortized_jacobian: None,
3642 recon_center: Array1::from(recon),
3643 amortized_base: None,
3644 }
3645 }
3646
3647 fn atlas_two_charts(m1: f64, m2: f64) -> AtomEncodeAtlas {
3648 AtomEncodeAtlas {
3649 atom_index: 0,
3650 latent_dim: 1,
3651 decoder_norm_sum: 1.0,
3652 charts: vec![chart_with_recon(vec![m1]), chart_with_recon(vec![m2])],
3653 }
3654 }
3655
3656 /// F1 counterexample (module review): chart centers reconstruct (amplitude 1)
3657 /// to `m₁ = 1` and `m₂ = 10`. A row `x = 1` at amplitude `z = 0.1`
3658 /// reconstructs as `z·m`, so chart 2 is EXACT (`0.1·10 = 1 = x`) and chart 1 is
3659 /// wrong (`0.1·1 = 0.1`, error `0.9`). Amplitude-blind routing on the
3660 /// amplitude-1 centers picks chart 1 (`|1−1| = 0 < |10−1| = 9`); the
3661 /// amplitude-aware fix picks chart 2.
3662 #[test]
3663 fn f1_routing_scores_amplitude_scaled_reconstruction() {
3664 let atlas = atlas_two_charts(1.0, 10.0);
3665 let x = Array1::from(vec![1.0]);
3666
3667 let (idx, _) = nearest_chart(&atlas, x.view(), 0.1).expect("routes");
3668 assert_eq!(
3669 idx, 1,
3670 "z=0.1: must route to the m=10 chart (z·m=1 exact), not m=1"
3671 );
3672
3673 let ranked = nearest_charts_topk(&atlas, x.view(), 0.1, 2);
3674 assert_eq!(ranked[0], 1, "z=0.1: nearest chart is the m=10 chart");
3675
3676 // Negative amplitude: z=-0.1 makes z·m₂ = -1 (error 2) and z·m₁ = -0.1
3677 // (error 1.1), so chart 1 is now nearer — the sign is respected.
3678 let (idxn, _) = nearest_chart(&atlas, x.view(), -0.1).expect("routes");
3679 assert_eq!(idxn, 0, "z=-0.1: sign-aware routing prefers the m=1 chart");
3680
3681 // No-regression at z=1: recovers the amplitude-1 argmin (m=1 chart exact).
3682 let (idx1, _) = nearest_chart(&atlas, x.view(), 1.0).expect("routes");
3683 assert_eq!(idx1, 0, "z=1 recovers the amplitude-1 nearest (m=1) chart");
3684 assert_eq!(nearest_charts_topk(&atlas, x.view(), 1.0, 1)[0], 0);
3685 }
3686
3687 // ---------------------------------------------------------------------
3688 // F2: the certificate uses the TRUE (un-ridged) Hessian.
3689 // ---------------------------------------------------------------------
3690
3691 /// A basis with constant `Φ`, zero Jacobian, and zero second jet: the
3692 /// reconstruction is locally CONSTANT, so the true encode Hessian is `0`
3693 /// (singular) and the coordinate is non-unique.
3694 #[derive(Debug)]
3695 struct ConstantPhi {
3696 m: usize,
3697 d: usize,
3698 }
3699 impl SaeBasisEvaluator for ConstantPhi {
3700 fn evaluate(
3701 &self,
3702 coords: ndarray::ArrayView2<'_, f64>,
3703 ) -> Result<(Array2<f64>, Array3<f64>), String> {
3704 let n = coords.nrows();
3705 Ok((
3706 Array2::ones((n, self.m)),
3707 Array3::zeros((n, self.m, self.d)),
3708 ))
3709 }
3710 fn second_jet_dyn(
3711 &self,
3712 coords: ndarray::ArrayView2<'_, f64>,
3713 ) -> Option<Result<ndarray::Array4<f64>, String>> {
3714 let n = coords.nrows();
3715 Some(Ok(ndarray::Array4::zeros((n, self.m, self.d, self.d))))
3716 }
3717 fn third_jet_dyn(
3718 &self,
3719 coords: ndarray::ArrayView2<'_, f64>,
3720 ) -> Option<Result<ndarray::Array5<f64>, String>> {
3721 if coords.ncols() != self.d {
3722 return Some(Err(format!(
3723 "ConstantPhi::third_jet_dyn: expected d = {}, got {} coords",
3724 self.d,
3725 coords.ncols()
3726 )));
3727 }
3728 None
3729 }
3730 }
3731
3732 /// F2: a locally-constant reconstruction has a SINGULAR true Hessian (`H = 0`),
3733 /// so the point is NOT a genuine isolated minimum and must NOT be certified.
3734 /// The old code ridged the certified Hessian (`H → ridge·I`, PD), faking
3735 /// `β = 1/ridge`, `η = 0`, `h = 0 ≤ ½` — a FALSE certificate. With the true
3736 /// Hessian (`encode_grad_hess` no longer adds ridge) `beta_eta_newton` sees
3737 /// `λ_min = 0` and refuses.
3738 #[test]
3739 fn f2_certificate_uses_true_hessian_refuses_singular_field() {
3740 let atom = tiny_atom(SaeAtomBasisKind::EuclideanPatch, 1);
3741 let eval = ConstantPhi {
3742 m: atom.basis_size(),
3743 d: 1,
3744 };
3745 let t0 = Array1::from(vec![0.0]);
3746 let x = Array1::from(vec![0.5]);
3747
3748 let (g, h) = encode_grad_hess(&atom, &eval, t0.view(), x.view(), 1.0)
3749 .expect("encode_grad_hess runs")
3750 .expect("second jet present ⇒ Some");
3751 assert!(
3752 h.iter().all(|&v| v == 0.0),
3753 "the TRUE Hessian of a constant reconstruction is 0 — no ridge is added \
3754 to the certified field; got {h:?}"
3755 );
3756 assert!(
3757 g.iter().all(|&v| v == 0.0),
3758 "gradient is 0 at a flat reconstruction"
3759 );
3760
3761 let (cert, _) = row_certificate(&atom, &eval, t0.view(), x.view(), 1.0, 1.0)
3762 .expect("row_certificate runs");
3763 assert!(
3764 !cert.certified(),
3765 "a singular true Hessian must NOT be certified (the old ridged H falsely did)"
3766 );
3767 assert!(
3768 !cert.beta.is_finite(),
3769 "β must be ∞ (uncertifiable), never the ridge-faked 1/ridge; got {}",
3770 cert.beta
3771 );
3772 }
3773
3774 // ---------------------------------------------------------------------
3775 // F3: Duchon atoms are refused (honest refusal, never a fabricated bound).
3776 // ---------------------------------------------------------------------
3777
3778 /// F3: `build_atom_atlas_from_centers` must emit ONLY uncertified charts for a
3779 /// Duchon atom — the closed-form bound available here would fabricate cubic-r³
3780 /// jets and an origin center for a polyharmonic `c·r^(2m−d)` kernel over
3781 /// data-placed centers, risking an under-estimate of `L` (false certificate).
3782 /// The refusal routes every Duchon row to the exact multi-start encode. A
3783 /// non-Duchon control atom is NOT auto-zeroed by this guard.
3784 #[test]
3785 fn f3_duchon_atoms_are_uncertifiable() {
3786 let atom = tiny_atom(SaeAtomBasisKind::Duchon, 1);
3787 let centers = ndarray::array![[0.0_f64], [0.3], [0.7]];
3788 let radii = vec![0.1_f64, 0.1, 0.1];
3789 let atlas = EncodeAtlas::build_atom_atlas_from_centers(
3790 0,
3791 &atom,
3792 centers.view(),
3793 &radii,
3794 1.0,
3795 1.0,
3796 &AtlasConfig::default(),
3797 )
3798 .expect("duchon atlas builds (uncertified)");
3799 assert_eq!(atlas.charts.len(), 3, "one chart per center");
3800 for (i, chart) in atlas.charts.iter().enumerate() {
3801 assert_eq!(
3802 chart.certified_radius, 0.0,
3803 "duchon chart {i} must be uncertified (refused), got r={}",
3804 chart.certified_radius
3805 );
3806 assert!(
3807 chart.amortized_jacobian.is_none(),
3808 "duchon chart {i} must carry no amortized predictor"
3809 );
3810 }
3811 }
3812
3813 // ---------------------------------------------------------------------
3814 // F6: an already-certified warm-up step is never rejected by the progress rule.
3815 // ---------------------------------------------------------------------
3816
3817 /// F6: a step that just crossed into the certified region (`h ≤ ½`) is accepted
3818 /// even when its multiplicative decrease was below the progress floor
3819 /// (`0.501 → 0.499`). Only a STILL-uncertified plateau step is rejected. The
3820 /// old unconditional progress test rejected the `0.501 → 0.499` cross — a false
3821 /// negative that forced an unnecessary exact-solve fallback.
3822 #[test]
3823 fn f6_certified_step_not_rejected_by_progress_rule() {
3824 // 0.501 → 0.499: below the 1/64 multiplicative floor, and NOT the quadratic
3825 // path, so `warmup_progress_sufficient` is false…
3826 assert!(
3827 !warmup_progress_sufficient(0.499, 0.501),
3828 "precondition: this tiny decrease fails the progress bar"
3829 );
3830 // …but because the step is now CERTIFIED it must NOT be rejected (F6).
3831 assert!(
3832 !warmup_should_reject(true, 0.499, 0.501),
3833 "a certified step must be accepted regardless of the progress floor"
3834 );
3835 // A still-uncertified plateau step IS rejected (the guard still bites).
3836 assert!(
3837 warmup_should_reject(false, 0.499, 0.501),
3838 "an uncertified sub-floor step is a plateau and must flag to fallback"
3839 );
3840 // A certified step that ALSO made big progress is accepted (sanity).
3841 assert!(!warmup_should_reject(true, 0.1, 0.9));
3842 // A genuinely-contracting uncertified step is accepted (no regression).
3843 assert!(!warmup_should_reject(false, 0.4, 0.9));
3844 }
3845}