use metaltile::kernel;
use metaltile_core::ir::KernelMode;
use crate::{
bench_types::DType,
spec::{BenchDispatch, BenchSpec},
};
macro_rules! gdr_spec {
($name:ident, $subop:literal) => {
inventory::submit! {
BenchSpec {
op: "gated_delta_replay",
subop: $subop,
kernel_name: stringify!($name),
kernel_ir: $name::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 1e-3,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Grid3D),
}
}
};
}
macro_rules! gated_delta_record {
($name:ident, $dk:literal, $dv:literal, $hk:literal, $hv:literal, $n_per_t:literal, $subop:literal) => {
#[kernel]
pub fn $name<T>(
q: Tensor<T>,
k: Tensor<T>,
v: Tensor<T>,
g: Tensor<T>,
beta: Tensor<T>,
state_in: Tensor<T>,
mask: Tensor<u32>,
mut y: Tensor<T>,
mut state_out: Tensor<T>,
mut delta_log: Tensor<T>,
#[constexpr] t_val: u32,
#[constexpr] has_mask: u32,
) {
let lane = program_id::<0>();
let dv_idx = program_id::<1>();
let n = program_id::<2>();
let b_idx = n / $hv;
let hv_idx = n - b_idx * $hv;
let hk_idx = hv_idx / ($hv / $hk);
let i_state_base = (n * $dv + dv_idx) * $dk;
stack_alloc("state", $n_per_t, "f32");
for i in range(0u32, $n_per_t, 1u32) {
let v = load(state_in[i_state_base + $n_per_t * lane + i]).cast::<f32>();
stack_store("state", i, v);
}
for t in range(0u32, t_val, 1u32) {
let m = select(has_mask == 0u32, 1u32, load(mask[b_idx * t_val + t]));
if m > 0u32 {
let qk_base = (b_idx * t_val + t) * $hk * $dk + hk_idx * $dk;
let v_base = (b_idx * t_val + t) * $hv * $dv + hv_idx * $dv;
let gb_idx = (b_idx * t_val + t) * $hv + hv_idx;
let g_val = load(g[gb_idx]).cast::<f32>();
let beta_val = load(beta[gb_idx]).cast::<f32>();
let mut kv_mem = 0.0f32;
for i in range(0u32, $n_per_t, 1u32) {
let s_idx = $n_per_t * lane + i;
let st = stack_load("state", i) * g_val;
stack_store("state", i, st);
kv_mem = kv_mem + st * load(k[qk_base + s_idx]).cast::<f32>();
}
let kv = simd_sum(kv_mem);
let delta = (load(v[v_base + dv_idx]).cast::<f32>() - kv) * beta_val;
if lane == 0u32 {
store(delta_log[v_base + dv_idx], delta.cast::<T>());
}
let mut out_acc = 0.0f32;
for i in range(0u32, $n_per_t, 1u32) {
let s_idx = $n_per_t * lane + i;
let st =
stack_load("state", i) + load(k[qk_base + s_idx]).cast::<f32>() * delta;
stack_store("state", i, st);
out_acc = out_acc + st * load(q[qk_base + s_idx]).cast::<f32>();
}
let out_red = simd_sum(out_acc);
if lane == 0u32 {
store(y[v_base + dv_idx], out_red.cast::<T>());
}
}
}
for i in range(0u32, $n_per_t, 1u32) {
let st = stack_load("state", i);
store(state_out[i_state_base + $n_per_t * lane + i], st.cast::<T>());
}
}
gdr_spec!($name, $subop);
};
}
macro_rules! state_replay {
($name:ident, $dk:literal, $dv:literal, $hv:literal, $n_per_t:literal, $subop:literal) => {
#[kernel]
pub fn $name<T>(
delta_log: Tensor<T>,
k_log: Tensor<T>,
g_log: Tensor<T>,
state_in: Tensor<T>,
mask: Tensor<u32>,
mut state_out: Tensor<T>,
#[constexpr] t_log: u32,
#[constexpr] accepted: u32,
#[constexpr] has_mask: u32,
) {
let lane = program_id::<0>();
let dv_idx = program_id::<1>();
let n = program_id::<2>();
let b_idx = n / $hv;
let hv_idx = n - b_idx * $hv;
let i_state_base = (n * $dv + dv_idx) * $dk;
stack_alloc("state", $n_per_t, "f32");
for i in range(0u32, $n_per_t, 1u32) {
let v = load(state_in[i_state_base + $n_per_t * lane + i]).cast::<f32>();
stack_store("state", i, v);
}
for t in range(0u32, t_log, 1u32) {
let mask_v = select(has_mask == 0u32, 1u32, load(mask[b_idx * t_log + t]));
let do_step = select(t < accepted, mask_v, 0u32);
let delta_row = (b_idx * t_log + t) * $hv * $dv + hv_idx * $dv;
let k_row = (b_idx * t_log + t) * $hv * $dk + hv_idx * $dk;
let g_idx = (b_idx * t_log + t) * $hv + hv_idx;
let g_val = load(g_log[g_idx]).cast::<f32>();
let d_val = load(delta_log[delta_row + dv_idx]).cast::<f32>();
for i in range(0u32, $n_per_t, 1u32) {
let s_idx = $n_per_t * lane + i;
let old = stack_load("state", i);
let new_val = old * g_val + load(k_log[k_row + s_idx]).cast::<f32>() * d_val;
stack_store("state", i, select(do_step > 0u32, new_val, old));
}
}
for i in range(0u32, $n_per_t, 1u32) {
let st = stack_load("state", i);
store(state_out[i_state_base + $n_per_t * lane + i], st.cast::<T>());
}
}
gdr_spec!($name, $subop);
};
}
gated_delta_record!(
gated_delta_step_record_d192_128_4_4,
192u32,
128u32,
4u32,
4u32,
6u32,
"record_d192_128_4_4"
);
state_replay!(state_replay_d192_128_4_4, 192u32, 128u32, 4u32, 6u32, "replay_d192_128_4_4");
gated_delta_record!(
gated_delta_step_record_d64_32_2_2,
64u32,
32u32,
2u32,
2u32,
2u32,
"record_d64_32_2_2"
);
state_replay!(state_replay_d64_32_2_2, 64u32, 32u32, 2u32, 2u32, "replay_d64_32_2_2");