#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK_K 256
#define K_SCALE_SIZE 12
typedef struct {
half d half dmin uchar scales[K_SCALE_SIZE] uchar qh[QK_K/8] uchar qs[QK_K/2]} block_q5_K
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 16
#elif defined(ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
#define BLOCK_STRIDE (N_SIMDWIDTH/8)
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q5_K_f32(
global char * src0,
int offset0,
global char * src1,
int offset1,
global char * dst,
int offsetd,
int ne00,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = src0 + offset0 src1 = src1 + offset1 dst = dst + offsetd
ushort kmask1 = 0x3f3f ushort kmask2 = 0x0f0f ushort kmask3 = 0xc0c0
int ix = get_sub_group_local_id()/8 int it = get_sub_group_local_id()%8 int iq = it/4 int ir = it%4
int nb = ne00/QK_K
int r0 = get_group_id(0) int r1 = get_group_id(1) int im = get_group_id(2) int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST
int i12 = im%ne12 int i13 = im/ne12
int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03 int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13
global block_q5_K * x = (global block_q5_K *) (src0 + offset_src0) global float * y = (global float *) (src1 + offset_src1)
float yl[16] float yh[16] float sumf[N_DST] = {0.f} float all_sum
global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir
uchar u1_lo = (uchar)(1 << (2*iq)) uchar u2_lo = (uchar)(2 << (2*iq)) uchar u1_hi = (uchar)(1 << (2*iq + 4)) uchar u2_hi = (uchar)(2 << (2*iq + 4))
ushort sc16[4] uchar * sc8 = (uchar *)sc16
for (int ib = ix float4 sumy = {0.f, 0.f, 0.f, 0.f} for (int i = 0 yl[i+0] = y4[i+0] sumy.s0 += yl[i+0]
yl[i+8] = y4[i+32] sumy.s1 += yl[i+8]
yh[i+0] = y4[i+128] sumy.s2 += yh[i+0]
yh[i+8] = y4[i+160] sumy.s3 += yh[i+8] }
global ushort * sc = (global ushort *)x[ib].scales + iq global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir global uchar * qh = x[ib].qh + 8 * ir global half * dh = &x[ib].d
for (int row = 0 sc16[0] = sc[0] & kmask1 sc16[1] = sc[2] & kmask1 sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2) sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2)
global ushort * q2 = q1 + 32
float4 acc1 = {0.f, 0.f, 0.f, 0.f} float4 acc2 = {0.f, 0.f, 0.f, 0.f} for (int i = 0 acc1.s0 += yl[i+0] * ((q1[i/2] & 0x000F) + (qh[i+0] & u1_lo ? 16.f : 0.f)) acc1.s1 += yl[i+1] * ((q1[i/2] & 0x0F00) + (qh[i+1] & u1_lo ? 16.f*256.f : 0.f)) acc1.s2 += yl[i+8] * ((q1[i/2] & 0x00F0) + (qh[i+0] & u2_lo ? 16.f*16.f : 0.f)) acc1.s3 += yl[i+9] * ((q1[i/2] & 0xF000) + (qh[i+1] & u2_lo ? 16.f*4096.f: 0.f)) acc2.s0 += yh[i+0] * ((q2[i/2] & 0x000F) + (qh[i+0] & u1_hi ? 16.f : 0.f)) acc2.s1 += yh[i+1] * ((q2[i/2] & 0x0F00) + (qh[i+1] & u1_hi ? 16.f*256.f : 0.f)) acc2.s2 += yh[i+8] * ((q2[i/2] & 0x00F0) + (qh[i+0] & u2_hi ? 16.f*16.f : 0.f)) acc2.s3 += yh[i+9] * ((q2[i/2] & 0xF000) + (qh[i+1] & u2_hi ? 16.f*4096.f: 0.f)) }
float dall = dh[0] float dmin = dh[1] sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] +
(acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f +
(acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] +
(acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) -
dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7])
q1 += nb01/2 sc += nb01/2 dh += nb01/2 qh += nb01 }
y4 += BLOCK_STRIDE * QK_K }
global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0
for (int row = 0 all_sum = sub_group_reduce_add(sumf[row]) if (first_row + row < ne01) {
if (get_sub_group_local_id() == 0) {
dst_f32[first_row + row] = all_sum }
}
}
}