#![cfg(any(target_os = "macos", target_os = "ios"))]
#![allow(clippy::excessive_precision)]
use rlx_coreml::CoremlExecutable;
use rlx_ir::op::Activation;
use rlx_ir::{DType, Graph, Shape};
fn approx(a: &[f32], b: &[f32], tol: f32) {
assert_eq!(
a.len(),
b.len(),
"length mismatch: {} vs {}",
a.len(),
b.len()
);
let mx = a
.iter()
.zip(b)
.map(|(x, y)| (x - y).abs())
.fold(0.0f32, f32::max);
assert!(
mx <= tol,
"max abs diff {mx} > {tol}\n got {a:?}\n want {b:?}"
);
}
fn run_unary(act: Activation, x: &[f32]) -> Vec<f32> {
let n = x.len();
let mut g = Graph::new("unary");
let xi = g.input("x", Shape::new(&[n], DType::F32));
let y = g.activation(act, xi, Shape::new(&[n], DType::F32));
g.set_outputs(vec![y]);
let mut exe = CoremlExecutable::compile(g);
exe.run(&[("x", x)]).expect("run").remove(0)
}
#[test]
fn gelu_exact() {
let x = [-2.0f32, -0.5, 0.0, 0.5, 2.0];
let got = run_unary(Activation::Gelu, &x);
let want: Vec<f32> = x
.iter()
.map(|&v| v * 0.5 * (1.0 + libm_erf(v / std::f32::consts::SQRT_2)))
.collect();
approx(&got, &want, 1e-3);
}
fn libm_erf(x: f32) -> f32 {
let t = 1.0 / (1.0 + 0.3275911 * x.abs());
let y = 1.0
- (((((1.061405429 * t - 1.453152027) * t) + 1.421413741) * t - 0.284496736) * t
+ 0.254829592)
* t
* (-x * x).exp();
if x < 0.0 { -y } else { y }
}
#[test]
fn silu_and_sigmoid() {
let x = [-2.0f32, -0.5, 0.0, 0.5, 2.0];
let sig = run_unary(Activation::Sigmoid, &x);
let silu = run_unary(Activation::Silu, &x);
let want_sig: Vec<f32> = x.iter().map(|&v| 1.0 / (1.0 + (-v).exp())).collect();
let want_silu: Vec<f32> = x.iter().map(|&v| v / (1.0 + (-v).exp())).collect();
approx(&sig, &want_sig, 1e-4);
approx(&silu, &want_silu, 1e-4);
}
#[test]
fn neg_exp_sqrt() {
approx(
&run_unary(Activation::Neg, &[1.0, -2.0, 3.0]),
&[-1.0, 2.0, -3.0],
1e-5,
);
approx(
&run_unary(Activation::Exp, &[0.0, 1.0, 2.0]),
&[1.0, std::f32::consts::E, std::f32::consts::E.powi(2)],
1e-3,
);
approx(
&run_unary(Activation::Sqrt, &[1.0, 4.0, 9.0]),
&[1.0, 2.0, 3.0],
1e-5,
);
}
#[test]
fn softmax_last_axis() {
let mut g = Graph::new("softmax");
let x = g.input("x", Shape::new(&[2, 3], DType::F32));
let y = g.softmax(x, -1, Shape::new(&[2, 3], DType::F32));
g.set_outputs(vec![y]);
let mut exe = CoremlExecutable::compile(g);
let out = exe
.run(&[("x", &[1.0f32, 2.0, 3.0, 1.0, 1.0, 1.0])])
.expect("run")
.remove(0);
let row0 = softmax_ref(&[1.0, 2.0, 3.0]);
let row1 = softmax_ref(&[1.0, 1.0, 1.0]);
approx(&out[0..3], &row0, 1e-4);
approx(&out[3..6], &row1, 1e-4);
}
fn softmax_ref(x: &[f32]) -> Vec<f32> {
let m = x.iter().cloned().fold(f32::MIN, f32::max);
let e: Vec<f32> = x.iter().map(|&v| (v - m).exp()).collect();
let s: f32 = e.iter().sum();
e.iter().map(|&v| v / s).collect()
}
#[test]
fn transpose_2d() {
let mut g = Graph::new("transpose");
let x = g.input("x", Shape::new(&[2, 3], DType::F32));
let y = g.append_node(
rlx_ir::Op::Transpose { perm: vec![1, 0] },
vec![x],
Shape::new(&[3, 2], DType::F32),
None,
);
g.set_outputs(vec![y]);
let mut exe = CoremlExecutable::compile(g);
let out = exe
.run(&[("x", &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0])])
.expect("run")
.remove(0);
approx(&out, &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], 1e-5);
}
#[test]
fn reshape_flatten() {
let mut g = Graph::new("reshape");
let x = g.input("x", Shape::new(&[2, 3], DType::F32));
let y = g.reshape(x, vec![6], Shape::new(&[6], DType::F32));
g.set_outputs(vec![y]);
let mut exe = CoremlExecutable::compile(g);
let out = exe
.run(&[("x", &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0])])
.expect("run")
.remove(0);
approx(&out, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 1e-5);
}
#[test]
fn layer_norm_affine() {
let mut g = Graph::new("layernorm");
let x = g.input("x", Shape::new(&[2, 4], DType::F32));
let gamma = g.param("g", Shape::new(&[4], DType::F32));
let beta = g.param("b", Shape::new(&[4], DType::F32));
let y = g.layer_norm(x, gamma, beta, -1, 1e-5, Shape::new(&[2, 4], DType::F32));
g.set_outputs(vec![y]);
let xs = [1.0f32, 2.0, 3.0, 4.0, -1.0, 0.0, 1.0, 2.0];
let gv = [1.0f32, 1.0, 1.0, 1.0];
let bv = [0.0f32, 0.0, 0.0, 0.0];
let mut exe = CoremlExecutable::compile(g);
exe.set_param("g", &gv);
exe.set_param("b", &bv);
let out = exe.run(&[("x", &xs)]).expect("run").remove(0);
let want: Vec<f32> = [&xs[0..4], &xs[4..8]]
.iter()
.flat_map(|row| layer_norm_ref(row, 1e-5))
.collect();
approx(&out, &want, 1e-3);
}
fn layer_norm_ref(x: &[f32], eps: f32) -> Vec<f32> {
let n = x.len() as f32;
let mean = x.iter().sum::<f32>() / n;
let var = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / n;
let inv = 1.0 / (var + eps).sqrt();
x.iter().map(|&v| (v - mean) * inv).collect()
}
#[test]
fn rms_norm_affine() {
let mut g = Graph::new("rmsnorm");
let x = g.input("x", Shape::new(&[2, 4], DType::F32));
let gamma = g.param("g", Shape::new(&[4], DType::F32));
let y = g.append_node(
rlx_ir::Op::RmsNorm {
axis: -1,
eps: 1e-6,
},
vec![x, gamma],
Shape::new(&[2, 4], DType::F32),
None,
);
g.set_outputs(vec![y]);
let xs = [1.0f32, 2.0, 3.0, 4.0, -1.0, 0.5, 1.5, 2.5];
let gv = [1.0f32, 0.5, 2.0, 1.0];
let mut exe = CoremlExecutable::compile(g);
exe.set_param("g", &gv);
let out = exe.run(&[("x", &xs)]).expect("run").remove(0);
let want: Vec<f32> = [&xs[0..4], &xs[4..8]]
.iter()
.flat_map(|row| rms_norm_ref(row, &gv, 1e-6))
.collect();
approx(&out, &want, 1e-3);
}
fn rms_norm_ref(x: &[f32], gamma: &[f32], eps: f32) -> Vec<f32> {
let n = x.len() as f32;
let ms = x.iter().map(|&v| v * v).sum::<f32>() / n;
let inv = 1.0 / (ms + eps).sqrt();
x.iter().zip(gamma).map(|(&v, &g)| v * inv * g).collect()
}
#[test]
fn reduce_sum_axis() {
let mut g = Graph::new("reduce");
let x = g.input("x", Shape::new(&[2, 3], DType::F32));
let y = g.reduce(
x,
rlx_ir::op::ReduceOp::Sum,
vec![1],
false,
Shape::new(&[2], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = CoremlExecutable::compile(g);
let out = exe
.run(&[("x", &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0])])
.expect("run")
.remove(0);
approx(&out, &[6.0, 15.0], 1e-4);
}
#[test]
fn concat_axis0() {
let mut g = Graph::new("concat");
let a = g.input("a", Shape::new(&[1, 3], DType::F32));
let b = g.input("b", Shape::new(&[1, 3], DType::F32));
let y = g.concat(vec![a, b], 0, Shape::new(&[2, 3], DType::F32));
g.set_outputs(vec![y]);
let mut exe = CoremlExecutable::compile(g);
let out = exe
.run(&[("a", &[1.0f32, 2.0, 3.0]), ("b", &[4.0f32, 5.0, 6.0])])
.expect("run")
.remove(0);
approx(&out, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 1e-5);
}
#[test]
fn narrow_slice() {
let mut g = Graph::new("narrow");
let x = g.input("x", Shape::new(&[2, 4], DType::F32));
let y = g.append_node(
rlx_ir::Op::Narrow {
axis: 1,
start: 1,
len: 2,
},
vec![x],
Shape::new(&[2, 2], DType::F32),
None,
);
g.set_outputs(vec![y]);
let mut exe = CoremlExecutable::compile(g);
let out = exe
.run(&[("x", &[0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0])])
.expect("run")
.remove(0);
approx(&out, &[1.0, 2.0, 5.0, 6.0], 1e-5);
}