use std::sync::Arc;
use std::time::Instant;
use ndarray::{Array1, Array2};
use gam_gpu::device_runtime::GpuRuntime;
use gam_gpu::encode_throughput::{encode_quality_metrics, FullEncodeThroughput};
use gam_gpu::policy::EncodeDeploymentDecision;
use gam_gpu::{GpuError, GpuMode};
use gam_sae::basis::{PeriodicHarmonicEvaluator, SaeBasisEvaluator};
use gam_sae::encode::{AtlasConfig, EncodeAtlas};
use gam_sae::manifold::{SaeAtomBasisKind, SaeManifoldAtom};
fn orthonormal_pair(p: usize) -> (Array1<f64>, Array1<f64>) {
let mut u = Array1::from_shape_fn(p, |j| (0.3 * j as f64 + 0.1).cos());
let mut v = Array1::from_shape_fn(p, |j| (0.2 * j as f64 + 0.7).sin());
let un = u.dot(&u).sqrt();
u.mapv_inplace(|x| x / un);
let proj = v.dot(&u);
v = &v - &(&u * proj);
let vn = v.dot(&v).sqrt();
v.mapv_inplace(|x| x / vn);
(u, v)
}
fn build_fixture(
n: usize,
p: usize,
) -> (
SaeManifoldAtom,
EncodeAtlas,
Array2<f64>,
Array1<f64>,
Array1<f64>,
) {
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(5).unwrap());
let n_seed = 64usize;
let seed: Array2<f64> =
Array2::from_shape_fn((n_seed, 1), |(i, _)| i as f64 / n_seed as f64);
let (seed_phi, seed_jet) = evaluator.evaluate(seed.view()).unwrap();
let m = seed_phi.ncols();
let (u, v) = orthonormal_pair(p);
let mut decoder = Array2::<f64>::zeros((m, p));
for c in 0..p {
decoder[[2, c]] = u[c];
decoder[[1, c]] = v[c];
}
let atom = SaeManifoldAtom::new(
"circle",
SaeAtomBasisKind::Periodic,
1,
seed_phi,
seed_jet,
decoder.clone(),
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(evaluator.clone());
let planted_t: Array1<f64> = Array1::from_shape_fn(n, |i| (i as f64 + 0.5) / n as f64);
let amplitudes: Array1<f64> =
Array1::from_shape_fn(n, |i| 0.8 + 0.4 * ((i as f64 * 0.123).sin() * 0.5 + 0.5));
let coords_truth = Array2::from_shape_fn((n, 1), |(i, _)| planted_t[i]);
let (phi_truth, _) = evaluator.evaluate(coords_truth.view()).unwrap();
let decoded = phi_truth.dot(&decoder); let mut targets = decoded;
for i in 0..n {
let z = amplitudes[i];
for c in 0..p {
targets[[i, c]] *= z;
}
}
let amplitude_bound = amplitudes.iter().cloned().fold(0.0_f64, f64::max);
let mut target_norm_bound = 0.0_f64;
for i in 0..n {
target_norm_bound = target_norm_bound.max(targets.row(i).dot(&targets.row(i)).sqrt());
}
let atlas = EncodeAtlas::build(
std::slice::from_ref(&atom),
&[amplitude_bound],
target_norm_bound,
AtlasConfig {
grid_resolution: 64,
ridge: 1e-10,
newton_steps: 8,
},
)
.expect("encode atlas builds over the frozen dictionary");
(atom, atlas, targets, amplitudes, planted_t)
}
fn reconstruct(
atom: &SaeManifoldAtom,
coords: &Array2<f64>,
amplitudes: &Array1<f64>,
) -> Array2<f64> {
let evaluator = atom.basis_evaluator.as_ref().expect("atom has evaluator");
let (phi, _) = evaluator.evaluate(coords.view()).unwrap();
let mut recon = phi.dot(&atom.decoder_coefficients); for i in 0..coords.nrows() {
let z = amplitudes[i];
for c in 0..recon.ncols() {
recon[[i, c]] *= z;
}
}
recon
}
#[test]
fn full_exact_encode_throughput_and_correctness() {
let n = 4_096usize; let p = 64usize;
let (atom, atlas, targets, amplitudes, _planted_t) = build_fixture(n, p);
let mut coords_ref = Array2::<f64>::zeros((n, atom.latent_dim));
let mut certified_ref = vec![false; n];
for i in 0..n {
let (t, cert) = atlas
.certified_encode_row(&atom, 0, targets.row(i), amplitudes[i])
.expect("per-row reference encode runs");
coords_ref.row_mut(i).assign(&t);
certified_ref[i] = cert.certified();
}
atlas
.certified_encode_batch(&atom, 0, targets.view(), amplitudes.view())
.expect("warm batch encode runs");
let start = Instant::now();
let result = atlas
.certified_encode_batch(&atom, 0, targets.view(), amplitudes.view())
.expect("timed batch encode runs");
let elapsed = start.elapsed();
let throughput = FullEncodeThroughput::from_elapsed(n, elapsed, false);
assert!(!throughput.device_encode_engaged);
assert!(
throughput.rows_per_sec > 0.0,
"the full encode must produce a positive rows/sec, got {}",
throughput.rows_per_sec
);
let decision = EncodeDeploymentDecision::from_device_measurement(
throughput.device_encode_engaged,
throughput.rows_per_sec,
);
eprintln!("[full-encode] deployment decision (device-only tri-state): {decision:?}");
assert!(
decision.is_undetermined(),
"#988/#1412: with no device-resident exact-encode kernel the full-encode deployment \
decision must be Undetermined (BLOCKED), got {decision:?}"
);
assert!(
!decision.surrogate_unneeded() && !decision.surrogate_justified(),
"#988/#1412: a CPU full-encode measurement must neither certify the target nor refute it — \
the surrogate decision is BLOCKED on a device measurement, got {decision:?}"
);
let reconstruction = reconstruct(&atom, &result.coords, &litudes);
let metrics = encode_quality_metrics(
result.coords.view(),
&result.certified,
coords_ref.view(),
&certified_ref,
reconstruction.view(),
targets.view(),
);
eprintln!(
"[full-encode] n={n} p={p} rows/sec={:.1} (device_engaged={}) | \
certified={}/{} fallback_rate={:.3} support_agreement={:.6} \
max_coord_err={:.3e} reconstruction_ev={:.6} max_recon_err={:.3e}",
throughput.rows_per_sec,
throughput.device_encode_engaged,
metrics.certified_rows,
n,
metrics.fallback_rate,
metrics.support_agreement,
metrics.max_coord_abs_err,
metrics.reconstruction_ev,
metrics.max_reconstruction_abs_err,
);
assert_eq!(
metrics.support_agreement, 1.0,
"batched encode certificate flags must match the per-row reference on every row"
);
assert!(
metrics.max_coord_abs_err < 1e-12,
"batched encode coordinates must match the per-row reference to round-off; \
max |Δcoord| = {:.3e}",
metrics.max_coord_abs_err
);
assert!(
metrics.reconstruction_ev > 0.99,
"exact encode must reconstruct on-manifold targets (EV > 0.99); got {:.6}",
metrics.reconstruction_ev
);
assert!(
metrics.max_reconstruction_abs_err < 1e-2,
"worst per-element reconstruction residual too large: {:.3e}",
metrics.max_reconstruction_abs_err
);
assert!(
metrics.fallback_rate < 0.5,
"the certified encode must certify a majority of a well-conditioned circle \
dictionary; fallback_rate = {:.3}",
metrics.fallback_rate
);
let required = GpuRuntime::global_or_fail(GpuMode::Required);
if GpuRuntime::is_available() {
assert!(
required.is_ok(),
"GpuMode::Required must succeed when a device is present"
);
} else {
assert!(
matches!(required, Err(GpuError::DriverLibraryUnavailable { .. })),
"GpuMode::Required must fail closed when the device is absent, got {required:?}"
);
}
assert!(GpuRuntime::global_or_fail(GpuMode::Off).is_err());
}