llama-cpp-sys-4 0.2.52

Low Level Bindings to llama.cpp
Documentation
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable

#define TILESIZE_K 16
#define TILESIZE_M 64
#define TILESIZE_N 32


static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) {
    ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
    fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
    fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
    fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
    fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;

    bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
    bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
    bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
    bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;

    fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
    fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
    fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
    fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;

    sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
    sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
    sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
    sign_b.hi = fp4x8.s0 & 0x8000;

    fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
    fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;

    ushort2 fp16_packed_a_1, fp16_packed_b_1;
    fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
    fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
    fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
    fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;

    bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
    bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
    bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
    bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;

    fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
    fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
    fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
    fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;

    sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
    sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
    sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
    sign_b.hi = fp4x8.s1 & 0x8000;

    fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
    fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;

    return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
}


#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \
    acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \
    acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \
    acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \
    acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \
    acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \
    acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \
    acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \
    acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \
    acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \
    acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \
    acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \
    acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \
    acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \
    acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \
    acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \
    acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \
    acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \
    acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \
    acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \
    acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \
    acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \
    acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \
    acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \
    acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \
    acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \
    acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \
    acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \
    acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \
    acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \
    acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \
    acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \
    acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \
    c_reg.lo += convert_float8(acc.lo); \
    c_reg.hi += convert_float8(acc.hi); \
    acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \
    acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \
    acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \
    acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \
    acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \
    acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \
    acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \
    acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \
    acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \
    acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \
    acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \
    acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \
    acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \
    acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \
    acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \
    acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \
    acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \
    acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \
    acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \
    acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \
    acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \
    acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \
    acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \
    acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \
    acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \
    acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \
    acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \
    acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \
    acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \
    acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \
    acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \
    acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \
    c_reg.lo += convert_float8(acc.lo); \
    c_reg.hi += convert_float8(acc.hi); \


static inline half e8m0_to_fp16(uchar x) {
    ushort bits;
    bits = (ushort)(x) - (ushort)(112);
    bits = ((bits & 0x00E0) != 0) ? 0x7C00 : (bits << 10);
    return as_half(bits);
}

static inline float e8m0_to_fp32(uchar x) {
    int bits;
    bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
    return as_float(bits);
}


__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair
kernel void kernel_gemm_moe_mxfp4_f32_ns(
        __read_only  image1d_buffer_t src0_q,
        __global     uchar *          src0_d,
        __read_only  image1d_buffer_t src1,
        __global     uint *           src2,
        __global     ushort *         src2_emap,
        __write_only image1d_buffer_t dst,
        __global     int *            total_tiles,
        uint ne00,
        uint ne01
) {
    uint block_id_m = get_global_id(1); // m_tile
    uint block_id_n = get_global_id(2); // n_tile

    // Boundary check
    if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
        return;
    }

    __private half16 reg_a;
    __private float32 reg_c = (float32)(0);
    __local half4 shared_b[128];

    const ushort expert_id = src2_emap[block_id_n];

    const uint row = block_id_m * TILESIZE_M;
    const uint col = block_id_n * TILESIZE_N;

    uint sub_block_id_m = get_local_id(0);
    uint2 b_global_offset;
    b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00;
    b_global_offset.y = b_global_offset.x + (16 * ne00);
    uint2 b_local_offset;
    b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2);
    b_local_offset.y = b_local_offset.x + 16;

    // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks
    for (uint step = 0; step < ne00; step += TILESIZE_K * 2) {
        // First sub-block
        uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
        uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5);
        uint b_sub_offset = col * ne00 + step;

        // Load scale for current mxfp4 block
        uint s_offset = s_sub_offset + get_global_id(0);
        float s = e8m0_to_fp32(src0_d[s_offset]);

        // Load 16 fp4 (64-bits) in transposed layout
        uint2 mxfp4x16;
        mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
        mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;

        // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
        float8 bx8_f32;
        bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
        bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
        // Convert to half and store to LM to share within the subgroup
        half8 bx8_f16 = convert_half8(bx8_f32);
        shared_b[b_local_offset.x] = bx8_f16.lo;
        shared_b[b_local_offset.y] = bx8_f16.hi;

        // Dequantization
        reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s;
        reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s;

        sub_group_barrier(CLK_LOCAL_MEM_FENCE);

        // 32 16x16 fp16 dot product with 8 elements reduction for better precision
        half16 acc;
        dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
        dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);

        // Repeat for second sub-block
        uint half_step = step + TILESIZE_K;
        q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
        b_sub_offset = col * ne00 + half_step;

        // Load next 16 fp4 (64-bits) in transposed layout
        mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
        mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;

        // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
        bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
        bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
        // Convert to half and store to LM to share within the subgroup
        bx8_f16 = convert_half8(bx8_f32);
        shared_b[b_local_offset.x] = bx8_f16.lo;
        shared_b[b_local_offset.y] = bx8_f16.hi;

        // Dequantization
        reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s;
        reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s;

        sub_group_barrier(CLK_LOCAL_MEM_FENCE);

        // 32 16x16 fp16 dot product with 3-levels reduction for better precision
        dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
        dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
    }

    // Load poster router and share in LM
    __local uint out_idx[TILESIZE_N];

    if (get_local_id(0) < TILESIZE_N) {
        uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)];
        if (idx == 0xFFFFFFFF) {
            idx = src2[block_id_n * TILESIZE_N + 0];
        }
        out_idx[get_local_id(0)] = idx * ne01;
    }

    barrier(CLK_LOCAL_MEM_FENCE);

    // Scatter results back to original position in output grid
    uint m_offset = row + get_local_id(0);

    write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1));
    write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2));
    write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3));
    write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4));
    write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5));
    write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6));
    write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7));
    write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8));
    write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9));
    write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa));
    write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb));
    write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc));
    write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd));
    write_imagef(dst, out_idx[14] + m_offset, (reg_c.se));
    write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf));
    write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg));
    write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh));
    write_imagef(dst, out_idx[18] + m_offset, (reg_c.si));
    write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj));
    write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk));
    write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl));
    write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm));
    write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn));
    write_imagef(dst, out_idx[24] + m_offset, (reg_c.so));
    write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp));
    write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq));
    write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr));
    write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss));
    write_imagef(dst, out_idx[29] + m_offset, (reg_c.st));
    write_imagef(dst, out_idx[30] + m_offset, (reg_c.su));
    write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv));

    // Store zero padding parts to the index of first output in tile, override correct result in the end
    barrier(CLK_GLOBAL_MEM_FENCE);
    write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0));
}