llama-gguf 0.14.0

A high-performance Rust implementation of llama.cpp - LLM inference engine with full GGUF support
Documentation
#version 450

// Online-softmax cached attention for single-token generation with GQA
// Q: [num_heads, 1, head_dim]  K_cache: [num_kv_heads, max_seq_len, head_dim]
// V_cache: [num_kv_heads, max_seq_len, head_dim]  Out: [num_heads, 1, head_dim]
// kv_len = valid positions in cache, max_seq_len = stride between heads
// Dispatch: (num_heads, 1, 1) workgroups

layout(local_size_x = 256) in;

layout(set = 0, binding = 0) readonly buffer QBuf { float q_data[]; };
layout(set = 0, binding = 1) readonly buffer KBuf { float k_data[]; };
layout(set = 0, binding = 2) readonly buffer VBuf { float v_data[]; };
layout(set = 0, binding = 3) writeonly buffer OutBuf { float out_data[]; };

layout(push_constant) uniform Params {
    int num_heads;
    int num_kv_heads;
    int kv_len;
    int max_seq_len;
    int head_dim;
    float scale;
    float softcap; // 0.0 = disabled, >0 = cap * tanh(score / cap)
};

shared float accum[256];
shared float reduction[256];
shared float s_max_score;
shared float s_sum_exp;
shared float s_weight;
shared float s_correction;

void main() {
    uint head = gl_WorkGroupID.x;
    uint tid = gl_LocalInvocationID.x;
    uint nt = gl_WorkGroupSize.x;

    uint kv_head = head / (uint(num_heads) / uint(num_kv_heads));

    for (uint d = tid; d < uint(head_dim); d += nt) {
        accum[d] = 0.0;
    }
    if (tid == 0) {
        s_max_score = -3.402823466e+38;
        s_sum_exp = 0.0;
    }
    barrier();

    uint q_base = head * uint(head_dim);

    for (uint kv_pos = 0; kv_pos < uint(kv_len); kv_pos++) {
        float local_dot = 0.0;
        uint k_base = kv_head * uint(max_seq_len) * uint(head_dim) + kv_pos * uint(head_dim);
        for (uint d = tid; d < uint(head_dim); d += nt) {
            local_dot += q_data[q_base + d] * k_data[k_base + d];
        }

        reduction[tid] = local_dot;
        barrier();
        for (uint stride = nt / 2; stride > 0; stride >>= 1) {
            if (tid < stride) {
                reduction[tid] += reduction[tid + stride];
            }
            barrier();
        }

        float score = reduction[0] * scale;
        if (softcap > 0.0) {
            score = softcap * tanh(score / softcap);
        }

        if (tid == 0) {
            float old_max = s_max_score;
            if (score > old_max) {
                s_correction = exp(old_max - score);
                s_sum_exp *= s_correction;
                s_max_score = score;
            } else {
                s_correction = 1.0;
            }
            s_weight = exp(score - s_max_score);
            s_sum_exp += s_weight;
        }
        barrier();

        float w = s_weight;
        float c = s_correction;

        uint v_base = kv_head * uint(max_seq_len) * uint(head_dim) + kv_pos * uint(head_dim);
        for (uint d = tid; d < uint(head_dim); d += nt) {
            accum[d] = accum[d] * c + w * v_data[v_base + d];
        }
        barrier();
    }

    if (tid == 0) {
        s_weight = (s_sum_exp > 0.0) ? 1.0 / s_sum_exp : 0.0;
    }
    barrier();

    float inv_sum = s_weight;
    uint out_base = head * uint(head_dim);
    for (uint d = tid; d < uint(head_dim); d += nt) {
        out_data[out_base + d] = accum[d] * inv_sum;
    }
}