#version 450
#extension GL_EXT_control_flow_attributes : enable
#ifdef COOPMAT2
#extension GL_NV_cooperative_matrix2 : enable
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_memory_scope_semantics : enable
#endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_memory_scope_semantics : enable
#endif
#include "types.glsl"
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
layout(binding = 0) readonly buffer A {
A_TYPE knl_data[];
}; // src0 - kernel: [KW, KH, KD, IC*OC]
layout(binding = 1) readonly buffer B {
B_TYPE src_data[];
}; // src1 - input: [IW, IH, ID, IC*N] -- channel_first format
layout(binding = 2) writeonly buffer D {
D_TYPE dst_data[];
}; // dst - result: [OW, OH, OD, OC*N]
layout(push_constant) uniform parameter {
// I/O channels, batch size
uint32_t OC;
uint32_t IC;
uint32_t N;
// Tensor spatial sizes: input, output
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t OW;
uint32_t OH;
uint32_t OD;
// Strides in elements
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
// fastdiv helper values
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
uint32_t OWOHODmp; uint32_t OWOHODL;
}
p;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
// Blocktile sizes
layout(constant_id = 1) const uint BS_K = 128;
layout(constant_id = 2) const uint BS_CRS = 16;
layout(constant_id = 3) const uint BS_NPQ = 128;
// Thread-tile sizes
layout(constant_id = 4) const uint TS_K = 8;
layout(constant_id = 5) const uint SHMEM_PAD = 4;
// Stride, padding, dilation
layout(constant_id = 6) const uint s0 = 1;
layout(constant_id = 7) const uint s1 = 1;
layout(constant_id = 8) const uint s2 = 1;
layout(constant_id = 9) const uint p0 = 0;
layout(constant_id = 10) const uint p1 = 0;
layout(constant_id = 11) const uint p2 = 0;
layout(constant_id = 12) const uint d0 = 1;
layout(constant_id = 13) const uint d1 = 1;
layout(constant_id = 14) const uint d2 = 1;
// Kernel spatial sizes
layout(constant_id = 15) const uint KW = 1;
layout(constant_id = 16) const uint KH = 1;
layout(constant_id = 17) const uint KD = 1;
// when set, skip bounds checks and address clamps (K/CRS/NPQ are tile-aligned)
layout(constant_id = 18) const uint aligned = 0;
// stage cm2 result through shmem (Csh) for coalesced stores. cm1 always does this.
layout(constant_id = 19) const uint csh_store = 0;
#ifdef COOPMAT
// cm1 subgroup tile: each subgroup computes a WM x WN region as a grid of
// TM x TN x TK fragments. Requires WM%TM == WN%TN == BS_K%WM == BS_NPQ%WN ==
// BS_CRS%TK == 0, and WG_SIZE == (BS_K/WM) * (BS_NPQ/WN) * subgroup_size.
layout(constant_id = 20) const uint WM = 32;
layout(constant_id = 21) const uint WN = 32;
const uint TM = 16;
const uint TN = 16;
const uint TK = 16;
const uint cms_per_row = WM / TM;
const uint cms_per_col = WN / TN;
const uint warps_M = BS_K / WM;
const uint warps_N = BS_NPQ / WN;
#endif
// without padding, ID_idx/IH_idx/IW_idx are in bounds by construction
const bool dhw_in_bounds = (p0 == 0) && (p1 == 0) && (p2 == 0);
uint32_t tid = gl_LocalInvocationID.x;
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
uint splitWork(uint work_size, uint block_size) {
return (block_size + work_size - 1) / block_size;
}
uint32_t K = p.OC;
uint32_t CRS = p.IC * KD * KH * KW;
uint32_t NPQ = p.N * p.OD * p.OH * p.OW;
// Number of blocktiles per input
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
#if defined(COOPMAT2) || defined(COOPMAT)
#define SHMEM_TYPE float16_t
#else
#define SHMEM_TYPE float
#endif
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
const uint32_t Ash_len = BS_K * Ash_stride;
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
shared SHMEM_TYPE Ash[Ash_len]; // K x CRS
shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ
#if defined(COOPMAT2) || defined(COOPMAT)
// stage matC through shmem so global stores are row-major (NPQ-contiguous)
const uint32_t Csh_stride = BS_NPQ;
#ifdef COOPMAT
const uint32_t Csh_len = BS_K * Csh_stride;
#else
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1;
#endif
shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ
#endif
// Threadtile sizes
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
// Number of threadtiles per blocktile
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
/*
Compute
KxCRS @ CRSxNPQ = K x NPQ
K=OC
C=IC
D,R,S=KD,KH,KW
Z,P,Q=OD,OH,OW
*/
uint32_t B_idx_K = gl_WorkGroupID.x;
uint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512;
uint32_t T_y = tid / NT_NPQ;
uint32_t T_x = tid % NT_NPQ;
uint32_t Ar = tid / BS_CRS;
uint32_t Ac = tid % BS_CRS;
const uint32_t ArpWg = WG_SIZE / BS_CRS;
uint32_t Br = tid / BS_NPQ;
uint32_t Bc = tid % BS_NPQ;
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
void split_crs(uint32_t crs_idx, out uint32_t ic, out uint32_t kd, out uint32_t kh, out uint32_t kw) {
const uint32_t KHKW = KH * KW;
const uint32_t KDKHKW = KD * KHKW;
ic = crs_idx / KDKHKW;
uint32_t rem = crs_idx - ic * KDKHKW;
kd = rem / KHKW;
rem = rem - kd * KHKW;
kh = rem / KW;
kw = rem - kh * KW;
}
void split_npq(uint32_t npq_idx, out uint32_t n, out uint32_t od, out uint32_t oh, out uint32_t ow) {
const uint32_t OWOH = p.OW * p.OH;
n = fastdiv(npq_idx, p.OWOHODmp, p.OWOHODL);
uint32_t rem = npq_idx - n * p.OD * OWOH;
od = fastdiv(rem, p.OWOHmp, p.OWOHL);
rem = rem - od * OWOH;
oh = fastdiv(rem, p.OWmp, p.OWL);
ow = rem - oh * p.OW;
}
#ifdef COOPMAT2
#define ACC_TYPE float16_t
ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
{
uint32_t K_idx = B_idx_K * BS_K + r;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(elem);
}
return elem;
}
#endif
void main() {
if (B_idx_NPQ * BS_NPQ >= NPQ) {
return;
}
#ifdef COOPMAT2
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
#elif defined(COOPMAT)
coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
sums[i] = coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0);
}
const uint warp_r = gl_SubgroupID / warps_N;
const uint warp_c = gl_SubgroupID % warps_N;
#else
float regC[TS_K][TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = 0.0;
}
}
#endif
/* Advance block in CRS dim */
[[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
uint32_t CRS_idx_a = B_idx_CRS * BS_CRS + Ac;
uint32_t IC_idx_a;
uint32_t KD_idx_a;
uint32_t KH_idx_a;
uint32_t KW_idx_a;
split_crs(CRS_idx_a, IC_idx_a, KD_idx_a, KH_idx_a, KW_idx_a);
/* Load kernel to A_block: (BS_K x BS_CRS)*/
UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
uint32_t B_ly = r_offset + Ar;
uint32_t B_lx = Ac;
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + KD_idx_a * p.nb02 + (K_idx * p.IC + IC_idx_a) * p.nb03;
if (aligned == 0) {
knl_idx = min(knl_idx, K * CRS - 1);
}
float val = knl_data[knl_idx];
if (aligned == 0 && (K_idx >= K || CRS_idx_a >= CRS)) {
val = 0.0;
}
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
}
/* Load input to B_block: (BS_CRS x BS_NPQ) */
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
uint32_t B_ly = r_offset + Br; /* Row index of B block */
uint32_t B_lx = Bc;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t CRS_idx_b = B_idx_CRS * BS_CRS + B_ly;
uint32_t IC_idx_b;
uint32_t KD_idx_b;
uint32_t KH_idx_b;
uint32_t KW_idx_b;
split_crs(CRS_idx_b, IC_idx_b, KD_idx_b, KH_idx_b, KW_idx_b);
uint32_t ID_idx = OD_idx * s2 + KD_idx_b * d2 - p2;
uint32_t IH_idx = OH_idx * s1 + KH_idx_b * d1 - p1;
uint32_t IW_idx = OW_idx * s0 + KW_idx_b * d0 - p0;
uint32_t src_idx = IW_idx + IH_idx * p.nb11 + ID_idx * p.nb12 + (N_idx * p.IC + IC_idx_b) * p.nb13;
// skip clamp when address can't go OOB
if (aligned == 0 || !dhw_in_bounds) {
src_idx = min(src_idx, p.IC * p.N * p.IW * p.IH * p.ID - 1);
}
float val = src_data[src_idx];
bool oob = false;
if (aligned == 0 && (CRS_idx_b >= CRS || NPQ_idx >= NPQ)) {
oob = true;
}
// also catches lower-bound underflow (idx wraps to 0x80000000+)
if (!dhw_in_bounds && (ID_idx >= p.ID || IH_idx >= p.IH || IW_idx >= p.IW)) {
oob = true;
}
if (oob) {
val = 0.0;
}
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
}
barrier();
#ifdef COOPMAT2
coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
matC = coopMatMulAdd(matA, matB, matC);
#elif defined(COOPMAT)
// each subgroup multiplies its grid of fragments per TK-sized CRS chunk
[[unroll]] for (uint k_step = 0; k_step < BS_CRS / TK; k_step++) {
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a[cms_per_row];
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
const uint a_off = (warp_r * WM + cm_row * TM) * Ash_stride + k_step * TK;
coopMatLoad(cache_a[cm_row], Ash, a_off, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
}
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
const uint b_off = k_step * TK * Bsh_stride + warp_c * WN + cm_col * TN;
coopMatLoad(cache_b, Bsh, b_off, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]);
}
}
}
#else
if (T_y * TS_K < K) {
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
float regA[TS_K];
float regB[TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
}
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
}
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
}
}
}
}
#endif
barrier();
}
/* Save C* */
#if defined(COOPMAT2) || defined(COOPMAT)
// stage matC into Csh, then write to dst with coalesced NPQ-contiguous stores
#ifdef COOPMAT
const bool use_staged_store = true;
#else
const bool use_staged_store = (csh_store != 0);
#endif
if (use_staged_store) {
#ifdef COOPMAT
// cm1: each subgroup stores its fragment grid into its Csh slot
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const uint csh_off = (warp_r * WM + cm_row * TM) * Csh_stride + warp_c * WN + cm_col * TN;
coopMatStore(sums[cm_col * cms_per_row + cm_row], Csh, csh_off, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
}
#else
coopMatStore(matC, Csh, 0, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
#endif
barrier();
// cooperative shmem->global: WG threads spread across BS_NPQ (the
// contiguous direction of dst), each iter covers store_rows_per_iter K-rows
const uint32_t store_rows_per_iter = WG_SIZE / BS_NPQ;
const uint32_t store_iters = BS_K / store_rows_per_iter;
const uint32_t k_thread_offset = tid / BS_NPQ;
const uint32_t npq_thread = tid % BS_NPQ;
[[unroll]] for (uint32_t i = 0; i < store_iters; i++) {
uint32_t k_local = i * store_rows_per_iter + k_thread_offset;
uint32_t K_idx = B_idx_K * BS_K + k_local;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + npq_thread;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(Csh[k_local * Csh_stride + npq_thread]);
}
}
}
#ifdef COOPMAT2
else {
coopMatPerElementNV(matC, matC, perElemOpStore);
}
#endif
#else
if (T_y * TS_K < K) {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(regC[T_ly][T_lx]);
}
}
}
}
#endif
}