use std::hint::black_box;
use std::time::{Duration, Instant};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use super::linalg_dispatch::ResidentDesignGram;
use super::policy::{GpuThroughputVerdict, GPU_THROUGHPUT_TARGET_ROWS_PER_SEC};
#[derive(Clone, Copy, Debug)]
pub struct EncodeShape {
pub label: &'static str,
pub n: usize,
pub p: usize,
}
pub const CANONICAL_ENCODE_SHAPES: &[EncodeShape] = &[
EncodeShape {
label: "sae-2k-2048",
n: 2_000,
p: 2_048,
},
EncodeShape {
label: "sae-4k-4096",
n: 4_000,
p: 4_096,
},
EncodeShape {
label: "sae-8k-1024",
n: 8_000,
p: 1_024,
},
];
#[derive(Clone, Copy, Debug)]
pub struct ResidentSolveThroughput {
pub shape: EncodeShape,
pub engaged: bool,
pub measured_rows_per_sec: f64,
pub verdict: GpuThroughputVerdict,
}
fn lcg(state: &mut u64) -> f64 {
*state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(*state >> 11) as f64 / (1u64 << 53) as f64 * 2.0 - 1.0
}
fn planted_design(n: usize, p: usize, seed: u64) -> Array2<f64> {
let mut s = seed;
Array2::from_shape_fn((n, p), |_| lcg(&mut s) * 0.05)
}
#[must_use]
pub fn measure_resident_solve_throughput(shape: EncodeShape, reps: usize) -> ResidentSolveThroughput {
let EncodeShape { n, p, .. } = shape;
let not_engaged = |shape| ResidentSolveThroughput {
shape,
engaged: false,
measured_rows_per_sec: 0.0,
verdict: GpuThroughputVerdict::from_measurement(0.0),
};
if n == 0 || p == 0 || reps == 0 {
return not_engaged(shape);
}
let x = planted_design(n, p, 0x1412_a100_dead_beef);
let w = {
let mut s = 0x988_5ae_e0c0_de01u64;
Array1::from_shape_fn(n, |_| lcg(&mut s).abs() + 1e-3)
};
let rhs = Array1::from_shape_fn(p, |j| ((j as f64 + 1.0) * 0.03).cos());
let ridge = 1e-3_f64;
let handle = match ResidentDesignGram::try_new(x.view()) {
Some(h) => h,
None => return not_engaged(shape),
};
if handle.solve_normal_equations(w.view(), rhs.view(), ridge).is_none() {
return not_engaged(shape);
}
let mut total = Duration::ZERO;
for r in 0..reps {
let wr = Array1::from_shape_fn(n, |i| (w[i] + 1e-3 * (r as f64)).abs());
let start = Instant::now();
match handle.solve_normal_equations(wr.view(), rhs.view(), ridge) {
Some(beta) => {
black_box(beta);
}
None => return not_engaged(shape),
}
total += start.elapsed();
}
let secs = total.as_secs_f64() / reps as f64;
let measured_rows_per_sec = if secs > 0.0 { n as f64 / secs } else { 0.0 };
ResidentSolveThroughput {
shape,
engaged: measured_rows_per_sec > 0.0,
measured_rows_per_sec,
verdict: GpuThroughputVerdict::from_measurement(measured_rows_per_sec),
}
}
#[must_use]
pub fn cpu_oracle_normal_equations_solve(
x: ArrayView2<'_, f64>,
w: ArrayView1<'_, f64>,
rhs: ArrayView1<'_, f64>,
ridge: f64,
) -> Array1<f64> {
let (n, p) = x.dim();
assert_eq!(w.len(), n, "w must have one entry per design row");
assert_eq!(rhs.len(), p, "rhs must have one entry per border column");
let mut xw = x.to_owned();
for i in 0..n {
let sw = w[i].sqrt();
for a in 0..p {
xw[[i, a]] *= sw;
}
}
let mut gram = xw.t().dot(&xw);
for j in 0..p {
gram[[j, j]] += ridge;
}
let mut l = Array2::<f64>::zeros((p, p));
for j in 0..p {
let mut diag = gram[[j, j]];
for s in 0..j {
diag -= l[[j, s]] * l[[j, s]];
}
assert!(
diag > 0.0,
"cpu_oracle: non-positive Cholesky pivot {diag:.3e} at index {j} — \
the Gram is not positive-definite (need ridge>0 and w>0)"
);
let ljj = diag.sqrt();
l[[j, j]] = ljj;
for i in (j + 1)..p {
let mut off = gram[[i, j]];
for s in 0..j {
off -= l[[i, s]] * l[[j, s]];
}
l[[i, j]] = off / ljj;
}
}
let mut y = rhs.to_owned();
for i in 0..p {
let mut acc = y[i];
for s in 0..i {
acc -= l[[i, s]] * y[s];
}
y[i] = acc / l[[i, i]];
}
let mut beta = y;
for i in (0..p).rev() {
let mut acc = beta[i];
for s in (i + 1)..p {
acc -= l[[s, i]] * beta[s];
}
beta[i] = acc / l[[i, i]];
}
beta
}
pub const DEPLOYMENT_TARGET_ROWS_PER_SEC: f64 = GPU_THROUGHPUT_TARGET_ROWS_PER_SEC;