#version 450
// Scaled dot-product attention with online (flash-style) softmax, mirroring
// the CPU reference semantics. One invocation per (batch, head, query).
// Layouts: bhsd=1 ⇒ [B,H,S,D]; bhsd=0 ⇒ [B,S,H,D] with per-tensor row stride.
// KV heads are assumed equal to query heads (GQA expanded upstream).
// mask_kind: 0 none, 1 causal, 2 sliding-window, 3 custom (per b,k threshold),
// 4 bias (additive per b,h,q,k). q_offset = k_s - q_s (KV-cache).
layout(local_size_x = 64) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
const uint MAX_DH = 256u;
layout(push_constant) uniform PC {
uint b;
uint nh;
uint q_s;
uint k_s;
uint dh;
uint q_off;
uint k_off;
uint v_off;
uint o_off;
uint qrs; // [B,S,H,D] row strides; unused when bhsd=1
uint krs;
uint vrs;
uint bhsd;
uint mask_kind;
uint mask_off;
uint window;
float scale;
float neg; // masked-out score value
float thr; // custom-mask threshold
} pc;
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = pc.b * pc.nh * pc.q_s;
if (gid >= total) { return; }
uint qi = gid % pc.q_s;
uint hi = (gid / pc.q_s) % pc.nh;
uint bi = gid / (pc.q_s * pc.nh);
if (pc.dh > MAX_DH) { return; }
uint q_base;
uint o_base;
if (pc.bhsd != 0u) {
q_base = pc.q_off + bi * pc.nh * pc.q_s * pc.dh + hi * pc.q_s * pc.dh + qi * pc.dh;
o_base = pc.o_off + bi * pc.nh * pc.q_s * pc.dh + hi * pc.q_s * pc.dh + qi * pc.dh;
} else {
q_base = pc.q_off + bi * pc.q_s * pc.qrs + qi * pc.qrs + hi * pc.dh;
o_base = pc.o_off + bi * pc.q_s * (pc.nh * pc.dh) + qi * (pc.nh * pc.dh) + hi * pc.dh;
}
uint q_offset = (pc.k_s > pc.q_s) ? (pc.k_s - pc.q_s) : 0u;
uint abs_q = q_offset + qi;
float acc[MAX_DH];
for (uint d = 0u; d < pc.dh; d++) { acc[d] = 0.0; }
float m = -3.402823466e38;
float l = 0.0;
for (uint ki = 0u; ki < pc.k_s; ki++) {
// Masking.
bool masked = false;
float bias = 0.0;
if (pc.mask_kind == 1u) { // causal
masked = (ki > abs_q);
} else if (pc.mask_kind == 2u) { // sliding window
uint lo = (abs_q > pc.window) ? (abs_q - pc.window) : 0u;
masked = (ki < lo) || (ki > abs_q);
} else if (pc.mask_kind == 3u) { // custom (per b,k)
masked = data[pc.mask_off + bi * pc.k_s + ki] < pc.thr;
} else if (pc.mask_kind == 4u) { // additive bias (per b,h,q,k)
uint per_bh = pc.q_s * pc.k_s;
bias = data[pc.mask_off + (bi * pc.nh + hi) * per_bh + qi * pc.k_s + ki];
}
uint k_base;
uint v_base;
if (pc.bhsd != 0u) {
k_base = pc.k_off + bi * pc.nh * pc.k_s * pc.dh + hi * pc.k_s * pc.dh + ki * pc.dh;
v_base = pc.v_off + bi * pc.nh * pc.k_s * pc.dh + hi * pc.k_s * pc.dh + ki * pc.dh;
} else {
k_base = pc.k_off + bi * pc.k_s * pc.krs + ki * pc.krs + hi * pc.dh;
v_base = pc.v_off + bi * pc.k_s * pc.vrs + ki * pc.vrs + hi * pc.dh;
}
float score;
if (masked) {
score = pc.neg;
} else {
float dot = 0.0;
for (uint d = 0u; d < pc.dh; d++) {
dot += data[q_base + d] * data[k_base + d];
}
score = dot * pc.scale + bias;
}
// Online softmax update.
float m_new = max(m, score);
float corr = exp(m - m_new);
float p = exp(score - m_new);
l = l * corr + p;
for (uint d = 0u; d < pc.dh; d++) {
acc[d] = acc[d] * corr + p * data[v_base + d];
}
m = m_new;
}
float inv = (l > 0.0) ? (1.0 / l) : 0.0;
for (uint d = 0u; d < pc.dh; d++) {
data[o_base + d] = acc[d] * inv;
}
}