use rlx_ir::op::{Activation, BinaryOp};
use rlx_ir::*;
use rlx_runtime::{CompileCache, Device, Session};
use std::collections::HashSet;
use std::time::Instant;
fn build(seq: usize) -> Graph {
let mut g = Graph::new("toy");
let f = DType::F32;
let h = 256usize;
let int_dim = 1024usize;
let x = g.input("x", Shape::new(&[seq, h], f));
let w1 = g.param("w1", Shape::new(&[h, int_dim], f));
let w2 = g.param("w2", Shape::new(&[int_dim, h], f));
let mm1 = g.matmul(x, w1, Shape::new(&[seq, int_dim], f));
let act = g.activation(Activation::Silu, mm1, Shape::new(&[seq, int_dim], f));
let mm2 = g.matmul(act, w2, Shape::new(&[seq, h], f));
let res = g.binary(BinaryOp::Add, mm2, x, Shape::new(&[seq, h], f));
g.set_outputs(vec![res]);
g
}
fn det(n: usize) -> Vec<f32> {
(0..n).map(|i| (i as f32 % 17.0) * 0.01 - 0.05).collect()
}
#[cfg(target_os = "macos")]
fn main() {
let calls: Vec<usize> = (0..200)
.map(|i| match i % 4 {
0 => 8,
1 => 16,
2 => 32,
_ => 64,
})
.collect();
let h = 256;
let int_dim = 1024;
let w1 = det(h * int_dim);
let w2 = det(int_dim * h);
let t0 = Instant::now();
for &seq in &calls {
let session = Session::new(Device::Metal);
let mut compiled = session.compile(build(seq));
compiled.set_param("w1", &w1);
compiled.set_param("w2", &w2);
let x_data = det(seq * h);
let _ = compiled.run(&[("x", &x_data)]);
}
let no_cache = t0.elapsed();
let mut cache = CompileCache::new(Device::Metal, 8);
let mut params_loaded: HashSet<u64> = HashSet::new();
let t0 = Instant::now();
for &seq in &calls {
let key = seq as u64;
let first_time = !params_loaded.contains(&key);
let compiled = cache.get_or_compile(key, || build(seq));
if first_time {
compiled.set_param("w1", &w1);
compiled.set_param("w2", &w2);
params_loaded.insert(key);
}
let x_data = det(seq * h);
let _ = compiled.run(&[("x", &x_data)]);
}
let with_cache = t0.elapsed();
println!("Calls: {} (mixed seq lengths: 8, 16, 32, 64)", calls.len());
println!(
" no cache : {:>8} ms total ({:.2} ms/call)",
no_cache.as_millis(),
no_cache.as_secs_f64() * 1000.0 / calls.len() as f64
);
println!(
" with cache : {:>8} ms total ({:.2} ms/call)",
with_cache.as_millis(),
with_cache.as_secs_f64() * 1000.0 / calls.len() as f64
);
let speedup = no_cache.as_secs_f64() / with_cache.as_secs_f64();
println!(
" speedup : {speedup:.2}x (cached unique shapes: {})",
cache.len()
);
println!();
println!("(no-cache repeats compile + memory-plan + arena alloc + param");
println!(" copy on every call. Cache hits skip all of that — the call");
println!(" cost collapses to encode + commit + wait.)");
}
#[cfg(not(target_os = "macos"))]
fn main() {
eprintln!("compile_cache_demo requires macOS");
}