llama-gguf 0.14.0

A high-performance Rust implementation of llama.cpp - LLM inference engine with full GGUF support
Documentation
// Online-softmax multi-head attention with causal masking and GQA
// Q: [num_heads, seq_len, head_dim]  K: [num_kv_heads, kv_len, head_dim]
// V: [num_kv_heads, kv_len, head_dim]  Out: [num_heads, seq_len, head_dim]
// Dispatch: (num_heads, seq_len, 1) groups of (256, 1, 1)

cbuffer Params : register(b0) {
    int num_heads;
    int num_kv_heads;
    int seq_len;
    int kv_len;
    int head_dim;
    float scale;
};

RWStructuredBuffer<float> q_data : register(u0);
RWStructuredBuffer<float> k_data : register(u1);
RWStructuredBuffer<float> v_data : register(u2);
RWStructuredBuffer<float> out_data : register(u3);

groupshared float accum[256];
groupshared float reduction[256];
groupshared float s_max_score;
groupshared float s_sum_exp;
groupshared float s_weight;
groupshared float s_correction;

[numthreads(256, 1, 1)]
void main(uint3 gid : SV_GroupID, uint3 gtid : SV_GroupThreadID) {
    uint head = gid.x;
    uint s = gid.y;
    uint tid = gtid.x;
    const uint nt = 256;

    uint kv_head = head / ((uint)num_heads / (uint)num_kv_heads);
    uint q_abs_pos = (uint)kv_len - (uint)seq_len + s;

    for (uint d = tid; d < (uint)head_dim; d += nt) {
        accum[d] = 0.0;
    }
    if (tid == 0) {
        s_max_score = -3.402823466e+38;
        s_sum_exp = 0.0;
    }
    GroupMemoryBarrierWithGroupSync();

    uint q_base = head * (uint)seq_len * (uint)head_dim + s * (uint)head_dim;

    for (uint kv_pos = 0; kv_pos <= q_abs_pos && kv_pos < (uint)kv_len; kv_pos++) {
        float local_dot = 0.0;
        uint k_base = kv_head * (uint)kv_len * (uint)head_dim + kv_pos * (uint)head_dim;
        for (uint d = tid; d < (uint)head_dim; d += nt) {
            local_dot += q_data[q_base + d] * k_data[k_base + d];
        }

        reduction[tid] = local_dot;
        GroupMemoryBarrierWithGroupSync();
        for (uint stride = nt / 2; stride > 0; stride >>= 1) {
            if (tid < stride) {
                reduction[tid] += reduction[tid + stride];
            }
            GroupMemoryBarrierWithGroupSync();
        }

        float score = reduction[0] * scale;

        if (tid == 0) {
            float old_max = s_max_score;
            if (score > old_max) {
                s_correction = exp(old_max - score);
                s_sum_exp *= s_correction;
                s_max_score = score;
            } else {
                s_correction = 1.0;
            }
            s_weight = exp(score - s_max_score);
            s_sum_exp += s_weight;
        }
        GroupMemoryBarrierWithGroupSync();

        float w = s_weight;
        float c = s_correction;

        uint v_base = kv_head * (uint)kv_len * (uint)head_dim + kv_pos * (uint)head_dim;
        for (uint d = tid; d < (uint)head_dim; d += nt) {
            accum[d] = accum[d] * c + w * v_data[v_base + d];
        }
        GroupMemoryBarrierWithGroupSync();
    }

    if (tid == 0) {
        s_weight = (s_sum_exp > 0.0) ? 1.0 / s_sum_exp : 0.0;
    }
    GroupMemoryBarrierWithGroupSync();

    float inv_sum = s_weight;
    uint out_base = head * (uint)seq_len * (uint)head_dim + s * (uint)head_dim;
    for (uint d = tid; d < (uint)head_dim; d += nt) {
        out_data[out_base + d] = accum[d] * inv_sum;
    }
}