#version 450
// Online-softmax multi-head attention with causal masking and GQA support
// Q: [num_heads, seq_len, head_dim] K: [num_kv_heads, kv_len, head_dim]
// V: [num_kv_heads, kv_len, head_dim] Out: [num_heads, seq_len, head_dim]
// Dispatch: (num_heads, seq_len, 1) workgroups
layout(local_size_x = 256) in;
layout(set = 0, binding = 0) readonly buffer QBuf { float q_data[]; };
layout(set = 0, binding = 1) readonly buffer KBuf { float k_data[]; };
layout(set = 0, binding = 2) readonly buffer VBuf { float v_data[]; };
layout(set = 0, binding = 3) writeonly buffer OutBuf { float out_data[]; };
layout(push_constant) uniform Params {
int num_heads;
int num_kv_heads;
int seq_len;
int kv_len;
int head_dim;
float scale;
};
shared float accum[256];
shared float reduction[256];
shared float s_max_score;
shared float s_sum_exp;
shared float s_weight;
shared float s_correction;
void main() {
uint head = gl_WorkGroupID.x;
uint s = gl_WorkGroupID.y;
uint tid = gl_LocalInvocationID.x;
uint nt = gl_WorkGroupSize.x;
uint kv_head = head / (uint(num_heads) / uint(num_kv_heads));
uint q_abs_pos = uint(kv_len) - uint(seq_len) + s;
for (uint d = tid; d < uint(head_dim); d += nt) {
accum[d] = 0.0;
}
if (tid == 0) {
s_max_score = -3.402823466e+38;
s_sum_exp = 0.0;
}
barrier();
uint q_base = head * uint(seq_len) * uint(head_dim) + s * uint(head_dim);
for (uint kv_pos = 0; kv_pos <= q_abs_pos && kv_pos < uint(kv_len); kv_pos++) {
float local_dot = 0.0;
uint k_base = kv_head * uint(kv_len) * uint(head_dim) + kv_pos * uint(head_dim);
for (uint d = tid; d < uint(head_dim); d += nt) {
local_dot += q_data[q_base + d] * k_data[k_base + d];
}
reduction[tid] = local_dot;
barrier();
for (uint stride = nt / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
reduction[tid] += reduction[tid + stride];
}
barrier();
}
float score = reduction[0] * scale;
if (tid == 0) {
float old_max = s_max_score;
if (score > old_max) {
s_correction = exp(old_max - score);
s_sum_exp *= s_correction;
s_max_score = score;
} else {
s_correction = 1.0;
}
s_weight = exp(score - s_max_score);
s_sum_exp += s_weight;
}
barrier();
float w = s_weight;
float c = s_correction;
uint v_base = kv_head * uint(kv_len) * uint(head_dim) + kv_pos * uint(head_dim);
for (uint d = tid; d < uint(head_dim); d += nt) {
accum[d] = accum[d] * c + w * v_data[v_base + d];
}
barrier();
}
if (tid == 0) {
s_weight = (s_sum_exp > 0.0) ? 1.0 / s_sum_exp : 0.0;
}
barrier();
float inv_sum = s_weight;
uint out_base = head * uint(seq_len) * uint(head_dim) + s * uint(head_dim);
for (uint d = tid; d < uint(head_dim); d += nt) {
out_data[out_base + d] = accum[d] * inv_sum;
}
}