#[cfg(feature = "cpu")]
fn main() {
use rlx_ir::*;
use rlx_runtime::{Device, Precision, PrecisionPolicy, Session};
let build = || {
let mut g = Graph::new("matmul_bias_gelu");
let x = g.input("x", Shape::new(&[2, 4], DType::F32));
let w = g.param("w", Shape::new(&[4, 3], DType::F32));
let b = g.param("b", Shape::new(&[3], DType::F32));
let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
let bias = g.binary(op::BinaryOp::Add, mm, b, Shape::new(&[2, 3], DType::F32));
let out = g.activation(op::Activation::Gelu, bias, Shape::new(&[2, 3], DType::F32));
g.set_outputs(vec![out]);
g
};
let w_data = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0];
let b_data = vec![0.5, -0.5, 0.0];
let x_data = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
println!("policy output");
println!("------------ ------------------------------");
for (name, policy) in [
("AlwaysF32", PrecisionPolicy::AlwaysF32),
("AutoMixed", PrecisionPolicy::AutoMixed),
] {
let session = Session::new_with_precision(Device::Cpu, Precision::F32).with_policy(policy);
let mut compiled = session.compile(build());
compiled.set_param("w", &w_data);
compiled.set_param("b", &b_data);
let out = compiled.run(&[("x", &x_data)]);
println!("{:<13} {:?}", name, out[0]);
assert!(
out[0][0] > 1.3 && out[0][0] < 1.5,
"{name}: out[0]={} not within tolerance",
out[0][0]
);
}
println!("\n✓ set_param + input feed + output read survive F32→F16 rewrite");
}
#[cfg(not(feature = "cpu"))]
fn main() {
eprintln!("requires --features cpu");
}