use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="gated_delta",
subop="step",
class=GenericEmpty,
tol=0.0,
kernel_mode=Reduction,
)]
#[kernel]
pub fn mt_gated_delta_step<T>(
q: Tensor<T>, k: Tensor<T>, v: Tensor<T>, g: Tensor<T>, beta: Tensor<T>, state_in: Tensor<T>, mut state_out: Tensor<T>, mut y: Tensor<T>, #[constexpr] dk: u32,
#[constexpr] dv: u32,
#[constexpr] hv: u32,
#[constexpr] hk: u32,
) {
let dv_idx = tgid_x;
let n = tgid_y;
let dk_idx = tid;
let hv_idx = n - (n / hv) * hv;
let b = n / hv;
let hk_per_hv = hv / hk;
let hk_idx = hv_idx / hk_per_hv;
let n_per_t = dk / 32u32;
let g_val = load(g[n]).cast::<f32>();
let beta_val = load(beta[n]).cast::<f32>();
let v_val = load(v[n * dv + dv_idx]).cast::<f32>();
let qk_base = (b * hk + hk_idx) * dk;
let state_base = n * dv * dk + dv_idx * dk;
stack_alloc("decayed", 8u32, "f32");
stack_alloc("k_cache", 8u32, "f32");
let mut kv_mem = 0.0f32;
for i in range(0u32, n_per_t, 1u32) {
let s_idx = n_per_t * dk_idx + i;
let s_decayed = load(state_in[state_base + s_idx]).cast::<f32>() * g_val;
let k_val = load(k[qk_base + s_idx]).cast::<f32>();
stack_store("decayed", i, s_decayed);
stack_store("k_cache", i, k_val);
kv_mem = kv_mem + s_decayed * k_val;
}
let kv_mem_sum = simd_sum(kv_mem);
let delta = (v_val - kv_mem_sum) * beta_val;
let mut out = 0.0f32;
for i in range(0u32, n_per_t, 1u32) {
let s_idx = n_per_t * dk_idx + i;
let s_decayed = stack_load("decayed", i);
let k_val = stack_load("k_cache", i);
let s_new = s_decayed + k_val * delta;
store(state_out[state_base + s_idx], s_new.cast::<T>());
let q_val = load(q[qk_base + s_idx]).cast::<f32>();
out = out + s_new * q_val;
}
let out_sum = simd_sum(out);
if dk_idx == 0u32 {
store(y[n * dv + dv_idx], out_sum.cast::<T>());
}
}
#[bench_kernel(
op="gated_delta",
subop="chunk",
class=GenericEmpty,
tol=0.0,
kernel_mode=Reduction,
)]
#[kernel]
pub fn mt_gated_delta_chunk<T>(
q: Tensor<T>, k: Tensor<T>, v: Tensor<T>, g: Tensor<T>, beta: Tensor<T>, state_in: Tensor<T>, mut state_out: Tensor<T>, mut y: Tensor<T>, t_len: Tensor<u32>, #[constexpr] dk: u32,
#[constexpr] dv: u32,
#[constexpr] hv: u32,
#[constexpr] hk: u32,
) {
let dv_idx = tgid_x;
let n = tgid_y;
let dk_idx = tid;
let hv_idx = n - (n / hv) * hv;
let b = n / hv;
let hk_per_hv = hv / hk;
let hk_idx = hv_idx / hk_per_hv;
let n_per_t = dk / 32u32;
let t_total = load(t_len[0]);
let state_base = n * dv * dk + dv_idx * dk;
stack_alloc("state_reg", 8u32, "f32");
stack_alloc("k_cache", 8u32, "f32");
for i in range(0u32, n_per_t, 1u32) {
let s_idx = n_per_t * dk_idx + i;
let val = load(state_in[state_base + s_idx]).cast::<f32>();
stack_store("state_reg", i, val);
}
for t in range(0u32, t_total, 1u32) {
let bt = b * t_total + t;
let qk_base = (bt * hk + hk_idx) * dk;
let vy_base = (bt * hv + hv_idx) * dv;
let gbeta_idx = bt * hv + hv_idx;
let g_val = load(g[gbeta_idx]).cast::<f32>();
let beta_val = load(beta[gbeta_idx]).cast::<f32>();
let v_val = load(v[vy_base + dv_idx]).cast::<f32>();
let mut kv_mem = 0.0f32;
for i in range(0u32, n_per_t, 1u32) {
let s_idx = n_per_t * dk_idx + i;
let s_old = stack_load("state_reg", i);
let s_decayed = s_old * g_val;
stack_store("state_reg", i, s_decayed);
let k_val = load(k[qk_base + s_idx]).cast::<f32>();
stack_store("k_cache", i, k_val);
kv_mem = kv_mem + s_decayed * k_val;
}
let kv_mem_sum = simd_sum(kv_mem);
let delta = (v_val - kv_mem_sum) * beta_val;
let mut out = 0.0f32;
for i in range(0u32, n_per_t, 1u32) {
let s_idx = n_per_t * dk_idx + i;
let s_decayed = stack_load("state_reg", i);
let k_val = stack_load("k_cache", i);
let s_new = s_decayed + k_val * delta;
stack_store("state_reg", i, s_new);
let q_val = load(q[qk_base + s_idx]).cast::<f32>();
out = out + s_new * q_val;
}
let out_sum = simd_sum(out);
if dk_idx == 0u32 {
store(y[vy_base + dv_idx], out_sum.cast::<T>());
}
}
for i in range(0u32, n_per_t, 1u32) {
let s_idx = n_per_t * dk_idx + i;
store(state_out[state_base + s_idx], stack_load("state_reg", i).cast::<T>());
}
}