use half::{bf16, f16};
use rlx_ir::op::{Activation, BinaryOp};
use rlx_ir::{DType, Graph, Op, Shape};
use rlx_runtime::{CompileOptions, Device, Session};
fn build_param_add_graph() -> Graph {
let mut g = Graph::new("typed_io_param_add");
let x = g.input("x", Shape::new(&[4], DType::F32));
let b = g.param("b", Shape::new(&[4], DType::F32));
let s = g.binary(BinaryOp::Add, x, b, Shape::new(&[4], DType::F32));
g.set_outputs(vec![s]);
g
}
fn build_f32_relu_graph() -> Graph {
let mut g = Graph::new("typed_io_relu");
let x = g.input("x", Shape::new(&[6], DType::F32));
let r = g.add_node(
Op::Activation(Activation::Relu),
vec![x],
Shape::new(&[6], DType::F32),
);
g.set_outputs(vec![r]);
g
}
#[cfg(feature = "cpu")]
#[test]
fn cpu_set_param_typed_f16_widens_to_f32_and_runs() {
let g = build_param_add_graph();
let session = Session::new(Device::Cpu);
let mut compiled = session.compile_with(g, &CompileOptions::default());
let b_f16: Vec<f16> = vec![1.0f32, 2.0, 3.0, 4.0]
.into_iter()
.map(f16::from_f32)
.collect();
let b_bytes: Vec<u8> = b_f16.iter().flat_map(|h| h.to_le_bytes()).collect();
compiled.set_param_typed("b", &b_bytes, DType::F16);
let xs: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
let outs = compiled.run(&[("x", &xs)]);
assert_eq!(outs[0], vec![11.0, 22.0, 33.0, 44.0]);
}
#[cfg(feature = "cpu")]
#[test]
fn cpu_set_param_typed_bf16_widens_to_f32() {
let g = build_param_add_graph();
let session = Session::new(Device::Cpu);
let mut compiled = session.compile_with(g, &CompileOptions::default());
let b_bf16: Vec<bf16> = vec![1.5f32, 2.5, 3.5, 4.5]
.into_iter()
.map(bf16::from_f32)
.collect();
let b_bytes: Vec<u8> = b_bf16.iter().flat_map(|h| h.to_le_bytes()).collect();
compiled.set_param_typed("b", &b_bytes, DType::BF16);
let xs: Vec<f32> = vec![0.5, 0.5, 0.5, 0.5];
let outs = compiled.run(&[("x", &xs)]);
assert_eq!(outs[0], vec![2.0, 3.0, 4.0, 5.0]);
}
#[cfg(feature = "cpu")]
#[test]
fn cpu_run_typed_with_f16_input_widens_and_runs() {
let g = build_f32_relu_graph();
let session = Session::new(Device::Cpu);
let mut compiled = session.compile_with(g, &CompileOptions::default());
let xs: Vec<f32> = vec![-1.0, 0.0, 0.5, 1.0, 2.5, -2.0];
let xs_f16: Vec<f16> = xs.iter().map(|&v| f16::from_f32(v)).collect();
let xs_bytes: Vec<u8> = xs_f16.iter().flat_map(|h| h.to_le_bytes()).collect();
let outs = compiled.run_typed(&[("x", &xs_bytes, DType::F16)]);
assert_eq!(outs.len(), 1);
let (bytes, dt) = &outs[0];
assert_eq!(
*dt,
DType::F32,
"graph output is F32; run_typed reports it as such"
);
assert_eq!(bytes.len(), 24, "6 elems × 4 bytes per F32");
let got: Vec<f32> = bytes
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
let want: Vec<f32> = xs.iter().map(|v| v.max(0.0)).collect();
assert_eq!(got, want);
}
#[cfg(all(feature = "metal", target_os = "macos"))]
#[test]
fn metal_run_typed_with_f16_input_widens_and_runs() {
let g = build_f32_relu_graph();
let session = Session::new(Device::Metal);
let mut compiled = session.compile_with(g, &CompileOptions::default());
let xs: Vec<f32> = vec![-1.0, 0.0, 0.5, 1.0, 2.5, -2.0];
let xs_f16: Vec<f16> = xs.iter().map(|&v| f16::from_f32(v)).collect();
let xs_bytes: Vec<u8> = xs_f16.iter().flat_map(|h| h.to_le_bytes()).collect();
let outs = compiled.run_typed(&[("x", &xs_bytes, DType::F16)]);
let (bytes, dt) = &outs[0];
assert_eq!(*dt, DType::F32);
assert_eq!(bytes.len(), 24);
let got: Vec<f32> = bytes
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
let want: Vec<f32> = xs.iter().map(|v| v.max(0.0)).collect();
assert_eq!(got, want);
}
#[cfg(feature = "cpu")]
#[test]
#[ignore = "needs Thunk::CastDtype — Op::Cast currently lowers to Thunk::Copy (rlx-cpu/src/thunk.rs:671)"]
fn cpu_run_typed_narrows_f16_output_via_cast_thunk() {
let mut g = Graph::new("typed_io_narrow_f16");
let x = g.input("x", Shape::new(&[6], DType::F32));
let r = g.add_node(
Op::Activation(Activation::Relu),
vec![x],
Shape::new(&[6], DType::F32),
);
let c = g.add_node(
Op::Cast { to: DType::F16 },
vec![r],
Shape::new(&[6], DType::F16),
);
g.set_outputs(vec![c]);
let session = Session::new(Device::Cpu);
let mut compiled = session.compile_with(g, &CompileOptions::default());
let xs: Vec<f32> = vec![-1.0, 0.0, 0.5, 1.0, 2.5, -2.0];
let xs_bytes: Vec<u8> = xs.iter().flat_map(|v| v.to_le_bytes()).collect();
let outs = compiled.run_typed(&[("x", &xs_bytes, DType::F32)]);
let (bytes, dt) = &outs[0];
assert_eq!(*dt, DType::F16, "graph output is F16");
assert_eq!(bytes.len(), 12, "6 elems * 2 bytes per F16");
let got: Vec<f16> = bytes
.chunks_exact(2)
.map(|b| f16::from_le_bytes([b[0], b[1]]))
.collect();
let want: Vec<f16> = xs.iter().map(|v| f16::from_f32(v.max(0.0))).collect();
assert_eq!(got, want);
}
#[cfg(feature = "cpu")]
#[test]
fn cpu_last_axis_broadcast_in_chain_matches_reference() {
use rlx_ir::op::BinaryOp;
let mut g = Graph::new("cpu_lastaxis_broadcast");
let x = g.input("x", Shape::new(&[2, 3, 4], DType::F32));
let bias = g.input("bias", Shape::new(&[4], DType::F32));
let scale = g.input("scale", Shape::new(&[4], DType::F32));
let s = Shape::new(&[2, 3, 4], DType::F32);
let add = g.binary(BinaryOp::Add, x, bias, s.clone());
let mul = g.binary(BinaryOp::Mul, add, scale, s);
g.set_outputs(vec![mul]);
let session = Session::new(Device::Cpu);
let mut compiled = session.compile_with(g, &CompileOptions::default());
let xs: Vec<f32> = (0..24).map(|i| i as f32).collect();
let bias_v = vec![10.0f32, 20.0, 30.0, 40.0];
let scale_v = vec![1.0f32, 2.0, 3.0, 4.0];
let outs = compiled.run(&[("x", &xs), ("bias", &bias_v), ("scale", &scale_v)]);
let want: Vec<f32> = xs
.iter()
.enumerate()
.map(|(gid, &v)| (v + bias_v[gid % 4]) * scale_v[gid % 4])
.collect();
assert_eq!(outs[0], want);
}
#[cfg(feature = "cpu")]
#[test]
fn cpu_scalar_broadcast_in_chain_matches_reference() {
use rlx_ir::op::BinaryOp;
let mut g = Graph::new("cpu_scalar_chain");
let x = g.input("x", Shape::new(&[6], DType::F32));
let bias = g.input("bias", Shape::new(&[1], DType::F32));
let scale = g.input("scale", Shape::new(&[1], DType::F32));
let s = Shape::new(&[6], DType::F32);
let add = g.binary(BinaryOp::Add, x, bias, s.clone());
let mul = g.binary(BinaryOp::Mul, add, scale, s);
g.set_outputs(vec![mul]);
let session = Session::new(Device::Cpu);
let mut compiled = session.compile_with(g, &CompileOptions::default());
let xs: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let bias_v = 0.5f32;
let scale_v = 2.0f32;
let outs = compiled.run(&[("x", &xs), ("bias", &[bias_v]), ("scale", &[scale_v])]);
let want: Vec<f32> = xs.iter().map(|v| (v + bias_v) * scale_v).collect();
assert_eq!(outs[0], want);
}
#[cfg(all(feature = "metal", target_os = "macos"))]
#[test]
fn metal_set_param_typed_f16_widens_to_f32() {
let g = build_param_add_graph();
let session = Session::new(Device::Metal);
let mut compiled = session.compile_with(g, &CompileOptions::default());
let b_f16: Vec<f16> = vec![1.0f32, 2.0, 3.0, 4.0]
.into_iter()
.map(f16::from_f32)
.collect();
let b_bytes: Vec<u8> = b_f16.iter().flat_map(|h| h.to_le_bytes()).collect();
compiled.set_param_typed("b", &b_bytes, DType::F16);
let xs: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
let outs = compiled.run(&[("x", &xs)]);
assert_eq!(outs[0], vec![11.0, 22.0, 33.0, 44.0]);
}