llama-cpp-sys-4 0.2.45

Low Level Bindings to llama.cpp
Documentation
#version 450

#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require

#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require

#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
#extension GL_NV_cooperative_matrix2 : enable
#extension GL_EXT_buffer_reference : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#extension GL_KHR_shader_subgroup_vote : enable
#extension GL_EXT_null_initializer : enable

#include "types.glsl"
#include "dequant_funcs_cm2.glsl"
#include "flash_attn_base.glsl"

layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
layout (binding = 2) readonly buffer V {uint8_t data_v[];};
layout (binding = 3) readonly buffer M {uint8_t data_m[];};

ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
    return max(x, y);
}

float16_t maxReduceFp16(const in float16_t x, const in float16_t y) {
    return max(x, y);
}

ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
    return x;
}

// Replace matrix elements >= numRows or numCols with 'replace'
ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) {
    if (row >= numRows || col >= numCols) {
        return replace;
    }
    return elem;
}

ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem)
{
    return exp(elem);
}

ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1)
{
    return max(elem0, elem1);
}

#if BLOCK_SIZE > 1
#define DECODEFUNC , DEQUANTFUNC
#else
#define DECODEFUNC
#endif

// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
    if (r < N && c < HSV) {
        uint32_t offset = (iq2 + r) * HSV + c;
        data_o[o_offset + offset] = D_TYPE(elem);
    }
    return elem;
}

// Store O values for non-GQA split_k. Rows are tokens, not heads.
D_TYPE perElemOpNonGqaSplitKStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t unused, const in uint32_t iq2, const in uint32_t N) {
    uint32_t global_row = i * Br + r;
    if (global_row < N && c < HSV) {
        uint32_t o_off = HSV * p.ne1
            * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
        data_o[o_off + iq2 * HSV + c] = D_TYPE(elem);
    }
    return elem;
}

// Store L/M values for non-GQA split_k.
ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t lm_base, const in uint32_t iq2, const in uint32_t N) {
    uint32_t global_row = i * Br + r;
    if (global_row < N && c == 0) {
        uint32_t lm_off = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3
            + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
        data_o[lm_off + lm_base + iq2] = D_TYPE(elem);
    }
    return elem;
}

void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
    init_iq_shmem(gl_WorkGroupSize);
#endif

    init_indices();

    tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
    tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
    tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp);

    tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);

#if BLOCK_SIZE > 1
    tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
    tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
#endif

    tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
    tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
    tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);

    // hint to the compiler that strides are aligned for the aligned variant of the shader
    if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
    {
        q_stride &= ~7;
#if BLOCK_SIZE == 1
        k_stride &= ~7;
        v_stride &= ~7;
#endif
        m_stride &= ~7;
    }
    tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
    tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
    tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);

    coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
    coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;

    uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03;
    coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));

    Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
    Qf16 *= float16_t(p.scale);

    coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);

    coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;

    // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
    const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);

    L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
#if defined(ACC_TYPE_MAX)
    M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-ACC_TYPE_MAX / ACC_TYPE(2));
#else
    M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
