#[cfg(all(feature = "cpu", feature = "metal", target_os = "macos"))]
fn main() {
use rlx_ir::*;
use rlx_runtime::{Device, Session};
let vocab = 32;
let max_seq = 16;
let tts = 2;
let b = 4;
let s = 8;
let h = 384;
let build = || {
let mut g = Graph::new("embed");
let ids = g.input("ids", Shape::new(&[b, s], DType::F32));
let pos = g.input("pos", Shape::new(&[b, s], DType::F32));
let tt = g.input("tt", Shape::new(&[b, s], DType::F32));
let we = g.param("we", Shape::new(&[vocab, h], DType::F32));
let pe = g.param("pe", Shape::new(&[max_seq, h], DType::F32));
let te = g.param("te", Shape::new(&[tts, h], DType::F32));
let lng = g.param("lng", Shape::new(&[h], DType::F32));
let lnb = g.param("lnb", Shape::new(&[h], DType::F32));
let we_o = g.gather_(we, ids, 0);
let pe_o = g.gather_(pe, pos, 0);
let te_o = g.gather_(te, tt, 0);
let s1 = g.add(we_o, pe_o);
let s2 = g.add(s1, te_o);
g.set_outputs(vec![s2]);
let _ = (lng, lnb);
g
};
let we_data: Vec<f32> = (0..vocab * h)
.map(|i| ((i as f32) * 0.001).sin() * 0.1)
.collect();
let pe_data: Vec<f32> = (0..max_seq * h)
.map(|i| ((i as f32) * 0.002).cos() * 0.1)
.collect();
let te_data: Vec<f32> = (0..tts * h)
.map(|i| ((i as f32) * 0.003).sin() * 0.1)
.collect();
let ids_data: Vec<f32> = (0..b * s).map(|i| ((i * 7 + 13) % vocab) as f32).collect();
let pos_data: Vec<f32> = (0..b).flat_map(|_| (0..s).map(|j| j as f32)).collect();
let tt_data: Vec<f32> = (0..b * s).map(|i| (i % 2) as f32).collect();
let lng_data = vec![1.0f32; h];
let lnb_data = vec![0.0f32; h];
let run_with = |use_mpsg: bool, dev: Device| -> Vec<f32> {
if use_mpsg {
rlx_ir::env::set("RLX_USE_MPSGRAPH", "1");
} else {
unsafe {
rlx_ir::env::unset("RLX_USE_MPSGRAPH");
}
}
let session = Session::new(dev);
let mut compiled = session.compile(build());
compiled.set_param("we", &we_data);
compiled.set_param("pe", &pe_data);
compiled.set_param("te", &te_data);
compiled.set_param("lng", &lng_data);
compiled.set_param("lnb", &lnb_data);
let outs = compiled.run(&[("ids", &ids_data), ("pos", &pos_data), ("tt", &tt_data)]);
outs.into_iter().next().unwrap_or_default()
};
let cpu = run_with(false, Device::Cpu);
let metal_thunk = run_with(false, Device::Metal);
let metal_mpsg = run_with(true, Device::Metal);
println!("CPU thunk[..6]: {:?}", &cpu[..6]);
println!("Metal thunk[..6]: {:?}", &metal_thunk[..6]);
println!("Metal MPSG[..6]: {:?}", &metal_mpsg[..6]);
let diff = |a: &[f32], b: &[f32]| -> f32 {
a.iter()
.zip(b)
.map(|(x, y)| (x - y).abs())
.fold(0f32, f32::max)
};
println!(
"\nmax_err CPU vs Metal-thunk: {:.3e}",
diff(&cpu, &metal_thunk)
);
println!(
"max_err CPU vs Metal-MPSG: {:.3e}",
diff(&cpu, &metal_mpsg)
);
}
#[cfg(not(all(feature = "cpu", feature = "metal", target_os = "macos")))]
fn main() {
eprintln!("requires cpu + metal on macOS");
}