layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t HSK = 32;
layout (constant_id = 4) const uint32_t HSV = 32;
layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
layout (constant_id = 7) const uint32_t row_split = 1;
layout (constant_id = 8) const uint32_t SubGroupSize = 32;
layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0;
layout (constant_id = 10) const uint32_t Flags = 0;
layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0;
// ggml_type enumerant for K/V
layout (constant_id = 12) const uint32_t FaTypeK = 0;
layout (constant_id = 13) const uint32_t FaTypeV = 0;
// sizeof(decode buffer): quants -> ggml block size; F32 -> 16 (decodeBufF32 vec4).
layout (constant_id = 14) const uint32_t FaBlockBytesK = 2;
layout (constant_id = 15) const uint32_t FaBlockBytesV = 2;
const bool USE_MASK_OPT = (Flags & 1) != 0;
const bool MASK_ENABLE = (Flags & 2) != 0;
const bool LOGIT_SOFTCAP = (Flags & 4) != 0;
const bool OLD_AMD_WINDOWS = (Flags & 8) != 0;
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
const uint32_t HSK_pad = (HSK + 15) & ~15;
const uint32_t HSV_pad = (HSV + 15) & ~15;
const bool KV_bounds_check = Clamp != 0;
layout (push_constant) uniform parameter {
uint32_t N;
uint32_t KV;
uint32_t ne1;
uint32_t ne2;
uint32_t ne3;
uint32_t neq2;
uint32_t neq3;
uint32_t nek2;
uint32_t nek3;
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;
uint32_t nem2;
uint32_t nem3;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
float scale;
float max_bias;
float logit_softcap;
uint32_t mask_n_head_log2;
float m0;
float m1;
uint32_t gqa_ratio;
uint32_t split_kv;
uint32_t k_num;
} p;
#define SINK_ENABLE_BIT (1<<24)
#define N_LOG2_MASK 0xFFFF
layout (binding = 4) readonly buffer S {float data_s[];};
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
layout (binding = 5) writeonly buffer OV4 {D_TYPEV4 data_ov4[];};
layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
#define MASK_OPT_ALL_NEG_INF 1
#define MASK_OPT_ALL_ZERO 2
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
// FaTypeK / FaTypeV spec constant values. These mirror enum ggml_type so the
// host can pass the type directly. Keep in sync with ggml.h.
#define FA_TYPE_F32 0u
#define FA_TYPE_F16 1u
#define FA_TYPE_Q4_0 2u
#define FA_TYPE_Q4_1 3u
#define FA_TYPE_Q5_0 6u
#define FA_TYPE_Q5_1 7u
#define FA_TYPE_Q8_0 8u
#define FA_TYPE_Q1_0 41u
// Number of matrix elements per buffer block, derived from the K/V type spec
// constant. F32 is treated as a vec4 "block" of 4 floats. F16 uses block size 1
// and bypasses the dequant path entirely. Quants follow their ggml block sizes.
uint fa_block_elems(uint ty) {
switch (ty) {
case FA_TYPE_F32: return 4u;
case FA_TYPE_F16: return 1u;
case FA_TYPE_Q4_0: return uint(QUANT_K_Q4_0);
case FA_TYPE_Q4_1: return uint(QUANT_K_Q4_1);
case FA_TYPE_Q5_0: return uint(QUANT_K_Q5_0);
case FA_TYPE_Q5_1: return uint(QUANT_K_Q5_1);
case FA_TYPE_Q8_0: return uint(QUANT_K_Q8_0);
case FA_TYPE_Q1_0: return uint(QUANT_K_Q1_0); // cm2-only, harmless elsewhere
default: return 1u;
}
}
// QUANT_R_MMQ for FA-eligible K types. Q4_*/Q5_* store two nibbles per byte
// (R==2); Q8_0 stores one byte per element (R==1). Used to derive the number
// of int32s per 32-element block on the MMQ K path: ints_per_block == 8 / R.
uint fa_quant_r_mmq(uint ty) {
switch (ty) {
case FA_TYPE_Q4_0: return uint(QUANT_R_Q4_0);
case FA_TYPE_Q4_1: return uint(QUANT_R_Q4_1);
case FA_TYPE_Q5_0: return uint(QUANT_R_Q5_0);
case FA_TYPE_Q5_1: return uint(QUANT_R_Q5_1);
case FA_TYPE_Q8_0: return uint(QUANT_R_Q8_0);
default: return 1u;
}
}
// These can't be `const` globals because GLSL forbids function calls in global
// const initializers, even when the spec constants would let the driver fold
// them. Macros expand at the use site and fold after specialization.
#define BLOCK_SIZE_K fa_block_elems(FaTypeK)
#define BLOCK_SIZE_V fa_block_elems(FaTypeV)
// F16 reads f16 elements directly from the binding; everything else routes
// through dequantize4 / the MMQ helpers to unpack from the packed block layout.
#define USE_DECODE_K (FaTypeK != FA_TYPE_F16)
#define USE_DECODE_V (FaTypeV != FA_TYPE_F16)
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
// Store column zero. This is used to save per-row m and L values for split_k.
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c == 0) {
uint32_t offset = iq2 + r;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
// Load the slope matrix, indexed by Q's dimension 2.
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
// Load the sink value, indexed by Q's dimension 2.
ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
return ACC_TYPE(data_s[h]);
}
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
gqa_iq1, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
q_stride, k_stride, v_stride, m_stride;
void init_indices()
{
N = p.N;
KV = p.KV;
if (p.k_num > 1) {
if (p.gqa_ratio > 1) {
i = 0;
// batch and split_k share gl_WorkGroupID.x
gqa_iq1 = gl_WorkGroupID.x / p.k_num;
split_k_index = gl_WorkGroupID.x % p.k_num;
} else {
gqa_iq1 = 0;
split_k_index = gl_WorkGroupID.x % p.k_num;
i = gl_WorkGroupID.x / p.k_num;
}
} else if (p.gqa_ratio > 1) {
i = 0;
gqa_iq1 = gl_WorkGroupID.x;
split_k_index = 0;
} else {
i = gl_WorkGroupID.x;
gqa_iq1 = 0;
split_k_index = 0;
}
Tr = CEIL_DIV(N, Br);
start_j = split_k_index * p.split_kv / Bc;
end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
iq2 = gl_WorkGroupID.y * p.gqa_ratio;
iq3 = gl_WorkGroupID.z;
// broadcast factors
rk2 = p.neq2/p.nek2;
rk3 = p.neq3/p.nek3;
rv2 = p.neq2/p.nev2;
rv3 = p.neq3/p.nev3;
// k indices
ik3 = iq3 / rk3;
ik2 = iq2 / rk2;
// v indices
iv3 = iq3 / rv3;
iv2 = iq2 / rv2;
// nb?1 are already divided by the type size and are in units of elements.
// When using grouped query attention, Q is indexed by iq2, so the stride
// should be nb02 (which is in bytes).
q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
k_stride = p.nb11;
v_stride = p.nb21;
// When using grouped query attention, all rows use the same mask (stride 0).
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
// that prevents the compiler from folding the "&" through the select
// and breaking the alignment detection.
m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
}
// Bias applied to softmax to stay in fp16 range.
// Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606
const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f;
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
void gqaStore(const in uint32_t r, const in uint32_t c, const in FLOAT_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * HSV / 4 + c;
data_ov4[o_offset + offset] = D_TYPEV4(elems);
}