Skip to main content

gam_sae/gpu_kernels/
sae_encode_resident.rs

1//! Device-resident **exact per-row certified SAE encode** (#988).
2//!
3//! The production CPU encode is [`crate::encode::EncodeAtlas::certified_encode_row`]:
4//! for one atom and one target row `x` at fixed amplitude `z` it
5//!
6//!   1. **routes** the row to the `topk` nearest certified charts by ambient
7//!      reconstruction distance `‖BᵀΦ(t_c) − x‖²` (the *active-set routing*),
8//!   2. **warm-starts** each candidate from that chart's distilled IFT affine
9//!      predictor `t̂ = t_c + (1/z)·A₁·(x − z·m₁)`,
10//!   3. runs the **per-row latent-coordinate Newton** solve inside the
11//!      Kantorovich basin: at each iterate it forms the FULL Hessian
12//!      `H = JₘᵀJₘ + r·∂²m + ridge·I`, takes the Newton step `δ = −H⁻¹g`, and
13//!      evaluates the certificate `h = β·η·L` (`β = 1/λ_min(H)`, `η = ‖δ‖`),
14//!      first navigating into the basin (`h ≤ ½`) then refining `newton_steps`,
15//!   4. **assigns** the row to the lowest-reconstruction-error CERTIFIED
16//!      candidate (the *assignment/gate solve*), and
17//!   5. otherwise returns the nearest chart's uncertified result — the
18//!      *certificate/fallback* the exact multi-start solve owns.
19//!
20//! This module ships that whole pipeline as a **device kernel** for the
21//! `EuclideanPatch` monomial family (the atom family whose basis
22//! `Φ_α(t) = Π_axis t_axis^{α_axis}` is closed-form-evaluable on-device with
23//! exact first/second jets — see [`crate::basis::EuclideanPatchEvaluator`]).
24//! One CUDA block encodes one row; the per-row work is done serially by the
25//! block's lead thread so the accumulation order is byte-identical to the
26//! host oracle (the same `tid == 0` idiom the fused Arrow-Schur kernel in
27//! `gam_solve::gpu_kernels::arrow_schur_nvrtc` uses for its Cholesky).
28//!
29//! # Correctness without a GPU
30//!
31//! Exactly the #1017 pattern of `arrow_schur_nvrtc`:
32//!
33//! * [`emulate_certified_encode_row`] is a device-free CPU emulator that mirrors
34//!   the kernel's arithmetic and control flow line-for-line — the SAME monomial
35//!   evaluation, the SAME cyclic-Jacobi symmetric eigensolver
36//!   ([`jacobi_eigh`], the device stand-in for the host LAPACK `eigh`), the SAME
37//!   basin-warmup / refine loop, the SAME routing + assignment. It is the CPU
38//!   fallback AND the exactness oracle the kernel is pinned to.
39//! * The parity tests assert the emulator reproduces the production
40//!   [`crate::encode::EncodeAtlas::certified_encode_row`] on planted + random
41//!   rows (support/coords/amplitude/certificate within a tight tol; the only
42//!   divergence is Jacobi-vs-LAPACK eigen round-off).
43//! * On Linux the CUDA source compiles to PTX through the shared
44//!   `--fmad=false` NVRTC options ([`gam_gpu::device_cache::compile_ptx_arch`]),
45//!   matching the sibling kernels; a device, when present, runs it and the
46//!   dispatch reports [`EncodePath::Device`] honestly (the #1026/#1551 gate).
47//!
48//! # What still needs real hardware
49//!
50//! Running the PTX (a launch on a CUDA device) and confirming device==emulator
51//! to round-off requires a GPU. Everything else — the kernel source, the
52//! emulator, the parity against production, and (on a CUDA host) the NVRTC→PTX
53//! compile + PTX audit — is verified without one.
54
55use std::time::Instant;
56
57use crate::encode::{
58    AtlasConfig, AtomEncodeAtlas, KANTOROVICH_THRESHOLD, euclidean_patch_degree,
59};
60use crate::manifold::SaeManifoldAtom;
61use gam_gpu::policy::{EncodeDecisionBlocked, EncodeDeploymentDecision};
62
63/// One `EuclideanPatch` atom's frozen encode data, flattened for a device
64/// launch. This is exactly what the online encode reads: the monomial exponent
65/// table, the decoder `B`, and the offline-certified charts. Built from a real
66/// atom + its [`AtomEncodeAtlas`] by [`EncodeAtomDevice::from_atom_atlas`] so
67/// the device path consumes the identical data the CPU path does.
68#[derive(Debug, Clone)]
69pub struct EncodeAtomDevice {
70    /// Latent dimension `d`.
71    pub d: usize,
72    /// Basis size `m` (number of monomials of total degree ≤ degree).
73    pub m: usize,
74    /// Output dimension `p`.
75    pub p: usize,
76    /// Number of nearest charts refined per row (`CERTIFIED_ROUTING_TOPK`).
77    pub topk: usize,
78    /// Online Newton refinement steps after a certified landing.
79    pub newton_steps: usize,
80    /// Levenberg ridge added to the per-row Hessian diagonal.
81    pub ridge: f64,
82    /// Monomial exponents, row-major `exponents[col*d + axis]`, length `m*d`.
83    pub exponents: Vec<i32>,
84    /// Decoder `B`, row-major `decoder[basis*p + out]`, length `m*p`.
85    pub decoder: Vec<f64>,
86    /// Charts (routing + warm-start + certificate constants).
87    pub charts: Vec<EncodeChartDevice>,
88}
89
90/// One offline-certified chart, flattened.
91#[derive(Debug, Clone)]
92pub struct EncodeChartDevice {
93    /// Chart center `t_c`, length `d`.
94    pub center: Vec<f64>,
95    /// In-chart radius (the Lipschitz-validity ball).
96    pub radius: f64,
97    /// Certified Newton radius (`> 0` ⇒ the chart is routable).
98    pub certified_radius: f64,
99    /// Closed-form Hessian-Lipschitz constant `L` over the chart.
100    pub lipschitz: f64,
101    /// Whether the chart carries a distilled IFT Jacobian `A₁` (finite β).
102    pub has_jacobian: bool,
103    /// `A₁`, row-major `a1[axis*p + out]`, length `d*p` (empty if `!has_jacobian`).
104    pub amortized_jacobian: Vec<f64>,
105    /// Amplitude-1 center reconstruction `m₁ = BᵀΦ(t_c)`, length `p`.
106    pub recon_center: Vec<f64>,
107}
108
109impl EncodeAtomDevice {
110    /// Extract the device encode data from a real `EuclideanPatch` atom and its
111    /// offline atlas. Recomputes the monomial exponent table (the atom's own
112    /// basis design) so the on-device evaluation is the SAME polynomial the host
113    /// `EuclideanPatchEvaluator` evaluates.
114    pub fn from_atom_atlas(
115        atom: &SaeManifoldAtom,
116        atom_atlas: &AtomEncodeAtlas,
117        config: &AtlasConfig,
118    ) -> Result<Self, String> {
119        let d = atom.latent_dim;
120        let p = atom.output_dim();
121        let m = atom.basis_size();
122        let degree = euclidean_patch_degree(d, m);
123        let exps = gam_terms::basis::monomial_exponents(d, degree);
124        if exps.len() != m {
125            return Err(format!(
126                "EncodeAtomDevice::from_atom_atlas: monomial table len {} != basis_size {m} \
127                 (atom is not a EuclideanPatch degree-{degree} monomial family)",
128                exps.len()
129            ));
130        }
131        let mut exponents = vec![0_i32; m * d];
132        for (col, alpha) in exps.iter().enumerate() {
133            for axis in 0..d {
134                exponents[col * d + axis] = alpha[axis] as i32;
135            }
136        }
137        let dec = &atom.decoder_coefficients;
138        if dec.dim() != (m, p) {
139            return Err(format!(
140                "EncodeAtomDevice::from_atom_atlas: decoder dim {:?} != ({m}, {p})",
141                dec.dim()
142            ));
143        }
144        let mut decoder = vec![0.0_f64; m * p];
145        for b in 0..m {
146            for c in 0..p {
147                decoder[b * p + c] = dec[[b, c]];
148            }
149        }
150        let mut charts = Vec::with_capacity(atom_atlas.charts.len());
151        for chart in &atom_atlas.charts {
152            let center = chart.region.center.to_vec();
153            if center.len() != d {
154                return Err(format!(
155                    "EncodeAtomDevice::from_atom_atlas: chart center len {} != d {d}",
156                    center.len()
157                ));
158            }
159            let (has_jacobian, amortized_jacobian) = match &chart.amortized_jacobian {
160                Some(a1) => {
161                    if a1.dim() != (d, p) {
162                        return Err(format!(
163                            "EncodeAtomDevice::from_atom_atlas: A1 dim {:?} != ({d}, {p})",
164                            a1.dim()
165                        ));
166                    }
167                    let mut flat = vec![0.0_f64; d * p];
168                    for axis in 0..d {
169                        for out in 0..p {
170                            flat[axis * p + out] = a1[[axis, out]];
171                        }
172                    }
173                    (true, flat)
174                }
175                None => (false, Vec::new()),
176            };
177            let recon_center = chart.recon_center.to_vec();
178            charts.push(EncodeChartDevice {
179                center,
180                radius: chart.region.radius,
181                certified_radius: chart.certified_radius,
182                lipschitz: chart.lipschitz,
183                has_jacobian,
184                amortized_jacobian,
185                recon_center,
186            });
187        }
188        Ok(Self {
189            d,
190            m,
191            p,
192            topk: crate::encode::CERTIFIED_ROUTING_TOPK,
193            newton_steps: config.newton_steps,
194            ridge: config.ridge,
195            exponents,
196            decoder,
197            charts,
198        })
199    }
200}
201
202/// A per-row Kantorovich certificate, the device/emulator mirror of
203/// [`crate::encode::RowCertificate`]. `certified()` uses the SAME `h ≤ ½` gate.
204#[derive(Debug, Clone, Copy, PartialEq)]
205pub struct DeviceRowCertificate {
206    pub beta: f64,
207    pub eta: f64,
208    pub lipschitz: f64,
209    pub h: f64,
210}
211
212impl DeviceRowCertificate {
213    #[inline]
214    #[must_use]
215    pub fn certified(&self) -> bool {
216        self.h.is_finite() && self.h <= KANTOROVICH_THRESHOLD
217    }
218    #[inline]
219    fn uncertified(lipschitz: f64) -> Self {
220        Self {
221            beta: f64::INFINITY,
222            eta: f64::INFINITY,
223            lipschitz,
224            h: f64::INFINITY,
225        }
226    }
227    #[inline]
228    fn uncertified_inf() -> Self {
229        Self {
230            beta: f64::INFINITY,
231            eta: f64::INFINITY,
232            lipschitz: f64::INFINITY,
233            h: f64::INFINITY,
234        }
235    }
236}
237
238/// One row's encode result: the latent coordinate and its certificate.
239#[derive(Debug, Clone)]
240pub struct DeviceEncodeRow {
241    pub coord: Vec<f64>,
242    pub cert: DeviceRowCertificate,
243}
244
245// ============================================================================
246// Numeric core — the byte-faithful CPU mirror of the device kernel. Every
247// function here has a 1:1 CUDA counterpart in `ENCODE_KERNEL_SOURCE`; the CUDA
248// comments name the mirror. These are also the CPU fallback path.
249// ============================================================================
250
251/// `base^exp` by exponentiation-by-squaring, matching `f64::powi` (LLVM
252/// `llvm.powi.f64`) so the monomial evaluation is bit-identical to the host
253/// `EuclideanPatchEvaluator`. Used by the emulator; the kernel `dpow` mirror is
254/// the same algorithm.
255#[inline]
256fn dpow(base: f64, exp: i32) -> f64 {
257    // The production monomial code calls `coords.powi(exp)`; using the SAME
258    // intrinsic here keeps phi/jet/hess bit-identical to production.
259    base.powi(exp)
260}
261
262/// Monomial basis value/first/second jets at one coordinate `t` (length `d`).
263/// Mirrors [`crate::basis::EuclideanPatchEvaluator::evaluate`] +
264/// [`crate::basis::EuclideanPatchEvaluator::second_jet`] (the same falling-
265/// factorial monomial derivatives), producing:
266///   `phi[col]`, `jet[col*d + axis]`, `hess[(col*d + a)*d + c]`.
267fn eval_basis(dev: &EncodeAtomDevice, t: &[f64], phi: &mut [f64], jet: &mut [f64], hess: &mut [f64]) {
268    let (d, m) = (dev.d, dev.m);
269    let exp = &dev.exponents;
270    for col in 0..m {
271        // value = Π_axis t_axis^{α_axis}
272        let mut value = 1.0_f64;
273        for axis in 0..d {
274            let e = exp[col * d + axis];
275            if e != 0 {
276                value *= dpow(t[axis], e);
277            }
278        }
279        phi[col] = value;
280        // first jet: ∂/∂t_axis = α_axis · Π_a t_a^{(a==axis? α_a-1 : α_a)}
281        for axis in 0..d {
282            let a_axis = exp[col * d + axis];
283            let mut jval = 0.0_f64;
284            if a_axis != 0 {
285                jval = a_axis as f64;
286                for a in 0..d {
287                    let ea = if a == axis { a_axis - 1 } else { exp[col * d + a] };
288                    if ea != 0 {
289                        jval *= dpow(t[a], ea);
290                    }
291                }
292            }
293            jet[col * d + axis] = jval;
294        }
295        // second jet: ∂²/∂t_a∂t_c (falling factorial), else 0.
296        for a in 0..d {
297            for c in 0..d {
298                let mut hval = 0.0_f64;
299                let aa = exp[col * d + a];
300                let ac = exp[col * d + c];
301                let admissible = aa != 0 && (a == c || ac != 0);
302                if admissible {
303                    let lead = if a == c {
304                        (aa as f64) * ((aa - 1).max(0) as f64)
305                    } else {
306                        (aa as f64) * (ac as f64)
307                    };
308                    if lead != 0.0 {
309                        hval = lead;
310                        for axis in 0..d {
311                            let mut e = exp[col * d + axis];
312                            if axis == a {
313                                e = (e - 1).max(0);
314                            }
315                            if axis == c {
316                                e = (e - 1).max(0);
317                            }
318                            if e != 0 {
319                                hval *= dpow(t[axis], e);
320                            }
321                        }
322                    }
323                }
324                hess[(col * d + a) * d + c] = hval;
325            }
326        }
327    }
328}
329
330/// Amplitude-1 reconstruction `m₁(t) = BᵀΦ(t)` from precomputed `phi`.
331/// (Routing + reconstruction-error use this; `nearest_chart` mirror.)
332fn recon_amp1(dev: &EncodeAtomDevice, phi: &[f64], out: &mut [f64]) {
333    let (m, p) = (dev.m, dev.p);
334    for c in 0..p {
335        out[c] = 0.0;
336    }
337    for b in 0..m {
338        let pv = phi[b];
339        if pv == 0.0 {
340            continue;
341        }
342        for c in 0..p {
343            out[c] += pv * dev.decoder[b * p + c];
344        }
345    }
346}
347
348/// Evaluated basis buffers at a point `t`: value `Φ`, first jet `∂Φ`, and the
349/// second jet `∂²Φ`. Bundled so [`encode_grad_hess`] takes them as one argument.
350struct EvaluatedBasis<'a> {
351    phi: &'a [f64],
352    jet: &'a [f64],
353    hess: &'a [f64],
354}
355
356/// Gradient `g` and FULL Hessian `H` (+ ridge) of the encode objective at `t`.
357/// Mirror of [`crate::encode::encode_grad_hess`]:
358///   `g[a] = Jₘ[a]·r`,  `H[a,b] = Jₘ[a]·Jₘ[b] + z·Σ ∂²Φ·(r·B) + ridge·δ_ab`,
359/// with `m = z·BᵀΦ`, `r = m − x`, `Jₘ = z·BᵀJ_Φ`. For the monomial family the
360/// second jet always exists, so this never returns "no certificate".
361fn encode_grad_hess(
362    dev: &EncodeAtomDevice,
363    x: &[f64],
364    amplitude: f64,
365    be: &EvaluatedBasis<'_>,
366    g: &mut [f64],
367    h: &mut [f64],
368) {
369    let (phi, jet, hess) = (be.phi, be.jet, be.hess);
370    let (d, m, p) = (dev.d, dev.m, dev.p);
371    // recon m(t) = z·BᵀΦ ; residual r = m − x
372    let mut recon = vec![0.0_f64; p];
373    for b in 0..m {
374        let pv = phi[b];
375        if pv == 0.0 {
376            continue;
377        }
378        for c in 0..p {
379            recon[c] += amplitude * pv * dev.decoder[b * p + c];
380        }
381    }
382    let mut residual = vec![0.0_f64; p];
383    for c in 0..p {
384        residual[c] = recon[c] - x[c];
385    }
386    // Jₘ[axis][out] = z·Bᵀ ∂Φ/∂t_axis  (row-major jm[axis*p + out])
387    let mut jm = vec![0.0_f64; d * p];
388    for axis in 0..d {
389        for b in 0..m {
390            let dphi = jet[b * d + axis];
391            if dphi == 0.0 {
392                continue;
393            }
394            for c in 0..p {
395                jm[axis * p + c] += amplitude * dphi * dev.decoder[b * p + c];
396            }
397        }
398    }
399    // g[a] = Jₘ[a]·r ; H[a,b] = Jₘ[a]·Jₘ[b] + z·Σ_b ∂²Φ·(r·B)
400    for a in 0..d {
401        let mut ga = 0.0;
402        for c in 0..p {
403            ga += jm[a * p + c] * residual[c];
404        }
405        g[a] = ga;
406        for b in 0..d {
407            let mut hab = 0.0;
408            for c in 0..p {
409                hab += jm[a * p + c] * jm[b * p + c];
410            }
411            let mut curv = 0.0;
412            for basis in 0..m {
413                let d2 = hess[(basis * d + a) * d + b];
414                if d2 == 0.0 {
415                    continue;
416                }
417                let mut dot = 0.0;
418                for c in 0..p {
419                    dot += residual[c] * dev.decoder[basis * p + c];
420                }
421                curv += amplitude * d2 * dot;
422            }
423            hab += curv;
424            h[a * d + b] = hab;
425        }
426    }
427    for a in 0..d {
428        h[a * d + a] += dev.ridge;
429    }
430}
431
432/// Cyclic Jacobi symmetric eigensolver for a `d×d` matrix (row-major, `d ≤ 8`).
433/// Returns eigenvalues `vals[i]` and eigenvectors as COLUMNS
434/// `vecs[col*d + row]`. This is the device stand-in for the host LAPACK `eigh`
435/// used by [`crate::encode::beta_eta_newton`]; the Newton step is reconstructed
436/// from the (eigenvector-basis-independent) spectral sum, so the result agrees
437/// with LAPACK to eigen round-off. The CUDA `jacobi_eigh` mirror is identical.
438pub fn jacobi_eigh(a_in: &[f64], d: usize, vals: &mut [f64], vecs: &mut [f64]) {
439    // Working copy A (row-major), V = I.
440    let mut a = a_in.to_vec();
441    for r in 0..d {
442        for c in 0..d {
443            vecs[c * d + r] = if r == c { 1.0 } else { 0.0 };
444        }
445    }
446    if d == 1 {
447        vals[0] = a[0];
448        return;
449    }
450    // Fixed, deterministic sweep count: for d ≤ 8, 30 cyclic sweeps drive the
451    // off-diagonal norm to well below f64 round-off.
452    for _sweep in 0..30 {
453        // Off-diagonal magnitude; stop early when negligible.
454        let mut off = 0.0_f64;
455        for r in 0..d {
456            for c in (r + 1)..d {
457                off += a[r * d + c] * a[r * d + c];
458            }
459        }
460        if off <= 1e-300 {
461            break;
462        }
463        for pp in 0..d {
464            for q in (pp + 1)..d {
465                let apq = a[pp * d + q];
466                if apq == 0.0 {
467                    continue;
468                }
469                let app = a[pp * d + pp];
470                let aqq = a[q * d + q];
471                // Jacobi rotation angle (Golub & Van Loan 8.4.1).
472                let tau = (aqq - app) / (2.0 * apq);
473                let t = if tau >= 0.0 {
474                    1.0 / (tau + (1.0 + tau * tau).sqrt())
475                } else {
476                    -1.0 / (-tau + (1.0 + tau * tau).sqrt())
477                };
478                let cph = 1.0 / (1.0 + t * t).sqrt();
479                let sph = t * cph;
480                // Apply rotation to A (rows/cols pp,q).
481                for k in 0..d {
482                    let akp = a[k * d + pp];
483                    let akq = a[k * d + q];
484                    a[k * d + pp] = cph * akp - sph * akq;
485                    a[k * d + q] = sph * akp + cph * akq;
486                }
487                for k in 0..d {
488                    let apk = a[pp * d + k];
489                    let aqk = a[q * d + k];
490                    a[pp * d + k] = cph * apk - sph * aqk;
491                    a[q * d + k] = sph * apk + cph * aqk;
492                }
493                // Accumulate eigenvectors.
494                for k in 0..d {
495                    let vkp = vecs[pp * d + k];
496                    let vkq = vecs[q * d + k];
497                    vecs[pp * d + k] = cph * vkp - sph * vkq;
498                    vecs[q * d + k] = sph * vkp + cph * vkq;
499                }
500            }
501        }
502    }
503    for i in 0..d {
504        vals[i] = a[i * d + i];
505    }
506}
507
508/// `(β, η, δ)` from the full Hessian `H` and gradient `g`. Mirror of
509/// [`crate::encode::beta_eta_newton`]: `β = 1/λ_min`, `δ = −Σ_i (vᵢᵀg/λᵢ)vᵢ`,
510/// `η = ‖δ‖`; `None` when `λ_min ≤ 0` (uncertifiable start).
511fn beta_eta_newton(h: &[f64], g: &[f64], d: usize) -> Option<(f64, f64, Vec<f64>)> {
512    let mut vals = vec![0.0_f64; d];
513    let mut vecs = vec![0.0_f64; d * d];
514    jacobi_eigh(h, d, &mut vals, &mut vecs);
515    let mut lambda_min = f64::INFINITY;
516    for &v in &vals {
517        if v < lambda_min {
518            lambda_min = v;
519        }
520    }
521    if !(lambda_min.is_finite() && lambda_min > 0.0) {
522        return None;
523    }
524    let beta = 1.0 / lambda_min;
525    let mut delta = vec![0.0_f64; d];
526    for col in 0..d {
527        let lam = vals[col];
528        if lam <= 0.0 {
529            return None;
530        }
531        // vᵀg
532        let mut vg = 0.0;
533        for row in 0..d {
534            vg += vecs[col * d + row] * g[row];
535        }
536        let coeff = vg / lam;
537        for row in 0..d {
538            delta[row] -= coeff * vecs[col * d + row];
539        }
540    }
541    let mut eta = 0.0;
542    for row in 0..d {
543        eta += delta[row] * delta[row];
544    }
545    Some((beta, eta.sqrt(), delta))
546}
547
548/// Certificate + Newton step at `t`. Mirror of [`crate::encode::row_certificate`].
549fn row_certificate(
550    dev: &EncodeAtomDevice,
551    t: &[f64],
552    x: &[f64],
553    amplitude: f64,
554    lipschitz: f64,
555    scratch: &mut Scratch,
556) -> (DeviceRowCertificate, Vec<f64>) {
557    let d = dev.d;
558    eval_basis(dev, t, &mut scratch.phi, &mut scratch.jet, &mut scratch.hess);
559    encode_grad_hess(
560        dev,
561        x,
562        amplitude,
563        &EvaluatedBasis {
564            phi: &scratch.phi,
565            jet: &scratch.jet,
566            hess: &scratch.hess,
567        },
568        &mut scratch.g,
569        &mut scratch.h,
570    );
571    match beta_eta_newton(&scratch.h, &scratch.g, d) {
572        Some((beta, eta, delta)) => (
573            DeviceRowCertificate {
574                beta,
575                eta,
576                lipschitz,
577                h: beta * eta * lipschitz,
578            },
579            delta,
580        ),
581        None => (
582            DeviceRowCertificate::uncertified(lipschitz),
583            vec![0.0_f64; d],
584        ),
585    }
586}
587
588/// Per-row working buffers (register/stack arrays in the kernel).
589struct Scratch {
590    phi: Vec<f64>,
591    jet: Vec<f64>,
592    hess: Vec<f64>,
593    g: Vec<f64>,
594    h: Vec<f64>,
595}
596
597impl Scratch {
598    fn new(dev: &EncodeAtomDevice) -> Self {
599        let (d, m) = (dev.d, dev.m);
600        Self {
601            phi: vec![0.0; m],
602            jet: vec![0.0; m * d],
603            hess: vec![0.0; m * d * d],
604            g: vec![0.0; d],
605            h: vec![0.0; d * d],
606        }
607    }
608}
609
610#[inline]
611fn in_chart(t: &[f64], center: &[f64], radius: f64) -> bool {
612    let mut r2 = 0.0;
613    for i in 0..t.len() {
614        let dlt = t[i] - center[i];
615        r2 += dlt * dlt;
616    }
617    r2 <= radius * radius
618}
619
620/// Basin-warmup + refine from `t_start`. Mirror of
621/// [`crate::encode::certify_with_basin_warmup`] composed with
622/// `refine_certified_start`: navigate into the `h ≤ ½` basin (staying in-chart,
623/// requiring `h` to contract), then take `newton_steps` refine steps that must
624/// all stay certified. Returns `(coord, landing_cert)` or `None`.
625fn certify_with_basin_warmup(
626    dev: &EncodeAtomDevice,
627    mut t: Vec<f64>,
628    x: &[f64],
629    amplitude: f64,
630    chart: &EncodeChartDevice,
631    scratch: &mut Scratch,
632) -> Option<(Vec<f64>, DeviceRowCertificate)> {
633    if !in_chart(&t, &chart.center, chart.radius) {
634        return None;
635    }
636    let (mut cert, mut delta) =
637        row_certificate(dev, &t, x, amplitude, chart.lipschitz, scratch);
638    while !cert.certified() {
639        if !(cert.h.is_finite() && cert.beta.is_finite() && cert.eta.is_finite()) {
640            return None;
641        }
642        let prev_h = cert.h;
643        let mut next = t.clone();
644        for i in 0..dev.d {
645            next[i] += delta[i];
646        }
647        if !in_chart(&next, &chart.center, chart.radius) {
648            return None;
649        }
650        t = next;
651        let (nc, nd) = row_certificate(dev, &t, x, amplitude, chart.lipschitz, scratch);
652        cert = nc;
653        delta = nd;
654        if !cert.h.is_finite() || cert.h >= prev_h {
655            return None;
656        }
657    }
658    // refine_certified_start: `newton_steps` further, must stay certified.
659    // Mirror production's convergence early-exit: once the pending Newton step is
660    // below the coordinate ULP scale the certified root is reached and further steps
661    // only re-accumulate round-off (keeps device parity with the encode.rs fold).
662    let landing = cert;
663    for _ in 0..dev.newton_steps {
664        let dnorm = delta.iter().map(|v| v * v).sum::<f64>().sqrt();
665        let tnorm = t.iter().map(|v| v * v).sum::<f64>().sqrt();
666        if dnorm <= crate::encode::NEWTON_REFINE_CONVERGED_EPS * (1.0 + tnorm) {
667            break;
668        }
669        for i in 0..dev.d {
670            t[i] += delta[i];
671        }
672        let (nc, nd) = row_certificate(dev, &t, x, amplitude, chart.lipschitz, scratch);
673        if !nc.certified() {
674            return None;
675        }
676        delta = nd;
677    }
678    Some((t, landing))
679}
680
681/// Distilled affine warm start `t̂ = t_c + (1/z)·A₁·(x − z·m₁)`. Mirror of
682/// [`crate::encode::amortized_warm_start`]. `None` when the chart has no
683/// Jacobian or the amplitude is not strictly positive & finite.
684fn amortized_warm_start(chart: &EncodeChartDevice, x: &[f64], amplitude: f64, d: usize, p: usize) -> Option<Vec<f64>> {
685    if !chart.has_jacobian {
686        return None;
687    }
688    if !(amplitude.is_finite() && amplitude.abs() > 0.0) {
689        return None;
690    }
691    let mut t_hat = chart.center.clone();
692    for out in 0..p.min(chart.recon_center.len()) {
693        let resid = x[out] - amplitude * chart.recon_center[out];
694        for axis in 0..d {
695            t_hat[axis] += chart.amortized_jacobian[axis * p + out] * resid / amplitude;
696        }
697    }
698    Some(t_hat)
699}
700
701/// Reconstruction error `‖x − z·m(t)‖`. Mirror of
702/// [`crate::encode::encode_reconstruction_error`].
703fn recon_error(dev: &EncodeAtomDevice, t: &[f64], x: &[f64], amplitude: f64, scratch: &mut Scratch) -> f64 {
704    eval_basis(dev, t, &mut scratch.phi, &mut scratch.jet, &mut scratch.hess);
705    let mut err2 = 0.0;
706    let p = dev.p;
707    let mut recon = vec![0.0_f64; p];
708    recon_amp1(dev, &scratch.phi, &mut recon);
709    for c in 0..p {
710        let r = x[c] - amplitude * recon[c];
711        err2 += r * r;
712    }
713    if err2.is_finite() { err2.sqrt() } else { f64::INFINITY }
714}
715
716/// Top-`k` charts by center reconstruction distance, sorted by (distance, index)
717/// — mirror of [`crate::encode::nearest_charts_topk`]. Only certifiable charts
718/// (`certified_radius > 0`) are considered.
719fn nearest_charts_topk(dev: &EncodeAtomDevice, x: &[f64], scratch: &mut Scratch) -> Vec<usize> {
720    if dev.charts.is_empty() || dev.topk == 0 {
721        return Vec::new();
722    }
723    let p = dev.p;
724    let mut scored: Vec<(usize, f64)> = Vec::new();
725    let mut recon = vec![0.0_f64; p];
726    for (idx, chart) in dev.charts.iter().enumerate() {
727        if chart.certified_radius <= 0.0 {
728            continue;
729        }
730        eval_basis(dev, &chart.center, &mut scratch.phi, &mut scratch.jet, &mut scratch.hess);
731        recon_amp1(dev, &scratch.phi, &mut recon);
732        let mut dist = 0.0;
733        for c in 0..p {
734            let diff = recon[c] - x[c];
735            dist += diff * diff;
736        }
737        scored.push((idx, dist));
738    }
739    scored.sort_by(|a, b| {
740        a.1.partial_cmp(&b.1)
741            .unwrap_or(std::cmp::Ordering::Equal)
742            .then(a.0.cmp(&b.0))
743    });
744    scored.into_iter().take(dev.topk).map(|(i, _)| i).collect()
745}
746
747/// The full exact per-row certified encode for one `EuclideanPatch` atom — the
748/// device-free mirror of [`crate::encode::EncodeAtlas::certified_encode_row`].
749/// This is BOTH the CPU fallback and the exactness oracle the CUDA kernel is
750/// pinned to (the kernel does exactly this, one block per row).
751#[must_use]
752pub fn emulate_certified_encode_row(dev: &EncodeAtomDevice, x: &[f64], amplitude: f64) -> DeviceEncodeRow {
753    let d = dev.d;
754    let p = dev.p;
755    let mut scratch = Scratch::new(dev);
756    let candidates = nearest_charts_topk(dev, x, &mut scratch);
757    if candidates.is_empty() {
758        return DeviceEncodeRow {
759            coord: vec![0.0; d],
760            cert: DeviceRowCertificate::uncertified_inf(),
761        };
762    }
763    let mut best: Option<(Vec<f64>, DeviceRowCertificate, f64)> = None;
764    let mut nearest_fallback: Option<(Vec<f64>, DeviceRowCertificate)> = None;
765    for chart_idx in candidates {
766        let chart = &dev.charts[chart_idx];
767        let Some(t_hat) = amortized_warm_start(chart, x, amplitude, d, p) else {
768            if nearest_fallback.is_none() {
769                nearest_fallback = Some((vec![0.0; d], DeviceRowCertificate::uncertified(chart.lipschitz)));
770            }
771            continue;
772        };
773        let (coord, cert) = match certify_with_basin_warmup(dev, t_hat, x, amplitude, chart, &mut scratch) {
774            Some((c, cert)) => (c, cert),
775            None => (vec![0.0; d], DeviceRowCertificate::uncertified(chart.lipschitz)),
776        };
777        if nearest_fallback.is_none() {
778            nearest_fallback = Some((coord.clone(), cert));
779        }
780        if cert.certified() {
781            let err = recon_error(dev, &coord, x, amplitude, &mut scratch);
782            if best.as_ref().map(|(_, _, e)| err < *e).unwrap_or(true) {
783                best = Some((coord, cert, err));
784            }
785            // Mirror production certified_encode_row's global-minimum short-circuit
786            // (encode.rs): reconstruction error ≥ 0, so a certified candidate already
787            // at the ambient noise floor is provably the global optimum — stop
788            // refining the remaining top-K charts (keeps device parity with the fold).
789            if let Some((_, _, e)) = best.as_ref() {
790                let xnorm = x.iter().map(|v| v * v).sum::<f64>().sqrt();
791                if *e <= crate::encode::CERTIFIED_GLOBAL_MIN_RECON_FLOOR * (1.0 + xnorm) {
792                    break;
793                }
794            }
795        }
796    }
797    match best {
798        Some((coord, cert, _)) => DeviceEncodeRow { coord, cert },
799        None => {
800            let (coord, cert) = nearest_fallback
801                .unwrap_or_else(|| (vec![0.0; d], DeviceRowCertificate::uncertified_inf()));
802            DeviceEncodeRow { coord, cert }
803        }
804    }
805}
806
807/// Batched device-free encode over many rows (the CPU fallback of
808/// [`sae_certified_encode_batch`]). Row-independent, so order-stable.
809#[must_use]
810pub fn emulate_certified_encode_batch(
811    dev: &EncodeAtomDevice,
812    targets: &[Vec<f64>],
813    amplitudes: &[f64],
814) -> Vec<DeviceEncodeRow> {
815    targets
816        .iter()
817        .zip(amplitudes.iter())
818        .map(|(x, &amp)| emulate_certified_encode_row(dev, x, amp))
819        .collect()
820}
821
822// ============================================================================
823// Device kernel source (NVRTC). Faithful port of the numeric core above; one
824// block per row, the block's lead thread runs the whole row's encode serially
825// (order-identical to the emulator). Compile-time #defines D/M/P/TOPK/NEWTON.
826// ============================================================================
827
828/// The NVRTC source template. `DD`/`MM`/`PP`/`TOPK`/`NEWTON`/`RIDGE` are
829/// prepended by [`encode_kernel_source`] as `#define`s, matching the sibling
830/// kernels' pure `compile_ptx` invocation. Full f64, no fast-math — the encode
831/// arithmetic mirrors the CPU `EncodeAtomDevice` core term-for-term.
832pub const ENCODE_KERNEL_SOURCE: &str = r#"
833#define KANTOROVICH 0.5
834
835__device__ __forceinline__ double dpow(double b, int e){
836  // exponentiation-by-squaring, matching llvm.powi/f64::powi and the emulator dpow.
837  if (e == 0) return 1.0;
838  int n = e < 0 ? -e : e;
839  double r = 1.0, base = b;
840  while (n > 0){ if (n & 1) r *= base; n >>= 1; if (n) base *= base; }
841  return e < 0 ? 1.0 / r : r;
842}
843
844// Monomial phi/jet/hess at t (mirror of eval_basis).
845__device__ void eval_basis(const int* exps, const double* t,
846                           double* phi, double* jet, double* hess){
847  for (int col=0; col<MM; ++col){
848    double value = 1.0;
849    for (int axis=0; axis<DD; ++axis){ int e=exps[col*DD+axis]; if(e!=0) value*=dpow(t[axis],e); }
850    phi[col]=value;
851    for (int axis=0; axis<DD; ++axis){
852      int a_axis=exps[col*DD+axis]; double jval=0.0;
853      if (a_axis!=0){ jval=(double)a_axis;
854        for(int a=0;a<DD;++a){ int ea=(a==axis)?a_axis-1:exps[col*DD+a]; if(ea!=0) jval*=dpow(t[a],ea); } }
855      jet[col*DD+axis]=jval;
856    }
857    for (int a=0;a<DD;++a) for(int c=0;c<DD;++c){
858      double hval=0.0; int aa=exps[col*DD+a]; int ac=exps[col*DD+c];
859      int adm = (aa!=0) && (a==c || ac!=0);
860      if (adm){
861        double lead = (a==c) ? (double)aa*(double)((aa-1)>0?(aa-1):0)
862                             : (double)aa*(double)ac;
863        if (lead!=0.0){ hval=lead;
864          for(int axis=0;axis<DD;++axis){ int e=exps[col*DD+axis];
865            if(axis==a) e=(e-1)>0?(e-1):0; if(axis==c) e=(e-1)>0?(e-1):0;
866            if(e!=0) hval*=dpow(t[axis],e); } }
867      }
868      hess[(col*DD+a)*DD+c]=hval;
869    }
870  }
871}
872
873__device__ void recon_amp1(const double* dec, const double* phi, double* out){
874  for(int c=0;c<PP;++c) out[c]=0.0;
875  for(int b=0;b<MM;++b){ double pv=phi[b]; if(pv==0.0) continue;
876    for(int c=0;c<PP;++c) out[c]+=pv*dec[b*PP+c]; }
877}
878
879// grad g[D] and full Hessian h[D*D] (+ridge). Mirror of encode_grad_hess.
880__device__ void grad_hess(const double* dec, const double* t, const double* x, double amp,
881                          const double* phi, const double* jet, const double* hess,
882                          double* g, double* h){
883  double recon[PP]; double residual[PP]; double jm[DD*PP];
884  for(int c=0;c<PP;++c) recon[c]=0.0;
885  for(int b=0;b<MM;++b){ double pv=phi[b]; if(pv==0.0) continue;
886    for(int c=0;c<PP;++c) recon[c]+=amp*pv*dec[b*PP+c]; }
887  for(int c=0;c<PP;++c) residual[c]=recon[c]-x[c];
888  for(int i=0;i<DD*PP;++i) jm[i]=0.0;
889  for(int axis=0;axis<DD;++axis) for(int b=0;b<MM;++b){ double dphi=jet[b*DD+axis]; if(dphi==0.0) continue;
890    for(int c=0;c<PP;++c) jm[axis*PP+c]+=amp*dphi*dec[b*PP+c]; }
891  for(int a=0;a<DD;++a){
892    double ga=0.0; for(int c=0;c<PP;++c) ga+=jm[a*PP+c]*residual[c]; g[a]=ga;
893    for(int b=0;b<DD;++b){
894      double hab=0.0; for(int c=0;c<PP;++c) hab+=jm[a*PP+c]*jm[b*PP+c];
895      double curv=0.0;
896      for(int basis=0;basis<MM;++basis){ double d2=hess[(basis*DD+a)*DD+b]; if(d2==0.0) continue;
897        double dot=0.0; for(int c=0;c<PP;++c) dot+=residual[c]*dec[basis*PP+c];
898        curv+=amp*d2*dot; }
899      h[a*DD+b]=hab+curv;
900    }
901  }
902  for(int a=0;a<DD;++a) h[a*DD+a]+=RIDGE;
903}
904
905// Cyclic Jacobi eigensolver (mirror of jacobi_eigh); vecs columns: vecs[col*D+row].
906__device__ void jacobi_eigh(const double* a_in, double* vals, double* vecs){
907  double a[DD*DD];
908  for(int i=0;i<DD*DD;++i) a[i]=a_in[i];
909  for(int r=0;r<DD;++r) for(int c=0;c<DD;++c) vecs[c*DD+r]=(r==c)?1.0:0.0;
910  if (DD==1){ vals[0]=a[0]; return; }
911  for(int sweep=0;sweep<30;++sweep){
912    double off=0.0;
913    for(int r=0;r<DD;++r) for(int c=r+1;c<DD;++c) off+=a[r*DD+c]*a[r*DD+c];
914    if (off<=1e-300) break;
915    for(int p=0;p<DD;++p) for(int q=p+1;q<DD;++q){
916      double apq=a[p*DD+q]; if(apq==0.0) continue;
917      double app=a[p*DD+p]; double aqq=a[q*DD+q];
918      double tau=(aqq-app)/(2.0*apq);
919      double t = (tau>=0.0) ? 1.0/(tau+sqrt(1.0+tau*tau)) : -1.0/(-tau+sqrt(1.0+tau*tau));
920      double cph=1.0/sqrt(1.0+t*t); double sph=t*cph;
921      for(int k=0;k<DD;++k){ double akp=a[k*DD+p]; double akq=a[k*DD+q];
922        a[k*DD+p]=cph*akp-sph*akq; a[k*DD+q]=sph*akp+cph*akq; }
923      for(int k=0;k<DD;++k){ double apk=a[p*DD+k]; double aqk=a[q*DD+k];
924        a[p*DD+k]=cph*apk-sph*aqk; a[q*DD+k]=sph*apk+cph*aqk; }
925      for(int k=0;k<DD;++k){ double vkp=vecs[p*DD+k]; double vkq=vecs[q*DD+k];
926        vecs[p*DD+k]=cph*vkp-sph*vkq; vecs[q*DD+k]=sph*vkp+cph*vkq; }
927    }
928  }
929  for(int i=0;i<DD;++i) vals[i]=a[i*DD+i];
930}
931
932// beta/eta/delta; returns 1 on success (lambda_min>0), 0 otherwise.
933__device__ int beta_eta_newton(const double* h, const double* g,
934                               double* beta, double* eta, double* delta){
935  double vals[DD]; double vecs[DD*DD];
936  jacobi_eigh(h, vals, vecs);
937  double lmin=1.0/0.0; // +inf
938  for(int i=0;i<DD;++i) if(vals[i]<lmin) lmin=vals[i];
939  if (!(isfinite(lmin) && lmin>0.0)) return 0;
940  *beta=1.0/lmin;
941  for(int i=0;i<DD;++i) delta[i]=0.0;
942  for(int col=0;col<DD;++col){ double lam=vals[col]; if(lam<=0.0) return 0;
943    double vg=0.0; for(int row=0;row<DD;++row) vg+=vecs[col*DD+row]*g[row];
944    double coeff=vg/lam; for(int row=0;row<DD;++row) delta[row]-=coeff*vecs[col*DD+row]; }
945  double e2=0.0; for(int i=0;i<DD;++i) e2+=delta[i]*delta[i]; *eta=sqrt(e2);
946  return 1;
947}
948
949// row_certificate: writes h_out (=beta*eta*L or +inf) and delta; returns certified 0/1 mask via h.
950__device__ void row_certificate(const int* exps, const double* dec,
951                                const double* t, const double* x, double amp, double L,
952                                double* h_out, double* beta_out, double* eta_out, double* delta){
953  double phi[MM]; double jet[MM*DD]; double hess[MM*DD*DD]; double g[DD]; double H[DD*DD];
954  eval_basis(exps, t, phi, jet, hess);
955  grad_hess(dec, t, x, amp, phi, jet, hess, g, H);
956  double beta, eta;
957  if (beta_eta_newton(H, g, &beta, &eta, delta)){
958    *beta_out=beta; *eta_out=eta; *h_out=beta*eta*L;
959  } else {
960    *beta_out=1.0/0.0; *eta_out=1.0/0.0; *h_out=1.0/0.0;
961    for(int i=0;i<DD;++i) delta[i]=0.0;
962  }
963}
964
965__device__ int in_chart(const double* t, const double* center, double radius){
966  double r2=0.0; for(int i=0;i<DD;++i){ double d=t[i]-center[i]; r2+=d*d; }
967  return r2 <= radius*radius;
968}
969
970// certify_with_basin_warmup + refine. Returns 1 with coord/landing_h on success.
971__device__ int certify_basin(const int* exps, const double* dec,
972                             const double* t_start, const double* x, double amp,
973                             const double* center, double radius, double L,
974                             double* coord_out, double* landing_h){
975  double t[DD]; for(int i=0;i<DD;++i) t[i]=t_start[i];
976  if(!in_chart(t, center, radius)) return 0;
977  double h, beta, eta; double delta[DD];
978  row_certificate(exps, dec, t, x, amp, L, &h, &beta, &eta, delta);
979  while(!(isfinite(h) && h<=KANTOROVICH)){
980    if(!(isfinite(h) && isfinite(beta) && isfinite(eta))) return 0;
981    double prev_h=h;
982    double next[DD]; for(int i=0;i<DD;++i) next[i]=t[i]+delta[i];
983    if(!in_chart(next, center, radius)) return 0;
984    for(int i=0;i<DD;++i) t[i]=next[i];
985    row_certificate(exps, dec, t, x, amp, L, &h, &beta, &eta, delta);
986    if(!(isfinite(h)) || h>=prev_h) return 0;
987  }
988  double landing = h;
989  for(int s=0;s<NEWTON;++s){
990    // convergence early-exit (mirror production refine_certified_start).
991    double dnorm=0.0, tnorm=0.0;
992    for(int i=0;i<DD;++i){ dnorm+=delta[i]*delta[i]; tnorm+=t[i]*t[i]; }
993    if(sqrt(dnorm) <= REFINE_EPS*(1.0+sqrt(tnorm))) break;
994    for(int i=0;i<DD;++i) t[i]+=delta[i];
995    row_certificate(exps, dec, t, x, amp, L, &h, &beta, &eta, delta);
996    if(!(isfinite(h) && h<=KANTOROVICH)) return 0;
997  }
998  for(int i=0;i<DD;++i) coord_out[i]=t[i];
999  *landing_h=landing;
1000  return 1;
1001}
1002
1003// One block per row. Charts are stored flattened; the block's lead thread runs
1004// the full route -> warm-start -> certify -> assign pipeline serially.
1005extern "C" __global__ void sae_certified_encode(
1006    const int*    __restrict__ exps,           // MM*DD
1007    const double* __restrict__ dec,            // MM*PP
1008    const double* __restrict__ centers,        // n_charts*DD
1009    const double* __restrict__ radii,          // n_charts
1010    const double* __restrict__ cert_radii,     // n_charts
1011    const double* __restrict__ lips,           // n_charts
1012    const int*    __restrict__ has_jac,        // n_charts
1013    const double* __restrict__ a1,             // n_charts*DD*PP
1014    const double* __restrict__ recon_c,        // n_charts*PP
1015    int n_charts,
1016    const double* __restrict__ targets,        // n*PP
1017    const double* __restrict__ amps,           // n
1018    int n,
1019    double* __restrict__ coords_out,           // n*DD
1020    double* __restrict__ h_out,                // n   (certificate h; >0.5 or inf = uncertified)
1021    int*    __restrict__ certified_out)        // n   (1/0)
1022{
1023  int row = blockIdx.x;
1024  if (row >= n) return;
1025  if (threadIdx.x != 0) return;
1026  const double* x = targets + (size_t)row*PP;
1027  double amp = amps[row];
1028
1029  // ---- routing: top-TOPK certifiable charts by center recon distance. ----
1030  int cand[TOPK]; double cand_d[TOPK]; int ncand=0;
1031  {
1032    double phi[MM]; double jet[MM*DD]; double hess[MM*DD*DD]; double recon[PP];
1033    for(int idx=0; idx<n_charts; ++idx){
1034      if (cert_radii[idx] <= 0.0) continue;
1035      eval_basis(exps, centers + (size_t)idx*DD, phi, jet, hess);
1036      recon_amp1(dec, phi, recon);
1037      double dist=0.0; for(int c=0;c<PP;++c){ double df=recon[c]-x[c]; dist+=df*df; }
1038      // insert into the sorted top-TOPK by (dist, idx).
1039      int pos=ncand;
1040      while(pos>0 && (cand_d[pos-1]>dist)){ if(pos<TOPK){cand_d[pos]=cand_d[pos-1]; cand[pos]=cand[pos-1];} pos--; }
1041      if(pos<TOPK){ cand_d[pos]=dist; cand[pos]=idx; if(ncand<TOPK) ncand++; }
1042    }
1043  }
1044  // defaults: uncertified.
1045  for(int i=0;i<DD;++i) coords_out[(size_t)row*DD+i]=0.0;
1046  h_out[row]=1.0/0.0; certified_out[row]=0;
1047  if(ncand==0) return;
1048
1049  int have_fallback=0; double fb_coord[DD]; double fb_h; int fb_cert;
1050  int have_best=0; double best_coord[DD]; double best_h; double best_err=1.0/0.0;
1051
1052  for(int ci=0; ci<ncand; ++ci){
1053    int idx=cand[ci];
1054    const double* center = centers + (size_t)idx*DD;
1055    double radius=radii[idx]; double L=lips[idx];
1056    // amortized_warm_start.
1057    int ok_ws = has_jac[idx] && isfinite(amp) && (amp!=0.0);
1058    double t_hat[DD]; int produced=0; double coord[DD]; double landing_h; int cert=0;
1059    if(ok_ws){
1060      const double* A1 = a1 + (size_t)idx*DD*PP;
1061      const double* m1 = recon_c + (size_t)idx*PP;
1062      for(int i=0;i<DD;++i) t_hat[i]=center[i];
1063      for(int out=0; out<PP; ++out){ double resid=x[out]-amp*m1[out];
1064        for(int axis=0;axis<DD;++axis) t_hat[axis]+=A1[axis*PP+out]*resid/amp; }
1065      if(certify_basin(exps, dec, t_hat, x, amp, center, radius, L, coord, &landing_h)){
1066        produced=1; cert=(isfinite(landing_h) && landing_h<=KANTOROVICH);
1067      } else { produced=1; for(int i=0;i<DD;++i) coord[i]=0.0; landing_h=1.0/0.0; cert=0; }
1068    }
1069    if(!ok_ws){
1070      // warm start declined: fallback candidate = zeros, uncertified.
1071      if(!have_fallback){ have_fallback=1; for(int i=0;i<DD;++i) fb_coord[i]=0.0; fb_h=1.0/0.0; fb_cert=0; }
1072      continue;
1073    }
1074    if(!have_fallback){ have_fallback=1; for(int i=0;i<DD;++i) fb_coord[i]=coord[i]; fb_h=landing_h; fb_cert=cert; }
1075    if(cert){
1076      // reconstruction error at coord.
1077      double phi[MM]; double jet[MM*DD]; double hess[MM*DD*DD]; double recon[PP];
1078      eval_basis(exps, coord, phi, jet, hess); recon_amp1(dec, phi, recon);
1079      double e2=0.0; for(int c=0;c<PP;++c){ double r=x[c]-amp*recon[c]; e2+=r*r; }
1080      double err = isfinite(e2)? sqrt(e2) : 1.0/0.0;
1081      if(!have_best || err<best_err){ have_best=1; best_err=err; best_h=landing_h; for(int i=0;i<DD;++i) best_coord[i]=coord[i]; }
1082      // global-min short-circuit (mirror production certified_encode_row).
1083      double xnorm2=0.0; for(int c=0;c<PP;++c) xnorm2+=x[c]*x[c];
1084      if(best_err <= GMIN_FLOOR*(1.0+sqrt(xnorm2))) break;
1085    }
1086    (void)produced;
1087  }
1088  if(have_best){
1089    for(int i=0;i<DD;++i) coords_out[(size_t)row*DD+i]=best_coord[i];
1090    h_out[row]=best_h; certified_out[row]=1;
1091  } else if(have_fallback){
1092    for(int i=0;i<DD;++i) coords_out[(size_t)row*DD+i]=fb_coord[i];
1093    h_out[row]=fb_h; certified_out[row]=fb_cert;
1094  }
1095}
1096"#;
1097
1098/// Build the full NVRTC source for one `(d, m, p, topk, newton, ridge)`
1099/// instantiation, prepending the `#define`s so the compile is a pure
1100/// `compile_ptx_arch` matching `sae_rowjet` / `arrow_schur_nvrtc`.
1101#[cfg(target_os = "linux")]
1102#[must_use]
1103pub fn encode_kernel_source(dev: &EncodeAtomDevice) -> String {
1104    format!(
1105        "#define DD {}\n#define MM {}\n#define PP {}\n#define TOPK {}\n#define NEWTON {}\n\
1106         #define RIDGE ({:e})\n#define GMIN_FLOOR ({:e})\n#define REFINE_EPS ({:e})\n\
1107         {ENCODE_KERNEL_SOURCE}",
1108        dev.d,
1109        dev.m,
1110        dev.p,
1111        dev.topk,
1112        dev.newton_steps,
1113        dev.ridge,
1114        crate::encode::CERTIFIED_GLOBAL_MIN_RECON_FLOOR,
1115        crate::encode::NEWTON_REFINE_CONVERGED_EPS
1116    )
1117}
1118
1119/// Which path produced the encode result — the #1026/#1551 honesty flag so a
1120/// caller can ASSERT the device engaged instead of silently falling back.
1121#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1122pub enum EncodePath {
1123    /// The NVRTC `sae_certified_encode` kernel compiled and ran on the device.
1124    Device,
1125    /// The host `EncodeAtomDevice` emulator ran (no Linux / no CUDA runtime /
1126    /// below the launch break-even).
1127    Cpu,
1128}
1129
1130/// Minimum row count below which the device launch is not worth its fixed cost.
1131pub const DEVICE_ROW_THRESHOLD: usize = 4_096;
1132
1133/// Batched certified encode over many rows, on the GPU when a CUDA device is
1134/// admitted and the batch amortises the launch, else on the CPU emulator. The
1135/// returned [`EncodePath`] reports which path ran honestly (`device_encode_engaged`).
1136/// Both paths run the SAME numeric core (Jacobi eigensolve, monomial jets), so
1137/// when the device runs its result matches the CPU oracle to eigen round-off.
1138#[must_use]
1139pub fn sae_certified_encode_batch(
1140    dev: &EncodeAtomDevice,
1141    targets: &[Vec<f64>],
1142    amplitudes: &[f64],
1143) -> (Vec<DeviceEncodeRow>, EncodePath) {
1144    #[cfg(target_os = "linux")]
1145    {
1146        if targets.len() >= DEVICE_ROW_THRESHOLD {
1147            if let Ok(out) = device::sae_certified_encode_device(dev, targets, amplitudes) {
1148                return (out, EncodePath::Device);
1149            }
1150            // Fall through to CPU on any device error (accelerator, not oracle).
1151        }
1152    }
1153    (
1154        emulate_certified_encode_batch(dev, targets, amplitudes),
1155        EncodePath::Cpu,
1156    )
1157}
1158
1159/// Measured throughput of the device-resident **exact per-row certified encode**
1160/// ([`sae_certified_encode_batch`]) — the literal "batched exact per-row GPU
1161/// encode" of #988, timed end to end (routing + amortized warm start + basin
1162/// Newton + Kantorovich certificate + lowest-error assignment/fallback), NOT a
1163/// component solve like [`gam_gpu::encode_throughput::measure_resident_solve_throughput`]
1164/// (which times only the resident normal-equations inner cell).
1165///
1166/// The point of this struct is [`Self::decision`]: the #988 surrogate question
1167/// ("is the exact encode fast enough at 10⁹ rows, or must we distill a certified
1168/// amortized surrogate?") is answered by *this* measurement and only this one.
1169/// The decision is keyed on [`EncodeDeploymentDecision::from_device_measurement`]
1170/// with `engaged = (path == EncodePath::Device)`, so it inherits that type's
1171/// anti-green-wash contract: a CPU-emulator run (`path == Cpu`) can NEVER declare
1172/// the surrogate unneeded — it is honestly [`EncodeDeploymentDecision::Undetermined`]
1173/// (blocked on hardware), no matter how fast the CPU rate is. Only a real device
1174/// launch of the exact-encode kernel can move the decision to `Met`/`Unmet`.
1175#[derive(Debug, Clone, Copy)]
1176pub struct DeviceEncodeThroughput {
1177    /// Rows encoded in the timed batch.
1178    pub n_rows: usize,
1179    /// Wall-clock seconds for the full exact encode of the batch.
1180    pub encode_secs: f64,
1181    /// `n_rows / encode_secs` (`0.0` for a degenerate / non-positive time).
1182    pub rows_per_sec: f64,
1183    /// Which path actually ran the encode — the #1026/#1551 honesty flag.
1184    pub path: EncodePath,
1185    /// The #988 surrogate decision keyed on THIS exact-encode measurement.
1186    /// `Met`/`Unmet` only when `path == EncodePath::Device`; a CPU-emulator run
1187    /// is `Undetermined { NoDeviceEncodeKernel-adjacent }` — a fast CPU number is
1188    /// never a device pass.
1189    pub decision: EncodeDeploymentDecision,
1190}
1191
1192impl DeviceEncodeThroughput {
1193    /// `true` iff the exact-encode kernel actually ran on a CUDA device — the
1194    /// only state in which [`Self::decision`] is a real `Met`/`Unmet`.
1195    #[must_use]
1196    pub fn device_engaged(&self) -> bool {
1197        matches!(self.path, EncodePath::Device)
1198    }
1199}
1200
1201/// Benchmark the device-resident exact per-row encode over a batch and gate the
1202/// #988 certified-surrogate decision on the measured throughput.
1203///
1204/// Runs [`sae_certified_encode_batch`] once to warm allocations/compile/module
1205/// caches, then once more under a wall-clock timer, and reports the measured
1206/// rows/sec together with the honest [`EncodePath`] and the derived
1207/// [`EncodeDeploymentDecision`]:
1208///
1209/// * On a CUDA host with `targets.len() >= DEVICE_ROW_THRESHOLD` the batch runs
1210///   on the device (`path == Device`), the measurement is a genuine device rate,
1211///   and the decision is `Met` (≥ 100k rows/sec/GPU ⇒ ship the exact encode, no
1212///   surrogate) or `Unmet` (surrogate justified) by the number.
1213/// * On a CPU-only host (or below the launch threshold) the emulator runs
1214///   (`path == Cpu`); the rate is real but it is NOT a device measurement, so the
1215///   decision is `Undetermined` — the surrogate stays neither justified nor
1216///   refuted. This is the honest "needs GPU hardware" outcome.
1217#[must_use]
1218pub fn measure_device_encode_throughput(
1219    dev: &EncodeAtomDevice,
1220    targets: &[Vec<f64>],
1221    amplitudes: &[f64],
1222) -> DeviceEncodeThroughput {
1223    let n = targets.len();
1224    // Warm run (device module load / PTX cache / first-touch allocations) is not
1225    // timed, mirroring the resident-solve and full-path benchmarks.
1226    drop(sae_certified_encode_batch(dev, targets, amplitudes));
1227    let start = Instant::now();
1228    let (_out, path) = sae_certified_encode_batch(dev, targets, amplitudes);
1229    let elapsed = start.elapsed();
1230    let encode_secs = elapsed.as_secs_f64();
1231    let rows_per_sec = if n > 0 && encode_secs > 0.0 {
1232        n as f64 / encode_secs
1233    } else {
1234        0.0
1235    };
1236    let engaged = matches!(path, EncodePath::Device);
1237    // Key the surrogate decision on the measurement. When the device did not run
1238    // the exact encode, report the honest blocked reason rather than a
1239    // `DeviceNotEngaged` false-routing (the kernel exists, but no device ran it).
1240    let decision = if engaged {
1241        EncodeDeploymentDecision::from_device_measurement(true, rows_per_sec)
1242    } else {
1243        EncodeDeploymentDecision::blocked(EncodeDecisionBlocked::NoDevice)
1244    };
1245    DeviceEncodeThroughput {
1246        n_rows: n,
1247        encode_secs,
1248        rows_per_sec,
1249        path,
1250        decision,
1251    }
1252}
1253
1254#[cfg(target_os = "linux")]
1255mod device {
1256    use super::{
1257        DeviceEncodeRow, DeviceRowCertificate, EncodeAtomDevice, encode_kernel_source,
1258    };
1259    use gam_gpu::gpu_error::{GpuError, GpuResultExt};
1260    use std::collections::HashMap;
1261    use std::sync::{Arc, Mutex, OnceLock};
1262
1263    use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
1264
1265    struct Backend {
1266        ctx: Arc<CudaContext>,
1267        stream: Arc<CudaStream>,
1268        modules: Mutex<HashMap<String, Arc<CudaModule>>>,
1269    }
1270
1271    fn backend() -> Result<&'static Backend, GpuError> {
1272        static BACKEND: OnceLock<Result<Backend, GpuError>> = OnceLock::new();
1273        BACKEND
1274            .get_or_init(|| {
1275                let parts = gam_gpu::backend_probe::probe_cuda_backend("sae_encode")?;
1276                Ok(Backend {
1277                    ctx: parts.ctx,
1278                    stream: parts.stream,
1279                    modules: Mutex::new(HashMap::new()),
1280                })
1281            })
1282            .as_ref()
1283            .map_err(GpuError::clone)
1284    }
1285
1286    fn module_for(b: &Backend, dev: &EncodeAtomDevice) -> Result<Arc<CudaModule>, GpuError> {
1287        let key = format!(
1288            "{}-{}-{}-{}-{}-{:e}",
1289            dev.d, dev.m, dev.p, dev.topk, dev.newton_steps, dev.ridge
1290        );
1291        if let Ok(guard) = b.modules.lock() {
1292            if let Some(m) = guard.get(&key) {
1293                return Ok(m.clone());
1294            }
1295        }
1296        let src = encode_kernel_source(dev);
1297        let ptx = gam_gpu::device_cache::compile_ptx_arch(&src)
1298            .gpu_ctx_with(|err| format!("sae_encode NVRTC compile ({key}): {err}"))?;
1299        let module = b.ctx.load_module(ptx).gpu_ctx("sae_encode module load")?;
1300        if let Ok(mut guard) = b.modules.lock() {
1301            guard.entry(key).or_insert_with(|| module.clone());
1302        }
1303        Ok(module)
1304    }
1305
1306    /// Device implementation: flatten the atom + charts + rows, launch one block
1307    /// per row, download coords/certificate.
1308    pub(super) fn sae_certified_encode_device(
1309        dev: &EncodeAtomDevice,
1310        targets: &[Vec<f64>],
1311        amplitudes: &[f64],
1312    ) -> Result<Vec<DeviceEncodeRow>, GpuError> {
1313        let n = targets.len();
1314        let (d, p) = (dev.d, dev.p);
1315        if n == 0 {
1316            return Ok(Vec::new());
1317        }
1318        let b = backend()?;
1319        let module = module_for(b, dev)?;
1320        let func = module
1321            .load_function("sae_certified_encode")
1322            .gpu_ctx("sae_encode load_function")?;
1323        let stream = b.stream.clone();
1324        let n_charts = dev.charts.len();
1325
1326        // Flatten charts.
1327        let mut centers = vec![0.0_f64; n_charts * d];
1328        let mut radii = vec![0.0_f64; n_charts];
1329        let mut cert_radii = vec![0.0_f64; n_charts];
1330        let mut lips = vec![0.0_f64; n_charts];
1331        let mut has_jac = vec![0_i32; n_charts];
1332        let mut a1 = vec![0.0_f64; n_charts * d * p];
1333        let mut recon_c = vec![0.0_f64; n_charts * p];
1334        for (i, ch) in dev.charts.iter().enumerate() {
1335            centers[i * d..(i + 1) * d].copy_from_slice(&ch.center);
1336            radii[i] = ch.radius;
1337            cert_radii[i] = ch.certified_radius;
1338            lips[i] = ch.lipschitz;
1339            has_jac[i] = i32::from(ch.has_jacobian);
1340            if ch.has_jacobian {
1341                a1[i * d * p..(i + 1) * d * p].copy_from_slice(&ch.amortized_jacobian);
1342            }
1343            recon_c[i * p..(i + 1) * p].copy_from_slice(&ch.recon_center);
1344        }
1345        let mut tgt = vec![0.0_f64; n * p];
1346        for (i, x) in targets.iter().enumerate() {
1347            tgt[i * p..(i + 1) * p].copy_from_slice(x);
1348        }
1349
1350        let exps_dev = stream.clone_htod(&dev.exponents).gpu_ctx("sae_encode htod exps")?;
1351        let dec_dev = stream.clone_htod(&dev.decoder).gpu_ctx("sae_encode htod dec")?;
1352        let centers_dev = stream.clone_htod(&centers).gpu_ctx("sae_encode htod centers")?;
1353        let radii_dev = stream.clone_htod(&radii).gpu_ctx("sae_encode htod radii")?;
1354        let cert_dev = stream.clone_htod(&cert_radii).gpu_ctx("sae_encode htod cert_radii")?;
1355        let lips_dev = stream.clone_htod(&lips).gpu_ctx("sae_encode htod lips")?;
1356        let hasj_dev = stream.clone_htod(&has_jac).gpu_ctx("sae_encode htod has_jac")?;
1357        let a1_dev = stream.clone_htod(&a1).gpu_ctx("sae_encode htod a1")?;
1358        let reconc_dev = stream.clone_htod(&recon_c).gpu_ctx("sae_encode htod recon_c")?;
1359        let tgt_dev = stream.clone_htod(&tgt).gpu_ctx("sae_encode htod targets")?;
1360        let amps_dev = stream.clone_htod(&amplitudes.to_vec()).gpu_ctx("sae_encode htod amps")?;
1361        let mut coords_dev = stream.alloc_zeros::<f64>(n * d).gpu_ctx("sae_encode alloc coords")?;
1362        let mut h_dev = stream.alloc_zeros::<f64>(n).gpu_ctx("sae_encode alloc h")?;
1363        let mut cert_out_dev = stream.alloc_zeros::<i32>(n).gpu_ctx("sae_encode alloc certified")?;
1364
1365        let n_i32 = i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sae_encode n overflow"))?;
1366        let ncharts_i32 =
1367            i32::try_from(n_charts).map_err(|_| gam_gpu::gpu_err!("sae_encode n_charts overflow"))?;
1368        let cfg = LaunchConfig {
1369            grid_dim: (n_i32 as u32, 1, 1),
1370            block_dim: (32, 1, 1),
1371            shared_mem_bytes: 0,
1372        };
1373        let mut builder = stream.launch_builder(&func);
1374        builder
1375            .arg(&exps_dev)
1376            .arg(&dec_dev)
1377            .arg(&centers_dev)
1378            .arg(&radii_dev)
1379            .arg(&cert_dev)
1380            .arg(&lips_dev)
1381            .arg(&hasj_dev)
1382            .arg(&a1_dev)
1383            .arg(&reconc_dev)
1384            .arg(&ncharts_i32)
1385            .arg(&tgt_dev)
1386            .arg(&amps_dev)
1387            .arg(&n_i32)
1388            .arg(&mut coords_dev)
1389            .arg(&mut h_dev)
1390            .arg(&mut cert_out_dev);
1391        // SAFETY: grid/block validated; all pointers are cudarc-checked allocations
1392        // on this stream; the kernel reads within the flattened inputs and writes
1393        // coords[0..n*d], h[0..n], certified[0..n].
1394        unsafe { builder.launch(cfg) }.gpu_ctx("sae_encode kernel launch")?;
1395
1396        let mut coords = vec![0.0_f64; n * d];
1397        let mut h = vec![0.0_f64; n];
1398        let mut cert = vec![0_i32; n];
1399        stream.memcpy_dtoh(&coords_dev, &mut coords).gpu_ctx("sae_encode dtoh coords")?;
1400        stream.memcpy_dtoh(&h_dev, &mut h).gpu_ctx("sae_encode dtoh h")?;
1401        stream.memcpy_dtoh(&cert_out_dev, &mut cert).gpu_ctx("sae_encode dtoh certified")?;
1402        stream.synchronize().gpu_ctx("sae_encode synchronize")?;
1403
1404        let mut out = Vec::with_capacity(n);
1405        for row in 0..n {
1406            let coord = coords[row * d..(row + 1) * d].to_vec();
1407            let hv = h[row];
1408            out.push(DeviceEncodeRow {
1409                coord,
1410                cert: DeviceRowCertificate {
1411                    // beta/eta not transported; the h + certified flag is the contract.
1412                    beta: f64::NAN,
1413                    eta: f64::NAN,
1414                    lipschitz: f64::NAN,
1415                    h: hv,
1416                },
1417            });
1418        }
1419        // Reconcile the certified flag from the device (authoritative) — the h
1420        // value alone can be +inf for an uncertified fallback.
1421        for (row, o) in out.iter_mut().enumerate() {
1422            if cert[row] == 0 {
1423                o.cert.h = f64::INFINITY;
1424            }
1425        }
1426        Ok(out)
1427    }
1428}
1429
1430#[cfg(test)]
1431mod tests {
1432    use super::*;
1433    use crate::basis::{EuclideanPatchEvaluator, SaeBasisEvaluator};
1434    use crate::encode::{AtlasConfig, EncodeAtlas};
1435    use crate::manifold::{SaeAtomBasisKind, SaeManifoldAtom};
1436    use ndarray::{Array1, Array2};
1437    use std::sync::Arc;
1438
1439    /// Build a degree-`deg`, `d`-D `EuclideanPatch` atom with a deterministic
1440    /// decoder into `p` outputs, plus a matching `EncodeAtlas`. The atom carries
1441    /// the closed-form second jet, exactly the production certified-encode setup.
1442    fn build_atom_and_atlas(
1443        d: usize,
1444        deg: usize,
1445        p: usize,
1446        config: AtlasConfig,
1447    ) -> (SaeManifoldAtom, EncodeAtlas) {
1448        let evaluator = Arc::new(EuclideanPatchEvaluator::new(d, deg).unwrap());
1449        // Seed rows over a small coordinate grid (only used for the atom's stored
1450        // basis_values; the encode recomputes jets from the evaluator).
1451        let n_seed = 12usize;
1452        let coords = Array2::from_shape_fn((n_seed, d), |(r, c)| {
1453            0.15 * ((r as f64 + 1.0) * (c as f64 + 2.0) * 0.37).sin()
1454        });
1455        let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
1456        let m = phi.ncols();
1457        // Deterministic decoder B (m x p): a smooth, well-conditioned map.
1458        let decoder = Array2::from_shape_fn((m, p), |(bidx, c)| {
1459            (1.0 / (1.0 + bidx as f64)) * (((bidx as f64 + 1.0) * (c as f64 + 1.0)) * 0.3).cos()
1460        });
1461        let atom = SaeManifoldAtom::new(
1462            "euclid",
1463            SaeAtomBasisKind::EuclideanPatch,
1464            d,
1465            phi,
1466            jet,
1467            decoder,
1468            Array2::<f64>::eye(m),
1469        )
1470        .unwrap()
1471        .with_basis_second_jet(evaluator);
1472        // Amplitude / target-norm bounds generous enough to certify.
1473        let atlas = EncodeAtlas::build(&[atom.clone()], &[2.0], 8.0, config).unwrap();
1474        (atom, atlas)
1475    }
1476
1477    /// Assert the emulator reproduces the production `certified_encode_row` on a
1478    /// set of rows: certificate flag must match, and for certified rows the
1479    /// coords + `h` must agree within the Jacobi-vs-LAPACK eigen tolerance.
1480    fn assert_parity(
1481        atom: &SaeManifoldAtom,
1482        atlas: &EncodeAtlas,
1483        dev: &EncodeAtomDevice,
1484        rows: &[Vec<f64>],
1485        amps: &[f64],
1486    ) -> (usize, usize, f64, f64) {
1487        let mut certified = 0usize;
1488        let mut max_coord = 0.0_f64;
1489        let mut max_h = 0.0_f64;
1490        for (x, &amp) in rows.iter().zip(amps.iter()) {
1491            let xv = Array1::from(x.clone());
1492            let (coord_p, cert_p) = atlas
1493                .certified_encode_row(atom, 0, xv.view(), amp)
1494                .expect("production encode runs");
1495            let emu = emulate_certified_encode_row(dev, x, amp);
1496            assert_eq!(
1497                cert_p.certified(),
1498                emu.cert.certified(),
1499                "certificate flag mismatch (prod h={}, emu h={})",
1500                cert_p.h,
1501                emu.cert.h
1502            );
1503            if cert_p.certified() {
1504                certified += 1;
1505                for axis in 0..dev.d {
1506                    max_coord = max_coord.max((coord_p[axis] - emu.coord[axis]).abs());
1507                }
1508                max_h = max_h.max((cert_p.h - emu.cert.h).abs());
1509            }
1510        }
1511        (certified, rows.len(), max_coord, max_h)
1512    }
1513
1514    #[test]
1515    fn emulator_matches_production_certified_encode_1d_quadratic() {
1516        let (d, deg, p) = (1usize, 2usize, 4usize);
1517        let config = AtlasConfig::default();
1518        let (atom, atlas) = build_atom_and_atlas(d, deg, p, config);
1519        let atom_atlas = &atlas.atoms[0];
1520        let dev = EncodeAtomDevice::from_atom_atlas(&atom, atom_atlas, &config).unwrap();
1521        // Planted rows: exact reconstructions at known coords (on-manifold), so
1522        // the encode has a genuine certified basin.
1523        let mut rows: Vec<Vec<f64>> = Vec::new();
1524        let mut amps: Vec<f64> = Vec::new();
1525        let evaluator = EuclideanPatchEvaluator::new(d, deg).unwrap();
1526        for k in 0..24 {
1527            let tc = -0.4 + 0.8 * (k as f64) / 23.0;
1528            let (phi, _) = evaluator
1529                .evaluate(Array2::from_shape_fn((1, d), |_| tc).view())
1530                .unwrap();
1531            let amp = 1.0;
1532            let mut x = vec![0.0; p];
1533            for c in 0..p {
1534                let mut r = 0.0;
1535                for b in 0..dev.m {
1536                    r += phi[[0, b]] * dev.decoder[b * p + c];
1537                }
1538                x[c] = amp * r;
1539            }
1540            rows.push(x);
1541            amps.push(amp);
1542        }
1543        // Random (off-manifold) rows exercise the fallback / uncertified paths.
1544        for k in 0..24 {
1545            let x = (0..p)
1546                .map(|c| 0.5 * (((k * 7 + c * 3) as f64) * 0.21).sin())
1547                .collect();
1548            rows.push(x);
1549            amps.push(0.7 + 0.3 * ((k as f64) * 0.11).cos());
1550        }
1551        let (cert, total, max_coord, max_h) = assert_parity(&atom, &atlas, &dev, &rows, &amps);
1552        eprintln!(
1553            "1D quadratic: certified {cert}/{total}, max coord diff {max_coord:.3e}, max h diff {max_h:.3e}"
1554        );
1555        assert!(cert > 0, "planted rows must certify through the encode");
1556        assert!(max_coord <= 1e-7, "coord parity {max_coord:.3e} > 1e-7");
1557        assert!(max_h <= 1e-7, "certificate h parity {max_h:.3e} > 1e-7");
1558    }
1559
1560    #[test]
1561    fn emulator_matches_production_certified_encode_2d_quadratic() {
1562        let (d, deg, p) = (2usize, 2usize, 5usize);
1563        let config = AtlasConfig {
1564            grid_resolution: 6,
1565            ..AtlasConfig::default()
1566        };
1567        let (atom, atlas) = build_atom_and_atlas(d, deg, p, config);
1568        let atom_atlas = &atlas.atoms[0];
1569        let dev = EncodeAtomDevice::from_atom_atlas(&atom, atom_atlas, &config).unwrap();
1570        let evaluator = EuclideanPatchEvaluator::new(d, deg).unwrap();
1571        let mut rows: Vec<Vec<f64>> = Vec::new();
1572        let mut amps: Vec<f64> = Vec::new();
1573        for k in 0..30 {
1574            let t0 = -0.3 + 0.6 * ((k % 6) as f64) / 5.0;
1575            let t1 = -0.3 + 0.6 * ((k / 6) as f64) / 5.0;
1576            let coord = Array2::from_shape_fn((1, d), |(_, c)| if c == 0 { t0 } else { t1 });
1577            let (phi, _) = evaluator.evaluate(coord.view()).unwrap();
1578            let amp = 1.0;
1579            let mut x = vec![0.0; p];
1580            for c in 0..p {
1581                let mut r = 0.0;
1582                for b in 0..dev.m {
1583                    r += phi[[0, b]] * dev.decoder[b * p + c];
1584                }
1585                x[c] = amp * r;
1586            }
1587            rows.push(x);
1588            amps.push(amp);
1589        }
1590        for k in 0..20 {
1591            let x = (0..p)
1592                .map(|c| 0.4 * (((k * 5 + c * 2) as f64) * 0.17).cos())
1593                .collect();
1594            rows.push(x);
1595            amps.push(1.0);
1596        }
1597        let (cert, total, max_coord, max_h) = assert_parity(&atom, &atlas, &dev, &rows, &amps);
1598        eprintln!(
1599            "2D quadratic: certified {cert}/{total}, max coord diff {max_coord:.3e}, max h diff {max_h:.3e}"
1600        );
1601        assert!(cert > 0, "planted 2D rows must certify");
1602        assert!(max_coord <= 1e-6, "coord parity {max_coord:.3e} > 1e-6");
1603        assert!(max_h <= 1e-6, "certificate h parity {max_h:.3e} > 1e-6");
1604    }
1605
1606    #[test]
1607    fn emulator_matches_production_batch() {
1608        let (d, deg, p) = (1usize, 3usize, 3usize);
1609        let config = AtlasConfig::default();
1610        let (atom, atlas) = build_atom_and_atlas(d, deg, p, config);
1611        let dev = EncodeAtomDevice::from_atom_atlas(&atom, &atlas.atoms[0], &config).unwrap();
1612        let n = 40usize;
1613        let rows: Vec<Vec<f64>> = (0..n)
1614            .map(|k| (0..p).map(|c| 0.3 * (((k + c) as f64) * 0.19).sin()).collect())
1615            .collect();
1616        let amps: Vec<f64> = (0..n).map(|_| 1.0).collect();
1617        let (batch, path) = sae_certified_encode_batch(&dev, &rows, &amps);
1618        assert_eq!(path, EncodePath::Cpu, "small batch stays on CPU");
1619        // Batch == per-row emulate, and per-row == production certified flag.
1620        for (k, r) in batch.iter().enumerate() {
1621            let single = emulate_certified_encode_row(&dev, &rows[k], amps[k]);
1622            assert_eq!(r.cert.certified(), single.cert.certified());
1623            let xv = Array1::from(rows[k].clone());
1624            let (_, cert_p) = atlas
1625                .certified_encode_row(&atom, 0, xv.view(), amps[k])
1626                .unwrap();
1627            assert_eq!(
1628                cert_p.certified(),
1629                r.cert.certified(),
1630                "batch row {k} certificate flag disagrees with production"
1631            );
1632        }
1633    }
1634
1635    /// #988 core: benchmark the batched EXACT per-row encode and gate the
1636    /// certified-surrogate decision on the MEASURED throughput of the actual
1637    /// device-resident encode kernel — not the host component solve, and not a
1638    /// hardcoded target. This is the "benchmark first; surrogate only on
1639    /// benchmark evidence" order-of-work wired end to end onto
1640    /// [`sae_certified_encode_batch`] (the literal batched exact per-row GPU
1641    /// encode) via [`measure_device_encode_throughput`].
1642    #[test]
1643    fn device_encode_throughput_gates_surrogate_on_measurement() {
1644        let (d, deg, p) = (1usize, 2usize, 4usize);
1645        let config = AtlasConfig::default();
1646        let (atom, atlas) = build_atom_and_atlas(d, deg, p, config);
1647        let dev = EncodeAtomDevice::from_atom_atlas(&atom, &atlas.atoms[0], &config).unwrap();
1648
1649        // A batch large enough that a CUDA host would take the device path
1650        // (>= DEVICE_ROW_THRESHOLD), mixing planted on-manifold rows (which must
1651        // certify — a non-vacuous benchmark) with off-manifold rows (fallback).
1652        let n = DEVICE_ROW_THRESHOLD + 64;
1653        let evaluator = EuclideanPatchEvaluator::new(d, deg).unwrap();
1654        let mut rows: Vec<Vec<f64>> = Vec::with_capacity(n);
1655        let mut amps: Vec<f64> = Vec::with_capacity(n);
1656        for k in 0..n {
1657            if k % 2 == 0 {
1658                // Planted: exact amplitude-1 reconstruction at a known coordinate.
1659                let tc = -0.4 + 0.8 * ((k % 24) as f64) / 23.0;
1660                let (phi, _) = evaluator
1661                    .evaluate(Array2::from_shape_fn((1, d), |_| tc).view())
1662                    .unwrap();
1663                let x = (0..p)
1664                    .map(|c| {
1665                        (0..dev.m)
1666                            .map(|b| phi[[0, b]] * dev.decoder[b * p + c])
1667                            .sum::<f64>()
1668                    })
1669                    .collect();
1670                rows.push(x);
1671                amps.push(1.0);
1672            } else {
1673                let x = (0..p)
1674                    .map(|c| 0.5 * (((k * 7 + c * 3) as f64) * 0.021).sin())
1675                    .collect();
1676                rows.push(x);
1677                amps.push(1.0);
1678            }
1679        }
1680
1681        // The benchmark: time the exact encode and derive the surrogate decision.
1682        let tput = measure_device_encode_throughput(&dev, &rows, &amps);
1683        eprintln!(
1684            "[device-encode #988] n={} rows/sec={:.1} path={:?} decision={:?}",
1685            tput.n_rows, tput.rows_per_sec, tput.path, tput.decision
1686        );
1687
1688        // It must be a REAL measurement (positive rate), and the engagement flag
1689        // must be consistent with the path that ran.
1690        assert!(
1691            tput.rows_per_sec > 0.0,
1692            "the exact encode benchmark must produce a positive rows/sec, got {}",
1693            tput.rows_per_sec
1694        );
1695        assert_eq!(tput.device_engaged(), matches!(tput.path, EncodePath::Device));
1696
1697        // The benchmark must be non-vacuous: on a well-conditioned dictionary the
1698        // planted on-manifold rows certify through the exact encode (proving the
1699        // routing + basin Newton + certificate really ran, not a trivial pass).
1700        let (batch, _) = sae_certified_encode_batch(&dev, &rows, &amps);
1701        let certified = batch.iter().filter(|r| r.cert.certified()).count();
1702        assert!(
1703            certified > 0,
1704            "the exact encode must certify a majority of the planted rows; certified={certified}/{n}"
1705        );
1706
1707        if tput.device_engaged() {
1708            // Only reachable on a CUDA host: the decision is a REAL Met/Unmet
1709            // keyed on the measured device throughput vs the 100k rows/sec target.
1710            assert!(
1711                !tput.decision.is_undetermined(),
1712                "an engaged device measurement must decide Met/Unmet, got {:?}",
1713                tput.decision
1714            );
1715            let target = gam_gpu::policy::GPU_THROUGHPUT_TARGET_ROWS_PER_SEC;
1716            if tput.rows_per_sec >= target {
1717                assert!(
1718                    tput.decision.surrogate_unneeded(),
1719                    "device rate {:.1} >= target {target} must mark the surrogate unneeded",
1720                    tput.rows_per_sec
1721                );
1722            } else {
1723                assert!(
1724                    tput.decision.surrogate_justified(),
1725                    "device rate {:.1} < target {target} must justify the surrogate",
1726                    tput.rows_per_sec
1727                );
1728            }
1729        } else {
1730            // CPU-only host (this dev box): the rate is honest but it is NOT a
1731            // device measurement. The surrogate decision is BLOCKED on hardware —
1732            // a fast CPU number can never declare the surrogate unneeded (the
1733            // #1412 anti-green-wash property carried to the exact device encode).
1734            assert!(
1735                tput.decision.is_undetermined(),
1736                "a CPU-emulator exact encode must leave the surrogate decision Undetermined, got {:?}",
1737                tput.decision
1738            );
1739            assert!(!tput.decision.surrogate_unneeded());
1740            assert!(!tput.decision.surrogate_justified());
1741        }
1742    }
1743
1744    #[test]
1745    fn jacobi_eigh_matches_reference_2x2() {
1746        // Symmetric 2x2 spectral check: reconstruct A from V diag(vals) Vᵀ.
1747        let a = [4.0, 1.0, 1.0, 3.0];
1748        let mut vals = [0.0; 2];
1749        let mut vecs = [0.0; 4];
1750        jacobi_eigh(&a, 2, &mut vals, &mut vecs);
1751        // A_reconstructed[r][c] = Σ_k vals[k] v_k[r] v_k[c].
1752        for r in 0..2 {
1753            for c in 0..2 {
1754                let mut acc = 0.0;
1755                for k in 0..2 {
1756                    acc += vals[k] * vecs[k * 2 + r] * vecs[k * 2 + c];
1757                }
1758                assert!((acc - a[r * 2 + c]).abs() < 1e-12, "eig reconstruct {r},{c}");
1759            }
1760        }
1761        // Eigenvalues of [[4,1],[1,3]] are (7±√5)/2.
1762        let mut vs = vals.to_vec();
1763        vs.sort_by(|a, b| a.partial_cmp(b).unwrap());
1764        assert!((vs[0] - (7.0 - 5.0_f64.sqrt()) / 2.0).abs() < 1e-12);
1765        assert!((vs[1] - (7.0 + 5.0_f64.sqrt()) / 2.0).abs() < 1e-12);
1766    }
1767
1768    #[cfg(target_os = "linux")]
1769    #[test]
1770    fn encode_kernel_source_substitutes_macros_and_compiles() {
1771        let (d, deg, p) = (1usize, 2usize, 4usize);
1772        let config = AtlasConfig::default();
1773        let (atom, atlas) = build_atom_and_atlas(d, deg, p, config);
1774        let dev = EncodeAtomDevice::from_atom_atlas(&atom, &atlas.atoms[0], &config).unwrap();
1775        let src = encode_kernel_source(&dev);
1776        assert!(src.contains(&format!("#define DD {}", dev.d)));
1777        assert!(src.contains(&format!("#define MM {}", dev.m)));
1778        assert!(src.contains(&format!("#define PP {}", dev.p)));
1779        assert!(src.contains("sae_certified_encode"));
1780        // NVRTC host-compile to PTX (no device needed) — the #1017 anchor.
1781        let ptx = gam_gpu::device_cache::compile_ptx_arch(&src)
1782            .expect("sae_encode kernel compiles to PTX via NVRTC");
1783        let text = ptx.to_src();
1784        assert!(text.contains(".visible .entry sae_certified_encode"),
1785            "PTX must export the encode entry");
1786        assert!(text.contains(".target sm_"), "PTX must carry a target arch");
1787    }
1788
1789    #[cfg(target_os = "linux")]
1790    #[test]
1791    fn device_matches_emulator_when_available() {
1792        let (d, deg, p) = (1usize, 2usize, 4usize);
1793        let config = AtlasConfig::default();
1794        let (atom, atlas) = build_atom_and_atlas(d, deg, p, config);
1795        let dev = EncodeAtomDevice::from_atom_atlas(&atom, &atlas.atoms[0], &config).unwrap();
1796        let n = DEVICE_ROW_THRESHOLD + 64;
1797        let rows: Vec<Vec<f64>> = (0..n)
1798            .map(|k| (0..p).map(|c| 0.3 * (((k + c) as f64) * 0.019).sin()).collect())
1799            .collect();
1800        let amps = vec![1.0; n];
1801        let cpu = emulate_certified_encode_batch(&dev, &rows, &amps);
1802        if gam_gpu::device_runtime::GpuRuntime::global().is_some() {
1803            let devout = device::sae_certified_encode_device(&dev, &rows, &amps)
1804                .expect("admitted GPU runtime must run the sae_encode kernel");
1805            let mut max_coord = 0.0_f64;
1806            for (a, b) in cpu.iter().zip(devout.iter()) {
1807                assert_eq!(a.cert.certified(), b.cert.certified(), "device certified flag");
1808                if a.cert.certified() {
1809                    for axis in 0..dev.d {
1810                        max_coord = max_coord.max((a.coord[axis] - b.coord[axis]).abs());
1811                    }
1812                }
1813            }
1814            assert!(max_coord <= 1e-9, "device vs emulator coord diff {max_coord:.3e} > 1e-9");
1815        }
1816    }
1817}
1818