pub mod activation_bridge;
pub mod argmax_rows;
pub mod embedding_lookup;
pub mod flash_attention;
pub mod fused_add_rms_norm;
pub mod gemm;
pub mod kv_cache_append;
pub mod marlin_matmul; pub mod paged_varlen_attn; pub mod qk_norm_rope;
pub mod residual_add;
pub mod rms_norm;
pub mod silu_mul;
pub mod split_qkv;
pub mod transpose_head_to_token;
pub const NMSE_FP32_TOL: f64 = 1e-7;
pub const NMSE_FP16_TOL: f64 = 1e-6;
pub fn nmse(a: &[f32], b: &[f32]) -> f64 {
assert_eq!(a.len(), b.len(), "nmse: length mismatch");
if a.is_empty() {
return 0.0;
}
let n = a.len() as f64;
let mse_ab: f64 = a
.iter()
.zip(b.iter())
.map(|(x, y)| {
let d = (*x as f64) - (*y as f64);
d * d
})
.sum::<f64>()
/ n;
let mse_a0: f64 = a
.iter()
.map(|x| {
let d = *x as f64;
d * d
})
.sum::<f64>()
/ n;
if mse_a0 < 1e-30 {
return mse_ab;
}
mse_ab / mse_a0
}
pub type Output = Vec<f32>;
pub trait OpUnderTest {
fn name(&self) -> &str;
fn run_cpu(&self, seed: u64) -> Output;
#[cfg(all(target_os = "macos", feature = "metal"))]
fn run_metal(&self, seed: u64) -> Output;
#[cfg(feature = "cuda")]
fn run_cuda(&self, seed: u64) -> Output;
}
#[derive(Debug)]
pub struct NmseReport {
pub op: String,
pub seed: u64,
pub cpu: Vec<f32>,
pub metal_nmse: Option<f64>,
pub cuda_nmse: Option<f64>,
}
impl NmseReport {
pub fn within_tol(&self, tol: f64) -> bool {
self.metal_nmse.map_or(true, |n| n < tol) && self.cuda_nmse.map_or(true, |n| n < tol)
}
}
pub fn compare_backends(op: &dyn OpUnderTest, seed: u64) -> NmseReport {
let cpu = op.run_cpu(seed);
let metal_nmse = run_metal_nmse(op, &cpu, seed);
let cuda_nmse = run_cuda_nmse(op, &cpu, seed);
NmseReport {
op: op.name().to_string(),
seed,
cpu,
metal_nmse,
cuda_nmse,
}
}
#[cfg(all(target_os = "macos", feature = "metal"))]
fn run_metal_nmse(op: &dyn OpUnderTest, cpu: &[f32], seed: u64) -> Option<f64> {
Some(nmse(cpu, &op.run_metal(seed)))
}
#[cfg(not(all(target_os = "macos", feature = "metal")))]
fn run_metal_nmse(_op: &dyn OpUnderTest, _cpu: &[f32], _seed: u64) -> Option<f64> {
None
}
#[cfg(feature = "cuda")]
fn run_cuda_nmse(op: &dyn OpUnderTest, cpu: &[f32], seed: u64) -> Option<f64> {
Some(nmse(cpu, &op.run_cuda(seed)))
}
#[cfg(not(feature = "cuda"))]
fn run_cuda_nmse(_op: &dyn OpUnderTest, _cpu: &[f32], _seed: u64) -> Option<f64> {
None
}
pub fn random_vec(n: usize, lo: f32, hi: f32, seed: u64) -> Vec<f32> {
use rand::{Rng, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
(0..n).map(|_| rng.random_range(lo..hi)).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn nmse_identical_is_zero() {
let a = vec![1.0, 2.0, 3.0];
assert!(nmse(&a, &a) < 1e-30);
}
#[test]
fn nmse_scaled_b_proportional() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b: Vec<f32> = a.iter().map(|x| x * 1.01).collect();
let n = nmse(&a, &b);
assert!((n - 1e-4).abs() < 1e-5);
}
#[test]
fn nmse_zero_reference_falls_back() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![0.1, 0.1, 0.1];
let n = nmse(&a, &b);
assert!((n - 0.01).abs() < 1e-9);
}
#[test]
fn random_vec_determinism() {
let a = random_vec(100, -1.0, 1.0, 42);
let b = random_vec(100, -1.0, 1.0, 42);
assert_eq!(a, b);
}
}