llama-cpp-sys-4 0.3.1

Low Level Bindings to llama.cpp
Documentation
#version 450

#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require

#include "mul_mat_vec_base.glsl"
#include "dequant_funcs.glsl"

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
#define K_PER_ITER 8
#else
#define K_PER_ITER 4
#endif


uint a_offset, b_offset, d_offset, y_offset;

vec4 load_b(const uint j, const uint iybs, const uint iqs, const bool lastiter, out bool OOB_y, out bool OOB_z, out bool OOB_w) {
    // Check if the latter elements are OOB, and don't fetch B or accumulate it.
    OOB_y = lastiter && (iybs + iqs + y_offset >= p.ncols);
    OOB_z = lastiter && (iybs + iqs + y_offset*2 >= p.ncols);
    OOB_w = lastiter && (iybs + iqs + y_offset*3 >= p.ncols);

    if (!OOB_w) {
        return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
                 FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]),
                 FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]),
                 FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*3]));
    } else if (!OOB_z) {
        return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
                 FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]),
                 FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]),
                 0);
    } else if (!OOB_y) {
        return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
                 FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]),
                 0, 0);
    } else {
        return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
                 0, 0, 0);
    }
}

void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
{
    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
        const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
        const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
        const uint iybs = col - col%QUANT_K; // y block start index

#if K_PER_ITER == 8
#if QUANT_R == 2
        // Note that we end up fetching bogus elements here, but its fine as they'll be
        // within an accessible block.
        const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
        const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);
        const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
        const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
#else
        const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
        const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);
#endif
#else
        bool OOB_y;
        bool OOB_z;
        bool OOB_w;

        const vec4 b = load_b(j, iybs, iqs, lastiter, OOB_y, OOB_z, OOB_w);
#endif
        uint ibi = first_row*p.ncols;
        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
            const uint ib = (ibi + col)/QUANT_K; // block index
            ibi += p.ncols;

#if K_PER_ITER == 8
            vec4 v = dequantize4(ib, iqs, a_offset);
            vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);

            const vec2 dm = get_dm(ib, a_offset);
            if (dm.y != 0) { // quant has min component
                v = v * dm.x + dm.y;
                v2 = v2 * dm.x + dm.y;
            }

            // matrix multiplication
            FLOAT_TYPE rowtmp = dot(bv0, v);
            rowtmp += dot(bv1, v2);

            if (dm.y == 0)
                rowtmp *= dm.x;

            temp[j][n] += rowtmp;
#else
            if (!OOB_w) {
                const vec4 v = dequantize4(ib, iqs, a_offset);
                temp[j][n] += dot(v, b);
            } else if (!OOB_z) {
                const vec2 v0 = dequantize(ib, iqs, a_offset);
                const FLOAT_TYPE v1 = dequantize1(ib + 2/QUANT_R, iqs, a_offset);
                const vec3 v = vec3(v0.x, v0.y, v1);
                const vec3 b0 = vec3(b.x, b.y, b.z);
                temp[j][n] += dot(v, b0);
            } else if (!OOB_y) {
                const vec2 v0 = dequantize(ib, iqs, a_offset);
                const vec2 b0 = vec2(b.x, b.y);
                temp[j][n] += dot(v0, b0);
            } else {
                const FLOAT_TYPE v = dequantize1(ib, iqs, a_offset);
                temp[j][n] = fma(v, b.x, temp[j][n]);
            }
#endif
        }
    }
}

#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
void iter_aligned_nonquant(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i)
{
    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
        const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
        const uint iqs = 0; // quant index
        const uint iybs = col; // y block start index

        const vec4 b = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4];

        uint ibi = first_row*p.ncols;
        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
            const uint ib = (ibi + col)/QUANT_K; // block index
            ibi += p.ncols;

            const vec4 v = dequantize4_2aligned(ib, iqs, a_offset);

            // matrix multiplication
            temp[j][n] += dot(v, b);
        }
    }
}
#endif

void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
    const uint tid = gl_LocalInvocationID.x;

    get_offsets(a_offset, b_offset, d_offset);
    const bool is_aligned_nonquant =
        p.batch_stride_b % 4 == 0 && b_offset % 4 == 0 &&
        p.ncols % 4 == 0 && BLOCK_SIZE % 4 == 0 &&
        K_PER_ITER == 4;

    y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;

    FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];

    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
            temp[j][i] = FLOAT_TYPE(0);
        }
    }

    uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
    if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
        num_iters++;
    }
    int unroll_count = 4;
    uint unrolled_iters = num_iters & ~(unroll_count - 1);

    uint i = 0;

#if K_PER_ITER == 4
    // If the K dimension is odd, we need lastiter==true on the last iteration
    // so OOB is computed correctly. Skip some unrolling to make that happen.
    if ((p.ncols & 3) != 0 &&
        unrolled_iters == num_iters &&
        unrolled_iters > 0) {
        unrolled_iters -= unroll_count;
    }
    if (is_aligned_nonquant) {
        while (i < unrolled_iters) {
            // Manually partially unroll the loop
            [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
                iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER);
                i++;
            }
        }
    } else {
#endif
    while (i < unrolled_iters) {
        // Manually partially unroll the loop
        [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
            iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
            i++;
        }
    }
#if K_PER_ITER == 4
    }
#endif

    unroll_count = 2;
    unrolled_iters = num_iters & ~(unroll_count - 1);

#if K_PER_ITER == 4
    if ((p.ncols & 3) != 0 &&
        unrolled_iters == num_iters &&
        unrolled_iters > 0) {
        unrolled_iters -= unroll_count;
    }

    if (is_aligned_nonquant) {
        while (i < unrolled_iters && is_aligned_nonquant) {
            // Manually partially unroll the loop
            [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
                iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER);
                i++;
            }
        }
    } else {
#endif
    while (i < unrolled_iters) {
        // Manually partially unroll the loop
        [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
            iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
            i++;
        }
    }
#if K_PER_ITER == 4
    }
#endif

#if K_PER_ITER == 4
    if (is_aligned_nonquant) {
        while (i < num_iters) {
            iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER);
            i++;
        }
    } else {
#endif
    while (i < num_iters) {
        iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
        i++;
    }
#if K_PER_ITER == 4
    }
#endif

    reduce_result(temp, d_offset, first_row, num_rows, tid);
}

void main() {
    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);

#ifdef NEEDS_INIT_IQ_SHMEM
    init_iq_shmem(gl_WorkGroupSize);
#endif

    // do NUM_ROWS at a time, unless there aren't enough remaining rows
    if (first_row + NUM_ROWS <= p.stride_d) {
        compute_outputs(first_row, NUM_ROWS);
    } else {
        if (first_row >= p.stride_d) {
            return;
        }
        compute_outputs(first_row, min(NUM_ROWS, p.stride_d - first_row));
    }
}