use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="gated_rmsnorm",
subop="gated_rmsnorm",
class=GenericEmpty,
tol=1e-4,
kernel_mode=Reduction,
)]
#[kernel]
pub fn ffai_gated_rmsnorm<T>(
y: Tensor<f32>,
z: Tensor<T>,
w: Tensor<T>,
out: Tensor<T>,
eps_buf: Tensor<f32>,
#[constexpr] n: u32,
) {
let row = program_id::<0>();
let rs = row * n;
let col = tid * 4u32;
let in_bounds = col + 3u32 < n;
let safe_col = select(in_bounds, col, 0u32);
let safe_base = rs + safe_col;
let base = rs + col;
let y0 = load(y[safe_base]).cast::<f32>();
let y1 = load(y[safe_base + 1u32]).cast::<f32>();
let y2 = load(y[safe_base + 2u32]).cast::<f32>();
let y3 = load(y[safe_base + 3u32]).cast::<f32>();
let raw_ssq = y0 * y0 + y1 * y1 + y2 * y2 + y3 * y3;
let partial_ssq = select(in_bounds, raw_ssq, 0.0f32);
let tg_ssq = reduce_sum(partial_ssq);
let eps = load(eps_buf[0]);
let rms = rsqrt(tg_ssq / n + eps);
if in_bounds {
let z0 = load(z[base]).cast::<f32>();
let z1 = load(z[base + 1u32]).cast::<f32>();
let z2 = load(z[base + 2u32]).cast::<f32>();
let z3 = load(z[base + 3u32]).cast::<f32>();
let g0 = z0 / (1.0f32 + exp(0.0f32 - z0));
let g1 = z1 / (1.0f32 + exp(0.0f32 - z1));
let g2 = z2 / (1.0f32 + exp(0.0f32 - z2));
let g3 = z3 / (1.0f32 + exp(0.0f32 - z3));
let o0 = y0 * rms * load(w[col]).cast::<f32>() * g0;
let o1 = y1 * rms * load(w[col + 1u32]).cast::<f32>() * g1;
let o2 = y2 * rms * load(w[col + 2u32]).cast::<f32>() * g2;
let o3 = y3 * rms * load(w[col + 3u32]).cast::<f32>() * g3;
store(out[base], o0.cast::<T>());
store(out[base + 1u32], o1.cast::<T>());
store(out[base + 2u32], o2.cast::<T>());
store(out[base + 3u32], o3.cast::<T>());
}
}