use metaltile::kernel;
use metaltile_core::ir::KernelMode;
use crate::{
bench_types::DType,
spec::{BenchDispatch, BenchSpec},
};
macro_rules! ssm_replay_spec {
($name:ident, $subop:literal) => {
inventory::submit! {
BenchSpec {
op: "ssm_replay",
subop: $subop,
kernel_name: stringify!($name),
kernel_ir: $name::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 1e-3,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Grid3D),
}
}
};
}
macro_rules! ssm_step_record {
($name:ident, $dh:literal, $ds:literal, $h:literal, $g:literal, $n_per_t:literal, $subop:literal) => {
#[kernel]
pub fn $name<T>(
x: Tensor<T>,
a_log: Tensor<T>,
b: Tensor<T>,
c: Tensor<T>,
d: Tensor<T>,
dt: Tensor<T>,
state_in: Tensor<T>,
mask: Tensor<u32>,
mut y: Tensor<T>,
mut state_out: Tensor<T>,
mut da_log: Tensor<T>,
mut dbx_log: Tensor<T>,
#[constexpr] t_total: u32,
#[constexpr] has_mask: u32,
) {
let ds_lane = program_id::<0>();
let d_idx = program_id::<1>();
let n = program_id::<2>();
let h_idx = n - (n / $h) * $h;
let b_idx = n / $h;
let g_idx = h_idx / ($h / $g);
let state_base = (n * $dh + d_idx) * $ds;
stack_alloc("state", $n_per_t, "f32");
for i in range(0u32, $n_per_t, 1u32) {
let v = load(state_in[state_base + $n_per_t * ds_lane + i]).cast::<f32>();
stack_store("state", i, v);
}
let a_neg = 0.0f32 - exp(load(a_log[h_idx]).cast::<f32>());
for t in range(0u32, t_total, 1u32) {
let bt = b_idx * t_total + t;
let bt_h = bt * $h + h_idx;
let bt_g = bt * $g + g_idx;
let active = select(has_mask == 0u32, 1u32, load(mask[bt]));
let dt_raw = load(dt[bt_h]).cast::<f32>();
let dt_eff = select(active > 0u32, dt_raw, 0.0f32);
let d_a = select(active > 0u32, exp(a_neg * dt_raw), 1.0f32);
for i in range(0u32, $n_per_t, 1u32) {
store(da_log[bt_h * $ds + $n_per_t * ds_lane + i], d_a.cast::<T>());
}
let x_v = load(x[bt_h * $dh + d_idx]).cast::<f32>();
let dbx_base = (bt_h * $dh + d_idx) * $ds;
let mut y_acc = 0.0f32;
for i in range(0u32, $n_per_t, 1u32) {
let s_idx = $n_per_t * ds_lane + i;
let b_v = load(b[bt_g * $ds + s_idx]).cast::<f32>();
let dbx = x_v * dt_eff * b_v;
store(dbx_log[dbx_base + s_idx], dbx.cast::<T>());
let st = d_a * stack_load("state", i) + dbx;
stack_store("state", i, st);
y_acc = y_acc + st * load(c[bt_g * $ds + s_idx]).cast::<f32>();
}
let y_sum = simd_sum(y_acc);
if ds_lane == 0u32 {
let y_d = y_sum + x_v * load(d[h_idx]).cast::<f32>();
store(y[bt_h * $dh + d_idx], y_d.cast::<T>());
}
}
for i in range(0u32, $n_per_t, 1u32) {
let st = stack_load("state", i);
store(state_out[state_base + $n_per_t * ds_lane + i], st.cast::<T>());
}
}
ssm_replay_spec!($name, $subop);
};
}
macro_rules! ssm_replay {
($name:ident, $dh:literal, $ds:literal, $h:literal, $n_per_t:literal, $subop:literal) => {
#[kernel]
pub fn $name<T>(
state_snapshot: Tensor<T>,
da_log: Tensor<T>,
dbx_log: Tensor<T>,
mask: Tensor<u32>,
mut state_after_k: Tensor<T>,
#[constexpr] k_steps: u32,
#[constexpr] t_total: u32,
#[constexpr] has_mask: u32,
) {
let ds_lane = program_id::<0>();
let d_idx = program_id::<1>();
let n = program_id::<2>();
let h_idx = n - (n / $h) * $h;
let b_idx = n / $h;
let state_base = (n * $dh + d_idx) * $ds;
stack_alloc("state", $n_per_t, "f32");
for i in range(0u32, $n_per_t, 1u32) {
let v = load(state_snapshot[state_base + $n_per_t * ds_lane + i]).cast::<f32>();
stack_store("state", i, v);
}
for t in range(0u32, k_steps, 1u32) {
let bt = b_idx * t_total + t;
let bt_h = bt * $h + h_idx;
let active = select(has_mask == 0u32, 1u32, load(mask[bt]));
let dbx_base = (bt_h * $dh + d_idx) * $ds;
for i in range(0u32, $n_per_t, 1u32) {
let s_idx = $n_per_t * ds_lane + i;
let old = stack_load("state", i);
let d_a = load(da_log[bt_h * $ds + s_idx]).cast::<f32>();
let dbx = load(dbx_log[dbx_base + s_idx]).cast::<f32>();
let new_val = d_a * old + dbx;
stack_store("state", i, select(active > 0u32, new_val, old));
}
}
for i in range(0u32, $n_per_t, 1u32) {
let st = stack_load("state", i);
store(state_after_k[state_base + $n_per_t * ds_lane + i], st.cast::<T>());
}
}
ssm_replay_spec!($name, $subop);
};
}
ssm_step_record!(ssm_step_record_d16_64_4_2, 16u32, 64u32, 4u32, 2u32, 2u32, "record_d16_64_4_2");
ssm_replay!(ssm_replay_d16_64_4, 16u32, 64u32, 4u32, 2u32, "replay_d16_64_4");
ssm_step_record!(
ssm_step_record_d128_128_32_2,
128u32,
128u32,
32u32,
2u32,
4u32,
"record_d128_128_32_2"
);
ssm_replay!(ssm_replay_d128_128_32, 128u32, 128u32, 32u32, 4u32, "replay_d128_128_32");