#endif

    coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);

    // ALiBi
    if (p.max_bias > 0.0f) {
        coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
    }

    const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
    // mo_offset will point to the tile starting at row i*Br and col 0
    uint32_t mo_offset = mo_stride * i;

    uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/;
    if (p.nem2 != 1 || p.nem3 != 1) {
        m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
        mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
    }

    uint32_t mask_opt = 0;
    uint32_t mask_opt_idx = ~0;

    [[dont_unroll]]
    for (uint32_t j = start_j; j < end_j; ++j) {

        coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
        if (MASK_ENABLE) {

            if (USE_MASK_OPT && mask_opt_idx != j / 16) {
                mask_opt_idx = j / 16;
                mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
            }
            uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
            if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
                // skip this block
                continue;
            }
            // Only load if the block is not all zeros
            if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
                bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;

                if (nem1_bounds_check) {
                    tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
                    tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
                    tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
                    tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t

                    coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;

                    coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
                    // skip the block if the mask is entirely -inf
                    coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
                    if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
                        continue;
                    }
                } else {
                    tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
                    // Don't clamp against nem1 when GQA is enabled
                    uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
                    tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
                    tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);

                    coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;

                    coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
                    // skip the block if the mask is entirely -inf
                    coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
                    if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
                        continue;
                    }
                }
            }
        }

        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);

        coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;

        uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
        coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
        S = coopMatMulAdd(Qf16, K_T, S);

        if (LOGIT_SOFTCAP) {
            [[unroll]]
            for (int k = 0; k < S.length(); ++k) {
                S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
            }
        }

        if (MASK_ENABLE) {
            S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
        }

        // Clear padding elements to -inf, so they don't contribute to rowmax
        if (Clamp != 0 &&
            ((j + 1) * Bc > KV ||
             (i + 1) * Br > N)) {

            uint R = ((i + 1) * Br >  N) ?  (N % Br) : Br;
            uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;

            coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C);
        }

        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;

        coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);

        rowmax += coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(FATTN_KQ_MAX_OFFSET);

        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M;

        // M = max(rowmax, Mold)
        // P = e^(S - M)
        // eM = e^(Mold - M)
        coopMatPerElementNV(M, rowmax, Max, Mold);
        coopMatPerElementNV(P, S - M, Exp);
        coopMatPerElementNV(eM, Mold - M, Exp);

        // Clear padding elements to 0, so they don't contribute to rowsum
        if (Clamp != 0 &&
            ((j + 1) * Bc > KV ||
             (i + 1) * Br > N)) {

            uint R = ((i + 1) * Br >  N) ?  (N % Br) : Br;
            uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;

            coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C);
        }

        coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);

        // compute rowsum by multiplying by matrix of all ones.
        coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);

        rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
        rowsum = coopMatMulAdd(P_A, One, rowsum);

        coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
        uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
        coopMatLoadTensorNV(V,  data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);

        L = eM*L + rowsum;

        // This is the "diagonal" matrix in the paper, but since we do componentwise
        // multiply rather than matrix multiply it has the diagonal element smeared
        // across the row
        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> eMdiag;

        // resize eM by using smear/reduce
        coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);

        O *= coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(eMdiag);
        O = coopMatMulAdd(P_A, V, O);
    }

    // If there is split_k, then the split_k resolve shader does the final
    // division by L. Store the intermediate O value and per-row m and L values.
    if (p.k_num > 1) {
        coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);

        if (p.gqa_ratio > 1) {
            // note: O and Q have swapped coord 1,2.
            uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
            coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);

            o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
            coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
            coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
        } else {
            coopMatPerElementNV(O_D, O_D, perElemOpNonGqaSplitKStore, 0u, iq2, N);
            coopMatPerElementNV(L, L, perElemOpNonGqaSplitKStoreCol0, 0u, iq2, N);
            coopMatPerElementNV(M, M, perElemOpNonGqaSplitKStoreCol0, p.ne1, iq2, N);
        }
        return;
    }

    coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Ldiag;

    // resize L by using smear/reduce
    coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);

    if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> S;
        coopMatPerElementNV(S, S, perElemOpGetSink, iq2);

        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Mr;

        // resize M by using smear/reduce
        coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);

        // O, Ldiag, Mr all have the same type so all element locations match
        [[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) {
            ACC_TYPE sink = S[i];

            ACC_TYPE ms = ACC_TYPE(1.0f);
            ACC_TYPE vs = ACC_TYPE(1.0f);

            if (sink > Mr[i]) {
                ms = exp(Mr[i] - sink);

                O[i] *= float16_t(ms);
            } else {
                vs = exp(sink - Mr[i]);
            }

            Ldiag[i] = Ldiag[i]*ms + vs;
        }
    }

    [[unroll]]
    for (int k = 0; k < Ldiag.length(); ++k) {
        Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]);
    }

    coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);

    O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(Ldiag)*O_D;

#if defined(ACC_TYPE_MAX)
    [[unroll]] for (uint i = 0; i < O_D.length(); ++i) { O_D[i] = clamp(O_D[i], D_TYPE(-ACC_TYPE_MAX), D_TYPE(ACC_TYPE_MAX)); }
#endif

    uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;

    if (p.gqa_ratio > 1) {
        coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
    } else {
        tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
        tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);

        // permute dimensions
        tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);

        coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute);
    }
}