use rlx_ir::op::{Activation, BinaryOp, CmpOp, MaskKind, ReduceOp};
use rlx_ir::{DType, Graph, Op, Shape};
use rlx_wgpu::backend::WgpuExecutable;
fn build_graph() -> Graph {
let mut g = Graph::new("basic");
let x = g.input("x", Shape::new(&[2, 3], DType::F32));
let w = g.param("w", Shape::new(&[3, 2], DType::F32));
let y = g.matmul(x, w, Shape::new(&[2, 2], DType::F32));
g.set_outputs(vec![y]);
g
}
fn matmul_ref(x: &[f32], w: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
let mut y = vec![0f32; m * n];
for i in 0..m {
for j in 0..n {
let mut s = 0f32;
for kk in 0..k {
s += x[i * k + kk] * w[kk * n + j];
}
y[i * n + j] = s;
}
}
y
}
fn close(a: &[f32], b: &[f32], tol: f32) -> bool {
a.len() == b.len() && a.iter().zip(b).all(|(x, y)| (x - y).abs() <= tol)
}
#[test]
fn binary_add_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("add");
let x = g.input("x", Shape::new(&[4], DType::F32));
let y = g.input("y", Shape::new(&[4], DType::F32));
let z = g.binary(BinaryOp::Add, x, y, Shape::new(&[4], DType::F32));
g.set_outputs(vec![z]);
let mut exe = WgpuExecutable::compile(g);
let out = exe.run(&[
("x", &[1.0, 2.0, 3.0, 4.0]),
("y", &[10.0, 20.0, 30.0, 40.0]),
]);
assert_eq!(out[0], vec![11.0, 22.0, 33.0, 44.0]);
}
#[test]
fn binary_max_min_pow_match_reference() {
if !rlx_wgpu::is_available() {
return;
}
for (op, want) in [
(BinaryOp::Max, vec![3.0, 4.0, 3.0, 4.0]),
(BinaryOp::Min, vec![1.0, 2.0, 1.0, 2.0]),
(BinaryOp::Pow, vec![1.0, 16.0, 3.0, 16.0]), ] {
let mut g = Graph::new("bin");
let a = g.input("a", Shape::new(&[4], DType::F32));
let b = g.input("b", Shape::new(&[4], DType::F32));
let c = g.binary(op, a, b, Shape::new(&[4], DType::F32));
g.set_outputs(vec![c]);
let mut exe = WgpuExecutable::compile(g);
let out = exe.run(&[("a", &[1.0, 2.0, 3.0, 4.0]), ("b", &[3.0, 4.0, 1.0, 2.0])]);
assert!(
close(&out[0], &want, 1e-4),
"Binary({op:?}) mismatch: got {:?} want {want:?}",
out[0]
);
}
}
#[test]
fn activations_relu_silu_match_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("act");
let x = g.input("x", Shape::new(&[5], DType::F32));
let r = g.activation(Activation::Relu, x, Shape::new(&[5], DType::F32));
let s = g.activation(Activation::Silu, r, Shape::new(&[5], DType::F32));
g.set_outputs(vec![s]);
let mut exe = WgpuExecutable::compile(g);
let xs = vec![-2.0, -0.5, 0.0, 1.0, 3.0];
let out = exe.run(&[("x", &xs)]);
let want = vec![0.0, 0.0, 0.0, 0.7311, 2.857];
assert!(
close(&out[0], &want, 1e-2),
"Relu+Silu mismatch: got {:?} want {want:?}",
out[0]
);
}
#[test]
fn compare_then_where_implements_abs() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("cw");
let x = g.input("x", Shape::new(&[4], DType::F32));
let z = g.input("z", Shape::new(&[4], DType::F32));
let nx = g.activation(Activation::Neg, x, Shape::new(&[4], DType::F32));
let cond = g.add_node(
Op::Compare(CmpOp::Gt),
vec![x, z],
Shape::new(&[4], DType::Bool),
);
let out = g.add_node(Op::Where, vec![cond, x, nx], Shape::new(&[4], DType::F32));
g.set_outputs(vec![out]);
let mut exe = WgpuExecutable::compile(g);
let r = exe.run(&[("x", &[1.0, -2.0, 3.0, -4.0]), ("z", &[0.0, 0.0, 0.0, 0.0])]);
assert_eq!(r[0], vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn reduce_sum_last_axis_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("rsum");
let x = g.input("x", Shape::new(&[2, 3], DType::F32));
let y = g.reduce(
x,
ReduceOp::Sum,
vec![1],
false,
Shape::new(&[2], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let r = exe.run(&[("x", &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])]);
assert_eq!(r[0], vec![6.0, 15.0]);
}
#[test]
fn reduce_mean_max_min_match_reference() {
if !rlx_wgpu::is_available() {
return;
}
for (op, want) in [
(ReduceOp::Mean, vec![2.0, 5.0]),
(ReduceOp::Max, vec![3.0, 6.0]),
(ReduceOp::Min, vec![1.0, 4.0]),
] {
let mut g = Graph::new("red");
let x = g.input("x", Shape::new(&[2, 3], DType::F32));
let y = g.reduce(x, op, vec![1], false, Shape::new(&[2], DType::F32));
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let r = exe.run(&[("x", &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])]);
assert!(
close(&r[0], &want, 1e-5),
"Reduce({op:?}) mismatch: got {:?} want {want:?}",
r[0]
);
}
}
#[test]
fn softmax_last_axis_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("smx");
let x = g.input("x", Shape::new(&[1, 3], DType::F32));
let y = g.softmax(x, -1, Shape::new(&[1, 3], DType::F32));
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let r = exe.run(&[("x", &[1.0, 2.0, 3.0])]);
let want = vec![0.0900, 0.2447, 0.6652];
assert!(
close(&r[0], &want, 1e-3),
"softmax mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn layer_norm_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("ln");
let x = g.input("x", Shape::new(&[2, 4], DType::F32));
let ga = g.param("g", Shape::new(&[4], DType::F32));
let be = g.param("b", Shape::new(&[4], DType::F32));
let y = g.layer_norm(x, ga, be, -1, 1e-5, Shape::new(&[2, 4], DType::F32));
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
exe.set_param("g", &[1.0, 1.0, 1.0, 1.0]);
exe.set_param("b", &[0.0, 0.0, 0.0, 0.0]);
let xs = vec![1.0, 2.0, 3.0, 4.0, 2.0, 0.0, 0.0, 0.0];
let r = exe.run(&[("x", &xs)]);
let mut want = vec![0f32; 8];
for row in 0..2 {
let off = row * 4;
let mean = (0..4).map(|i| xs[off + i]).sum::<f32>() / 4.0;
let var = (0..4).map(|i| (xs[off + i] - mean).powi(2)).sum::<f32>() / 4.0;
let inv = 1.0 / (var + 1e-5).sqrt();
for i in 0..4 {
want[off + i] = (xs[off + i] - mean) * inv;
}
}
assert!(
close(&r[0], &want, 1e-3),
"layer_norm mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn cumsum_inclusive_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("cs");
let x = g.input("x", Shape::new(&[4], DType::F32));
let y = g.cumsum(x, 0, false, Shape::new(&[4], DType::F32));
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let r = exe.run(&[("x", &[1.0, 2.0, 3.0, 4.0])]);
assert_eq!(r[0], vec![1.0, 3.0, 6.0, 10.0]);
}
#[test]
fn reshape_passes_data_through() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("rs");
let x = g.input("x", Shape::new(&[2, 3], DType::F32));
let y = g.reshape(x, vec![3, 2], Shape::new(&[3, 2], DType::F32));
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let r = exe.run(&[("x", &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])]);
assert_eq!(r[0], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn transpose_2x3_to_3x2_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("tr");
let x = g.input("x", Shape::new(&[2, 3], DType::F32));
let y = g.add_node(
Op::Transpose { perm: vec![1, 0] },
vec![x],
Shape::new(&[3, 2], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let r = exe.run(&[("x", &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])]);
assert_eq!(r[0], vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn transpose_bhsd_layout_swap_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("tr4");
let x = g.input("x", Shape::new(&[1, 2, 2, 2], DType::F32));
let y = g.add_node(
Op::Transpose {
perm: vec![0, 2, 1, 3],
},
vec![x],
Shape::new(&[1, 2, 2, 2], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let xs: Vec<f32> = (1..=8).map(|i| i as f32).collect();
let r = exe.run(&[("x", &xs)]);
let want = vec![1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 7.0, 8.0];
assert_eq!(r[0], want);
}
#[test]
fn narrow_axis2_slice_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("nrw");
let x = g.input("x", Shape::new(&[1, 1, 4], DType::F32));
let y = g.add_node(
Op::Narrow {
axis: 2,
start: 1,
len: 2,
},
vec![x],
Shape::new(&[1, 1, 2], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let r = exe.run(&[("x", &[10.0, 20.0, 30.0, 40.0])]);
assert_eq!(r[0], vec![20.0, 30.0]);
}
#[test]
fn concat_axis_minus_one_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("cat");
let a = g.input("a", Shape::new(&[2, 2], DType::F32));
let b = g.input("b", Shape::new(&[2, 3], DType::F32));
let y = g.concat(vec![a, b], 1, Shape::new(&[2, 5], DType::F32));
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let r = exe.run(&[
("a", &[1.0, 2.0, 3.0, 4.0]),
("b", &[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]),
]);
assert_eq!(
r[0],
vec![1.0, 2.0, 10.0, 20.0, 30.0, 3.0, 4.0, 40.0, 50.0, 60.0]
);
}
#[test]
fn gather_embedding_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("gat");
let table = g.param("t", Shape::new(&[3, 2], DType::F32));
let idx = g.input("i", Shape::new(&[2], DType::F32));
let y = g.gather(table, idx, 0, Shape::new(&[2, 2], DType::F32));
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
exe.set_param("t", &[10.0, 11.0, 20.0, 21.0, 30.0, 31.0]);
let r = exe.run(&[("i", &[2.0, 0.0])]);
assert_eq!(r[0], vec![30.0, 31.0, 10.0, 11.0]);
}
fn cpu_attention_packed_qkv(packed: &[f32], b: usize, s: usize, nh: usize, dh: usize) -> Vec<f32> {
let hs = nh * dh;
let qrs = hs * 3;
let scale = 1.0f32 / (dh as f32).sqrt();
let mut out = vec![0.0f32; b * s * hs];
for bi in 0..b {
for hi in 0..nh {
let mut qh = vec![0.0f32; s * dh];
let mut kh = vec![0.0f32; s * dh];
let mut vh = vec![0.0f32; s * dh];
for si in 0..s {
let q_off = bi * s * qrs + si * qrs + hi * dh;
let k_off = q_off + hs;
let v_off = q_off + 2 * hs;
qh[si * dh..(si + 1) * dh].copy_from_slice(&packed[q_off..q_off + dh]);
kh[si * dh..(si + 1) * dh].copy_from_slice(&packed[k_off..k_off + dh]);
vh[si * dh..(si + 1) * dh].copy_from_slice(&packed[v_off..v_off + dh]);
}
for qi in 0..s {
let o_off = bi * s * hs + qi * hs + hi * dh;
let mut m = f32::NEG_INFINITY;
let mut l = 0.0f32;
let mut acc = vec![0.0f32; dh];
for ki in 0..s {
let mut score = 0.0f32;
for d in 0..dh {
score += qh[qi * dh + d] * kh[ki * dh + d];
}
score *= scale;
let m_new = m.max(score);
let e_old = (m - m_new).exp();
let e_cur = (score - m_new).exp();
l = e_old * l + e_cur;
for d in 0..dh {
acc[d] = e_old * acc[d] + e_cur * vh[ki * dh + d];
}
m = m_new;
}
let inv_l = 1.0 / l;
for d in 0..dh {
out[o_off + d] = acc[d] * inv_l;
}
}
}
}
out
}
fn cpu_attention_bshd(
q: &[f32],
k: &[f32],
v: &[f32],
b: usize,
s: usize,
nh: usize,
dh: usize,
) -> Vec<f32> {
let hs = nh * dh;
let scale = 1.0f32 / (dh as f32).sqrt();
let mut out = vec![0.0f32; b * s * hs];
for bi in 0..b {
for qi in 0..s {
for hi in 0..nh {
let q_base = bi * s * hs + qi * hs + hi * dh;
let mut m = f32::NEG_INFINITY;
let mut l = 0.0f32;
let mut acc = vec![0.0f32; dh];
for ki in 0..s {
let k_base = bi * s * hs + ki * hs + hi * dh;
let mut score = 0.0f32;
for d in 0..dh {
score += q[q_base + d] * k[k_base + d];
}
score *= scale;
let m_new = m.max(score);
let e_old = (m - m_new).exp();
let e_cur = (score - m_new).exp();
l = e_old * l + e_cur;
let v_base = bi * s * hs + ki * hs + hi * dh;
for d in 0..dh {
acc[d] = e_old * acc[d] + e_cur * v[v_base + d];
}
m = m_new;
}
let inv_l = 1.0 / l;
for d in 0..dh {
out[q_base + d] = acc[d] * inv_l;
}
}
}
}
out
}
#[test]
fn matmul_eeg_qkv_shape_matches_cpu() {
if !rlx_wgpu::is_available() {
return;
}
use rlx::prelude::*;
let (m, k, n) = (191u32, 200u32, 600u32);
let mut g = Graph::new("mm");
let a = g.input("a", Shape::new(&[m as usize, k as usize], DType::F32));
let b = g.param("b", Shape::new(&[k as usize, n as usize], DType::F32));
let y = g.matmul(a, b, Shape::new(&[m as usize, n as usize], DType::F32));
g.set_outputs(vec![y]);
let a_v: Vec<f32> = (0..(m * k) as usize)
.map(|i| (i as f32 * 0.03).sin())
.collect();
let b_v: Vec<f32> = (0..(k * n) as usize)
.map(|i| (i as f32 * 0.02).cos() * 0.1)
.collect();
let cpu = Session::new(Device::Cpu);
let mut cc = cpu.compile(g.clone());
cc.set_param("b", &b_v);
let want = cc.run(&[("a", &a_v)]).into_iter().next().unwrap();
let mut exe = WgpuExecutable::compile(g);
exe.set_param("b", &b_v);
let got = exe.run(&[("a", &a_v)]).into_iter().next().unwrap();
let err = want
.iter()
.zip(got.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(err < 1e-3, "QKV matmul max_abs={err:.3e}");
}
#[test]
fn eeg_qkv_matmul_batched_matches_cpu() {
if !rlx_wgpu::is_available() {
return;
}
use rlx::prelude::*;
let (b, s, k, n) = (1, 191, 200, 600);
let mut g = Graph::new("mm3");
let a = g.input("a", Shape::new(&[b, s, k], DType::F32));
let w = g.param("w", Shape::new(&[k, n], DType::F32));
let y = g.matmul(a, w, Shape::new(&[b, s, n], DType::F32));
g.set_outputs(vec![y]);
let a_v: Vec<f32> = (0..b * s * k).map(|i| (i as f32 * 0.05).sin()).collect();
let w_v: Vec<f32> = (0..k * n).map(|i| (i as f32 * 0.01).cos() * 0.1).collect();
let cpu = Session::new(Device::Cpu);
let mut cc = cpu.compile(g.clone());
cc.set_param("w", &w_v);
let want = cc.run(&[("a", &a_v)]).into_iter().next().unwrap();
let mut exe = WgpuExecutable::compile(g);
exe.set_param("w", &w_v);
let got = exe.run(&[("a", &a_v)]).into_iter().next().unwrap();
let err = want
.iter()
.zip(got.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
err < 1e-3,
"batched QKV matmul max_abs={err:.3e} idx26862 cpu={} gpu={}",
want[26862],
got[26862]
);
}
#[test]
fn eeg_qkv_fmb_and_narrows_match_cpu() {
if !rlx_wgpu::is_available() {
return;
}
use rlx::prelude::*;
let (b, s, nh, dh) = (1, 191, 8, 25);
let hd = nh * dh;
let f = DType::F32;
let mut g = Graph::new("fmb_narrow");
let x = g.input("x", Shape::new(&[b, s, hd], f));
let w = g.param("w", Shape::new(&[hd, 3 * hd], f));
let bias = g.param("b", Shape::new(&[3 * hd], f));
let qkv = g.add_node(
Op::FusedMatMulBiasAct { activation: None },
vec![x, w, bias],
Shape::new(&[b, s, 3 * hd], f),
);
let qkv4 = g.reshape_(qkv, vec![b as i64, s as i64, 3, nh as i64, dh as i64]);
let q0 = g.add_node(
Op::Narrow {
axis: 2,
start: 0,
len: 1,
},
vec![qkv4],
Shape::new(&[b, s, 1, nh, dh], f),
);
g.set_outputs(vec![q0]);
let n = b * s * hd;
let x_v: Vec<f32> = (0..n).map(|i| (i as f32 * 0.05).sin()).collect();
let w_v: Vec<f32> = (0..(hd * 3 * hd))
.map(|i| (i as f32 * 0.01).cos() * 0.1)
.collect();
let b_v: Vec<f32> = (0..(3 * hd)).map(|i| i as f32 * 0.001).collect();
let cpu = Session::new(Device::Cpu);
let mut cc = cpu.compile(g.clone());
cc.set_param("w", &w_v);
cc.set_param("b", &b_v);
let want = cc.run(&[("x", &x_v)]).into_iter().next().unwrap();
let mut exe = WgpuExecutable::compile(g);
exe.set_param("w", &w_v);
exe.set_param("b", &b_v);
let got = exe.run(&[("x", &x_v)]).into_iter().next().unwrap();
let err = want
.iter()
.zip(got.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(err < 1e-3, "FMB+narrow Q max_abs={err:.3e}");
}
#[test]
fn eeg_qkv_tensors_gpu_match_cpu() {
if !rlx_wgpu::is_available() {
return;
}
use rlx::prelude::*;
let (b, s, nh, dh) = (1, 191, 8, 25);
let hd = nh * dh;
let f = DType::F32;
let n = b * s * hd;
let x_v: Vec<f32> = (0..n).map(|i| (i as f32 * 0.05).sin()).collect();
let w_v: Vec<f32> = (0..(hd * 3 * hd))
.map(|i| (i as f32 * 0.01).cos() * 0.1)
.collect();
let b_v: Vec<f32> = (0..(3 * hd)).map(|i| i as f32 * 0.001).collect();
let mut g = Graph::new("qkv_tensors");
let x = g.input("x", Shape::new(&[b, s, hd], f));
let w = g.param("w", Shape::new(&[hd, 3 * hd], f));
let bias = g.param("b", Shape::new(&[3 * hd], f));
let qkv = g.add_node(
Op::FusedMatMulBiasAct { activation: None },
vec![x, w, bias],
Shape::new(&[b, s, 3 * hd], f),
);
let qkv4 = g.reshape_(qkv, vec![b as i64, s as i64, 3, nh as i64, dh as i64]);
let k0 = g.add_node(
Op::Narrow {
axis: 2,
start: 1,
len: 1,
},
vec![qkv4],
Shape::new(&[b, s, 1, nh, dh], f),
);
let v0 = g.add_node(
Op::Narrow {
axis: 2,
start: 2,
len: 1,
},
vec![qkv4],
Shape::new(&[b, s, 1, nh, dh], f),
);
let q0 = g.add_node(
Op::Narrow {
axis: 2,
start: 0,
len: 1,
},
vec![qkv4],
Shape::new(&[b, s, 1, nh, dh], f),
);
let q = g.reshape_(q0, vec![b as i64, s as i64, nh as i64, dh as i64]);
let k = g.reshape_(k0, vec![b as i64, s as i64, nh as i64, dh as i64]);
let v = g.reshape_(v0, vec![b as i64, s as i64, nh as i64, dh as i64]);
g.set_outputs(vec![q, k, v]);
let cpu = Session::new(Device::Cpu);
let mut cc = cpu.compile(g.clone());
cc.set_param("w", &w_v);
cc.set_param("b", &b_v);
let want = cc.run(&[("x", &x_v)]);
let mut exe = WgpuExecutable::compile(g);
exe.set_param("w", &w_v);
exe.set_param("b", &b_v);
let got = exe.run(&[("x", &x_v)]);
for (name, w, g) in [
("Q", &want[0], &got[0]),
("K", &want[1], &got[1]),
("V", &want[2], &got[2]),
] {
let err = w
.iter()
.zip(g.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
err < 1e-3,
"GPU {name} tensor max_abs={err:.3e} idx26862 cpu={} gpu={}",
w[26862],
g[26862]
);
}
}
#[test]
fn eeg_attention_with_cpu_qkv_matches_cpu() {
if !rlx_wgpu::is_available() {
return;
}
use rlx::prelude::*;
let (b, s, nh, dh) = (1, 191, 8, 25);
let hd = nh * dh;
let f = DType::F32;
let n = b * s * hd;
let x_v: Vec<f32> = (0..n).map(|i| (i as f32 * 0.05).sin()).collect();
let w_v: Vec<f32> = (0..(hd * 3 * hd))
.map(|i| (i as f32 * 0.01).cos() * 0.1)
.collect();
let b_v: Vec<f32> = (0..(3 * hd)).map(|i| i as f32 * 0.001).collect();
let mut g_qkv = Graph::new("qkv");
let x = g_qkv.input("x", Shape::new(&[b, s, hd], f));
let w = g_qkv.param("w", Shape::new(&[hd, 3 * hd], f));
let bias = g_qkv.param("b", Shape::new(&[3 * hd], f));
let qkv = g_qkv.add_node(
Op::FusedMatMulBiasAct { activation: None },
vec![x, w, bias],
Shape::new(&[b, s, 3 * hd], f),
);
let qkv4 = g_qkv.reshape_(qkv, vec![b as i64, s as i64, 3, nh as i64, dh as i64]);
let q0 = g_qkv.add_node(
Op::Narrow {
axis: 2,
start: 0,
len: 1,
},
vec![qkv4],
Shape::new(&[b, s, 1, nh, dh], f),
);
let k0 = g_qkv.add_node(
Op::Narrow {
axis: 2,
start: 1,
len: 1,
},
vec![qkv4],
Shape::new(&[b, s, 1, nh, dh], f),
);
let v0 = g_qkv.add_node(
Op::Narrow {
axis: 2,
start: 2,
len: 1,
},
vec![qkv4],
Shape::new(&[b, s, 1, nh, dh], f),
);
let q = g_qkv.reshape_(q0, vec![b as i64, s as i64, nh as i64, dh as i64]);
let k = g_qkv.reshape_(k0, vec![b as i64, s as i64, nh as i64, dh as i64]);
let v = g_qkv.reshape_(v0, vec![b as i64, s as i64, nh as i64, dh as i64]);
g_qkv.set_outputs(vec![q, k, v]);
let cpu = Session::new(Device::Cpu);
let mut cc = cpu.compile(g_qkv.clone());
cc.set_param("w", &w_v);
cc.set_param("b", &b_v);
let qkv_out = cc.run(&[("x", &x_v)]);
let (q_v, k_v, v_v) = (&qkv_out[0], &qkv_out[1], &qkv_out[2]);
let want = cpu_attention_bshd(q_v, k_v, v_v, b, s, nh, dh);
let mut g_attn = Graph::new("attn");
let qi = g_attn.input("q", Shape::new(&[b, s, nh, dh], f));
let ki = g_attn.input("k", Shape::new(&[b, s, nh, dh], f));
let vi = g_attn.input("v", Shape::new(&[b, s, nh, dh], f));
let out = g_attn.add_node(
Op::Attention {
num_heads: nh,
head_dim: dh,
mask_kind: MaskKind::None,
score_scale: None,
attn_logit_softcap: None,
},
vec![qi, ki, vi],
Shape::new(&[b, s, nh, dh], f),
);
g_attn.set_outputs(vec![out]);
let mut exe = WgpuExecutable::compile(g_attn);
let got = exe
.run(&[("q", q_v), ("k", k_v), ("v", v_v)])
.into_iter()
.next()
.unwrap();
let err = want
.iter()
.zip(got.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
err < 1e-3,
"attention on CPU QKV max_abs={err:.3e} idx26862 cpu={} gpu={}",
want[26862],
got[26862]
);
let mut exe_qkv = WgpuExecutable::compile(g_qkv.clone());
exe_qkv.set_param("w", &w_v);
exe_qkv.set_param("b", &b_v);
let gpu_qkv = exe_qkv.run(&[("x", &x_v)]);
let ref_on_gpu_qkv = cpu_attention_bshd(&gpu_qkv[0], &gpu_qkv[1], &gpu_qkv[2], b, s, nh, dh);
let got2 = exe
.run(&[("q", &gpu_qkv[0]), ("k", &gpu_qkv[1]), ("v", &gpu_qkv[2])])
.into_iter()
.next()
.unwrap();
let err2 = ref_on_gpu_qkv
.iter()
.zip(got2.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
err2 < 1e-3,
"wgpu attn on GPU QKV vs cpu ref max_abs={err2:.3e} idx26862 ref={} gpu={}",
ref_on_gpu_qkv[26862],
got2[26862]
);
}
#[test]
fn attention_in_graph_does_not_change_qkv_activations() {
if !rlx_wgpu::is_available() {
return;
}
use rlx::prelude::*;
let (b, s, nh, dh) = (1, 191, 8, 25);
let hd = nh * dh;
let f = DType::F32;
let n = b * s * hd;
let x_v: Vec<f32> = (0..n).map(|i| (i as f32 * 0.05).sin()).collect();
let w_v: Vec<f32> = (0..(hd * 3 * hd))
.map(|i| (i as f32 * 0.01).cos() * 0.1)
.collect();
let b_v: Vec<f32> = (0..(3 * hd)).map(|i| i as f32 * 0.001).collect();
let mut g0 = Graph::new("qkv_only");
let x = g0.input("x", Shape::new(&[b, s, hd], f));
let w = g0.param("w", Shape::new(&[hd, 3 * hd], f));
let bias = g0.param("b", Shape::new(&[3 * hd], f));
let qkv = g0.add_node(
Op::FusedMatMulBiasAct { activation: None },
vec![x, w, bias],
Shape::new(&[b, s, 3 * hd], f),
);
let qkv4 = g0.reshape_(qkv, vec![b as i64, s as i64, 3, nh as i64, dh as i64]);
let q0 = g0.add_node(
Op::Narrow {
axis: 2,
start: 0,
len: 1,
},
vec![qkv4],
Shape::new(&[b, s, 1, nh, dh], f),
);
let q_only = g0.reshape_(q0, vec![b as i64, s as i64, nh as i64, dh as i64]);
g0.set_outputs(vec![q_only]);
let mut g1 = Graph::new("qkv_plus_attn");
let x1 = g1.input("x", Shape::new(&[b, s, hd], f));
let w1 = g1.param("w", Shape::new(&[hd, 3 * hd], f));
let bias1 = g1.param("b", Shape::new(&[3 * hd], f));
let qkv1 = g1.add_node(
Op::FusedMatMulBiasAct { activation: None },
vec![x1, w1, bias1],
Shape::new(&[b, s, 3 * hd], f),
);
let qkv41 = g1.reshape_(qkv1, vec![b as i64, s as i64, 3, nh as i64, dh as i64]);
let q01 = g1.add_node(
Op::Narrow {
axis: 2,
start: 0,
len: 1,
},
vec![qkv41],
Shape::new(&[b, s, 1, nh, dh], f),
);
let k01 = g1.add_node(
Op::Narrow {
axis: 2,
start: 1,
len: 1,
},
vec![qkv41],
Shape::new(&[b, s, 1, nh, dh], f),
);
let v01 = g1.add_node(
Op::Narrow {
axis: 2,
start: 2,
len: 1,
},
vec![qkv41],
Shape::new(&[b, s, 1, nh, dh], f),
);
let q1 = g1.reshape_(q01, vec![b as i64, s as i64, nh as i64, dh as i64]);
let k1 = g1.reshape_(k01, vec![b as i64, s as i64, nh as i64, dh as i64]);
let v1 = g1.reshape_(v01, vec![b as i64, s as i64, nh as i64, dh as i64]);
let _attn = g1.add_node(
Op::Attention {
num_heads: nh,
head_dim: dh,
mask_kind: MaskKind::None,
score_scale: None,
attn_logit_softcap: None,
},
vec![q1, k1, v1],
Shape::new(&[b, s, nh, dh], f),
);
g1.set_outputs(vec![q1]);
let mut exe0 = WgpuExecutable::compile(g0);
exe0.set_param("w", &w_v);
exe0.set_param("b", &b_v);
let q_a = exe0.run(&[("x", &x_v)]).into_iter().next().unwrap();
let mut exe1 = WgpuExecutable::compile(g1);
exe1.set_param("w", &w_v);
exe1.set_param("b", &b_v);
let q_b = exe1.run(&[("x", &x_v)]).into_iter().next().unwrap();
let err = q_a
.iter()
.zip(q_b.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
err < 1e-6,
"Q differs when attention is in graph: max_abs={err:.3e} idx26862 a={} b={}",
q_a[26862],
q_b[26862]
);
}
#[test]
fn eeg_qkv4_after_fmb_matches_cpu() {
if !rlx_wgpu::is_available() {
return;
}
use rlx::prelude::*;
let (b, s, nh, dh) = (1, 191, 8, 25);
let hd = nh * dh;
let f = DType::F32;
let mut g = Graph::new("qkv4");
let x = g.input("x", Shape::new(&[b, s, hd], f));
let w = g.param("w", Shape::new(&[hd, 3 * hd], f));
let bias = g.param("b", Shape::new(&[3 * hd], f));
let qkv = g.add_node(
Op::FusedMatMulBiasAct { activation: None },
vec![x, w, bias],
Shape::new(&[b, s, 3 * hd], f),
);
let qkv4 = g.reshape_(qkv, vec![b as i64, s as i64, 3, nh as i64, dh as i64]);
g.set_outputs(vec![qkv4]);
let n = b * s * hd;
let x_v: Vec<f32> = (0..n).map(|i| (i as f32 * 0.05).sin()).collect();
let w_v: Vec<f32> = (0..(hd * 3 * hd))
.map(|i| (i as f32 * 0.01).cos() * 0.1)
.collect();
let b_v: Vec<f32> = (0..(3 * hd)).map(|i| i as f32 * 0.001).collect();
let cpu = Session::new(Device::Cpu);
let mut cc = cpu.compile(g.clone());
cc.set_param("w", &w_v);
cc.set_param("b", &b_v);
let want = cc.run(&[("x", &x_v)]).into_iter().next().unwrap();
let mut exe = WgpuExecutable::compile(g);
exe.set_param("w", &w_v);
exe.set_param("b", &b_v);
let got = exe.run(&[("x", &x_v)]).into_iter().next().unwrap();
let err = want
.iter()
.zip(got.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(err < 1e-3, "qkv4 after FMB max_abs={err:.3e}");
}
#[test]
fn detect_packed_bshd_on_eeg_chain_graph() {
use rlx::prelude::*;
let (b, s, nh, dh) = (1, 191, 8, 25);
let hd = nh * dh;
let f = DType::F32;
let mut g = Graph::new("qkv_chain");
let x = g.input("x", Shape::new(&[b, s, hd], f));
let w = g.param("w", Shape::new(&[hd, 3 * hd], f));
let bias = g.param("b", Shape::new(&[3 * hd], f));
let qkv = g.add_node(
Op::FusedMatMulBiasAct { activation: None },
vec![x, w, bias],
Shape::new(&[b, s, 3 * hd], f),
);
let qkv4 = g.reshape_(qkv, vec![b as i64, s as i64, 3, nh as i64, dh as i64]);
let q0 = g.add_node(
Op::Narrow {
axis: 2,
start: 0,
len: 1,
},
vec![qkv4],
Shape::new(&[b, s, 1, nh, dh], f),
);
let k0 = g.add_node(
Op::Narrow {
axis: 2,
start: 1,
len: 1,
},
vec![qkv4],
Shape::new(&[b, s, 1, nh, dh], f),
);
let v0 = g.add_node(
Op::Narrow {
axis: 2,
start: 2,
len: 1,
},
vec![qkv4],
Shape::new(&[b, s, 1, nh, dh], f),
);
let q = g.reshape_(q0, vec![b as i64, s as i64, nh as i64, dh as i64]);
let k = g.reshape_(k0, vec![b as i64, s as i64, nh as i64, dh as i64]);
let v = g.reshape_(v0, vec![b as i64, s as i64, nh as i64, dh as i64]);
let out = g.add_node(
Op::Attention {
num_heads: nh,
head_dim: dh,
mask_kind: MaskKind::None,
score_scale: None,
attn_logit_softcap: None,
},
vec![q, k, v],
Shape::new(&[b, s, nh, dh], f),
);
g.set_outputs(vec![out]);
let g = rlx_wgpu::unfuse::unfuse(g);
let attn = g
.nodes()
.iter()
.find(|n| matches!(n.op, Op::Attention { .. }))
.unwrap();
let got = rlx_ir::detect_packed_bshd_qkv_attention(
&g,
attn.inputs[0],
attn.inputs[1],
attn.inputs[2],
);
assert!(got.is_some(), "packed BSHD QKV pattern should be detected");
let (parent, hw, narrows) = got.unwrap();
assert_eq!(hw, nh * dh);
assert_eq!(g.node(parent).shape.dims().len(), 5);
assert!(
narrows
.iter()
.all(|&n| matches!(g.node(n).op, Op::Narrow { .. }))
);
}
#[test]
fn wgpu_chain_attention_uses_packed_qkv_stride() {
if !rlx_wgpu::is_available() {
return;
}
use rlx::prelude::*;
let (b, s, nh, dh) = (1, 191, 8, 25);
let hd = nh * dh;
let f = DType::F32;
let mut g = Graph::new("qkv_chain");
let x = g.input("x", Shape::new(&[b, s, hd], f));
let w = g.param("w", Shape::new(&[hd, 3 * hd], f));
let bias = g.param("b", Shape::new(&[3 * hd], f));
let qkv = g.add_node(
Op::FusedMatMulBiasAct { activation: None },
vec![x, w, bias],
Shape::new(&[b, s, 3 * hd], f),
);
let qkv4 = g.reshape_(qkv, vec![b as i64, s as i64, 3, nh as i64, dh as i64]);
let q0 = g.add_node(
Op::Narrow {
axis: 2,
start: 0,
len: 1,
},
vec![qkv4],
Shape::new(&[b, s, 1, nh, dh], f),
);
let k0 = g.add_node(
Op::Narrow {
axis: 2,
start: 1,
len: 1,
},
vec![qkv4],
Shape::new(&[b, s, 1, nh, dh], f),
);
let v0 = g.add_node(
Op::Narrow {
axis: 2,
start: 2,
len: 1,
},
vec![qkv4],
Shape::new(&[b, s, 1, nh, dh], f),
);
let q = g.reshape_(q0, vec![b as i64, s as i64, nh as i64, dh as i64]);
let k = g.reshape_(k0, vec![b as i64, s as i64, nh as i64, dh as i64]);
let v = g.reshape_(v0, vec![b as i64, s as i64, nh as i64, dh as i64]);
let out = g.add_node(
Op::Attention {
num_heads: nh,
head_dim: dh,
mask_kind: MaskKind::None,
score_scale: None,
attn_logit_softcap: None,
},
vec![q, k, v],
Shape::new(&[b, s, nh, dh], f),
);
g.set_outputs(vec![out]);
let exe = WgpuExecutable::compile(g);
let qs = exe.test_attn_q_seq_stride().expect("attention step");
assert_eq!(
qs,
(hd * 3) as u32,
"expected packed seq stride 3*H*D=600, got {qs}"
);
let (qo, ko, vo, _) = exe.test_attn_offsets_and_stride().unwrap();
let parent_g = exe.test_arena_offset_elems(qkv4);
let q_g = exe.test_arena_offset_elems(q);
assert_eq!(ko - qo, hd as u32);
assert_eq!(vo - ko, hd as u32);
assert_eq!(
qo, parent_g,
"q_off={qo} should equal parent qkv4 global off={parent_g}, q global={q_g}"
);
}
#[test]
fn wgpu_packed_attn_matches_strided_cpu_ref() {
if !rlx_wgpu::is_available() {
return;
}
use rlx::prelude::*;
let (b, s, nh, dh) = (1, 191, 8, 25);
let hd = nh * dh;
let f = DType::F32;
let n = b * s * hd;
let x_v: Vec<f32> = (0..n).map(|i| (i as f32 * 0.05).sin()).collect();
let w_v: Vec<f32> = (0..(hd * 3 * hd))
.map(|i| (i as f32 * 0.01).cos() * 0.1)
.collect();
let b_v: Vec<f32> = (0..(3 * hd)).map(|i| i as f32 * 0.001).collect();
let mut g_pack = Graph::new("pack");
let x = g_pack.input("x", Shape::new(&[b, s, hd], f));
let w = g_pack.param("w", Shape::new(&[hd, 3 * hd], f));
let bias = g_pack.param("b", Shape::new(&[3 * hd], f));
let qkv = g_pack.add_node(
Op::FusedMatMulBiasAct { activation: None },
vec![x, w, bias],
Shape::new(&[b, s, 3 * hd], f),
);
let qkv4 = g_pack.reshape_(qkv, vec![b as i64, s as i64, 3, nh as i64, dh as i64]);
g_pack.set_outputs(vec![qkv4]);
let cpu = Session::new(Device::Cpu);
let mut cc = cpu.compile(g_pack.clone());
cc.set_param("w", &w_v);
cc.set_param("b", &b_v);
let packed = cc.run(&[("x", &x_v)]).into_iter().next().unwrap();
let want = cpu_attention_packed_qkv(&packed, b, s, nh, dh);
let mut g_attn = Graph::new("attn_pack");
let pin = g_attn.input("p", Shape::new(&[b, s, 3, nh, dh], f));
let q0 = g_attn.add_node(
Op::Narrow {
axis: 2,
start: 0,
len: 1,
},
vec![pin],
Shape::new(&[b, s, 1, nh, dh], f),
);
let k0 = g_attn.add_node(
Op::Narrow {
axis: 2,
start: 1,
len: 1,
},
vec![pin],
Shape::new(&[b, s, 1, nh, dh], f),
);
let v0 = g_attn.add_node(
Op::Narrow {
axis: 2,
start: 2,
len: 1,
},
vec![pin],
Shape::new(&[b, s, 1, nh, dh], f),
);
let q = g_attn.reshape_(q0, vec![b as i64, s as i64, nh as i64, dh as i64]);
let k = g_attn.reshape_(k0, vec![b as i64, s as i64, nh as i64, dh as i64]);
let v = g_attn.reshape_(v0, vec![b as i64, s as i64, nh as i64, dh as i64]);
let out = g_attn.add_node(
Op::Attention {
num_heads: nh,
head_dim: dh,
mask_kind: MaskKind::None,
score_scale: None,
attn_logit_softcap: None,
},
vec![q, k, v],
Shape::new(&[b, s, nh, dh], f),
);
g_attn.set_outputs(vec![out]);
let mut exe = WgpuExecutable::compile(g_attn);
let got = exe.run(&[("p", &packed)]).into_iter().next().unwrap();
let err = want
.iter()
.zip(got.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
err < 1e-3,
"packed strided attn max_abs={err:.3e} idx26862 want={} got={}",
want[26862],
got[26862]
);
}
#[test]
fn encoder_qkv_attention_chain_matches_cpu() {
if !rlx_wgpu::is_available() {
return;
}
use rlx::prelude::*;
let (b, s, nh, dh) = (1, 191, 8, 25);
let hd = nh * dh;
let f = DType::F32;
let mut g = Graph::new("qkv_chain");
let x = g.input("x", Shape::new(&[b, s, hd], f));
let w = g.param("w", Shape::new(&[hd, 3 * hd], f));
let bias = g.param("b", Shape::new(&[3 * hd], f));
let qkv = g.add_node(
Op::FusedMatMulBiasAct { activation: None },
vec![x, w, bias],
Shape::new(&[b, s, 3 * hd], f),
);
let qkv4 = g.reshape_(qkv, vec![b as i64, s as i64, 3, nh as i64, dh as i64]);
let q0 = g.add_node(
Op::Narrow {
axis: 2,
start: 0,
len: 1,
},
vec![qkv4],
Shape::new(&[b, s, 1, nh, dh], f),
);
let k0 = g.add_node(
Op::Narrow {
axis: 2,
start: 1,
len: 1,
},
vec![qkv4],
Shape::new(&[b, s, 1, nh, dh], f),
);
let v0 = g.add_node(
Op::Narrow {
axis: 2,
start: 2,
len: 1,
},
vec![qkv4],
Shape::new(&[b, s, 1, nh, dh], f),
);
let q = g.reshape_(q0, vec![b as i64, s as i64, nh as i64, dh as i64]);
let k = g.reshape_(k0, vec![b as i64, s as i64, nh as i64, dh as i64]);
let v = g.reshape_(v0, vec![b as i64, s as i64, nh as i64, dh as i64]);
let out = g.add_node(
Op::Attention {
num_heads: nh,
head_dim: dh,
mask_kind: MaskKind::None,
score_scale: None,
attn_logit_softcap: None,
},
vec![q, k, v],
Shape::new(&[b, s, nh, dh], f),
);
g.set_outputs(vec![out]);
let n = b * s * hd;
let x_v: Vec<f32> = (0..n).map(|i| (i as f32 * 0.05).sin()).collect();
let w_v: Vec<f32> = (0..(hd * 3 * hd))
.map(|i| (i as f32 * 0.01).cos() * 0.1)
.collect();
let b_v: Vec<f32> = (0..(3 * hd)).map(|i| i as f32 * 0.001).collect();
let cpu_sess = Session::new(Device::Cpu);
let mut cpu_c = cpu_sess.compile(g.clone());
cpu_c.set_param("w", &w_v);
cpu_c.set_param("b", &b_v);
let cpu_out = cpu_c.run(&[("x", &x_v)]).into_iter().next().unwrap();
let mut exe = WgpuExecutable::compile(g);
exe.set_param("w", &w_v);
exe.set_param("b", &b_v);
let gpu_out = exe.run(&[("x", &x_v)]).into_iter().next().unwrap();
let err = cpu_out
.iter()
.zip(gpu_out.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
err < 1e-3,
"QKV+attention chain max_abs={err:.3e} idx26862 cpu={} gpu={}",
cpu_out[26862],
gpu_out[26862]
);
}
#[test]
fn attention_bshd_eeg_shape_matches_cpu() {
if !rlx_wgpu::is_available() {
return;
}
let (b, s, nh, dh) = (1, 191, 8, 25);
let n = b * s * nh * dh;
let q: Vec<f32> = (0..n).map(|i| (i as f32 * 0.07).sin() * 0.5).collect();
let k: Vec<f32> = (0..n).map(|i| (i as f32 * 0.11).cos() * 0.3).collect();
let v: Vec<f32> = (0..n).map(|i| (i as f32 * 0.03) % 1.0 - 0.5).collect();
let want = cpu_attention_bshd(&q, &k, &v, b, s, nh, dh);
let mut g = Graph::new("bshd_eeg");
let qi = g.input("q", Shape::new(&[b, s, nh, dh], DType::F32));
let ki = g.input("k", Shape::new(&[b, s, nh, dh], DType::F32));
let vi = g.input("v", Shape::new(&[b, s, nh, dh], DType::F32));
let o = g.add_node(
Op::Attention {
num_heads: nh,
head_dim: dh,
mask_kind: MaskKind::None,
score_scale: None,
attn_logit_softcap: None,
},
vec![qi, ki, vi],
Shape::new(&[b, s, nh, dh], DType::F32),
);
g.set_outputs(vec![o]);
let mut exe = WgpuExecutable::compile(g);
let got = exe
.run(&[("q", &q), ("k", &k), ("v", &v)])
.into_iter()
.next()
.unwrap();
let err = want
.iter()
.zip(got.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
err < 1e-3,
"BSHD attention [1,191,8,25] max_abs={err:.3e} (idx 26862: cpu={} gpu={})",
want[26862],
got[26862]
);
}
#[test]
fn attention_no_mask_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("attn");
let q = g.input("q", Shape::new(&[1, 1, 2, 2], DType::F32));
let k = g.input("k", Shape::new(&[1, 1, 2, 2], DType::F32));
let v = g.input("v", Shape::new(&[1, 1, 2, 2], DType::F32));
let o = g.add_node(
Op::Attention {
num_heads: 1,
head_dim: 2,
mask_kind: MaskKind::None,
score_scale: None,
attn_logit_softcap: None,
},
vec![q, k, v],
Shape::new(&[1, 1, 2, 2], DType::F32),
);
g.set_outputs(vec![o]);
let mut exe = WgpuExecutable::compile(g);
let qd = vec![1.0, 0.0, 0.0, 1.0];
let kd = vec![1.0, 0.0, 0.0, 1.0];
let vd = vec![10.0, 20.0, 30.0, 40.0];
let r = exe.run(&[("q", &qd), ("k", &kd), ("v", &vd)]);
let want = vec![16.605, 26.605, 23.395, 33.395];
assert!(
close(&r[0], &want, 5e-3),
"attention mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn rope_identity_passes_through() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("rope");
let x = g.input("x", Shape::new(&[1, 1, 1, 4], DType::F32));
let cos = g.input("cos", Shape::new(&[1, 2], DType::F32));
let sin = g.input("sin", Shape::new(&[1, 2], DType::F32));
let y = g.add_node(
Op::Rope {
head_dim: 4,
n_rot: 4,
},
vec![x, cos, sin],
Shape::new(&[1, 1, 1, 4], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let r = exe.run(&[
("x", &[1.0, 2.0, 3.0, 4.0]),
("cos", &[1.0, 1.0]),
("sin", &[0.0, 0.0]),
]);
assert_eq!(r[0], vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn rope_90_degree_rotation_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("rope90");
let x = g.input("x", Shape::new(&[1, 1, 1, 4], DType::F32));
let cos = g.input("cos", Shape::new(&[1, 2], DType::F32));
let sin = g.input("sin", Shape::new(&[1, 2], DType::F32));
let y = g.add_node(
Op::Rope {
head_dim: 4,
n_rot: 4,
},
vec![x, cos, sin],
Shape::new(&[1, 1, 1, 4], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let r = exe.run(&[
("x", &[1.0, 2.0, 3.0, 4.0]),
("cos", &[0.0, 0.0]),
("sin", &[1.0, 1.0]),
]);
let want = vec![-3.0, -4.0, 1.0, 2.0];
assert!(
close(&r[0], &want, 1e-5),
"rope90 mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn expand_broadcast_replicates_values() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("expand");
let x = g.input("x", Shape::new(&[1, 3], DType::F32));
let y = g.add_node(
Op::Expand {
target_shape: vec![2, 3],
},
vec![x],
Shape::new(&[2, 3], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let r = exe.run(&[("x", &[1.0, 2.0, 3.0])]);
assert_eq!(r[0], vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
}
#[test]
fn dot_general_canonical_matches_matmul() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("dg");
let x = g.input("x", Shape::new(&[2, 3], DType::F32));
let w = g.param("w", Shape::new(&[3, 2], DType::F32));
let y = g.add_node(
Op::DotGeneral {
lhs_contracting: vec![1],
rhs_contracting: vec![0],
lhs_batch: vec![],
rhs_batch: vec![],
},
vec![x, w],
Shape::new(&[2, 2], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
exe.set_param("w", &[1.0, 0.0, 0.0, 1.0, 0.5, 0.5]);
let r = exe.run(&[("x", &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])]);
assert!(close(&r[0], &[2.5, 3.5, 7.0, 8.0], 1e-5));
}
#[test]
fn sample_argmax_picks_dominant_logit() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("samp");
let logits = g.input("l", Shape::new(&[1, 5], DType::F32));
let id = g.add_node(
Op::Sample {
top_k: 0,
top_p: 1.0,
temperature: 1.0,
seed: 0,
},
vec![logits],
Shape::new(&[1], DType::F32),
);
g.set_outputs(vec![id]);
let mut exe = WgpuExecutable::compile(g);
let r = exe.run(&[("l", &[0.0, 0.0, 100.0, 0.0, 0.0])]);
assert_eq!(r[0][0] as i32, 2);
}
#[test]
fn pool_2x2_max_stride_2_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("pool");
let x = g.input("x", Shape::new(&[1, 1, 4, 4], DType::F32));
let p = g.add_node(
Op::Pool {
kind: ReduceOp::Max,
kernel_size: vec![2, 2],
stride: vec![2, 2],
padding: vec![0, 0],
},
vec![x],
Shape::new(&[1, 1, 2, 2], DType::F32),
);
g.set_outputs(vec![p]);
let mut exe = WgpuExecutable::compile(g);
let xs: Vec<f32> = (1..=16).map(|i| i as f32).collect();
let r = exe.run(&[("x", &xs)]);
assert_eq!(r[0], vec![6.0, 8.0, 14.0, 16.0]);
}
#[test]
fn conv2d_1x1_identity_matches_input() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("conv");
let x = g.input("x", Shape::new(&[1, 1, 2, 2], DType::F32));
let w = g.param("w", Shape::new(&[1, 1, 1, 1], DType::F32));
let y = g.add_node(
Op::Conv {
kernel_size: vec![1, 1],
stride: vec![1, 1],
padding: vec![0, 0],
dilation: vec![1, 1],
groups: 1,
},
vec![x, w],
Shape::new(&[1, 1, 2, 2], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
exe.set_param("w", &[1.0]);
let r = exe.run(&[("x", &[1.0, 2.0, 3.0, 4.0])]);
assert_eq!(r[0], vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn pool1d_max_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("pool1d");
let x = g.input("x", Shape::new(&[1, 1, 4], DType::F32));
let p = g.add_node(
Op::Pool {
kind: ReduceOp::Max,
kernel_size: vec![2],
stride: vec![2],
padding: vec![0],
},
vec![x],
Shape::new(&[1, 1, 2], DType::F32),
);
g.set_outputs(vec![p]);
let mut exe = WgpuExecutable::compile(g);
let r = exe.run(&[("x", &[1.0, 3.0, 2.0, 4.0])]);
assert_eq!(r[0], vec![3.0, 4.0]);
}
#[test]
fn pool3d_max_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("pool3d");
let x = g.input("x", Shape::new(&[1, 1, 2, 2, 2], DType::F32));
let p = g.add_node(
Op::Pool {
kind: ReduceOp::Max,
kernel_size: vec![2, 2, 2],
stride: vec![1, 1, 1],
padding: vec![0, 0, 0],
},
vec![x],
Shape::new(&[1, 1, 1, 1, 1], DType::F32),
);
g.set_outputs(vec![p]);
let mut exe = WgpuExecutable::compile(g);
let xs: Vec<f32> = (1..=8).map(|i| i as f32).collect();
let r = exe.run(&[("x", &xs)]);
assert_eq!(r[0], vec![8.0]);
}
#[test]
fn conv1d_simple_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("conv1d");
let x = g.input("x", Shape::new(&[1, 1, 4], DType::F32));
let w = g.param("w", Shape::new(&[1, 1, 2], DType::F32));
let y = g.add_node(
Op::Conv {
kernel_size: vec![2],
stride: vec![1],
padding: vec![0],
dilation: vec![1],
groups: 1,
},
vec![x, w],
Shape::new(&[1, 1, 3], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
exe.set_param("w", &[1.0, -1.0]);
let r = exe.run(&[("x", &[1.0, 2.0, 3.0, 4.0])]);
assert_eq!(r[0], vec![-1.0, -1.0, -1.0]);
}
#[test]
fn conv3d_1x1x1_identity_matches_input() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("conv3d");
let x = g.input("x", Shape::new(&[1, 1, 2, 2, 2], DType::F32));
let w = g.param("w", Shape::new(&[1, 1, 1, 1, 1], DType::F32));
let y = g.add_node(
Op::Conv {
kernel_size: vec![1, 1, 1],
stride: vec![1, 1, 1],
padding: vec![0, 0, 0],
dilation: vec![1, 1, 1],
groups: 1,
},
vec![x, w],
Shape::new(&[1, 1, 2, 2, 2], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
exe.set_param("w", &[1.0]);
let xs: Vec<f32> = (1..=8).map(|i| i as f32).collect();
let r = exe.run(&[("x", &xs)]);
assert_eq!(r[0], xs);
}
#[test]
fn fused_matmul_bias_act_matches_unfused_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("fmb");
let x = g.input("x", Shape::new(&[2, 3], DType::F32));
let w = g.param("w", Shape::new(&[3, 2], DType::F32));
let b = g.param("b", Shape::new(&[2], DType::F32));
let y = g.fused_matmul_bias_act(
x,
w,
b,
Some(Activation::Relu),
Shape::new(&[2, 2], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let xv = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let wv = vec![0.1, 0.2, 0.3, 0.4, -0.5, 0.6];
let bv = vec![-2.0, 0.5];
exe.set_param("w", &wv);
exe.set_param("b", &bv);
let r = exe.run(&[("x", &xv)]);
let mm = matmul_ref(&xv, &wv, 2, 3, 2);
let want: Vec<f32> = mm
.iter()
.enumerate()
.map(|(i, &v)| (v + bv[i % 2]).max(0.0))
.collect();
assert!(
close(&r[0], &want, 1e-4),
"FMB mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn fused_residual_ln_matches_unfused_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("frln");
let x = g.input("x", Shape::new(&[2, 4], DType::F32));
let r = g.input("r", Shape::new(&[2, 4], DType::F32));
let ga = g.param("g", Shape::new(&[4], DType::F32));
let be = g.param("b", Shape::new(&[4], DType::F32));
let y = g.fused_residual_ln(x, r, None, ga, be, 1e-5, Shape::new(&[2, 4], DType::F32));
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
exe.set_param("g", &[1.0, 1.0, 1.0, 1.0]);
exe.set_param("b", &[0.0, 0.0, 0.0, 0.0]);
let xv = vec![1.0, 2.0, 3.0, 4.0, 0.0, 1.0, 2.0, 3.0];
let rv = vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
let out = exe.run(&[("x", &xv), ("r", &rv)]);
let mut want = vec![0f32; 8];
for row in 0..2 {
let off = row * 4;
let s: Vec<f32> = (0..4).map(|i| xv[off + i] + rv[off + i]).collect();
let mean = s.iter().sum::<f32>() / 4.0;
let var = s.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / 4.0;
let inv = 1.0 / (var + 1e-5).sqrt();
for i in 0..4 {
want[off + i] = (s[i] - mean) * inv;
}
}
assert!(
close(&out[0], &want, 1e-3),
"FusedResidualLN mismatch: got {:?} want {want:?}",
out[0]
);
}
#[test]
fn fused_residual_rms_norm_matches_unfused_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("frrms");
let x = g.input("x", Shape::new(&[2, 4], DType::F32));
let r = g.input("r", Shape::new(&[2, 4], DType::F32));
let ga = g.param("g", Shape::new(&[4], DType::F32));
let be = g.param("b", Shape::new(&[4], DType::F32));
let y = g.fused_residual_rms_norm(x, r, None, ga, be, 1e-5, Shape::new(&[2, 4], DType::F32));
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
exe.set_param("g", &[1.0, 1.0, 1.0, 1.0]);
exe.set_param("b", &[0.0, 0.0, 0.0, 0.0]);
let xv = vec![1.0, 2.0, 3.0, 4.0, 0.0, 1.0, 2.0, 3.0];
let rv = vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
let out = exe.run(&[("x", &xv), ("r", &rv)]);
let mut want = vec![0f32; 8];
for row in 0..2 {
let off = row * 4;
let s: Vec<f32> = (0..4).map(|i| xv[off + i] + rv[off + i]).collect();
let mean_sq = s.iter().map(|v| v * v).sum::<f32>() / 4.0;
let inv_rms = 1.0 / (mean_sq + 1e-5).sqrt();
for i in 0..4 {
want[off + i] = s[i] * inv_rms;
}
}
assert!(
close(&out[0], &want, 1e-3),
"FusedResidualRmsNorm mismatch: got {:?} want {want:?}",
out[0]
);
}
#[test]
fn fused_swiglu_matches_unfused_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("swg");
let x = g.input("x", Shape::new(&[2, 4], DType::F32));
let y = g.add_node(
Op::FusedSwiGLU {
cast_to: None,
gate_first: false,
},
vec![x],
Shape::new(&[2, 2], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let xv: Vec<f32> = vec![
1.0, 2.0, 0.5, 1.5, 3.0, 4.0, 1.0, 2.0, ];
let r = exe.run(&[("x", &xv)]);
let silu = |z: f32| z / (1.0 + (-z).exp());
let want = vec![
1.0 * silu(0.5),
2.0 * silu(1.5),
3.0 * silu(1.0),
4.0 * silu(2.0),
];
assert!(
close(&r[0], &want, 1e-4),
"FusedSwiGLU mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn lora_matmul_matches_unfused_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("lora");
let m = 2;
let k = 3;
let n = 2;
let r = 2;
let scale = 0.5f32;
let x = g.input("x", Shape::new(&[m, k], DType::F32));
let w = g.param("w", Shape::new(&[k, n], DType::F32));
let a = g.param("a", Shape::new(&[k, r], DType::F32));
let b = g.param("b", Shape::new(&[r, n], DType::F32));
let y = g.lora_matmul(x, w, a, b, scale, Shape::new(&[m, n], DType::F32));
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let xv = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let wv = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6];
let av = vec![0.1, 0.0, 0.0, 0.1, 0.1, 0.1];
let bv = vec![1.0, 0.0, 0.0, 1.0];
exe.set_param("w", &wv);
exe.set_param("a", &av);
exe.set_param("b", &bv);
let r_out = exe.run(&[("x", &xv)]);
let xw = matmul_ref(&xv, &wv, m, k, n);
let xa = matmul_ref(&xv, &av, m, k, r);
let xab = matmul_ref(&xa, &bv, m, r, n);
let want: Vec<f32> = xw.iter().zip(&xab).map(|(&a, &b)| a + scale * b).collect();
assert!(
close(&r_out[0], &want, 1e-4),
"LoRA mismatch: got {:?} want {want:?}",
r_out[0]
);
}
#[test]
fn gelu_eeg_tensor_matches_cpu() {
if !rlx_wgpu::is_available() {
return;
}
use rlx::prelude::*;
let n = 25 * 190 * 8;
let x: Vec<f32> = (0..n).map(|i| (i as f32 * 0.017).sin() * 2.0).collect();
let mut g = Graph::new("gelu");
let xi = g.input("x", Shape::new(&[1, 25, 190, 8], DType::F32));
let y = g.activation(
Activation::Gelu,
xi,
Shape::new(&[1, 25, 190, 8], DType::F32),
);
g.set_outputs(vec![y]);
let cpu = Session::new(Device::Cpu);
let want = cpu
.compile(g.clone())
.run(&[("x", &x)])
.into_iter()
.next()
.unwrap();
let gpu = Session::new(Device::Gpu);
let got = gpu.compile(g).run(&[("x", &x)]).into_iter().next().unwrap();
let err = want
.iter()
.zip(got.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
err < 1e-5,
"GELU [1,25,190,8] max_abs={err:.3e} (cpu[0]={} gpu[0]={})",
want[0],
got[0]
);
}
#[test]
fn gelu_finite_for_large_inputs() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("gelu-large");
let x = g.input("x", Shape::new(&[6], DType::F32));
let y = g.activation(Activation::Gelu, x, Shape::new(&[6], DType::F32));
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let xs = vec![-25.0_f32, -17.0, -5.0, 5.0, 17.0, 25.0];
let r = exe.run(&[("x", &xs)]);
let nans = r[0].iter().filter(|v| v.is_nan()).count();
let infs = r[0].iter().filter(|v| v.is_infinite()).count();
assert_eq!(nans, 0, "GELU produced NaN: r={:?}", r[0]);
assert_eq!(infs, 0, "GELU produced Inf: r={:?}", r[0]);
assert!(
r[0][0].abs() < 1e-3,
"gelu(-25) should ≈ 0, got {}",
r[0][0]
);
assert!(
(r[0][5] - 25.0).abs() < 1e-2,
"gelu(25) should ≈ 25, got {}",
r[0][5]
);
}
#[test]
fn attention_rank3_with_2d_mask_produces_finite_output() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("attn-rank3");
let b = 1;
let s = 3;
let h = 2;
let d = 4;
let inner = h * d;
let q = g.input("q", Shape::new(&[b, s, inner], DType::F32));
let k = g.input("k", Shape::new(&[b, s, inner], DType::F32));
let v = g.input("v", Shape::new(&[b, s, inner], DType::F32));
let m = g.input("m", Shape::new(&[b, s], DType::F32));
let y = g.add_node(
Op::Attention {
num_heads: h,
head_dim: d,
mask_kind: MaskKind::Custom,
score_scale: None,
attn_logit_softcap: None,
},
vec![q, k, v, m],
Shape::new(&[b, s, inner], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let qv: Vec<f32> = (0..b * s * inner).map(|i| (i as f32) * 0.01).collect();
let kv = qv.clone();
let vv = qv.clone();
let mv = vec![0.0; b * s]; let r = exe.run(&[("q", &qv), ("k", &kv), ("v", &vv), ("m", &mv)]);
let nans = r[0].iter().filter(|v| v.is_nan()).count();
let infs = r[0].iter().filter(|v| v.is_infinite()).count();
assert_eq!(
nans,
0,
"rank-3 attention produced {nans} NaN values; \
first 8 = {:?}",
&r[0][..8.min(r[0].len())]
);
assert_eq!(infs, 0, "rank-3 attention produced {infs} Inf values");
}
#[test]
fn fused_attention_block_end_to_end() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("fab");
let b = 1;
let s = 2;
let h = 2;
let d = 2;
let inner = h * d;
let hidden_shape = Shape::new(&[b, s, inner], DType::F32);
let qkv_w_shape = Shape::new(&[inner, 3 * inner], DType::F32);
let out_w_shape = Shape::new(&[inner, inner], DType::F32);
let mask_shape = Shape::new(&[b, h, s, s], DType::F32);
let hidden = g.input("h", hidden_shape.clone());
let qkv_w = g.param("qkv_w", qkv_w_shape);
let out_w = g.param("out_w", out_w_shape);
let mask = g.input("mask", mask_shape);
let y = g.add_node(
Op::FusedAttentionBlock {
num_heads: h,
head_dim: d,
has_bias: false,
has_rope: false,
},
vec![hidden, qkv_w, out_w, mask],
hidden_shape,
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let mut qkv_w = vec![0f32; inner * 3 * inner];
for i in 0..inner {
qkv_w[i * 3 * inner + i] = 1.0;
}
for i in 0..inner {
qkv_w[i * 3 * inner + 2 * inner + i] = 1.0;
}
let mut out_w = vec![0f32; inner * inner];
for i in 0..inner {
out_w[i * inner + i] = 1.0;
}
exe.set_param("qkv_w", &qkv_w);
exe.set_param("out_w", &out_w);
let hidden_v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mask_v = vec![0.0; b * h * s * s];
let r = exe.run(&[("h", &hidden_v), ("mask", &mask_v)]);
let want = vec![3.0, 4.0, 5.0, 6.0, 3.0, 4.0, 5.0, 6.0];
assert!(
close(&r[0], &want, 1e-3),
"FAB mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn selective_scan_minimum_config_matches_cpu_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("ssm");
let b = 1;
let s = 2;
let h = 2;
let n = 2;
let x = g.input("x", Shape::new(&[b, s, h], DType::F32));
let dt = g.input("dt", Shape::new(&[b, s, h], DType::F32));
let a = g.param("a", Shape::new(&[h, n], DType::F32));
let bb = g.input("b", Shape::new(&[b, s, n], DType::F32));
let cc = g.input("c", Shape::new(&[b, s, n], DType::F32));
let y = g.add_node(
Op::SelectiveScan { state_size: n },
vec![x, dt, a, bb, cc],
Shape::new(&[b, s, h], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let av = vec![-1.0; h * n];
exe.set_param("a", &av);
let xv = vec![1.0, 1.0, 1.0, 1.0]; let dtv = vec![1.0, 1.0, 1.0, 1.0];
let bv = vec![1.0, 0.0, 0.0, 1.0]; let cv = vec![1.0, 1.0, 1.0, 1.0];
let r = exe.run(&[("x", &xv), ("dt", &dtv), ("b", &bv), ("c", &cv)]);
let mut want = vec![0f32; b * s * h];
let mut state = vec![0f32; h * n];
for bi in 0..b {
for v in state.iter_mut() {
*v = 0.0;
}
for si in 0..s {
for ci in 0..h {
let d = dtv[bi * s * h + si * h + ci];
let xv_ = xv[bi * s * h + si * h + ci];
let mut acc = 0.0;
for ni in 0..n {
let da = (d * av[ci * n + ni]).exp();
state[ci * n + ni] =
da * state[ci * n + ni] + d * bv[bi * s * n + si * n + ni] * xv_;
acc += cv[bi * s * n + si * n + ni] * state[ci * n + ni];
}
want[bi * s * h + si * h + ci] = acc;
}
}
}
assert!(
close(&r[0], &want, 1e-4),
"SelectiveScan mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn gated_delta_net_matches_cpu_reference() {
if !rlx_wgpu::is_available() {
return;
}
let (b, s, h, n) = (1, 4, 2, 3);
let mut g = Graph::new("gdn");
let bshn = Shape::new(&[b, s, h, n], DType::F32);
let bsh = Shape::new(&[b, s, h], DType::F32);
let q = g.input("q", bshn.clone());
let k = g.input("k", bshn.clone());
let v = g.input("v", bshn.clone());
let g_in = g.input("g", bsh.clone());
let beta = g.input("beta", bsh);
let y = g.add_node(
Op::GatedDeltaNet {
state_size: n,
carry_state: false,
},
vec![q, k, v, g_in, beta],
bshn,
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let nqkv = b * s * h * n;
let ngb = b * s * h;
let q_data: Vec<f32> = (0..nqkv).map(|i| 0.05 + 0.03 * (i as f32)).collect();
let k_data: Vec<f32> = (0..nqkv).map(|i| 0.10 + 0.02 * (i as f32)).collect();
let v_data: Vec<f32> = (0..nqkv).map(|i| 0.30 + 0.05 * (i as f32)).collect();
let g_data: Vec<f32> = (0..ngb).map(|i| -0.20 - 0.01 * (i as f32)).collect();
let beta_data: Vec<f32> = (0..ngb).map(|i| 0.40 + 0.02 * (i as f32)).collect();
let r = exe.run(&[
("q", &q_data),
("k", &k_data),
("v", &v_data),
("g", &g_data),
("beta", &beta_data),
]);
let scale = 1.0f32 / (n as f32).sqrt();
let mut want = vec![0f32; nqkv];
let mut state = vec![0f32; h * n * n];
let mut sk = vec![0f32; n];
for bi in 0..b {
for st in state.iter_mut() {
*st = 0.0;
}
for ti in 0..s {
let step_qkv = bi * s * h * n + ti * h * n;
let step_gb = bi * s * h + ti * h;
for hi in 0..h {
let q_row = &q_data[step_qkv + hi * n..step_qkv + (hi + 1) * n];
let k_row = &k_data[step_qkv + hi * n..step_qkv + (hi + 1) * n];
let v_row = &v_data[step_qkv + hi * n..step_qkv + (hi + 1) * n];
let g_t = g_data[step_gb + hi];
let beta_t = beta_data[step_gb + hi];
let s_base = hi * n * n;
let s_mat = &mut state[s_base..s_base + n * n];
let g_exp = g_t.exp();
for v in s_mat.iter_mut() {
*v *= g_exp;
}
for j in 0..n {
let mut acc = 0.0f32;
for i in 0..n {
acc += s_mat[i * n + j] * k_row[i];
}
sk[j] = acc;
}
for j in 0..n {
sk[j] = (v_row[j] - sk[j]) * beta_t;
}
for i in 0..n {
for j in 0..n {
s_mat[i * n + j] += k_row[i] * sk[j];
}
}
let out_row = &mut want[step_qkv + hi * n..step_qkv + (hi + 1) * n];
for j in 0..n {
let mut acc = 0.0f32;
for i in 0..n {
acc += s_mat[i * n + j] * q_row[i];
}
out_row[j] = acc * scale;
}
}
}
}
assert!(
close(&r[0], &want, 1e-4),
"GatedDeltaNet mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn dequant_matmul_int8_symmetric_matches_dequant_then_matmul() {
if !rlx_wgpu::is_available() {
return;
}
use rlx_ir::QuantScheme;
let m = 2usize;
let k = 4usize;
let n = 3usize;
let block_size: u32 = 2;
let n_blocks = (k as u32).div_ceil(block_size);
let mut g = Graph::new("dq");
let x = g.input("x", Shape::new(&[m, k], DType::F32));
let wq = g.param("wq", Shape::new(&[k, n], DType::I8));
let sc = g.param("sc", Shape::new(&[n_blocks as usize, n], DType::F32));
let zp = g.param("zp", Shape::new(&[n_blocks as usize, n], DType::F32));
let y = g.dequant_matmul(
x,
wq,
sc,
zp,
QuantScheme::Int8Block { block_size },
Shape::new(&[m, n], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let w_i8: Vec<i8> = vec![
1, 2, 3, -1, 0, 4, 5, -2, 1, 2, 3, -1, ];
let w_bytes: Vec<u8> = w_i8.iter().map(|&b| b as u8).collect();
exe.set_param_bytes("wq", &w_bytes);
let scales = vec![
0.1, 0.2, 0.3, 0.4, 0.5, 0.6, ];
let zps = vec![0.0; (n_blocks as usize) * n]; exe.set_param("sc", &scales);
exe.set_param("zp", &zps);
let xv = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let r = exe.run(&[("x", &xv)]);
let mut w_dq = vec![0f32; k * n];
for ki in 0..k {
let block = ki / (block_size as usize);
for ni in 0..n {
let q = w_i8[ki * n + ni] as f32;
w_dq[ki * n + ni] = q * scales[block * n + ni];
}
}
let want = matmul_ref(&xv, &w_dq, m, k, n);
assert!(
close(&r[0], &want, 1e-3),
"DequantMatMul mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn dequant_matmul_int4_symmetric_matches_dequant_then_matmul() {
if !rlx_wgpu::is_available() {
return;
}
use rlx_ir::QuantScheme;
let m = 2usize;
let k = 4usize;
let n = 4usize;
let block_size: u32 = 2;
let n_blocks = (k as u32).div_ceil(block_size);
let mut g = Graph::new("dq4");
let x = g.input("x", Shape::new(&[m, k], DType::F32));
let wq = g.param("wq", Shape::new(&[k, n], DType::I8));
let sc = g.param("sc", Shape::new(&[n_blocks as usize, n], DType::F32));
let zp = g.param("zp", Shape::new(&[n_blocks as usize, n], DType::F32));
let y = g.dequant_matmul(
x,
wq,
sc,
zp,
QuantScheme::Int4Block { block_size },
Shape::new(&[m, n], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let w_i4: Vec<i8> = vec![
1, 2, -3, 4, -1, 0, 5, -6, 3, -2, 1, 7, -4, 6, -5, 2, ];
let mut packed = vec![0u8; (k * n) / 2];
for (i, chunk) in w_i4.chunks(2).enumerate() {
let lo = (chunk[0] as i32 & 0xf) as u8;
let hi = (chunk[1] as i32 & 0xf) as u8;
packed[i] = lo | (hi << 4);
}
exe.set_param_bytes("wq", &packed);
let scales = vec![
0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, ];
let zps = vec![0.0; (n_blocks as usize) * n];
exe.set_param("sc", &scales);
exe.set_param("zp", &zps);
let xv = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let r = exe.run(&[("x", &xv)]);
let mut w_dq = vec![0f32; k * n];
for ki in 0..k {
let block = ki / (block_size as usize);
for ni in 0..n {
let q = w_i4[ki * n + ni] as f32;
w_dq[ki * n + ni] = q * scales[block * n + ni];
}
}
let want = matmul_ref(&xv, &w_dq, m, k, n);
assert!(
close(&r[0], &want, 1e-3),
"DequantMatMul Int4 mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn dequant_matmul_nvfp4_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
use rlx_ir::{NVFP4_GROUP_SIZE, fp4_e2m1_to_f32, fp8_e4m3_scale_to_f32};
let m = 1usize;
let k = NVFP4_GROUP_SIZE;
let n = 4usize;
let n_scale = k.div_ceil(NVFP4_GROUP_SIZE) * n;
let mut g = Graph::new("nvfp4");
let x = g.input("x", Shape::new(&[m, k], DType::F32));
let wq = g.param("wq", Shape::new(&[(k * n).div_ceil(2)], DType::U8));
let sc = g.param("sc", Shape::new(&[n_scale], DType::U8));
let gs = g.param("gs", Shape::new(&[1], DType::F32));
let y = g.dequant_matmul_nvfp4(x, wq, sc, gs, Shape::new(&[m, n], DType::F32));
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let codes: Vec<u8> = (0..k * n)
.map(|i| if i % 2 == 0 { 2u8 } else { 4u8 })
.collect();
let mut packed = vec![0u8; (k * n).div_ceil(2)];
for (i, chunk) in codes.chunks(2).enumerate() {
packed[i] = chunk[0] | (chunk[1] << 4);
}
exe.set_param_bytes("wq", &packed);
let scale_bytes = vec![0x38u8; n_scale]; exe.set_param_bytes("sc", &scale_bytes);
exe.set_param("gs", &[2.0f32]);
let xv: Vec<f32> = (1..=k).map(|v| v as f32).collect();
let r = exe.run(&[("x", &xv)]);
let gs_val = 2.0f32;
let mut w_dq = vec![0f32; k * n];
for p in 0..k {
for j in 0..n {
let nib = codes[p * n + j];
let scale = fp8_e4m3_scale_to_f32(scale_bytes[j]);
w_dq[p * n + j] = fp4_e2m1_to_f32(nib) * scale * gs_val;
}
}
let want = matmul_ref(&xv, &w_dq, m, k, n);
assert!(
close(&r[0], &want, 1e-3),
"DequantMatMul NVFP4 mismatch: got {:?} want {want:?}",
r[0]
);
}
fn e4m3_to_f32(byte: u8) -> f32 {
let sign = (byte >> 7) & 1;
let exp = (byte >> 3) & 0xf;
let mant = byte & 0x7;
let v = if exp == 0 {
(mant as f32 / 8.0) * (-6f32).exp2()
} else if exp == 15 && mant == 7 {
0.0
} else {
let m = 1.0 + mant as f32 / 8.0;
m * ((exp as i32 - 7) as f32).exp2()
};
if sign != 0 { -v } else { v }
}
fn e5m2_to_f32(byte: u8) -> f32 {
let sign = (byte >> 7) & 1;
let exp = (byte >> 2) & 0x1f;
let mant = byte & 0x3;
let v = if exp == 0 {
(mant as f32 / 4.0) * (-14f32).exp2()
} else if exp == 31 {
0.0
} else {
let m = 1.0 + mant as f32 / 4.0;
m * ((exp as i32 - 15) as f32).exp2()
};
if sign != 0 { -v } else { v }
}
#[test]
fn dequant_matmul_fp8_e4m3_matches_decode_then_matmul() {
if !rlx_wgpu::is_available() {
return;
}
use rlx_ir::QuantScheme;
let m = 2usize;
let k = 4usize;
let n = 3usize;
let mut g = Graph::new("dq-e4m3");
let x = g.input("x", Shape::new(&[m, k], DType::F32));
let wq = g.param("wq", Shape::new(&[k, n], DType::U8));
let sc = g.param("sc", Shape::new(&[1, n], DType::F32));
let zp = g.param("zp", Shape::new(&[1, n], DType::F32));
let y = g.dequant_matmul(
x,
wq,
sc,
zp,
QuantScheme::Fp8E4m3,
Shape::new(&[m, n], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let w_bytes: Vec<u8> = vec![
0x38, 0x40, 0x48, 0x01, 0xC0, 0x44, 0xB8, 0x00, 0x70, 0x30, 0x21, 0x68, ];
exe.set_param_bytes("wq", &w_bytes);
exe.set_param("sc", &vec![0.0; n]);
exe.set_param("zp", &vec![0.0; n]);
let xv = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let r = exe.run(&[("x", &xv)]);
let w_dq: Vec<f32> = w_bytes.iter().map(|&b| e4m3_to_f32(b)).collect();
let want = matmul_ref(&xv, &w_dq, m, k, n);
assert!(
close(&r[0], &want, 1e-3),
"FP8 E4M3 mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn dequant_matmul_fp8_e5m2_matches_decode_then_matmul() {
if !rlx_wgpu::is_available() {
return;
}
use rlx_ir::QuantScheme;
let m = 2usize;
let k = 4usize;
let n = 3usize;
let mut g = Graph::new("dq-e5m2");
let x = g.input("x", Shape::new(&[m, k], DType::F32));
let wq = g.param("wq", Shape::new(&[k, n], DType::U8));
let sc = g.param("sc", Shape::new(&[1, n], DType::F32));
let zp = g.param("zp", Shape::new(&[1, n], DType::F32));
let y = g.dequant_matmul(
x,
wq,
sc,
zp,
QuantScheme::Fp8E5m2,
Shape::new(&[m, n], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let w_bytes: Vec<u8> = vec![
0x3C, 0x40, 0x44, 0xBC, 0x00, 0x4C, 0x3C, 0xC4, 0x3C, 0x40, 0x44, 0x40,
];
exe.set_param_bytes("wq", &w_bytes);
exe.set_param("sc", &vec![0.0; n]);
exe.set_param("zp", &vec![0.0; n]);
let xv = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let r = exe.run(&[("x", &xv)]);
let w_dq: Vec<f32> = w_bytes.iter().map(|&b| e5m2_to_f32(b)).collect();
let want = matmul_ref(&xv, &w_dq, m, k, n);
assert!(
close(&r[0], &want, 1e-3),
"FP8 E5M2 mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn dynamic_shape_auto_infers_at_run_time() {
if !rlx_wgpu::is_available() {
return;
}
use rlx_ir::shape::Dim;
let mut g = Graph::new("dyn-auto");
let dyn_shape = rlx_ir::Shape::from_dims(&[Dim::Dynamic(0), Dim::Static(4)], DType::F32);
let x = g.input("x", dyn_shape);
let two = g.add_node(
Op::Constant {
data: 2.0_f32.to_le_bytes().to_vec(),
},
vec![],
rlx_ir::Shape::from_dims(&[Dim::Static(1), Dim::Static(1)], DType::F32),
);
let two_b = g.add_node(
Op::Expand {
target_shape: vec![3, 4],
},
vec![two],
Shape::new(&[3, 4], DType::F32),
);
let dyn_out = rlx_ir::Shape::from_dims(&[Dim::Dynamic(0), Dim::Static(4)], DType::F32);
let y = g.binary(BinaryOp::Mul, x, two_b, dyn_out);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g); let xv: Vec<f32> = (1..=12).map(|i| i as f32).collect();
let r = exe.run(&[("x", &xv)]);
let want: Vec<f32> = xv.iter().map(|v| v * 2.0).collect();
assert!(
close(&r[0], &want, 1e-4),
"auto-infer DimBinding mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn dynamic_shape_resolves_via_compile_with_bindings() {
if !rlx_wgpu::is_available() {
return;
}
use rlx_ir::shape::{Dim, DimBinding};
let mut g = Graph::new("dyn");
let dyn_shape = rlx_ir::Shape::from_dims(&[Dim::Dynamic(0), Dim::Static(4)], DType::F32);
let x = g.input("x", dyn_shape.clone());
let two = g.add_node(
Op::Constant {
data: 2.0_f32.to_le_bytes().to_vec(),
},
vec![],
Shape::new(&[1, 1], DType::F32),
);
let two_b = g.add_node(
Op::Expand {
target_shape: vec![3, 4],
},
vec![two],
Shape::new(&[3, 4], DType::F32),
);
let y = g.binary(BinaryOp::Mul, x, two_b, Shape::new(&[3, 4], DType::F32));
g.set_outputs(vec![y]);
let mut bindings = DimBinding::new();
bindings.set(0, 3);
let mut exe = WgpuExecutable::compile_with_bindings(g, &bindings);
let xv: Vec<f32> = (1..=12).map(|i| i as f32).collect();
let r = exe.run(&[("x", &xv)]);
let want: Vec<f32> = xv.iter().map(|v| v * 2.0).collect();
assert!(
close(&r[0], &want, 1e-4),
"DimBinding mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn op_if_picks_branch_per_predicate() {
if !rlx_wgpu::is_available() {
return;
}
let then_branch = {
let mut g = Graph::new("then");
let x = g.input("x", Shape::new(&[3], DType::F32));
let c = g.add_node(
Op::Constant {
data: 1.0_f32.to_le_bytes().to_vec(),
},
vec![],
Shape::new(&[1], DType::F32),
);
let cb = g.add_node(
Op::Expand {
target_shape: vec![3],
},
vec![c],
Shape::new(&[3], DType::F32),
);
let y = g.binary(BinaryOp::Add, x, cb, Shape::new(&[3], DType::F32));
g.set_outputs(vec![y]);
g
};
let else_branch = {
let mut g = Graph::new("else");
let x = g.input("x", Shape::new(&[3], DType::F32));
let c = g.add_node(
Op::Constant {
data: 2.0_f32.to_le_bytes().to_vec(),
},
vec![],
Shape::new(&[1], DType::F32),
);
let cb = g.add_node(
Op::Expand {
target_shape: vec![3],
},
vec![c],
Shape::new(&[3], DType::F32),
);
let y = g.binary(BinaryOp::Mul, x, cb, Shape::new(&[3], DType::F32));
g.set_outputs(vec![y]);
g
};
let mut g = Graph::new("ifx");
let pred = g.input("pred", Shape::new(&[3], DType::Bool));
let xv = g.input("x", Shape::new(&[3], DType::F32));
let y = g.add_node(
Op::If {
then_branch: Box::new(then_branch),
else_branch: Box::new(else_branch),
},
vec![pred, xv],
Shape::new(&[3], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let xs = vec![1.0f32, 2.0, 3.0];
let pv = vec![1.0f32, 0.0, 1.0]; let r = exe.run(&[("x", &xs), ("pred", &pv)]);
let want = vec![
1.0 + 1.0, 2.0 * 2.0, 3.0 + 1.0,
];
assert!(
close(&r[0], &want, 1e-4),
"Op::If mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn op_while_unrolls_until_cond_false() {
if !rlx_wgpu::is_available() {
return;
}
let body = {
let mut g = Graph::new("body");
let x = g.input("x", Shape::new(&[1], DType::F32));
let c = g.add_node(
Op::Constant {
data: 2.0_f32.to_le_bytes().to_vec(),
},
vec![],
Shape::new(&[1], DType::F32),
);
let y = g.binary(BinaryOp::Mul, x, c, Shape::new(&[1], DType::F32));
g.set_outputs(vec![y]);
g
};
let cond = {
let mut g = Graph::new("cond");
let x = g.input("x", Shape::new(&[1], DType::F32));
let c = g.add_node(
Op::Constant {
data: 16.0_f32.to_le_bytes().to_vec(),
},
vec![],
Shape::new(&[1], DType::F32),
);
let y = g.add_node(
Op::Compare(CmpOp::Lt),
vec![x, c],
Shape::new(&[1], DType::Bool),
);
g.set_outputs(vec![y]);
g
};
let mut g = Graph::new("loopy");
let x = g.input("x", Shape::new(&[1], DType::F32));
let y = g.add_node(
Op::While {
cond: Box::new(cond),
body: Box::new(body),
max_iterations: Some(6),
},
vec![x],
Shape::new(&[1], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let r = exe.run(&[("x", &[1.0f32])]);
assert!(
close(&r[0], &[16.0], 1e-4),
"Op::While mismatch: got {:?}, expected 16",
r[0]
);
}
#[test]
fn dot_general_batched_matches_per_batch_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("dg-batched");
let l = g.input("l", Shape::new(&[2, 2, 3], DType::F32));
let r = g.input("r", Shape::new(&[2, 3, 2], DType::F32));
let y = g.add_node(
Op::DotGeneral {
lhs_contracting: vec![2],
rhs_contracting: vec![1],
lhs_batch: vec![0],
rhs_batch: vec![0],
},
vec![l, r],
Shape::new(&[2, 2, 2], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let lv: Vec<f32> = (1..=12).map(|i| i as f32).collect();
let rv: Vec<f32> = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let r = exe.run(&[("l", &lv), ("r", &rv)]);
let mut want = vec![0f32; 8];
for bi in 0..2 {
let l_slice = &lv[bi * 6..(bi + 1) * 6];
let r_slice = &rv[bi * 6..(bi + 1) * 6];
let y = matmul_ref(l_slice, r_slice, 2, 3, 2);
want[bi * 4..(bi + 1) * 4].copy_from_slice(&y);
}
assert!(
close(&r[0], &want, 1e-4),
"batched DotGeneral mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn dot_general_lhs_transposed_matches_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("dg-lhs-t");
let m = 2;
let k = 3;
let n = 2;
let l = g.input("l", Shape::new(&[k, m], DType::F32));
let r = g.input("r", Shape::new(&[k, n], DType::F32));
let y = g.add_node(
Op::DotGeneral {
lhs_contracting: vec![0],
rhs_contracting: vec![0],
lhs_batch: vec![],
rhs_batch: vec![],
},
vec![l, r],
Shape::new(&[m, n], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let lv = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let rv = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0]; let out = exe.run(&[("l", &lv), ("r", &rv)]);
let want = vec![6.0, 8.0, 8.0, 10.0];
assert!(
close(&out[0], &want, 1e-4),
"DotGeneral lhs.T mismatch: got {:?} want {want:?}",
out[0]
);
}
#[test]
fn sample_top_k_one_collapses_to_argmax() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("samp-k1");
let x = g.input("x", Shape::new(&[2, 4], DType::F32));
let y = g.add_node(
Op::Sample {
top_k: 1,
top_p: 1.0,
temperature: 1.0,
seed: 42,
},
vec![x],
Shape::new(&[2], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let xs = vec![
1.0, 5.0, 2.0, 3.0, 9.0, 0.0, 0.0, 0.0,
]; let r = exe.run(&[("x", &xs)]);
assert_eq!(r[0], vec![1.0, 0.0]);
}
fn threefry2x32_20_ref(c_in: [u32; 2], k_in: [u32; 2]) -> [u32; 2] {
fn rotl32(x: u32, n: u32) -> u32 {
x.rotate_left(n)
}
let ks0 = k_in[0];
let ks1 = k_in[1];
let ks2 = ks0 ^ ks1 ^ 0x1BD11BDA;
let mut x0 = c_in[0].wrapping_add(ks0);
let mut x1 = c_in[1].wrapping_add(ks1);
let r2x32: [u32; 8] = [13, 15, 26, 6, 17, 29, 16, 24];
for round in 0..20 {
x0 = x0.wrapping_add(x1);
x1 = rotl32(x1, r2x32[round % 8]);
x1 ^= x0;
if (round + 1) % 4 == 0 {
let inj = (round / 4 + 1) as u32;
let ksx = match inj % 3 {
0 => ks0,
1 => ks1,
_ => ks2,
};
let ksy = match (inj + 1) % 3 {
0 => ks0,
1 => ks1,
_ => ks2,
};
x0 = x0.wrapping_add(ksx);
x1 = x1.wrapping_add(ksy);
x1 = x1.wrapping_add(inj);
}
}
[x0, x1]
}
#[test]
fn threefry_reference_distributes_uniformly() {
let mut buckets = [0u32; 8];
for row in 0..64u32 {
let r = threefry2x32_20_ref([row, 0], [0xC0FFEE, 0]);
let u = r[0] as f64 / 4294967296.0;
let bucket = (u * 8.0) as usize;
buckets[bucket.min(7)] += 1;
}
let hit = buckets.iter().filter(|&&n| n > 0).count();
assert!(
hit >= 6,
"Reference Threefry only hit {hit}/8 buckets — buckets={buckets:?}"
);
}
#[test]
fn sample_threefry_seed_is_deterministic_and_distributes() {
if !rlx_wgpu::is_available() {
return;
}
let n = 8;
let batch = 64;
let mut g = Graph::new("samp-threefry");
let x = g.input("x", Shape::new(&[batch, n], DType::F32));
let y = g.add_node(
Op::Sample {
top_k: 0,
top_p: 0.999,
temperature: 1.0,
seed: 0xC0FFEE,
},
vec![x],
Shape::new(&[batch], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let xs: Vec<f32> = (0..batch * n).map(|_| 0.0).collect();
let r1 = exe.run(&[("x", &xs)]);
let r2 = exe.run(&[("x", &xs)]);
assert_eq!(
r1[0], r2[0],
"Threefry should be deterministic for same seed"
);
let mut hit = vec![0u32; n];
for &v in &r1[0] {
hit[v as usize] += 1;
}
let covered = hit.iter().filter(|c| **c > 0).count();
assert!(
covered >= 6,
"Threefry-driven Sample only hit {covered}/{n} tokens; \
per-token counts={hit:?}"
);
}
#[test]
fn sample_gumbel_max_concentrates_on_dominant_logit() {
if !rlx_wgpu::is_available() {
return;
}
let n = 8;
let batch = 64;
let mut g = Graph::new("samp-gumbel");
let x = g.input("x", Shape::new(&[batch, n], DType::F32));
let y = g.add_node(
Op::Sample {
top_k: 0,
top_p: 0.999,
temperature: 1.0,
seed: 99,
},
vec![x],
Shape::new(&[batch], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let mut xs = vec![0.0f32; batch * n];
for b in 0..batch {
xs[b * n + 3] = 5.0;
}
let r = exe.run(&[("x", &xs)]);
let three_picks = r[0].iter().filter(|&&v| v == 3.0).count();
assert!(
three_picks >= batch * 90 / 100,
"Gumbel-max should land on the dominant token most of the time; \
only {three_picks}/{batch} hit token 3, picks={:?}",
r[0]
);
}
#[test]
fn sample_top_p_zero_collapses_to_argmax() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("samp-p0");
let x = g.input("x", Shape::new(&[2, 4], DType::F32));
let y = g.add_node(
Op::Sample {
top_k: 0,
top_p: 0.001,
temperature: 1.0,
seed: 7,
},
vec![x],
Shape::new(&[2], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let xs = vec![
10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 10.0,
]; let r = exe.run(&[("x", &xs)]);
assert_eq!(r[0], vec![0.0, 3.0]);
}
#[test]
fn attention_causal_mask_zeros_future_tokens() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("attn-causal");
let q = g.input("q", Shape::new(&[1, 1, 2, 2], DType::F32));
let k = g.input("k", Shape::new(&[1, 1, 2, 2], DType::F32));
let v = g.input("v", Shape::new(&[1, 1, 2, 2], DType::F32));
let y = g.add_node(
Op::Attention {
num_heads: 1,
head_dim: 2,
mask_kind: MaskKind::Causal,
score_scale: None,
attn_logit_softcap: None,
},
vec![q, k, v],
Shape::new(&[1, 1, 2, 2], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let qv = vec![1.0, 0.0, 0.0, 1.0];
let kv = vec![1.0, 0.0, 0.0, 1.0];
let vv = vec![1.0, 2.0, 3.0, 4.0];
let r = exe.run(&[("q", &qv), ("k", &kv), ("v", &vv)]);
let s = 1.0 / 2.0_f32.sqrt();
let e0 = (0.0_f32 - s).exp();
let e1 = 1.0_f32;
let z = e0 + e1;
let w0 = e0 / z;
let w1 = e1 / z;
let want = vec![1.0, 2.0, w0 * 1.0 + w1 * 3.0, w0 * 2.0 + w1 * 4.0];
assert!(
close(&r[0], &want, 1e-3),
"Causal attention mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn attention_sliding_window_limits_lookback() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("attn-sw");
let q = g.input("q", Shape::new(&[1, 1, 2, 2], DType::F32));
let k = g.input("k", Shape::new(&[1, 1, 2, 2], DType::F32));
let v = g.input("v", Shape::new(&[1, 1, 2, 2], DType::F32));
let y = g.add_node(
Op::Attention {
num_heads: 1,
head_dim: 2,
mask_kind: MaskKind::SlidingWindow(0),
score_scale: None,
attn_logit_softcap: None,
},
vec![q, k, v],
Shape::new(&[1, 1, 2, 2], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let qv = vec![1.0, 0.0, 0.0, 1.0];
let kv = vec![1.0, 0.0, 0.0, 1.0];
let vv = vec![1.0, 2.0, 3.0, 4.0];
let r = exe.run(&[("q", &qv), ("k", &kv), ("v", &vv)]);
assert!(
close(&r[0], &vv, 1e-3),
"SlidingWindow attention mismatch: got {:?} want {vv:?}",
r[0]
);
}
#[test]
fn grouped_matmul_routes_per_token_to_expert() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("gmm");
let m = 2;
let k = 2;
let n = 2;
let ne = 2;
let x = g.input("x", Shape::new(&[m, k], DType::F32));
let w = g.param("w", Shape::new(&[ne, k, n], DType::F32));
let idx = g.input("idx", Shape::new(&[m], DType::F32));
let y = g.add_node(
Op::GroupedMatMul,
vec![x, w, idx],
Shape::new(&[m, n], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let wv = vec![
1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0, ];
exe.set_param("w", &wv);
let xv = vec![3.0, 4.0, 5.0, 6.0];
let idxv = vec![0.0, 1.0];
let r = exe.run(&[("x", &xv), ("idx", &idxv)]);
assert!(
close(
&r[0],
&[
3.0, 4.0, 10.0, 12.0
], 1e-4
),
"GroupedMatMul mismatch: got {:?}",
r[0]
);
}
#[test]
fn topk_picks_largest_three_indices() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("topk");
let x = g.input("x", Shape::new(&[2, 5], DType::F32));
let y = g.add_node(Op::TopK { k: 3 }, vec![x], Shape::new(&[2, 3], DType::F32));
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let xv = vec![
5.0, 1.0, 4.0, 2.0, 3.0, 0.5, 9.0, 0.1, 7.0, 8.0,
]; let r = exe.run(&[("x", &xv)]);
assert_eq!(r[0], vec![0.0, 2.0, 4.0, 1.0, 4.0, 3.0]);
}
#[test]
fn batched_matmul_3d_by_3d_matches_per_batch_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("bmm3");
let l = g.input("l", Shape::new(&[2, 2, 3], DType::F32));
let r = g.input("r", Shape::new(&[2, 3, 2], DType::F32));
let y = g.matmul(l, r, Shape::new(&[2, 2, 2], DType::F32));
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let lv: Vec<f32> = (1..=12).map(|i| i as f32).collect();
let rv: Vec<f32> = vec![
0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, ];
let r = exe.run(&[("l", &lv), ("r", &rv)]);
let mut want = vec![0f32; 2 * 2 * 2];
for bi in 0..2 {
let l_slice = &lv[bi * 6..(bi + 1) * 6];
let r_slice = &rv[bi * 6..(bi + 1) * 6];
let y = matmul_ref(l_slice, r_slice, 2, 3, 2);
want[bi * 4..(bi + 1) * 4].copy_from_slice(&y);
}
assert!(
close(&r[0], &want, 1e-4),
"batched 3D@3D mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn batched_matmul_3d_by_2d_matches_per_row_reference() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("bmm");
let x = g.input("x", Shape::new(&[2, 2, 3], DType::F32));
let w = g.param("w", Shape::new(&[3, 2], DType::F32));
let y = g.matmul(x, w, Shape::new(&[2, 2, 2], DType::F32));
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let xv: Vec<f32> = (1..=12).map(|i| i as f32).collect();
let wv = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6];
exe.set_param("w", &wv);
let r = exe.run(&[("x", &xv)]);
let want = matmul_ref(&xv, &wv, 4, 3, 2);
assert!(
close(&r[0], &want, 1e-4),
"batched matmul mismatch: got {:?} want {want:?}",
r[0]
);
}
#[test]
fn scatter_add_accumulates_into_destination() {
if !rlx_wgpu::is_available() {
return;
}
let mut g = Graph::new("sa");
let upd = g.input("upd", Shape::new(&[4, 2], DType::F32));
let idx = g.input("idx", Shape::new(&[4], DType::F32));
let y = g.add_node(
Op::ScatterAdd,
vec![upd, idx],
Shape::new(&[3, 2], DType::F32),
);
g.set_outputs(vec![y]);
let mut exe = WgpuExecutable::compile(g);
let updv = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let idxv = vec![1.0, 0.0, 1.0, 2.0];
let r = exe.run(&[("upd", &updv), ("idx", &idxv)]);
let want = vec![
3.0,
4.0, 1.0 + 5.0,
2.0 + 6.0, 7.0,
8.0, ];
assert!(
close(&r[0], &want, 1e-4),
"ScatterAdd mismatch: got {:?} want {want:?}",
r[0]
);
}
fn wgpu_run_and_check(
g: Graph,
inputs: &[(&str, &[f32])],
params: &[(&str, &[f32])],
want: &[f32],
) -> (f32, bool) {
let mut exe = WgpuExecutable::compile(g);
for (n, d) in params {
exe.set_param(n, d);
}
let outs = exe.run(inputs);
let got = outs.into_iter().next().unwrap_or_default();
let has_nan = got.iter().any(|v| v.is_nan());
let diff = if got.len() == want.len() && !has_nan {
got.iter()
.zip(want)
.map(|(a, b)| (a - b).abs())
.fold(0f32, f32::max)
} else {
f32::INFINITY
};
(diff, has_nan)
}
#[test]
fn bisect_wgpu_gather_only() {
if !rlx_wgpu::is_available() {
return;
}
let f = DType::F32;
let mut g = Graph::new("gather_only");
let ids = g.input("ids", Shape::new(&[1, 3], f));
let table = g.param("emb", Shape::new(&[8, 4], f));
let out = g.add_node(
Op::Gather { axis: 0 },
vec![table, ids],
Shape::new(&[1, 3, 4], f),
);
g.set_outputs(vec![out]);
let ids_v = vec![0.0f32, 2.0, 5.0];
let table_v: Vec<f32> = (0..32).map(|i| i as f32).collect();
let want: Vec<f32> = vec![
0.0, 1.0, 2.0, 3.0, 8.0, 9.0, 10.0, 11.0, 20.0, 21.0, 22.0, 23.0,
];
let (diff, has_nan) = wgpu_run_and_check(g, &[("ids", &ids_v)], &[("emb", &table_v)], &want);
eprintln!("[bisect:gather] diff={diff:e} has_nan={has_nan}");
assert!(!has_nan, "gather produced NaN");
assert!(diff < 1e-5, "gather diff {diff:e}");
}
#[test]
fn bisect_wgpu_gather_then_layernorm() {
if !rlx_wgpu::is_available() {
return;
}
let f = DType::F32;
let mut g = Graph::new("gather_ln");
let ids = g.input("ids", Shape::new(&[1, 3], f));
let table = g.param("emb", Shape::new(&[8, 4], f));
let gamma = g.param("gamma", Shape::new(&[4], f));
let beta = g.param("beta", Shape::new(&[4], f));
let g_out = g.add_node(
Op::Gather { axis: 0 },
vec![table, ids],
Shape::new(&[1, 3, 4], f),
);
let ln = g.add_node(
Op::LayerNorm {
axis: -1,
eps: 1e-5,
},
vec![g_out, gamma, beta],
Shape::new(&[1, 3, 4], f),
);
g.set_outputs(vec![ln]);
let ids_v = vec![0.0f32, 2.0, 5.0];
let table_v: Vec<f32> = (0..32).map(|i| i as f32).collect();
let gamma_v = vec![1.0f32; 4];
let beta_v = vec![0.0f32; 4];
let mut exe = WgpuExecutable::compile(g);
exe.set_param("emb", &table_v);
exe.set_param("gamma", &gamma_v);
exe.set_param("beta", &beta_v);
let out = exe.run(&[("ids", &ids_v)]).into_iter().next().unwrap();
let has_nan = out.iter().any(|v| v.is_nan());
eprintln!(
"[bisect:gather+ln] first={:?} has_nan={has_nan}",
&out[..4.min(out.len())]
);
assert!(!has_nan, "gather+layernorm produced NaN");
}
#[test]
fn bisect_wgpu_matmul_bias_narrow() {
if !rlx_wgpu::is_available() {
return;
}
use rlx_ir::op::BinaryOp;
let f = DType::F32;
let h = 8;
let mut g = Graph::new("mm_bias_narrow");
let x = g.input("x", Shape::new(&[1, 3, h], f));
let w = g.param("w", Shape::new(&[h, 3 * h], f));
let b = g.param("b", Shape::new(&[3 * h], f));
let mm = g.add_node(Op::MatMul, vec![x, w], Shape::new(&[1, 3, 3 * h], f));
let qkv = g.binary(BinaryOp::Add, mm, b, Shape::new(&[1, 3, 3 * h], f));
let q = g.add_node(
Op::Narrow {
axis: 2,
start: 0,
len: h,
},
vec![qkv],
Shape::new(&[1, 3, h], f),
);
g.set_outputs(vec![q]);
let x_v: Vec<f32> = (0..(3 * h)).map(|i| i as f32 * 0.1).collect();
let w_v: Vec<f32> = (0..(h * 3 * h)).map(|i| (i % 7) as f32 * 0.05).collect();
let b_v: Vec<f32> = (0..(3 * h)).map(|i| i as f32 * 0.01).collect();
let mut exe = WgpuExecutable::compile(g);
exe.set_param("w", &w_v);
exe.set_param("b", &b_v);
let out = exe.run(&[("x", &x_v)]).into_iter().next().unwrap();
let has_nan = out.iter().any(|v| v.is_nan());
eprintln!(
"[bisect:mm+bias+narrow] first={:?} has_nan={has_nan}",
&out[..4.min(out.len())]
);
assert!(!has_nan, "mm+bias+narrow produced NaN");
}
#[test]
fn bisect_wgpu_attention_with_qkv_chain() {
if !rlx_wgpu::is_available() {
return;
}
use rlx_ir::op::BinaryOp;
let f = DType::F32;
let (b, s, nh, dh) = (1, 3, 2, 4);
let h = nh * dh;
let mut g = Graph::new("attn_chain");
let x = g.input("x", Shape::new(&[b, s, h], f));
let mask = g.input("mask", Shape::new(&[b, s], f));
let w = g.param("w", Shape::new(&[h, 3 * h], f));
let bias = g.param("b", Shape::new(&[3 * h], f));
let mm = g.add_node(Op::MatMul, vec![x, w], Shape::new(&[b, s, 3 * h], f));
let qkv = g.binary(BinaryOp::Add, mm, bias, Shape::new(&[b, s, 3 * h], f));
let q = g.add_node(
Op::Narrow {
axis: 2,
start: 0,
len: h,
},
vec![qkv],
Shape::new(&[b, s, h], f),
);
let k = g.add_node(
Op::Narrow {
axis: 2,
start: h,
len: h,
},
vec![qkv],
Shape::new(&[b, s, h], f),
);
let v = g.add_node(
Op::Narrow {
axis: 2,
start: 2 * h,
len: h,
},
vec![qkv],
Shape::new(&[b, s, h], f),
);
let attn = g.add_node(
Op::Attention {
num_heads: nh,
head_dim: dh,
mask_kind: MaskKind::Custom,
score_scale: None,
attn_logit_softcap: None,
},
vec![q, k, v, mask],
Shape::new(&[b, s, h], f),
);
g.set_outputs(vec![attn]);
let x_v: Vec<f32> = (0..(b * s * h)).map(|i| ((i % 11) as f32) * 0.1).collect();
let w_v: Vec<f32> = (0..(h * 3 * h)).map(|i| ((i % 7) as f32) * 0.05).collect();
let b_v: Vec<f32> = (0..(3 * h)).map(|i| (i as f32) * 0.01).collect();
let mask_v = vec![1.0f32; b * s];
let mut exe = WgpuExecutable::compile(g);
exe.set_param("w", &w_v);
exe.set_param("b", &b_v);
let out = exe
.run(&[("x", &x_v), ("mask", &mask_v)])
.into_iter()
.next()
.unwrap();
let has_nan = out.iter().any(|v| v.is_nan());
eprintln!(
"[bisect:attn-chain] len={} first={:?} has_nan={has_nan}",
out.len(),
&out[..4.min(out.len())]
);
assert!(!has_nan, "attention chain produced NaN");
}
#[test]
fn bisect_wgpu_fused_residual_ln() {
if !rlx_wgpu::is_available() {
return;
}
let f = DType::F32;
let mut g = Graph::new("frln");
let x = g.input("x", Shape::new(&[1, 3, 4], f));
let res = g.input("res", Shape::new(&[1, 3, 4], f));
let gamma = g.param("gamma", Shape::new(&[4], f));
let beta = g.param("beta", Shape::new(&[4], f));
let frln = g.add_node(
Op::FusedResidualLN {
has_bias: false,
eps: 1e-5,
},
vec![x, res, gamma, beta],
Shape::new(&[1, 3, 4], f),
);
g.set_outputs(vec![frln]);
let x_v: Vec<f32> = (0..12).map(|i| i as f32 * 0.1).collect();
let res_v: Vec<f32> = (0..12).map(|i| i as f32 * 0.2).collect();
let gamma_v = vec![1.0f32; 4];
let beta_v = vec![0.0f32; 4];
let mut exe = WgpuExecutable::compile(g);
exe.set_param("gamma", &gamma_v);
exe.set_param("beta", &beta_v);
let out = exe
.run(&[("x", &x_v), ("res", &res_v)])
.into_iter()
.next()
.unwrap();
let has_nan = out.iter().any(|v| v.is_nan());
eprintln!(
"[bisect:fused_residual_ln] first={:?} has_nan={has_nan}",
&out[..4.min(out.len())]
);
assert!(!has_nan, "FusedResidualLN produced NaN");
}
#[test]
fn bisect_wgpu_full_bert_layer() {
if !rlx_wgpu::is_available() {
return;
}
use rlx_ir::op::{Activation, BinaryOp};
let f = DType::F32;
let (b, s, nh, dh) = (1, 3, 2, 4);
let h = nh * dh;
let intermediate = h * 4;
let mut g = Graph::new("bert_layer");
let ids = g.input("ids", Shape::new(&[b, s], f));
let mask = g.input("mask", Shape::new(&[b, s], f));
let emb = g.param("emb", Shape::new(&[16, h], f));
let h0 = g.add_node(
Op::Gather { axis: 0 },
vec![emb, ids],
Shape::new(&[b, s, h], f),
);
let qkv_w = g.param("qkv_w", Shape::new(&[h, 3 * h], f));
let qkv_b = g.param("qkv_b", Shape::new(&[3 * h], f));
let qkv_mm = g.add_node(Op::MatMul, vec![h0, qkv_w], Shape::new(&[b, s, 3 * h], f));
let qkv = g.binary(BinaryOp::Add, qkv_mm, qkv_b, Shape::new(&[b, s, 3 * h], f));
let q = g.add_node(
Op::Narrow {
axis: 2,
start: 0,
len: h,
},
vec![qkv],
Shape::new(&[b, s, h], f),
);
let k = g.add_node(
Op::Narrow {
axis: 2,
start: h,
len: h,
},
vec![qkv],
Shape::new(&[b, s, h], f),
);
let v = g.add_node(
Op::Narrow {
axis: 2,
start: 2 * h,
len: h,
},
vec![qkv],
Shape::new(&[b, s, h], f),
);
let attn = g.add_node(
Op::Attention {
num_heads: nh,
head_dim: dh,
mask_kind: MaskKind::Custom,
score_scale: None,
attn_logit_softcap: None,
},
vec![q, k, v, mask],
Shape::new(&[b, s, h], f),
);
let out_w = g.param("out_w", Shape::new(&[h, h], f));
let out_b = g.param("out_b", Shape::new(&[h], f));
let attn_out_mm = g.add_node(Op::MatMul, vec![attn, out_w], Shape::new(&[b, s, h], f));
let attn_out = g.binary(BinaryOp::Add, attn_out_mm, out_b, Shape::new(&[b, s, h], f));
let ln1_g = g.param("ln1_g", Shape::new(&[h], f));
let ln1_b = g.param("ln1_b", Shape::new(&[h], f));
let res1 = g.add_node(
Op::FusedResidualLN {
has_bias: false,
eps: 1e-5,
},
vec![attn_out, h0, ln1_g, ln1_b],
Shape::new(&[b, s, h], f),
);
let ffn1_w = g.param("ffn1_w", Shape::new(&[h, intermediate], f));
let ffn1_b = g.param("ffn1_b", Shape::new(&[intermediate], f));
let ffn1_mm = g.add_node(
Op::MatMul,
vec![res1, ffn1_w],
Shape::new(&[b, s, intermediate], f),
);
let ffn1_bias = g.binary(
BinaryOp::Add,
ffn1_mm,
ffn1_b,
Shape::new(&[b, s, intermediate], f),
);
let ffn1_gelu = g.add_node(
Op::Activation(Activation::Gelu),
vec![ffn1_bias],
Shape::new(&[b, s, intermediate], f),
);
let ffn2_w = g.param("ffn2_w", Shape::new(&[intermediate, h], f));
let ffn2_b = g.param("ffn2_b", Shape::new(&[h], f));
let ffn2_mm = g.add_node(
Op::MatMul,
vec![ffn1_gelu, ffn2_w],
Shape::new(&[b, s, h], f),
);
let ffn2_out = g.binary(BinaryOp::Add, ffn2_mm, ffn2_b, Shape::new(&[b, s, h], f));
let ln2_g = g.param("ln2_g", Shape::new(&[h], f));
let ln2_b = g.param("ln2_b", Shape::new(&[h], f));
let res2 = g.add_node(
Op::FusedResidualLN {
has_bias: false,
eps: 1e-5,
},
vec![ffn2_out, res1, ln2_g, ln2_b],
Shape::new(&[b, s, h], f),
);
g.set_outputs(vec![res2]);
let ids_v = vec![1.0f32, 2.0, 3.0];
let mask_v = vec![1.0f32; b * s];
let emb_v: Vec<f32> = (0..(16 * h)).map(|i| (i as f32) * 0.01).collect();
let qkv_w_v: Vec<f32> = (0..(h * 3 * h)).map(|i| ((i % 7) as f32) * 0.05).collect();
let qkv_b_v: Vec<f32> = (0..(3 * h)).map(|i| (i as f32) * 0.001).collect();
let out_w_v: Vec<f32> = (0..(h * h)).map(|i| ((i % 5) as f32) * 0.05).collect();
let out_b_v: Vec<f32> = (0..h).map(|i| (i as f32) * 0.001).collect();
let ln1_g_v = vec![1.0f32; h];
let ln1_b_v = vec![0.0f32; h];
let ffn1_w_v: Vec<f32> = (0..(h * intermediate))
.map(|i| ((i % 9) as f32) * 0.02)
.collect();
let ffn1_b_v: Vec<f32> = (0..intermediate).map(|i| (i as f32) * 0.001).collect();
let ffn2_w_v: Vec<f32> = (0..(intermediate * h))
.map(|i| ((i % 11) as f32) * 0.02)
.collect();
let ffn2_b_v: Vec<f32> = (0..h).map(|i| (i as f32) * 0.001).collect();
let ln2_g_v = vec![1.0f32; h];
let ln2_b_v = vec![0.0f32; h];
let mut exe = WgpuExecutable::compile(g);
exe.set_param("emb", &emb_v);
exe.set_param("qkv_w", &qkv_w_v);
exe.set_param("qkv_b", &qkv_b_v);
exe.set_param("out_w", &out_w_v);
exe.set_param("out_b", &out_b_v);
exe.set_param("ln1_g", &ln1_g_v);
exe.set_param("ln1_b", &ln1_b_v);
exe.set_param("ffn1_w", &ffn1_w_v);
exe.set_param("ffn1_b", &ffn1_b_v);
exe.set_param("ffn2_w", &ffn2_w_v);
exe.set_param("ffn2_b", &ffn2_b_v);
exe.set_param("ln2_g", &ln2_g_v);
exe.set_param("ln2_b", &ln2_b_v);
let out = exe
.run(&[("ids", &ids_v), ("mask", &mask_v)])
.into_iter()
.next()
.unwrap();
let nan_count = out.iter().filter(|v| v.is_nan()).count();
eprintln!(
"[bisect:full_bert_layer] len={} nan_count={}/{} first={:?}",
out.len(),
nan_count,
out.len(),
&out[..4.min(out.len())]
);
assert_eq!(
nan_count, 0,
"full BERT layer produced {nan_count} NaN values"
);
}
#[test]
fn bisect_wgpu_full_bert_realistic_dim() {
if !rlx_wgpu::is_available() {
return;
}
use rlx_ir::op::{Activation, BinaryOp};
let f = DType::F32;
let (b, s, nh, dh) = (1, 6, 12, 32);
let h = nh * dh; let intermediate = h * 4;
let mut g = Graph::new("bert_real");
let ids = g.input("ids", Shape::new(&[b, s], f));
let mask = g.input("mask", Shape::new(&[b, s], f));
let emb = g.param("emb", Shape::new(&[100, h], f));
let h0 = g.add_node(
Op::Gather { axis: 0 },
vec![emb, ids],
Shape::new(&[b, s, h], f),
);
let qkv_w = g.param("qkv_w", Shape::new(&[h, 3 * h], f));
let qkv_b = g.param("qkv_b", Shape::new(&[3 * h], f));
let qkv_mm = g.add_node(Op::MatMul, vec![h0, qkv_w], Shape::new(&[b, s, 3 * h], f));
let qkv = g.binary(BinaryOp::Add, qkv_mm, qkv_b, Shape::new(&[b, s, 3 * h], f));
let q = g.add_node(
Op::Narrow {
axis: 2,
start: 0,
len: h,
},
vec![qkv],
Shape::new(&[b, s, h], f),
);
let k = g.add_node(
Op::Narrow {
axis: 2,
start: h,
len: h,
},
vec![qkv],
Shape::new(&[b, s, h], f),
);
let v = g.add_node(
Op::Narrow {
axis: 2,
start: 2 * h,
len: h,
},
vec![qkv],
Shape::new(&[b, s, h], f),
);
let attn = g.add_node(
Op::Attention {
num_heads: nh,
head_dim: dh,
mask_kind: MaskKind::Custom,
score_scale: None,
attn_logit_softcap: None,
},
vec![q, k, v, mask],
Shape::new(&[b, s, h], f),
);
let out_w = g.param("out_w", Shape::new(&[h, h], f));
let out_b = g.param("out_b", Shape::new(&[h], f));
let attn_out_mm = g.add_node(Op::MatMul, vec![attn, out_w], Shape::new(&[b, s, h], f));
let attn_out = g.binary(BinaryOp::Add, attn_out_mm, out_b, Shape::new(&[b, s, h], f));
let ln1_g = g.param("ln1_g", Shape::new(&[h], f));
let ln1_b = g.param("ln1_b", Shape::new(&[h], f));
let res1 = g.add_node(
Op::FusedResidualLN {
has_bias: false,
eps: 1e-5,
},
vec![attn_out, h0, ln1_g, ln1_b],
Shape::new(&[b, s, h], f),
);
let ffn1_w = g.param("ffn1_w", Shape::new(&[h, intermediate], f));
let ffn1_b = g.param("ffn1_b", Shape::new(&[intermediate], f));
let ffn1_mm = g.add_node(
Op::MatMul,
vec![res1, ffn1_w],
Shape::new(&[b, s, intermediate], f),
);
let ffn1_bias = g.binary(
BinaryOp::Add,
ffn1_mm,
ffn1_b,
Shape::new(&[b, s, intermediate], f),
);
let ffn1_gelu = g.add_node(
Op::Activation(Activation::Gelu),
vec![ffn1_bias],
Shape::new(&[b, s, intermediate], f),
);
let ffn2_w = g.param("ffn2_w", Shape::new(&[intermediate, h], f));
let ffn2_b = g.param("ffn2_b", Shape::new(&[h], f));
let ffn2_mm = g.add_node(
Op::MatMul,
vec![ffn1_gelu, ffn2_w],
Shape::new(&[b, s, h], f),
);
let ffn2_out = g.binary(BinaryOp::Add, ffn2_mm, ffn2_b, Shape::new(&[b, s, h], f));
let ln2_g = g.param("ln2_g", Shape::new(&[h], f));
let ln2_b = g.param("ln2_b", Shape::new(&[h], f));
let res2 = g.add_node(
Op::FusedResidualLN {
has_bias: false,
eps: 1e-5,
},
vec![ffn2_out, res1, ln2_g, ln2_b],
Shape::new(&[b, s, h], f),
);
g.set_outputs(vec![res2]);
let ids_v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let mask_v = vec![1.0f32; b * s];
let small = |n: usize, scale: f32| -> Vec<f32> {
(0..n).map(|i| ((i % 31) as f32 - 15.0) * scale).collect()
};
let emb_v = small(100 * h, 0.01);
let qkv_w_v = small(h * 3 * h, 0.01);
let qkv_b_v = small(3 * h, 0.001);
let out_w_v = small(h * h, 0.01);
let out_b_v = small(h, 0.001);
let ln1_g_v = vec![1.0f32; h];
let ln1_b_v = vec![0.0f32; h];
let ffn1_w_v = small(h * intermediate, 0.01);
let ffn1_b_v = small(intermediate, 0.001);
let ffn2_w_v = small(intermediate * h, 0.01);
let ffn2_b_v = small(h, 0.001);
let ln2_g_v = vec![1.0f32; h];
let ln2_b_v = vec![0.0f32; h];
let mut exe = WgpuExecutable::compile(g);
exe.set_param("emb", &emb_v);
exe.set_param("qkv_w", &qkv_w_v);
exe.set_param("qkv_b", &qkv_b_v);
exe.set_param("out_w", &out_w_v);
exe.set_param("out_b", &out_b_v);
exe.set_param("ln1_g", &ln1_g_v);
exe.set_param("ln1_b", &ln1_b_v);
exe.set_param("ffn1_w", &ffn1_w_v);
exe.set_param("ffn1_b", &ffn1_b_v);
exe.set_param("ffn2_w", &ffn2_w_v);
exe.set_param("ffn2_b", &ffn2_b_v);
exe.set_param("ln2_g", &ln2_g_v);
exe.set_param("ln2_b", &ln2_b_v);
let out = exe
.run(&[("ids", &ids_v), ("mask", &mask_v)])
.into_iter()
.next()
.unwrap();
let nan_count = out.iter().filter(|v| v.is_nan()).count();
eprintln!(
"[bisect:bert_realistic h=384] len={} nan={}/{} first={:?}",
out.len(),
nan_count,
out.len(),
&out[..4.min(out.len())]
);
assert_eq!(
nan_count,
0,
"full BERT layer at realistic dim produced {nan_count}/{} NaN",
out.len()
);
}
#[test]
fn bisect_wgpu_full_bert_layer_stack() {
if !rlx_wgpu::is_available() {
return;
}
let f = DType::F32;
let mut g = Graph::new("bert_real_2layer");
let (b, s, h, nh, dh, n_layers, vocab) = (1, 4, 32, 4, 8, 2, 100);
let intermediate = h * 4;
use rlx_ir::op::{Activation, BinaryOp};
let ids = g.input("ids", Shape::new(&[b, s], f));
let mask = g.input("mask", Shape::new(&[b, s], f));
let emb = g.param("emb", Shape::new(&[vocab, h], f));
let mut h_id = g.add_node(
Op::Gather { axis: 0 },
vec![emb, ids],
Shape::new(&[b, s, h], f),
);
for l in 0..n_layers {
let qkv_w = g.param(format!("qkv_w_{l}"), Shape::new(&[h, 3 * h], f));
let qkv_b = g.param(format!("qkv_b_{l}"), Shape::new(&[3 * h], f));
let qkv_mm = g.add_node(Op::MatMul, vec![h_id, qkv_w], Shape::new(&[b, s, 3 * h], f));
let qkv = g.binary(BinaryOp::Add, qkv_mm, qkv_b, Shape::new(&[b, s, 3 * h], f));
let q = g.add_node(
Op::Narrow {
axis: 2,
start: 0,
len: h,
},
vec![qkv],
Shape::new(&[b, s, h], f),
);
let k = g.add_node(
Op::Narrow {
axis: 2,
start: h,
len: h,
},
vec![qkv],
Shape::new(&[b, s, h], f),
);
let v = g.add_node(
Op::Narrow {
axis: 2,
start: 2 * h,
len: h,
},
vec![qkv],
Shape::new(&[b, s, h], f),
);
let attn = g.add_node(
Op::Attention {
num_heads: nh,
head_dim: dh,
mask_kind: MaskKind::Custom,
score_scale: None,
attn_logit_softcap: None,
},
vec![q, k, v, mask],
Shape::new(&[b, s, h], f),
);
let out_w = g.param(format!("out_w_{l}"), Shape::new(&[h, h], f));
let out_b = g.param(format!("out_b_{l}"), Shape::new(&[h], f));
let attn_mm = g.add_node(Op::MatMul, vec![attn, out_w], Shape::new(&[b, s, h], f));
let attn_out = g.binary(BinaryOp::Add, attn_mm, out_b, Shape::new(&[b, s, h], f));
let ln1_g = g.param(format!("ln1_g_{l}"), Shape::new(&[h], f));
let ln1_b = g.param(format!("ln1_b_{l}"), Shape::new(&[h], f));
let res1 = g.add_node(
Op::FusedResidualLN {
has_bias: false,
eps: 1e-5,
},
vec![attn_out, h_id, ln1_g, ln1_b],
Shape::new(&[b, s, h], f),
);
let ffn1_w = g.param(format!("ffn1_w_{l}"), Shape::new(&[h, intermediate], f));
let ffn1_b = g.param(format!("ffn1_b_{l}"), Shape::new(&[intermediate], f));
let ffn1_mm = g.add_node(
Op::MatMul,
vec![res1, ffn1_w],
Shape::new(&[b, s, intermediate], f),
);
let ffn1_bias = g.binary(
BinaryOp::Add,
ffn1_mm,
ffn1_b,
Shape::new(&[b, s, intermediate], f),
);
let ffn1_gelu = g.add_node(
Op::Activation(Activation::Gelu),
vec![ffn1_bias],
Shape::new(&[b, s, intermediate], f),
);
let ffn2_w = g.param(format!("ffn2_w_{l}"), Shape::new(&[intermediate, h], f));
let ffn2_b = g.param(format!("ffn2_b_{l}"), Shape::new(&[h], f));
let ffn2_mm = g.add_node(
Op::MatMul,
vec![ffn1_gelu, ffn2_w],
Shape::new(&[b, s, h], f),
);
let ffn2_out = g.binary(BinaryOp::Add, ffn2_mm, ffn2_b, Shape::new(&[b, s, h], f));
let ln2_g = g.param(format!("ln2_g_{l}"), Shape::new(&[h], f));
let ln2_b = g.param(format!("ln2_b_{l}"), Shape::new(&[h], f));
h_id = g.add_node(
Op::FusedResidualLN {
has_bias: false,
eps: 1e-5,
},
vec![ffn2_out, res1, ln2_g, ln2_b],
Shape::new(&[b, s, h], f),
);
}
g.set_outputs(vec![h_id]);
let small = |n: usize, scale: f32| -> Vec<f32> {
(0..n).map(|i| ((i % 31) as f32 - 15.0) * scale).collect()
};
let mut exe = WgpuExecutable::compile(g);
exe.set_param("emb", &small(vocab * h, 0.01));
for l in 0..n_layers {
exe.set_param(&format!("qkv_w_{l}"), &small(h * 3 * h, 0.01));
exe.set_param(&format!("qkv_b_{l}"), &small(3 * h, 0.001));
exe.set_param(&format!("out_w_{l}"), &small(h * h, 0.01));
exe.set_param(&format!("out_b_{l}"), &small(h, 0.001));
exe.set_param(&format!("ln1_g_{l}"), &vec![1.0f32; h]);
exe.set_param(&format!("ln1_b_{l}"), &vec![0.0f32; h]);
exe.set_param(&format!("ffn1_w_{l}"), &small(h * intermediate, 0.01));
exe.set_param(&format!("ffn1_b_{l}"), &small(intermediate, 0.001));
exe.set_param(&format!("ffn2_w_{l}"), &small(intermediate * h, 0.01));
exe.set_param(&format!("ffn2_b_{l}"), &small(h, 0.001));
exe.set_param(&format!("ln2_g_{l}"), &vec![1.0f32; h]);
exe.set_param(&format!("ln2_b_{l}"), &vec![0.0f32; h]);
}
let ids_v = vec![1.0f32, 2.0, 3.0, 4.0];
let mask_v = vec![1.0f32; b * s];
let out = exe
.run(&[("ids", &ids_v), ("mask", &mask_v)])
.into_iter()
.next()
.unwrap();
let nan_count = out.iter().filter(|v| v.is_nan()).count();
eprintln!(
"[bisect:bert_2layer] len={} nan={}/{} first={:?}",
out.len(),
nan_count,
out.len(),
&out[..4.min(out.len())]
);
assert_eq!(
nan_count,
0,
"2-layer BERT produced {nan_count}/{} NaN",
out.len()
);
}
#[test]
fn bisect_wgpu_bert_input_prep() {
if !rlx_wgpu::is_available() {
return;
}
use rlx_ir::op::BinaryOp;
let f = DType::F32;
let (b, s, h) = (1, 4, 16);
let vocab = 100;
let mut g = Graph::new("bert_input_prep");
let ids = g.input("ids", Shape::new(&[b, s], f));
let pos_ids = g.input("pos_ids", Shape::new(&[b, s], f));
let tt_ids = g.input("tt_ids", Shape::new(&[b, s], f));
let word_emb = g.param("word_emb", Shape::new(&[vocab, h], f));
let pos_emb = g.param("pos_emb", Shape::new(&[vocab, h], f));
let tt_emb = g.param("tt_emb", Shape::new(&[2, h], f));
let ln_g = g.param("ln_g", Shape::new(&[h], f));
let ln_b = g.param("ln_b", Shape::new(&[h], f));
let word_out = g.add_node(
Op::Gather { axis: 0 },
vec![word_emb, ids],
Shape::new(&[b, s, h], f),
);
let pos_out = g.add_node(
Op::Gather { axis: 0 },
vec![pos_emb, pos_ids],
Shape::new(&[b, s, h], f),
);
let tt_out = g.add_node(
Op::Gather { axis: 0 },
vec![tt_emb, tt_ids],
Shape::new(&[b, s, h], f),
);
let wp = g.binary(BinaryOp::Add, word_out, pos_out, Shape::new(&[b, s, h], f));
let sum = g.binary(BinaryOp::Add, wp, tt_out, Shape::new(&[b, s, h], f));
let ln = g.add_node(
Op::LayerNorm {
axis: -1,
eps: 1e-5,
},
vec![sum, ln_g, ln_b],
Shape::new(&[b, s, h], f),
);
g.set_outputs(vec![ln]);
let small = |n: usize, scale: f32| -> Vec<f32> {
(0..n).map(|i| ((i % 31) as f32 - 15.0) * scale).collect()
};
let mut exe = WgpuExecutable::compile(g);
exe.set_param("word_emb", &small(vocab * h, 0.01));
exe.set_param("pos_emb", &small(vocab * h, 0.01));
exe.set_param("tt_emb", &small(2 * h, 0.01));
exe.set_param("ln_g", &vec![1.0f32; h]);
exe.set_param("ln_b", &vec![0.0f32; h]);
let ids_v = vec![1.0f32, 2.0, 3.0, 4.0];
let pos_ids_v = vec![0.0f32, 1.0, 2.0, 3.0];
let tt_ids_v = vec![0.0f32, 0.0, 0.0, 0.0];
let out = exe
.run(&[
("ids", &ids_v),
("pos_ids", &pos_ids_v),
("tt_ids", &tt_ids_v),
])
.into_iter()
.next()
.unwrap();
let nan_count = out.iter().filter(|v| v.is_nan()).count();
eprintln!(
"[bisect:bert_input_prep] len={} nan={}/{} first={:?}",
out.len(),
nan_count,
out.len(),
&out[..4.min(out.len())]
);
assert_eq!(
nan_count,
0,
"BERT input prep produced {nan_count}/{} NaN",
out.len()
);
}
#[test]
fn region_relu_matches_atomic() {
if !rlx_wgpu::is_available() {
return;
}
use rlx_ir::op::{ChainOperand, ChainStep};
let mut g_reg = Graph::new("relu_region");
let xr = g_reg.input("x", Shape::new(&[8], DType::F32));
let chain = vec![ChainStep::Activation(
Activation::Relu,
ChainOperand::Input(0),
)];
let region = g_reg.add_node(
Op::ElementwiseRegion {
chain,
num_inputs: 1,
scalar_input_mask: 0,
input_modulus: [0u32; 16],
prologue: rlx_ir::RegionPrologue::None,
prologue_input: 0,
},
vec![xr],
Shape::new(&[8], DType::F32),
);
g_reg.set_outputs(vec![region]);
let xs = vec![-1.0f32, 0.0, 0.5, 1.0, -2.0, 3.0, -0.5, 2.5];
let mut reg = WgpuExecutable::compile(g_reg);
let got_reg = reg.run(&[("x", &xs)]).into_iter().next().unwrap();
let want: Vec<f32> = xs.iter().map(|v| v.max(0.0)).collect();
assert!(
close(&got_reg, &want, 1e-5),
"region mismatch: got {got_reg:?} want {want:?}"
);
}
#[test]
fn matmul_2x3x2_matches_cpu_reference() {
if !rlx_wgpu::is_available() {
eprintln!("rlx-wgpu: no compatible adapter; skipping test");
return;
}
let g = build_graph();
let mut exe = WgpuExecutable::compile(g);
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let w = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6];
exe.set_param("w", &w);
let outs = exe.run(&[("x", &x)]);
assert_eq!(outs.len(), 1);
let want = matmul_ref(&x, &w, 2, 3, 2);
assert!(
close(&outs[0], &want, 1e-4),
"matmul mismatch: got {:?} want {want:?}",
outs[0]
);
}
#[test]
fn welch_peaks_gpu_matches_cpu_reference() {
if !rlx_wgpu::is_available() {
return;
}
let batch = 8usize;
let n_fft = 256usize;
let n_segments = 2usize;
let k = 16usize;
let seg_batch = batch * n_segments;
let row_len = n_fft * 2;
let mut spectrum = vec![0f32; seg_batch * row_len];
for i in 0..spectrum.len() {
spectrum[i] = ((i as f32) * 0.013).sin() * 0.5 + 0.01 * (i as f32).cos();
}
let mut g = Graph::new("welch_peaks");
let spec_in = g.input("spec", Shape::new(&[seg_batch, row_len], DType::F32));
let peaks = g.welch_peaks(spec_in, k, n_segments);
g.set_outputs(vec![peaks]);
let mut exe = WgpuExecutable::compile(g);
let gpu_out = exe.run(&[("spec", &spectrum)]).remove(0);
let mut ref_out = vec![0f32; batch * k * 2];
rlx_ir::audio::welch_peaks_block_f32(&spectrum, batch, n_fft, n_segments, k, &mut ref_out);
assert!(
close(&gpu_out, &ref_out, 1e-4),
"welch_peaks_gpu mismatch max={:.3e}",
gpu_out
.iter()
.zip(ref_out.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max)
);
}