#[cfg(all(feature = "metal", target_os = "macos"))]
fn main() {
use rlx_ir::infer::GraphExt;
use rlx_ir::op::{Activation, BinaryOp, ReduceOp};
use rlx_ir::*;
use rlx_runtime::{Device, Session};
use std::time::Instant;
let warmup = 5usize;
let iters: usize = rlx_ir::env::var("RLX_BENCH_ITERS")
.and_then(|v| v.parse().ok())
.unwrap_or(50);
fn time_one(
device: Device,
build: impl Fn() -> Graph,
params: &[(&str, Vec<f32>)],
inputs: &[(&str, Vec<f32>)],
warmup: usize,
iters: usize,
) -> (f64, f64) {
let session = Session::new(device);
let mut compiled = session.compile(build());
for (n, d) in params {
compiled.set_param(n, d);
}
let cpu_inputs: Vec<(&str, &[f32])> =
inputs.iter().map(|(n, v)| (*n, v.as_slice())).collect();
for _ in 0..warmup {
let _ = compiled.run(&cpu_inputs);
}
let mut samples_us: Vec<f64> = Vec::with_capacity(iters);
for _ in 0..iters {
let t0 = Instant::now();
let _ = compiled.run(&cpu_inputs);
samples_us.push(t0.elapsed().as_secs_f64() * 1e6);
}
samples_us.sort_by(|a, b| a.partial_cmp(b).unwrap());
let median = samples_us[samples_us.len() / 2];
let p95 = samples_us[(samples_us.len() * 95 / 100).min(samples_us.len() - 1)];
(median, p95)
}
let f = DType::F32;
let det = |n: usize, mul: usize, modulo: usize, offset: f32| -> Vec<f32> {
(0..n)
.map(|i| ((i * mul + 7) % modulo) as f32 / modulo as f32 + offset)
.collect()
};
println!(
"{:<40} {:>14} {:>14} {:>14} {:>14} {:>10}",
"Op (shape)", "CPU med (µs)", "CPU p95", "Metal med", "Metal p95", "speedup"
);
println!("{:-<116}", "");
let bench = |name: String,
build: Box<dyn Fn() -> Graph>,
params: Vec<(&str, Vec<f32>)>,
inputs: Vec<(&str, Vec<f32>)>| {
let (cpu_med, cpu_p95) = time_one(Device::Cpu, &build, ¶ms, &inputs, warmup, iters);
let (m_med, m_p95) = time_one(Device::Metal, &build, ¶ms, &inputs, warmup, iters);
let speedup = cpu_med / m_med;
println!(
"{name:<40} {cpu_med:>14.2} {cpu_p95:>14.2} {m_med:>14.2} {m_p95:>14.2} {speedup:>9.2}x"
);
};
bench(
"MatMul (60, 768) × (768, 2304)".into(),
Box::new(|| {
let mut g = Graph::new("mm");
let x = g.input("x", Shape::new(&[60, 768], f));
let w = g.param("w", Shape::new(&[768, 2304], f));
let y = g.matmul(x, w, Shape::new(&[60, 2304], f));
g.set_outputs(vec![y]);
g
}),
vec![("w", det(768 * 2304, 17, 31, 0.001))],
vec![("x", det(60 * 768, 13, 23, 0.001))],
);
bench(
"MatMul (12, 3072) × (3072, 768)".into(),
Box::new(|| {
let mut g = Graph::new("mm");
let x = g.input("x", Shape::new(&[12, 3072], f));
let w = g.param("w", Shape::new(&[3072, 768], f));
let y = g.matmul(x, w, Shape::new(&[12, 768], f));
g.set_outputs(vec![y]);
g
}),
vec![("w", det(3072 * 768, 17, 31, 0.001))],
vec![("x", det(12 * 3072, 13, 23, 0.001))],
);
let elem_shape = [60usize, 768];
let elem_n = elem_shape[0] * elem_shape[1];
for (act_name, act) in &[
("gelu", Activation::Gelu),
("silu", Activation::Silu),
("relu", Activation::Relu),
] {
let act = *act;
bench(
format!("Activation::{act_name} (60, 768)"),
Box::new(move || {
let mut g = Graph::new("act");
let x = g.input("x", Shape::new(&[60, 768], f));
let y = g.activation(act, x, Shape::new(&[60, 768], f));
g.set_outputs(vec![y]);
g
}),
vec![],
vec![("x", det(elem_n, 13, 23, 0.1))],
);
}
bench(
"Binary::Add (60, 768)".into(),
Box::new(|| {
let mut g = Graph::new("bin");
let x = g.input("x", Shape::new(&[60, 768], f));
let y = g.input("y", Shape::new(&[60, 768], f));
let z = g.binary(BinaryOp::Add, x, y, Shape::new(&[60, 768], f));
g.set_outputs(vec![z]);
g
}),
vec![],
vec![
("x", det(elem_n, 7, 11, 0.1)),
("y", det(elem_n, 13, 19, 0.1)),
],
);
bench(
"LayerNorm (60, 768)".into(),
Box::new(|| {
let mut g = Graph::new("ln");
let x = g.input("x", Shape::new(&[60, 768], f));
let gamma = g.param("g", Shape::new(&[768], f));
let beta = g.param("b", Shape::new(&[768], f));
let y = g.ln(x, gamma, beta, 1e-5);
g.set_outputs(vec![y]);
g
}),
vec![("g", vec![1.0; 768]), ("b", vec![0.0; 768])],
vec![("x", det(elem_n, 13, 23, 0.5))],
);
bench(
"RmsNorm (60, 768)".into(),
Box::new(|| {
let mut g = Graph::new("rms");
let x = g.input("x", Shape::new(&[60, 768], f));
let gamma = g.param("g", Shape::new(&[768], f));
let beta = g.param("b", Shape::new(&[768], f));
let y = g.rms_norm(x, gamma, beta, 1e-5);
g.set_outputs(vec![y]);
g
}),
vec![("g", vec![1.0; 768]), ("b", vec![0.0; 768])],
vec![("x", det(elem_n, 13, 23, 0.5))],
);
bench(
"Reduce::Sum (60, 768) → (60,)".into(),
Box::new(|| {
let mut g = Graph::new("red");
let x = g.input("x", Shape::new(&[60, 768], f));
let y = g.add_node(
Op::Reduce {
op: ReduceOp::Sum,
axes: vec![1],
keep_dim: false,
},
vec![x],
Shape::new(&[60], f),
);
g.set_outputs(vec![y]);
g
}),
vec![],
vec![("x", det(elem_n, 7, 17, 0.1))],
);
bench(
"Softmax (60, 768)".into(),
Box::new(|| {
let mut g = Graph::new("sm");
let x = g.input("x", Shape::new(&[60, 768], f));
let y = g.add_node(Op::Softmax { axis: -1 }, vec![x], Shape::new(&[60, 768], f));
g.set_outputs(vec![y]);
g
}),
vec![],
vec![("x", det(elem_n, 7, 17, 0.1))],
);
bench(
"Transpose (768, 60) → (60, 768)".into(),
Box::new(|| {
let mut g = Graph::new("tr");
let x = g.input("x", Shape::new(&[768, 60], f));
let y = g.add_node(
Op::Transpose { perm: vec![1, 0] },
vec![x],
Shape::new(&[60, 768], f),
);
g.set_outputs(vec![y]);
g
}),
vec![],
vec![("x", det(elem_n, 7, 17, 0.1))],
);
bench(
"Gather axis=0 (vocab=30k, idx=60)".into(),
Box::new(|| {
let mut g = Graph::new("g0");
let table = g.param("t", Shape::new(&[30000, 768], f));
let idx = g.input("idx", Shape::new(&[60], f));
let y = g.gather_(table, idx, 0);
g.set_outputs(vec![y]);
g
}),
vec![("t", det(30000 * 768, 7, 31, 0.001))],
vec![("idx", (0..60).map(|i| (i * 17 % 30000) as f32).collect())],
);
bench(
"Concat last-axis (60, 768) ⊕ (60, 768)".into(),
Box::new(|| {
let mut g = Graph::new("ct");
let a = g.param("a", Shape::new(&[60, 768], f));
let b = g.param("b", Shape::new(&[60, 768], f));
let y = g.add_node(
Op::Concat { axis: 1 },
vec![a, b],
Shape::new(&[60, 1536], f),
);
g.set_outputs(vec![y]);
g
}),
vec![
("a", det(elem_n, 7, 17, 0.1)),
("b", det(elem_n, 11, 19, 0.1)),
],
vec![],
);
bench(
"Attention SDPA (b=1, s=15, h=768)".into(),
Box::new(|| {
let mut g = Graph::new("attn");
let nh = 12;
let dh = 64;
let h = nh * dh;
let q = g.input("q", Shape::new(&[1, 15, h], f));
let k = g.input("k", Shape::new(&[1, 15, h], f));
let v = g.input("v", Shape::new(&[1, 15, h], f));
let mask = g.input("mask", Shape::new(&[1, 15], f));
let y = g.attention_(q, k, v, mask, nh, dh);
g.set_outputs(vec![y]);
g
}),
vec![],
vec![
("q", det(15 * 768, 5, 13, 0.1)),
("k", det(15 * 768, 7, 17, 0.1)),
("v", det(15 * 768, 11, 19, 0.1)),
("mask", vec![1.0; 15]),
],
);
bench(
"SwiGLU (60, 768) → (60, 2048)".into(),
Box::new(|| {
let m = 60usize;
let k = 768usize;
let n = 2048usize;
let mut g = Graph::new("swiglu");
let x = g.input("x", Shape::new(&[m, k], f));
let w_up = g.param("w_up", Shape::new(&[k, n], f));
let w_gate = g.param("w_gate", Shape::new(&[k, n], f));
let up_mm = g.matmul(x, w_up, Shape::new(&[m, n], f));
let gate_mm = g.matmul(x, w_gate, Shape::new(&[m, n], f));
let gate = g.activation(Activation::Silu, gate_mm, Shape::new(&[m, n], f));
let y = g.binary(BinaryOp::Mul, up_mm, gate, Shape::new(&[m, n], f));
g.set_outputs(vec![y]);
g
}),
vec![
("w_up", det(768 * 2048, 11, 31, 0.001)),
("w_gate", det(768 * 2048, 13, 29, 0.001)),
],
vec![("x", det(60 * 768, 7, 23, 0.1))],
);
println!();
println!("(all timings include the run() overhead — graph dispatch, encoder");
println!(" setup on Metal, output read-back. Use RLX_BENCH_ITERS=N to vary.)");
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
fn main() {
eprintln!("op_bench requires --features metal on macOS");
}