#![allow(clippy::too_many_arguments)]
use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="gated_delta",
subop="wy_chunk",
class=GenericEmpty,
tol=0.0,
kernel_mode=Reduction,
)]
#[kernel]
pub fn mt_gated_delta_wy_chunk<T>(
q: Tensor<T>,
k: Tensor<T>,
v: Tensor<T>,
g: Tensor<T>,
beta: Tensor<T>,
state_in: Tensor<T>,
mut state_out: Tensor<T>,
mut y: Tensor<T>,
#[constexpr] dk: u32,
#[constexpr] dv: u32,
#[constexpr] hv: u32,
#[constexpr] hk: u32,
#[constexpr] c: u32,
#[constexpr] t_len: u32,
) {
let n = tgid_y; let b_idx = n / hv;
let hv_idx = n % hv;
let hv_per_hk = hv / hk;
let hk_idx = hv_idx / hv_per_hk;
let lane = simd_lane;
threadgroup_alloc("tg_state", 1024u32, f32); stack_alloc("new_state", 128u32, "f32");
threadgroup_alloc("tg_q", 512u32, f32); threadgroup_alloc("tg_k", 512u32, f32);
threadgroup_alloc("tg_v", 512u32, f32); threadgroup_alloc("tg_g", 16u32, f32);
threadgroup_alloc("tg_beta", 16u32, f32);
threadgroup_alloc("tg_bigG", 16u32, f32);
threadgroup_alloc("tg_kkt", 256u32, f32); threadgroup_alloc("tg_p", 512u32, f32);
threadgroup_alloc("tg_uv", 512u32, f32);
threadgroup_alloc("tg_qkt", 256u32, f32);
threadgroup_alloc("tg_s0p", 512u32, f32); let state_base = n * dv * dk;
let total_state = dv * dk;
for ii in range(lane, total_state, 32u32) {
let v_in = load(state_in[state_base + ii]).cast::<f32>();
threadgroup_store("tg_state", ii, v_in);
}
threadgroup_barrier();
let num_chunks = t_len / c;
for chunk_idx in range(0u32, num_chunks, 1u32) {
let chunk_start = chunk_idx * c;
for i in range(0u32, c, 1u32) {
let t_abs = chunk_start + i;
for d in range(lane, dk, 32u32) {
let qkv_off = (t_abs * hk + hk_idx) * dk + d;
threadgroup_store("tg_q", i * dk + d, load(q[qkv_off]).cast::<f32>());
threadgroup_store("tg_k", i * dk + d, load(k[qkv_off]).cast::<f32>());
}
for d in range(lane, dv, 32u32) {
let v_off = (t_abs * hv + hv_idx) * dv + d;
threadgroup_store("tg_v", i * dv + d, load(v[v_off]).cast::<f32>());
}
if lane == 0u32 {
let gb_off = t_abs * hv + hv_idx;
threadgroup_store("tg_g", i, load(g[gb_off]).cast::<f32>());
threadgroup_store("tg_beta", i, load(beta[gb_off]).cast::<f32>());
}
}
threadgroup_barrier();
if lane == 0u32 {
let mut g_acc = 1.0f32;
for i in range(0u32, c, 1u32) {
g_acc = g_acc * threadgroup_load("tg_g", i);
threadgroup_store("tg_bigG", i, g_acc);
}
}
threadgroup_barrier();
for ij in range(lane, c * c, 32u32) {
let i = ij / c;
let j = ij % c;
let mut s = 0.0f32;
for d in range(0u32, dk, 1u32) {
let ki = threadgroup_load("tg_k", i * dk + d);
let kj = threadgroup_load("tg_k", j * dk + d);
s = s + ki * kj;
}
threadgroup_store("tg_kkt", i * c + j, s);
}
threadgroup_barrier();
for t in range(0u32, c, 1u32) {
for d in range(lane, dk, 32u32) {
let mut accum = threadgroup_load("tg_k", t * dk + d);
for j in range(0u32, t, 1u32) {
let beta_j = threadgroup_load("tg_beta", j);
let kkt_tj = threadgroup_load("tg_kkt", t * c + j);
let p_jd = threadgroup_load("tg_p", j * dk + d);
accum = accum - beta_j * kkt_tj * p_jd;
}
threadgroup_store("tg_p", t * dk + d, accum);
}
threadgroup_barrier();
}
for t in range(0u32, c, 1u32) {
let beta_t = threadgroup_load("tg_beta", t);
let big_g_t = threadgroup_load("tg_bigG", t);
for d in range(lane, dv, 32u32) {
let v_td = threadgroup_load("tg_v", t * dv + d);
let mut accum = beta_t * v_td;
for j in range(0u32, t, 1u32) {
let big_g_j = threadgroup_load("tg_bigG", j);
let gamma_tj = big_g_t / big_g_j;
let kkt_tj = threadgroup_load("tg_kkt", t * c + j);
let a_tj = beta_t * gamma_tj * kkt_tj;
let uv_jd = threadgroup_load("tg_uv", j * dv + d);
accum = accum - a_tj * uv_jd;
}
threadgroup_store("tg_uv", t * dv + d, accum);
}
threadgroup_barrier();
}
for tj in range(lane, c * c, 32u32) {
let t = tj / c;
let j = tj % c;
let mut s = 0.0f32;
for d in range(0u32, dk, 1u32) {
let qt = threadgroup_load("tg_q", t * dk + d);
let kj = threadgroup_load("tg_k", j * dk + d);
s = s + qt * kj;
}
threadgroup_store("tg_qkt", t * c + j, s);
}
threadgroup_barrier();
for vi in range(lane, dv * c, 32u32) {
let d_v = vi / c;
let i = vi % c;
let mut acc = 0.0f32;
for d in range(0u32, dk, 1u32) {
let st = threadgroup_load("tg_state", d_v * dk + d);
let pi = threadgroup_load("tg_p", i * dk + d);
acc = acc + st * pi;
}
threadgroup_store("tg_s0p", d_v * c + i, acc);
}
threadgroup_barrier();
for tdv in range(lane, c * dv, 32u32) {
let t = tdv / dv;
let d_v = tdv % dv;
let big_g_t = threadgroup_load("tg_bigG", t);
let mut y_loc = 0.0f32;
for j in range(0u32, t + 1u32, 1u32) {
let big_g_j = threadgroup_load("tg_bigG", j);
let gamma_tj = big_g_t / big_g_j;
let qkt_tj = threadgroup_load("tg_qkt", t * c + j);
let uv_jd = threadgroup_load("tg_uv", j * dv + d_v);
y_loc = y_loc + gamma_tj * qkt_tj * uv_jd;
}
let mut s0q = 0.0f32;
for d in range(0u32, dk, 1u32) {
let st = threadgroup_load("tg_state", d_v * dk + d);
let qt = threadgroup_load("tg_q", t * dk + d);
s0q = s0q + st * qt;
}
let mut corr = 0.0f32;
for i in range(0u32, t + 1u32, 1u32) {
let beta_i = threadgroup_load("tg_beta", i);
let qkt_ti = threadgroup_load("tg_qkt", t * c + i);
let s0p_vi = threadgroup_load("tg_s0p", d_v * c + i);
corr = corr + beta_i * qkt_ti * s0p_vi;
}
let y_pass = big_g_t * (s0q - corr);
let t_abs = chunk_start + t;
let y_off = (t_abs * hv + hv_idx) * dv + d_v;
store(y[y_off], (y_pass + y_loc).cast::<T>());
}
threadgroup_barrier();
let big_g_c = threadgroup_load("tg_bigG", c - 1u32);
let mut iter_idx = 0u32;
for vd in range(lane, dv * dk, 32u32) {
let d_v = vd / dk;
let d_k = vd % dk;
let s0_old = threadgroup_load("tg_state", d_v * dk + d_k);
let mut s_corr = 0.0f32;
for i in range(0u32, c, 1u32) {
let beta_i = threadgroup_load("tg_beta", i);
let p_ik = threadgroup_load("tg_p", i * dk + d_k);
let s0p_vi = threadgroup_load("tg_s0p", d_v * c + i);
s_corr = s_corr + beta_i * p_ik * s0p_vi;
}
let s_through = big_g_c * (s0_old - s_corr);
let mut u_end = 0.0f32;
for j in range(0u32, c, 1u32) {
let big_g_j = threadgroup_load("tg_bigG", j);
let rw = big_g_c / big_g_j;
let uv_jv = threadgroup_load("tg_uv", j * dv + d_v);
let k_jd = threadgroup_load("tg_k", j * dk + d_k);
u_end = u_end + rw * uv_jv * k_jd;
}
stack_store("new_state", iter_idx, s_through + u_end);
iter_idx = iter_idx + 1u32;
}
threadgroup_barrier();
let mut flush_idx = 0u32;
for vd in range(lane, dv * dk, 32u32) {
threadgroup_store("tg_state", vd, stack_load("new_state", flush_idx));
flush_idx = flush_idx + 1u32;
}
threadgroup_barrier();
}
for ii in range(lane, total_state, 32u32) {
let s = threadgroup_load("tg_state", ii);
store(state_out[state_base + ii], s.cast::<T>());
}
}