llama-gguf 0.14.0

A high-performance Rust implementation of llama.cpp - LLM inference engine with full GGUF support
Documentation
#version 450

// Online-softmax multi-head attention with causal masking and GQA support
// Q: [num_heads, seq_len, head_dim]  K: [num_kv_heads, kv_len, head_dim]
// V: [num_kv_heads, kv_len, head_dim]  Out: [num_heads, seq_len, head_dim]
// Dispatch: (num_heads, seq_len, 1) workgroups

layout(local_size_x = 256) in;

layout(set = 0, binding = 0) readonly buffer QBuf { float q_data[]; };
layout(set = 0, binding = 1) readonly buffer KBuf { float k_data[]; };
layout(set = 0, binding = 2) readonly buffer VBuf { float v_data[]; };
layout(set = 0, binding = 3) writeonly buffer OutBuf { float out_data[]; };

layout(push_constant) uniform Params {
    int num_heads;
    int num_kv_heads;
    int seq_len;
    int kv_len;
    int head_dim;
    float scale;
};

shared float accum[256];
shared float reduction[256];
shared float s_max_score;
shared float s_sum_exp;
shared float s_weight;
shared float s_correction;

void main() {
    uint head = gl_WorkGroupID.x;
    uint s = gl_WorkGroupID.y;
    uint tid = gl_LocalInvocationID.x;
    uint nt = gl_WorkGroupSize.x;

    uint kv_head = head / (uint(num_heads) / uint(num_kv_heads));
    uint q_abs_pos = uint(kv_len) - uint(seq_len) + s;

    for (uint d = tid; d < uint(head_dim); d += nt) {
        accum[d] = 0.0;
    }
    if (tid == 0) {
        s_max_score = -3.402823466e+38;
        s_sum_exp = 0.0;
    }
    barrier();

    uint q_base = head * uint(seq_len) * uint(head_dim) + s * uint(head_dim);

    for (uint kv_pos = 0; kv_pos <= q_abs_pos && kv_pos < uint(kv_len); kv_pos++) {
        float local_dot = 0.0;
        uint k_base = kv_head * uint(kv_len) * uint(head_dim) + kv_pos * uint(head_dim);
        for (uint d = tid; d < uint(head_dim); d += nt) {
            local_dot += q_data[q_base + d] * k_data[k_base + d];
        }

        reduction[tid] = local_dot;
        barrier();
        for (uint stride = nt / 2; stride > 0; stride >>= 1) {
            if (tid < stride) {
                reduction[tid] += reduction[tid + stride];
            }
            barrier();
        }

        float score = reduction[0] * scale;

        if (tid == 0) {
            float old_max = s_max_score;
            if (score > old_max) {
                s_correction = exp(old_max - score);
                s_sum_exp *= s_correction;
                s_max_score = score;
            } else {
                s_correction = 1.0;
            }
            s_weight = exp(score - s_max_score);
            s_sum_exp += s_weight;
        }
        barrier();

        float w = s_weight;
        float c = s_correction;

        uint v_base = kv_head * uint(kv_len) * uint(head_dim) + kv_pos * uint(head_dim);
        for (uint d = tid; d < uint(head_dim); d += nt) {
            accum[d] = accum[d] * c + w * v_data[v_base + d];
        }
        barrier();
    }

    if (tid == 0) {
        s_weight = (s_sum_exp > 0.0) ? 1.0 / s_sum_exp : 0.0;
    }
    barrier();

    float inv_sum = s_weight;
    uint out_base = head * uint(seq_len) * uint(head_dim) + s * uint(head_dim);
    for (uint d = tid; d < uint(head_dim); d += nt) {
        out_data[out_base + d] = accum[d] * inv_sum;
    }
}