1use std::time::Instant;
56
57use crate::encode::{
58 AtlasConfig, AtomEncodeAtlas, KANTOROVICH_THRESHOLD, euclidean_patch_degree,
59};
60use crate::manifold::SaeManifoldAtom;
61use gam_gpu::policy::{EncodeDecisionBlocked, EncodeDeploymentDecision};
62
63#[derive(Debug, Clone)]
69pub struct EncodeAtomDevice {
70 pub d: usize,
72 pub m: usize,
74 pub p: usize,
76 pub topk: usize,
78 pub newton_steps: usize,
80 pub ridge: f64,
82 pub exponents: Vec<i32>,
84 pub decoder: Vec<f64>,
86 pub charts: Vec<EncodeChartDevice>,
88}
89
90#[derive(Debug, Clone)]
92pub struct EncodeChartDevice {
93 pub center: Vec<f64>,
95 pub radius: f64,
97 pub certified_radius: f64,
99 pub lipschitz: f64,
101 pub has_jacobian: bool,
103 pub amortized_jacobian: Vec<f64>,
105 pub recon_center: Vec<f64>,
107}
108
109impl EncodeAtomDevice {
110 pub fn from_atom_atlas(
115 atom: &SaeManifoldAtom,
116 atom_atlas: &AtomEncodeAtlas,
117 config: &AtlasConfig,
118 ) -> Result<Self, String> {
119 let d = atom.latent_dim;
120 let p = atom.output_dim();
121 let m = atom.basis_size();
122 let degree = euclidean_patch_degree(d, m);
123 let exps = gam_terms::basis::monomial_exponents(d, degree);
124 if exps.len() != m {
125 return Err(format!(
126 "EncodeAtomDevice::from_atom_atlas: monomial table len {} != basis_size {m} \
127 (atom is not a EuclideanPatch degree-{degree} monomial family)",
128 exps.len()
129 ));
130 }
131 let mut exponents = vec![0_i32; m * d];
132 for (col, alpha) in exps.iter().enumerate() {
133 for axis in 0..d {
134 exponents[col * d + axis] = alpha[axis] as i32;
135 }
136 }
137 let dec = &atom.decoder_coefficients;
138 if dec.dim() != (m, p) {
139 return Err(format!(
140 "EncodeAtomDevice::from_atom_atlas: decoder dim {:?} != ({m}, {p})",
141 dec.dim()
142 ));
143 }
144 let mut decoder = vec![0.0_f64; m * p];
145 for b in 0..m {
146 for c in 0..p {
147 decoder[b * p + c] = dec[[b, c]];
148 }
149 }
150 let mut charts = Vec::with_capacity(atom_atlas.charts.len());
151 for chart in &atom_atlas.charts {
152 let center = chart.region.center.to_vec();
153 if center.len() != d {
154 return Err(format!(
155 "EncodeAtomDevice::from_atom_atlas: chart center len {} != d {d}",
156 center.len()
157 ));
158 }
159 let (has_jacobian, amortized_jacobian) = match &chart.amortized_jacobian {
160 Some(a1) => {
161 if a1.dim() != (d, p) {
162 return Err(format!(
163 "EncodeAtomDevice::from_atom_atlas: A1 dim {:?} != ({d}, {p})",
164 a1.dim()
165 ));
166 }
167 let mut flat = vec![0.0_f64; d * p];
168 for axis in 0..d {
169 for out in 0..p {
170 flat[axis * p + out] = a1[[axis, out]];
171 }
172 }
173 (true, flat)
174 }
175 None => (false, Vec::new()),
176 };
177 let recon_center = chart.recon_center.to_vec();
178 charts.push(EncodeChartDevice {
179 center,
180 radius: chart.region.radius,
181 certified_radius: chart.certified_radius,
182 lipschitz: chart.lipschitz,
183 has_jacobian,
184 amortized_jacobian,
185 recon_center,
186 });
187 }
188 Ok(Self {
189 d,
190 m,
191 p,
192 topk: crate::encode::CERTIFIED_ROUTING_TOPK,
193 newton_steps: config.newton_steps,
194 ridge: config.ridge,
195 exponents,
196 decoder,
197 charts,
198 })
199 }
200}
201
202#[derive(Debug, Clone, Copy, PartialEq)]
205pub struct DeviceRowCertificate {
206 pub beta: f64,
207 pub eta: f64,
208 pub lipschitz: f64,
209 pub h: f64,
210}
211
212impl DeviceRowCertificate {
213 #[inline]
214 #[must_use]
215 pub fn certified(&self) -> bool {
216 self.h.is_finite() && self.h <= KANTOROVICH_THRESHOLD
217 }
218 #[inline]
219 fn uncertified(lipschitz: f64) -> Self {
220 Self {
221 beta: f64::INFINITY,
222 eta: f64::INFINITY,
223 lipschitz,
224 h: f64::INFINITY,
225 }
226 }
227 #[inline]
228 fn uncertified_inf() -> Self {
229 Self {
230 beta: f64::INFINITY,
231 eta: f64::INFINITY,
232 lipschitz: f64::INFINITY,
233 h: f64::INFINITY,
234 }
235 }
236}
237
238#[derive(Debug, Clone)]
240pub struct DeviceEncodeRow {
241 pub coord: Vec<f64>,
242 pub cert: DeviceRowCertificate,
243}
244
245#[inline]
256fn dpow(base: f64, exp: i32) -> f64 {
257 base.powi(exp)
260}
261
262fn eval_basis(dev: &EncodeAtomDevice, t: &[f64], phi: &mut [f64], jet: &mut [f64], hess: &mut [f64]) {
268 let (d, m) = (dev.d, dev.m);
269 let exp = &dev.exponents;
270 for col in 0..m {
271 let mut value = 1.0_f64;
273 for axis in 0..d {
274 let e = exp[col * d + axis];
275 if e != 0 {
276 value *= dpow(t[axis], e);
277 }
278 }
279 phi[col] = value;
280 for axis in 0..d {
282 let a_axis = exp[col * d + axis];
283 let mut jval = 0.0_f64;
284 if a_axis != 0 {
285 jval = a_axis as f64;
286 for a in 0..d {
287 let ea = if a == axis { a_axis - 1 } else { exp[col * d + a] };
288 if ea != 0 {
289 jval *= dpow(t[a], ea);
290 }
291 }
292 }
293 jet[col * d + axis] = jval;
294 }
295 for a in 0..d {
297 for c in 0..d {
298 let mut hval = 0.0_f64;
299 let aa = exp[col * d + a];
300 let ac = exp[col * d + c];
301 let admissible = aa != 0 && (a == c || ac != 0);
302 if admissible {
303 let lead = if a == c {
304 (aa as f64) * ((aa - 1).max(0) as f64)
305 } else {
306 (aa as f64) * (ac as f64)
307 };
308 if lead != 0.0 {
309 hval = lead;
310 for axis in 0..d {
311 let mut e = exp[col * d + axis];
312 if axis == a {
313 e = (e - 1).max(0);
314 }
315 if axis == c {
316 e = (e - 1).max(0);
317 }
318 if e != 0 {
319 hval *= dpow(t[axis], e);
320 }
321 }
322 }
323 }
324 hess[(col * d + a) * d + c] = hval;
325 }
326 }
327 }
328}
329
330fn recon_amp1(dev: &EncodeAtomDevice, phi: &[f64], out: &mut [f64]) {
333 let (m, p) = (dev.m, dev.p);
334 for c in 0..p {
335 out[c] = 0.0;
336 }
337 for b in 0..m {
338 let pv = phi[b];
339 if pv == 0.0 {
340 continue;
341 }
342 for c in 0..p {
343 out[c] += pv * dev.decoder[b * p + c];
344 }
345 }
346}
347
348struct EvaluatedBasis<'a> {
351 phi: &'a [f64],
352 jet: &'a [f64],
353 hess: &'a [f64],
354}
355
356fn encode_grad_hess(
362 dev: &EncodeAtomDevice,
363 x: &[f64],
364 amplitude: f64,
365 be: &EvaluatedBasis<'_>,
366 g: &mut [f64],
367 h: &mut [f64],
368) {
369 let (phi, jet, hess) = (be.phi, be.jet, be.hess);
370 let (d, m, p) = (dev.d, dev.m, dev.p);
371 let mut recon = vec![0.0_f64; p];
373 for b in 0..m {
374 let pv = phi[b];
375 if pv == 0.0 {
376 continue;
377 }
378 for c in 0..p {
379 recon[c] += amplitude * pv * dev.decoder[b * p + c];
380 }
381 }
382 let mut residual = vec![0.0_f64; p];
383 for c in 0..p {
384 residual[c] = recon[c] - x[c];
385 }
386 let mut jm = vec![0.0_f64; d * p];
388 for axis in 0..d {
389 for b in 0..m {
390 let dphi = jet[b * d + axis];
391 if dphi == 0.0 {
392 continue;
393 }
394 for c in 0..p {
395 jm[axis * p + c] += amplitude * dphi * dev.decoder[b * p + c];
396 }
397 }
398 }
399 for a in 0..d {
401 let mut ga = 0.0;
402 for c in 0..p {
403 ga += jm[a * p + c] * residual[c];
404 }
405 g[a] = ga;
406 for b in 0..d {
407 let mut hab = 0.0;
408 for c in 0..p {
409 hab += jm[a * p + c] * jm[b * p + c];
410 }
411 let mut curv = 0.0;
412 for basis in 0..m {
413 let d2 = hess[(basis * d + a) * d + b];
414 if d2 == 0.0 {
415 continue;
416 }
417 let mut dot = 0.0;
418 for c in 0..p {
419 dot += residual[c] * dev.decoder[basis * p + c];
420 }
421 curv += amplitude * d2 * dot;
422 }
423 hab += curv;
424 h[a * d + b] = hab;
425 }
426 }
427 for a in 0..d {
428 h[a * d + a] += dev.ridge;
429 }
430}
431
432pub fn jacobi_eigh(a_in: &[f64], d: usize, vals: &mut [f64], vecs: &mut [f64]) {
439 let mut a = a_in.to_vec();
441 for r in 0..d {
442 for c in 0..d {
443 vecs[c * d + r] = if r == c { 1.0 } else { 0.0 };
444 }
445 }
446 if d == 1 {
447 vals[0] = a[0];
448 return;
449 }
450 for _sweep in 0..30 {
453 let mut off = 0.0_f64;
455 for r in 0..d {
456 for c in (r + 1)..d {
457 off += a[r * d + c] * a[r * d + c];
458 }
459 }
460 if off <= 1e-300 {
461 break;
462 }
463 for pp in 0..d {
464 for q in (pp + 1)..d {
465 let apq = a[pp * d + q];
466 if apq == 0.0 {
467 continue;
468 }
469 let app = a[pp * d + pp];
470 let aqq = a[q * d + q];
471 let tau = (aqq - app) / (2.0 * apq);
473 let t = if tau >= 0.0 {
474 1.0 / (tau + (1.0 + tau * tau).sqrt())
475 } else {
476 -1.0 / (-tau + (1.0 + tau * tau).sqrt())
477 };
478 let cph = 1.0 / (1.0 + t * t).sqrt();
479 let sph = t * cph;
480 for k in 0..d {
482 let akp = a[k * d + pp];
483 let akq = a[k * d + q];
484 a[k * d + pp] = cph * akp - sph * akq;
485 a[k * d + q] = sph * akp + cph * akq;
486 }
487 for k in 0..d {
488 let apk = a[pp * d + k];
489 let aqk = a[q * d + k];
490 a[pp * d + k] = cph * apk - sph * aqk;
491 a[q * d + k] = sph * apk + cph * aqk;
492 }
493 for k in 0..d {
495 let vkp = vecs[pp * d + k];
496 let vkq = vecs[q * d + k];
497 vecs[pp * d + k] = cph * vkp - sph * vkq;
498 vecs[q * d + k] = sph * vkp + cph * vkq;
499 }
500 }
501 }
502 }
503 for i in 0..d {
504 vals[i] = a[i * d + i];
505 }
506}
507
508fn beta_eta_newton(h: &[f64], g: &[f64], d: usize) -> Option<(f64, f64, Vec<f64>)> {
512 let mut vals = vec![0.0_f64; d];
513 let mut vecs = vec![0.0_f64; d * d];
514 jacobi_eigh(h, d, &mut vals, &mut vecs);
515 let mut lambda_min = f64::INFINITY;
516 for &v in &vals {
517 if v < lambda_min {
518 lambda_min = v;
519 }
520 }
521 if !(lambda_min.is_finite() && lambda_min > 0.0) {
522 return None;
523 }
524 let beta = 1.0 / lambda_min;
525 let mut delta = vec![0.0_f64; d];
526 for col in 0..d {
527 let lam = vals[col];
528 if lam <= 0.0 {
529 return None;
530 }
531 let mut vg = 0.0;
533 for row in 0..d {
534 vg += vecs[col * d + row] * g[row];
535 }
536 let coeff = vg / lam;
537 for row in 0..d {
538 delta[row] -= coeff * vecs[col * d + row];
539 }
540 }
541 let mut eta = 0.0;
542 for row in 0..d {
543 eta += delta[row] * delta[row];
544 }
545 Some((beta, eta.sqrt(), delta))
546}
547
548fn row_certificate(
550 dev: &EncodeAtomDevice,
551 t: &[f64],
552 x: &[f64],
553 amplitude: f64,
554 lipschitz: f64,
555 scratch: &mut Scratch,
556) -> (DeviceRowCertificate, Vec<f64>) {
557 let d = dev.d;
558 eval_basis(dev, t, &mut scratch.phi, &mut scratch.jet, &mut scratch.hess);
559 encode_grad_hess(
560 dev,
561 x,
562 amplitude,
563 &EvaluatedBasis {
564 phi: &scratch.phi,
565 jet: &scratch.jet,
566 hess: &scratch.hess,
567 },
568 &mut scratch.g,
569 &mut scratch.h,
570 );
571 match beta_eta_newton(&scratch.h, &scratch.g, d) {
572 Some((beta, eta, delta)) => (
573 DeviceRowCertificate {
574 beta,
575 eta,
576 lipschitz,
577 h: beta * eta * lipschitz,
578 },
579 delta,
580 ),
581 None => (
582 DeviceRowCertificate::uncertified(lipschitz),
583 vec![0.0_f64; d],
584 ),
585 }
586}
587
588struct Scratch {
590 phi: Vec<f64>,
591 jet: Vec<f64>,
592 hess: Vec<f64>,
593 g: Vec<f64>,
594 h: Vec<f64>,
595}
596
597impl Scratch {
598 fn new(dev: &EncodeAtomDevice) -> Self {
599 let (d, m) = (dev.d, dev.m);
600 Self {
601 phi: vec![0.0; m],
602 jet: vec![0.0; m * d],
603 hess: vec![0.0; m * d * d],
604 g: vec![0.0; d],
605 h: vec![0.0; d * d],
606 }
607 }
608}
609
610#[inline]
611fn in_chart(t: &[f64], center: &[f64], radius: f64) -> bool {
612 let mut r2 = 0.0;
613 for i in 0..t.len() {
614 let dlt = t[i] - center[i];
615 r2 += dlt * dlt;
616 }
617 r2 <= radius * radius
618}
619
620fn certify_with_basin_warmup(
626 dev: &EncodeAtomDevice,
627 mut t: Vec<f64>,
628 x: &[f64],
629 amplitude: f64,
630 chart: &EncodeChartDevice,
631 scratch: &mut Scratch,
632) -> Option<(Vec<f64>, DeviceRowCertificate)> {
633 if !in_chart(&t, &chart.center, chart.radius) {
634 return None;
635 }
636 let (mut cert, mut delta) =
637 row_certificate(dev, &t, x, amplitude, chart.lipschitz, scratch);
638 while !cert.certified() {
639 if !(cert.h.is_finite() && cert.beta.is_finite() && cert.eta.is_finite()) {
640 return None;
641 }
642 let prev_h = cert.h;
643 let mut next = t.clone();
644 for i in 0..dev.d {
645 next[i] += delta[i];
646 }
647 if !in_chart(&next, &chart.center, chart.radius) {
648 return None;
649 }
650 t = next;
651 let (nc, nd) = row_certificate(dev, &t, x, amplitude, chart.lipschitz, scratch);
652 cert = nc;
653 delta = nd;
654 if !cert.h.is_finite() || cert.h >= prev_h {
655 return None;
656 }
657 }
658 let landing = cert;
663 for _ in 0..dev.newton_steps {
664 let dnorm = delta.iter().map(|v| v * v).sum::<f64>().sqrt();
665 let tnorm = t.iter().map(|v| v * v).sum::<f64>().sqrt();
666 if dnorm <= crate::encode::NEWTON_REFINE_CONVERGED_EPS * (1.0 + tnorm) {
667 break;
668 }
669 for i in 0..dev.d {
670 t[i] += delta[i];
671 }
672 let (nc, nd) = row_certificate(dev, &t, x, amplitude, chart.lipschitz, scratch);
673 if !nc.certified() {
674 return None;
675 }
676 delta = nd;
677 }
678 Some((t, landing))
679}
680
681fn amortized_warm_start(chart: &EncodeChartDevice, x: &[f64], amplitude: f64, d: usize, p: usize) -> Option<Vec<f64>> {
685 if !chart.has_jacobian {
686 return None;
687 }
688 if !(amplitude.is_finite() && amplitude.abs() > 0.0) {
689 return None;
690 }
691 let mut t_hat = chart.center.clone();
692 for out in 0..p.min(chart.recon_center.len()) {
693 let resid = x[out] - amplitude * chart.recon_center[out];
694 for axis in 0..d {
695 t_hat[axis] += chart.amortized_jacobian[axis * p + out] * resid / amplitude;
696 }
697 }
698 Some(t_hat)
699}
700
701fn recon_error(dev: &EncodeAtomDevice, t: &[f64], x: &[f64], amplitude: f64, scratch: &mut Scratch) -> f64 {
704 eval_basis(dev, t, &mut scratch.phi, &mut scratch.jet, &mut scratch.hess);
705 let mut err2 = 0.0;
706 let p = dev.p;
707 let mut recon = vec![0.0_f64; p];
708 recon_amp1(dev, &scratch.phi, &mut recon);
709 for c in 0..p {
710 let r = x[c] - amplitude * recon[c];
711 err2 += r * r;
712 }
713 if err2.is_finite() { err2.sqrt() } else { f64::INFINITY }
714}
715
716fn nearest_charts_topk(dev: &EncodeAtomDevice, x: &[f64], scratch: &mut Scratch) -> Vec<usize> {
720 if dev.charts.is_empty() || dev.topk == 0 {
721 return Vec::new();
722 }
723 let p = dev.p;
724 let mut scored: Vec<(usize, f64)> = Vec::new();
725 let mut recon = vec![0.0_f64; p];
726 for (idx, chart) in dev.charts.iter().enumerate() {
727 if chart.certified_radius <= 0.0 {
728 continue;
729 }
730 eval_basis(dev, &chart.center, &mut scratch.phi, &mut scratch.jet, &mut scratch.hess);
731 recon_amp1(dev, &scratch.phi, &mut recon);
732 let mut dist = 0.0;
733 for c in 0..p {
734 let diff = recon[c] - x[c];
735 dist += diff * diff;
736 }
737 scored.push((idx, dist));
738 }
739 scored.sort_by(|a, b| {
740 a.1.partial_cmp(&b.1)
741 .unwrap_or(std::cmp::Ordering::Equal)
742 .then(a.0.cmp(&b.0))
743 });
744 scored.into_iter().take(dev.topk).map(|(i, _)| i).collect()
745}
746
747#[must_use]
752pub fn emulate_certified_encode_row(dev: &EncodeAtomDevice, x: &[f64], amplitude: f64) -> DeviceEncodeRow {
753 let d = dev.d;
754 let p = dev.p;
755 let mut scratch = Scratch::new(dev);
756 let candidates = nearest_charts_topk(dev, x, &mut scratch);
757 if candidates.is_empty() {
758 return DeviceEncodeRow {
759 coord: vec![0.0; d],
760 cert: DeviceRowCertificate::uncertified_inf(),
761 };
762 }
763 let mut best: Option<(Vec<f64>, DeviceRowCertificate, f64)> = None;
764 let mut nearest_fallback: Option<(Vec<f64>, DeviceRowCertificate)> = None;
765 for chart_idx in candidates {
766 let chart = &dev.charts[chart_idx];
767 let Some(t_hat) = amortized_warm_start(chart, x, amplitude, d, p) else {
768 if nearest_fallback.is_none() {
769 nearest_fallback = Some((vec![0.0; d], DeviceRowCertificate::uncertified(chart.lipschitz)));
770 }
771 continue;
772 };
773 let (coord, cert) = match certify_with_basin_warmup(dev, t_hat, x, amplitude, chart, &mut scratch) {
774 Some((c, cert)) => (c, cert),
775 None => (vec![0.0; d], DeviceRowCertificate::uncertified(chart.lipschitz)),
776 };
777 if nearest_fallback.is_none() {
778 nearest_fallback = Some((coord.clone(), cert));
779 }
780 if cert.certified() {
781 let err = recon_error(dev, &coord, x, amplitude, &mut scratch);
782 if best.as_ref().map(|(_, _, e)| err < *e).unwrap_or(true) {
783 best = Some((coord, cert, err));
784 }
785 if let Some((_, _, e)) = best.as_ref() {
790 let xnorm = x.iter().map(|v| v * v).sum::<f64>().sqrt();
791 if *e <= crate::encode::CERTIFIED_GLOBAL_MIN_RECON_FLOOR * (1.0 + xnorm) {
792 break;
793 }
794 }
795 }
796 }
797 match best {
798 Some((coord, cert, _)) => DeviceEncodeRow { coord, cert },
799 None => {
800 let (coord, cert) = nearest_fallback
801 .unwrap_or_else(|| (vec![0.0; d], DeviceRowCertificate::uncertified_inf()));
802 DeviceEncodeRow { coord, cert }
803 }
804 }
805}
806
807#[must_use]
810pub fn emulate_certified_encode_batch(
811 dev: &EncodeAtomDevice,
812 targets: &[Vec<f64>],
813 amplitudes: &[f64],
814) -> Vec<DeviceEncodeRow> {
815 targets
816 .iter()
817 .zip(amplitudes.iter())
818 .map(|(x, &)| emulate_certified_encode_row(dev, x, amp))
819 .collect()
820}
821
822pub const ENCODE_KERNEL_SOURCE: &str = r#"
833#define KANTOROVICH 0.5
834
835__device__ __forceinline__ double dpow(double b, int e){
836 // exponentiation-by-squaring, matching llvm.powi/f64::powi and the emulator dpow.
837 if (e == 0) return 1.0;
838 int n = e < 0 ? -e : e;
839 double r = 1.0, base = b;
840 while (n > 0){ if (n & 1) r *= base; n >>= 1; if (n) base *= base; }
841 return e < 0 ? 1.0 / r : r;
842}
843
844// Monomial phi/jet/hess at t (mirror of eval_basis).
845__device__ void eval_basis(const int* exps, const double* t,
846 double* phi, double* jet, double* hess){
847 for (int col=0; col<MM; ++col){
848 double value = 1.0;
849 for (int axis=0; axis<DD; ++axis){ int e=exps[col*DD+axis]; if(e!=0) value*=dpow(t[axis],e); }
850 phi[col]=value;
851 for (int axis=0; axis<DD; ++axis){
852 int a_axis=exps[col*DD+axis]; double jval=0.0;
853 if (a_axis!=0){ jval=(double)a_axis;
854 for(int a=0;a<DD;++a){ int ea=(a==axis)?a_axis-1:exps[col*DD+a]; if(ea!=0) jval*=dpow(t[a],ea); } }
855 jet[col*DD+axis]=jval;
856 }
857 for (int a=0;a<DD;++a) for(int c=0;c<DD;++c){
858 double hval=0.0; int aa=exps[col*DD+a]; int ac=exps[col*DD+c];
859 int adm = (aa!=0) && (a==c || ac!=0);
860 if (adm){
861 double lead = (a==c) ? (double)aa*(double)((aa-1)>0?(aa-1):0)
862 : (double)aa*(double)ac;
863 if (lead!=0.0){ hval=lead;
864 for(int axis=0;axis<DD;++axis){ int e=exps[col*DD+axis];
865 if(axis==a) e=(e-1)>0?(e-1):0; if(axis==c) e=(e-1)>0?(e-1):0;
866 if(e!=0) hval*=dpow(t[axis],e); } }
867 }
868 hess[(col*DD+a)*DD+c]=hval;
869 }
870 }
871}
872
873__device__ void recon_amp1(const double* dec, const double* phi, double* out){
874 for(int c=0;c<PP;++c) out[c]=0.0;
875 for(int b=0;b<MM;++b){ double pv=phi[b]; if(pv==0.0) continue;
876 for(int c=0;c<PP;++c) out[c]+=pv*dec[b*PP+c]; }
877}
878
879// grad g[D] and full Hessian h[D*D] (+ridge). Mirror of encode_grad_hess.
880__device__ void grad_hess(const double* dec, const double* t, const double* x, double amp,
881 const double* phi, const double* jet, const double* hess,
882 double* g, double* h){
883 double recon[PP]; double residual[PP]; double jm[DD*PP];
884 for(int c=0;c<PP;++c) recon[c]=0.0;
885 for(int b=0;b<MM;++b){ double pv=phi[b]; if(pv==0.0) continue;
886 for(int c=0;c<PP;++c) recon[c]+=amp*pv*dec[b*PP+c]; }
887 for(int c=0;c<PP;++c) residual[c]=recon[c]-x[c];
888 for(int i=0;i<DD*PP;++i) jm[i]=0.0;
889 for(int axis=0;axis<DD;++axis) for(int b=0;b<MM;++b){ double dphi=jet[b*DD+axis]; if(dphi==0.0) continue;
890 for(int c=0;c<PP;++c) jm[axis*PP+c]+=amp*dphi*dec[b*PP+c]; }
891 for(int a=0;a<DD;++a){
892 double ga=0.0; for(int c=0;c<PP;++c) ga+=jm[a*PP+c]*residual[c]; g[a]=ga;
893 for(int b=0;b<DD;++b){
894 double hab=0.0; for(int c=0;c<PP;++c) hab+=jm[a*PP+c]*jm[b*PP+c];
895 double curv=0.0;
896 for(int basis=0;basis<MM;++basis){ double d2=hess[(basis*DD+a)*DD+b]; if(d2==0.0) continue;
897 double dot=0.0; for(int c=0;c<PP;++c) dot+=residual[c]*dec[basis*PP+c];
898 curv+=amp*d2*dot; }
899 h[a*DD+b]=hab+curv;
900 }
901 }
902 for(int a=0;a<DD;++a) h[a*DD+a]+=RIDGE;
903}
904
905// Cyclic Jacobi eigensolver (mirror of jacobi_eigh); vecs columns: vecs[col*D+row].
906__device__ void jacobi_eigh(const double* a_in, double* vals, double* vecs){
907 double a[DD*DD];
908 for(int i=0;i<DD*DD;++i) a[i]=a_in[i];
909 for(int r=0;r<DD;++r) for(int c=0;c<DD;++c) vecs[c*DD+r]=(r==c)?1.0:0.0;
910 if (DD==1){ vals[0]=a[0]; return; }
911 for(int sweep=0;sweep<30;++sweep){
912 double off=0.0;
913 for(int r=0;r<DD;++r) for(int c=r+1;c<DD;++c) off+=a[r*DD+c]*a[r*DD+c];
914 if (off<=1e-300) break;
915 for(int p=0;p<DD;++p) for(int q=p+1;q<DD;++q){
916 double apq=a[p*DD+q]; if(apq==0.0) continue;
917 double app=a[p*DD+p]; double aqq=a[q*DD+q];
918 double tau=(aqq-app)/(2.0*apq);
919 double t = (tau>=0.0) ? 1.0/(tau+sqrt(1.0+tau*tau)) : -1.0/(-tau+sqrt(1.0+tau*tau));
920 double cph=1.0/sqrt(1.0+t*t); double sph=t*cph;
921 for(int k=0;k<DD;++k){ double akp=a[k*DD+p]; double akq=a[k*DD+q];
922 a[k*DD+p]=cph*akp-sph*akq; a[k*DD+q]=sph*akp+cph*akq; }
923 for(int k=0;k<DD;++k){ double apk=a[p*DD+k]; double aqk=a[q*DD+k];
924 a[p*DD+k]=cph*apk-sph*aqk; a[q*DD+k]=sph*apk+cph*aqk; }
925 for(int k=0;k<DD;++k){ double vkp=vecs[p*DD+k]; double vkq=vecs[q*DD+k];
926 vecs[p*DD+k]=cph*vkp-sph*vkq; vecs[q*DD+k]=sph*vkp+cph*vkq; }
927 }
928 }
929 for(int i=0;i<DD;++i) vals[i]=a[i*DD+i];
930}
931
932// beta/eta/delta; returns 1 on success (lambda_min>0), 0 otherwise.
933__device__ int beta_eta_newton(const double* h, const double* g,
934 double* beta, double* eta, double* delta){
935 double vals[DD]; double vecs[DD*DD];
936 jacobi_eigh(h, vals, vecs);
937 double lmin=1.0/0.0; // +inf
938 for(int i=0;i<DD;++i) if(vals[i]<lmin) lmin=vals[i];
939 if (!(isfinite(lmin) && lmin>0.0)) return 0;
940 *beta=1.0/lmin;
941 for(int i=0;i<DD;++i) delta[i]=0.0;
942 for(int col=0;col<DD;++col){ double lam=vals[col]; if(lam<=0.0) return 0;
943 double vg=0.0; for(int row=0;row<DD;++row) vg+=vecs[col*DD+row]*g[row];
944 double coeff=vg/lam; for(int row=0;row<DD;++row) delta[row]-=coeff*vecs[col*DD+row]; }
945 double e2=0.0; for(int i=0;i<DD;++i) e2+=delta[i]*delta[i]; *eta=sqrt(e2);
946 return 1;
947}
948
949// row_certificate: writes h_out (=beta*eta*L or +inf) and delta; returns certified 0/1 mask via h.
950__device__ void row_certificate(const int* exps, const double* dec,
951 const double* t, const double* x, double amp, double L,
952 double* h_out, double* beta_out, double* eta_out, double* delta){
953 double phi[MM]; double jet[MM*DD]; double hess[MM*DD*DD]; double g[DD]; double H[DD*DD];
954 eval_basis(exps, t, phi, jet, hess);
955 grad_hess(dec, t, x, amp, phi, jet, hess, g, H);
956 double beta, eta;
957 if (beta_eta_newton(H, g, &beta, &eta, delta)){
958 *beta_out=beta; *eta_out=eta; *h_out=beta*eta*L;
959 } else {
960 *beta_out=1.0/0.0; *eta_out=1.0/0.0; *h_out=1.0/0.0;
961 for(int i=0;i<DD;++i) delta[i]=0.0;
962 }
963}
964
965__device__ int in_chart(const double* t, const double* center, double radius){
966 double r2=0.0; for(int i=0;i<DD;++i){ double d=t[i]-center[i]; r2+=d*d; }
967 return r2 <= radius*radius;
968}
969
970// certify_with_basin_warmup + refine. Returns 1 with coord/landing_h on success.
971__device__ int certify_basin(const int* exps, const double* dec,
972 const double* t_start, const double* x, double amp,
973 const double* center, double radius, double L,
974 double* coord_out, double* landing_h){
975 double t[DD]; for(int i=0;i<DD;++i) t[i]=t_start[i];
976 if(!in_chart(t, center, radius)) return 0;
977 double h, beta, eta; double delta[DD];
978 row_certificate(exps, dec, t, x, amp, L, &h, &beta, &eta, delta);
979 while(!(isfinite(h) && h<=KANTOROVICH)){
980 if(!(isfinite(h) && isfinite(beta) && isfinite(eta))) return 0;
981 double prev_h=h;
982 double next[DD]; for(int i=0;i<DD;++i) next[i]=t[i]+delta[i];
983 if(!in_chart(next, center, radius)) return 0;
984 for(int i=0;i<DD;++i) t[i]=next[i];
985 row_certificate(exps, dec, t, x, amp, L, &h, &beta, &eta, delta);
986 if(!(isfinite(h)) || h>=prev_h) return 0;
987 }
988 double landing = h;
989 for(int s=0;s<NEWTON;++s){
990 // convergence early-exit (mirror production refine_certified_start).
991 double dnorm=0.0, tnorm=0.0;
992 for(int i=0;i<DD;++i){ dnorm+=delta[i]*delta[i]; tnorm+=t[i]*t[i]; }
993 if(sqrt(dnorm) <= REFINE_EPS*(1.0+sqrt(tnorm))) break;
994 for(int i=0;i<DD;++i) t[i]+=delta[i];
995 row_certificate(exps, dec, t, x, amp, L, &h, &beta, &eta, delta);
996 if(!(isfinite(h) && h<=KANTOROVICH)) return 0;
997 }
998 for(int i=0;i<DD;++i) coord_out[i]=t[i];
999 *landing_h=landing;
1000 return 1;
1001}
1002
1003// One block per row. Charts are stored flattened; the block's lead thread runs
1004// the full route -> warm-start -> certify -> assign pipeline serially.
1005extern "C" __global__ void sae_certified_encode(
1006 const int* __restrict__ exps, // MM*DD
1007 const double* __restrict__ dec, // MM*PP
1008 const double* __restrict__ centers, // n_charts*DD
1009 const double* __restrict__ radii, // n_charts
1010 const double* __restrict__ cert_radii, // n_charts
1011 const double* __restrict__ lips, // n_charts
1012 const int* __restrict__ has_jac, // n_charts
1013 const double* __restrict__ a1, // n_charts*DD*PP
1014 const double* __restrict__ recon_c, // n_charts*PP
1015 int n_charts,
1016 const double* __restrict__ targets, // n*PP
1017 const double* __restrict__ amps, // n
1018 int n,
1019 double* __restrict__ coords_out, // n*DD
1020 double* __restrict__ h_out, // n (certificate h; >0.5 or inf = uncertified)
1021 int* __restrict__ certified_out) // n (1/0)
1022{
1023 int row = blockIdx.x;
1024 if (row >= n) return;
1025 if (threadIdx.x != 0) return;
1026 const double* x = targets + (size_t)row*PP;
1027 double amp = amps[row];
1028
1029 // ---- routing: top-TOPK certifiable charts by center recon distance. ----
1030 int cand[TOPK]; double cand_d[TOPK]; int ncand=0;
1031 {
1032 double phi[MM]; double jet[MM*DD]; double hess[MM*DD*DD]; double recon[PP];
1033 for(int idx=0; idx<n_charts; ++idx){
1034 if (cert_radii[idx] <= 0.0) continue;
1035 eval_basis(exps, centers + (size_t)idx*DD, phi, jet, hess);
1036 recon_amp1(dec, phi, recon);
1037 double dist=0.0; for(int c=0;c<PP;++c){ double df=recon[c]-x[c]; dist+=df*df; }
1038 // insert into the sorted top-TOPK by (dist, idx).
1039 int pos=ncand;
1040 while(pos>0 && (cand_d[pos-1]>dist)){ if(pos<TOPK){cand_d[pos]=cand_d[pos-1]; cand[pos]=cand[pos-1];} pos--; }
1041 if(pos<TOPK){ cand_d[pos]=dist; cand[pos]=idx; if(ncand<TOPK) ncand++; }
1042 }
1043 }
1044 // defaults: uncertified.
1045 for(int i=0;i<DD;++i) coords_out[(size_t)row*DD+i]=0.0;
1046 h_out[row]=1.0/0.0; certified_out[row]=0;
1047 if(ncand==0) return;
1048
1049 int have_fallback=0; double fb_coord[DD]; double fb_h; int fb_cert;
1050 int have_best=0; double best_coord[DD]; double best_h; double best_err=1.0/0.0;
1051
1052 for(int ci=0; ci<ncand; ++ci){
1053 int idx=cand[ci];
1054 const double* center = centers + (size_t)idx*DD;
1055 double radius=radii[idx]; double L=lips[idx];
1056 // amortized_warm_start.
1057 int ok_ws = has_jac[idx] && isfinite(amp) && (amp!=0.0);
1058 double t_hat[DD]; int produced=0; double coord[DD]; double landing_h; int cert=0;
1059 if(ok_ws){
1060 const double* A1 = a1 + (size_t)idx*DD*PP;
1061 const double* m1 = recon_c + (size_t)idx*PP;
1062 for(int i=0;i<DD;++i) t_hat[i]=center[i];
1063 for(int out=0; out<PP; ++out){ double resid=x[out]-amp*m1[out];
1064 for(int axis=0;axis<DD;++axis) t_hat[axis]+=A1[axis*PP+out]*resid/amp; }
1065 if(certify_basin(exps, dec, t_hat, x, amp, center, radius, L, coord, &landing_h)){
1066 produced=1; cert=(isfinite(landing_h) && landing_h<=KANTOROVICH);
1067 } else { produced=1; for(int i=0;i<DD;++i) coord[i]=0.0; landing_h=1.0/0.0; cert=0; }
1068 }
1069 if(!ok_ws){
1070 // warm start declined: fallback candidate = zeros, uncertified.
1071 if(!have_fallback){ have_fallback=1; for(int i=0;i<DD;++i) fb_coord[i]=0.0; fb_h=1.0/0.0; fb_cert=0; }
1072 continue;
1073 }
1074 if(!have_fallback){ have_fallback=1; for(int i=0;i<DD;++i) fb_coord[i]=coord[i]; fb_h=landing_h; fb_cert=cert; }
1075 if(cert){
1076 // reconstruction error at coord.
1077 double phi[MM]; double jet[MM*DD]; double hess[MM*DD*DD]; double recon[PP];
1078 eval_basis(exps, coord, phi, jet, hess); recon_amp1(dec, phi, recon);
1079 double e2=0.0; for(int c=0;c<PP;++c){ double r=x[c]-amp*recon[c]; e2+=r*r; }
1080 double err = isfinite(e2)? sqrt(e2) : 1.0/0.0;
1081 if(!have_best || err<best_err){ have_best=1; best_err=err; best_h=landing_h; for(int i=0;i<DD;++i) best_coord[i]=coord[i]; }
1082 // global-min short-circuit (mirror production certified_encode_row).
1083 double xnorm2=0.0; for(int c=0;c<PP;++c) xnorm2+=x[c]*x[c];
1084 if(best_err <= GMIN_FLOOR*(1.0+sqrt(xnorm2))) break;
1085 }
1086 (void)produced;
1087 }
1088 if(have_best){
1089 for(int i=0;i<DD;++i) coords_out[(size_t)row*DD+i]=best_coord[i];
1090 h_out[row]=best_h; certified_out[row]=1;
1091 } else if(have_fallback){
1092 for(int i=0;i<DD;++i) coords_out[(size_t)row*DD+i]=fb_coord[i];
1093 h_out[row]=fb_h; certified_out[row]=fb_cert;
1094 }
1095}
1096"#;
1097
1098#[cfg(target_os = "linux")]
1102#[must_use]
1103pub fn encode_kernel_source(dev: &EncodeAtomDevice) -> String {
1104 format!(
1105 "#define DD {}\n#define MM {}\n#define PP {}\n#define TOPK {}\n#define NEWTON {}\n\
1106 #define RIDGE ({:e})\n#define GMIN_FLOOR ({:e})\n#define REFINE_EPS ({:e})\n\
1107 {ENCODE_KERNEL_SOURCE}",
1108 dev.d,
1109 dev.m,
1110 dev.p,
1111 dev.topk,
1112 dev.newton_steps,
1113 dev.ridge,
1114 crate::encode::CERTIFIED_GLOBAL_MIN_RECON_FLOOR,
1115 crate::encode::NEWTON_REFINE_CONVERGED_EPS
1116 )
1117}
1118
1119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1122pub enum EncodePath {
1123 Device,
1125 Cpu,
1128}
1129
1130pub const DEVICE_ROW_THRESHOLD: usize = 4_096;
1132
1133#[must_use]
1139pub fn sae_certified_encode_batch(
1140 dev: &EncodeAtomDevice,
1141 targets: &[Vec<f64>],
1142 amplitudes: &[f64],
1143) -> (Vec<DeviceEncodeRow>, EncodePath) {
1144 #[cfg(target_os = "linux")]
1145 {
1146 if targets.len() >= DEVICE_ROW_THRESHOLD {
1147 if let Ok(out) = device::sae_certified_encode_device(dev, targets, amplitudes) {
1148 return (out, EncodePath::Device);
1149 }
1150 }
1152 }
1153 (
1154 emulate_certified_encode_batch(dev, targets, amplitudes),
1155 EncodePath::Cpu,
1156 )
1157}
1158
1159#[derive(Debug, Clone, Copy)]
1176pub struct DeviceEncodeThroughput {
1177 pub n_rows: usize,
1179 pub encode_secs: f64,
1181 pub rows_per_sec: f64,
1183 pub path: EncodePath,
1185 pub decision: EncodeDeploymentDecision,
1190}
1191
1192impl DeviceEncodeThroughput {
1193 #[must_use]
1196 pub fn device_engaged(&self) -> bool {
1197 matches!(self.path, EncodePath::Device)
1198 }
1199}
1200
1201#[must_use]
1218pub fn measure_device_encode_throughput(
1219 dev: &EncodeAtomDevice,
1220 targets: &[Vec<f64>],
1221 amplitudes: &[f64],
1222) -> DeviceEncodeThroughput {
1223 let n = targets.len();
1224 drop(sae_certified_encode_batch(dev, targets, amplitudes));
1227 let start = Instant::now();
1228 let (_out, path) = sae_certified_encode_batch(dev, targets, amplitudes);
1229 let elapsed = start.elapsed();
1230 let encode_secs = elapsed.as_secs_f64();
1231 let rows_per_sec = if n > 0 && encode_secs > 0.0 {
1232 n as f64 / encode_secs
1233 } else {
1234 0.0
1235 };
1236 let engaged = matches!(path, EncodePath::Device);
1237 let decision = if engaged {
1241 EncodeDeploymentDecision::from_device_measurement(true, rows_per_sec)
1242 } else {
1243 EncodeDeploymentDecision::blocked(EncodeDecisionBlocked::NoDevice)
1244 };
1245 DeviceEncodeThroughput {
1246 n_rows: n,
1247 encode_secs,
1248 rows_per_sec,
1249 path,
1250 decision,
1251 }
1252}
1253
1254#[cfg(target_os = "linux")]
1255mod device {
1256 use super::{
1257 DeviceEncodeRow, DeviceRowCertificate, EncodeAtomDevice, encode_kernel_source,
1258 };
1259 use gam_gpu::gpu_error::{GpuError, GpuResultExt};
1260 use std::collections::HashMap;
1261 use std::sync::{Arc, Mutex, OnceLock};
1262
1263 use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
1264
1265 struct Backend {
1266 ctx: Arc<CudaContext>,
1267 stream: Arc<CudaStream>,
1268 modules: Mutex<HashMap<String, Arc<CudaModule>>>,
1269 }
1270
1271 fn backend() -> Result<&'static Backend, GpuError> {
1272 static BACKEND: OnceLock<Result<Backend, GpuError>> = OnceLock::new();
1273 BACKEND
1274 .get_or_init(|| {
1275 let parts = gam_gpu::backend_probe::probe_cuda_backend("sae_encode")?;
1276 Ok(Backend {
1277 ctx: parts.ctx,
1278 stream: parts.stream,
1279 modules: Mutex::new(HashMap::new()),
1280 })
1281 })
1282 .as_ref()
1283 .map_err(GpuError::clone)
1284 }
1285
1286 fn module_for(b: &Backend, dev: &EncodeAtomDevice) -> Result<Arc<CudaModule>, GpuError> {
1287 let key = format!(
1288 "{}-{}-{}-{}-{}-{:e}",
1289 dev.d, dev.m, dev.p, dev.topk, dev.newton_steps, dev.ridge
1290 );
1291 if let Ok(guard) = b.modules.lock() {
1292 if let Some(m) = guard.get(&key) {
1293 return Ok(m.clone());
1294 }
1295 }
1296 let src = encode_kernel_source(dev);
1297 let ptx = gam_gpu::device_cache::compile_ptx_arch(&src)
1298 .gpu_ctx_with(|err| format!("sae_encode NVRTC compile ({key}): {err}"))?;
1299 let module = b.ctx.load_module(ptx).gpu_ctx("sae_encode module load")?;
1300 if let Ok(mut guard) = b.modules.lock() {
1301 guard.entry(key).or_insert_with(|| module.clone());
1302 }
1303 Ok(module)
1304 }
1305
1306 pub(super) fn sae_certified_encode_device(
1309 dev: &EncodeAtomDevice,
1310 targets: &[Vec<f64>],
1311 amplitudes: &[f64],
1312 ) -> Result<Vec<DeviceEncodeRow>, GpuError> {
1313 let n = targets.len();
1314 let (d, p) = (dev.d, dev.p);
1315 if n == 0 {
1316 return Ok(Vec::new());
1317 }
1318 let b = backend()?;
1319 let module = module_for(b, dev)?;
1320 let func = module
1321 .load_function("sae_certified_encode")
1322 .gpu_ctx("sae_encode load_function")?;
1323 let stream = b.stream.clone();
1324 let n_charts = dev.charts.len();
1325
1326 let mut centers = vec![0.0_f64; n_charts * d];
1328 let mut radii = vec![0.0_f64; n_charts];
1329 let mut cert_radii = vec![0.0_f64; n_charts];
1330 let mut lips = vec![0.0_f64; n_charts];
1331 let mut has_jac = vec![0_i32; n_charts];
1332 let mut a1 = vec![0.0_f64; n_charts * d * p];
1333 let mut recon_c = vec![0.0_f64; n_charts * p];
1334 for (i, ch) in dev.charts.iter().enumerate() {
1335 centers[i * d..(i + 1) * d].copy_from_slice(&ch.center);
1336 radii[i] = ch.radius;
1337 cert_radii[i] = ch.certified_radius;
1338 lips[i] = ch.lipschitz;
1339 has_jac[i] = i32::from(ch.has_jacobian);
1340 if ch.has_jacobian {
1341 a1[i * d * p..(i + 1) * d * p].copy_from_slice(&ch.amortized_jacobian);
1342 }
1343 recon_c[i * p..(i + 1) * p].copy_from_slice(&ch.recon_center);
1344 }
1345 let mut tgt = vec![0.0_f64; n * p];
1346 for (i, x) in targets.iter().enumerate() {
1347 tgt[i * p..(i + 1) * p].copy_from_slice(x);
1348 }
1349
1350 let exps_dev = stream.clone_htod(&dev.exponents).gpu_ctx("sae_encode htod exps")?;
1351 let dec_dev = stream.clone_htod(&dev.decoder).gpu_ctx("sae_encode htod dec")?;
1352 let centers_dev = stream.clone_htod(¢ers).gpu_ctx("sae_encode htod centers")?;
1353 let radii_dev = stream.clone_htod(&radii).gpu_ctx("sae_encode htod radii")?;
1354 let cert_dev = stream.clone_htod(&cert_radii).gpu_ctx("sae_encode htod cert_radii")?;
1355 let lips_dev = stream.clone_htod(&lips).gpu_ctx("sae_encode htod lips")?;
1356 let hasj_dev = stream.clone_htod(&has_jac).gpu_ctx("sae_encode htod has_jac")?;
1357 let a1_dev = stream.clone_htod(&a1).gpu_ctx("sae_encode htod a1")?;
1358 let reconc_dev = stream.clone_htod(&recon_c).gpu_ctx("sae_encode htod recon_c")?;
1359 let tgt_dev = stream.clone_htod(&tgt).gpu_ctx("sae_encode htod targets")?;
1360 let amps_dev = stream.clone_htod(&litudes.to_vec()).gpu_ctx("sae_encode htod amps")?;
1361 let mut coords_dev = stream.alloc_zeros::<f64>(n * d).gpu_ctx("sae_encode alloc coords")?;
1362 let mut h_dev = stream.alloc_zeros::<f64>(n).gpu_ctx("sae_encode alloc h")?;
1363 let mut cert_out_dev = stream.alloc_zeros::<i32>(n).gpu_ctx("sae_encode alloc certified")?;
1364
1365 let n_i32 = i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sae_encode n overflow"))?;
1366 let ncharts_i32 =
1367 i32::try_from(n_charts).map_err(|_| gam_gpu::gpu_err!("sae_encode n_charts overflow"))?;
1368 let cfg = LaunchConfig {
1369 grid_dim: (n_i32 as u32, 1, 1),
1370 block_dim: (32, 1, 1),
1371 shared_mem_bytes: 0,
1372 };
1373 let mut builder = stream.launch_builder(&func);
1374 builder
1375 .arg(&exps_dev)
1376 .arg(&dec_dev)
1377 .arg(¢ers_dev)
1378 .arg(&radii_dev)
1379 .arg(&cert_dev)
1380 .arg(&lips_dev)
1381 .arg(&hasj_dev)
1382 .arg(&a1_dev)
1383 .arg(&reconc_dev)
1384 .arg(&ncharts_i32)
1385 .arg(&tgt_dev)
1386 .arg(&s_dev)
1387 .arg(&n_i32)
1388 .arg(&mut coords_dev)
1389 .arg(&mut h_dev)
1390 .arg(&mut cert_out_dev);
1391 unsafe { builder.launch(cfg) }.gpu_ctx("sae_encode kernel launch")?;
1395
1396 let mut coords = vec![0.0_f64; n * d];
1397 let mut h = vec![0.0_f64; n];
1398 let mut cert = vec![0_i32; n];
1399 stream.memcpy_dtoh(&coords_dev, &mut coords).gpu_ctx("sae_encode dtoh coords")?;
1400 stream.memcpy_dtoh(&h_dev, &mut h).gpu_ctx("sae_encode dtoh h")?;
1401 stream.memcpy_dtoh(&cert_out_dev, &mut cert).gpu_ctx("sae_encode dtoh certified")?;
1402 stream.synchronize().gpu_ctx("sae_encode synchronize")?;
1403
1404 let mut out = Vec::with_capacity(n);
1405 for row in 0..n {
1406 let coord = coords[row * d..(row + 1) * d].to_vec();
1407 let hv = h[row];
1408 out.push(DeviceEncodeRow {
1409 coord,
1410 cert: DeviceRowCertificate {
1411 beta: f64::NAN,
1413 eta: f64::NAN,
1414 lipschitz: f64::NAN,
1415 h: hv,
1416 },
1417 });
1418 }
1419 for (row, o) in out.iter_mut().enumerate() {
1422 if cert[row] == 0 {
1423 o.cert.h = f64::INFINITY;
1424 }
1425 }
1426 Ok(out)
1427 }
1428}
1429
1430#[cfg(test)]
1431mod tests {
1432 use super::*;
1433 use crate::basis::{EuclideanPatchEvaluator, SaeBasisEvaluator};
1434 use crate::encode::{AtlasConfig, EncodeAtlas};
1435 use crate::manifold::{SaeAtomBasisKind, SaeManifoldAtom};
1436 use ndarray::{Array1, Array2};
1437 use std::sync::Arc;
1438
1439 fn build_atom_and_atlas(
1443 d: usize,
1444 deg: usize,
1445 p: usize,
1446 config: AtlasConfig,
1447 ) -> (SaeManifoldAtom, EncodeAtlas) {
1448 let evaluator = Arc::new(EuclideanPatchEvaluator::new(d, deg).unwrap());
1449 let n_seed = 12usize;
1452 let coords = Array2::from_shape_fn((n_seed, d), |(r, c)| {
1453 0.15 * ((r as f64 + 1.0) * (c as f64 + 2.0) * 0.37).sin()
1454 });
1455 let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
1456 let m = phi.ncols();
1457 let decoder = Array2::from_shape_fn((m, p), |(bidx, c)| {
1459 (1.0 / (1.0 + bidx as f64)) * (((bidx as f64 + 1.0) * (c as f64 + 1.0)) * 0.3).cos()
1460 });
1461 let atom = SaeManifoldAtom::new(
1462 "euclid",
1463 SaeAtomBasisKind::EuclideanPatch,
1464 d,
1465 phi,
1466 jet,
1467 decoder,
1468 Array2::<f64>::eye(m),
1469 )
1470 .unwrap()
1471 .with_basis_second_jet(evaluator);
1472 let atlas = EncodeAtlas::build(&[atom.clone()], &[2.0], 8.0, config).unwrap();
1474 (atom, atlas)
1475 }
1476
1477 fn assert_parity(
1481 atom: &SaeManifoldAtom,
1482 atlas: &EncodeAtlas,
1483 dev: &EncodeAtomDevice,
1484 rows: &[Vec<f64>],
1485 amps: &[f64],
1486 ) -> (usize, usize, f64, f64) {
1487 let mut certified = 0usize;
1488 let mut max_coord = 0.0_f64;
1489 let mut max_h = 0.0_f64;
1490 for (x, &) in rows.iter().zip(amps.iter()) {
1491 let xv = Array1::from(x.clone());
1492 let (coord_p, cert_p) = atlas
1493 .certified_encode_row(atom, 0, xv.view(), amp)
1494 .expect("production encode runs");
1495 let emu = emulate_certified_encode_row(dev, x, amp);
1496 assert_eq!(
1497 cert_p.certified(),
1498 emu.cert.certified(),
1499 "certificate flag mismatch (prod h={}, emu h={})",
1500 cert_p.h,
1501 emu.cert.h
1502 );
1503 if cert_p.certified() {
1504 certified += 1;
1505 for axis in 0..dev.d {
1506 max_coord = max_coord.max((coord_p[axis] - emu.coord[axis]).abs());
1507 }
1508 max_h = max_h.max((cert_p.h - emu.cert.h).abs());
1509 }
1510 }
1511 (certified, rows.len(), max_coord, max_h)
1512 }
1513
1514 #[test]
1515 fn emulator_matches_production_certified_encode_1d_quadratic() {
1516 let (d, deg, p) = (1usize, 2usize, 4usize);
1517 let config = AtlasConfig::default();
1518 let (atom, atlas) = build_atom_and_atlas(d, deg, p, config);
1519 let atom_atlas = &atlas.atoms[0];
1520 let dev = EncodeAtomDevice::from_atom_atlas(&atom, atom_atlas, &config).unwrap();
1521 let mut rows: Vec<Vec<f64>> = Vec::new();
1524 let mut amps: Vec<f64> = Vec::new();
1525 let evaluator = EuclideanPatchEvaluator::new(d, deg).unwrap();
1526 for k in 0..24 {
1527 let tc = -0.4 + 0.8 * (k as f64) / 23.0;
1528 let (phi, _) = evaluator
1529 .evaluate(Array2::from_shape_fn((1, d), |_| tc).view())
1530 .unwrap();
1531 let amp = 1.0;
1532 let mut x = vec![0.0; p];
1533 for c in 0..p {
1534 let mut r = 0.0;
1535 for b in 0..dev.m {
1536 r += phi[[0, b]] * dev.decoder[b * p + c];
1537 }
1538 x[c] = amp * r;
1539 }
1540 rows.push(x);
1541 amps.push(amp);
1542 }
1543 for k in 0..24 {
1545 let x = (0..p)
1546 .map(|c| 0.5 * (((k * 7 + c * 3) as f64) * 0.21).sin())
1547 .collect();
1548 rows.push(x);
1549 amps.push(0.7 + 0.3 * ((k as f64) * 0.11).cos());
1550 }
1551 let (cert, total, max_coord, max_h) = assert_parity(&atom, &atlas, &dev, &rows, &s);
1552 eprintln!(
1553 "1D quadratic: certified {cert}/{total}, max coord diff {max_coord:.3e}, max h diff {max_h:.3e}"
1554 );
1555 assert!(cert > 0, "planted rows must certify through the encode");
1556 assert!(max_coord <= 1e-7, "coord parity {max_coord:.3e} > 1e-7");
1557 assert!(max_h <= 1e-7, "certificate h parity {max_h:.3e} > 1e-7");
1558 }
1559
1560 #[test]
1561 fn emulator_matches_production_certified_encode_2d_quadratic() {
1562 let (d, deg, p) = (2usize, 2usize, 5usize);
1563 let config = AtlasConfig {
1564 grid_resolution: 6,
1565 ..AtlasConfig::default()
1566 };
1567 let (atom, atlas) = build_atom_and_atlas(d, deg, p, config);
1568 let atom_atlas = &atlas.atoms[0];
1569 let dev = EncodeAtomDevice::from_atom_atlas(&atom, atom_atlas, &config).unwrap();
1570 let evaluator = EuclideanPatchEvaluator::new(d, deg).unwrap();
1571 let mut rows: Vec<Vec<f64>> = Vec::new();
1572 let mut amps: Vec<f64> = Vec::new();
1573 for k in 0..30 {
1574 let t0 = -0.3 + 0.6 * ((k % 6) as f64) / 5.0;
1575 let t1 = -0.3 + 0.6 * ((k / 6) as f64) / 5.0;
1576 let coord = Array2::from_shape_fn((1, d), |(_, c)| if c == 0 { t0 } else { t1 });
1577 let (phi, _) = evaluator.evaluate(coord.view()).unwrap();
1578 let amp = 1.0;
1579 let mut x = vec![0.0; p];
1580 for c in 0..p {
1581 let mut r = 0.0;
1582 for b in 0..dev.m {
1583 r += phi[[0, b]] * dev.decoder[b * p + c];
1584 }
1585 x[c] = amp * r;
1586 }
1587 rows.push(x);
1588 amps.push(amp);
1589 }
1590 for k in 0..20 {
1591 let x = (0..p)
1592 .map(|c| 0.4 * (((k * 5 + c * 2) as f64) * 0.17).cos())
1593 .collect();
1594 rows.push(x);
1595 amps.push(1.0);
1596 }
1597 let (cert, total, max_coord, max_h) = assert_parity(&atom, &atlas, &dev, &rows, &s);
1598 eprintln!(
1599 "2D quadratic: certified {cert}/{total}, max coord diff {max_coord:.3e}, max h diff {max_h:.3e}"
1600 );
1601 assert!(cert > 0, "planted 2D rows must certify");
1602 assert!(max_coord <= 1e-6, "coord parity {max_coord:.3e} > 1e-6");
1603 assert!(max_h <= 1e-6, "certificate h parity {max_h:.3e} > 1e-6");
1604 }
1605
1606 #[test]
1607 fn emulator_matches_production_batch() {
1608 let (d, deg, p) = (1usize, 3usize, 3usize);
1609 let config = AtlasConfig::default();
1610 let (atom, atlas) = build_atom_and_atlas(d, deg, p, config);
1611 let dev = EncodeAtomDevice::from_atom_atlas(&atom, &atlas.atoms[0], &config).unwrap();
1612 let n = 40usize;
1613 let rows: Vec<Vec<f64>> = (0..n)
1614 .map(|k| (0..p).map(|c| 0.3 * (((k + c) as f64) * 0.19).sin()).collect())
1615 .collect();
1616 let amps: Vec<f64> = (0..n).map(|_| 1.0).collect();
1617 let (batch, path) = sae_certified_encode_batch(&dev, &rows, &s);
1618 assert_eq!(path, EncodePath::Cpu, "small batch stays on CPU");
1619 for (k, r) in batch.iter().enumerate() {
1621 let single = emulate_certified_encode_row(&dev, &rows[k], amps[k]);
1622 assert_eq!(r.cert.certified(), single.cert.certified());
1623 let xv = Array1::from(rows[k].clone());
1624 let (_, cert_p) = atlas
1625 .certified_encode_row(&atom, 0, xv.view(), amps[k])
1626 .unwrap();
1627 assert_eq!(
1628 cert_p.certified(),
1629 r.cert.certified(),
1630 "batch row {k} certificate flag disagrees with production"
1631 );
1632 }
1633 }
1634
1635 #[test]
1643 fn device_encode_throughput_gates_surrogate_on_measurement() {
1644 let (d, deg, p) = (1usize, 2usize, 4usize);
1645 let config = AtlasConfig::default();
1646 let (atom, atlas) = build_atom_and_atlas(d, deg, p, config);
1647 let dev = EncodeAtomDevice::from_atom_atlas(&atom, &atlas.atoms[0], &config).unwrap();
1648
1649 let n = DEVICE_ROW_THRESHOLD + 64;
1653 let evaluator = EuclideanPatchEvaluator::new(d, deg).unwrap();
1654 let mut rows: Vec<Vec<f64>> = Vec::with_capacity(n);
1655 let mut amps: Vec<f64> = Vec::with_capacity(n);
1656 for k in 0..n {
1657 if k % 2 == 0 {
1658 let tc = -0.4 + 0.8 * ((k % 24) as f64) / 23.0;
1660 let (phi, _) = evaluator
1661 .evaluate(Array2::from_shape_fn((1, d), |_| tc).view())
1662 .unwrap();
1663 let x = (0..p)
1664 .map(|c| {
1665 (0..dev.m)
1666 .map(|b| phi[[0, b]] * dev.decoder[b * p + c])
1667 .sum::<f64>()
1668 })
1669 .collect();
1670 rows.push(x);
1671 amps.push(1.0);
1672 } else {
1673 let x = (0..p)
1674 .map(|c| 0.5 * (((k * 7 + c * 3) as f64) * 0.021).sin())
1675 .collect();
1676 rows.push(x);
1677 amps.push(1.0);
1678 }
1679 }
1680
1681 let tput = measure_device_encode_throughput(&dev, &rows, &s);
1683 eprintln!(
1684 "[device-encode #988] n={} rows/sec={:.1} path={:?} decision={:?}",
1685 tput.n_rows, tput.rows_per_sec, tput.path, tput.decision
1686 );
1687
1688 assert!(
1691 tput.rows_per_sec > 0.0,
1692 "the exact encode benchmark must produce a positive rows/sec, got {}",
1693 tput.rows_per_sec
1694 );
1695 assert_eq!(tput.device_engaged(), matches!(tput.path, EncodePath::Device));
1696
1697 let (batch, _) = sae_certified_encode_batch(&dev, &rows, &s);
1701 let certified = batch.iter().filter(|r| r.cert.certified()).count();
1702 assert!(
1703 certified > 0,
1704 "the exact encode must certify a majority of the planted rows; certified={certified}/{n}"
1705 );
1706
1707 if tput.device_engaged() {
1708 assert!(
1711 !tput.decision.is_undetermined(),
1712 "an engaged device measurement must decide Met/Unmet, got {:?}",
1713 tput.decision
1714 );
1715 let target = gam_gpu::policy::GPU_THROUGHPUT_TARGET_ROWS_PER_SEC;
1716 if tput.rows_per_sec >= target {
1717 assert!(
1718 tput.decision.surrogate_unneeded(),
1719 "device rate {:.1} >= target {target} must mark the surrogate unneeded",
1720 tput.rows_per_sec
1721 );
1722 } else {
1723 assert!(
1724 tput.decision.surrogate_justified(),
1725 "device rate {:.1} < target {target} must justify the surrogate",
1726 tput.rows_per_sec
1727 );
1728 }
1729 } else {
1730 assert!(
1735 tput.decision.is_undetermined(),
1736 "a CPU-emulator exact encode must leave the surrogate decision Undetermined, got {:?}",
1737 tput.decision
1738 );
1739 assert!(!tput.decision.surrogate_unneeded());
1740 assert!(!tput.decision.surrogate_justified());
1741 }
1742 }
1743
1744 #[test]
1745 fn jacobi_eigh_matches_reference_2x2() {
1746 let a = [4.0, 1.0, 1.0, 3.0];
1748 let mut vals = [0.0; 2];
1749 let mut vecs = [0.0; 4];
1750 jacobi_eigh(&a, 2, &mut vals, &mut vecs);
1751 for r in 0..2 {
1753 for c in 0..2 {
1754 let mut acc = 0.0;
1755 for k in 0..2 {
1756 acc += vals[k] * vecs[k * 2 + r] * vecs[k * 2 + c];
1757 }
1758 assert!((acc - a[r * 2 + c]).abs() < 1e-12, "eig reconstruct {r},{c}");
1759 }
1760 }
1761 let mut vs = vals.to_vec();
1763 vs.sort_by(|a, b| a.partial_cmp(b).unwrap());
1764 assert!((vs[0] - (7.0 - 5.0_f64.sqrt()) / 2.0).abs() < 1e-12);
1765 assert!((vs[1] - (7.0 + 5.0_f64.sqrt()) / 2.0).abs() < 1e-12);
1766 }
1767
1768 #[cfg(target_os = "linux")]
1769 #[test]
1770 fn encode_kernel_source_substitutes_macros_and_compiles() {
1771 let (d, deg, p) = (1usize, 2usize, 4usize);
1772 let config = AtlasConfig::default();
1773 let (atom, atlas) = build_atom_and_atlas(d, deg, p, config);
1774 let dev = EncodeAtomDevice::from_atom_atlas(&atom, &atlas.atoms[0], &config).unwrap();
1775 let src = encode_kernel_source(&dev);
1776 assert!(src.contains(&format!("#define DD {}", dev.d)));
1777 assert!(src.contains(&format!("#define MM {}", dev.m)));
1778 assert!(src.contains(&format!("#define PP {}", dev.p)));
1779 assert!(src.contains("sae_certified_encode"));
1780 let ptx = gam_gpu::device_cache::compile_ptx_arch(&src)
1782 .expect("sae_encode kernel compiles to PTX via NVRTC");
1783 let text = ptx.to_src();
1784 assert!(text.contains(".visible .entry sae_certified_encode"),
1785 "PTX must export the encode entry");
1786 assert!(text.contains(".target sm_"), "PTX must carry a target arch");
1787 }
1788
1789 #[cfg(target_os = "linux")]
1790 #[test]
1791 fn device_matches_emulator_when_available() {
1792 let (d, deg, p) = (1usize, 2usize, 4usize);
1793 let config = AtlasConfig::default();
1794 let (atom, atlas) = build_atom_and_atlas(d, deg, p, config);
1795 let dev = EncodeAtomDevice::from_atom_atlas(&atom, &atlas.atoms[0], &config).unwrap();
1796 let n = DEVICE_ROW_THRESHOLD + 64;
1797 let rows: Vec<Vec<f64>> = (0..n)
1798 .map(|k| (0..p).map(|c| 0.3 * (((k + c) as f64) * 0.019).sin()).collect())
1799 .collect();
1800 let amps = vec![1.0; n];
1801 let cpu = emulate_certified_encode_batch(&dev, &rows, &s);
1802 if gam_gpu::device_runtime::GpuRuntime::global().is_some() {
1803 let devout = device::sae_certified_encode_device(&dev, &rows, &s)
1804 .expect("admitted GPU runtime must run the sae_encode kernel");
1805 let mut max_coord = 0.0_f64;
1806 for (a, b) in cpu.iter().zip(devout.iter()) {
1807 assert_eq!(a.cert.certified(), b.cert.certified(), "device certified flag");
1808 if a.cert.certified() {
1809 for axis in 0..dev.d {
1810 max_coord = max_coord.max((a.coord[axis] - b.coord[axis]).abs());
1811 }
1812 }
1813 }
1814 assert!(max_coord <= 1e-9, "device vs emulator coord diff {max_coord:.3e} > 1e-9");
1815 }
1816 }
1817}
1818