#pragma OPENCL EXTENSION cl_khr_fp16 : enable
typedef char int8_ttypedef uchar uint8_ttypedef short int16_ttypedef ushort uint16_ttypedef int int32_ttypedef uint uint32_t
#define QK4_0 32
//------------------------------------------------------------------------------
// block_q4_0
//------------------------------------------------------------------------------
struct block_q4_0
{
half d uint8_t qs[QK4_0 / 2]}
//------------------------------------------------------------------------------
// dequantize_q4_0_f32, dequantize_q4_0_f16
//------------------------------------------------------------------------------
void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) {
global ushort * qs = ((global ushort *)xb + 1) float d1 = il ? (xb->d / 16.h) : xb->d float d2 = d1 / 256.f float md = -8.h * xb->d ushort mask0 = il ? 0x00F0 : 0x000F ushort mask1 = mask0 << 8
reg->s0 = d1 * (qs[0] & mask0) + md reg->s1 = d2 * (qs[0] & mask1) + md
reg->s2 = d1 * (qs[1] & mask0) + md reg->s3 = d2 * (qs[1] & mask1) + md
reg->s4 = d1 * (qs[2] & mask0) + md reg->s5 = d2 * (qs[2] & mask1) + md
reg->s6 = d1 * (qs[3] & mask0) + md reg->s7 = d2 * (qs[3] & mask1) + md
reg->s8 = d1 * (qs[4] & mask0) + md reg->s9 = d2 * (qs[4] & mask1) + md
reg->sa = d1 * (qs[5] & mask0) + md reg->sb = d2 * (qs[5] & mask1) + md
reg->sc = d1 * (qs[6] & mask0) + md reg->sd = d2 * (qs[6] & mask1) + md
reg->se = d1 * (qs[7] & mask0) + md reg->sf = d2 * (qs[7] & mask1) + md}
//------------------------------------------------------------------------------
// get_rows
//------------------------------------------------------------------------------
kernel void kernel_get_rows_f32(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = (global void*)((global char*)src0 + offset0) src1 = (global int*)((global char*)src1 + offset1) dst = (global float*)((global char*)dst + offsetd)
int i10 = get_group_id(0) int i11 = get_group_id(1) int i12 = get_group_id(2)
int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]
int i02 = i11 int i03 = i12
for (int ind = get_local_id(0) if (ind >= ne00) {
return }
((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind] }
}
kernel void kernel_get_rows_f16(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = (global void*)((global char*)src0 + offset0) src1 = (global int*)((global char*)src1 + offset1) dst = (global float*)((global char*)dst + offsetd)
int i10 = get_group_id(0) int i11 = get_group_id(1) int i12 = get_group_id(2)
int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]
int i02 = i11 int i03 = i12
for (int ind = get_local_id(0) if (ind >= ne00) {
return }
((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
((global half *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind] }
}
kernel void kernel_get_rows_q4_0(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = (global void*)((global char*)src0 + offset0) src1 = (global int*)((global char*)src1 + offset1) dst = (global float*)((global char*)dst + offsetd)
const int NL = 2
int i10 = get_group_id(0) int i11 = get_group_id(1) int i12 = get_group_id(2)
int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]
int i02 = i11 int i03 = i12
for (int ind = get_local_id(0) float16 temp if (ind >= ne00) {
return }
dequantize_q4_0_f32(
((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03)) + ind/NL, ind%NL, &temp) *(((global float16 *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1)) + ind) = temp }
}