use std::hint::black_box;
use std::time::{Duration, Instant};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use super::linalg_dispatch::ResidentDesignGram;
use super::policy::{GpuThroughputVerdict, GPU_THROUGHPUT_TARGET_ROWS_PER_SEC};
#[derive(Clone, Copy, Debug)]
pub struct EncodeShape {
pub label: &'static str,
pub n: usize,
pub p: usize,
}
pub const CANONICAL_ENCODE_SHAPES: &[EncodeShape] = &[
EncodeShape {
label: "sae-2k-2048",
n: 2_000,
p: 2_048,
},
EncodeShape {
label: "sae-4k-4096",
n: 4_000,
p: 4_096,
},
EncodeShape {
label: "sae-8k-1024",
n: 8_000,
p: 1_024,
},
];
#[derive(Clone, Copy, Debug)]
pub struct ResidentSolveThroughput {
pub shape: EncodeShape,
pub engaged: bool,
pub measured_rows_per_sec: f64,
pub verdict: GpuThroughputVerdict,
}
fn lcg(state: &mut u64) -> f64 {
*state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(*state >> 11) as f64 / (1u64 << 53) as f64 * 2.0 - 1.0
}
fn planted_design(n: usize, p: usize, seed: u64) -> Array2<f64> {
let mut s = seed;
Array2::from_shape_fn((n, p), |_| lcg(&mut s) * 0.05)
}
#[must_use]
pub fn measure_resident_solve_throughput(shape: EncodeShape, reps: usize) -> ResidentSolveThroughput {
let EncodeShape { n, p, .. } = shape;
let not_engaged = |shape| ResidentSolveThroughput {
shape,
engaged: false,
measured_rows_per_sec: 0.0,
verdict: GpuThroughputVerdict::from_measurement(0.0),
};
if n == 0 || p == 0 || reps == 0 {
return not_engaged(shape);
}
let x = planted_design(n, p, 0x1412_a100_dead_beef);
let w = {
let mut s = 0x988_5ae_e0c0_de01u64;
Array1::from_shape_fn(n, |_| lcg(&mut s).abs() + 1e-3)
};
let rhs = Array1::from_shape_fn(p, |j| ((j as f64 + 1.0) * 0.03).cos());
let ridge = 1e-3_f64;
let handle = match ResidentDesignGram::try_new(x.view()) {
Some(h) => h,
None => return not_engaged(shape),
};
if handle.solve_normal_equations(w.view(), rhs.view(), ridge).is_none() {
return not_engaged(shape);
}
let mut total = Duration::ZERO;
for r in 0..reps {
let wr = Array1::from_shape_fn(n, |i| (w[i] + 1e-3 * (r as f64)).abs());
let start = Instant::now();
match handle.solve_normal_equations(wr.view(), rhs.view(), ridge) {
Some(beta) => {
black_box(beta);
}
None => return not_engaged(shape),
}
total += start.elapsed();
}
let secs = total.as_secs_f64() / reps as f64;
let measured_rows_per_sec = if secs > 0.0 { n as f64 / secs } else { 0.0 };
ResidentSolveThroughput {
shape,
engaged: measured_rows_per_sec > 0.0,
measured_rows_per_sec,
verdict: GpuThroughputVerdict::from_measurement(measured_rows_per_sec),
}
}
#[must_use]
pub fn cpu_oracle_normal_equations_solve(
x: ArrayView2<'_, f64>,
w: ArrayView1<'_, f64>,
rhs: ArrayView1<'_, f64>,
ridge: f64,
) -> Array1<f64> {
let (n, p) = x.dim();
assert_eq!(w.len(), n, "w must have one entry per design row");
assert_eq!(rhs.len(), p, "rhs must have one entry per border column");
let mut xw = x.to_owned();
for i in 0..n {
let sw = w[i].sqrt();
for a in 0..p {
xw[[i, a]] *= sw;
}
}
let mut gram = xw.t().dot(&xw);
for j in 0..p {
gram[[j, j]] += ridge;
}
let mut l = Array2::<f64>::zeros((p, p));
for j in 0..p {
let mut diag = gram[[j, j]];
for s in 0..j {
diag -= l[[j, s]] * l[[j, s]];
}
assert!(
diag > 0.0,
"cpu_oracle: non-positive Cholesky pivot {diag:.3e} at index {j} — \
the Gram is not positive-definite (need ridge>0 and w>0)"
);
let ljj = diag.sqrt();
l[[j, j]] = ljj;
for i in (j + 1)..p {
let mut off = gram[[i, j]];
for s in 0..j {
off -= l[[i, s]] * l[[j, s]];
}
l[[i, j]] = off / ljj;
}
}
let mut y = rhs.to_owned();
for i in 0..p {
let mut acc = y[i];
for s in 0..i {
acc -= l[[i, s]] * y[s];
}
y[i] = acc / l[[i, i]];
}
let mut beta = y;
for i in (0..p).rev() {
let mut acc = beta[i];
for s in (i + 1)..p {
acc -= l[[s, i]] * beta[s];
}
beta[i] = acc / l[[i, i]];
}
beta
}
pub const DEPLOYMENT_TARGET_ROWS_PER_SEC: f64 = GPU_THROUGHPUT_TARGET_ROWS_PER_SEC;
#[derive(Clone, Copy, Debug)]
pub struct FullEncodeThroughput {
pub n_rows: usize,
pub encode_secs: f64,
pub rows_per_sec: f64,
pub device_encode_engaged: bool,
}
impl FullEncodeThroughput {
#[must_use]
pub fn from_elapsed(n_rows: usize, elapsed: Duration, device_encode_engaged: bool) -> Self {
let encode_secs = elapsed.as_secs_f64();
let rows_per_sec = if n_rows > 0 && encode_secs > 0.0 {
n_rows as f64 / encode_secs
} else {
0.0
};
Self {
n_rows,
encode_secs,
rows_per_sec,
device_encode_engaged,
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct EncodeQualityMetrics {
pub n_rows: usize,
pub certified_rows: usize,
pub fallback_rate: f64,
pub support_agreement: f64,
pub max_coord_abs_err: f64,
pub max_reconstruction_abs_err: f64,
pub reconstruction_ev: f64,
}
#[must_use]
pub fn encode_quality_metrics(
coords: ArrayView2<'_, f64>,
certified: &[bool],
coords_ref: ArrayView2<'_, f64>,
certified_ref: &[bool],
reconstruction: ArrayView2<'_, f64>,
targets: ArrayView2<'_, f64>,
) -> EncodeQualityMetrics {
let (n, d) = coords.dim();
assert_eq!(
coords_ref.dim(),
(n, d),
"encode_quality_metrics: reference coords shape {:?} != under-test {:?}",
coords_ref.dim(),
(n, d)
);
assert_eq!(certified.len(), n, "certified flags must have one entry per row");
assert_eq!(
certified_ref.len(),
n,
"reference certified flags must have one entry per row"
);
let (nt, p) = targets.dim();
assert_eq!(nt, n, "targets must have one row per encoded row");
assert_eq!(
reconstruction.dim(),
(n, p),
"reconstruction shape {:?} != targets {:?}",
reconstruction.dim(),
(n, p)
);
let certified_rows = certified.iter().filter(|c| **c).count();
let fallback_rate = if n > 0 {
1.0 - certified_rows as f64 / n as f64
} else {
0.0
};
let agree = certified
.iter()
.zip(certified_ref.iter())
.filter(|(a, b)| a == b)
.count();
let support_agreement = if n > 0 { agree as f64 / n as f64 } else { 1.0 };
let mut max_coord_abs_err = 0.0_f64;
for i in 0..n {
for j in 0..d {
max_coord_abs_err = max_coord_abs_err.max((coords[[i, j]] - coords_ref[[i, j]]).abs());
}
}
let mut max_reconstruction_abs_err = 0.0_f64;
let mut ss_res = 0.0_f64;
let mut ss_tot = 0.0_f64;
for c in 0..p {
let mut mean = 0.0_f64;
for i in 0..n {
mean += targets[[i, c]];
}
if n > 0 {
mean /= n as f64;
}
for i in 0..n {
let resid = reconstruction[[i, c]] - targets[[i, c]];
max_reconstruction_abs_err = max_reconstruction_abs_err.max(resid.abs());
ss_res += resid * resid;
let centered = targets[[i, c]] - mean;
ss_tot += centered * centered;
}
}
let reconstruction_ev = if ss_tot > 0.0 {
1.0 - ss_res / ss_tot
} else {
if ss_res == 0.0 { 1.0 } else { 0.0 }
};
EncodeQualityMetrics {
n_rows: n,
certified_rows,
fallback_rate,
support_agreement,
max_coord_abs_err,
max_reconstruction_abs_err,
reconstruction_ev,
}
}
#[cfg(test)]
mod full_encode_metric_tests {
use super::*;
use ndarray::array;
#[test]
fn throughput_is_rows_over_seconds_and_guards_degenerate_time() {
let t = FullEncodeThroughput::from_elapsed(8_000, Duration::from_millis(100), false);
assert_eq!(t.n_rows, 8_000);
assert!(!t.device_encode_engaged);
assert!((t.rows_per_sec - 80_000.0).abs() < 1.0, "got {}", t.rows_per_sec);
let z = FullEncodeThroughput::from_elapsed(8_000, Duration::ZERO, false);
assert_eq!(z.rows_per_sec, 0.0);
}
#[test]
fn perfect_match_scores_full_agreement_and_unit_ev() {
let coords = array![[0.10], [0.40]];
let targets = array![[1.0, 0.0], [0.0, 1.0]];
let m = encode_quality_metrics(
coords.view(),
&[true, true],
coords.view(),
&[true, true],
targets.view(),
targets.view(),
);
assert_eq!(m.n_rows, 2);
assert_eq!(m.certified_rows, 2);
assert_eq!(m.fallback_rate, 0.0);
assert_eq!(m.support_agreement, 1.0);
assert_eq!(m.max_coord_abs_err, 0.0);
assert_eq!(m.max_reconstruction_abs_err, 0.0);
assert!((m.reconstruction_ev - 1.0).abs() < 1e-12);
}
#[test]
fn divergence_is_surfaced_in_every_axis() {
let coords = array![[0.10], [0.40]];
let coords_ref = array![[0.10], [0.50]]; let targets = array![[1.0, 0.0], [0.0, 1.0]];
let recon = array![[1.0, 0.0], [0.0, 0.75]];
let m = encode_quality_metrics(
coords.view(),
&[true, false], coords_ref.view(),
&[true, true], recon.view(),
targets.view(),
);
assert_eq!(m.certified_rows, 1);
assert!((m.fallback_rate - 0.5).abs() < 1e-12);
assert!((m.support_agreement - 0.5).abs() < 1e-12); assert!((m.max_coord_abs_err - 0.10).abs() < 1e-12);
assert!((m.max_reconstruction_abs_err - 0.25).abs() < 1e-12);
assert!(m.reconstruction_ev < 1.0);
}
}