gam_sae/encode.rs
1//! Kantorovich-certified encode atlas (issue #1010).
2//!
3//! Encoding a row `x ∈ ℝᵖ` against a FROZEN SAE dictionary is, per atom `k`,
4//! the coordinate-only Newton problem
5//!
6//! ```text
7//! min_t f_k(t) = ½‖x − z_k · B_kᵀ Φ_k(t)‖² + prior_k(t),
8//! ```
9//!
10//! with the amplitude `z_k` and decoder block `B_k` held fixed (the encode
11//! freezes the dictionary; only the latent coordinate `t` moves). Newton on
12//! `F(t) = ∇f_k(t)` converges quadratically from a start `t₀` into the unique
13//! root in a certified ball whenever the **Newton–Kantorovich** quantity
14//!
15//! ```text
16//! h = β · η · L ≤ ½, β = ‖F'(t₀)⁻¹‖, η = ‖F'(t₀)⁻¹ F(t₀)‖,
17//! ```
18//!
19//! where `L` is a Lipschitz constant of `F'` (the Hessian of `f_k`) on a region
20//! containing the Newton iterates. `h` is CHECKABLE per row in `O(q³)`
21//! (`q = latent_dim`, tiny), so each fast-path encode carries its own
22//! exactness certificate.
23//!
24//! ## The closed-form Hessian-Lipschitz constant `L`
25//!
26//! Write `m(t) = z·BᵀΦ(t) ∈ ℝᵖ` (the reconstruction) and `r(t) = m(t) − x`.
27//! Then `f = ½‖r‖² + prior` and, differentiating three times,
28//!
29//! ```text
30//! ∇³f = 3·sym(J_mᵀ : ∇²m) + ⟨r, ∇³m⟩ + ∇³prior,
31//! ```
32//!
33//! so an operator-norm bound on the chart is
34//!
35//! ```text
36//! L ≤ 3·‖J_m‖·‖∇²m‖ + ‖r‖·‖∇³m‖ + L_prior,
37//! ```
38//!
39//! with `‖∂^g m‖ ≤ |z|·(Σ_m ‖B_{m,:}‖)·B_g`, where `B_g = sup_chart max_m
40//! ‖∂^g Φ_m‖` is the per-column jet sup of the basis family — closed form per
41//! family ([`BasisHessianLipschitz`]). `‖r‖` is bounded by `‖x‖ +
42//! |z|·(Σ_m‖B_{m,:}‖)·B_0`. The ARD/von-Mises prior `L_prior` is a closed-form
43//! constant from the prior strength. Every bound is conservative (an
44//! over-estimate of `L` only SHRINKS the certified radius — it can never
45//! certify a row that does not converge).
46//!
47//! ## Pipeline
48//!
49//! 1. **Offline, per atom** ([`EncodeAtlas::build`]): chart centers `t_c` on the
50//! atom's coordinate grid (the SHAPE_BAND grid idiom), each with a certified
51//! Newton radius `R_c` solved from the Kantorovich inequality at the
52//! worst-case in-chart start.
53//! 2. **Online, per row** ([`EncodeAtlas::certified_encode_row`]): route to the
54//! nearest chart, start from its distilled IFT predictor, take one or two
55//! Newton steps, then the `h ≤ ½` check AT the start point is the per-row
56//! certificate.
57//! 3. **Uncertified tail**: rows whose start fails `h ≤ ½` are FLAGGED (counted
58//! in [`EncodeResult::encode_uncertified_count`]) and must be routed by the
59//! caller to the existing exact multi-start solve. No approximation enters
60//! silently.
61
62use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
63
64use gam_linalg::faer_ndarray::FaerEigh;
65use crate::candidate_index::{
66 AtomFrameSketch, SaeCandidateIndex, auto_candidate_budget,
67};
68use crate::manifold::{
69 AffineCoordinateEvaluator, CylinderHarmonicEvaluator, DuchonCoordinateEvaluator,
70 EuclideanPatchEvaluator, PeriodicHarmonicEvaluator, SaeBasisEvaluator, SaeManifoldAtom,
71 SphereChartEvaluator, TorusHarmonicEvaluator,
72};
73
74use faer::Side;
75
76/// The Kantorovich convergence threshold `h ≤ ½`. Below this the Newton
77/// iteration is guaranteed to converge quadratically into the unique root in
78/// the certified ball; at or above it the start is uncertified.
79pub const KANTOROVICH_THRESHOLD: f64 = 0.5;
80
81/// Row count at or above which the corpus-rate certified-encode batch
82/// (`certified_encode_batch` / `certified_encode_with_index`) fans its
83/// per-row encodes out over rayon. Below this the per-row Newton + chart
84/// routing is cheap enough that the fan-out overhead does not pay; matched to
85/// the same order as the arrow-Schur `SCHUR_MATVEC_PARALLEL_ROW_MIN` gate so
86/// short batches inside an outer atom-level fan-out stay sequential.
87pub(crate) const ENCODE_BATCH_PARALLEL_ROW_MIN: usize = 256;
88
89/// Minimum frame alignment `‖Uₖᵀd‖/‖d‖ ∈ [0,1]` the routed atom must have for an
90/// index-routed encode to be attempted at all — a FIT-QUALITY floor, NOT a
91/// routing-correctness gate (#1026, corrected by the #1777 exact-routing path).
92///
93/// Routing itself is now EXACT: the index-routed encode picks the atom via
94/// [`SaeCandidateIndex::route_exact`], which returns the GLOBAL argmax of the
95/// routing score (the universal-bound LSH fast path, else a full-scan fallback) —
96/// so there is no "missed-better-ungathered-atom" hole left for this constant to
97/// patch. What remains is a different, honest question: even the globally-best
98/// atom may align only weakly with a row (no atom in the dictionary fits it). A
99/// finite alignment below this floor means the best available atom is a poor fit,
100/// so the row is flagged and routed to the exact multi-start fallback rather than
101/// encoded against an atom it barely belongs to. (Previously this same comparison
102/// double-served as a recall proxy; that role is gone — `route_exact` guarantees
103/// recall — leaving only the fit-quality role described here.)
104pub(crate) const CANDIDATE_ROUTING_MIN_ALIGNMENT: f64 = 0.5;
105
106/// Number of nearest charts the CERTIFIED encode refines in before returning the
107/// lowest-reconstruction-error certified result. A single nearest chart is not
108/// globally sound where the decoded manifold folds near itself (both competing
109/// basins' charts reconstruct near the fold, so both rank among the nearest by
110/// ambient distance); refining the top few captures the global basin. For a
111/// unimodal atom all candidates converge to the same root, so K>1 is a no-op.
112pub(crate) const CERTIFIED_ROUTING_TOPK: usize = 4;
113
114/// Newton refinement convergence floor. Once a refinement step's length `‖δ‖`
115/// falls below this (relative to the coordinate scale `1 + ‖t‖`), the iterate has
116/// reached the certified root to f64 resolution: applying the step cannot move `t`
117/// meaningfully, and the remaining fixed-budget steps only re-accumulate round-off.
118/// Stopping there is STRICTLY more accurate than draining a fixed step budget on a
119/// well-conditioned quadratic Newton tail, and it removes that tail's per-step
120/// `evaluate` + `second_jet` cost (the dominant per-row encode work). The batched
121/// and per-row encodes share this rule, so they stay bit-identical.
122pub(crate) const NEWTON_REFINE_CONVERGED_EPS: f64 = 1.0e-12;
123
124/// Global-minimum short-circuit floor for top-K certified routing. The
125/// reconstruction error `‖x − z·m(t)‖` is bounded below by 0, so a certified
126/// candidate whose residual already sits at the ambient noise floor
127/// (`≤ this · (1 + ‖x‖)`) is provably the global optimum over the charts — no
128/// competing chart can reach a strictly lower residual. The remaining candidates'
129/// refinement is then skipped. Conservative (a genuine second basin of the same
130/// target reconstructs the SAME point, so returning the first is a valid encode).
131pub(crate) const CERTIFIED_GLOBAL_MIN_RECON_FLOOR: f64 = 1.0e-11;
132
133/// A chart region on an atom's latent coordinate: a center `t_c` plus a
134/// certified in-chart radius. Over the ball `‖t − t_c‖ ≤ radius` the jet sup
135/// bounds returned by [`BasisHessianLipschitz`] hold, so the Kantorovich
136/// constant `L` computed from them is valid for any start in the ball.
137///
138/// For radial (Duchon) families the chart also carries the minimum kernel-center
139/// distance `exclusion_r_min` (a lower bound on `‖t − c_k‖` over the chart) that
140/// bounds the otherwise-singular `1/r` radial tails (issue #1010).
141#[derive(Debug, Clone)]
142pub struct ChartRegion {
143 /// Chart center coordinate `t_c` (length = latent_dim).
144 pub center: Array1<f64>,
145 /// In-chart radius in the coordinate metric.
146 pub radius: f64,
147 /// For radial (Duchon) families: a lower bound on `‖t − c_k‖` over the
148 /// chart, across every kernel center `c_k`. `None` for non-radial families.
149 pub exclusion_r_min: Option<f64>,
150 /// For radial (Duchon) families: an upper bound on `‖t − c_k‖` over the
151 /// chart, across every kernel center `c_k`. `None` for non-radial families.
152 pub radial_r_max: Option<f64>,
153}
154
155impl ChartRegion {
156 pub fn new(center: Array1<f64>, radius: f64) -> Self {
157 Self {
158 center,
159 radius,
160 exclusion_r_min: None,
161 radial_r_max: None,
162 }
163 }
164
165 pub fn with_radial_bounds(mut self, r_min: f64, r_max: f64) -> Self {
166 self.exclusion_r_min = Some(r_min);
167 self.radial_r_max = Some(r_max);
168 self
169 }
170
171 /// A jet-sup certificate is only meaningful over a genuine region. Even
172 /// families whose bounds are manifold-global constants (the sup over any
173 /// chart equals the global sup) must refuse a malformed chart rather than
174 /// certify garbage geometry.
175 pub(crate) fn assert_valid(&self) {
176 assert!(
177 self.radius.is_finite()
178 && self.radius >= 0.0
179 && self.center.iter().all(|c| c.is_finite()),
180 "ChartRegion must have a finite center and a finite non-negative radius"
181 );
182 }
183}
184
185/// Per-column sup-norm bounds on the first three coordinate jets of a basis
186/// family `Φ(t)`, valid over a stated [`ChartRegion`] (issue #1010). These are
187/// the analytic ingredients of the Hessian-Lipschitz constant `L` — see the
188/// module docs for the assembly. `value_sup` bounds `max_m |Φ_m|`,
189/// `jacobian_sup`/`hessian_sup`/`third_sup` bound `max_m ‖∂^g Φ_m‖`.
190pub trait BasisHessianLipschitz {
191 fn value_sup(&self, chart: &ChartRegion) -> f64;
192 fn jacobian_sup(&self, chart: &ChartRegion) -> f64;
193 fn hessian_sup(&self, chart: &ChartRegion) -> f64;
194 fn third_sup(&self, chart: &ChartRegion) -> f64;
195}
196
197/// Sup over the circle of the `g`-th derivative of any single harmonic column
198/// of a `num_basis`-wide Fourier basis `[1, sin(2π h t), cos(2π h t), …]`:
199/// `(2π·H)^g` for the top harmonic `H = (num_basis − 1)/2`. The constant column
200/// contributes `0` for `g ≥ 1`, so the top harmonic dominates; the bound is
201/// global (the trig magnitudes are `≤ 1` everywhere, independent of the chart).
202pub(crate) fn harmonic_jet_sup(num_basis: usize, order: u32) -> f64 {
203 let top_harmonic = num_basis.saturating_sub(1) / 2;
204 let omega = std::f64::consts::TAU * top_harmonic as f64;
205 omega.powi(order as i32)
206}
207
208impl BasisHessianLipschitz for PeriodicHarmonicEvaluator {
209 fn value_sup(&self, chart: &ChartRegion) -> f64 {
210 chart.assert_valid();
211 1.0
212 }
213 fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
214 chart.assert_valid();
215 harmonic_jet_sup(self.num_basis, 1)
216 }
217 fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
218 chart.assert_valid();
219 harmonic_jet_sup(self.num_basis, 2)
220 }
221 fn third_sup(&self, chart: &ChartRegion) -> f64 {
222 chart.assert_valid();
223 harmonic_jet_sup(self.num_basis, 3)
224 }
225}
226
227impl BasisHessianLipschitz for TorusHarmonicEvaluator {
228 /// Tensor product of per-axis circle harmonics. A torus basis column is a
229 /// product of single-axis harmonics, each bounded as in the circle case.
230 /// The `g`-th coordinate jet routes `g` derivative operators across the
231 /// `latent_dim` factors (Leibniz); each routing contributes a product of
232 /// per-axis derivative magnitudes. A per-column sup is therefore bounded by
233 /// the top single-axis frequency to the `g`-th power times the number of
234 /// such routings (`latent_dim^g`, the count of operator-to-axis maps).
235 fn value_sup(&self, chart: &ChartRegion) -> f64 {
236 chart.assert_valid();
237 1.0
238 }
239 fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
240 chart.assert_valid();
241 torus_jet_sup(self.num_harmonics, self.latent_dim, 1)
242 }
243 fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
244 chart.assert_valid();
245 torus_jet_sup(self.num_harmonics, self.latent_dim, 2)
246 }
247 fn third_sup(&self, chart: &ChartRegion) -> f64 {
248 chart.assert_valid();
249 torus_jet_sup(self.num_harmonics, self.latent_dim, 3)
250 }
251}
252
253/// Per-column `g`-th jet sup for the torus harmonic basis: `(2π·H)^g ·
254/// latent_dim^g`, where `H = num_harmonics` is the top per-axis frequency and
255/// `latent_dim^g` over-counts the Leibniz routings of `g` operators across the
256/// product factors (a conservative bound — each routing's per-axis magnitude is
257/// `≤ (2π H)^{#ops on that axis}`, and the products telescope to `(2π H)^g`).
258pub(crate) fn torus_jet_sup(num_harmonics: usize, latent_dim: usize, order: u32) -> f64 {
259 let omega = std::f64::consts::TAU * num_harmonics as f64;
260 omega.powi(order as i32) * (latent_dim as f64).powi(order as i32)
261}
262
263impl BasisHessianLipschitz for SphereChartEvaluator {
264 /// The 7-column lat/lon chart `[1, x, y, z, xy, yz, xz]` with
265 /// `x = cos(lat)cos(lon)`, `y = cos(lat)sin(lon)`, `z = sin(lat)`. Each of
266 /// `x, y, z` is a product of two unit-frequency trig factors, so its `g`-th
267 /// coordinate jet is a sum of `2^g` products of `{sin,cos}` (each `≤ 1`):
268 /// magnitude `≤ 2^g` for `g ≥ 1`, `≤ 1` for `g = 0`. The bilinear columns
269 /// `xy, yz, xz` are products of two such coordinates; by Leibniz over the
270 /// product, their `g`-th jet is bounded by `Σ_{i=0}^{g} C(g,i)·(2^i)·(2^{g−i})
271 /// = (2+2)^g = 4^g` (using `‖∂^i u‖ ≤ 2^i`, `|u| ≤ 1`). The bilinear columns
272 /// dominate, so the per-column sup is `4^g` (`g ≥ 1`). Bounds are global
273 /// constants — the chart box `lat ∈ [-π/2, π/2]` does not enlarge them.
274 fn value_sup(&self, chart: &ChartRegion) -> f64 {
275 chart.assert_valid();
276 1.0
277 }
278 fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
279 chart.assert_valid();
280 4.0
281 }
282 fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
283 chart.assert_valid();
284 16.0
285 }
286 fn third_sup(&self, chart: &ChartRegion) -> f64 {
287 chart.assert_valid();
288 64.0
289 }
290}
291
292impl BasisHessianLipschitz for AffineCoordinateEvaluator {
293 /// The affine basis `[1, t₁, …, t_d]` is degree ≤ 1: its first jet has unit
294 /// columns, and all second and third jets vanish. The value sup is
295 /// `max(1, ‖t‖)` over the chart, bounded by `1 + ‖t_c‖ + radius`.
296 fn value_sup(&self, chart: &ChartRegion) -> f64 {
297 let center_norm = chart.center.dot(&chart.center).sqrt();
298 1.0 + center_norm + chart.radius
299 }
300 fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
301 chart.assert_valid();
302 1.0
303 }
304 fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
305 chart.assert_valid();
306 0.0
307 }
308 fn third_sup(&self, chart: &ChartRegion) -> f64 {
309 chart.assert_valid();
310 0.0
311 }
312}
313
314impl BasisHessianLipschitz for EuclideanPatchEvaluator {
315 /// Monomials of total degree ≤ `max_degree` in `t ∈ ℝ^d`. Over the ball of
316 /// radius `R` about `t_c`, each coordinate is bounded by `ρ = ‖t_c‖∞ + R`.
317 /// A monomial `t^α` with `|α| = q` has `g`-th partials bounded (crudely) by
318 /// the descending-factorial coefficient `q·(q−1)···(q−g+1) ≤ q^g` times
319 /// `ρ^{max(q−g,0)}`, and there are at most `d^g` partial routings, so the
320 /// per-column `g`-th jet sup is `≤ d^g · D^g · ρ^{max(D−g,0)}` with
321 /// `D = max_degree`. Conservative; D is small for patch evaluators.
322 fn value_sup(&self, chart: &ChartRegion) -> f64 {
323 let rho = patch_rho(chart);
324 let d = self.max_degree as i32;
325 rho.powi(d).max(1.0)
326 }
327 fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
328 patch_jet_sup(self.latent_dim, self.max_degree, chart, 1)
329 }
330 fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
331 patch_jet_sup(self.latent_dim, self.max_degree, chart, 2)
332 }
333 fn third_sup(&self, chart: &ChartRegion) -> f64 {
334 patch_jet_sup(self.latent_dim, self.max_degree, chart, 3)
335 }
336}
337
338impl BasisHessianLipschitz for CylinderHarmonicEvaluator {
339 /// Cylinder `S¹ × ℝ` product basis `Φ_{c,l} = c(t₀)·l(t₁)`, the circle
340 /// (periodic harmonic) factor on axis 0 crossed with the monomial line
341 /// factor on axis 1. Because the two factors depend on disjoint coordinates,
342 /// the order-`g` coordinate jet in any cell is exactly
343 /// `c^{(k₀)}(t₀)·l^{(k₁)}(t₁)` with `k₀ + k₁ = g`, so the per-column sup is
344 /// the max over the split `k₀ + k₁ = g` of the product of the two per-axis
345 /// per-order sups: the circle factor contributes `1` at order 0 and
346 /// `(2π·H)^{k₀}` at order `k₀ ≥ 1` (trig magnitudes `≤ 1`); the line factor
347 /// contributes the monomial-patch sup `D^{k₁}·ρ^{max(D−k₁,0)}` (`D = line
348 /// degree`, `ρ = ‖t_c‖∞ + radius`). Bounds are global in the periodic axis
349 /// and chart-local in the line axis.
350 fn value_sup(&self, chart: &ChartRegion) -> f64 {
351 cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 0)
352 }
353 fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
354 cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 1)
355 }
356 fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
357 cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 2)
358 }
359 fn third_sup(&self, chart: &ChartRegion) -> f64 {
360 cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 3)
361 }
362}
363
364/// Per-column order-`g` jet sup of the cylinder product basis: the max over
365/// `k₀ + k₁ = g` of `circle_axis_sup(k₀) · line_axis_sup(k₁)`, where the circle
366/// axis sup is `(2π·H)^{k₀}` (`1` at `k₀ = 0`) and the line axis sup is the
367/// monomial-patch bound `D^{k₁}·ρ^{max(D−k₁,0)}` (`1` at `k₁ = 0`). See the
368/// [`CylinderHarmonicEvaluator`] doc comment for the derivation.
369pub(crate) fn cylinder_jet_sup(
370 circle_harmonics: usize,
371 line_degree: usize,
372 chart: &ChartRegion,
373 order: u32,
374) -> f64 {
375 let omega = std::f64::consts::TAU * circle_harmonics as f64;
376 let big_d = line_degree as f64;
377 let rho = patch_rho(chart);
378 let mut best = 0.0_f64;
379 for k0 in 0..=order {
380 let k1 = order - k0;
381 let circle = if k0 == 0 { 1.0 } else { omega.powi(k0 as i32) };
382 let line = if k1 == 0 {
383 rho.powi(line_degree as i32).max(1.0)
384 } else {
385 let residual = line_degree.saturating_sub(k1 as usize) as i32;
386 // `.max(1.0)` as in `patch_jet_sup`: for ρ < 1 a lower-degree line
387 // monomial dominates the k1-th derivative, so the bare `ρ^residual`
388 // underestimates the line-factor sup. The value case (k1==0) already
389 // clamps; this completes it for the derivative orders.
390 big_d.powi(k1 as i32) * rho.powi(residual).max(1.0)
391 };
392 best = best.max(circle * line);
393 }
394 best
395}
396
397/// Sup-norm radius `ρ = ‖t_c‖∞ + radius` of the chart (the coordinate magnitude
398/// bound used by the monomial-patch jet bounds).
399pub(crate) fn patch_rho(chart: &ChartRegion) -> f64 {
400 let center_inf = chart
401 .center
402 .iter()
403 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
404 center_inf + chart.radius
405}
406
407/// Per-column `g`-th jet sup for a monomial patch of max degree `D` in `d`
408/// coordinates over the chart: `d^g · D^g · ρ^{max(D−g,0)}` (see the
409/// [`EuclideanPatchEvaluator`] doc comment for the derivation).
410pub(crate) fn patch_jet_sup(
411 latent_dim: usize,
412 max_degree: usize,
413 chart: &ChartRegion,
414 order: u32,
415) -> f64 {
416 let d = latent_dim as f64;
417 let big_d = max_degree as f64;
418 let rho = patch_rho(chart);
419 let residual_degree = max_degree.saturating_sub(order as usize) as i32;
420 // `.max(1.0)`: for ρ < 1 (small charts near the origin) the g-th jet sup is NOT
421 // dominated by the max-degree monomial `t^D` (whose g-th derivative ~ ρ^{D-g}
422 // shrinks with ρ) but by a LOWER-degree monomial whose g-th derivative is a
423 // larger constant — e.g. {1,t,t²}'s jacobian sup is the linear term's constant
424 // `1`, which exceeds `2ρ` when ρ < ½. Without the clamp the bound underestimates
425 // the true sup (numerically: D=3, ρ=0.1, g=1 → formula 0.03 vs true 1.0), which
426 // would make the certificate's Lipschitz `L` too small → a FALSE certificate.
427 // `D^g · max(ρ^{D-g}, 1)` upper-bounds `max_{q∈[g,D]} (q!/(q-g)!)·ρ^{q-g}` for
428 // all ρ (the `q=g` term gives `g! ≤ D^g`, the `q=D` term gives `≤ D^g·ρ^{D-g}`).
429 d.powi(order as i32) * big_d.powi(order as i32) * rho.powi(residual_degree).max(1.0)
430}
431
432impl BasisHessianLipschitz for DuchonCoordinateEvaluator {
433 /// Radial-kernel basis `Φ_m(t) = φ(r_m)`, `r_m = ‖t − c_m‖`, plus a
434 /// polynomial nullspace block. For the cubic Duchon kernel `φ(r) = r³` the
435 /// radial derivatives are `φ' = 3r²`, `φ'' = 6r`, `φ''' = 6`. The chain rule
436 /// to coordinate jets introduces `1/r` factors through the unit radial
437 /// direction `u = (t − c)/r` and the projector `(I − uuᵀ)/r`, so over a
438 /// chart the jets are bounded by combining the radial-derivative magnitudes
439 /// at the worst-case radius with the inverse-radius tail at the chart's
440 /// EXCLUSION radius `r_min` (the closest a chart point gets to any center):
441 ///
442 /// ```text
443 /// ‖∇φ‖ ≤ |φ'| ≤ 3 r_max²
444 /// ‖∇²φ‖ ≤ |φ''| + |φ'|/r ≤ 6 r_max + 3 r_max²/r_min
445 /// ‖∇³φ‖ ≤ |φ'''| + 3|φ''|/r + 3|φ'|/r² ≤ 6 + 18 r_max/r_min + 9 r_max²/r_min²
446 /// ```
447 ///
448 /// (the `1/r`, `1/r²` tails are bounded by `1/r_min`, `1/r_min²`). The
449 /// polynomial nullspace block is degree ≤ `order`; its jets are bounded like
450 /// the monomial patch with `D = order`. The per-column sup is the max of the
451 /// kernel and polynomial bounds. The `r³` kernel is itself `C²` (no
452 /// singularity) so these tails are conservative but finite for any
453 /// `r_min > 0`; the atlas refines charts to keep `r_min` bounded away from 0.
454 fn value_sup(&self, chart: &ChartRegion) -> f64 {
455 let r_max = chart.radial_r_max.unwrap_or(chart.radius);
456 let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 0);
457 (r_max.powi(3)).max(poly)
458 }
459 fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
460 let r_max = chart.radial_r_max.unwrap_or(chart.radius);
461 let kernel = 3.0 * r_max * r_max;
462 let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 1);
463 kernel.max(poly)
464 }
465 fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
466 let r_max = chart.radial_r_max.unwrap_or(chart.radius);
467 let r_min = chart
468 .exclusion_r_min
469 .unwrap_or(chart.radius)
470 .max(f64::MIN_POSITIVE);
471 let kernel = 6.0 * r_max + 3.0 * r_max * r_max / r_min;
472 let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 2);
473 kernel.max(poly)
474 }
475 fn third_sup(&self, chart: &ChartRegion) -> f64 {
476 let r_max = chart.radial_r_max.unwrap_or(chart.radius);
477 let r_min = chart
478 .exclusion_r_min
479 .unwrap_or(chart.radius)
480 .max(f64::MIN_POSITIVE);
481 let kernel = 6.0 + 18.0 * r_max / r_min + 9.0 * r_max * r_max / (r_min * r_min);
482 let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 3);
483 kernel.max(poly)
484 }
485}
486
487/// Polynomial-block degree of a Duchon nullspace order, used to bound the
488/// nullspace columns like a monomial patch.
489trait DuchonOrderDegree {
490 fn order_degree(&self) -> usize;
491}
492
493impl DuchonOrderDegree for DuchonCoordinateEvaluator {
494 fn order_degree(&self) -> usize {
495 match self.order {
496 gam_terms::basis::DuchonNullspaceOrder::Zero => 0,
497 gam_terms::basis::DuchonNullspaceOrder::Linear => 1,
498 gam_terms::basis::DuchonNullspaceOrder::Degree(d) => d,
499 }
500 }
501}
502
503/// Per-column `g`-th jet sup of the Duchon polynomial nullspace block, treated
504/// as a monomial patch of degree `order_degree`.
505pub(crate) fn duchon_poly_jet_sup(
506 latent_dim: usize,
507 order_degree: usize,
508 chart: &ChartRegion,
509 order: u32,
510) -> f64 {
511 if order_degree == 0 {
512 return if order == 0 { 1.0 } else { 0.0 };
513 }
514 patch_jet_sup(latent_dim, order_degree, chart, order)
515}
516
517/// Decoder magnitude `Σ_m ‖B_{m,:}‖₂` of an atom's frozen decoder block: the
518/// factor that converts a per-column `Φ`-jet sup `B_g` into a reconstruction
519/// jet sup `‖∂^g m‖ ≤ |z|·decoder_row_norm_sum·B_g`.
520pub(crate) fn decoder_row_norm_sum(decoder: ArrayView2<'_, f64>) -> f64 {
521 let mut acc = 0.0;
522 for row in decoder.rows() {
523 acc += row.dot(&row).sqrt();
524 }
525 acc
526}
527
528#[derive(Debug, Clone, Copy)]
529pub(crate) struct ReconstructionJetSups {
530 pub(crate) value: f64,
531 pub(crate) jacobian: f64,
532 pub(crate) hessian: f64,
533 pub(crate) third: f64,
534}
535
536pub(crate) fn pair_trig_decoder_sup(
537 sin_row: ArrayView1<'_, f64>,
538 cos_row: ArrayView1<'_, f64>,
539) -> f64 {
540 let aa = sin_row.dot(&sin_row);
541 let bb = cos_row.dot(&cos_row);
542 let ab = sin_row.dot(&cos_row);
543 let trace = aa + bb;
544 let disc = ((aa - bb) * (aa - bb) + 4.0 * ab * ab).sqrt();
545 (0.5 * (trace + disc)).sqrt()
546}
547
548pub(crate) fn periodic_reconstruction_jet_sups(
549 decoder: ArrayView2<'_, f64>,
550) -> ReconstructionJetSups {
551 let mut value = 0.0;
552 let mut jacobian = 0.0;
553 let mut hessian = 0.0;
554 let mut third = 0.0;
555 if decoder.nrows() > 0 {
556 value += decoder.row(0).dot(&decoder.row(0)).sqrt();
557 }
558 let harmonics = decoder.nrows().saturating_sub(1) / 2;
559 for h in 1..=harmonics {
560 let sin_idx = 2 * h - 1;
561 let cos_idx = 2 * h;
562 let amp = pair_trig_decoder_sup(decoder.row(sin_idx), decoder.row(cos_idx));
563 let omega = std::f64::consts::TAU * h as f64;
564 value += amp;
565 jacobian += omega * amp;
566 hessian += omega.powi(2) * amp;
567 third += omega.powi(3) * amp;
568 }
569 for row in (1 + 2 * harmonics)..decoder.nrows() {
570 let amp = decoder.row(row).dot(&decoder.row(row)).sqrt();
571 value += amp;
572 let omega = std::f64::consts::TAU * harmonics.max(1) as f64;
573 jacobian += omega * amp;
574 hessian += omega.powi(2) * amp;
575 third += omega.powi(3) * amp;
576 }
577 ReconstructionJetSups {
578 value,
579 jacobian,
580 hessian,
581 third,
582 }
583}
584
585pub(crate) fn reconstruction_jet_sups(
586 atom: &SaeManifoldAtom,
587 sups: JetSups,
588) -> ReconstructionJetSups {
589 if matches!(
590 atom.basis_kind,
591 crate::manifold::SaeAtomBasisKind::Periodic
592 ) {
593 periodic_reconstruction_jet_sups(atom.decoder_coefficients.view())
594 } else {
595 let decoder_norm_sum = decoder_row_norm_sum(atom.decoder_coefficients.view());
596 ReconstructionJetSups {
597 value: decoder_norm_sum * sups.value,
598 jacobian: decoder_norm_sum * sups.jacobian,
599 hessian: decoder_norm_sum * sups.hessian,
600 third: decoder_norm_sum * sups.third,
601 }
602 }
603}
604
605/// The Hessian-Lipschitz constant `L` of the per-row encode objective `f_k` on
606/// a chart, assembled in closed form from the basis jet sups and the decoder /
607/// amplitude / target magnitudes. See the module docs for the derivation:
608///
609/// ```text
610/// L ≤ 3·‖J_m‖·‖∇²m‖ + ‖r‖·‖∇³m‖ + L_prior,
611/// ‖∂^g m‖ ≤ |z|·S_B·B_g, S_B = Σ_m ‖B_{m,:}‖,
612/// ‖r‖ ≤ ‖x‖ + |z|·S_B·B_0,
613/// ```
614///
615/// `prior_lipschitz` is the caller-supplied closed-form `L_prior` of the
616/// ARD/von-Mises coordinate prior (`0.0` if no prior is active on the encode).
617pub(crate) fn hessian_lipschitz_constant(
618 recon_sups: ReconstructionJetSups,
619 amplitude: f64,
620 target_norm: f64,
621 prior_lipschitz: f64,
622) -> f64 {
623 let z = amplitude.abs();
624 let m_jac = z * recon_sups.jacobian;
625 let m_hess = z * recon_sups.hessian;
626 let m_third = z * recon_sups.third;
627 let recon_value = z * recon_sups.value;
628 let r_norm = target_norm + recon_value;
629 3.0 * m_jac * m_hess + r_norm * m_third + prior_lipschitz
630}
631
632/// One offline-certified chart: a center, its Kantorovich constants, and the
633/// certified Newton-convergence radius `R_c` solved from `h = β·η·L ≤ ½` at the
634/// worst-case in-chart start.
635#[derive(Debug, Clone)]
636pub struct CertifiedChart {
637 pub region: ChartRegion,
638 /// Closed-form Hessian-Lipschitz constant `L` over the chart.
639 pub lipschitz: f64,
640 /// `β = ‖F'(t_c)⁻¹‖` at the chart center (worst-case in-chart start uses
641 /// the center's curvature; the radius is solved so the certificate holds for
642 /// any start in the ball).
643 pub beta_center: f64,
644 /// Certified Newton radius: starts within `radius` of `t_c` satisfy `h ≤ ½`.
645 pub certified_radius: f64,
646 /// Distilled amortized-encoder Jacobian for this chart (#1026 ladder item 3).
647 ///
648 /// The exact encode map `x ↦ t` solves `F(t; x) = J_m(t)ᵀ(m(t) − x) = 0`. By
649 /// the implicit function theorem its derivative at the converged root is
650 /// `dt/dx = −(∂_t F)⁻¹ (∂_x F) = H⁻¹ J_m` (since `∂_x F = −J_m`), so the
651 /// first-order Taylor expansion of the encode map about this chart's center
652 /// `t_c` is the closed-form AFFINE predictor
653 ///
654 /// ```text
655 /// t(x) ≈ t_c + (1/z) · A₁ · (x − z · m₁(t_c)), A₁ = (J₁ᵀJ₁ + ridge·I)⁻¹ J₁,
656 /// ```
657 ///
658 /// with `J₁ = Bᵀ J_Φ(t_c)` and `m₁(t_c) = BᵀΦ(t_c)` the AMPLITUDE-1
659 /// reconstruction jets (the amplitude `z` factors out analytically, so the
660 /// stored Jacobian is amplitude-free). This is the DISTILLED amortized
661 /// encoder of the #1026 thread: the per-row Hessian factorization + Newton
662 /// iteration is moved OFFLINE into this `d × p` matrix, leaving a single
663 /// `O(d·p)` mat-vec online — no per-row eigendecomposition, no second-jet
664 /// evaluation. The Kantorovich certificate is still evaluated AT the
665 /// predicted start, so the amortized prediction is trusted iff `h ≤ ½` and an
666 /// uncertified row still routes to the exact multi-start solve (the encoder
667 /// approximates inference, the certificate keeps it honest — the thread's
668 /// "encoder + certificate-gated exact fallback" deployment). `None` when the
669 /// center's Gauss–Newton block is singular (no certifiable amortization).
670 pub amortized_jacobian: Option<Array2<f64>>,
671 /// Amplitude-1 chart-center reconstruction `m₁(t_c) = BᵀΦ(t_c)` (length `p`),
672 /// the anchor the amortized predictor expands the encode map around.
673 pub recon_center: Array1<f64>,
674 /// Precomputed affine-predictor CONSTANT term `base = t_c − A₁·m₁(t_c)` (length
675 /// `d`), so the online amortized encode of a row `x` at amplitude `z` is the
676 /// single mat-vec `t̂ = base + (1/z)·A₁·x` with NO per-row `A₁·m₁` recompute.
677 /// Hoisting this atom-static term offline is what lets the massive-K index-routed
678 /// fast paths run a single allocation-free pass over rows (rather than a per-atom
679 /// GEMM sub-batch that degenerates to one row per group when `K ≫ N`). `None`
680 /// exactly when `amortized_jacobian` is `None` (singular Gauss–Newton block).
681 pub amortized_base: Option<Array1<f64>>,
682}
683
684/// The per-atom encode atlas: a set of certified charts covering the atom's
685/// coordinate domain, plus the decoder/amplitude scaling needed to recompute a
686/// per-row certificate online.
687#[derive(Debug, Clone)]
688pub struct AtomEncodeAtlas {
689 pub atom_index: usize,
690 pub latent_dim: usize,
691 pub decoder_norm_sum: f64,
692 pub charts: Vec<CertifiedChart>,
693}
694
695/// Result of a certified encode over a batch of rows, carrying the honesty
696/// flag: how many rows could NOT be certified and were flagged for the exact
697/// multi-start fallback (issue #1010 — no approximation enters silently).
698#[derive(Debug, Clone)]
699pub struct EncodeResult {
700 /// Per-row encoded latent coordinates (`n_rows × latent_dim`).
701 pub coords: Array2<f64>,
702 /// Per-row certificate: `true` ⇒ the row's start satisfied `h ≤ ½` and the
703 /// 1–2 Newton steps are exact-into-the-certified-ball; `false` ⇒ flagged.
704 pub certified: Vec<bool>,
705 /// Count of rows that could not be certified. These ride the payload so the
706 /// caller routes them to the exact multi-start encode — honesty, never
707 /// silent. Equals `certified.iter().filter(|c| !**c).count()`.
708 pub encode_uncertified_count: usize,
709}
710
711impl EncodeResult {
712 pub(crate) fn from_rows(coords: Array2<f64>, certified: Vec<bool>) -> Self {
713 let encode_uncertified_count = certified.iter().filter(|c| !**c).count();
714 Self {
715 coords,
716 certified,
717 encode_uncertified_count,
718 }
719 }
720}
721
722/// Per-row Kantorovich certificate at a start `t₀` for one atom encode.
723#[derive(Debug, Clone, Copy)]
724pub struct RowCertificate {
725 pub beta: f64,
726 pub eta: f64,
727 pub lipschitz: f64,
728 /// `h = β·η·L`. The row is certified iff `h ≤ ½`.
729 pub h: f64,
730}
731
732impl RowCertificate {
733 pub fn certified(&self) -> bool {
734 self.h.is_finite() && self.h <= KANTOROVICH_THRESHOLD
735 }
736}
737
738#[derive(Debug, Clone)]
739struct CertifiedEncodeProbe {
740 coord: Array1<f64>,
741 initial_cert: RowCertificate,
742 final_cert: RowCertificate,
743}
744
745/// Canonical flat-axis polynomial degree of a cylinder `S¹ × ℝ` atom — the
746/// degree the topology-race builder ([`gam_solve::structure_harvest`]) uses
747/// for the line axis (`CylinderHarmonicEvaluator::new(_, 2)`). The encode atlas
748/// recovers the circle harmonic count from the basis width using this degree, so
749/// the two must agree.
750pub(crate) const SAE_CYLINDER_LINE_DEGREE: usize = 2;
751
752/// Build a basis-family handle for one atom from its [`SaeManifoldAtom`]. The
753/// atlas needs to evaluate the jet sups, which live on the concrete evaluator
754/// types; the atom carries the evaluator as `Arc<dyn SaeBasisEvaluator>`, so we
755/// reconstruct the family bound from the atom's basis kind + width + centers.
756pub(crate) fn family_jet_sups(
757 atom: &SaeManifoldAtom,
758 chart: &ChartRegion,
759) -> Result<JetSups, String> {
760 use crate::manifold::SaeAtomBasisKind::*;
761 let m = atom.basis_size();
762 let d = atom.latent_dim;
763 let sups = match &atom.basis_kind {
764 Periodic => {
765 let ev = PeriodicHarmonicEvaluator::new(m)?;
766 JetSups::from_family(&ev, chart)
767 }
768 Torus => {
769 // Torus basis width is `(2H+1)^d`; recover the per-axis harmonic
770 // count `H` from `axis_m = m^(1/d)` rather than a sum formula.
771 let axis_m = integer_root(m, d.max(1));
772 let num_harmonics = axis_m.saturating_sub(1) / 2;
773 let ev = TorusHarmonicEvaluator::new(d, num_harmonics.max(1))?;
774 JetSups::from_family(&ev, chart)
775 }
776 Sphere => {
777 let ev = SphereChartEvaluator;
778 JetSups::from_family(&ev, chart)
779 }
780 Cylinder => {
781 // Cylinder width is `(2H+1)·(D+1)` with the canonical flat-axis
782 // degree `D = SAE_CYLINDER_LINE_DEGREE` (the harvest convention).
783 // Recover the per-axis circle harmonic count `H` from
784 // `2H+1 = m/(D+1)`.
785 let ml = SAE_CYLINDER_LINE_DEGREE + 1;
786 if d != 2 || ml == 0 || m % ml != 0 {
787 return Err(format!(
788 "EncodeAtlas: Cylinder atom requires latent_dim == 2 and width divisible by {ml}; got dim={d}, m={m}"
789 ));
790 }
791 let axis_mc = m / ml;
792 let h = axis_mc.saturating_sub(1) / 2;
793 let ev = CylinderHarmonicEvaluator::new(h.max(1), SAE_CYLINDER_LINE_DEGREE)?;
794 JetSups::from_family(&ev, chart)
795 }
796 Linear | EuclideanPatch | Poincare => {
797 // The patch width fixes max_degree implicitly; bound by a degree that
798 // covers the column count (conservative). Degree d-patch column count
799 // grows fast; we recover the smallest degree whose patch is ≥ m.
800 // Poincare atoms use the same tangent-coordinate polynomial decoder;
801 // their intrinsic smoothness differs in the penalty, not in Phi(t).
802 let degree = euclidean_patch_degree(d, m);
803 let ev = EuclideanPatchEvaluator::new(d, degree)?;
804 JetSups::from_family(&ev, chart)
805 }
806 Duchon => {
807 // The atom carries the basis kind but not the nullspace order, and
808 // the certificate needs an UPPER bound on L. The kernel-tail bound
809 // (cubic r³ coefficients vs the chart's r_min/r_max) is independent
810 // of the constructed order; the polynomial-block bound grows with the
811 // order, so we construct with a conservative order whose polynomial
812 // degree upper-bounds any nullspace the atom's basis width can hold.
813 // Constructing with `m = basis_size` maps to `Degree(basis_size − 1)`
814 // — an over-estimate that keeps the Lipschitz bound sound.
815 let centers = duchon_centers_from_atom(atom);
816 let conservative_m = m.max(1);
817 let ev = DuchonCoordinateEvaluator::new(centers, conservative_m)?;
818 JetSups::from_family(&ev, chart)
819 }
820 Precomputed(name) => {
821 return Err(format!(
822 "EncodeAtlas: precomputed basis '{name}' has no closed-form jet sup; route to exact encode"
823 ));
824 }
825 };
826 Ok(sups)
827}
828
829/// Smallest monomial-patch degree whose column count covers `m` basis columns.
830pub(crate) fn euclidean_patch_degree(latent_dim: usize, m: usize) -> usize {
831 // Column count of a degree-D patch in d vars is C(d+D, D). Grow D until it
832 // covers m; cap at m so a degenerate width still terminates.
833 let mut degree = 0usize;
834 while patch_column_count(latent_dim, degree) < m && degree < m {
835 degree += 1;
836 }
837 degree
838}
839
840/// Largest integer `a` with `a^k ≤ n` (the floor of the `k`-th root). Used to
841/// recover the per-axis harmonic width `axis_m` from a torus basis width
842/// `m = axis_m^d`.
843pub(crate) fn integer_root(n: usize, k: usize) -> usize {
844 if k == 0 {
845 return 1;
846 }
847 if k == 1 {
848 return n;
849 }
850 let mut a = 1usize;
851 loop {
852 let next = a + 1;
853 let mut pow: u128 = 1;
854 let mut overflow = false;
855 for _ in 0..k {
856 pow = pow.saturating_mul(next as u128);
857 if pow > n as u128 {
858 overflow = true;
859 break;
860 }
861 }
862 if overflow {
863 return a;
864 }
865 a = next;
866 }
867}
868
869pub(crate) fn patch_column_count(latent_dim: usize, degree: usize) -> usize {
870 // C(d + D, D)
871 let mut num = 1u128;
872 let mut den = 1u128;
873 for i in 1..=degree {
874 num *= (latent_dim + i) as u128;
875 den *= i as u128;
876 }
877 (num / den) as usize
878}
879
880/// Recover Duchon centers from an atom: when the evaluator is unavailable the
881/// atlas falls back to the atom's own latent-coordinate hull as the center set,
882/// which only affects the radial-tail bound conservatively.
883pub(crate) fn duchon_centers_from_atom(atom: &SaeManifoldAtom) -> Array2<f64> {
884 // One center at the origin in latent_dim space is a sound conservative
885 // default: the chart's own r_min / r_max bracket the true radial range.
886 Array2::<f64>::zeros((1, atom.latent_dim.max(1)))
887}
888
889/// The four per-column jet sups of a basis family over a chart.
890#[derive(Debug, Clone, Copy)]
891pub(crate) struct JetSups {
892 pub(crate) value: f64,
893 pub(crate) jacobian: f64,
894 pub(crate) hessian: f64,
895 pub(crate) third: f64,
896}
897
898impl JetSups {
899 pub(crate) fn from_family<B: BasisHessianLipschitz>(family: &B, chart: &ChartRegion) -> Self {
900 Self {
901 value: family.value_sup(chart),
902 jacobian: family.jacobian_sup(chart),
903 hessian: family.hessian_sup(chart),
904 third: family.third_sup(chart),
905 }
906 }
907}
908
909/// Evaluate one atom's encode objective gradient `F(t) = ∇f_k(t)` and the FULL
910/// Hessian `F'(t) = ∇²f_k(t)` at a single coordinate `t`, for a single target
911/// row `x` and fixed amplitude `z`. With `m(t) = z·BᵀΦ(t)`, `r = m − x`,
912/// `J_m = z·Bᵀ J_Φ`:
913///
914/// ```text
915/// g_t[a] = J_m[a] · r (= ∇f)
916/// H_tt[a,b] = J_m[a] · J_m[b] + r · ∂²m/∂t_a∂t_b (= ∇²f, FULL Hessian)
917/// ```
918///
919/// The certificate uses the FULL Hessian rather than the Gauss-Newton block
920/// `J_mᵀ J_m`. This is the principled choice for Newton–Kantorovich: the
921/// theorem certifies convergence of Newton on `F = ∇f` to the unique nearby
922/// ROOT of `∇f`, but a root of `∇f` can be a maximum. The full Hessian is
923/// positive-definite exactly on the genuine-minimum basin, so requiring
924/// `λ_min(H) > 0` (finite `β`) is what flags a start that would otherwise let
925/// Gauss-Newton march into the wrong root (e.g. the circle antipode, a local
926/// max where `∇f = 0` but the full curvature is negative). The residual term
927/// needs the basis second jet `∂²Φ/∂t²`; an evaluator without one returns
928/// `None`, and the row is flagged (no silent Gauss-Newton fallback).
929pub(crate) fn encode_grad_hess(
930 atom: &SaeManifoldAtom,
931 evaluator: &dyn SaeBasisEvaluator,
932 t: ArrayView1<'_, f64>,
933 x: ArrayView1<'_, f64>,
934 amplitude: f64,
935 ridge: f64,
936) -> Result<Option<(Array1<f64>, Array2<f64>)>, String> {
937 let d = atom.latent_dim;
938 let p = atom.output_dim();
939 let m = atom.basis_size();
940 let coords = t.to_shape((1, d)).map_err(|e| e.to_string())?.to_owned();
941 let (phi, jet) = evaluator.evaluate(coords.view())?;
942 if phi.dim() != (1, m) {
943 return Err(format!(
944 "encode_grad_hess: evaluator returned phi {:?}, expected (1, {m})",
945 phi.dim()
946 ));
947 }
948 let decoder = &atom.decoder_coefficients;
949 // Reconstruction m(t) = z · Bᵀ Φ(t) ∈ ℝᵖ.
950 let mut recon = Array1::<f64>::zeros(p);
951 for basis_col in 0..m {
952 let phi_v = phi[[0, basis_col]];
953 if phi_v == 0.0 {
954 continue;
955 }
956 for out in 0..p {
957 recon[out] += amplitude * phi_v * decoder[[basis_col, out]];
958 }
959 }
960 let residual = &recon - &x;
961 // J_m[axis] = z · Bᵀ (∂Φ/∂t_axis) ∈ ℝᵖ.
962 let mut jm = Array2::<f64>::zeros((d, p));
963 for axis in 0..d {
964 for basis_col in 0..m {
965 let dphi = jet[[0, basis_col, axis]];
966 if dphi == 0.0 {
967 continue;
968 }
969 for out in 0..p {
970 jm[[axis, out]] += amplitude * dphi * decoder[[basis_col, out]];
971 }
972 }
973 }
974 // The full-Hessian residual term needs ∂²Φ/∂t². No second jet ⇒ no
975 // certificate (flag), never a silent Gauss-Newton substitute.
976 let second = match evaluator.second_jet_dyn(coords.view()) {
977 Some(result) => result?,
978 None => return Ok(None),
979 };
980 // Residual · decoder-row `r·B_{basis,:}` is INDEPENDENT of the (a,b) axes, yet
981 // the old code recomputed it `d²` times inside the Hessian double loop. Hoist it
982 // to one O(m·p) pass so the per-axis curvature term is a cheap O(m) dot.
983 let mut rd = vec![0.0_f64; m];
984 for (basis_col, rd_col) in rd.iter_mut().enumerate() {
985 let mut dot = 0.0;
986 for out in 0..p {
987 dot += residual[out] * decoder[[basis_col, out]];
988 }
989 *rd_col = dot;
990 }
991 // g_t[axis] = J_m[axis] · r ; H_tt[a,b] = J_m[a]·J_m[b] + r·∂²m/∂t_a∂t_b.
992 // The full Hessian is symmetric (Gauss-Newton block + symmetric second jet), so
993 // compute the upper triangle and mirror — half the curvature work.
994 let mut g = Array1::<f64>::zeros(d);
995 let mut h = Array2::<f64>::zeros((d, d));
996 for a in 0..d {
997 let ja = jm.row(a);
998 g[a] = ja.dot(&residual);
999 for b in a..d {
1000 // Gauss-Newton block.
1001 let mut hab = ja.dot(&jm.row(b));
1002 // Residual · second-jet curvature: r · ∂²m_{ab},
1003 // ∂²m_{ab}[out] = z · Σ_basis (∂²Φ/∂t_a∂t_b) · B[basis, out].
1004 let mut curv = 0.0;
1005 for basis_col in 0..m {
1006 let d2phi = second[[0, basis_col, a, b]];
1007 if d2phi == 0.0 {
1008 continue;
1009 }
1010 curv += amplitude * d2phi * rd[basis_col];
1011 }
1012 hab += curv;
1013 h[[a, b]] = hab;
1014 h[[b, a]] = hab;
1015 }
1016 }
1017 for a in 0..d {
1018 h[[a, a]] += ridge;
1019 }
1020 Ok(Some((g, h)))
1021}
1022
1023/// Operator-norm of `H⁻¹` (i.e. `β = 1/λ_min(H)`) and the Newton step
1024/// `δ = −H⁻¹ g` with `η = ‖δ‖`, from a symmetric PSD `H` and gradient `g`.
1025/// Returns `None` when `H` is numerically singular (λ_min ≤ 0) — an
1026/// uncertifiable start.
1027pub(crate) fn beta_eta_newton(
1028 h: ArrayView2<'_, f64>,
1029 g: ArrayView1<'_, f64>,
1030) -> Result<Option<(f64, f64, Array1<f64>)>, String> {
1031 // Closed-form fast paths for the tiny latent dims that dominate SAE atoms
1032 // (`d = 1, 2`), avoiding a faer eigendecomposition + its heap allocations on
1033 // the hottest per-row Newton inner loop. `β = 1/λ_min(H)`, `δ = −H⁻¹g`, and the
1034 // `λ_min ≤ 0` gate are computed directly; the symmetric-`H` reads mirror
1035 // `eigh(Side::Lower)` (which uses the lower triangle) exactly.
1036 let d = h.nrows();
1037 if d == 1 {
1038 let h00 = h[[0, 0]];
1039 if !(h00.is_finite() && h00 > 0.0) {
1040 return Ok(None);
1041 }
1042 let delta0 = -g[0] / h00;
1043 let mut delta = Array1::<f64>::zeros(1);
1044 delta[0] = delta0;
1045 return Ok(Some((1.0 / h00, delta0.abs(), delta)));
1046 }
1047 if d == 2 {
1048 // Symmetric H = [[a, b], [b, c]] read from the lower triangle.
1049 let a = h[[0, 0]];
1050 let b = h[[1, 0]];
1051 let c = h[[1, 1]];
1052 let tr = a + c;
1053 let det = a * c - b * b;
1054 // λ_min = ½(tr − √((a−c)² + 4b²)); ≥ 0 ⇒ H PSD, > 0 ⇒ PD.
1055 let disc = ((a - c) * (a - c) + 4.0 * b * b).max(0.0).sqrt();
1056 let lambda_min = 0.5 * (tr - disc);
1057 if !(lambda_min.is_finite() && lambda_min > 0.0) {
1058 return Ok(None);
1059 }
1060 // δ = −H⁻¹g with H⁻¹ = [[c, −b], [−b, a]] / det (det = λ_min·λ_max > 0).
1061 let inv_det = 1.0 / det;
1062 let g0 = g[0];
1063 let g1 = g[1];
1064 let d0 = -(c * g0 - b * g1) * inv_det;
1065 let d1 = -(a * g1 - b * g0) * inv_det;
1066 if !(d0.is_finite() && d1.is_finite()) {
1067 return Ok(None);
1068 }
1069 let mut delta = Array1::<f64>::zeros(2);
1070 delta[0] = d0;
1071 delta[1] = d1;
1072 let eta = (d0 * d0 + d1 * d1).sqrt();
1073 return Ok(Some((1.0 / lambda_min, eta, delta)));
1074 }
1075 let (vals, vecs) = h
1076 .eigh(Side::Lower)
1077 .map_err(|e| format!("beta_eta_newton: eigh failed: {e:?}"))?;
1078 let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
1079 if !(lambda_min.is_finite() && lambda_min > 0.0) {
1080 return Ok(None);
1081 }
1082 let beta = 1.0 / lambda_min;
1083 // Newton step δ = −H⁻¹ g via the eigendecomposition: δ = −Σ_i (vᵢᵀg/λᵢ) vᵢ.
1084 let mut delta = Array1::<f64>::zeros(d);
1085 for (col, &lam) in vals.iter().enumerate() {
1086 if lam <= 0.0 {
1087 return Ok(None);
1088 }
1089 let vi = vecs.column(col);
1090 let coeff = vi.dot(&g) / lam;
1091 for row in 0..d {
1092 delta[row] -= coeff * vi[row];
1093 }
1094 }
1095 let eta = delta.dot(&delta).sqrt();
1096 Ok(Some((beta, eta, delta)))
1097}
1098
1099/// Compute the per-row Kantorovich certificate for encoding target row `x`
1100/// against atom `atom` at start coordinate `t₀`, with fixed amplitude `z` and
1101/// the chart's closed-form Lipschitz constant `lipschitz`. Returns the
1102/// certificate AND the Newton step `δ = −H⁻¹ g` so the caller can advance.
1103pub fn row_certificate(
1104 atom: &SaeManifoldAtom,
1105 evaluator: &dyn SaeBasisEvaluator,
1106 t0: ArrayView1<'_, f64>,
1107 x: ArrayView1<'_, f64>,
1108 amplitude: f64,
1109 lipschitz: f64,
1110 ridge: f64,
1111) -> Result<(RowCertificate, Array1<f64>), String> {
1112 let uncertified = || {
1113 (
1114 RowCertificate {
1115 beta: f64::INFINITY,
1116 eta: f64::INFINITY,
1117 lipschitz,
1118 h: f64::INFINITY,
1119 },
1120 Array1::<f64>::zeros(atom.latent_dim),
1121 )
1122 };
1123 // No second jet ⇒ no full Hessian ⇒ uncertifiable (flag).
1124 let Some((g, h)) = encode_grad_hess(atom, evaluator, t0, x, amplitude, ridge)? else {
1125 return Ok(uncertified());
1126 };
1127 match beta_eta_newton(h.view(), g.view())? {
1128 Some((beta, eta, delta)) => {
1129 let cert = RowCertificate {
1130 beta,
1131 eta,
1132 lipschitz,
1133 h: beta * eta * lipschitz,
1134 };
1135 Ok((cert, delta))
1136 }
1137 // Indefinite / negative-curvature full Hessian: the start is at or past
1138 // a basin boundary (a max/saddle of f), not the minimum basin — flag.
1139 None => Ok(uncertified()),
1140 }
1141}
1142
1143fn uncertified_certificate(lipschitz: f64) -> RowCertificate {
1144 RowCertificate {
1145 beta: f64::INFINITY,
1146 eta: f64::INFINITY,
1147 lipschitz,
1148 h: f64::INFINITY,
1149 }
1150}
1151
1152fn refine_certified_start(
1153 atom: &SaeManifoldAtom,
1154 evaluator: &dyn SaeBasisEvaluator,
1155 mut t: Array1<f64>,
1156 x: ArrayView1<'_, f64>,
1157 amplitude: f64,
1158 lipschitz: f64,
1159 ridge: f64,
1160 newton_steps: usize,
1161 initial_cert: RowCertificate,
1162 mut delta: Array1<f64>,
1163) -> Result<Option<CertifiedEncodeProbe>, String> {
1164 assert!(initial_cert.certified());
1165 let mut final_cert = initial_cert;
1166 for _ in 0..newton_steps {
1167 // Convergence early-exit: the pending Newton step is below the coordinate
1168 // ULP scale, so `t + δ == t` to f64 resolution — the certified root is
1169 // reached and the remaining fixed-budget steps would only re-accumulate
1170 // round-off. This is where the well-conditioned quadratic Newton tail's
1171 // redundant `evaluate` + `second_jet` work is eliminated.
1172 if delta.dot(&delta).sqrt()
1173 <= NEWTON_REFINE_CONVERGED_EPS * (1.0 + t.dot(&t).sqrt())
1174 {
1175 break;
1176 }
1177 t = &t + δ
1178 let (cert, next_delta) =
1179 row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz, ridge)?;
1180 if !cert.certified() {
1181 return Ok(None);
1182 }
1183 final_cert = cert;
1184 delta = next_delta;
1185 }
1186 Ok(Some(CertifiedEncodeProbe {
1187 coord: t,
1188 initial_cert,
1189 final_cert,
1190 }))
1191}
1192
1193/// Certify an encode probe from `t_start`, navigating into the Kantorovich basin
1194/// first if needed (#1154/#1026). The Kantorovich quantity `h = β·η·L` scales with
1195/// amplitude through `L`, so at unit amplitude a positive-definite chart-center /
1196/// distilled start can sit OUTSIDE the certified ball (`h > ½`). Rather than
1197/// flagging it uncertified immediately — which made the encoder certify ZERO
1198/// held-out rows at amplitude 1.0 and fall back to the exact solve for everything —
1199/// take plain Newton steps toward the root, re-certifying at each iterate, while
1200/// the Kantorovich quantity `h = β·η·L` keeps CONTRACTING toward the ½ bound. The
1201/// certificate at the landing point is a full Kantorovich guarantee from there
1202/// (`h ≤ ½` ⇒ Newton converges to the in-ball root), so this only ever WIDENS the
1203/// certified set; it never certifies a non-convergent start.
1204///
1205/// Termination is the natural Newton stopping rule — there is no arbitrary step
1206/// budget. The warm-up stops and flags for the exact fallback when either the start
1207/// is not steppable (indefinite / non-finite Hessian — at or past a basin boundary)
1208/// or a step fails to reduce `h` (the iterate is not approaching a certifiable
1209/// in-chart root: its root lies outside this chart's valid Lipschitz region, or the
1210/// start was past the basin — empirically the rows that miss *plateau*, so more
1211/// steps cannot help; the lever there is denser charts, not more iterations). On
1212/// success the start is refined `newton_steps` further by [`refine_certified_start`].
1213/// Minimum per-step *multiplicative* decrease of the Kantorovich `h` the basin
1214/// warm-up requires to keep stepping (FIX #4). A tiny geometric floor: a
1215/// continued step must contract `h` by at least this fraction. Chosen small so
1216/// it never bites a genuinely converging row (Newton in the Kantorovich regime
1217/// is at least geometric and quadratic once `h < 1`, contracting `h` far faster
1218/// than `1/64` per step) while still forcing termination on a plateau.
1219const WARMUP_MIN_MULTIPLICATIVE_DECREASE: f64 = 1.0 / 64.0;
1220
1221/// Kantorovich-quadratic acceptance coefficient for the basin warm-up (FIX #4).
1222/// Once `h < 1` a converging Newton step contracts quadratically (`h_new ≲ κ·h²`);
1223/// accepting that path in addition to the geometric floor makes the
1224/// "genuinely-converging rows are untouched" guarantee explicit. Kept `< 1` so
1225/// the quadratic path is itself a strict contraction (for `h < 1`,
1226/// `κ·h² < κ·h < h`), which preserves the termination bound below.
1227const WARMUP_QUADRATIC_KAPPA: f64 = 0.5;
1228
1229/// Sufficient-decrease test for the basin-warmup loop (FIX #4).
1230///
1231/// Returns `true` while the warm-up should keep stepping. The previous exit rule
1232/// (`h_new >= h_prev` — strict decrease or quit) never fires for an `h`-sequence
1233/// that decreases *monotonically toward a limit above ½*: the increments fall
1234/// below one ulp long before `h` crosses the certifiable ½ bound, so a single
1235/// pathological row could spin ~1e15 full `row_certificate` solves (Hessian
1236/// build + solve) on the encode hot path. We instead require genuine
1237/// *multiplicative* progress each step, which matches Newton's actual behavior:
1238/// a healthy contraction clears the geometric floor by a wide margin (and the
1239/// quadratic path once `h < 1`), while a plateau (`h_new/h_prev → 1`) satisfies
1240/// neither branch and flags to the exact fallback.
1241///
1242/// Termination bound: the warm-up loop runs only while the row is uncertified
1243/// (`h_prev > ½`), and every continued step contracts `h` by at least the factor
1244/// `(1 − WARMUP_MIN_MULTIPLICATIVE_DECREASE)` (the quadratic branch, for `h < 1`,
1245/// contracts by `κ·h_prev < κ < 1`, i.e. even harder). Hence after `N` continued
1246/// steps `h ≤ (1 − c)^N · h₀`, and since the loop stops once `h ≤ ½` we get
1247/// `N < ln(2·h₀) / −ln(1 − c)` — a few hundred iterations at most, versus
1248/// unbounded before. No arbitrary fixed step cap is imposed; the bound is a
1249/// consequence of the contraction requirement, in keeping with the function's
1250/// no-magic-budget design.
1251fn warmup_progress_sufficient(h_new: f64, h_prev: f64) -> bool {
1252 if !(h_new.is_finite() && h_prev.is_finite()) {
1253 return false;
1254 }
1255 // Geometric floor: a real Newton contraction easily clears this.
1256 if h_new <= (1.0 - WARMUP_MIN_MULTIPLICATIVE_DECREASE) * h_prev {
1257 return true;
1258 }
1259 // Kantorovich-quadratic path (only once `h < 1`, where it is a strict
1260 // contraction): makes the no-regression guarantee for converging rows
1261 // explicit without ever admitting a non-contracting (plateau) step.
1262 h_prev < 1.0 && h_new <= WARMUP_QUADRATIC_KAPPA * h_prev * h_prev
1263}
1264
1265fn certify_with_basin_warmup(
1266 atom: &SaeManifoldAtom,
1267 evaluator: &dyn SaeBasisEvaluator,
1268 t_start: Array1<f64>,
1269 x: ArrayView1<'_, f64>,
1270 amplitude: f64,
1271 lipschitz: f64,
1272 ridge: f64,
1273 newton_steps: usize,
1274 chart_center: ArrayView1<'_, f64>,
1275 chart_radius: f64,
1276) -> Result<Option<CertifiedEncodeProbe>, String> {
1277 // SOUNDNESS GUARD: `lipschitz` is the chart's Hessian-Lipschitz sup, which is
1278 // only a valid bound over this chart's ball `‖t − center‖ ≤ radius` for the
1279 // chart-local families (`EuclideanPatch`/`Linear`/`Poincare` monomial patches,
1280 // `Cylinder` line axis, `Duchon` radial kernels). If a warm-up iterate leaves
1281 // that ball, `row_certificate` would compute `h = β·η·L` with an `L` that no
1282 // longer bounds the true geometry there, so `h ≤ ½` would NOT imply Kantorovich
1283 // convergence — a false certificate. (The `h`-contraction check does NOT catch
1284 // this: `h` can decrease monotonically toward an out-of-chart root the whole
1285 // way.) So we keep every certified iterate inside the chart; a row whose root is
1286 // outside this chart flags for the exact fallback — its lever is a denser grid,
1287 // not a step using an invalid `L`. Global-`L` families (periodic/torus/sphere)
1288 // route their points to charts whose centers are near the root, so the guard
1289 // rarely trips for them, and where it does the row was out-of-chart anyway.
1290 let in_chart = |t: &Array1<f64>| -> bool {
1291 // Wrap-aware chart containment. A raw Euclidean latent distance
1292 // `Σ(tᵢ − centerᵢ)²` mis-measures separation across the wrap seam of a
1293 // periodic axis: an iterate at `t = 0.99` against a chart centered at
1294 // `0.01` reads distance `0.98` instead of the true circle distance
1295 // `0.02`, so it is wrongly rejected at the start check and at every step
1296 // check — silently disabling the amortized encoder for a phase-localized
1297 // band of rows on every periodic / torus / cylinder-angle /
1298 // sphere-longitude chart and pushing that band to the multi-start
1299 // fallback. `latent_coordinate_distance` folds each axis onto its
1300 // `latent_axis_period`, so the containment ball is measured in the true
1301 // manifold geometry — exactly the geometry the chart's Lipschitz bound
1302 // `L` is valid over. This only ever moves points that are genuinely
1303 // near the center (small wrapped distance) back INTO the chart where
1304 // `L` holds, so it widens acceptance without weakening the soundness
1305 // guard (an iterate truly far from the center on the circle still fails).
1306 latent_coordinate_distance(atom, t.view(), chart_center) <= chart_radius
1307 };
1308 let mut t = t_start;
1309 // The distilled / chart-center start must itself be in-chart for its certificate
1310 // to be valid; a bad IFT prediction landing outside the chart is uncertifiable.
1311 if !in_chart(&t) {
1312 return Ok(None);
1313 }
1314 let (mut cert, mut delta) =
1315 row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz, ridge)?;
1316 while !cert.certified() {
1317 // Not steppable (indefinite / non-finite Hessian): flag.
1318 if !(cert.h.is_finite() && cert.beta.is_finite() && cert.eta.is_finite()) {
1319 return Ok(None);
1320 }
1321 let prev_h = cert.h;
1322 let next = &t + δ
1323 // Refuse to step where the chart's `L` is no longer valid (see guard above).
1324 if !in_chart(&next) {
1325 return Ok(None);
1326 }
1327 t = next;
1328 let (next_cert, next_delta) =
1329 row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz, ridge)?;
1330 cert = next_cert;
1331 delta = next_delta;
1332 // The warm-up only helps while h keeps *multiplicatively* contracting
1333 // toward ½. A plain strict-decrease test (`h >= prev_h`) never fires for
1334 // a sequence that decreases monotonically toward a limit above ½, so it
1335 // could spin ~1e15 `row_certificate` solves for one row; require genuine
1336 // multiplicative progress instead (bounded to a few hundred steps, see
1337 // `warmup_progress_sufficient`). Once a step fails that bar the iterate is
1338 // not converging to a certifiable in-chart root — flag for the exact
1339 // fallback (no arbitrary step budget; the bound falls out of the
1340 // contraction requirement).
1341 if !warmup_progress_sufficient(cert.h, prev_h) {
1342 return Ok(None);
1343 }
1344 }
1345 refine_certified_start(
1346 atom,
1347 evaluator,
1348 t,
1349 x,
1350 amplitude,
1351 lipschitz,
1352 ridge,
1353 newton_steps,
1354 cert,
1355 delta,
1356 )
1357}
1358
1359fn kantorovich_root_radius(cert: RowCertificate) -> f64 {
1360 if !cert.certified() || !(cert.eta.is_finite() && cert.eta >= 0.0) {
1361 return f64::INFINITY;
1362 }
1363 if cert.eta == 0.0 {
1364 return 0.0;
1365 }
1366 if !(cert.h.is_finite() && cert.h >= 0.0) {
1367 return f64::INFINITY;
1368 }
1369 let h = cert.h.min(KANTOROVICH_THRESHOLD);
1370 let discriminant = (1.0 - 2.0 * h).max(0.0).sqrt();
1371 let radius = 2.0 * cert.eta / (1.0 + discriminant);
1372 if radius.is_finite() {
1373 radius
1374 } else {
1375 f64::INFINITY
1376 }
1377}
1378
1379fn distilled_probe_tolerance(
1380 amortized: &CertifiedEncodeProbe,
1381 cold: &CertifiedEncodeProbe,
1382 amplitude: f64,
1383 x: ArrayView1<'_, f64>,
1384) -> f64 {
1385 let certified_radius =
1386 kantorovich_root_radius(amortized.final_cert) + kantorovich_root_radius(cold.final_cert);
1387 let coord_scale = amortized.coord.dot(&amortized.coord).sqrt()
1388 + cold.coord.dot(&cold.coord).sqrt()
1389 + x.dot(&x).sqrt()
1390 + amplitude.abs()
1391 + 1.0;
1392 certified_radius + 1024.0 * f64::EPSILON * coord_scale
1393}
1394
1395fn latent_coordinate_distance(
1396 atom: &SaeManifoldAtom,
1397 lhs: ArrayView1<'_, f64>,
1398 rhs: ArrayView1<'_, f64>,
1399) -> f64 {
1400 let mut acc = 0.0;
1401 for axis in 0..lhs.len().min(rhs.len()) {
1402 let mut diff = (lhs[axis] - rhs[axis]).abs();
1403 if let Some(period) = latent_axis_period(atom, axis) {
1404 let wrapped = diff.rem_euclid(period);
1405 diff = wrapped.min(period - wrapped);
1406 }
1407 acc += diff * diff;
1408 }
1409 acc.sqrt()
1410}
1411
1412fn latent_axis_period(atom: &SaeManifoldAtom, axis: usize) -> Option<f64> {
1413 use crate::manifold::SaeAtomBasisKind::*;
1414 match &atom.basis_kind {
1415 Periodic | Torus => Some(1.0),
1416 Cylinder if axis == 0 => Some(1.0),
1417 Sphere if axis == 1 => Some(std::f64::consts::TAU),
1418 _ => None,
1419 }
1420}
1421
1422/// Configuration for [`EncodeAtlas`] construction and online encode. All fields
1423/// are explicit; the atlas never reads global state and adds no CLI flags.
1424#[derive(Debug, Clone, Copy)]
1425pub struct AtlasConfig {
1426 /// Grid resolution per latent axis for offline chart centers (the
1427 /// SHAPE_BAND grid idiom).
1428 pub grid_resolution: usize,
1429 /// Levenberg ridge floor added to the per-row Gauss-Newton Hessian.
1430 pub ridge: f64,
1431 /// Number of online Newton refinement steps after a certified start (1 or 2
1432 /// per issue #1010).
1433 pub newton_steps: usize,
1434}
1435
1436impl Default for AtlasConfig {
1437 fn default() -> Self {
1438 Self {
1439 grid_resolution: 16,
1440 ridge: 1.0e-9,
1441 newton_steps: 2,
1442 }
1443 }
1444}
1445
1446/// The encode atlas: per-atom certified charts plus the online certified-encode
1447/// driver (issue #1010).
1448#[derive(Debug, Clone)]
1449pub struct EncodeAtlas {
1450 pub atoms: Vec<AtomEncodeAtlas>,
1451 pub config: AtlasConfig,
1452}
1453
1454impl EncodeAtlas {
1455 /// Build the offline atlas over a frozen dictionary: for each atom, lay down
1456 /// chart centers on the atom's coordinate grid and certify a Newton radius
1457 /// from the Kantorovich inequality at the worst-case in-chart start.
1458 ///
1459 /// `amplitude_bound[k]` is the per-atom bound on `|z_k|` used to scale the
1460 /// reconstruction jets (the offline `L` must hold for the largest amplitude
1461 /// the encode can produce); `target_norm_bound` bounds `‖x‖` over the data.
1462 pub fn build(
1463 atoms: &[SaeManifoldAtom],
1464 amplitude_bound: &[f64],
1465 target_norm_bound: f64,
1466 config: AtlasConfig,
1467 ) -> Result<Self, String> {
1468 if amplitude_bound.len() != atoms.len() {
1469 return Err(format!(
1470 "EncodeAtlas::build: amplitude_bound length {} != atom count {}",
1471 amplitude_bound.len(),
1472 atoms.len()
1473 ));
1474 }
1475 let mut atom_atlases = Vec::with_capacity(atoms.len());
1476 for (k, atom) in atoms.iter().enumerate() {
1477 let atlas =
1478 Self::build_atom_atlas(k, atom, amplitude_bound[k], target_norm_bound, &config)?;
1479 atom_atlases.push(atlas);
1480 }
1481 Ok(Self {
1482 atoms: atom_atlases,
1483 config,
1484 })
1485 }
1486
1487 pub(crate) fn build_atom_atlas(
1488 atom_index: usize,
1489 atom: &SaeManifoldAtom,
1490 amplitude_bound: f64,
1491 target_norm_bound: f64,
1492 config: &AtlasConfig,
1493 ) -> Result<AtomEncodeAtlas, String> {
1494 let centers = chart_center_grid(atom, config.grid_resolution);
1495 // Half the inter-center spacing is the natural in-chart radius so the
1496 // charts tile the grid without gaps; refined below if the certificate
1497 // fails at that radius. One uniform radius for the regular grid.
1498 let nominal_radius = chart_nominal_radius(atom, config.grid_resolution);
1499 let radii = vec![nominal_radius; centers.nrows()];
1500 Self::build_atom_atlas_from_centers(
1501 atom_index,
1502 atom,
1503 centers.view(),
1504 &radii,
1505 amplitude_bound,
1506 target_norm_bound,
1507 config,
1508 )
1509 }
1510
1511 /// Build a per-atom atlas from EXPLICIT chart centers with a per-center
1512 /// nominal radius — the geometry-agnostic core shared by the regular-grid
1513 /// [`Self::build_atom_atlas`] and the data-driven [`Self::build_data_driven`].
1514 /// Every chart is certified identically (Kantorovich radius from the in-chart
1515 /// curvature at its center); only the center PLACEMENT and per-center radius
1516 /// differ. `radii[c]` is the nominal in-chart radius for `centers[c]`.
1517 pub(crate) fn build_atom_atlas_from_centers(
1518 atom_index: usize,
1519 atom: &SaeManifoldAtom,
1520 centers: ArrayView2<'_, f64>,
1521 radii: &[f64],
1522 amplitude_bound: f64,
1523 target_norm_bound: f64,
1524 config: &AtlasConfig,
1525 ) -> Result<AtomEncodeAtlas, String> {
1526 let d = atom.latent_dim;
1527 if centers.ncols() != d {
1528 return Err(format!(
1529 "build_atom_atlas_from_centers: centers have {} cols but atom latent_dim is {d}",
1530 centers.ncols()
1531 ));
1532 }
1533 if radii.len() != centers.nrows() {
1534 return Err(format!(
1535 "build_atom_atlas_from_centers: {} radii != {} centers",
1536 radii.len(),
1537 centers.nrows()
1538 ));
1539 }
1540 let decoder_norm_sum = decoder_row_norm_sum(atom.decoder_coefficients.view());
1541 let mut charts = Vec::with_capacity(centers.nrows());
1542 for c in 0..centers.nrows() {
1543 let center = centers.row(c).to_owned();
1544 let nominal_radius = radii[c];
1545 let region = chart_region(atom, center.clone(), nominal_radius);
1546 let sups = family_jet_sups(atom, ®ion)?;
1547 let recon_sups = reconstruction_jet_sups(atom, sups);
1548 let lipschitz =
1549 hessian_lipschitz_constant(recon_sups, amplitude_bound, target_norm_bound, 0.0);
1550 // β at the chart center bounds the worst-case in-chart curvature
1551 // (the Gauss-Newton Hessian is continuous; the certified radius is
1552 // solved so the certificate is robust to the start within the ball).
1553 let beta_center = match center_beta(atom, ¢er, config.ridge) {
1554 Some(b) => b,
1555 None => {
1556 // Degenerate center curvature: no certifiable chart here, and
1557 // no amortized Jacobian (the same singular Gauss–Newton block).
1558 charts.push(CertifiedChart {
1559 region,
1560 lipschitz,
1561 beta_center: f64::INFINITY,
1562 certified_radius: 0.0,
1563 amortized_jacobian: None,
1564 recon_center: Array1::<f64>::zeros(atom.output_dim()),
1565 amortized_base: None,
1566 });
1567 continue;
1568 }
1569 };
1570 // Distill the amortized-encoder Jacobian at this center (#1026 ladder
1571 // item 3): the IFT derivative of the encode map, precomputed offline
1572 // so the online encode is one mat-vec. A finite `beta_center` (above)
1573 // means the Gauss–Newton block is non-singular, so this succeeds
1574 // alongside it; the pair travels together on the chart.
1575 let (amortized_jacobian, recon_center) =
1576 match center_amortized_jacobian(atom, ¢er, config.ridge) {
1577 Some((a1, m1)) => (Some(a1), m1),
1578 None => (None, Array1::<f64>::zeros(atom.output_dim())),
1579 };
1580 // Certified radius from h = β·η·L ≤ ½ with η ≤ R (Newton step length
1581 // is bounded by the start distance to the root, itself ≤ chart
1582 // radius at worst): R_c = ½ / (β·L), capped at the nominal radius.
1583 let certified_radius = if lipschitz > 0.0 && beta_center.is_finite() {
1584 (0.5 / (beta_center * lipschitz)).min(region.radius)
1585 } else {
1586 region.radius
1587 };
1588 // Precompute the affine-predictor constant `base = t_c − A₁·m₁` (atom-
1589 // static), so the online encode is a single `base + (1/z)·A₁·x` mat-vec.
1590 let amortized_base = amortized_jacobian
1591 .as_ref()
1592 .map(|a1| ¢er - &a1.dot(&recon_center));
1593 charts.push(CertifiedChart {
1594 region,
1595 lipschitz,
1596 beta_center,
1597 certified_radius,
1598 amortized_jacobian,
1599 recon_center,
1600 amortized_base,
1601 });
1602 }
1603 Ok(AtomEncodeAtlas {
1604 atom_index,
1605 latent_dim: d,
1606 decoder_norm_sum,
1607 charts,
1608 })
1609 }
1610
1611 /// Build the atlas with DATA-DRIVEN chart placement: instead of a dense
1612 /// `resolution^d` product grid (exponential in latent dim `d`, so the regular
1613 /// [`Self::build`] is forced to coarse, poorly-certified charts for `d ≥ 3`),
1614 /// place a bounded number of charts AT the data's own latent coordinates. The
1615 /// chart count is then `O(max_charts)` regardless of `d`, and every chart sits
1616 /// where data actually lands (small in-chart residual → certifies), so
1617 /// higher-dimensional atoms — which reconstruct real activations far better per
1618 /// parameter — become affordable and well-covered.
1619 ///
1620 /// `coords[k]` is atom `k`'s `n × d_k` latent coordinates (the seed coords, or
1621 /// a previous encode's output). Charts are chosen by greedy farthest-point
1622 /// sampling over those coords (deterministic, coverage-maximizing), capped at
1623 /// `max_charts`. Each chart's nominal radius is half the distance to its
1624 /// nearest neighbor center, so the charts tile the local data density. The
1625 /// per-chart Kantorovich certification is IDENTICAL to the regular grid — only
1626 /// the center placement differs.
1627 pub fn build_data_driven(
1628 atoms: &[SaeManifoldAtom],
1629 coords: &[Array2<f64>],
1630 amplitude_bound: &[f64],
1631 target_norm_bound: f64,
1632 max_charts: usize,
1633 config: AtlasConfig,
1634 ) -> Result<Self, String> {
1635 if amplitude_bound.len() != atoms.len() || coords.len() != atoms.len() {
1636 return Err(format!(
1637 "build_data_driven: amplitude_bound {} / coords {} must match atom count {}",
1638 amplitude_bound.len(),
1639 coords.len(),
1640 atoms.len()
1641 ));
1642 }
1643 let mut atom_atlases = Vec::with_capacity(atoms.len());
1644 for (k, atom) in atoms.iter().enumerate() {
1645 let (centers, radii) =
1646 data_driven_chart_centers(atom, coords[k].view(), max_charts.max(1))?;
1647 let atlas = Self::build_atom_atlas_from_centers(
1648 k,
1649 atom,
1650 centers.view(),
1651 &radii,
1652 amplitude_bound[k],
1653 target_norm_bound,
1654 &config,
1655 )?;
1656 atom_atlases.push(atlas);
1657 }
1658 Ok(Self {
1659 atoms: atom_atlases,
1660 config,
1661 })
1662 }
1663
1664 fn refine_certified_encode_start(
1665 &self,
1666 atom: &SaeManifoldAtom,
1667 evaluator: &dyn SaeBasisEvaluator,
1668 chart: &CertifiedChart,
1669 t: Array1<f64>,
1670 x: ArrayView1<'_, f64>,
1671 amplitude: f64,
1672 ) -> Result<(Array1<f64>, RowCertificate), String> {
1673 // Certify from the warm start, navigating into the Kantorovich basin first
1674 // if the unit-amplitude start has h > ½ (see `certify_with_basin_warmup`).
1675 let Some(probe) = certify_with_basin_warmup(
1676 atom,
1677 evaluator,
1678 t,
1679 x,
1680 amplitude,
1681 chart.lipschitz,
1682 self.config.ridge,
1683 self.config.newton_steps,
1684 chart.region.center.view(),
1685 chart.region.radius,
1686 )?
1687 else {
1688 return Ok((
1689 Array1::<f64>::zeros(atom.latent_dim),
1690 uncertified_certificate(chart.lipschitz),
1691 ));
1692 };
1693 Ok((probe.coord, probe.initial_cert))
1694 }
1695
1696 /// Online certified encode of one target row `x` against one atom `k` with
1697 /// fixed amplitude `z`. Routes to the nearest chart, starts from that chart's
1698 /// distilled IFT warm start, runs `config.newton_steps` Newton steps, and
1699 /// returns the encoded coordinate with its certificate. An uncertified start
1700 /// (no chart, no distilled Jacobian, non-positive amplitude, or `h > ½`)
1701 /// flags the row for the exact multi-start caller.
1702 pub fn certified_encode_row(
1703 &self,
1704 atom: &SaeManifoldAtom,
1705 atom_index: usize,
1706 x: ArrayView1<'_, f64>,
1707 amplitude: f64,
1708 ) -> Result<(Array1<f64>, RowCertificate), String> {
1709 let atom_atlas = self
1710 .atoms
1711 .get(atom_index)
1712 .ok_or_else(|| format!("certified_encode_row: atom {atom_index} not in atlas"))?;
1713 let d = atom.latent_dim;
1714 // A missing basis evaluator means the amortized/cold predictor cannot fire
1715 // for this atom (e.g. a frozen-baseline or first-build atom that never
1716 // attached a distilled evaluator). That is exactly the "cannot certify"
1717 // state — flag the row uncertified (zeros coords, ∞ certificate) so the
1718 // upstream exact multi-start solve owns it, never a hard error that aborts
1719 // the whole criterion. Mirrors the no-chart / singular-Jacobian branches.
1720 let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
1721 return Ok((
1722 Array1::<f64>::zeros(d),
1723 RowCertificate {
1724 beta: f64::INFINITY,
1725 eta: f64::INFINITY,
1726 lipschitz: f64::INFINITY,
1727 h: f64::INFINITY,
1728 },
1729 ));
1730 };
1731
1732 // Route to the nearest chart centers by AMBIENT reconstruction distance.
1733 // A single nearest chart is NOT globally sound on self-approaching atoms:
1734 // where the decoded manifold folds near itself (two distant latent points
1735 // map near the same output), the nearest-center chart can certify into the
1736 // locally-worse basin while another chart holds the GLOBAL minimum (both
1737 // branches' charts reconstruct near the crossing, so both are near in
1738 // ambient distance). The certificate is honest about LOCAL convergence but
1739 // cannot see the better far basin. So we refine in the top-K nearest charts
1740 // and keep the lowest-reconstruction-error CERTIFIED result. For a unimodal
1741 // atom every candidate chart converges to the same root, so this is a no-op
1742 // (first-wins tie → the nearest chart), preserving the existing behavior.
1743 let candidates =
1744 nearest_charts_topk(atom_atlas, x, CERTIFIED_ROUTING_TOPK);
1745 if candidates.is_empty() {
1746 return Ok((
1747 Array1::<f64>::zeros(d),
1748 RowCertificate {
1749 beta: f64::INFINITY,
1750 eta: f64::INFINITY,
1751 lipschitz: f64::INFINITY,
1752 h: f64::INFINITY,
1753 },
1754 ));
1755 }
1756 // Best CERTIFIED result by reconstruction error, plus the nearest chart's
1757 // result as the uncertified fallback (preserving the prior return when no
1758 // candidate certifies — the nearest chart owns the flagged row).
1759 let mut best: Option<(Array1<f64>, RowCertificate, f64)> = None;
1760 let mut nearest_fallback: Option<(Array1<f64>, RowCertificate)> = None;
1761 for chart_idx in candidates {
1762 let chart = &atom_atlas.charts[chart_idx];
1763 let Some(t) = amortized_warm_start(chart, x, amplitude) else {
1764 if nearest_fallback.is_none() {
1765 nearest_fallback =
1766 Some((Array1::<f64>::zeros(d), uncertified_certificate(chart.lipschitz)));
1767 }
1768 continue;
1769 };
1770 let (coord, cert) = self.refine_certified_encode_start(
1771 atom,
1772 evaluator.as_ref(),
1773 chart,
1774 t,
1775 x,
1776 amplitude,
1777 )?;
1778 if nearest_fallback.is_none() {
1779 nearest_fallback = Some((coord.clone(), cert.clone()));
1780 }
1781 if cert.certified() {
1782 let err =
1783 encode_reconstruction_error(atom, evaluator.as_ref(), coord.view(), x, amplitude);
1784 if best.as_ref().map(|(_, _, e)| err < *e).unwrap_or(true) {
1785 best = Some((coord, cert, err));
1786 }
1787 // Global-minimum short-circuit: reconstruction error ≥ 0, so a
1788 // certified candidate already at the ambient noise floor is provably
1789 // the global optimum over the remaining charts — stop refining them.
1790 if let Some((_, _, e)) = best.as_ref() {
1791 if *e <= CERTIFIED_GLOBAL_MIN_RECON_FLOOR * (1.0 + x.dot(&x).sqrt()) {
1792 break;
1793 }
1794 }
1795 }
1796 }
1797 match best {
1798 Some((coord, cert, _)) => Ok((coord, cert)),
1799 None => Ok(nearest_fallback.unwrap_or_else(|| {
1800 (
1801 Array1::<f64>::zeros(d),
1802 RowCertificate {
1803 beta: f64::INFINITY,
1804 eta: f64::INFINITY,
1805 lipschitz: f64::INFINITY,
1806 h: f64::INFINITY,
1807 },
1808 )
1809 })),
1810 }
1811 }
1812
1813 /// Amortized (distilled) encode of one target row `x` against one atom `k`
1814 /// with fixed amplitude `z` (#1026 ladder item 3).
1815 ///
1816 /// Routes to the nearest chart, then predicts the latent coordinate in CLOSED
1817 /// FORM from that chart's precomputed implicit-function-theorem Jacobian:
1818 ///
1819 /// ```text
1820 /// t̂ = t_c + (1/z) · A₁ · (x − z · m₁(t_c)),
1821 /// ```
1822 ///
1823 /// a single `O(d·p)` mat-vec — no per-row Hessian factorization or
1824 /// eigendecomposition, which is the amortization. The Kantorovich
1825 /// certificate is then evaluated AT the predicted start `t̂` with the chart's
1826 /// closed-form Lipschitz constant. A prediction is accepted only when that
1827 /// certificate holds, an independent cold chart-center probe also certifies,
1828 /// and the two refined coordinates agree within the two probes' final
1829 /// Kantorovich root-radius bounds. This keeps the distilled path honest
1830 /// without letting the exact probe reuse the distilled warm start it is
1831 /// auditing. A chart without a distilled Jacobian (singular Gauss–Newton
1832 /// block) flags the row.
1833 pub fn amortized_encode_row(
1834 &self,
1835 atom: &SaeManifoldAtom,
1836 atom_index: usize,
1837 x: ArrayView1<'_, f64>,
1838 amplitude: f64,
1839 ) -> Result<(Array1<f64>, RowCertificate), String> {
1840 let atom_atlas = self
1841 .atoms
1842 .get(atom_index)
1843 .ok_or_else(|| format!("amortized_encode_row: atom {atom_index} not in atlas"))?;
1844 let d = atom.latent_dim;
1845 let uncertified = || {
1846 (
1847 Array1::<f64>::zeros(d),
1848 RowCertificate {
1849 beta: f64::INFINITY,
1850 eta: f64::INFINITY,
1851 lipschitz: f64::INFINITY,
1852 h: f64::INFINITY,
1853 },
1854 )
1855 };
1856 // A missing basis evaluator means the distilled predictor cannot fire for
1857 // this atom — flag the row uncertified (the exact upstream solve owns it)
1858 // rather than erroring, exactly as the no-chart / singular-Jacobian /
1859 // non-positive-amplitude branches below do. Never a silent wrong encode,
1860 // never a hard abort of the criterion.
1861 let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
1862 return Ok(uncertified());
1863 };
1864 let Some((chart_idx, _)) = nearest_chart(atom_atlas, x) else {
1865 return Ok(uncertified());
1866 };
1867 let chart = &atom_atlas.charts[chart_idx];
1868 // Closed-form predicted start t̂ = t_c + (1/z)·A₁·(x − z·m₁). `None` when
1869 // the chart's Gauss–Newton block was singular (no distilled Jacobian, so
1870 // the amortized predictor cannot fire) or the amplitude is not strictly
1871 // positive and finite (a near-inactive atom, where the amplitude-divided
1872 // map is undefined) — either way flag for the exact fallback, never a
1873 // silent wrong encode.
1874 let Some(t_hat) = amortized_warm_start(chart, x, amplitude) else {
1875 return Ok(uncertified());
1876 };
1877 // Evaluate the SAME Kantorovich certificate at the predicted start. The
1878 // amortized prediction is trusted only if this certificate holds AND an
1879 // independent cold chart-center probe certifies and agrees below the
1880 // two probes' final Kantorovich root-radius bounds. This avoids the
1881 // self-referential gate where the "exact" probe is warm-started by the
1882 // same distilled prediction it is supposed to audit.
1883 let Some(amortized_probe) = certify_with_basin_warmup(
1884 atom,
1885 evaluator.as_ref(),
1886 t_hat,
1887 x,
1888 amplitude,
1889 chart.lipschitz,
1890 self.config.ridge,
1891 self.config.newton_steps,
1892 chart.region.center.view(),
1893 chart.region.radius,
1894 )?
1895 else {
1896 return Ok((
1897 Array1::<f64>::zeros(d),
1898 uncertified_certificate(chart.lipschitz),
1899 ));
1900 };
1901
1902 let cold_start = chart.region.center.clone();
1903 let Some(cold_probe) = certify_with_basin_warmup(
1904 atom,
1905 evaluator.as_ref(),
1906 cold_start,
1907 x,
1908 amplitude,
1909 chart.lipschitz,
1910 self.config.ridge,
1911 self.config.newton_steps,
1912 chart.region.center.view(),
1913 chart.region.radius,
1914 )?
1915 else {
1916 return Ok((
1917 amortized_probe.coord,
1918 uncertified_certificate(chart.lipschitz),
1919 ));
1920 };
1921
1922 let gap =
1923 latent_coordinate_distance(atom, amortized_probe.coord.view(), cold_probe.coord.view());
1924 let tolerance = distilled_probe_tolerance(&amortized_probe, &cold_probe, amplitude, x);
1925 if !(gap.is_finite() && gap <= tolerance) {
1926 return Ok((
1927 amortized_probe.coord,
1928 uncertified_certificate(chart.lipschitz),
1929 ));
1930 }
1931 Ok((amortized_probe.coord, amortized_probe.initial_cert))
1932 }
1933
1934 /// Batched amortized (distilled) encode over many rows against one atom
1935 /// (#1026 ladder item 3, corpus-rate). Each row uses the closed-form
1936 /// per-chart Jacobian predictor and carries its own Kantorovich certificate;
1937 /// uncertified rows are flagged in [`EncodeResult::encode_uncertified_count`]
1938 /// for the exact multi-start fallback. Row-independent against the frozen
1939 /// dictionary, so the batch fans out over rows (deterministic row-order
1940 /// assembly, bit-identical run-to-run), staying sequential inside a rayon
1941 /// worker to avoid nested oversubscription.
1942 pub fn amortized_encode_batch(
1943 &self,
1944 atom: &SaeManifoldAtom,
1945 atom_index: usize,
1946 targets: ArrayView2<'_, f64>,
1947 amplitudes: ArrayView1<'_, f64>,
1948 ) -> Result<EncodeResult, String> {
1949 let n = targets.nrows();
1950 if amplitudes.len() != n {
1951 return Err(format!(
1952 "amortized_encode_batch: amplitudes len {} != rows {n}",
1953 amplitudes.len()
1954 ));
1955 }
1956 let d = atom.latent_dim;
1957 let encode_rows =
1958 |range: std::ops::Range<usize>| -> Result<Vec<(Array1<f64>, bool)>, String> {
1959 range
1960 .map(|row| {
1961 let (t, cert) = self.amortized_encode_row(
1962 atom,
1963 atom_index,
1964 targets.row(row),
1965 amplitudes[row],
1966 )?;
1967 Ok((t, cert.certified()))
1968 })
1969 .collect()
1970 };
1971 let rows: Vec<(Array1<f64>, bool)> =
1972 if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
1973 use rayon::prelude::*;
1974 const CHUNK: usize = 256;
1975 let n_chunks = n.div_ceil(CHUNK);
1976 let chunked: Vec<Vec<(Array1<f64>, bool)>> = (0..n_chunks)
1977 .into_par_iter()
1978 .map(|c| {
1979 let start = c * CHUNK;
1980 let end = (start + CHUNK).min(n);
1981 encode_rows(start..end)
1982 })
1983 .collect::<Result<_, _>>()?;
1984 chunked.into_iter().flatten().collect()
1985 } else {
1986 encode_rows(0..n)?
1987 };
1988 let mut coords = Array2::<f64>::zeros((n, d));
1989 let mut certified = Vec::with_capacity(n);
1990 for (row, (t, cert)) in rows.into_iter().enumerate() {
1991 coords.row_mut(row).assign(&t);
1992 certified.push(cert);
1993 }
1994 Ok(EncodeResult::from_rows(coords, certified))
1995 }
1996
1997 /// Batched certified encode over many rows against one atom (the #988
1998 /// throughput consumer). Each row carries its own certificate; uncertified
1999 /// rows are flagged in [`EncodeResult::encode_uncertified_count`] for the
2000 /// exact multi-start fallback.
2001 pub fn certified_encode_batch(
2002 &self,
2003 atom: &SaeManifoldAtom,
2004 atom_index: usize,
2005 targets: ArrayView2<'_, f64>,
2006 amplitudes: ArrayView1<'_, f64>,
2007 ) -> Result<EncodeResult, String> {
2008 let n = targets.nrows();
2009 if amplitudes.len() != n {
2010 return Err(format!(
2011 "certified_encode_batch: amplitudes len {} != rows {n}",
2012 amplitudes.len()
2013 ));
2014 }
2015 let d = atom.latent_dim;
2016 // Per-row encode is independent against a frozen dictionary (#1010), so
2017 // the corpus-rate batch fans out over rows (#1026 amortized-encoder leg /
2018 // #977 Stage-3 corpus encode). Each row produces an owned `(t, certified)`
2019 // pair; results are assembled back in row order so the output is
2020 // bit-identical run-to-run regardless of thread scheduling. Stay
2021 // sequential inside a rayon worker (e.g. when an outer atom-level fan-out
2022 // owns the pool) to avoid nested oversubscription. The first row that
2023 // fails to encode propagates its error deterministically.
2024 let encode_rows =
2025 |range: std::ops::Range<usize>| -> Result<Vec<(Array1<f64>, bool)>, String> {
2026 range
2027 .map(|row| {
2028 let (t, cert) = self.certified_encode_row(
2029 atom,
2030 atom_index,
2031 targets.row(row),
2032 amplitudes[row],
2033 )?;
2034 Ok((t, cert.certified()))
2035 })
2036 .collect()
2037 };
2038 let rows: Vec<(Array1<f64>, bool)> =
2039 if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
2040 use rayon::prelude::*;
2041 const CHUNK: usize = 256;
2042 let n_chunks = n.div_ceil(CHUNK);
2043 let chunked: Vec<Vec<(Array1<f64>, bool)>> = (0..n_chunks)
2044 .into_par_iter()
2045 .map(|c| {
2046 let start = c * CHUNK;
2047 let end = (start + CHUNK).min(n);
2048 encode_rows(start..end)
2049 })
2050 .collect::<Result<_, _>>()?;
2051 chunked.into_iter().flatten().collect()
2052 } else {
2053 encode_rows(0..n)?
2054 };
2055 let mut coords = Array2::<f64>::zeros((n, d));
2056 let mut certified = Vec::with_capacity(n);
2057 for (row, (t, cert)) in rows.into_iter().enumerate() {
2058 coords.row_mut(row).assign(&t);
2059 certified.push(cert);
2060 }
2061 Ok(EncodeResult::from_rows(coords, certified))
2062 }
2063
2064 /// Batched GEMM "fast" amortized encode — the traditional-encoder forward
2065 /// pass, WITH manifolds. For every row this applies the SAME closed-form
2066 /// affine predictor as [`amortized_warm_start`]
2067 /// (`t̂ = t_c + (1/z)·A₁·(x − z·m₁)`), but routed and applied as batched
2068 /// matrix products instead of a per-row loop wrapped in the Kantorovich
2069 /// certificate + basin warmup. NO per-row certificate is taken: this is the
2070 /// speed mode (the certified `*_encode_*` paths remain the accuracy mode).
2071 ///
2072 /// Cost is GEMM-bound: one `(n × p)·(p × d)` decode-distance product for
2073 /// nearest-chart routing (skipped for single-chart atoms) plus, per chart,
2074 /// one `(n_c × p)·(p × d)` predictor product — i.e. `≈ X·Wᵀ`, exactly a
2075 /// dense SAE encoder's forward map.
2076 ///
2077 /// Degenerate rows are handled exactly as `amortized_warm_start` flags them
2078 /// (returns `None` ⇒ zeroed coord here): a missing basis evaluator, a chart
2079 /// whose Gauss–Newton block was singular (`amortized_jacobian == None`), or a
2080 /// non-finite / non-positive amplitude. Those rows are zeroed (never a panic,
2081 /// never a silent wrong encode), and their indices are returned in the
2082 /// `valid` mask so the caller can route them to the exact path if desired.
2083 ///
2084 /// Returns `(coords, valid)` where `coords` is `n × d` and `valid[row]` is
2085 /// `true` iff the amortized predictor fired for that row.
2086 pub fn amortized_encode_batch_fast(
2087 &self,
2088 atom: &SaeManifoldAtom,
2089 atom_index: usize,
2090 x: ArrayView2<'_, f64>,
2091 amplitudes: ArrayView1<'_, f64>,
2092 ) -> Result<(Array2<f64>, Vec<bool>), String> {
2093 let n = x.nrows();
2094 let p = atom.output_dim();
2095 let d = atom.latent_dim;
2096 if x.ncols() != p {
2097 return Err(format!(
2098 "amortized_encode_batch_fast: x has {} cols but atom output dim is {p}",
2099 x.ncols()
2100 ));
2101 }
2102 if amplitudes.len() != n {
2103 return Err(format!(
2104 "amortized_encode_batch_fast: amplitudes len {} != rows {n}",
2105 amplitudes.len()
2106 ));
2107 }
2108 let atom_atlas = self.atoms.get(atom_index).ok_or_else(|| {
2109 format!("amortized_encode_batch_fast: atom {atom_index} not in atlas")
2110 })?;
2111 let mut coords = Array2::<f64>::zeros((n, d));
2112 let mut valid = vec![false; n];
2113
2114 // A missing basis evaluator means this atom never had a well-formed atlas
2115 // built here — treat every row as uncertified (zeroed), exactly like the
2116 // per-row `amortized_encode_row` no-evaluator branch. (The predictor below
2117 // uses only cached atlas data, so no evaluator call is made online.)
2118 if atom.basis_evaluator.is_none() {
2119 return Ok((coords, valid));
2120 }
2121
2122 // ── Routing recon-centers (cached, no online basis evaluation). ────────
2123 // Routing sends a row to the chart whose center reconstruction
2124 // `m(t_c) = BᵀΦ(t_c)` is closest in ‖·‖². Those center reconstructions are
2125 // OFFLINE-cached in `chart.recon_center` (bit-identical to re-evaluating the
2126 // basis at the fixed centers — same φ·decoder accumulation). Gather the cache
2127 // instead of calling `evaluator.evaluate` on every invocation: that per-call
2128 // chart-center evaluation was the dominant per-atom-group overhead at massive
2129 // K, where N rows scatter across many atoms into tiny groups so a fixed
2130 // per-call cost is amortized over only a handful of rows. This is what keeps
2131 // the fast index-routed encode near-flat as K grows.
2132 let valid_charts: Vec<usize> = (0..atom_atlas.charts.len())
2133 .filter(|&c| atom_atlas.charts[c].certified_radius > 0.0)
2134 .collect();
2135 if valid_charts.is_empty() {
2136 return Ok((coords, valid));
2137 }
2138 // recon_centers (C × p): the cached m(t_c) for each certifiable chart.
2139 let mut recon_centers = Array2::<f64>::zeros((valid_charts.len(), p));
2140 for (ci, &c) in valid_charts.iter().enumerate() {
2141 recon_centers
2142 .row_mut(ci)
2143 .assign(&atom_atlas.charts[c].recon_center);
2144 }
2145 // Per-chart routing key: route_idx[row] = argmin_c ‖x_row − recon_c‖².
2146 // ‖x − r‖² = ‖x‖² − 2 x·r + ‖r‖²; the ‖x‖² term is row-constant so the
2147 // argmin uses S = X·recon_centersᵀ and the per-chart ‖r‖². First chart
2148 // wins on a tie (strict `<`), matching `nearest_chart`.
2149 let route_idx: Vec<usize> = if valid_charts.len() == 1 {
2150 vec![0usize; n]
2151 } else {
2152 let s = x.dot(&recon_centers.t()); // (n × C)
2153 let r_sq: Vec<f64> = (0..valid_charts.len())
2154 .map(|c| recon_centers.row(c).dot(&recon_centers.row(c)))
2155 .collect();
2156 (0..n)
2157 .map(|row| {
2158 let mut best_c = 0usize;
2159 let mut best_d = f64::INFINITY;
2160 for c in 0..valid_charts.len() {
2161 let dist = r_sq[c] - 2.0 * s[[row, c]];
2162 if dist < best_d {
2163 best_d = dist;
2164 best_c = c;
2165 }
2166 }
2167 best_c
2168 })
2169 .collect()
2170 };
2171
2172 // ── Per-chart batched affine predictor. ───────────────────────────────
2173 // For rows routed to chart `c` with finite jacobian `A₁` (d × p) and
2174 // center reconstruction `m₁` (= `chart.recon_center`), the predictor is
2175 // t̂ = t_c − A₁·m₁ + (1/z)·(A₁·x).
2176 // `t_c − A₁·m₁` is a per-chart constant `base`; `A₁·x` is a d-vector of
2177 // per-row dot products. Instead of gathering routed rows into a fresh
2178 // `X_c` (n_c × p) buffer and running a GEMM into a second `U` (n_c × d)
2179 // buffer — two allocations plus a full copy of the routed rows, per chart —
2180 // fuse the gather straight into the multiply: stream each source row of `x`
2181 // once (it is contiguous) and dot it against `A₁`'s rows, writing the
2182 // predicted coord directly. Zero per-chart heap traffic; the inverse
2183 // amplitude is hoisted to one reciprocal per row.
2184 //
2185 // Precompute each valid chart's `(A₁, base)` once (charts with a singular
2186 // Gauss–Newton block carry no `A₁`, so their routed rows stay
2187 // zeroed/uncertified — same as `amortized_warm_start` returning `None`).
2188 struct ChartPredictor<'a> {
2189 a1: &'a Array2<f64>,
2190 base: &'a Array1<f64>,
2191 }
2192 let predictors: Vec<Option<ChartPredictor<'_>>> = valid_charts
2193 .iter()
2194 .map(|&c| {
2195 let chart = &atom_atlas.charts[c];
2196 // `base = t_c − A₁·m₁` is precomputed offline in the atlas; reuse it
2197 // (both are `Some` together — singular G-N block ⇒ both `None`).
2198 match (chart.amortized_jacobian.as_ref(), chart.amortized_base.as_ref()) {
2199 (Some(a1), Some(base)) => Some(ChartPredictor { a1, base }),
2200 _ => None,
2201 }
2202 })
2203 .collect();
2204
2205 for row in 0..n {
2206 let Some(pred) = predictors[route_idx[row]].as_ref() else {
2207 continue;
2208 };
2209 let amp = amplitudes[row];
2210 if !(amp.is_finite() && amp.abs() > 0.0) {
2211 continue;
2212 }
2213 let inv_z = 1.0 / amp;
2214 let x_row = x.row(row);
2215 let mut coord_row = coords.row_mut(row);
2216 for axis in 0..d {
2217 // (A₁·x)[axis] = A₁ row `axis` (contiguous, length p) · x_row.
2218 coord_row[axis] = pred.base[axis] + pred.a1.row(axis).dot(&x_row) * inv_z;
2219 }
2220 valid[row] = true;
2221 }
2222 Ok((coords, valid))
2223 }
2224
2225 /// Fast batched FULL forward pass against one atom: encode → decode, the
2226 /// manifold analogue of a traditional SAE's `x̂ = z·D` (decoder `D`, code `z`).
2227 ///
2228 /// A traditional SAE decodes with one GEMM. The manifold SAE's reconstruction
2229 /// is `m(t̂) = z·Φ(t̂)·B` (module header) — the SAME GEMM `Φ·B`, but the code
2230 /// `Φ(t̂)` is the curved chart basis evaluated at the encoded latent coordinate
2231 /// rather than a flat one-hot. So the fast forward is exactly:
2232 /// 1. [`amortized_encode_batch_fast`] → per-row latent coords `t̂` (one
2233 /// routing GEMM + one affine GEMM per chart — a traditional `W·x+b`);
2234 /// 2. ONE batched basis evaluation `Φ(t̂)` (the manifold-curvature step a
2235 /// flat SAE doesn't have — `n×m`);
2236 /// 3. ONE GEMM `recon = Φ(t̂)·B` (`(n×m)·(m×p)` — a traditional decoder
2237 /// `z·D`), then the per-row amplitude scale `z`.
2238 ///
2239 /// Rows the encoder could not certify-predict (no evaluator / singular
2240 /// Gauss–Newton block / non-finite-or-zero amplitude) are returned as a ZERO
2241 /// reconstruction and flagged `false` in the valid-mask — never a silent wrong
2242 /// decode. The reconstruction of a valid row equals, bit-for-bit up to GEMM
2243 /// reassociation, `z·(Φ(t̂_row)·B)` with `t̂` from the per-row predictor.
2244 pub fn amortized_reconstruct_batch_fast(
2245 &self,
2246 atom: &SaeManifoldAtom,
2247 atom_index: usize,
2248 x: ArrayView2<'_, f64>,
2249 amplitudes: ArrayView1<'_, f64>,
2250 ) -> Result<(Array2<f64>, Vec<bool>), String> {
2251 let n = x.nrows();
2252 let p = atom.output_dim();
2253 // Step 1: batched encode → latent coords (reuses the fast routing+affine).
2254 let (coords, valid) = self.amortized_encode_batch_fast(atom, atom_index, x, amplitudes)?;
2255 let mut recon = Array2::<f64>::zeros((n, p));
2256 // A missing evaluator means no row could encode — every row is zeroed and
2257 // already flagged `false` by the encode; nothing to decode.
2258 let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
2259 return Ok((recon, valid));
2260 };
2261 // Step 2: ONE batched basis evaluation Φ(t̂) over all rows (n × m). Invalid
2262 // rows carry coords = 0 (the chart-origin); we still evaluate them in the
2263 // batch for a single GEMM, then zero their reconstruction below — the basis
2264 // is finite at the origin so this cannot poison the valid rows' GEMM.
2265 let (phi, _jet) = evaluator
2266 .evaluate(coords.view())
2267 .map_err(|err| format!("amortized_reconstruct_batch_fast: basis eval: {err}"))?;
2268 // Step 3: ONE GEMM recon = Φ·B (n × p), then per-row amplitude scale z.
2269 // m(t̂) = z·Φ(t̂)·B, matching the module header and `fill_decoded_row`'s
2270 // `Φ·decoder` accumulation (the amplitude is applied once here).
2271 let decoded = phi.dot(&atom.decoder_coefficients); // (n × p), amplitude-1
2272 for row in 0..n {
2273 if !valid[row] {
2274 continue; // stays zeroed — uncertified, like warm_start `None`.
2275 }
2276 let z = amplitudes[row];
2277 for col in 0..p {
2278 recon[[row, col]] = z * decoded[[row, col]];
2279 }
2280 }
2281 Ok((recon, valid))
2282 }
2283
2284 /// LSH-routed certified encode (issue #1010 step 2 + 3): for each target
2285 /// row, the existing [`SaeCandidateIndex`] (#985/#994) proposes the
2286 /// best-aligned atom by frame alignment to the row direction; the row is then
2287 /// encoded against THAT atom's certified chart atlas. This is the production
2288 /// routing path. Atom selection is EXACT (#1777): [`SaeCandidateIndex::route_exact`]
2289 /// returns the global argmax of the routing score (the universal-bound LSH fast
2290 /// path, else a full-scan fallback) — never a silently-missed ungathered atom —
2291 /// and the atlas does the in-atom nearest-chart routing and the per-row
2292 /// Kantorovich certificate.
2293 ///
2294 /// `atoms[id]` must be aligned with the atlas's `atoms[id]` (same dictionary
2295 /// order the atlas was built from and the sketch/index were built over).
2296 /// A row over an empty dictionary, or whose globally-best atom aligns below the
2297 /// fit-quality floor, is flagged uncertified — it routes to the exact
2298 /// multi-start fallback, never a silent wrong encode.
2299 pub fn certified_encode_with_index<S: AtomFrameSketch + Sync>(
2300 &self,
2301 atoms: &[SaeManifoldAtom],
2302 index: &SaeCandidateIndex,
2303 sketch: &S,
2304 targets: ArrayView2<'_, f64>,
2305 amplitudes: ArrayView1<'_, f64>,
2306 latent_dim: usize,
2307 ) -> Result<EncodeResult, String> {
2308 let n = targets.nrows();
2309 if amplitudes.len() != n {
2310 return Err(format!(
2311 "certified_encode_with_index: amplitudes len {} != rows {n}",
2312 amplitudes.len()
2313 ));
2314 }
2315 let budget = auto_candidate_budget(atoms.len().max(1));
2316 // LSH-routed per-row encode is independent across rows (sublinear atom
2317 // selection + frozen-dictionary in-atom Newton), so the corpus-rate batch
2318 // fans out over rows (#1026 amortized-encoder/routing leg / #977 Stage-3).
2319 // `None` coords (no LSH candidate) carry through as a zeroed row flagged
2320 // uncertified — identical to the sequential semantics. Results assemble
2321 // back in row order (bit-identical run-to-run); the first encode error
2322 // propagates deterministically. Stay sequential inside a rayon worker to
2323 // avoid nested oversubscription.
2324 let encode_rows =
2325 |range: std::ops::Range<usize>| -> Result<Vec<Option<(Array1<f64>, bool)>>, String> {
2326 range
2327 .map(|row| {
2328 // EXACT routing (#1777): pick the GLOBAL argmax of the
2329 // routing score over the whole dictionary, not merely the
2330 // best LSH-gathered candidate. `route_exact` certifies the
2331 // sublinear gather against the universal `[0,1]` alignment
2332 // bound and falls back to a full scan otherwise, so the
2333 // returned atom is guaranteed to be the globally-best — no
2334 // silently-missed ungathered atom.
2335 let Some(route) =
2336 index.route_exact(sketch, targets.row(row), budget, true)
2337 else {
2338 // Empty dictionary: flag for the exact fallback.
2339 return Ok(None);
2340 };
2341 let best_atom = route.atom;
2342 // Fit-quality floor: even the globally-best atom may align
2343 // only weakly with this row (no atom fits it). A finite
2344 // alignment below the floor — or a NaN, the zero-norm
2345 // ‖d‖ = 0 row — flags for the exact multi-start fallback
2346 // rather than encoding against a poorly-fitting atom. This is
2347 // a quality gate, not a routing-correctness gate; routing is
2348 // already exact. See CANDIDATE_ROUTING_MIN_ALIGNMENT.
2349 if !route.alignment.is_finite()
2350 || route.alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
2351 {
2352 return Ok(None);
2353 }
2354 let atom = atoms.get(best_atom).ok_or_else(|| {
2355 format!(
2356 "certified_encode_with_index: proposed atom {best_atom} out of range"
2357 )
2358 })?;
2359 let (t, cert) = self.certified_encode_row(
2360 atom,
2361 best_atom,
2362 targets.row(row),
2363 amplitudes[row],
2364 )?;
2365 // Heterogeneous-atom dictionaries with different latent_dim
2366 // per atom are not supported by the batched API: the caller
2367 // declares one shared `latent_dim` for the output tensor.
2368 // Silently zeroing the coord row while recording a
2369 // certified=true flag would produce corrupted
2370 // reconstructions downstream — error loudly instead.
2371 if t.len() != latent_dim {
2372 return Err(format!(
2373 "certified_encode_with_index: atom {best_atom} returned t.len()={} \
2374 but declared latent_dim={latent_dim}; heterogeneous-dim \
2375 dictionaries are not supported by this batched encode path",
2376 t.len()
2377 ));
2378 }
2379 Ok(Some((t, cert.certified())))
2380 })
2381 .collect()
2382 };
2383 let rows: Vec<Option<(Array1<f64>, bool)>> =
2384 if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
2385 use rayon::prelude::*;
2386 const CHUNK: usize = 256;
2387 let n_chunks = n.div_ceil(CHUNK);
2388 let chunked: Vec<Vec<Option<(Array1<f64>, bool)>>> = (0..n_chunks)
2389 .into_par_iter()
2390 .map(|c| {
2391 let start = c * CHUNK;
2392 let end = (start + CHUNK).min(n);
2393 encode_rows(start..end)
2394 })
2395 .collect::<Result<_, _>>()?;
2396 chunked.into_iter().flatten().collect()
2397 } else {
2398 encode_rows(0..n)?
2399 };
2400 let mut coords = Array2::<f64>::zeros((n, latent_dim));
2401 let mut certified = Vec::with_capacity(n);
2402 for (row, slot) in rows.into_iter().enumerate() {
2403 match slot {
2404 Some((t, cert)) => {
2405 coords.row_mut(row).assign(&t);
2406 certified.push(cert);
2407 }
2408 None => certified.push(false),
2409 }
2410 }
2411 Ok(EncodeResult::from_rows(coords, certified))
2412 }
2413
2414 /// LSH-routed AMORTIZED (distilled) encode — the production token-rate
2415 /// encoder of #1026 ladder item 3. Identical routing to
2416 /// [`Self::certified_encode_with_index`] (LSH proposes the best-aligned atom,
2417 /// the atlas routes to the in-atom nearest chart), but the in-atom encode is
2418 /// the closed-form per-chart Jacobian predictor + certificate gate of
2419 /// [`Self::amortized_encode_row`] rather than the certified Newton-refinement
2420 /// path.
2421 /// This is the deployment path: the distilled affine map produces the encode
2422 /// in one mat-vec, the Kantorovich certificate decides trust-or-fallback per
2423 /// row, and uncertified rows (the adversarial tail the thread expects to
2424 /// concentrate on rare tokens) are flagged for the exact multi-start solve —
2425 /// compute goes where the questions are. Row-independent against the frozen
2426 /// dictionary, so the batch fans out over rows with deterministic row-order
2427 /// assembly (bit-identical run-to-run).
2428 pub fn amortized_encode_with_index<S: AtomFrameSketch + Sync>(
2429 &self,
2430 atoms: &[SaeManifoldAtom],
2431 index: &SaeCandidateIndex,
2432 sketch: &S,
2433 targets: ArrayView2<'_, f64>,
2434 amplitudes: ArrayView1<'_, f64>,
2435 latent_dim: usize,
2436 ) -> Result<EncodeResult, String> {
2437 let n = targets.nrows();
2438 if amplitudes.len() != n {
2439 return Err(format!(
2440 "amortized_encode_with_index: amplitudes len {} != rows {n}",
2441 amplitudes.len()
2442 ));
2443 }
2444 let budget = auto_candidate_budget(atoms.len().max(1));
2445 let encode_rows =
2446 |range: std::ops::Range<usize>| -> Result<Vec<Option<(Array1<f64>, bool)>>, String> {
2447 range
2448 .map(|row| {
2449 // EXACT routing (#1777): global argmax of the routing score,
2450 // not just the best LSH-gathered candidate (see
2451 // certified_encode_with_index for the full rationale).
2452 let Some(route) =
2453 index.route_exact(sketch, targets.row(row), budget, true)
2454 else {
2455 return Ok(None);
2456 };
2457 let best_atom = route.atom;
2458 // Fit-quality floor (not a routing-correctness gate; routing
2459 // is exact): even the globally-best atom may fit a row poorly,
2460 // and a NaN alignment is the zero-norm ‖d‖ = 0 row. Either way
2461 // flag for the exact multi-start fallback. See
2462 // CANDIDATE_ROUTING_MIN_ALIGNMENT.
2463 if !route.alignment.is_finite()
2464 || route.alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
2465 {
2466 return Ok(None);
2467 }
2468 let atom = atoms.get(best_atom).ok_or_else(|| {
2469 format!(
2470 "amortized_encode_with_index: proposed atom {best_atom} out of range"
2471 )
2472 })?;
2473 let (t, cert) = self.amortized_encode_row(
2474 atom,
2475 best_atom,
2476 targets.row(row),
2477 amplitudes[row],
2478 )?;
2479 if t.len() != latent_dim {
2480 return Err(format!(
2481 "amortized_encode_with_index: atom {best_atom} returned t.len()={} \
2482 but declared latent_dim={latent_dim}; heterogeneous-dim \
2483 dictionaries are not supported by this batched encode path",
2484 t.len()
2485 ));
2486 }
2487 Ok(Some((t, cert.certified())))
2488 })
2489 .collect()
2490 };
2491 let rows: Vec<Option<(Array1<f64>, bool)>> =
2492 if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
2493 use rayon::prelude::*;
2494 const CHUNK: usize = 256;
2495 let n_chunks = n.div_ceil(CHUNK);
2496 let chunked: Vec<Vec<Option<(Array1<f64>, bool)>>> = (0..n_chunks)
2497 .into_par_iter()
2498 .map(|c| {
2499 let start = c * CHUNK;
2500 let end = (start + CHUNK).min(n);
2501 encode_rows(start..end)
2502 })
2503 .collect::<Result<_, _>>()?;
2504 chunked.into_iter().flatten().collect()
2505 } else {
2506 encode_rows(0..n)?
2507 };
2508 let mut coords = Array2::<f64>::zeros((n, latent_dim));
2509 let mut certified = Vec::with_capacity(n);
2510 for (row, slot) in rows.into_iter().enumerate() {
2511 match slot {
2512 Some((t, cert)) => {
2513 coords.row_mut(row).assign(&t);
2514 certified.push(cert);
2515 }
2516 None => certified.push(false),
2517 }
2518 }
2519 Ok(EncodeResult::from_rows(coords, certified))
2520 }
2521
2522 /// LSH-routed FAST amortized encode over the WHOLE dictionary — the
2523 /// multi-atom, corpus-rate analogue of [`Self::amortized_encode_with_index`].
2524 ///
2525 /// `amortized_encode_with_index` routes per row, then runs the per-row
2526 /// closed-form predictor + Kantorovich certificate + cold cross-check on each
2527 /// row independently. This fast variant keeps the SAME per-row EXACT routing
2528 /// (`index.route_exact` + the fit-quality floor), but replaces the per-row
2529 /// predictor with the GEMM-batched [`Self::amortized_encode_batch_fast`]:
2530 /// it GROUPS rows by their global-argmax atom and runs one batched affine-
2531 /// predictor pass per atom-group (a routing GEMM + a predictor GEMM each),
2532 /// reproducing a traditional SAE's whole-dictionary `W·x+b` throughput. No
2533 /// per-row certificate — this is the speed mode validated as accuracy-parity
2534 /// with the certified solve (`fast_forward_is_accuracy_parity_with_certified`).
2535 ///
2536 /// Returns the per-row latent coords and a valid-mask: `false` for a row with
2537 /// an empty dictionary, a sub-threshold/NaN routing alignment, or one the batched
2538 /// predictor could not fire on (no evaluator / singular Gauss–Newton block /
2539 /// non-finite-or-zero amplitude). Each row is written exactly once (disjoint
2540 /// per-atom groups), so the result is independent of group iteration order.
2541 pub fn amortized_encode_with_index_fast<S: AtomFrameSketch + Sync>(
2542 &self,
2543 atoms: &[SaeManifoldAtom],
2544 index: &SaeCandidateIndex,
2545 sketch: &S,
2546 targets: ArrayView2<'_, f64>,
2547 amplitudes: ArrayView1<'_, f64>,
2548 latent_dim: usize,
2549 ) -> Result<(Array2<f64>, Vec<bool>), String> {
2550 let n = targets.nrows();
2551 if amplitudes.len() != n {
2552 return Err(format!(
2553 "amortized_encode_with_index_fast: amplitudes len {} != rows {n}",
2554 amplitudes.len()
2555 ));
2556 }
2557 let budget = auto_candidate_budget(atoms.len().max(1));
2558 let mut coords = Array2::<f64>::zeros((n, latent_dim));
2559 let mut valid = vec![false; n];
2560 // ── Single allocation-free pass: route each row, apply the CACHED predictor.
2561 //
2562 // Routing sublinearity (massive-K, K≈32k): the certified path uses
2563 // `route_exact`, whose universal-bound LSH certificate only fires at the
2564 // alignment ceiling (≈1.0); for any real dictionary (`alignment < 1`) it
2565 // falls back to `brute_force_best_atom` — an O(K) full scan PER ROW, making
2566 // the encode O(N·K). This SPEED path takes the LSH gather's best-aligned atom
2567 // directly (`propose` scores only ~budget candidates → O(log K)); a rare miss
2568 // is caught by the fit-quality floor + the downstream certificate/exact
2569 // fallback (the documented speed/accuracy tradeoff).
2570 //
2571 // Allocation: the predictor uses only OFFLINE-cached atlas data (per-chart
2572 // `recon_center` for routing + `amortized_jacobian`/`amortized_base` for the
2573 // `t̂ = base + (1/z)·A₁·x` mat-vec), so NO per-row or per-atom heap buffer is
2574 // allocated. This replaces the old per-atom-group GEMM sub-batch — which at
2575 // `K ≫ N` degenerated to ONE row per group, so its per-group buffers (x_sub,
2576 // recon-centers, predictors) dominated and made the "fast" path allocation-
2577 // bound. Now the only per-row allocation is inside `index.propose`.
2578 for row in 0..n {
2579 let dir = targets.row(row);
2580 let proposal = index.propose(sketch, dir, budget, true);
2581 let Some(&best_atom) = proposal.proposed.first() else {
2582 continue; // nothing gathered (empty dictionary / probe-dim mismatch)
2583 };
2584 // Fit-quality floor: the best gathered atom still fits this row poorly,
2585 // or the alignment is NaN (zero-norm row) — flag for the exact fallback.
2586 let alignment = sketch.alignment(best_atom, dir);
2587 if !alignment.is_finite() || alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT {
2588 continue;
2589 }
2590 let atom = atoms.get(best_atom).ok_or_else(|| {
2591 format!("amortized_encode_with_index_fast: proposed atom {best_atom} out of range")
2592 })?;
2593 if atom.latent_dim != latent_dim {
2594 return Err(format!(
2595 "amortized_encode_with_index_fast: atom {best_atom} latent_dim {} != declared \
2596 {latent_dim}; heterogeneous-dim dictionaries are not supported by this path",
2597 atom.latent_dim
2598 ));
2599 }
2600 let Some(atom_atlas) = self.atoms.get(best_atom) else {
2601 continue; // no atlas for this atom → predictor cannot fire (zeroed)
2602 };
2603 if amortized_predict_row(atom_atlas, dir, amplitudes[row], latent_dim, coords.row_mut(row))
2604 {
2605 valid[row] = true;
2606 }
2607 }
2608 Ok((coords, valid))
2609 }
2610
2611 /// LSH-routed FAST full forward over the WHOLE dictionary: encode → decode,
2612 /// the multi-atom analogue of [`Self::amortized_reconstruct_batch_fast`]. Same
2613 /// sublinear per-row routing + per-atom grouping as
2614 /// [`Self::amortized_encode_with_index_fast`], but each group is run through
2615 /// the batched reconstruct (`m(t̂) = z·Φ(t̂)·B`) so the result is the per-row
2616 /// reconstruction in the ambient space. Rows that do not route/predict decode
2617 /// to an exact zero reconstruction and are flagged `false`.
2618 pub fn amortized_reconstruct_with_index_fast<S: AtomFrameSketch + Sync>(
2619 &self,
2620 atoms: &[SaeManifoldAtom],
2621 index: &SaeCandidateIndex,
2622 sketch: &S,
2623 targets: ArrayView2<'_, f64>,
2624 amplitudes: ArrayView1<'_, f64>,
2625 ) -> Result<(Array2<f64>, Vec<bool>), String> {
2626 let n = targets.nrows();
2627 let p = targets.ncols();
2628 if amplitudes.len() != n {
2629 return Err(format!(
2630 "amortized_reconstruct_with_index_fast: amplitudes len {} != rows {n}",
2631 amplitudes.len()
2632 ));
2633 }
2634 let budget = auto_candidate_budget(atoms.len().max(1));
2635 // SUBLINEAR routing for the SPEED-mode full forward — mirror
2636 // `amortized_encode_with_index_fast`: take the LSH gather's best-aligned
2637 // atom (O(budget) candidates) instead of route_exact's O(K) full-scan
2638 // certification, keeping the whole fast encode→decode sublinear in K at
2639 // K=32k. The gather's best is the exact argmax on the vast majority of rows;
2640 // rare misses are caught by the fit-quality floor + downstream fallback.
2641 let mut groups: std::collections::HashMap<usize, Vec<usize>> =
2642 std::collections::HashMap::new();
2643 for row in 0..n {
2644 let dir = targets.row(row);
2645 let proposal = index.propose(sketch, dir, budget, true);
2646 let Some(&best_atom) = proposal.proposed.first() else {
2647 continue; // nothing gathered (empty dictionary / probe-dim mismatch)
2648 };
2649 let alignment = sketch.alignment(best_atom, dir);
2650 if !alignment.is_finite() || alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT {
2651 continue;
2652 }
2653 groups.entry(best_atom).or_default().push(row);
2654 }
2655
2656 let mut recon = Array2::<f64>::zeros((n, p));
2657 let mut valid = vec![false; n];
2658 for (atom_idx, rows_here) in groups {
2659 let atom = atoms.get(atom_idx).ok_or_else(|| {
2660 format!(
2661 "amortized_reconstruct_with_index_fast: proposed atom {atom_idx} out of range"
2662 )
2663 })?;
2664 if atom.output_dim() != p {
2665 return Err(format!(
2666 "amortized_reconstruct_with_index_fast: atom {atom_idx} output_dim {} != target \
2667 dim {p}",
2668 atom.output_dim()
2669 ));
2670 }
2671 let mut x_sub = Array2::<f64>::zeros((rows_here.len(), p));
2672 let mut amp_sub = Array1::<f64>::zeros(rows_here.len());
2673 for (i, &row) in rows_here.iter().enumerate() {
2674 x_sub.row_mut(i).assign(&targets.row(row));
2675 amp_sub[i] = amplitudes[row];
2676 }
2677 let (sub_recon, sub_valid) = self.amortized_reconstruct_batch_fast(
2678 atom,
2679 atom_idx,
2680 x_sub.view(),
2681 amp_sub.view(),
2682 )?;
2683 for (i, &row) in rows_here.iter().enumerate() {
2684 if sub_valid[i] {
2685 recon.row_mut(row).assign(&sub_recon.row(i));
2686 valid[row] = true;
2687 }
2688 }
2689 }
2690 Ok((recon, valid))
2691 }
2692}
2693
2694/// Offline `β = 1/λ_min(H_GN)` at a chart center from the Gauss-Newton block
2695/// `H_GN = J_mᵀ J_m` (residual-free). The offline `β` bounds the curvature the
2696/// online certificate sees: charts are placed where the encode lands, so the
2697/// representative residual is small and `H_GN` is the dominant, residual-free
2698/// curvature estimate. (The online per-row certificate still uses the FULL
2699/// Hessian; this is only the offline radius-sizing curvature.) Returns `None`
2700/// for a degenerate center (`λ_min ≤ 0`), which marks an uncertifiable chart.
2701pub(crate) fn center_beta(atom: &SaeManifoldAtom, center: &Array1<f64>, ridge: f64) -> Option<f64> {
2702 let evaluator = atom.basis_evaluator.as_ref()?.clone();
2703 let d = atom.latent_dim;
2704 let p = atom.output_dim();
2705 let m = atom.basis_size();
2706 let coords = center.view().to_shape((1, d)).ok()?.to_owned();
2707 let (_phi, jet) = evaluator.evaluate(coords.view()).ok()?;
2708 let decoder = &atom.decoder_coefficients;
2709 // J_m[axis] = Bᵀ (∂Φ/∂t_axis) ∈ ℝᵖ (amplitude-1; curvature scales with z²
2710 // and is absorbed conservatively by the amplitude-bounded Lipschitz term).
2711 let mut jm = Array2::<f64>::zeros((d, p));
2712 for axis in 0..d {
2713 for basis_col in 0..m {
2714 let dphi = jet[[0, basis_col, axis]];
2715 if dphi == 0.0 {
2716 continue;
2717 }
2718 for out in 0..p {
2719 jm[[axis, out]] += dphi * decoder[[basis_col, out]];
2720 }
2721 }
2722 }
2723 let mut h = Array2::<f64>::zeros((d, d));
2724 for a in 0..d {
2725 for b in 0..d {
2726 h[[a, b]] = jm.row(a).dot(&jm.row(b));
2727 }
2728 h[[a, a]] += ridge;
2729 }
2730 let (vals, _vecs) = h.eigh(Side::Lower).ok()?;
2731 let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
2732 if lambda_min.is_finite() && lambda_min > 0.0 {
2733 Some(1.0 / lambda_min)
2734 } else {
2735 None
2736 }
2737}
2738
2739/// #1154 — the amortized encoder's closed-form warm-start coordinate for one
2740/// row `x` against one chart at amplitude `z`:
2741///
2742/// ```text
2743/// t̂ = t_c + (1/z) · A₁ · (x − z · m₁(t_c)),
2744/// ```
2745///
2746/// a single `O(d·p)` mat-vec from the chart's precomputed IFT Jacobian `A₁` and
2747/// center reconstruction `m₁`. Returns `None` when the chart carries no
2748/// distilled Jacobian (singular Gauss–Newton block) or the amplitude is not
2749/// strictly positive and finite (a near-inactive atom, where the
2750/// amplitude-divided map is undefined) — in those cases the caller starts from
2751/// the chart center instead. Shared by the amortized encode (where `t̂` is the
2752/// prediction) and the exact certified encode (where `t̂` is the Newton
2753/// warm-start that then refines to stationarity, Design A).
2754pub(crate) fn amortized_warm_start(
2755 chart: &CertifiedChart,
2756 x: ArrayView1<'_, f64>,
2757 amplitude: f64,
2758) -> Option<Array1<f64>> {
2759 let a1 = chart.amortized_jacobian.as_ref()?;
2760 if !(amplitude.is_finite() && amplitude.abs() > 0.0) {
2761 return None;
2762 }
2763 let d = a1.nrows();
2764 let mut t_hat = chart.region.center.clone();
2765 for (out_idx, &m1_out) in chart.recon_center.iter().enumerate().take(a1.ncols()) {
2766 let resid = x[out_idx] - amplitude * m1_out;
2767 for axis in 0..d {
2768 t_hat[axis] += a1[[axis, out_idx]] * resid / amplitude;
2769 }
2770 }
2771 Some(t_hat)
2772}
2773
2774/// Single-row amortized predictor against ONE atom's cached atlas, writing the
2775/// encoded latent coordinate DIRECTLY into `out` (length `d`) with NO heap
2776/// allocation. Routes to the nearest certifiable chart by cached center-
2777/// reconstruction distance `‖x − m(t_c)‖²` (matching `amortized_encode_batch_fast`'s
2778/// per-chart routing), then applies that chart's precomputed affine predictor
2779/// `t̂ = base + (1/z)·A₁·x` (`base = t_c − A₁·m₁` is `chart.amortized_base`).
2780///
2781/// Returns `false` — leaving `out` at its incoming (zeroed) value — for exactly the
2782/// rows `amortized_encode_batch_fast` would flag: no certifiable chart, a nearest
2783/// chart with no distilled predictor (singular Gauss–Newton block), or an unusable
2784/// amplitude. This is the per-row core of the allocation-free massive-K fast encode.
2785pub(crate) fn amortized_predict_row(
2786 atom_atlas: &AtomEncodeAtlas,
2787 x: ArrayView1<'_, f64>,
2788 amplitude: f64,
2789 d: usize,
2790 mut out: ndarray::ArrayViewMut1<'_, f64>,
2791) -> bool {
2792 if !(amplitude.is_finite() && amplitude.abs() > 0.0) {
2793 return false;
2794 }
2795 // Nearest certifiable chart by ‖x − recon_center‖² (first-wins on ties, strict
2796 // `<` — same argmin as `amortized_encode_batch_fast`'s route_idx; the ‖x‖² term
2797 // it drops is row-constant and does not change the argmin).
2798 let mut best_ci: Option<usize> = None;
2799 let mut best_dist = f64::INFINITY;
2800 for (ci, chart) in atom_atlas.charts.iter().enumerate() {
2801 if chart.certified_radius <= 0.0 {
2802 continue;
2803 }
2804 let mut dist = 0.0;
2805 for (r, xv) in chart.recon_center.iter().zip(x.iter()) {
2806 let diff = r - xv;
2807 dist += diff * diff;
2808 }
2809 if dist < best_dist {
2810 best_dist = dist;
2811 best_ci = Some(ci);
2812 }
2813 }
2814 let Some(ci) = best_ci else {
2815 return false;
2816 };
2817 let chart = &atom_atlas.charts[ci];
2818 // The nearest chart must carry a distilled predictor; otherwise flag (zeroed),
2819 // exactly as the per-chart `None` branch of `amortized_encode_batch_fast`.
2820 let (Some(a1), Some(base)) = (
2821 chart.amortized_jacobian.as_ref(),
2822 chart.amortized_base.as_ref(),
2823 ) else {
2824 return false;
2825 };
2826 let inv_z = 1.0 / amplitude;
2827 for axis in 0..d {
2828 out[axis] = base[axis] + a1.row(axis).dot(&x) * inv_z;
2829 }
2830 true
2831}
2832
2833/// The amplitude-1 distilled amortized-encoder Jacobian at a chart center
2834/// (#1026 ladder item 3). Returns `(A₁, m₁)` where `m₁ = BᵀΦ(t_c) ∈ ℝᵖ` is the
2835/// amplitude-1 center reconstruction and `A₁ = (J₁ᵀJ₁ + ridge·I)⁻¹ J₁ ∈ ℝ^{d×p}`
2836/// is the implicit-function-theorem derivative of the encode map `x ↦ t`
2837/// (Gauss–Newton block — the residual-free, dominant curvature exactly as the
2838/// offline radius-sizing `β`). With these, the online encode of a row `x` at
2839/// amplitude `z` is the closed-form affine prediction
2840/// `t = t_c + (1/z)·A₁·(x − z·m₁)` — one mat-vec, no per-row factorization.
2841/// `None` when the basis has no jet or the Gauss–Newton block is singular (no
2842/// certifiable amortization), matching `center_beta`'s gate so a chart with a
2843/// finite `β` always carries a Jacobian and vice versa.
2844pub(crate) fn center_amortized_jacobian(
2845 atom: &SaeManifoldAtom,
2846 center: &Array1<f64>,
2847 ridge: f64,
2848) -> Option<(Array2<f64>, Array1<f64>)> {
2849 let evaluator = atom.basis_evaluator.as_ref()?.clone();
2850 let d = atom.latent_dim;
2851 let p = atom.output_dim();
2852 let m = atom.basis_size();
2853 let coords = center.view().to_shape((1, d)).ok()?.to_owned();
2854 let (phi, jet) = evaluator.evaluate(coords.view()).ok()?;
2855 let decoder = &atom.decoder_coefficients;
2856 // m₁(t_c) = BᵀΦ(t_c) ∈ ℝᵖ (amplitude-1 center reconstruction).
2857 let mut recon = Array1::<f64>::zeros(p);
2858 for basis_col in 0..m {
2859 let phi_v = phi[[0, basis_col]];
2860 if phi_v == 0.0 {
2861 continue;
2862 }
2863 for out in 0..p {
2864 recon[out] += phi_v * decoder[[basis_col, out]];
2865 }
2866 }
2867 // J₁[axis] = Bᵀ (∂Φ/∂t_axis) ∈ ℝᵖ (amplitude-1; z factors out analytically).
2868 let mut jm = Array2::<f64>::zeros((d, p));
2869 for axis in 0..d {
2870 for basis_col in 0..m {
2871 let dphi = jet[[0, basis_col, axis]];
2872 if dphi == 0.0 {
2873 continue;
2874 }
2875 for out in 0..p {
2876 jm[[axis, out]] += dphi * decoder[[basis_col, out]];
2877 }
2878 }
2879 }
2880 // H_GN = J₁ J₁ᵀ + ridge·I ∈ ℝ^{d×d}.
2881 let mut h = Array2::<f64>::zeros((d, d));
2882 for a in 0..d {
2883 for b in 0..d {
2884 h[[a, b]] = jm.row(a).dot(&jm.row(b));
2885 }
2886 h[[a, a]] += ridge;
2887 }
2888 let (vals, vecs) = h.eigh(Side::Lower).ok()?;
2889 let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
2890 if !(lambda_min.is_finite() && lambda_min > 0.0) {
2891 return None;
2892 }
2893 // A₁ = H_GN⁻¹ J₁ via the eigendecomposition: H⁻¹ = Σ_i (1/λᵢ) vᵢ vᵢᵀ, so
2894 // A₁[:, out] = Σ_i (vᵢ · J₁[:, out]) / λᵢ · vᵢ. Column-by-column keeps it the
2895 // d×p Jacobian (one SPD solve reused across all p output channels).
2896 let mut a1 = Array2::<f64>::zeros((d, p));
2897 for out in 0..p {
2898 let jcol = jm.column(out);
2899 for (i, &lam) in vals.iter().enumerate() {
2900 if !(lam.is_finite() && lam > 0.0) {
2901 return None;
2902 }
2903 let vi = vecs.column(i);
2904 let coeff = vi.dot(&jcol) / lam;
2905 for row in 0..d {
2906 a1[[row, out]] += coeff * vi[row];
2907 }
2908 }
2909 }
2910 Some((a1, recon))
2911}
2912
2913/// Route a target row to the nearest chart of an atom by reconstruction
2914/// distance: the chart whose center reconstruction `m(t_c)` is closest to `x`.
2915/// Returns the chart index and the distance, or `None` when the atom has no
2916/// charts.
2917pub(crate) fn nearest_chart(
2918 atom_atlas: &AtomEncodeAtlas,
2919 x: ArrayView1<'_, f64>,
2920) -> Option<(usize, f64)> {
2921 if atom_atlas.charts.is_empty() {
2922 return None;
2923 }
2924 let mut best: Option<(usize, f64)> = None;
2925 for (idx, chart) in atom_atlas.charts.iter().enumerate() {
2926 if chart.certified_radius <= 0.0 {
2927 continue;
2928 }
2929 // Reuse the offline-distilled `m(t_c) = B^T Phi(t_c)` (`chart.recon_center`)
2930 // instead of re-evaluating the basis at a fixed center per row — see
2931 // `nearest_charts_topk`. Distance accumulated in place, no temporary array.
2932 let mut dist = 0.0;
2933 for (r, xv) in chart.recon_center.iter().zip(x.iter()) {
2934 let diff = r - xv;
2935 dist += diff * diff;
2936 }
2937 if best.map(|(_, b)| dist < b).unwrap_or(true) {
2938 best = Some((idx, dist));
2939 }
2940 }
2941 best
2942}
2943
2944/// The `k` charts whose CENTER reconstruction `m(t_c)` is nearest to `x` in
2945/// ambient ‖·‖², returned as chart indices sorted by increasing distance (ties
2946/// broken by chart index — deterministic). Only certifiable charts
2947/// (`certified_radius > 0`) are considered, exactly like [`nearest_chart`], whose
2948/// single result is `nearest_charts_topk(.., 1)[0]`. Used by the certified encode
2949/// to refine the global basin on self-approaching atoms (see
2950/// [`CERTIFIED_ROUTING_TOPK`]).
2951pub(crate) fn nearest_charts_topk(
2952 atom_atlas: &AtomEncodeAtlas,
2953 x: ArrayView1<'_, f64>,
2954 k: usize,
2955) -> Vec<usize> {
2956 if atom_atlas.charts.is_empty() || k == 0 {
2957 return Vec::new();
2958 }
2959 let mut scored: Vec<(usize, f64)> = Vec::new();
2960 for (idx, chart) in atom_atlas.charts.iter().enumerate() {
2961 if chart.certified_radius <= 0.0 {
2962 continue;
2963 }
2964 // `m(t_c) = BᵀΦ(t_c)` is an OFFLINE per-chart constant already distilled
2965 // into `chart.recon_center` at build time (bit-for-bit the same φ·decoder
2966 // accumulation this used to recompute). Reuse it instead of re-evaluating
2967 // the basis at a fixed center for every row — that re-eval was the encode's
2968 // dominant per-row cost (charts × rows basis evals, each allocating the φ/jet
2969 // arrays). `‖m(t_c) − x‖²` computed in place (no temporary diff array).
2970 let mut dist = 0.0;
2971 for (r, xv) in chart.recon_center.iter().zip(x.iter()) {
2972 let diff = r - xv;
2973 dist += diff * diff;
2974 }
2975 scored.push((idx, dist));
2976 }
2977 // Sort by distance, then chart index for a deterministic, first-wins order
2978 // consistent with `nearest_chart`'s strict-`<` tie rule.
2979 scored.sort_by(|a, b| {
2980 a.1.partial_cmp(&b.1)
2981 .unwrap_or(std::cmp::Ordering::Equal)
2982 .then(a.0.cmp(&b.0))
2983 });
2984 scored.into_iter().take(k).map(|(idx, _)| idx).collect()
2985}
2986
2987/// Reconstruction error `‖x − z·m(t)‖` of an encoded coordinate `t` — the
2988/// criterion the certified encode minimizes over its candidate charts to pick the
2989/// GLOBAL basin. `m(t) = Bᵀ Φ(t)` is the amplitude-1 reconstruction; `z` is the
2990/// amplitude. A non-finite reconstruction returns `+∞` so it never wins.
2991pub(crate) fn encode_reconstruction_error(
2992 atom: &SaeManifoldAtom,
2993 evaluator: &dyn SaeBasisEvaluator,
2994 coord: ArrayView1<'_, f64>,
2995 x: ArrayView1<'_, f64>,
2996 amplitude: f64,
2997) -> f64 {
2998 let d = atom.latent_dim;
2999 let p = atom.output_dim();
3000 let m = atom.basis_size();
3001 let coords = match coord.to_shape((1, d)) {
3002 Ok(c) => c.to_owned(),
3003 Err(_) => return f64::INFINITY,
3004 };
3005 let Ok((phi, _jet)) = evaluator.evaluate(coords.view()) else {
3006 return f64::INFINITY;
3007 };
3008 let mut err2 = 0.0;
3009 for out in 0..p {
3010 let mut recon = 0.0;
3011 for basis_col in 0..m {
3012 recon += phi[[0, basis_col]] * atom.decoder_coefficients[[basis_col, out]];
3013 }
3014 let r = x[out] - amplitude * recon;
3015 err2 += r * r;
3016 }
3017 if err2.is_finite() { err2.sqrt() } else { f64::INFINITY }
3018}
3019
3020/// Maximum number of chart centers laid down per atom (the SHAPE_BAND grid
3021/// point cap; mirrors `SHAPE_BAND_MAX_POINTS` in the atom band machinery).
3022pub(crate) const SHAPE_BAND_MAX_POINTS: usize = 512;
3023
3024/// Lay down chart centers on an atom's coordinate grid (the SHAPE_BAND grid
3025/// idiom): a regular grid spanning the compact latent domain for periodic /
3026/// sphere / torus atoms, and a strided cover of the latent axes for unbounded
3027/// (Duchon / Euclidean) atoms.
3028///
3029/// Periodic / torus latents are fractions of one period, so the per-axis grid
3030/// spans `[0, 1)`; the sphere chart spans `lat ∈ [−π/2, π/2]`, `lon ∈ [−π, π)`.
3031/// These conventions match the basis evaluators (the fraction-of-period circle
3032/// harmonic and the lat/lon sphere chart).
3033/// Squared coordinate distance between two latent points under the atom's chart
3034/// geometry: per-axis WRAPPED distance `min(|a−b|, period−|a−b|)` on periodic
3035/// (circle) axes — period 1 to match `chart_center_grid`'s `[0,1)` torus tiling
3036/// — and plain difference on line axes. Used to place + size data-driven charts.
3037pub(crate) fn coord_dist_sq(atom: &SaeManifoldAtom, a: ArrayView1<'_, f64>, b: ArrayView1<'_, f64>) -> f64 {
3038 // Per-axis period comes from the ONE canonical source `latent_axis_period` —
3039 // the SAME convention the certified-encode `in_chart` guard uses (via
3040 // `latent_coordinate_distance`): period-1 fraction axes for periodic/torus and
3041 // the cylinder angle, period-2π on the sphere LONGITUDE (axis 1), and
3042 // NON-periodic otherwise — including the sphere LATITUDE (axis 0), which ranges
3043 // over [−π/2, π/2] and must NOT wrap. A former local `periodic_axis` closure
3044 // wrapped BOTH sphere axes at period 1, so two genuinely-far radian longitudes
3045 // (e.g. 0 and 3 rad) collapsed to distance 0 — corrupting the farthest-point
3046 // placement AND the nearest-neighbour radii in `data_driven_chart_centers` for
3047 // sphere atoms, and disagreeing with the metric the encode's `in_chart`
3048 // soundness guard enforces. Delegating keeps the two metrics identical.
3049 let mut acc = 0.0;
3050 for axis in 0..a.len().min(b.len()) {
3051 let mut d = (a[axis] - b[axis]).abs();
3052 if let Some(period) = latent_axis_period(atom, axis) {
3053 let wrapped = d.rem_euclid(period);
3054 d = wrapped.min(period - wrapped);
3055 }
3056 acc += d * d;
3057 }
3058 acc
3059}
3060
3061/// Greedy farthest-point sampling of up to `max_charts` chart centers from the
3062/// atom's latent `coords` (n × d), with each center's nominal radius set to half
3063/// the distance to its nearest neighbor center (floored, so a singleton/coincident
3064/// cluster still gets a usable ball). Deterministic: seeds from row 0, then
3065/// repeatedly adds the coord maximally far (under [`coord_dist_sq`]) from the
3066/// chosen set — coverage-maximizing and reproducible run-to-run.
3067pub(crate) fn data_driven_chart_centers(
3068 atom: &SaeManifoldAtom,
3069 coords: ArrayView2<'_, f64>,
3070 max_charts: usize,
3071) -> Result<(Array2<f64>, Vec<f64>), String> {
3072 let n = coords.nrows();
3073 let d = coords.ncols();
3074 if d != atom.latent_dim {
3075 return Err(format!(
3076 "data_driven_chart_centers: coords have {d} cols but atom latent_dim is {}",
3077 atom.latent_dim
3078 ));
3079 }
3080 if n == 0 {
3081 return Ok((Array2::<f64>::zeros((0, d)), Vec::new()));
3082 }
3083 let k = max_charts.min(n);
3084 // Farthest-point sampling: maintain each row's distance to the nearest chosen
3085 // center, add the row with the maximum such distance each step.
3086 let mut chosen: Vec<usize> = Vec::with_capacity(k);
3087 chosen.push(0);
3088 let mut nearest_sq: Vec<f64> = (0..n)
3089 .map(|r| coord_dist_sq(atom, coords.row(r), coords.row(0)))
3090 .collect();
3091 while chosen.len() < k {
3092 // Pick the row farthest from the current center set (first-wins tie).
3093 let mut best = 0usize;
3094 let mut best_d = -1.0;
3095 for r in 0..n {
3096 if nearest_sq[r] > best_d {
3097 best_d = nearest_sq[r];
3098 best = r;
3099 }
3100 }
3101 if best_d <= 0.0 {
3102 break; // all remaining rows coincide with a chosen center.
3103 }
3104 chosen.push(best);
3105 for r in 0..n {
3106 let dr = coord_dist_sq(atom, coords.row(r), coords.row(best));
3107 if dr < nearest_sq[r] {
3108 nearest_sq[r] = dr;
3109 }
3110 }
3111 }
3112 let m = chosen.len();
3113 let mut centers = Array2::<f64>::zeros((m, d));
3114 for (i, &row) in chosen.iter().enumerate() {
3115 centers.row_mut(i).assign(&coords.row(row));
3116 }
3117 // Per-center radius = half the nearest-OTHER-center distance, floored so a
3118 // coincident pair still yields a positive ball, capped at 0.5 (the largest
3119 // meaningful half-period on a unit circle).
3120 let mut radii = vec![0.0_f64; m];
3121 for i in 0..m {
3122 let mut nn = f64::INFINITY;
3123 for j in 0..m {
3124 if i == j {
3125 continue;
3126 }
3127 let dsq = coord_dist_sq(atom, centers.row(i), centers.row(j));
3128 if dsq < nn {
3129 nn = dsq;
3130 }
3131 }
3132 let r = if nn.is_finite() { 0.5 * nn.sqrt() } else { 0.5 };
3133 radii[i] = r.max(1.0e-3).min(0.5);
3134 }
3135 Ok((centers, radii))
3136}
3137
3138pub(crate) fn chart_center_grid(atom: &SaeManifoldAtom, resolution: usize) -> Array2<f64> {
3139 use crate::manifold::SaeAtomBasisKind::*;
3140 let d = atom.latent_dim;
3141 match &atom.basis_kind {
3142 Periodic | Torus => regular_product_grid(d, resolution, 0.0, 1.0, false),
3143 // Cylinder `S¹ × ℝ`: axis 0 is the periodic circle `[0, 1)` (no
3144 // endpoint, like the harmonic axes); axis 1 is the unbounded line,
3145 // covered by a strided unit box `[-0.5, 0.5]` about the origin (like the
3146 // Euclidean patch). The certified radius refines each chart; out-of-cover
3147 // line starts route to the exact fallback honestly.
3148 Cylinder if d == 2 => cylinder_chart_center_grid(resolution),
3149 Cylinder => regular_product_grid(d, resolution, -0.5, 0.5, true),
3150 Sphere if d == 2 => sphere_latlon_grid(resolution),
3151 Linear | Sphere | Duchon | EuclideanPatch | Poincare | Precomputed(_) => {
3152 // Unbounded / non-compact latents: a strided cover of a unit box
3153 // about the origin per axis. The certified radius refines each chart;
3154 // out-of-cover starts route to the exact fallback honestly.
3155 regular_product_grid(d, resolution, -0.5, 0.5, true)
3156 }
3157 }
3158}
3159
3160/// A regular `resolution`-per-axis product grid over `[lo, hi]^d`, capped at
3161/// [`SHAPE_BAND_MAX_POINTS`] total points (the per-axis resolution is reduced
3162/// until the product fits). When `include_endpoint` the last grid point sits at
3163/// `hi`; otherwise the axis is treated as periodic and stops one step short.
3164/// Per-axis resolution actually used by [`regular_product_grid`] after the
3165/// [`SHAPE_BAND_MAX_POINTS`] product cap. Chart radii must be derived from THIS
3166/// (not the raw `resolution`), otherwise for `resolution^d > SHAPE_BAND_MAX_POINTS`
3167/// the grid spacing is coarser than the radius and the charts leave gaps.
3168pub(crate) fn capped_per_axis(d: usize, resolution: usize) -> usize {
3169 let mut per_axis = resolution.max(2);
3170 while per_axis.saturating_pow(d as u32) > SHAPE_BAND_MAX_POINTS && per_axis > 2 {
3171 per_axis -= 1;
3172 }
3173 per_axis
3174}
3175
3176pub(crate) fn regular_product_grid(
3177 d: usize,
3178 resolution: usize,
3179 lo: f64,
3180 hi: f64,
3181 include_endpoint: bool,
3182) -> Array2<f64> {
3183 if d == 0 {
3184 return Array2::<f64>::zeros((1, 0));
3185 }
3186 let per_axis = capped_per_axis(d, resolution);
3187 let total = per_axis.saturating_pow(d as u32).max(1);
3188 let denom = if include_endpoint {
3189 (per_axis.max(2) - 1) as f64
3190 } else {
3191 per_axis as f64
3192 };
3193 let mut grid = Array2::<f64>::zeros((total, d));
3194 let mut idx = vec![0usize; d];
3195 for flat in 0..total {
3196 for axis in 0..d {
3197 let frac = idx[axis] as f64 / denom;
3198 grid[[flat, axis]] = lo + (hi - lo) * frac;
3199 }
3200 for axis in (0..d).rev() {
3201 idx[axis] += 1;
3202 if idx[axis] < per_axis {
3203 break;
3204 }
3205 idx[axis] = 0;
3206 }
3207 }
3208 grid
3209}
3210
3211/// Lat/lon sphere chart grid: `lat ∈ [−π/2, π/2]`, `lon ∈ [−π, π)`, matching
3212/// the [`crate::manifold::SphereChartEvaluator`] convention.
3213pub(crate) fn sphere_latlon_grid(resolution: usize) -> Array2<f64> {
3214 use std::f64::consts::PI;
3215 let r = resolution.max(2).min(22); // 22² = 484 ≤ SHAPE_BAND_MAX_POINTS.
3216 let mut grid = Array2::<f64>::zeros((r * r, 2));
3217 for i in 0..r {
3218 let lat = -PI / 2.0 + PI * (i as f64 + 0.5) / r as f64;
3219 for j in 0..r {
3220 let lon = -PI + 2.0 * PI * (j as f64) / r as f64;
3221 grid[[i * r + j, 0]] = lat;
3222 grid[[i * r + j, 1]] = lon;
3223 }
3224 }
3225 grid
3226}
3227
3228/// Cylinder `S¹ × ℝ` chart-center grid: axis 0 sweeps the periodic circle over
3229/// one period `[0, 1)` (no endpoint, matching the harmonic axis), axis 1 strides
3230/// a unit box `[−0.5, 0.5]` about the origin on the unbounded line (with
3231/// endpoint). Capped at [`SHAPE_BAND_MAX_POINTS`] total centers.
3232pub(crate) fn cylinder_chart_center_grid(resolution: usize) -> Array2<f64> {
3233 let mut per_axis = resolution.max(2);
3234 while per_axis * per_axis > SHAPE_BAND_MAX_POINTS && per_axis > 2 {
3235 per_axis -= 1;
3236 }
3237 let total = per_axis * per_axis;
3238 let line_denom = (per_axis.max(2) - 1) as f64;
3239 let mut grid = Array2::<f64>::zeros((total, 2));
3240 for i in 0..per_axis {
3241 // Periodic axis 0: stop one step short of the period.
3242 let circle = i as f64 / per_axis as f64;
3243 for j in 0..per_axis {
3244 // Line axis 1: include the endpoint of the unit box.
3245 let line = -0.5 + (j as f64) / line_denom;
3246 grid[[i * per_axis + j, 0]] = circle;
3247 grid[[i * per_axis + j, 1]] = line;
3248 }
3249 }
3250 grid
3251}
3252
3253/// Nominal in-chart radius: half the inter-center grid spacing, so charts tile
3254/// the domain. For compact latents this is the grid step; for unbounded latents
3255/// a unit default that the certified radius refines.
3256pub(crate) fn chart_nominal_radius(atom: &SaeManifoldAtom, resolution: usize) -> f64 {
3257 use crate::manifold::SaeAtomBasisKind::*;
3258 match &atom.basis_kind {
3259 Periodic | Torus => 0.5 / (capped_per_axis(atom.latent_dim, resolution) as f64),
3260 // Must use the SAME capped per-axis count `min(resolution, 22)` that
3261 // `sphere_latlon_grid` lays the centers on: the coarsest tiling step is
3262 // the longitude half-spacing `π/r` with `r = min(resolution, 22)`. Deriving
3263 // the radius from the RAW resolution makes it smaller than the grid spacing
3264 // once `resolution > 22`, leaving gaps between charts so rows in the gaps
3265 // spuriously route to the exact fallback (the exact hazard `capped_per_axis`
3266 // documents for the regular grid).
3267 Sphere => std::f64::consts::PI / (resolution.max(2).min(22) as f64),
3268 // Cylinder charts tile two heterogeneous axes (a `[0,1)` periodic step
3269 // and a unit-box line step); the chart radius is a single scalar, so we
3270 // take the tighter (periodic) step `0.5/res` to keep every chart valid
3271 // on both axes. The certified Kantorovich radius refines it per chart.
3272 Cylinder => 0.5 / (capped_per_axis(atom.latent_dim, resolution) as f64),
3273 Linear | Duchon | EuclideanPatch | Poincare | Precomputed(_) => {
3274 1.0 / (resolution.max(2) as f64)
3275 }
3276 }
3277}
3278
3279/// Build the [`ChartRegion`] for a center, attaching the radial r_min / r_max
3280/// bracket for Duchon atoms (the chart's distance range to the kernel centers).
3281pub(crate) fn chart_region(
3282 atom: &SaeManifoldAtom,
3283 center: Array1<f64>,
3284 radius: f64,
3285) -> ChartRegion {
3286 use crate::manifold::SaeAtomBasisKind::*;
3287 let region = ChartRegion::new(center.clone(), radius);
3288 match &atom.basis_kind {
3289 Duchon => {
3290 // r ranges over [‖t_c‖ − radius, ‖t_c‖ + radius] about the single
3291 // origin-anchored center used by the conservative radial bound.
3292 //
3293 // The lower bound must be `max(0, center_norm − radius)` — NOT floored
3294 // at `radius`. When the chart contains the kernel center
3295 // (`center_norm < radius`, true r_min = 0), flooring at `radius`
3296 // would give a finite, NON-CONSERVATIVE `r_min`, causing the
3297 // hessian_sup / third_sup formulas (which divide by r_min) to
3298 // underestimate the Lipschitz constant and potentially grant a false
3299 // Kantorovich certificate. Flooring at `f64::MIN_POSITIVE` instead
3300 // correctly drives the formulas toward ∞, producing a very large L
3301 // that will NEVER certify (rows route to the exact multi-start
3302 // fallback) — conservative and sound.
3303 let center_norm = center.dot(¢er).sqrt();
3304 let r_min = (center_norm - radius).max(f64::MIN_POSITIVE);
3305 let r_max = center_norm + radius;
3306 region.with_radial_bounds(r_min, r_max)
3307 }
3308 // Cylinder has no radial kernel block (it is a harmonic × polynomial
3309 // tensor, not a Duchon radial basis), so it needs no radial r_min/r_max.
3310 Periodic | Sphere | Torus | Cylinder | Linear | EuclideanPatch | Poincare
3311 | Precomputed(_) => region,
3312 }
3313}
3314
3315#[cfg(test)]
3316mod encode_fix_tests {
3317 //! Unit tests for the two `certify_with_basin_warmup` fixes:
3318 //! FIX #3 — wrap-aware chart containment (`in_chart` now measures latent
3319 //! distance on the atom's periodic geometry via
3320 //! [`latent_coordinate_distance`]).
3321 //! FIX #4 — multiplicative sufficient-decrease bound on the warm-up loop
3322 //! ([`warmup_progress_sufficient`]).
3323 use super::*;
3324 use crate::manifold::SaeAtomBasisKind;
3325 use ndarray::{Array1, Array2, Array3};
3326
3327 /// Minimal atom carrying only a `basis_kind`; the wrap-aware metric reads no
3328 /// other field, so a tiny well-formed basis/decoder/penalty suffices.
3329 fn tiny_atom(kind: SaeAtomBasisKind, latent_dim: usize) -> SaeManifoldAtom {
3330 let m = 2usize;
3331 let phi = Array2::<f64>::eye(m);
3332 let jet = Array3::<f64>::zeros((m, m, latent_dim));
3333 let dec = Array2::<f64>::from_elem((m, 1), 0.5);
3334 let smooth = Array2::<f64>::eye(m);
3335 SaeManifoldAtom::new("tiny", kind, latent_dim, phi, jet, dec, smooth)
3336 .expect("tiny atom builds")
3337 }
3338
3339 /// FIX #3: a point on the far side of the wrap seam of a periodic axis is now
3340 /// IN-CHART under the wrap-aware metric, where the old raw-Euclidean sum used
3341 /// by `in_chart` rejected it. This is the exact predicate `in_chart` evaluates
3342 /// (`latent_coordinate_distance(atom, t, center) <= radius`).
3343 #[test]
3344 fn wrap_aware_containment_accepts_seam_point() {
3345 let atom = tiny_atom(SaeAtomBasisKind::Periodic, 1);
3346 let t = Array1::from(vec![0.99f64]);
3347 let center = Array1::from(vec![0.01f64]);
3348 let radius = 0.05;
3349
3350 // Precondition: the OLD raw-Euclidean latent distance is 0.98 — far
3351 // outside the 0.05 ball, so the old `in_chart` REJECTED this iterate.
3352 let raw = (t[0] - center[0]).abs();
3353 assert!(
3354 raw > radius,
3355 "precondition: raw Euclidean distance {raw} must exceed the chart radius {radius}"
3356 );
3357
3358 // The wrap-aware metric sees the true circle distance 0.02 (period 1.0),
3359 // so the seam point is now correctly IN-CHART.
3360 let d = latent_coordinate_distance(&atom, t.view(), center.view());
3361 assert!(
3362 (d - 0.02).abs() < 1e-12,
3363 "wrap-aware distance across the seam must be 0.02, got {d}"
3364 );
3365 assert!(
3366 d <= radius,
3367 "wrap-aware seam point must now be in-chart (d={d} <= r={radius})"
3368 );
3369 }
3370
3371 /// Soundness guard for FIX #3: on a NON-periodic axis the metric must NOT
3372 /// wrap — the same coordinates stay at their full Euclidean separation and
3373 /// (correctly) remain OUT of the small ball. This preserves the
3374 /// never-issue-a-false-certificate invariant for flat-patch families.
3375 #[test]
3376 fn wrap_aware_containment_does_not_wrap_flat_axis() {
3377 let atom = tiny_atom(SaeAtomBasisKind::EuclideanPatch, 1);
3378 let t = Array1::from(vec![0.99f64]);
3379 let center = Array1::from(vec![0.01f64]);
3380 let d = latent_coordinate_distance(&atom, t.view(), center.view());
3381 assert!(
3382 (d - 0.98).abs() < 1e-12,
3383 "a flat (non-periodic) axis must keep the full 0.98 distance, got {d}"
3384 );
3385 assert!(d > 0.05, "flat-axis point correctly stays out of a 0.05 ball");
3386 }
3387
3388 /// FIX #4: a genuinely converging (Kantorovich-quadratic) `h`-sequence is
3389 /// untouched — every step clears the sufficient-decrease bar, so no
3390 /// converging row is regressed to the fallback.
3391 #[test]
3392 fn converging_h_sequence_is_never_flagged() {
3393 // Quadratic Newton contraction h_{k+1} = h_k^2 from an uncertified start.
3394 let mut h = 0.9f64;
3395 while h > KANTOROVICH_THRESHOLD {
3396 let h_next = h * h;
3397 assert!(
3398 warmup_progress_sufficient(h_next, h),
3399 "converging step {h} -> {h_next} must be accepted (no regression)"
3400 );
3401 h = h_next;
3402 }
3403 }
3404
3405 /// FIX #4: a monotone `h`-sequence decreasing toward a limit ABOVE ½ (the
3406 /// pathological plateau that the old strict-decrease rule spun on until the
3407 /// increments fell below one ulp) now terminates in a BOUNDED, small number
3408 /// of `warmup_progress_sufficient` steps — flag-to-fallback rather than loop.
3409 #[test]
3410 fn plateau_h_sequence_terminates_bounded() {
3411 let limit = 0.65f64; // plateau limit strictly above ½: never certifies
3412 let ratio = 0.8f64; // geometric approach; ratio -> 1 as h -> limit
3413 let mut h = 2.0f64; // uncertified start
3414 let start = h;
3415 let mut accepted = 0usize;
3416 let mut iters = 0usize;
3417 loop {
3418 iters += 1;
3419 // Hard ceiling: proves boundedness. The OLD strict-decrease rule
3420 // would run ~1e15 steps here (h strictly decreases every step toward
3421 // 0.65 and only stalls below one ulp), so tripping this assert would
3422 // signal a regression to the unbounded behavior.
3423 assert!(
3424 iters < 10_000,
3425 "warm-up must terminate in a bounded number of steps, not loop"
3426 );
3427 let h_next = limit + (h - limit) * ratio; // monotone decreasing > limit
3428 if !warmup_progress_sufficient(h_next, h) {
3429 break; // flag to the exact fallback
3430 }
3431 accepted += 1;
3432 h = h_next;
3433 }
3434 // It made real progress before flagging (a warm-up, not an instant bail)…
3435 assert!(
3436 accepted >= 1,
3437 "expected the warm-up to accept at least one contracting step first"
3438 );
3439 // …and it flagged while still strictly above the plateau limit — exactly
3440 // where the OLD rule would have kept spinning (h is still shrinking).
3441 assert!(
3442 h > limit && h < start,
3443 "flagged mid-descent (limit < h={h} < start={start}); old rule would not stop here"
3444 );
3445 // Termination bound was tiny, not astronomical.
3446 assert!(iters < 200, "plateau flagged in {iters} steps (bounded)");
3447 }
3448
3449 /// FIX #4: non-finite `h` (indefinite / blown-up Hessian) is treated as
3450 /// insufficient progress — the row flags rather than being accepted.
3451 #[test]
3452 fn non_finite_h_flags() {
3453 assert!(!warmup_progress_sufficient(f64::NAN, 0.9));
3454 assert!(!warmup_progress_sufficient(f64::INFINITY, 0.9));
3455 assert!(!warmup_progress_sufficient(0.5, f64::NAN));
3456 }
3457
3458 /// BUG (sphere data-driven placement): `coord_dist_sq` must measure sphere
3459 /// coordinates in RADIANS via the canonical `latent_axis_period` — longitude
3460 /// (axis 1) wraps at 2π, latitude (axis 0) does NOT wrap — not at unit period
3461 /// on both axes. The old period-1 wrap collapsed genuinely-far longitudes to
3462 /// distance 0, corrupting `data_driven_chart_centers` for sphere atoms.
3463 #[test]
3464 fn coord_dist_sq_sphere_uses_radian_metric_not_unit_period() {
3465 let atom = tiny_atom(SaeAtomBasisKind::Sphere, 2);
3466 // Longitude differs by 3.0 rad: circle distance min(3, 2π−3)=3
3467 // (2π−3 ≈ 3.283) ⇒ dist² = 9. The old period-1 wrap read 3 mod 1 = 0.
3468 let a = Array1::from(vec![0.0f64, 0.0]);
3469 let b = Array1::from(vec![0.0f64, 3.0]);
3470 let dsq = coord_dist_sq(&atom, a.view(), b.view());
3471 assert!(
3472 (dsq - 9.0).abs() < 1e-9,
3473 "sphere longitude dist² must be 3²=9 (2π radian period), got {dsq}"
3474 );
3475 // Latitude (axis 0) must NOT wrap: 1.2 rad apart ⇒ dist² = 1.44.
3476 let c = Array1::from(vec![0.0f64, 0.0]);
3477 let e = Array1::from(vec![1.2f64, 0.0]);
3478 let dsq_lat = coord_dist_sq(&atom, c.view(), e.view());
3479 assert!(
3480 (dsq_lat - 1.44).abs() < 1e-9,
3481 "sphere latitude must not wrap; 1.2²=1.44, got {dsq_lat}"
3482 );
3483 // Must agree with the certified-encode metric (single source of truth).
3484 let dc = latent_coordinate_distance(&atom, a.view(), b.view());
3485 assert!(
3486 (dc * dc - dsq).abs() < 1e-9,
3487 "coord_dist_sq must equal latent_coordinate_distance² ({} vs {dsq})",
3488 dc * dc
3489 );
3490 }
3491
3492 /// No-regression: periodic / torus axes still wrap at unit period.
3493 #[test]
3494 fn coord_dist_sq_torus_still_wraps_unit_period() {
3495 let atom = tiny_atom(SaeAtomBasisKind::Torus, 1);
3496 let a = Array1::from(vec![0.02f64]);
3497 let b = Array1::from(vec![0.98f64]);
3498 // Wrapped circle distance min(0.96, 0.04) = 0.04 ⇒ dist² = 0.0016.
3499 let dsq = coord_dist_sq(&atom, a.view(), b.view());
3500 assert!((dsq - 0.0016).abs() < 1e-12, "torus unit-period wrap; got {dsq}");
3501 }
3502
3503 /// BUG (sphere chart tiling): `chart_nominal_radius` for the sphere must use
3504 /// the SAME capped per-axis count `min(resolution, 22)` as `sphere_latlon_grid`,
3505 /// so the radius covers the (capped) longitude half-spacing π/22 and the charts
3506 /// tile without gaps for resolution > 22 (raw π/40 would leave gaps).
3507 #[test]
3508 fn chart_nominal_radius_sphere_covers_capped_grid_spacing() {
3509 let atom = tiny_atom(SaeAtomBasisKind::Sphere, 2);
3510 let lon_half_spacing = std::f64::consts::PI / 22.0; // r capped at 22
3511 let r = chart_nominal_radius(&atom, 40);
3512 assert!(
3513 r >= lon_half_spacing - 1e-12,
3514 "sphere radius {r} must cover the capped lon half-spacing {lon_half_spacing} \
3515 (no gaps for resolution>22); raw π/40={} would gap",
3516 std::f64::consts::PI / 40.0
3517 );
3518 }
3519}