#version 450
// Fused scaled-dot-product attention (flash-attention style, online softmax) for one query row.
// out[bh, qi, :] = sum_j softmax_j( scale * dot(Q[bh,qi], K[bh,j]) (+ causal mask) ) * V[bh, j, :]
// One invocation computes one full output row (head_dim D values), streaming over the Lk keys with
// the numerically-stable running-max / running-sum recurrence -- so the [Lq, Lk] score matrix is
// NEVER materialized (the memory win over eager bmm + softmax + bmm). Q/K/V are contiguous
// [BH, L, D] f32. `causal` masks key j > (qi + key_len - q_len) (the standard aligned causal rule,
// so it is correct for both prefill, q_len==k_len, and decode, q_len==1).
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly buffer Q { float q[]; }; // [BH, Lq, D]
layout(set = 0, binding = 1) readonly buffer K { float k[]; }; // [BH, Lk, D]
layout(set = 0, binding = 2) readonly buffer V { float v[]; }; // [BH, Lk, D]
layout(set = 0, binding = 3) writeonly buffer O { float o[]; }; // [BH, Lq, D]
layout(push_constant) uniform Pc {
uint bh; // number of (batch*head) slices
uint lq; // query length
uint lk; // key length
uint d; // head dim
float scale; // 1/sqrt(d) (or any caller scale)
uint causal; // 0/1
};
const uint DMAX = 256u; // supports head_dim up to 256 (Qwen3 uses 128)
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = bh * lq;
if (gid >= total) {
return;
}
uint b = gid / lq; // which bh slice
uint qi = gid - b * lq; // query index within the slice
uint qbase = (b * lq + qi) * d;
uint kv_slice = b * lk * d;
// Causal limit: with aligned causal masking, query qi attends to keys [0 .. qi + (lk - lq)].
uint last_key = lk;
if (causal != 0u) {
last_key = qi + (lk - lq) + 1u; // exclusive upper bound
if (last_key > lk) { last_key = lk; }
}
float acc[DMAX];
for (uint t = 0u; t < d; t++) { acc[t] = 0.0; }
float m = -3.402823466e38; // running max
float l = 0.0; // running denom
for (uint j = 0u; j < last_key; j++) {
uint kbase = kv_slice + j * d;
float s = 0.0;
for (uint t = 0u; t < d; t++) {
s += q[qbase + t] * k[kbase + t];
}
s *= scale;
// Online-softmax update: rescale the running accumulator when a new max appears.
float mnew = max(m, s);
float corr = exp(m - mnew);
float p = exp(s - mnew);
uint vbase = kv_slice + j * d;
for (uint t = 0u; t < d; t++) {
acc[t] = acc[t] * corr + p * v[vbase + t];
}
l = l * corr + p;
m = mnew;
}
float inv = (l > 0.0) ? (1.0 / l) : 0.0;
uint obase = (b * lq + qi) * d;
for (uint t = 0u; t < d; t++) {
o[obase + t] = acc[t] * inv;
}
}