#[cfg(all(feature = "metal", target_os = "macos"))]
fn main() {
use rlx_ir::*;
use rlx_runtime::{Device, Precision, PrecisionPolicy, Session};
let w_data = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0];
let b_data = vec![0.5, -0.5, 0.0];
let x_data = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
let aot_graph = || {
let mut g = Graph::new("aot");
let x = g.input("x", Shape::new(&[2, 4], DType::F32));
let w = g.param("w", Shape::new(&[4, 3], DType::F32));
let b = g.param("b", Shape::new(&[3], DType::F32));
let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
let bias = g.binary(op::BinaryOp::Add, mm, b, Shape::new(&[2, 3], DType::F32));
let out = g.activation(op::Activation::Gelu, bias, Shape::new(&[2, 3], DType::F32));
g.set_outputs(vec![out]);
g
};
let jit_graph = || {
use rlx_runtime::trace::trace;
trace("jit", |t| {
let x = t.input("x", &[2, 4], DType::F32);
let w = t.param("w", &[4, 3], DType::F32);
let b = t.param("b", &[3], DType::F32);
let mm = t.matmul(x, w);
let bias = mm + b;
let out = bias.gelu();
vec![out]
})
};
{
use rlx_opt::pass::Pass;
let g = aot_graph();
println!("Original graph:");
for n in g.nodes() {
println!(
" [{}] {} {:?} → {:?}",
n.id,
n.op,
n.inputs,
n.shape.dtype()
);
}
let pass = rlx_opt::AutoMixedPrecision::new(PrecisionPolicy::AutoMixed);
let g2 = pass.run(g);
println!("\nAfter AutoMixedPrecision:");
for n in g2.nodes() {
println!(
" [{}] {} {:?} → {:?}",
n.id,
n.op,
n.inputs,
n.shape.dtype()
);
}
println!(" outputs: {:?}", g2.outputs);
println!();
}
println!("Comparing AOT vs JIT × precision policies × devices:\n");
println!(
"{:<6} {:<6} {:<12} {:?}",
"mode", "device", "policy", "output"
);
println!("{}", "-".repeat(80));
let modes: [(&str, fn() -> Graph); 2] = [("AOT", aot_graph), ("JIT", jit_graph)];
for &(mode, build_graph) in &modes {
for &(dev_name, dev) in &[("CPU", Device::Cpu), ("Metal", Device::Metal)] {
for (policy_name, policy) in [
("F32", PrecisionPolicy::AlwaysF32),
("AutoMixed", PrecisionPolicy::AutoMixed),
] {
let session = Session::new_with_precision(dev, Precision::F32).with_policy(policy);
let mut compiled = session.compile(build_graph());
compiled.set_param("w", &w_data);
compiled.set_param("b", &b_data);
let out = compiled.run(&[("x", &x_data)]);
println!(
"{:<6} {:<6} {:<12} {:?}",
mode, dev_name, policy_name, out[0]
);
}
}
}
println!("\nAll modes use the same Session + PrecisionPolicy API — the");
println!("AutoMixedPrecision pass runs as a graph rewrite regardless of");
println!("how the graph was built (AOT, JIT, or proc-macro AOT).");
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
fn main() {
eprintln!("requires --features metal on macOS");
}