use std::path::PathBuf;
use rlx_ir::Tick;
use serde::{Deserialize, Serialize};
use crate::array::{Array, MlxError, device_name, eval};
use crate::ffi::{MlxMask, MlxReduce};
use crate::ops;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Calibration {
pub device_name: String,
pub sgemm_large_flops: f64,
pub sgemm_small_flops: f64,
pub roundtrip_overhead_ns: f64,
#[serde(default)]
pub memory_bw_gbps: f64,
#[serde(default)]
pub attention_flops: f64,
#[serde(default)]
pub reduce_gbps: f64,
}
fn sanitize(s: &str) -> String {
s.chars()
.map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
.collect()
}
fn cache_path(device_name: &str) -> PathBuf {
let home = std::env::var("HOME").unwrap_or_else(|_| ".".into());
let dir = PathBuf::from(home).join(".cache").join("rlx");
let _ = std::fs::create_dir_all(&dir);
let key = if device_name.is_empty() {
"default".into()
} else {
sanitize(device_name)
};
dir.join(format!("mlx-calib-{key}.json"))
}
impl Calibration {
pub fn load(name: &str) -> Option<Self> {
let raw = std::fs::read_to_string(cache_path(name)).ok()?;
let cal: Calibration = serde_json::from_str(&raw).ok()?;
if cal.device_name == name {
Some(cal)
} else {
None
}
}
pub fn save(&self) -> std::io::Result<()> {
let path = cache_path(&self.device_name);
let raw = serde_json::to_string_pretty(self).map_err(std::io::Error::other)?;
std::fs::write(path, raw)
}
pub fn measure() -> Result<Self, MlxError> {
let name = device_name();
let measure_matmul = |m: usize, k: usize, n: usize| -> Result<f64, MlxError> {
let lhs_data = (0..m * k).map(|i| (i as f32) / 257.0).collect::<Vec<_>>();
let rhs_data = (0..k * n)
.map(|i| ((i + 7) as f32) / 257.0)
.collect::<Vec<_>>();
let lhs = Array::from_f32_slice(&lhs_data, &[m, k], rlx_ir::DType::F32)?;
let rhs = Array::from_f32_slice(&rhs_data, &[k, n], rlx_ir::DType::F32)?;
let warm = ops::matmul(&lhs, &rhs)?;
eval(&[&warm])?;
const N_ITER: usize = 50;
let t0 = Tick::now();
let mut outs: Vec<Array> = Vec::with_capacity(N_ITER);
for _ in 0..N_ITER {
outs.push(ops::matmul(&lhs, &rhs)?);
}
let refs: Vec<&Array> = outs.iter().collect();
eval(&refs)?;
let total_ns = Tick::now().elapsed_ns(t0) as f64;
let total_s = total_ns / 1e9;
let flops = 2.0 * m as f64 * k as f64 * n as f64 * N_ITER as f64;
Ok(flops / total_s)
};
let sgemm_large = measure_matmul(256, 768, 3072)?;
let sgemm_small = measure_matmul(8, 64, 64)?;
let roundtrip_ns = {
let lhs = Array::from_f32_slice(&[1.0, 2.0], &[1, 2], rlx_ir::DType::F32)?;
let rhs = Array::from_f32_slice(&[3.0, 4.0], &[2, 1], rlx_ir::DType::F32)?;
const N_ITER: usize = 10;
let t0 = Tick::now();
for _ in 0..N_ITER {
let y = ops::matmul(&lhs, &rhs)?;
eval(&[&y])?;
}
Tick::now().elapsed_ns(t0) as f64 / N_ITER as f64
};
let memory_bw_gbps = {
const N: usize = 1024 * 1024;
let data: Vec<f32> = (0..N).map(|i| i as f32 * 0.001).collect();
let a = Array::from_f32_slice(&data, &[N], rlx_ir::DType::F32)?;
let zero = Array::from_f32_slice(&[0.0], &[1], rlx_ir::DType::F32)?;
let warm = ops::add(&a, &zero)?;
eval(&[&warm])?;
const N_ITER: usize = 20;
let t0 = Tick::now();
let mut outs = Vec::with_capacity(N_ITER);
for _ in 0..N_ITER {
outs.push(ops::add(&a, &zero)?);
}
let refs: Vec<&Array> = outs.iter().collect();
eval(&refs)?;
let total_ns = Tick::now().elapsed_ns(t0) as f64;
let bytes_per_iter = (N * 4 * 2) as f64;
(bytes_per_iter * N_ITER as f64) / total_ns
};
let attention_flops = {
const B: usize = 1;
const H: usize = 4;
const S: usize = 128;
const D: usize = 64;
let q_data: Vec<f32> = (0..B * H * S * D)
.map(|i| (i as f32 % 17.0) * 0.01)
.collect();
let q = Array::from_f32_slice(&q_data, &[B, H, S, D], rlx_ir::DType::F32)?;
let k = Array::from_f32_slice(&q_data, &[B, H, S, D], rlx_ir::DType::F32)?;
let v = Array::from_f32_slice(&q_data, &[B, H, S, D], rlx_ir::DType::F32)?;
let scale = 1.0 / (D as f32).sqrt();
let warm = ops::attention(&q, &k, &v, scale, MlxMask::None, None)?;
eval(&[&warm])?;
const N_ITER: usize = 20;
let t0 = Tick::now();
let mut outs = Vec::with_capacity(N_ITER);
for _ in 0..N_ITER {
outs.push(ops::attention(&q, &k, &v, scale, MlxMask::None, None)?);
}
let refs: Vec<&Array> = outs.iter().collect();
eval(&refs)?;
let total_ns = Tick::now().elapsed_ns(t0) as f64;
let total_s = total_ns / 1e9;
let flops_per_iter = 4.0 * B as f64 * H as f64 * S as f64 * S as f64 * D as f64;
(flops_per_iter * N_ITER as f64) / total_s
};
let reduce_gbps = {
const M: usize = 1024;
const N: usize = 1024;
let data: Vec<f32> = (0..M * N).map(|i| i as f32 * 0.001).collect();
let a = Array::from_f32_slice(&data, &[M, N], rlx_ir::DType::F32)?;
let warm = ops::reduce(&a, MlxReduce::Sum, &[1], false)?;
eval(&[&warm])?;
const N_ITER: usize = 20;
let t0 = Tick::now();
let mut outs = Vec::with_capacity(N_ITER);
for _ in 0..N_ITER {
outs.push(ops::reduce(&a, MlxReduce::Sum, &[1], false)?);
}
let refs: Vec<&Array> = outs.iter().collect();
eval(&refs)?;
let total_ns = Tick::now().elapsed_ns(t0) as f64;
let bytes_per_iter = (M * N * 4) as f64;
(bytes_per_iter * N_ITER as f64) / total_ns
};
Ok(Calibration {
device_name: name,
sgemm_large_flops: sgemm_large,
sgemm_small_flops: sgemm_small,
roundtrip_overhead_ns: roundtrip_ns,
memory_bw_gbps,
attention_flops,
reduce_gbps,
})
}
pub fn load_or_measure() -> Self {
let name = device_name();
if let Some(cal) = Self::load(&name) {
return cal;
}
let verbose = rlx_ir::env::var("RLX_VERBOSE")
.and_then(|v| v.parse::<u8>().ok())
.unwrap_or(0)
>= 1;
if verbose {
eprintln!("[rlx-mlx] no calibration cache for '{name}'; measuring...");
}
let cal = Self::measure().unwrap_or_else(|e| {
if verbose {
eprintln!("[rlx-mlx] calibration failed: {e}; using conservative defaults");
}
Calibration {
device_name: name.clone(),
sgemm_large_flops: 1.0e12,
sgemm_small_flops: 5.0e10,
roundtrip_overhead_ns: 200_000.0,
memory_bw_gbps: 200.0,
attention_flops: 5.0e11,
reduce_gbps: 150.0,
}
});
if verbose {
eprintln!(
"[rlx-mlx] calibrated: large={:.0} GF/s, small={:.0} GF/s, \
rt={:.0}µs, mem={:.0} GB/s, attn={:.0} GF/s, reduce={:.0} GB/s",
cal.sgemm_large_flops / 1e9,
cal.sgemm_small_flops / 1e9,
cal.roundtrip_overhead_ns / 1000.0,
cal.memory_bw_gbps,
cal.attention_flops / 1e9,
cal.reduce_gbps
);
}
let _ = cal.save();
cal
}
}