use std::{cell::RefCell, ptr::NonNull};
use metaltile_codegen::msl::MslGenerator;
pub use metaltile_core::dtype::DType;
use metaltile_core::ir::{Kernel, KernelMode};
use crate::stats::BenchStats;
pub const FLOAT_DTYPES: &[DType] = &[DType::F32, DType::F16, DType::BF16];
pub const FLOAT_DTYPE_STRS: &[&str] = &["f32", "f16", "bf16"];
pub const INTEGER_DTYPES: &[DType] = &[DType::I32, DType::U32, DType::I8, DType::U8];
pub fn dtype_label(dt: DType) -> &'static str {
match dt {
DType::F32 => "f32",
DType::F16 => "f16",
DType::BF16 => "bf16",
DType::I32 => "i32",
DType::U32 => "u32",
DType::I8 => "i8",
DType::U8 => "u8",
DType::Bool => "bool",
_ => "?",
}
}
pub fn mlx_tname(dt: DType) -> &'static str {
match dt {
DType::F32 => "float32",
DType::F16 => "float16",
DType::BF16 => "bfloat16",
DType::I32 => "int32",
DType::U32 => "uint32",
DType::I8 => "int8",
DType::U8 => "uint8",
DType::Bool => "bool_",
_ => "float32",
}
}
pub fn elem_bytes(dt: DType) -> usize {
match dt {
DType::F32 | DType::I32 | DType::U32 => 4,
DType::F16 | DType::BF16 => 2,
DType::U8 | DType::Bool | DType::I8 => 1,
_ => 4,
}
}
pub fn dtype_tol(dt: DType) -> f32 {
match dt {
DType::F32 => 1e-4,
DType::F16 => 1.5e-2,
DType::BF16 => 1.3e-1,
_ => 0.0,
}
}
pub fn dtype_tol_reduce(dt: DType) -> f32 {
match dt {
DType::F32 => 1e-3,
DType::F16 => 0.5,
DType::BF16 => 128.0,
_ => 1e-3,
}
}
fn f32_to_f16(v: f32) -> u16 {
let x = v.to_bits();
let sign = ((x >> 31) as u16) << 15;
let exp = ((x >> 23) & 0xFF) as i32 - 127 + 15;
let mant32 = x & 0x7F_FFFF;
if exp <= 0 {
return sign;
}
if exp >= 31 {
return sign | 0x7C00;
}
let mant16 = mant32 >> 13;
let round_bit = (mant32 >> 12) & 1;
let sticky = mant32 & 0xFFF;
let round_up = round_bit == 1 && (sticky != 0 || (mant16 & 1) == 1);
let mant16 = (mant16 + u32::from(round_up)) as u16;
if mant16 > 0x3FF {
sign | (((exp + 1) as u16) << 10)
} else {
sign | ((exp as u16) << 10) | mant16
}
}
fn f32_to_bf16(v: f32) -> u16 {
let x = v.to_bits();
let rounded = x.wrapping_add(0x7FFF).wrapping_add((x >> 16) & 1);
(rounded >> 16) as u16
}
fn f16_to_f32(bits: u16) -> f32 {
let sign = ((bits as u32) >> 15) << 31;
let exp5 = ((bits as u32) >> 10) & 0x1f;
let mantissa = (bits as u32) & 0x3ff;
if exp5 == 0 {
return f32::from_bits(sign);
}
if exp5 == 31 {
return f32::from_bits(sign | 0x7f80_0000 | (mantissa << 13));
}
f32::from_bits(sign | ((exp5 + 112) << 23) | (mantissa << 13))
}
fn bf16_to_f32(bits: u16) -> f32 { f32::from_bits((bits as u32) << 16) }
pub fn quantize_roundtrip(vals: &[f32], dt: DType) -> Vec<f32> {
match dt {
DType::F32 => vals.to_vec(),
DType::F16 => vals.iter().map(|&v| f16_to_f32(f32_to_f16(v))).collect(),
DType::BF16 => vals.iter().map(|&v| bf16_to_f32(f32_to_bf16(v))).collect(),
DType::I32 => vals.iter().map(|&v| v as i32 as f32).collect(),
DType::U32 => vals.iter().map(|&v| v as u32 as f32).collect(),
DType::I8 => vals.iter().map(|&v| v as i8 as f32).collect(),
DType::U8 => vals.iter().map(|&v| v as u8 as f32).collect(),
_ => vals.to_vec(),
}
}
type ResultReporterFn = NonNull<dyn FnMut(&OpResult)>;
thread_local! {
static RESULT_REPORTER: RefCell<Option<ResultReporterFn>> = RefCell::new(None);
}
pub const DEFAULT_MIN_COSINE_SIM: f32 = 0.999;
#[derive(Debug, Clone, Copy)]
pub struct EquivResult {
pub n_checked: usize,
pub max_abs_err: f32,
pub cosine_sim: f32,
pub passed: bool,
}
#[derive(Debug, Clone, Copy)]
pub struct EquivTolerance {
pub max_abs_err: f32,
pub min_cosine_sim: f32,
}
impl EquivTolerance {
pub const fn new(max_abs_err: f32, min_cosine_sim: f32) -> Self {
Self { max_abs_err, min_cosine_sim }
}
}
pub fn check_equiv_with(
ref_vals: &[f32],
mt_vals: &[f32],
tolerance: EquivTolerance,
) -> EquivResult {
let n = ref_vals.len().min(mt_vals.len());
let mut max_err = 0.0f32;
let mut dot = 0.0f64;
let mut ref_norm_sq = 0.0f64;
let mut mt_norm_sq = 0.0f64;
for (&r, &m) in ref_vals[..n].iter().zip(&mt_vals[..n]) {
let err = (r - m).abs();
if err > max_err {
max_err = err;
}
let r = r as f64;
let m = m as f64;
dot += r * m;
ref_norm_sq += r * r;
mt_norm_sq += m * m;
}
let cosine_sim = match (ref_norm_sq > 0.0, mt_norm_sq > 0.0) {
(false, false) => 1.0,
(false, true) | (true, false) => 0.0,
(true, true) => {
let denom = ref_norm_sq.sqrt() * mt_norm_sq.sqrt();
(dot / denom) as f32
},
}
.clamp(-1.0, 1.0);
let same_len = ref_vals.len() == mt_vals.len();
EquivResult {
n_checked: n,
max_abs_err: max_err,
cosine_sim,
passed: same_len
&& max_err.is_finite()
&& cosine_sim.is_finite()
&& max_err <= tolerance.max_abs_err
&& cosine_sim >= tolerance.min_cosine_sim,
}
}
pub fn check_equiv(ref_vals: &[f32], mt_vals: &[f32], max_abs_err: f32) -> EquivResult {
check_equiv_with(ref_vals, mt_vals, EquivTolerance::new(max_abs_err, DEFAULT_MIN_COSINE_SIM))
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CorrectnessStatus {
Passed { max_abs_err: f32, cosine_sim: f32 },
Failed { max_abs_err: f32, cosine_sim: f32 },
Unchecked,
Unavailable,
}
#[derive(Debug, Clone, Copy)]
pub struct OpBench {
op: &'static str,
metric: &'static str,
}
impl OpBench {
pub const fn new(op: &'static str, metric: &'static str) -> Self { Self { op, metric } }
pub const fn op(&self) -> &'static str { self.op }
pub fn result(
&self,
shape: impl Into<String>,
ref_perf: Option<f64>,
mt_perf: Option<f64>,
equiv: Option<EquivResult>,
) -> OpResult {
self.result_sub(None::<&str>, shape, ref_perf, mt_perf, equiv)
}
pub fn result_sub(
&self,
subop: Option<impl Into<String>>,
shape: impl Into<String>,
ref_perf: Option<f64>,
mt_perf: Option<f64>,
equiv: Option<EquivResult>,
) -> OpResult {
self.result_sub_timed(subop, shape, ref_perf, mt_perf, equiv, None, None)
}
#[allow(clippy::too_many_arguments)]
pub fn result_sub_timed(
&self,
subop: Option<impl Into<String>>,
shape: impl Into<String>,
ref_perf: Option<f64>,
mt_perf: Option<f64>,
equiv: Option<EquivResult>,
mt_timing: Option<BenchStats>,
ref_timing: Option<BenchStats>,
) -> OpResult {
let shape = shape.into();
if mt_perf.is_some() && equiv.is_none() {
panic!("implemented benchmark '{}' [{}] is missing correctness", self.op, shape);
}
let result = OpResult {
op: self.op,
subop: subop.map(|s| s.into()),
shape,
metric: self.metric,
ref_perf,
mt_perf,
equiv,
mt_timing,
ref_timing,
};
report_result(&result);
result
}
pub fn implemented(
&self,
shape: impl Into<String>,
ref_perf: Option<f64>,
mt_perf: f64,
equiv: EquivResult,
) -> OpResult {
self.result(shape, ref_perf, Some(mt_perf), Some(equiv))
}
pub fn nyi(&self, shape: impl Into<String>, ref_perf: Option<f64>) -> OpResult {
self.result(shape, ref_perf, None, None)
}
}
pub struct OpResult {
op: &'static str,
subop: Option<String>,
shape: String,
metric: &'static str,
ref_perf: Option<f64>,
mt_perf: Option<f64>,
equiv: Option<EquivResult>,
pub mt_timing: Option<BenchStats>,
pub ref_timing: Option<BenchStats>,
}
impl OpResult {
pub fn op(&self) -> &'static str { self.op }
pub fn subop(&self) -> Option<&str> { self.subop.as_deref() }
pub fn op_display(&self) -> String {
match &self.subop {
Some(s) => format!("{} ({})", self.op, s),
None => self.op.to_string(),
}
}
pub fn shape(&self) -> &str { &self.shape }
pub fn metric(&self) -> &'static str { self.metric }
pub fn ref_perf(&self) -> Option<f64> { self.ref_perf }
pub fn mt_perf(&self) -> Option<f64> { self.mt_perf }
pub fn equiv(&self) -> Option<&EquivResult> { self.equiv.as_ref() }
pub fn pct(&self) -> Option<f64> {
match (self.ref_perf, self.mt_perf) {
(Some(r), Some(m)) if r > 0.0 => Some(m / r * 100.0),
_ => None,
}
}
pub fn correctness_status(&self) -> CorrectnessStatus {
match (&self.equiv, self.mt_perf) {
(Some(e), _) if e.passed =>
CorrectnessStatus::Passed { max_abs_err: e.max_abs_err, cosine_sim: e.cosine_sim },
(Some(e), _) =>
CorrectnessStatus::Failed { max_abs_err: e.max_abs_err, cosine_sim: e.cosine_sim },
(None, Some(_)) => CorrectnessStatus::Unchecked,
(None, None) => CorrectnessStatus::Unavailable,
}
}
pub fn correctness_cell(&self) -> String {
match self.correctness_status() {
CorrectnessStatus::Passed { max_abs_err, .. } =>
if max_abs_err < 1e-5 {
"✓".into()
} else {
format!("✓ {max_abs_err:.2e}")
},
CorrectnessStatus::Failed { max_abs_err, cosine_sim } =>
if cosine_sim < 0.999 {
format!("✗ {max_abs_err:.2e} cos={cosine_sim:.3}")
} else {
format!("✗ {max_abs_err:.2e}")
},
CorrectnessStatus::Unchecked => "! missing-check".into(),
CorrectnessStatus::Unavailable => "—".into(),
}
}
pub fn is_unchecked(&self) -> bool {
matches!(self.correctness_status(), CorrectnessStatus::Unchecked)
}
}
pub fn validate_results(results: &[OpResult]) -> Result<(), String> {
let unchecked: Vec<String> = results
.iter()
.filter(|r| r.is_unchecked())
.map(|r| format!("{} [{}]", r.op(), r.shape()))
.collect();
if unchecked.is_empty() {
Ok(())
} else {
Err(format!("implemented benchmarks missing correctness checks: {}", unchecked.join(", ")))
}
}
fn report_result(result: &OpResult) {
RESULT_REPORTER.with(|slot| {
if let Some(mut reporter) = *slot.borrow() {
unsafe {
reporter.as_mut()(result);
}
}
});
}
pub struct ResultReporterGuard {
previous: Option<ResultReporterFn>,
}
impl Drop for ResultReporterGuard {
fn drop(&mut self) {
RESULT_REPORTER.with(|slot| {
*slot.borrow_mut() = self.previous.take();
});
}
}
pub fn set_result_reporter(reporter: &mut dyn FnMut(&OpResult)) -> ResultReporterGuard {
let reporter: NonNull<dyn FnMut(&OpResult)> =
unsafe { std::mem::transmute(NonNull::from(reporter)) };
let previous = RESULT_REPORTER.with(|slot| (*slot.borrow_mut()).replace(reporter));
ResultReporterGuard { previous }
}
pub fn generate_elementwise_msl<F>(make_ir: F, label: &str) -> String
where F: Fn() -> Kernel {
MslGenerator::default().generate(&make_ir()).unwrap_or_else(|e| {
eprintln!("[{label}]: {e}");
String::new()
})
}
pub fn generate_reduction_msl<F>(make_ir: F, label: &str) -> String
where F: Fn() -> Kernel {
let mut k = make_ir();
k.mode = KernelMode::Reduction;
MslGenerator::default().generate(&k).unwrap_or_else(|e| {
eprintln!("[{label}]: {e}");
String::new()
})
}
pub struct DtypeCtx {
pub dt: DType,
pub tn: &'static str,
pub label: &'static str,
pub eb: usize,
pub tol: f32,
}
impl DtypeCtx {
pub fn reduce(dt: DType) -> Self {
Self {
dt,
tn: mlx_tname(dt),
label: dtype_label(dt),
eb: elem_bytes(dt),
tol: dtype_tol_reduce(dt),
}
}
pub fn elementwise(dt: DType) -> Self {
Self {
dt,
tn: mlx_tname(dt),
label: dtype_label(dt),
eb: elem_bytes(dt),
tol: dtype_tol(dt),
}
}
}
#[macro_export]
macro_rules! bench_tests {
(msl_fn: $msl_fn:ident, kernel_name: $name:expr) => {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn msl_generates_for_all_dtypes() {
for &dt in $crate::bench_types::FLOAT_DTYPES {
let msl = $msl_fn(dt);
assert!(!msl.trim().is_empty(), "MSL empty for {dt:?}");
}
}
#[cfg(target_os = "macos")]
#[test]
fn kernels_compile() {
}
}
};
}
#[cfg(test)]
mod tests {
use super::{CorrectnessStatus, EquivResult, OpBench, OpResult, check_equiv, validate_results};
fn sample_result(mt_perf: Option<f64>, equiv: Option<EquivResult>) -> OpResult {
OpBench::new("sample", "GB/s").result("shape", Some(1.0), mt_perf, equiv)
}
#[test]
fn correctness_status_distinguishes_unchecked_from_unavailable() {
let unchecked = OpResult {
op: "sample",
subop: None,
shape: "shape".into(),
metric: "GB/s",
ref_perf: Some(1.0),
mt_perf: Some(2.0),
equiv: None,
mt_timing: None,
ref_timing: None,
};
let unavailable = sample_result(None, None);
assert_eq!(unchecked.correctness_status(), CorrectnessStatus::Unchecked);
assert_eq!(unchecked.correctness_cell(), "! missing-check");
assert!(unchecked.is_unchecked());
assert_eq!(unavailable.correctness_status(), CorrectnessStatus::Unavailable);
assert_eq!(unavailable.correctness_cell(), "—");
}
#[test]
fn check_equiv_reports_cosine_similarity() {
let equiv = check_equiv(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.001], 1e-2);
assert_eq!(equiv.n_checked, 3);
assert!(equiv.passed);
assert!(equiv.cosine_sim > 0.999_999);
assert!(equiv.max_abs_err > 0.0);
}
#[test]
fn correctness_status_formats_checked_results() {
let passed = sample_result(
Some(2.0),
Some(EquivResult { n_checked: 16, max_abs_err: 0.0, cosine_sim: 1.0, passed: true }),
);
let failed = sample_result(
Some(2.0),
Some(EquivResult { n_checked: 16, max_abs_err: 1.5, cosine_sim: 0.5, passed: false }),
);
assert_eq!(passed.correctness_status(), CorrectnessStatus::Passed {
max_abs_err: 0.0,
cosine_sim: 1.0
});
assert_eq!(passed.correctness_cell(), "✓");
assert_eq!(failed.correctness_status(), CorrectnessStatus::Failed {
max_abs_err: 1.5,
cosine_sim: 0.5
});
assert_eq!(failed.correctness_cell(), "✗ 1.50e0 cos=0.500");
}
#[test]
#[should_panic(expected = "missing correctness")]
fn op_bench_rejects_implemented_row_without_correctness() {
let _ = OpBench::new("sample", "GB/s").result("shape", Some(1.0), Some(2.0), None);
}
#[test]
fn validation_reports_unchecked_rows() {
let unchecked = OpResult {
op: "sample",
subop: None,
shape: "shape".into(),
metric: "GB/s",
ref_perf: Some(1.0),
mt_perf: Some(2.0),
equiv: None,
mt_timing: None,
ref_timing: None,
};
let err = validate_results(&[unchecked]).expect_err("unchecked rows should fail");
assert!(err.contains("sample [shape]"));
}
}