#![cfg(feature = "cpu")]
use rlx_ir::{DType, Graph, Shape};
use rlx_runtime::{Device, Session};
fn build_ssm_graph(b: usize, s: usize, h: usize, n: usize) -> Graph {
let mut g = Graph::new("ssm");
let bsh = Shape::new(&[b, s, h], DType::F32);
let hn = Shape::new(&[h, n], DType::F32);
let bsn = Shape::new(&[b, s, n], DType::F32);
let x = g.input("x", bsh.clone());
let delta = g.input("delta", bsh.clone());
let a = g.input("a", hn);
let b_in = g.input("b", bsn.clone());
let c_in = g.input("c", bsn);
let y = g.selective_scan(x, delta, a, b_in, c_in, n, bsh);
g.set_outputs(vec![y]);
g
}
#[test]
fn cpu_selective_scan_native_matches_recurrence() {
let (b, s, h, n) = (1, 4, 2, 3);
let nx = b * s * h;
let nd = b * s * h;
let na = h * n;
let nb = b * s * n;
let xs: Vec<f32> = (0..nx).map(|i| 0.1 + 0.05 * (i as f32)).collect();
let delta: Vec<f32> = (0..nd).map(|i| 0.1 + 0.02 * (i as f32)).collect();
let a_data: Vec<f32> = (0..na).map(|i| -0.5 + 0.1 * (i as f32)).collect();
let b_data: Vec<f32> = (0..nb).map(|i| 0.1 + 0.03 * (i as f32)).collect();
let c_data: Vec<f32> = (0..nb).map(|i| 0.2 + 0.04 * (i as f32)).collect();
let g_native = build_ssm_graph(b, s, h, n);
let session = Session::new(Device::Cpu);
let mut native = session.compile(g_native);
let native_out = native.run(&[
("x", &xs),
("delta", &delta),
("a", &a_data),
("b", &b_data),
("c", &c_data),
]);
let mut want = vec![0f32; b * s * h];
let mut state = vec![0f32; h * n];
for bi in 0..b {
for v in state.iter_mut() {
*v = 0.0;
}
for si in 0..s {
for ci in 0..h {
let d = delta[bi * s * h + si * h + ci];
let xv = xs[bi * s * h + si * h + ci];
let mut acc = 0.0f32;
for ni in 0..n {
let da = (d * a_data[ci * n + ni]).exp();
state[ci * n + ni] =
da * state[ci * n + ni] + d * b_data[bi * s * n + si * n + ni] * xv;
acc += c_data[bi * s * n + si * n + ni] * state[ci * n + ni];
}
want[bi * s * h + si * h + ci] = acc;
}
}
}
let got = &native_out[0];
assert_eq!(
got.len(),
want.len(),
"SelectiveScan output length mismatch: got {} want {}",
got.len(),
want.len()
);
for (i, (g, w)) in got.iter().zip(want.iter()).enumerate() {
let abs_err = (g - w).abs();
let rel_err = abs_err / (w.abs().max(1e-6));
assert!(
abs_err < 1e-5 || rel_err < 1e-5,
"SelectiveScan parity diverges at idx {i}: native {g} vs scalar reference {w} (abs {abs_err:e}, rel {rel_err:e})"
);
}
}