#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.glsl"
#include "dequant_funcs.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
#define K_PER_ITER 8
#else
#define K_PER_ITER 4
#endif
uint a_offset, b_offset, d_offset, y_offset;
vec4 load_b(const uint j, const uint iybs, const uint iqs, const bool lastiter, out bool OOB_y, out bool OOB_z, out bool OOB_w) {
// Check if the latter elements are OOB, and don't fetch B or accumulate it.
OOB_y = lastiter && (iybs + iqs + y_offset >= p.ncols);
OOB_z = lastiter && (iybs + iqs + y_offset*2 >= p.ncols);
OOB_w = lastiter && (iybs + iqs + y_offset*3 >= p.ncols);
if (!OOB_w) {
return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]),
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]),
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*3]));
} else if (!OOB_z) {
return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]),
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]),
0);
} else if (!OOB_y) {
return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]),
0, 0);
} else {
return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
0, 0, 0);
}
}
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
{
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
const uint iybs = col - col%QUANT_K; // y block start index
#if K_PER_ITER == 8
#if QUANT_R == 2
// Note that we end up fetching bogus elements here, but its fine as they'll be
// within an accessible block.
const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);
const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
#else
const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);
#endif
#else
bool OOB_y;
bool OOB_z;
bool OOB_w;
const vec4 b = load_b(j, iybs, iqs, lastiter, OOB_y, OOB_z, OOB_w);
#endif
uint ibi = first_row*p.ncols;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib = (ibi + col)/QUANT_K; // block index
ibi += p.ncols;
#if K_PER_ITER == 8
vec4 v = dequantize4(ib, iqs, a_offset);
vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
const vec2 dm = get_dm(ib, a_offset);
if (dm.y != 0) { // quant has min component
v = v * dm.x + dm.y;
v2 = v2 * dm.x + dm.y;
}
// matrix multiplication
FLOAT_TYPE rowtmp = dot(bv0, v);
rowtmp += dot(bv1, v2);
if (dm.y == 0)
rowtmp *= dm.x;
temp[j][n] += rowtmp;
#else
if (!OOB_w) {
const vec4 v = dequantize4(ib, iqs, a_offset);
temp[j][n] += dot(v, b);
} else if (!OOB_z) {
const vec2 v0 = dequantize(ib, iqs, a_offset);
const FLOAT_TYPE v1 = dequantize1(ib + 2/QUANT_R, iqs, a_offset);
const vec3 v = vec3(v0.x, v0.y, v1);
const vec3 b0 = vec3(b.x, b.y, b.z);
temp[j][n] += dot(v, b0);
} else if (!OOB_y) {
const vec2 v0 = dequantize(ib, iqs, a_offset);
const vec2 b0 = vec2(b.x, b.y);
temp[j][n] += dot(v0, b0);
} else {
const FLOAT_TYPE v = dequantize1(ib, iqs, a_offset);
temp[j][n] = fma(v, b.x, temp[j][n]);
}
#endif
}
}
}
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
void iter_aligned_nonquant(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i)
{
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
const uint iqs = 0; // quant index
const uint iybs = col; // y block start index
const vec4 b = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4];
uint ibi = first_row*p.ncols;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib = (ibi + col)/QUANT_K; // block index
ibi += p.ncols;
const vec4 v = dequantize4_2aligned(ib, iqs, a_offset);
// matrix multiplication
temp[j][n] += dot(v, b);
}
}
}
#endif
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x;
get_offsets(a_offset, b_offset, d_offset);
const bool is_aligned_nonquant =
p.batch_stride_b % 4 == 0 && b_offset % 4 == 0 &&
p.ncols % 4 == 0 && BLOCK_SIZE % 4 == 0 &&
K_PER_ITER == 4;
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
num_iters++;
}
int unroll_count = 4;
uint unrolled_iters = num_iters & ~(unroll_count - 1);
uint i = 0;
#if K_PER_ITER == 4
// If the K dimension is odd, we need lastiter==true on the last iteration
// so OOB is computed correctly. Skip some unrolling to make that happen.
if ((p.ncols & 3) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
if (is_aligned_nonquant) {
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
}
}
} else {
#endif
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
i++;
}
}
#if K_PER_ITER == 4
}
#endif
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);
#if K_PER_ITER == 4
if ((p.ncols & 3) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
if (is_aligned_nonquant) {
while (i < unrolled_iters && is_aligned_nonquant) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
}
}
} else {
#endif
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
i++;
}
}
#if K_PER_ITER == 4
}
#endif
#if K_PER_ITER == 4
if (is_aligned_nonquant) {
while (i < num_iters) {
iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
}
} else {
#endif
while (i < num_iters) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
i++;
}
#if K_PER_ITER == 4
}
#endif
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, min(NUM_ROWS, p.stride_d - first_row));
}
}