use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="sdpa",
subop="sdpa_decode_batched_q2",
class=SdpaBatchedDecode,
h=128,
n_kv=4096,
n_heads=32,
gqa_factor=4,
batch_q=2,
variant=Decode,
tpg=1024,
tol=1e-3,
kernel_mode=Reduction,
)]
#[kernel]
pub fn sdpa_decode_batched_q2<T>(
q: Tensor<T>,
k: Tensor<T>,
v: Tensor<T>,
out: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] n_kv: u32,
#[constexpr] kv_stride: u32,
#[constexpr] heads_per_group: u32,
#[constexpr] scale: f32,
) {
let q_head = tgid_x;
let kv_head = q_head / heads_per_group;
let sg = simd_id;
let lane = simd_lane;
let ns = n_simd;
threadgroup_alloc("tg_max", 32);
threadgroup_alloc("tg_sum", 32);
threadgroup_alloc("tg_out0", 1056);
threadgroup_alloc("tg_out1", 1056);
threadgroup_alloc("tg_out2", 1056);
threadgroup_alloc("tg_out3", 1056);
let q_off_0 = q_head * 2u32 * head_dim;
let q_off_1 = q_off_0 + head_dim;
let kv_head_base = kv_head * kv_stride * head_dim;
let d0 = lane * 4u32;
let q0_a = load(q[q_off_0 + d0]).cast::<f32>() * scale;
let q0_b = load(q[q_off_0 + d0 + 1u32]).cast::<f32>() * scale;
let q0_c = load(q[q_off_0 + d0 + 2u32]).cast::<f32>() * scale;
let q0_d = load(q[q_off_0 + d0 + 3u32]).cast::<f32>() * scale;
let q1_a = load(q[q_off_1 + d0]).cast::<f32>() * scale;
let q1_b = load(q[q_off_1 + d0 + 1u32]).cast::<f32>() * scale;
let q1_c = load(q[q_off_1 + d0 + 2u32]).cast::<f32>() * scale;
let q1_d = load(q[q_off_1 + d0 + 3u32]).cast::<f32>() * scale;
let mut run_max_0 = neg_infinity();
let mut run_max_1 = neg_infinity();
let mut run_sum_0 = 0.0f32;
let mut run_sum_1 = 0.0f32;
let mut o0_0 = 0.0f32;
let mut o0_1 = 0.0f32;
let mut o0_2 = 0.0f32;
let mut o0_3 = 0.0f32;
let mut o1_0 = 0.0f32;
let mut o1_1 = 0.0f32;
let mut o1_2 = 0.0f32;
let mut o1_3 = 0.0f32;
for _t in range(sg, n_kv, ns) {
let base = kv_head_base + _t * head_dim;
let kv_idx = base + d0;
let kv0 = kv_idx;
let kv1 = kv_idx + 1u32;
let kv2 = kv_idx + 2u32;
let kv3 = kv_idx + 3u32;
let k0_raw = load(k[kv0]);
let k1_raw = load(k[kv1]);
let k2_raw = load(k[kv2]);
let k3_raw = load(k[kv3]);
let k0 = k0_raw.cast::<f32>();
let k1 = k1_raw.cast::<f32>();
let k2 = k2_raw.cast::<f32>();
let k3 = k3_raw.cast::<f32>();
let partial_0 = q0_a * k0 + q0_b * k1 + q0_c * k2 + q0_d * k3;
let partial_1 = q1_a * k0 + q1_b * k1 + q1_c * k2 + q1_d * k3;
let score_0 = simd_sum(partial_0);
let score_1 = simd_sum(partial_1);
let new_max_0 = select(score_0 > run_max_0, score_0, run_max_0);
let new_max_1 = select(score_1 > run_max_1, score_1, run_max_1);
let factor_0 = exp(run_max_0 - new_max_0);
let factor_1 = exp(run_max_1 - new_max_1);
let weight_0 = exp(score_0 - new_max_0);
let weight_1 = exp(score_1 - new_max_1);
run_sum_0 = run_sum_0 * factor_0 + weight_0;
run_sum_1 = run_sum_1 * factor_1 + weight_1;
run_max_0 = new_max_0;
run_max_1 = new_max_1;
let v0_raw = load(v[kv0]);
let v1_raw = load(v[kv1]);
let v2_raw = load(v[kv2]);
let v3_raw = load(v[kv3]);
let v0 = v0_raw.cast::<f32>();
let v1 = v1_raw.cast::<f32>();
let v2 = v2_raw.cast::<f32>();
let v3 = v3_raw.cast::<f32>();
o0_0 = o0_0 * factor_0 + weight_0 * v0;
o0_1 = o0_1 * factor_0 + weight_0 * v1;
o0_2 = o0_2 * factor_0 + weight_0 * v2;
o0_3 = o0_3 * factor_0 + weight_0 * v3;
o1_0 = o1_0 * factor_1 + weight_1 * v0;
o1_1 = o1_1 * factor_1 + weight_1 * v1;
o1_2 = o1_2 * factor_1 + weight_1 * v2;
o1_3 = o1_3 * factor_1 + weight_1 * v3;
}
if lane == 0 {
threadgroup_store("tg_max", sg, run_max_0);
threadgroup_store("tg_sum", sg, run_sum_0);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max_0 = threadgroup_load("tg_max", 0);
let g_sum_0 = threadgroup_load("tg_sum", 0);
let rescale_0 = select(g_sum_0 > 0.0f32, exp(run_max_0 - g_max_0) / g_sum_0, 0.0f32);
let stride = ns + 1u32;
let idx = lane * stride + sg;
threadgroup_store("tg_out0", idx, o0_0 * rescale_0);
threadgroup_store("tg_out1", idx, o0_1 * rescale_0);
threadgroup_store("tg_out2", idx, o0_2 * rescale_0);
threadgroup_store("tg_out3", idx, o0_3 * rescale_0);
threadgroup_barrier();
if sg == 0 {
let mut so0_0 = 0.0f32;
let mut so0_1 = 0.0f32;
let mut so0_2 = 0.0f32;
let mut so0_3 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so0_0 = so0_0 + threadgroup_load("tg_out0", ri);
so0_1 = so0_1 + threadgroup_load("tg_out1", ri);
so0_2 = so0_2 + threadgroup_load("tg_out2", ri);
so0_3 = so0_3 + threadgroup_load("tg_out3", ri);
}
let out_off_0 = q_head * 2u32 * head_dim + d0;
store(out[out_off_0], so0_0.cast::<T>());
store(out[out_off_0 + 1u32], so0_1.cast::<T>());
store(out[out_off_0 + 2u32], so0_2.cast::<T>());
store(out[out_off_0 + 3u32], so0_3.cast::<T>());
}
threadgroup_barrier();
if lane == 0 {
threadgroup_store("tg_max", sg, run_max_1);
threadgroup_store("tg_sum", sg, run_sum_1);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max_1 = threadgroup_load("tg_max", 0);
let g_sum_1 = threadgroup_load("tg_sum", 0);
let rescale_1 = select(g_sum_1 > 0.0f32, exp(run_max_1 - g_max_1) / g_sum_1, 0.0f32);
threadgroup_store("tg_out0", idx, o1_0 * rescale_1);
threadgroup_store("tg_out1", idx, o1_1 * rescale_1);
threadgroup_store("tg_out2", idx, o1_2 * rescale_1);
threadgroup_store("tg_out3", idx, o1_3 * rescale_1);
threadgroup_barrier();
if sg == 0 {
let mut so1_0 = 0.0f32;
let mut so1_1 = 0.0f32;
let mut so1_2 = 0.0f32;
let mut so1_3 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so1_0 = so1_0 + threadgroup_load("tg_out0", ri);
so1_1 = so1_1 + threadgroup_load("tg_out1", ri);
so1_2 = so1_2 + threadgroup_load("tg_out2", ri);
so1_3 = so1_3 + threadgroup_load("tg_out3", ri);
}
let out_off_1 = q_head * 2u32 * head_dim + head_dim + d0;
store(out[out_off_1], so1_0.cast::<T>());
store(out[out_off_1 + 1u32], so1_1.cast::<T>());
store(out[out_off_1 + 2u32], so1_2.cast::<T>());
store(out[out_off_1 + 3u32], so1_3.cast::<T>());
}
}
#[bench_kernel(
op="sdpa",
subop="sdpa_decode_batched_q4",
class=SdpaBatchedDecode,
h=128,
n_kv=4096,
n_heads=32,
gqa_factor=4,
batch_q=4,
variant=Decode,
tpg=512,
tol=1e-3,
kernel_mode=Reduction,
)]
#[kernel]
pub fn sdpa_decode_batched_q4<T>(
q: Tensor<T>,
k: Tensor<T>,
v: Tensor<T>,
out: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] n_kv: u32,
#[constexpr] kv_stride: u32,
#[constexpr] heads_per_group: u32,
#[constexpr] scale: f32,
) {
let q_head = tgid_x;
let kv_head = q_head / heads_per_group;
let sg = simd_id;
let lane = simd_lane;
let ns = n_simd;
threadgroup_alloc("tg_max", 32);
threadgroup_alloc("tg_sum", 32);
threadgroup_alloc("tg_out0", 1056);
threadgroup_alloc("tg_out1", 1056);
threadgroup_alloc("tg_out2", 1056);
threadgroup_alloc("tg_out3", 1056);
let q_off_0 = q_head * 4u32 * head_dim;
let q_off_1 = q_off_0 + head_dim;
let q_off_2 = q_off_1 + head_dim;
let q_off_3 = q_off_2 + head_dim;
let kv_head_base = kv_head * kv_stride * head_dim;
let d0 = lane * 4u32;
let q0_a = load(q[q_off_0 + d0]).cast::<f32>() * scale;
let q0_b = load(q[q_off_0 + d0 + 1u32]).cast::<f32>() * scale;
let q0_c = load(q[q_off_0 + d0 + 2u32]).cast::<f32>() * scale;
let q0_d = load(q[q_off_0 + d0 + 3u32]).cast::<f32>() * scale;
let q1_a = load(q[q_off_1 + d0]).cast::<f32>() * scale;
let q1_b = load(q[q_off_1 + d0 + 1u32]).cast::<f32>() * scale;
let q1_c = load(q[q_off_1 + d0 + 2u32]).cast::<f32>() * scale;
let q1_d = load(q[q_off_1 + d0 + 3u32]).cast::<f32>() * scale;
let q2_a = load(q[q_off_2 + d0]).cast::<f32>() * scale;
let q2_b = load(q[q_off_2 + d0 + 1u32]).cast::<f32>() * scale;
let q2_c = load(q[q_off_2 + d0 + 2u32]).cast::<f32>() * scale;
let q2_d = load(q[q_off_2 + d0 + 3u32]).cast::<f32>() * scale;
let q3_a = load(q[q_off_3 + d0]).cast::<f32>() * scale;
let q3_b = load(q[q_off_3 + d0 + 1u32]).cast::<f32>() * scale;
let q3_c = load(q[q_off_3 + d0 + 2u32]).cast::<f32>() * scale;
let q3_d = load(q[q_off_3 + d0 + 3u32]).cast::<f32>() * scale;
let mut run_max_0 = neg_infinity();
let mut run_max_1 = neg_infinity();
let mut run_max_2 = neg_infinity();
let mut run_max_3 = neg_infinity();
let mut run_sum_0 = 0.0f32;
let mut run_sum_1 = 0.0f32;
let mut run_sum_2 = 0.0f32;
let mut run_sum_3 = 0.0f32;
let mut o0_0 = 0.0f32;
let mut o0_1 = 0.0f32;
let mut o0_2 = 0.0f32;
let mut o0_3 = 0.0f32;
let mut o1_0 = 0.0f32;
let mut o1_1 = 0.0f32;
let mut o1_2 = 0.0f32;
let mut o1_3 = 0.0f32;
let mut o2_0 = 0.0f32;
let mut o2_1 = 0.0f32;
let mut o2_2 = 0.0f32;
let mut o2_3 = 0.0f32;
let mut o3_0 = 0.0f32;
let mut o3_1 = 0.0f32;
let mut o3_2 = 0.0f32;
let mut o3_3 = 0.0f32;
for _t in range(sg, n_kv, ns) {
let base = kv_head_base + _t * head_dim;
let kv_idx = base + d0;
let kv0 = kv_idx;
let kv1 = kv_idx + 1u32;
let kv2 = kv_idx + 2u32;
let kv3 = kv_idx + 3u32;
let k0_raw = load(k[kv0]);
let k1_raw = load(k[kv1]);
let k2_raw = load(k[kv2]);
let k3_raw = load(k[kv3]);
let k0 = k0_raw.cast::<f32>();
let k1 = k1_raw.cast::<f32>();
let k2 = k2_raw.cast::<f32>();
let k3 = k3_raw.cast::<f32>();
let partial_0 = q0_a * k0 + q0_b * k1 + q0_c * k2 + q0_d * k3;
let partial_1 = q1_a * k0 + q1_b * k1 + q1_c * k2 + q1_d * k3;
let partial_2 = q2_a * k0 + q2_b * k1 + q2_c * k2 + q2_d * k3;
let partial_3 = q3_a * k0 + q3_b * k1 + q3_c * k2 + q3_d * k3;
let score_0 = simd_sum(partial_0);
let score_1 = simd_sum(partial_1);
let score_2 = simd_sum(partial_2);
let score_3 = simd_sum(partial_3);
let new_max_0 = select(score_0 > run_max_0, score_0, run_max_0);
let new_max_1 = select(score_1 > run_max_1, score_1, run_max_1);
let new_max_2 = select(score_2 > run_max_2, score_2, run_max_2);
let new_max_3 = select(score_3 > run_max_3, score_3, run_max_3);
let factor_0 = exp(run_max_0 - new_max_0);
let factor_1 = exp(run_max_1 - new_max_1);
let factor_2 = exp(run_max_2 - new_max_2);
let factor_3 = exp(run_max_3 - new_max_3);
let weight_0 = exp(score_0 - new_max_0);
let weight_1 = exp(score_1 - new_max_1);
let weight_2 = exp(score_2 - new_max_2);
let weight_3 = exp(score_3 - new_max_3);
run_sum_0 = run_sum_0 * factor_0 + weight_0;
run_sum_1 = run_sum_1 * factor_1 + weight_1;
run_sum_2 = run_sum_2 * factor_2 + weight_2;
run_sum_3 = run_sum_3 * factor_3 + weight_3;
run_max_0 = new_max_0;
run_max_1 = new_max_1;
run_max_2 = new_max_2;
run_max_3 = new_max_3;
let v0_raw = load(v[kv0]);
let v1_raw = load(v[kv1]);
let v2_raw = load(v[kv2]);
let v3_raw = load(v[kv3]);
let v0 = v0_raw.cast::<f32>();
let v1 = v1_raw.cast::<f32>();
let v2 = v2_raw.cast::<f32>();
let v3 = v3_raw.cast::<f32>();
o0_0 = o0_0 * factor_0 + weight_0 * v0;
o0_1 = o0_1 * factor_0 + weight_0 * v1;
o0_2 = o0_2 * factor_0 + weight_0 * v2;
o0_3 = o0_3 * factor_0 + weight_0 * v3;
o1_0 = o1_0 * factor_1 + weight_1 * v0;
o1_1 = o1_1 * factor_1 + weight_1 * v1;
o1_2 = o1_2 * factor_1 + weight_1 * v2;
o1_3 = o1_3 * factor_1 + weight_1 * v3;
o2_0 = o2_0 * factor_2 + weight_2 * v0;
o2_1 = o2_1 * factor_2 + weight_2 * v1;
o2_2 = o2_2 * factor_2 + weight_2 * v2;
o2_3 = o2_3 * factor_2 + weight_2 * v3;
o3_0 = o3_0 * factor_3 + weight_3 * v0;
o3_1 = o3_1 * factor_3 + weight_3 * v1;
o3_2 = o3_2 * factor_3 + weight_3 * v2;
o3_3 = o3_3 * factor_3 + weight_3 * v3;
}
let stride = ns + 1u32;
let idx = lane * stride + sg;
if lane == 0 {
threadgroup_store("tg_max", sg, run_max_0);
threadgroup_store("tg_sum", sg, run_sum_0);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max_0 = threadgroup_load("tg_max", 0);
let g_sum_0 = threadgroup_load("tg_sum", 0);
let rescale_0 = select(g_sum_0 > 0.0f32, exp(run_max_0 - g_max_0) / g_sum_0, 0.0f32);
threadgroup_store("tg_out0", idx, o0_0 * rescale_0);
threadgroup_store("tg_out1", idx, o0_1 * rescale_0);
threadgroup_store("tg_out2", idx, o0_2 * rescale_0);
threadgroup_store("tg_out3", idx, o0_3 * rescale_0);
threadgroup_barrier();
if sg == 0 {
let mut so0_0 = 0.0f32;
let mut so0_1 = 0.0f32;
let mut so0_2 = 0.0f32;
let mut so0_3 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so0_0 = so0_0 + threadgroup_load("tg_out0", ri);
so0_1 = so0_1 + threadgroup_load("tg_out1", ri);
so0_2 = so0_2 + threadgroup_load("tg_out2", ri);
so0_3 = so0_3 + threadgroup_load("tg_out3", ri);
}
let out_off_0 = q_head * 4u32 * head_dim + d0;
store(out[out_off_0], so0_0.cast::<T>());
store(out[out_off_0 + 1u32], so0_1.cast::<T>());
store(out[out_off_0 + 2u32], so0_2.cast::<T>());
store(out[out_off_0 + 3u32], so0_3.cast::<T>());
}
threadgroup_barrier();
if lane == 0 {
threadgroup_store("tg_max", sg, run_max_1);
threadgroup_store("tg_sum", sg, run_sum_1);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max_1 = threadgroup_load("tg_max", 0);
let g_sum_1 = threadgroup_load("tg_sum", 0);
let rescale_1 = select(g_sum_1 > 0.0f32, exp(run_max_1 - g_max_1) / g_sum_1, 0.0f32);
threadgroup_store("tg_out0", idx, o1_0 * rescale_1);
threadgroup_store("tg_out1", idx, o1_1 * rescale_1);
threadgroup_store("tg_out2", idx, o1_2 * rescale_1);
threadgroup_store("tg_out3", idx, o1_3 * rescale_1);
threadgroup_barrier();
if sg == 0 {
let mut so1_0 = 0.0f32;
let mut so1_1 = 0.0f32;
let mut so1_2 = 0.0f32;
let mut so1_3 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so1_0 = so1_0 + threadgroup_load("tg_out0", ri);
so1_1 = so1_1 + threadgroup_load("tg_out1", ri);
so1_2 = so1_2 + threadgroup_load("tg_out2", ri);
so1_3 = so1_3 + threadgroup_load("tg_out3", ri);
}
let out_off_1 = q_head * 4u32 * head_dim + head_dim + d0;
store(out[out_off_1], so1_0.cast::<T>());
store(out[out_off_1 + 1u32], so1_1.cast::<T>());
store(out[out_off_1 + 2u32], so1_2.cast::<T>());
store(out[out_off_1 + 3u32], so1_3.cast::<T>());
}
threadgroup_barrier();
if lane == 0 {
threadgroup_store("tg_max", sg, run_max_2);
threadgroup_store("tg_sum", sg, run_sum_2);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max_2 = threadgroup_load("tg_max", 0);
let g_sum_2 = threadgroup_load("tg_sum", 0);
let rescale_2 = select(g_sum_2 > 0.0f32, exp(run_max_2 - g_max_2) / g_sum_2, 0.0f32);
threadgroup_store("tg_out0", idx, o2_0 * rescale_2);
threadgroup_store("tg_out1", idx, o2_1 * rescale_2);
threadgroup_store("tg_out2", idx, o2_2 * rescale_2);
threadgroup_store("tg_out3", idx, o2_3 * rescale_2);
threadgroup_barrier();
if sg == 0 {
let mut so2_0 = 0.0f32;
let mut so2_1 = 0.0f32;
let mut so2_2 = 0.0f32;
let mut so2_3 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so2_0 = so2_0 + threadgroup_load("tg_out0", ri);
so2_1 = so2_1 + threadgroup_load("tg_out1", ri);
so2_2 = so2_2 + threadgroup_load("tg_out2", ri);
so2_3 = so2_3 + threadgroup_load("tg_out3", ri);
}
let out_off_2 = q_head * 4u32 * head_dim + 2u32 * head_dim + d0;
store(out[out_off_2], so2_0.cast::<T>());
store(out[out_off_2 + 1u32], so2_1.cast::<T>());
store(out[out_off_2 + 2u32], so2_2.cast::<T>());
store(out[out_off_2 + 3u32], so2_3.cast::<T>());
}
threadgroup_barrier();
if lane == 0 {
threadgroup_store("tg_max", sg, run_max_3);
threadgroup_store("tg_sum", sg, run_sum_3);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max_3 = threadgroup_load("tg_max", 0);
let g_sum_3 = threadgroup_load("tg_sum", 0);
let rescale_3 = select(g_sum_3 > 0.0f32, exp(run_max_3 - g_max_3) / g_sum_3, 0.0f32);
threadgroup_store("tg_out0", idx, o3_0 * rescale_3);
threadgroup_store("tg_out1", idx, o3_1 * rescale_3);
threadgroup_store("tg_out2", idx, o3_2 * rescale_3);
threadgroup_store("tg_out3", idx, o3_3 * rescale_3);
threadgroup_barrier();
if sg == 0 {
let mut so3_0 = 0.0f32;
let mut so3_1 = 0.0f32;
let mut so3_2 = 0.0f32;
let mut so3_3 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so3_0 = so3_0 + threadgroup_load("tg_out0", ri);
so3_1 = so3_1 + threadgroup_load("tg_out1", ri);
so3_2 = so3_2 + threadgroup_load("tg_out2", ri);
so3_3 = so3_3 + threadgroup_load("tg_out3", ri);
}
let out_off_3 = q_head * 4u32 * head_dim + 3u32 * head_dim + d0;
store(out[out_off_3], so3_0.cast::<T>());
store(out[out_off_3 + 1u32], so3_1.cast::<T>());
store(out[out_off_3 + 2u32], so3_2.cast::<T>());
store(out[out_off_3 + 3u32], so3_3.cast::<T>());
}
}
#[bench_kernel(
op="sdpa",
subop="sdpa_decode_batched_q8",
class=SdpaBatchedDecode,
h=128,
n_kv=4096,
n_heads=32,
gqa_factor=4,
batch_q=8,
variant=Decode,
tpg=256,
tol=1e-3,
kernel_mode=Reduction,
)]
#[kernel]
pub fn sdpa_decode_batched_q8<T>(
q: Tensor<T>,
k: Tensor<T>,
v: Tensor<T>,
out: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] n_kv: u32,
#[constexpr] kv_stride: u32,
#[constexpr] heads_per_group: u32,
#[constexpr] scale: f32,
) {
let q_head = tgid_x;
let kv_head = q_head / heads_per_group;
let sg = simd_id;
let lane = simd_lane;
let ns = n_simd;
threadgroup_alloc("tg_max", 32);
threadgroup_alloc("tg_sum", 32);
threadgroup_alloc("tg_out0", 1056);
threadgroup_alloc("tg_out1", 1056);
threadgroup_alloc("tg_out2", 1056);
threadgroup_alloc("tg_out3", 1056);
let q_off_0 = q_head * 8u32 * head_dim;
let q_off_1 = q_off_0 + head_dim;
let q_off_2 = q_off_1 + head_dim;
let q_off_3 = q_off_2 + head_dim;
let q_off_4 = q_off_3 + head_dim;
let q_off_5 = q_off_4 + head_dim;
let q_off_6 = q_off_5 + head_dim;
let q_off_7 = q_off_6 + head_dim;
let kv_head_base = kv_head * kv_stride * head_dim;
let d0 = lane * 4u32;
let q0_a = load(q[q_off_0 + d0]).cast::<f32>() * scale;
let q0_b = load(q[q_off_0 + d0 + 1u32]).cast::<f32>() * scale;
let q0_c = load(q[q_off_0 + d0 + 2u32]).cast::<f32>() * scale;
let q0_d = load(q[q_off_0 + d0 + 3u32]).cast::<f32>() * scale;
let q1_a = load(q[q_off_1 + d0]).cast::<f32>() * scale;
let q1_b = load(q[q_off_1 + d0 + 1u32]).cast::<f32>() * scale;
let q1_c = load(q[q_off_1 + d0 + 2u32]).cast::<f32>() * scale;
let q1_d = load(q[q_off_1 + d0 + 3u32]).cast::<f32>() * scale;
let q2_a = load(q[q_off_2 + d0]).cast::<f32>() * scale;
let q2_b = load(q[q_off_2 + d0 + 1u32]).cast::<f32>() * scale;
let q2_c = load(q[q_off_2 + d0 + 2u32]).cast::<f32>() * scale;
let q2_d = load(q[q_off_2 + d0 + 3u32]).cast::<f32>() * scale;
let q3_a = load(q[q_off_3 + d0]).cast::<f32>() * scale;
let q3_b = load(q[q_off_3 + d0 + 1u32]).cast::<f32>() * scale;
let q3_c = load(q[q_off_3 + d0 + 2u32]).cast::<f32>() * scale;
let q3_d = load(q[q_off_3 + d0 + 3u32]).cast::<f32>() * scale;
let q4_a = load(q[q_off_4 + d0]).cast::<f32>() * scale;
let q4_b = load(q[q_off_4 + d0 + 1u32]).cast::<f32>() * scale;
let q4_c = load(q[q_off_4 + d0 + 2u32]).cast::<f32>() * scale;
let q4_d = load(q[q_off_4 + d0 + 3u32]).cast::<f32>() * scale;
let q5_a = load(q[q_off_5 + d0]).cast::<f32>() * scale;
let q5_b = load(q[q_off_5 + d0 + 1u32]).cast::<f32>() * scale;
let q5_c = load(q[q_off_5 + d0 + 2u32]).cast::<f32>() * scale;
let q5_d = load(q[q_off_5 + d0 + 3u32]).cast::<f32>() * scale;
let q6_a = load(q[q_off_6 + d0]).cast::<f32>() * scale;
let q6_b = load(q[q_off_6 + d0 + 1u32]).cast::<f32>() * scale;
let q6_c = load(q[q_off_6 + d0 + 2u32]).cast::<f32>() * scale;
let q6_d = load(q[q_off_6 + d0 + 3u32]).cast::<f32>() * scale;
let q7_a = load(q[q_off_7 + d0]).cast::<f32>() * scale;
let q7_b = load(q[q_off_7 + d0 + 1u32]).cast::<f32>() * scale;
let q7_c = load(q[q_off_7 + d0 + 2u32]).cast::<f32>() * scale;
let q7_d = load(q[q_off_7 + d0 + 3u32]).cast::<f32>() * scale;
let mut run_max_0 = neg_infinity();
let mut run_max_1 = neg_infinity();
let mut run_max_2 = neg_infinity();
let mut run_max_3 = neg_infinity();
let mut run_max_4 = neg_infinity();
let mut run_max_5 = neg_infinity();
let mut run_max_6 = neg_infinity();
let mut run_max_7 = neg_infinity();
let mut run_sum_0 = 0.0f32;
let mut run_sum_1 = 0.0f32;
let mut run_sum_2 = 0.0f32;
let mut run_sum_3 = 0.0f32;
let mut run_sum_4 = 0.0f32;
let mut run_sum_5 = 0.0f32;
let mut run_sum_6 = 0.0f32;
let mut run_sum_7 = 0.0f32;
let mut o0_0 = 0.0f32;
let mut o0_1 = 0.0f32;
let mut o0_2 = 0.0f32;
let mut o0_3 = 0.0f32;
let mut o1_0 = 0.0f32;
let mut o1_1 = 0.0f32;
let mut o1_2 = 0.0f32;
let mut o1_3 = 0.0f32;
let mut o2_0 = 0.0f32;
let mut o2_1 = 0.0f32;
let mut o2_2 = 0.0f32;
let mut o2_3 = 0.0f32;
let mut o3_0 = 0.0f32;
let mut o3_1 = 0.0f32;
let mut o3_2 = 0.0f32;
let mut o3_3 = 0.0f32;
let mut o4_0 = 0.0f32;
let mut o4_1 = 0.0f32;
let mut o4_2 = 0.0f32;
let mut o4_3 = 0.0f32;
let mut o5_0 = 0.0f32;
let mut o5_1 = 0.0f32;
let mut o5_2 = 0.0f32;
let mut o5_3 = 0.0f32;
let mut o6_0 = 0.0f32;
let mut o6_1 = 0.0f32;
let mut o6_2 = 0.0f32;
let mut o6_3 = 0.0f32;
let mut o7_0 = 0.0f32;
let mut o7_1 = 0.0f32;
let mut o7_2 = 0.0f32;
let mut o7_3 = 0.0f32;
for _t in range(sg, n_kv, ns) {
let base = kv_head_base + _t * head_dim;
let kv_idx = base + d0;
let kv0 = kv_idx;
let kv1 = kv_idx + 1u32;
let kv2 = kv_idx + 2u32;
let kv3 = kv_idx + 3u32;
let k0 = load(k[kv0]).cast::<f32>();
let k1 = load(k[kv1]).cast::<f32>();
let k2 = load(k[kv2]).cast::<f32>();
let k3 = load(k[kv3]).cast::<f32>();
let partial_0 = q0_a * k0 + q0_b * k1 + q0_c * k2 + q0_d * k3;
let partial_1 = q1_a * k0 + q1_b * k1 + q1_c * k2 + q1_d * k3;
let partial_2 = q2_a * k0 + q2_b * k1 + q2_c * k2 + q2_d * k3;
let partial_3 = q3_a * k0 + q3_b * k1 + q3_c * k2 + q3_d * k3;
let partial_4 = q4_a * k0 + q4_b * k1 + q4_c * k2 + q4_d * k3;
let partial_5 = q5_a * k0 + q5_b * k1 + q5_c * k2 + q5_d * k3;
let partial_6 = q6_a * k0 + q6_b * k1 + q6_c * k2 + q6_d * k3;
let partial_7 = q7_a * k0 + q7_b * k1 + q7_c * k2 + q7_d * k3;
let score_0 = simd_sum(partial_0);
let score_1 = simd_sum(partial_1);
let score_2 = simd_sum(partial_2);
let score_3 = simd_sum(partial_3);
let score_4 = simd_sum(partial_4);
let score_5 = simd_sum(partial_5);
let score_6 = simd_sum(partial_6);
let score_7 = simd_sum(partial_7);
let new_max_0 = select(score_0 > run_max_0, score_0, run_max_0);
let new_max_1 = select(score_1 > run_max_1, score_1, run_max_1);
let new_max_2 = select(score_2 > run_max_2, score_2, run_max_2);
let new_max_3 = select(score_3 > run_max_3, score_3, run_max_3);
let new_max_4 = select(score_4 > run_max_4, score_4, run_max_4);
let new_max_5 = select(score_5 > run_max_5, score_5, run_max_5);
let new_max_6 = select(score_6 > run_max_6, score_6, run_max_6);
let new_max_7 = select(score_7 > run_max_7, score_7, run_max_7);
let factor_0 = exp(run_max_0 - new_max_0);
let factor_1 = exp(run_max_1 - new_max_1);
let factor_2 = exp(run_max_2 - new_max_2);
let factor_3 = exp(run_max_3 - new_max_3);
let factor_4 = exp(run_max_4 - new_max_4);
let factor_5 = exp(run_max_5 - new_max_5);
let factor_6 = exp(run_max_6 - new_max_6);
let factor_7 = exp(run_max_7 - new_max_7);
let weight_0 = exp(score_0 - new_max_0);
let weight_1 = exp(score_1 - new_max_1);
let weight_2 = exp(score_2 - new_max_2);
let weight_3 = exp(score_3 - new_max_3);
let weight_4 = exp(score_4 - new_max_4);
let weight_5 = exp(score_5 - new_max_5);
let weight_6 = exp(score_6 - new_max_6);
let weight_7 = exp(score_7 - new_max_7);
run_sum_0 = run_sum_0 * factor_0 + weight_0;
run_sum_1 = run_sum_1 * factor_1 + weight_1;
run_sum_2 = run_sum_2 * factor_2 + weight_2;
run_sum_3 = run_sum_3 * factor_3 + weight_3;
run_sum_4 = run_sum_4 * factor_4 + weight_4;
run_sum_5 = run_sum_5 * factor_5 + weight_5;
run_sum_6 = run_sum_6 * factor_6 + weight_6;
run_sum_7 = run_sum_7 * factor_7 + weight_7;
run_max_0 = new_max_0;
run_max_1 = new_max_1;
run_max_2 = new_max_2;
run_max_3 = new_max_3;
run_max_4 = new_max_4;
run_max_5 = new_max_5;
run_max_6 = new_max_6;
run_max_7 = new_max_7;
let v0 = load(v[kv0]).cast::<f32>();
let v1 = load(v[kv1]).cast::<f32>();
let v2 = load(v[kv2]).cast::<f32>();
let v3 = load(v[kv3]).cast::<f32>();
o0_0 = o0_0 * factor_0 + weight_0 * v0;
o0_1 = o0_1 * factor_0 + weight_0 * v1;
o0_2 = o0_2 * factor_0 + weight_0 * v2;
o0_3 = o0_3 * factor_0 + weight_0 * v3;
o1_0 = o1_0 * factor_1 + weight_1 * v0;
o1_1 = o1_1 * factor_1 + weight_1 * v1;
o1_2 = o1_2 * factor_1 + weight_1 * v2;
o1_3 = o1_3 * factor_1 + weight_1 * v3;
o2_0 = o2_0 * factor_2 + weight_2 * v0;
o2_1 = o2_1 * factor_2 + weight_2 * v1;
o2_2 = o2_2 * factor_2 + weight_2 * v2;
o2_3 = o2_3 * factor_2 + weight_2 * v3;
o3_0 = o3_0 * factor_3 + weight_3 * v0;
o3_1 = o3_1 * factor_3 + weight_3 * v1;
o3_2 = o3_2 * factor_3 + weight_3 * v2;
o3_3 = o3_3 * factor_3 + weight_3 * v3;
o4_0 = o4_0 * factor_4 + weight_4 * v0;
o4_1 = o4_1 * factor_4 + weight_4 * v1;
o4_2 = o4_2 * factor_4 + weight_4 * v2;
o4_3 = o4_3 * factor_4 + weight_4 * v3;
o5_0 = o5_0 * factor_5 + weight_5 * v0;
o5_1 = o5_1 * factor_5 + weight_5 * v1;
o5_2 = o5_2 * factor_5 + weight_5 * v2;
o5_3 = o5_3 * factor_5 + weight_5 * v3;
o6_0 = o6_0 * factor_6 + weight_6 * v0;
o6_1 = o6_1 * factor_6 + weight_6 * v1;
o6_2 = o6_2 * factor_6 + weight_6 * v2;
o6_3 = o6_3 * factor_6 + weight_6 * v3;
o7_0 = o7_0 * factor_7 + weight_7 * v0;
o7_1 = o7_1 * factor_7 + weight_7 * v1;
o7_2 = o7_2 * factor_7 + weight_7 * v2;
o7_3 = o7_3 * factor_7 + weight_7 * v3;
}
let stride = ns + 1u32;
let idx = lane * stride + sg;
if lane == 0 {
threadgroup_store("tg_max", sg, run_max_0);
threadgroup_store("tg_sum", sg, run_sum_0);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max_0 = threadgroup_load("tg_max", 0);
let g_sum_0 = threadgroup_load("tg_sum", 0);
let rescale_0 = select(g_sum_0 > 0.0f32, exp(run_max_0 - g_max_0) / g_sum_0, 0.0f32);
threadgroup_store("tg_out0", idx, o0_0 * rescale_0);
threadgroup_store("tg_out1", idx, o0_1 * rescale_0);
threadgroup_store("tg_out2", idx, o0_2 * rescale_0);
threadgroup_store("tg_out3", idx, o0_3 * rescale_0);
threadgroup_barrier();
if sg == 0 {
let mut so0_0 = 0.0f32;
let mut so0_1 = 0.0f32;
let mut so0_2 = 0.0f32;
let mut so0_3 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so0_0 = so0_0 + threadgroup_load("tg_out0", ri);
so0_1 = so0_1 + threadgroup_load("tg_out1", ri);
so0_2 = so0_2 + threadgroup_load("tg_out2", ri);
so0_3 = so0_3 + threadgroup_load("tg_out3", ri);
}
let out_off_0 = q_head * 8u32 * head_dim + d0;
store(out[out_off_0], so0_0.cast::<T>());
store(out[out_off_0 + 1u32], so0_1.cast::<T>());
store(out[out_off_0 + 2u32], so0_2.cast::<T>());
store(out[out_off_0 + 3u32], so0_3.cast::<T>());
}
threadgroup_barrier();
if lane == 0 {
threadgroup_store("tg_max", sg, run_max_1);
threadgroup_store("tg_sum", sg, run_sum_1);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max_1 = threadgroup_load("tg_max", 0);
let g_sum_1 = threadgroup_load("tg_sum", 0);
let rescale_1 = select(g_sum_1 > 0.0f32, exp(run_max_1 - g_max_1) / g_sum_1, 0.0f32);
threadgroup_store("tg_out0", idx, o1_0 * rescale_1);
threadgroup_store("tg_out1", idx, o1_1 * rescale_1);
threadgroup_store("tg_out2", idx, o1_2 * rescale_1);
threadgroup_store("tg_out3", idx, o1_3 * rescale_1);
threadgroup_barrier();
if sg == 0 {
let mut so1_0 = 0.0f32;
let mut so1_1 = 0.0f32;
let mut so1_2 = 0.0f32;
let mut so1_3 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so1_0 = so1_0 + threadgroup_load("tg_out0", ri);
so1_1 = so1_1 + threadgroup_load("tg_out1", ri);
so1_2 = so1_2 + threadgroup_load("tg_out2", ri);
so1_3 = so1_3 + threadgroup_load("tg_out3", ri);
}
let out_off_1 = q_head * 8u32 * head_dim + head_dim + d0;
store(out[out_off_1], so1_0.cast::<T>());
store(out[out_off_1 + 1u32], so1_1.cast::<T>());
store(out[out_off_1 + 2u32], so1_2.cast::<T>());
store(out[out_off_1 + 3u32], so1_3.cast::<T>());
}
threadgroup_barrier();
if lane == 0 {
threadgroup_store("tg_max", sg, run_max_2);
threadgroup_store("tg_sum", sg, run_sum_2);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max_2 = threadgroup_load("tg_max", 0);
let g_sum_2 = threadgroup_load("tg_sum", 0);
let rescale_2 = select(g_sum_2 > 0.0f32, exp(run_max_2 - g_max_2) / g_sum_2, 0.0f32);
threadgroup_store("tg_out0", idx, o2_0 * rescale_2);
threadgroup_store("tg_out1", idx, o2_1 * rescale_2);
threadgroup_store("tg_out2", idx, o2_2 * rescale_2);
threadgroup_store("tg_out3", idx, o2_3 * rescale_2);
threadgroup_barrier();
if sg == 0 {
let mut so2_0 = 0.0f32;
let mut so2_1 = 0.0f32;
let mut so2_2 = 0.0f32;
let mut so2_3 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so2_0 = so2_0 + threadgroup_load("tg_out0", ri);
so2_1 = so2_1 + threadgroup_load("tg_out1", ri);
so2_2 = so2_2 + threadgroup_load("tg_out2", ri);
so2_3 = so2_3 + threadgroup_load("tg_out3", ri);
}
let out_off_2 = q_head * 8u32 * head_dim + 2u32 * head_dim + d0;
store(out[out_off_2], so2_0.cast::<T>());
store(out[out_off_2 + 1u32], so2_1.cast::<T>());
store(out[out_off_2 + 2u32], so2_2.cast::<T>());
store(out[out_off_2 + 3u32], so2_3.cast::<T>());
}
threadgroup_barrier();
if lane == 0 {
threadgroup_store("tg_max", sg, run_max_3);
threadgroup_store("tg_sum", sg, run_sum_3);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max_3 = threadgroup_load("tg_max", 0);
let g_sum_3 = threadgroup_load("tg_sum", 0);
let rescale_3 = select(g_sum_3 > 0.0f32, exp(run_max_3 - g_max_3) / g_sum_3, 0.0f32);
threadgroup_store("tg_out0", idx, o3_0 * rescale_3);
threadgroup_store("tg_out1", idx, o3_1 * rescale_3);
threadgroup_store("tg_out2", idx, o3_2 * rescale_3);
threadgroup_store("tg_out3", idx, o3_3 * rescale_3);
threadgroup_barrier();
if sg == 0 {
let mut so3_0 = 0.0f32;
let mut so3_1 = 0.0f32;
let mut so3_2 = 0.0f32;
let mut so3_3 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so3_0 = so3_0 + threadgroup_load("tg_out0", ri);
so3_1 = so3_1 + threadgroup_load("tg_out1", ri);
so3_2 = so3_2 + threadgroup_load("tg_out2", ri);
so3_3 = so3_3 + threadgroup_load("tg_out3", ri);
}
let out_off_3 = q_head * 8u32 * head_dim + 3u32 * head_dim + d0;
store(out[out_off_3], so3_0.cast::<T>());
store(out[out_off_3 + 1u32], so3_1.cast::<T>());
store(out[out_off_3 + 2u32], so3_2.cast::<T>());
store(out[out_off_3 + 3u32], so3_3.cast::<T>());
}
threadgroup_barrier();
if lane == 0 {
threadgroup_store("tg_max", sg, run_max_4);
threadgroup_store("tg_sum", sg, run_sum_4);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max_4 = threadgroup_load("tg_max", 0);
let g_sum_4 = threadgroup_load("tg_sum", 0);
let rescale_4 = select(g_sum_4 > 0.0f32, exp(run_max_4 - g_max_4) / g_sum_4, 0.0f32);
threadgroup_store("tg_out0", idx, o4_0 * rescale_4);
threadgroup_store("tg_out1", idx, o4_1 * rescale_4);
threadgroup_store("tg_out2", idx, o4_2 * rescale_4);
threadgroup_store("tg_out3", idx, o4_3 * rescale_4);
threadgroup_barrier();
if sg == 0 {
let mut so4_0 = 0.0f32;
let mut so4_1 = 0.0f32;
let mut so4_2 = 0.0f32;
let mut so4_3 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so4_0 = so4_0 + threadgroup_load("tg_out0", ri);
so4_1 = so4_1 + threadgroup_load("tg_out1", ri);
so4_2 = so4_2 + threadgroup_load("tg_out2", ri);
so4_3 = so4_3 + threadgroup_load("tg_out3", ri);
}
let out_off_4 = q_head * 8u32 * head_dim + 4u32 * head_dim + d0;
store(out[out_off_4], so4_0.cast::<T>());
store(out[out_off_4 + 1u32], so4_1.cast::<T>());
store(out[out_off_4 + 2u32], so4_2.cast::<T>());
store(out[out_off_4 + 3u32], so4_3.cast::<T>());
}
threadgroup_barrier();
if lane == 0 {
threadgroup_store("tg_max", sg, run_max_5);
threadgroup_store("tg_sum", sg, run_sum_5);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max_5 = threadgroup_load("tg_max", 0);
let g_sum_5 = threadgroup_load("tg_sum", 0);
let rescale_5 = select(g_sum_5 > 0.0f32, exp(run_max_5 - g_max_5) / g_sum_5, 0.0f32);
threadgroup_store("tg_out0", idx, o5_0 * rescale_5);
threadgroup_store("tg_out1", idx, o5_1 * rescale_5);
threadgroup_store("tg_out2", idx, o5_2 * rescale_5);
threadgroup_store("tg_out3", idx, o5_3 * rescale_5);
threadgroup_barrier();
if sg == 0 {
let mut so5_0 = 0.0f32;
let mut so5_1 = 0.0f32;
let mut so5_2 = 0.0f32;
let mut so5_3 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so5_0 = so5_0 + threadgroup_load("tg_out0", ri);
so5_1 = so5_1 + threadgroup_load("tg_out1", ri);
so5_2 = so5_2 + threadgroup_load("tg_out2", ri);
so5_3 = so5_3 + threadgroup_load("tg_out3", ri);
}
let out_off_5 = q_head * 8u32 * head_dim + 5u32 * head_dim + d0;
store(out[out_off_5], so5_0.cast::<T>());
store(out[out_off_5 + 1u32], so5_1.cast::<T>());
store(out[out_off_5 + 2u32], so5_2.cast::<T>());
store(out[out_off_5 + 3u32], so5_3.cast::<T>());
}
threadgroup_barrier();
if lane == 0 {
threadgroup_store("tg_max", sg, run_max_6);
threadgroup_store("tg_sum", sg, run_sum_6);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max_6 = threadgroup_load("tg_max", 0);
let g_sum_6 = threadgroup_load("tg_sum", 0);
let rescale_6 = select(g_sum_6 > 0.0f32, exp(run_max_6 - g_max_6) / g_sum_6, 0.0f32);
threadgroup_store("tg_out0", idx, o6_0 * rescale_6);
threadgroup_store("tg_out1", idx, o6_1 * rescale_6);
threadgroup_store("tg_out2", idx, o6_2 * rescale_6);
threadgroup_store("tg_out3", idx, o6_3 * rescale_6);
threadgroup_barrier();
if sg == 0 {
let mut so6_0 = 0.0f32;
let mut so6_1 = 0.0f32;
let mut so6_2 = 0.0f32;
let mut so6_3 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so6_0 = so6_0 + threadgroup_load("tg_out0", ri);
so6_1 = so6_1 + threadgroup_load("tg_out1", ri);
so6_2 = so6_2 + threadgroup_load("tg_out2", ri);
so6_3 = so6_3 + threadgroup_load("tg_out3", ri);
}
let out_off_6 = q_head * 8u32 * head_dim + 6u32 * head_dim + d0;
store(out[out_off_6], so6_0.cast::<T>());
store(out[out_off_6 + 1u32], so6_1.cast::<T>());
store(out[out_off_6 + 2u32], so6_2.cast::<T>());
store(out[out_off_6 + 3u32], so6_3.cast::<T>());
}
threadgroup_barrier();
if lane == 0 {
threadgroup_store("tg_max", sg, run_max_7);
threadgroup_store("tg_sum", sg, run_sum_7);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max_7 = threadgroup_load("tg_max", 0);
let g_sum_7 = threadgroup_load("tg_sum", 0);
let rescale_7 = select(g_sum_7 > 0.0f32, exp(run_max_7 - g_max_7) / g_sum_7, 0.0f32);
threadgroup_store("tg_out0", idx, o7_0 * rescale_7);
threadgroup_store("tg_out1", idx, o7_1 * rescale_7);
threadgroup_store("tg_out2", idx, o7_2 * rescale_7);
threadgroup_store("tg_out3", idx, o7_3 * rescale_7);
threadgroup_barrier();
if sg == 0 {
let mut so7_0 = 0.0f32;
let mut so7_1 = 0.0f32;
let mut so7_2 = 0.0f32;
let mut so7_3 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so7_0 = so7_0 + threadgroup_load("tg_out0", ri);
so7_1 = so7_1 + threadgroup_load("tg_out1", ri);
so7_2 = so7_2 + threadgroup_load("tg_out2", ri);
so7_3 = so7_3 + threadgroup_load("tg_out3", ri);
}
let out_off_7 = q_head * 8u32 * head_dim + 7u32 * head_dim + d0;
store(out[out_off_7], so7_0.cast::<T>());
store(out[out_off_7 + 1u32], so7_1.cast::<T>());
store(out[out_off_7 + 2u32], so7_2.cast::<T>());
store(out[out_off_7 + 3u32], so7_3.cast::<T>());
}
}
#[cfg(test)]
mod tests {
use metaltile_codegen::msl::MslGenerator;
use metaltile_core::ir::KernelMode;
use super::sdpa_decode_batched_q2;
use crate::bench_types::DType;
fn msl_for(dt: DType) -> String {
let mut k = sdpa_decode_batched_q2::kernel_ir_for(dt);
k.mode = KernelMode::Reduction;
MslGenerator::default().generate(&k).expect("sdpa_decode_batched_q2 codegen succeeds")
}
#[test]
fn codegen_produces_nonempty_msl_for_all_float_dtypes() {
for dt in [DType::F32, DType::F16, DType::BF16] {
let src = msl_for(dt);
assert!(!src.trim().is_empty(), "MSL for {dt:?} should not be empty");
assert!(
src.contains("kernel void sdpa_decode_batched_q2"),
"MSL for {dt:?} should declare the kernel function:\n{src}",
);
}
}
#[test]
fn codegen_uses_threadgroup_reduction_primitives() {
let src = msl_for(DType::F32);
for tok in &[
"simd_group",
"simd_lane",
"threadgroup_barrier",
"mem_threadgroup",
"simd_sum",
"simd_max",
] {
assert!(src.contains(tok), "MSL missing `{tok}`:\n{src}");
}
}
#[test]
fn codegen_emits_two_phase_reduction() {
let src = msl_for(DType::F32);
let scalar = src.matches("out[").count();
let vector = src.matches("float*)out +").count();
let effective = scalar + vector * 4;
assert!(
effective >= 8,
"Expected ≥8 effective out-writes (4 quartiles × 2 Q positions); \
got {effective} ({scalar} scalar + {vector} vectorized×4):\n{src}",
);
}
#[test]
fn codegen_loads_both_q_vectors() {
let src = msl_for(DType::F32);
assert!(
src.contains("q_off_0") || src.contains("q[q_head * 2"),
"Expected Q[0] base offset in MSL:\n{src}"
);
}
fn msl_for_q4(dt: DType) -> String {
let mut k = super::sdpa_decode_batched_q4::kernel_ir_for(dt);
k.mode = KernelMode::Reduction;
MslGenerator::default().generate(&k).expect("sdpa_decode_batched_q4 codegen succeeds")
}
#[test]
fn q4_codegen_produces_nonempty_msl_for_all_float_dtypes() {
for dt in [DType::F32, DType::F16, DType::BF16] {
let src = msl_for_q4(dt);
assert!(!src.trim().is_empty(), "MSL for {dt:?} should not be empty");
assert!(
src.contains("kernel void sdpa_decode_batched_q4"),
"MSL for {dt:?} should declare the kernel function:\n{src}",
);
}
}
#[test]
fn q4_codegen_emits_four_phase_reduction() {
let src = msl_for_q4(DType::F32);
let scalar = src.matches("out[").count();
let vector = src.matches("float*)out +").count();
let effective = scalar + vector * 4;
assert!(
effective >= 16,
"Expected ≥16 effective out-writes (4 quartiles × 4 Q positions); \
got {effective} ({scalar} scalar + {vector} vectorized×4)",
);
}
#[test]
fn q4_codegen_uses_threadgroup_reduction_primitives() {
let src = msl_for_q4(DType::F32);
for tok in &[
"simd_group",
"simd_lane",
"threadgroup_barrier",
"mem_threadgroup",
"simd_sum",
"simd_max",
] {
assert!(src.contains(tok), "MSL missing `{tok}`");
}
}
#[test]
fn q4_codegen_loads_all_four_q_vectors() {
let src = msl_for_q4(DType::F32);
assert!(
src.contains("q_off_0") || src.contains("q[q_head * 4"),
"Expected Q[0] base offset (q_head * 4 * head_dim) in MSL"
);
for stream in &["rescale_0", "rescale_1", "rescale_2", "rescale_3"] {
assert!(src.contains(stream), "MSL missing per-stream `{stream}`");
}
}
}