#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_basic : enable
#if USE_SUBGROUP_CLUSTERED
#extension GL_KHR_shader_subgroup_clustered : enable
#endif
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_arithmetic : enable
#endif
// Caller guarantees valid spec constants: S_V % COLS_PER_WG == 0 and S_V % LANES_PER_COLUMN == 0,
// so no bounds checking is needed.
layout(constant_id = 0) const uint S_V = 128;
layout(constant_id = 1) const uint KDA = 0;
layout(constant_id = 2) const uint SUBGROUP_SIZE = 32;
layout(constant_id = 3) const uint LANES_PER_COLUMN = 32;
const uint COLS_PER_WG = SUBGROUP_SIZE / LANES_PER_COLUMN;
const uint ROWS_PER_LANE = S_V / LANES_PER_COLUMN;
layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in;
layout(push_constant) uniform Parameters {
uint H;
uint n_tokens;
uint n_seqs;
uint s_off;
uint sq1, sq2, sq3;
uint sv1, sv2, sv3;
uint sb1, sb2, sb3;
uint neq1, rq3;
float scale;
};
layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; };
layout(binding = 1) readonly buffer KBuf { FLOAT_TYPE data_k[]; };
layout(binding = 2) readonly buffer VBuf { FLOAT_TYPE data_v[]; };
layout(binding = 3) readonly buffer GBuf { FLOAT_TYPE data_g[]; };
layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; };
layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; };
layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; };
#if !USE_SUBGROUP_ADD && !USE_SUBGROUP_CLUSTERED
shared FLOAT_TYPE temp[SUBGROUP_SIZE];
// This does a reduction across groups of LANES_PER_COLUMN
FLOAT_TYPE reduce_add_shmem(FLOAT_TYPE partial) {
const uint lane = gl_SubgroupInvocationID;
temp[lane] = partial;
barrier();
[[unroll]] for (uint s = LANES_PER_COLUMN / 2u; s > 0; s >>= 1u) {
FLOAT_TYPE other = temp[lane ^ s];
barrier();
temp[lane] += other;
barrier();
}
const FLOAT_TYPE result = temp[lane];
barrier();
return result;
}
#endif
// clusterSize for subgroupClusteredAdd must be a compile-time constant; branch on spec constant
FLOAT_TYPE reduce_partial(FLOAT_TYPE partial) {
switch (LANES_PER_COLUMN) {
case 1u:
return partial;
#if USE_SUBGROUP_CLUSTERED
// Workaround for GLSL requiring a literal constant for the cluster size.
// The branches should all fold away.
case 2u:
return subgroupClusteredAdd(partial, 2u);
case 4u:
return subgroupClusteredAdd(partial, 4u);
case 8u:
return subgroupClusteredAdd(partial, 8u);
case 16u:
return subgroupClusteredAdd(partial, 16u);
case 32u:
return subgroupClusteredAdd(partial, 32u);
case 64u:
return subgroupClusteredAdd(partial, 64u);
#endif
default:
#if USE_SUBGROUP_ADD
return subgroupAdd(partial);
#else
return reduce_add_shmem(partial);
#endif
}
}
void main() {
const uint head_id = gl_WorkGroupID.x;
const uint seq_id = gl_WorkGroupID.y;
const uint lane = gl_SubgroupInvocationID % LANES_PER_COLUMN;
const uint col = gl_WorkGroupID.z * COLS_PER_WG + (gl_SubgroupInvocationID / LANES_PER_COLUMN);
const uint iq1 = head_id % neq1;
const uint iq3 = seq_id / rq3;
const uint state_size = S_V * S_V;
const uint state_base = (seq_id * H + head_id) * state_size;
FLOAT_TYPE s_shard[ROWS_PER_LANE];
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]);
}
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
for (uint t = 0; t < n_tokens; t++) {
const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;
const uint k_off = q_off;
const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;
const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1;
const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]);
FLOAT_TYPE k_reg[ROWS_PER_LANE];
FLOAT_TYPE q_reg[ROWS_PER_LANE];
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
const uint i = r * LANES_PER_COLUMN + lane;
k_reg[r] = FLOAT_TYPE(data_k[k_off + i]);
q_reg[r] = FLOAT_TYPE(data_q[q_off + i]);
}
FLOAT_TYPE g_exp[ROWS_PER_LANE];
if (KDA == 0) {
const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off]));
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
g_exp[r] = g_val;
}
} else {
const uint g_base = gb_off * S_V;
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
const uint i = r * LANES_PER_COLUMN + lane;
g_exp[r] = exp(FLOAT_TYPE(data_g[g_base + i]));
}
}
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
FLOAT_TYPE kv_shard = 0.0;
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
kv_shard += g_exp[r] * s_shard[r] * k_reg[r];
}
FLOAT_TYPE kv_col = reduce_partial(kv_shard);
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
FLOAT_TYPE attn_partial = 0.0;
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
s_shard[r] = g_exp[r] * s_shard[r] + k_reg[r] * delta_col;
attn_partial += s_shard[r] * q_reg[r];
}
FLOAT_TYPE attn_col = reduce_partial(attn_partial);
if (lane == 0) {
data_dst[attn_off + col] = attn_col * scale;
}
attn_off += S_V * H;
}
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
}
}