use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="gated_delta",
subop="prep_step",
class=GenericEmpty,
tol=0.0,
kernel_mode=Reduction,
)]
#[kernel]
pub fn mt_gated_delta_prep_step<T>(
conv_out: Tensor<T>, a_log: Tensor<T>, dt_bias: Tensor<T>, a_raw: Tensor<T>, b_raw: Tensor<T>, q_norm_weight: Tensor<T>, k_norm_weight: Tensor<T>, state_in: Tensor<T>, mut state_out: Tensor<T>, mut y: Tensor<T>, #[constexpr] dk: u32,
#[constexpr] dv: u32,
#[constexpr] hv: u32,
#[constexpr] hk: u32,
) {
let dv_idx = tgid_x;
let n = tgid_y;
let dk_idx = tid;
let hv_idx = n - (n / hv) * hv;
let b = n / hv;
let hk_per_hv = hv / hk;
let hk_idx = hv_idx / hk_per_hv;
let n_per_t = dk / 32u32;
let stride_b = 2u32 * hk * dk + hv * dv;
let conv_base = b * stride_b;
let q_off = conv_base + hk_idx * dk;
let k_off = conv_base + hk * dk + hk_idx * dk;
let v_off = conv_base + 2u32 * hk * dk + hv_idx * dv;
let eps = 0.000001f32;
let dk_f = dk.cast::<f32>();
stack_alloc("q_raw", 8u32, "f32");
stack_alloc("k_raw", 8u32, "f32");
stack_alloc("q_w", 8u32, "f32");
stack_alloc("k_w", 8u32, "f32");
let mut q_ssq = 0.0f32;
let mut k_ssq = 0.0f32;
for i in range(0u32, n_per_t, 1u32) {
let s_idx = n_per_t * dk_idx + i;
let qv = load(conv_out[q_off + s_idx]).cast::<f32>();
let kv = load(conv_out[k_off + s_idx]).cast::<f32>();
stack_store("q_raw", i, qv);
stack_store("k_raw", i, kv);
q_ssq = q_ssq + qv * qv;
k_ssq = k_ssq + kv * kv;
let qw = load(q_norm_weight[hk_idx * dk + s_idx]).cast::<f32>();
let kw = load(k_norm_weight[hk_idx * dk + s_idx]).cast::<f32>();
stack_store("q_w", i, qw);
stack_store("k_w", i, kw);
}
let q_ssq_sum = simd_sum(q_ssq);
let k_ssq_sum = simd_sum(k_ssq);
let q_inv = rsqrt(q_ssq_sum / dk_f + eps);
let k_inv = rsqrt(k_ssq_sum / dk_f + eps);
let a_log_val = load(a_log[hv_idx]).cast::<f32>();
let dt_bias_val = load(dt_bias[hv_idx]).cast::<f32>();
let a_raw_val = load(a_raw[n]).cast::<f32>();
let b_raw_val = load(b_raw[n]).cast::<f32>();
let pre_softplus = a_raw_val + dt_bias_val;
let dt_val = log(exp(pre_softplus) + 1.0f32);
let g_val = exp(0.0f32 - exp(a_log_val) * dt_val);
let beta_val = 1.0f32 / (1.0f32 + exp(0.0f32 - b_raw_val));
let v_val = load(conv_out[v_off + dv_idx]).cast::<f32>();
let state_base = n * dv * dk + dv_idx * dk;
stack_alloc("decayed", 8u32, "f32");
stack_alloc("k_cache", 8u32, "f32");
let mut kv_mem = 0.0f32;
for i in range(0u32, n_per_t, 1u32) {
let s_idx = n_per_t * dk_idx + i;
let s_decayed = load(state_in[state_base + s_idx]).cast::<f32>() * g_val;
let k_normed = stack_load("k_raw", i) * k_inv * stack_load("k_w", i);
stack_store("decayed", i, s_decayed);
stack_store("k_cache", i, k_normed);
kv_mem = kv_mem + s_decayed * k_normed;
}
let kv_mem_sum = simd_sum(kv_mem);
let delta = (v_val - kv_mem_sum) * beta_val;
let mut out_acc = 0.0f32;
for i in range(0u32, n_per_t, 1u32) {
let s_idx = n_per_t * dk_idx + i;
let s_decayed = stack_load("decayed", i);
let k_normed = stack_load("k_cache", i);
let s_new = s_decayed + k_normed * delta;
store(state_out[state_base + s_idx], s_new.cast::<T>());
let q_normed = stack_load("q_raw", i) * q_inv * stack_load("q_w", i);
out_acc = out_acc + s_new * q_normed;
}
let out_sum = simd_sum(out_acc);
if dk_idx == 0u32 {
store(y[n * dv + dv_idx], out_sum.cast::<T>());
}
}
#[cfg(test)]
mod tests {
use metaltile_core::ir::KernelMode;
use super::*;
use crate::bench_types::DType;
#[test]
fn dump() {
use metaltile_codegen::msl::MslGenerator;
let mut k = mt_gated_delta_prep_step::kernel_ir_for(DType::F32);
k.mode = KernelMode::Reduction;
let msl = MslGenerator::default().generate(&k).expect("codegen");
println!("===== BEGIN MSL =====\n{}\n===== END MSL =====", msl);
}
}