use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="sdpa",
subop="sdpa_vector",
class=SdpaVector,
h=128, // head_dim
n_kv=4096,
n_heads=32, // n_q_heads
gqa_factor=4, // 32 Q heads grouped onto 8 KV heads
batch=1,
tpg=1024, // BN × BD = 32 × 32
tol=1e-3,
metal_file="scaled_dot_product_attention.metal",
)]
#[kernel]
pub fn mt_sdpa_vector<T>(
q: Tensor<T>,
k: Tensor<T>,
v: Tensor<T>,
out: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] n_kv: u32,
#[constexpr] gqa_factor: u32,
#[constexpr] scale: f32,
) {
let q_head = tgid_x;
let kv_head = q_head / gqa_factor;
let sg = simd_id;
let lane = simd_lane;
let ns = n_simd;
threadgroup_alloc("tg_max", 32);
threadgroup_alloc("tg_sum", 32);
threadgroup_alloc("tg_out", 1024);
let q_off = q_head * head_dim;
let kv_base = kv_head * n_kv * head_dim;
let d0 = lane * 4u32;
let q0 = load(q[q_off + d0]).cast::<f32>() * scale;
let q1 = load(q[q_off + d0 + 1u32]).cast::<f32>() * scale;
let q2 = load(q[q_off + d0 + 2u32]).cast::<f32>() * scale;
let q3 = load(q[q_off + d0 + 3u32]).cast::<f32>() * scale;
let mut run_max = neg_infinity();
let mut run_sum = 0.0f32;
let mut o0 = 0.0f32;
let mut o1 = 0.0f32;
let mut o2 = 0.0f32;
let mut o3 = 0.0f32;
for t in range(sg, n_kv, ns) {
let kv0 = kv_base + t * head_dim + d0;
let kv1 = kv0 + 1u32;
let kv2 = kv0 + 2u32;
let kv3 = kv0 + 3u32;
let k0_raw = load(k[kv0]);
let k1_raw = load(k[kv1]);
let k2_raw = load(k[kv2]);
let k3_raw = load(k[kv3]);
let k0 = k0_raw.cast::<f32>();
let k1 = k1_raw.cast::<f32>();
let k2 = k2_raw.cast::<f32>();
let k3 = k3_raw.cast::<f32>();
let score = simd_sum(q0 * k0 + q1 * k1 + q2 * k2 + q3 * k3);
let new_max = select(score > run_max, score, run_max);
let factor = exp(run_max - new_max);
let weight = exp(score - new_max);
run_sum = run_sum * factor + weight;
run_max = new_max;
let v0_raw = load(v[kv0]);
let v1_raw = load(v[kv1]);
let v2_raw = load(v[kv2]);
let v3_raw = load(v[kv3]);
let v0 = v0_raw.cast::<f32>();
let v1 = v1_raw.cast::<f32>();
let v2 = v2_raw.cast::<f32>();
let v3 = v3_raw.cast::<f32>();
o0 = o0 * factor + weight * v0;
o1 = o1 * factor + weight * v1;
o2 = o2 * factor + weight * v2;
o3 = o3 * factor + weight * v3;
}
if lane == 0 {
threadgroup_store("tg_max", sg, run_max);
threadgroup_store("tg_sum", sg, run_sum);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max = threadgroup_load("tg_max", 0);
let g_sum = threadgroup_load("tg_sum", 0);
let factor_g = exp(run_max - g_max);
let inv_sum = select(g_sum > 0.0f32, 1.0f32 / g_sum, 0.0f32);
threadgroup_store("tg_out", lane * ns + sg, o0);
threadgroup_barrier();
let red0 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o1);
threadgroup_barrier();
let red1 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o2);
threadgroup_barrier();
let red2 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o3);
threadgroup_barrier();
let red3 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
if lane == 0u32 {
let out_off = q_off + sg * 4u32;
store(out[out_off], red0);
store(out[out_off + 1u32], red1);
store(out[out_off + 2u32], red2);
store(out[out_off + 3u32], red3);
}
}
#[kernel]
pub fn mt_sdpa_vector_d64<T>(
q: Tensor<T>,
k: Tensor<T>,
v: Tensor<T>,
out: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] n_kv: u32,
#[constexpr] gqa_factor: u32,
#[constexpr] scale: f32,
) {
let q_head = tgid_x;
let kv_head = q_head / gqa_factor;
let sg = simd_id;
let lane = simd_lane;
let ns = n_simd;
threadgroup_alloc("tg_max", 32);
threadgroup_alloc("tg_sum", 32);
threadgroup_alloc("tg_out", 1024);
let q_off = q_head * head_dim;
let kv_base = kv_head * n_kv * head_dim;
let d0 = lane * 2u32;
let q0 = load(q[q_off + d0]).cast::<f32>() * scale;
let q1 = load(q[q_off + d0 + 1u32]).cast::<f32>() * scale;
let mut run_max = neg_infinity();
let mut run_sum = 0.0f32;
let mut o0 = 0.0f32;
let mut o1 = 0.0f32;
for t in range(sg, n_kv, ns) {
let kv0 = kv_base + t * head_dim + d0;
let kv1 = kv0 + 1u32;
let k0_raw = load(k[kv0]);
let k1_raw = load(k[kv1]);
let k0 = k0_raw.cast::<f32>();
let k1 = k1_raw.cast::<f32>();
let score = simd_sum(q0 * k0 + q1 * k1);
let new_max = select(score > run_max, score, run_max);
let factor = exp(run_max - new_max);
let weight = exp(score - new_max);
run_sum = run_sum * factor + weight;
run_max = new_max;
let v0_raw = load(v[kv0]);
let v1_raw = load(v[kv1]);
let v0 = v0_raw.cast::<f32>();
let v1 = v1_raw.cast::<f32>();
o0 = o0 * factor + weight * v0;
o1 = o1 * factor + weight * v1;
}
if lane == 0 {
threadgroup_store("tg_max", sg, run_max);
threadgroup_store("tg_sum", sg, run_sum);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max = threadgroup_load("tg_max", 0);
let g_sum = threadgroup_load("tg_sum", 0);
let factor_g = exp(run_max - g_max);
let inv_sum = select(g_sum > 0.0f32, 1.0f32 / g_sum, 0.0f32);
threadgroup_store("tg_out", lane * ns + sg, o0);
threadgroup_barrier();
let red0 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o1);
threadgroup_barrier();
let red1 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
if lane == 0u32 {
let out_off = q_off + sg * 2u32;
store(out[out_off], red0);
store(out[out_off + 1u32], red1);
}
}
#[kernel]
pub fn mt_sdpa_vector_d96<T>(
q: Tensor<T>,
k: Tensor<T>,
v: Tensor<T>,
out: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] n_kv: u32,
#[constexpr] gqa_factor: u32,
#[constexpr] scale: f32,
) {
let q_head = tgid_x;
let kv_head = q_head / gqa_factor;
let sg = simd_id;
let lane = simd_lane;
let ns = n_simd;
threadgroup_alloc("tg_max", 32);
threadgroup_alloc("tg_sum", 32);
threadgroup_alloc("tg_out", 1024);
let q_off = q_head * head_dim;
let kv_base = kv_head * n_kv * head_dim;
let d0 = lane * 3u32;
let q0 = load(q[q_off + d0]).cast::<f32>() * scale;
let q1 = load(q[q_off + d0 + 1u32]).cast::<f32>() * scale;
let q2 = load(q[q_off + d0 + 2u32]).cast::<f32>() * scale;
let mut run_max = neg_infinity();
let mut run_sum = 0.0f32;
let mut o0 = 0.0f32;
let mut o1 = 0.0f32;
let mut o2 = 0.0f32;
for t in range(sg, n_kv, ns) {
let kv0 = kv_base + t * head_dim + d0;
let kv1 = kv0 + 1u32;
let kv2 = kv0 + 2u32;
let k0_raw = load(k[kv0]);
let k1_raw = load(k[kv1]);
let k2_raw = load(k[kv2]);
let k0 = k0_raw.cast::<f32>();
let k1 = k1_raw.cast::<f32>();
let k2 = k2_raw.cast::<f32>();
let score = simd_sum(q0 * k0 + q1 * k1 + q2 * k2);
let new_max = select(score > run_max, score, run_max);
let factor = exp(run_max - new_max);
let weight = exp(score - new_max);
run_sum = run_sum * factor + weight;
run_max = new_max;
let v0_raw = load(v[kv0]);
let v1_raw = load(v[kv1]);
let v2_raw = load(v[kv2]);
let v0 = v0_raw.cast::<f32>();
let v1 = v1_raw.cast::<f32>();
let v2 = v2_raw.cast::<f32>();
o0 = o0 * factor + weight * v0;
o1 = o1 * factor + weight * v1;
o2 = o2 * factor + weight * v2;
}
if lane == 0 {
threadgroup_store("tg_max", sg, run_max);
threadgroup_store("tg_sum", sg, run_sum);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max = threadgroup_load("tg_max", 0);
let g_sum = threadgroup_load("tg_sum", 0);
let factor_g = exp(run_max - g_max);
let inv_sum = select(g_sum > 0.0f32, 1.0f32 / g_sum, 0.0f32);
threadgroup_store("tg_out", lane * ns + sg, o0);
threadgroup_barrier();
let red0 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o1);
threadgroup_barrier();
let red1 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o2);
threadgroup_barrier();
let red2 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
if lane == 0u32 {
let out_off = q_off + sg * 3u32;
store(out[out_off], red0);
store(out[out_off + 1u32], red1);
store(out[out_off + 2u32], red2);
}
}
#[kernel]
pub fn mt_sdpa_vector_d192<T>(
q: Tensor<T>,
k: Tensor<T>,
v: Tensor<T>,
out: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] n_kv: u32,
#[constexpr] gqa_factor: u32,
#[constexpr] scale: f32,
) {
let q_head = tgid_x;
let kv_head = q_head / gqa_factor;
let sg = simd_id;
let lane = simd_lane;
let ns = n_simd;
threadgroup_alloc("tg_max", 32);
threadgroup_alloc("tg_sum", 32);
threadgroup_alloc("tg_out", 1024);
let q_off = q_head * head_dim;
let kv_base = kv_head * n_kv * head_dim;
let d0 = lane * 6u32;
let q0 = load(q[q_off + d0]).cast::<f32>() * scale;
let q1 = load(q[q_off + d0 + 1u32]).cast::<f32>() * scale;
let q2 = load(q[q_off + d0 + 2u32]).cast::<f32>() * scale;
let q3 = load(q[q_off + d0 + 3u32]).cast::<f32>() * scale;
let q4 = load(q[q_off + d0 + 4u32]).cast::<f32>() * scale;
let q5 = load(q[q_off + d0 + 5u32]).cast::<f32>() * scale;
let mut run_max = neg_infinity();
let mut run_sum = 0.0f32;
let mut o0 = 0.0f32;
let mut o1 = 0.0f32;
let mut o2 = 0.0f32;
let mut o3 = 0.0f32;
let mut o4 = 0.0f32;
let mut o5 = 0.0f32;
for t in range(sg, n_kv, ns) {
let kv0 = kv_base + t * head_dim + d0;
let kv1 = kv0 + 1u32;
let kv2 = kv0 + 2u32;
let kv3 = kv0 + 3u32;
let kv4 = kv0 + 4u32;
let kv5 = kv0 + 5u32;
let k0 = load(k[kv0]).cast::<f32>();
let k1 = load(k[kv1]).cast::<f32>();
let k2 = load(k[kv2]).cast::<f32>();
let k3 = load(k[kv3]).cast::<f32>();
let k4 = load(k[kv4]).cast::<f32>();
let k5 = load(k[kv5]).cast::<f32>();
let score = simd_sum(q0 * k0 + q1 * k1 + q2 * k2 + q3 * k3 + q4 * k4 + q5 * k5);
let new_max = select(score > run_max, score, run_max);
let factor = exp(run_max - new_max);
let weight = exp(score - new_max);
run_sum = run_sum * factor + weight;
run_max = new_max;
let v0 = load(v[kv0]).cast::<f32>();
let v1 = load(v[kv1]).cast::<f32>();
let v2 = load(v[kv2]).cast::<f32>();
let v3 = load(v[kv3]).cast::<f32>();
let v4 = load(v[kv4]).cast::<f32>();
let v5 = load(v[kv5]).cast::<f32>();
o0 = o0 * factor + weight * v0;
o1 = o1 * factor + weight * v1;
o2 = o2 * factor + weight * v2;
o3 = o3 * factor + weight * v3;
o4 = o4 * factor + weight * v4;
o5 = o5 * factor + weight * v5;
}
if lane == 0 {
threadgroup_store("tg_max", sg, run_max);
threadgroup_store("tg_sum", sg, run_sum);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max = threadgroup_load("tg_max", 0);
let g_sum = threadgroup_load("tg_sum", 0);
let factor_g = exp(run_max - g_max);
let inv_sum = select(g_sum > 0.0f32, 1.0f32 / g_sum, 0.0f32);
threadgroup_store("tg_out", lane * ns + sg, o0);
threadgroup_barrier();
let red0 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o1);
threadgroup_barrier();
let red1 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o2);
threadgroup_barrier();
let red2 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o3);
threadgroup_barrier();
let red3 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o4);
threadgroup_barrier();
let red4 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o5);
threadgroup_barrier();
let red5 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
if lane == 0u32 {
let out_off = q_off + sg * 6u32;
store(out[out_off], red0);
store(out[out_off + 1u32], red1);
store(out[out_off + 2u32], red2);
store(out[out_off + 3u32], red3);
store(out[out_off + 4u32], red4);
store(out[out_off + 5u32], red5);
}
}
#[kernel]
pub fn mt_sdpa_vector_d256<T>(
q: Tensor<T>,
k: Tensor<T>,
v: Tensor<T>,
out: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] n_kv: u32,
#[constexpr] gqa_factor: u32,
#[constexpr] scale: f32,
) {
let q_head = tgid_x;
let kv_head = q_head / gqa_factor;
let sg = simd_id;
let lane = simd_lane;
let ns = n_simd;
threadgroup_alloc("tg_max", 32);
threadgroup_alloc("tg_sum", 32);
threadgroup_alloc("tg_out", 1024);
let q_off = q_head * head_dim;
let kv_base = kv_head * n_kv * head_dim;
let d0 = lane * 8u32;
let q0 = load(q[q_off + d0]).cast::<f32>() * scale;
let q1 = load(q[q_off + d0 + 1u32]).cast::<f32>() * scale;
let q2 = load(q[q_off + d0 + 2u32]).cast::<f32>() * scale;
let q3 = load(q[q_off + d0 + 3u32]).cast::<f32>() * scale;
let q4 = load(q[q_off + d0 + 4u32]).cast::<f32>() * scale;
let q5 = load(q[q_off + d0 + 5u32]).cast::<f32>() * scale;
let q6 = load(q[q_off + d0 + 6u32]).cast::<f32>() * scale;
let q7 = load(q[q_off + d0 + 7u32]).cast::<f32>() * scale;
let mut run_max = neg_infinity();
let mut run_sum = 0.0f32;
let mut o0 = 0.0f32;
let mut o1 = 0.0f32;
let mut o2 = 0.0f32;
let mut o3 = 0.0f32;
let mut o4 = 0.0f32;
let mut o5 = 0.0f32;
let mut o6 = 0.0f32;
let mut o7 = 0.0f32;
for t in range(sg, n_kv, ns) {
let kv0 = kv_base + t * head_dim + d0;
let kv1 = kv0 + 1u32;
let kv2 = kv0 + 2u32;
let kv3 = kv0 + 3u32;
let kv4 = kv0 + 4u32;
let kv5 = kv0 + 5u32;
let kv6 = kv0 + 6u32;
let kv7 = kv0 + 7u32;
let k0 = load(k[kv0]).cast::<f32>();
let k1 = load(k[kv1]).cast::<f32>();
let k2 = load(k[kv2]).cast::<f32>();
let k3 = load(k[kv3]).cast::<f32>();
let k4 = load(k[kv4]).cast::<f32>();
let k5 = load(k[kv5]).cast::<f32>();
let k6 = load(k[kv6]).cast::<f32>();
let k7 = load(k[kv7]).cast::<f32>();
let score =
simd_sum(q0 * k0 + q1 * k1 + q2 * k2 + q3 * k3 + q4 * k4 + q5 * k5 + q6 * k6 + q7 * k7);
let new_max = select(score > run_max, score, run_max);
let factor = exp(run_max - new_max);
let weight = exp(score - new_max);
run_sum = run_sum * factor + weight;
run_max = new_max;
let v0 = load(v[kv0]).cast::<f32>();
let v1 = load(v[kv1]).cast::<f32>();
let v2 = load(v[kv2]).cast::<f32>();
let v3 = load(v[kv3]).cast::<f32>();
let v4 = load(v[kv4]).cast::<f32>();
let v5 = load(v[kv5]).cast::<f32>();
let v6 = load(v[kv6]).cast::<f32>();
let v7 = load(v[kv7]).cast::<f32>();
o0 = o0 * factor + weight * v0;
o1 = o1 * factor + weight * v1;
o2 = o2 * factor + weight * v2;
o3 = o3 * factor + weight * v3;
o4 = o4 * factor + weight * v4;
o5 = o5 * factor + weight * v5;
o6 = o6 * factor + weight * v6;
o7 = o7 * factor + weight * v7;
}
if lane == 0 {
threadgroup_store("tg_max", sg, run_max);
threadgroup_store("tg_sum", sg, run_sum);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max = threadgroup_load("tg_max", 0);
let g_sum = threadgroup_load("tg_sum", 0);
let factor_g = exp(run_max - g_max);
let inv_sum = select(g_sum > 0.0f32, 1.0f32 / g_sum, 0.0f32);
threadgroup_store("tg_out", lane * ns + sg, o0);
threadgroup_barrier();
let red0 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o1);
threadgroup_barrier();
let red1 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o2);
threadgroup_barrier();
let red2 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o3);
threadgroup_barrier();
let red3 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o4);
threadgroup_barrier();
let red4 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o5);
threadgroup_barrier();
let red5 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o6);
threadgroup_barrier();
let red6 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
threadgroup_barrier();
threadgroup_store("tg_out", lane * ns + sg, o7);
threadgroup_barrier();
let red7 = simd_sum(threadgroup_load("tg_out", sg * ns + lane) * factor_g) * inv_sum;
if lane == 0u32 {
let out_off = q_off + sg * 8u32;
store(out[out_off], red0);
store(out[out_off + 1u32], red1);
store(out[out_off + 2u32], red2);
store(out[out_off + 3u32], red3);
store(out[out_off + 4u32], red4);
store(out[out_off + 5u32], red5);
store(out[out_off + 6u32], red6);
store(out[out_off + 7u32], red7);
}
}