llama-gguf 0.14.0

A high-performance Rust implementation of llama.cpp - LLM inference engine with full GGUF support
Documentation
#version 450

// Rotary Position Embedding for single-position inference
// q/k: [num_heads, head_dim] (interleaved)
layout(local_size_x = 64) in;

layout(set = 0, binding = 0) buffer QueryData { float q[]; };
layout(set = 0, binding = 1) buffer KeyData { float k[]; };

layout(push_constant) uniform Params {
    int num_q_heads;
    int num_k_heads;
    int head_dim;
    int position;
    float freq_base;
    float freq_scale;
    int use_neox; // 0 = normal, 1 = NeoX style
    int rope_dims; // number of dimensions to rotate (may be < head_dim for partial RoPE)
};

void main() {
    uint head_idx = gl_WorkGroupID.x;
    uint pair_idx = gl_LocalInvocationID.x;

    int half_rope = rope_dims / 2;
    if (pair_idx >= half_rope) return;

    // Compute frequency for this dimension pair
    float freq = freq_scale / pow(freq_base, float(2 * pair_idx) / float(rope_dims));
    float angle = float(position) * freq;
    float cos_val = cos(angle);
    float sin_val = sin(angle);

    // Apply RoPE to query (if within q head count)
    if (head_idx < num_q_heads) {
        uint base = head_idx * head_dim;
        uint i0, i1;
        if (use_neox != 0) {
            // NeoX: first half + second half (within rope_dims region)
            i0 = base + pair_idx;
            i1 = base + pair_idx + half_rope;
        } else {
            // Normal: consecutive pairs
            i0 = base + 2 * pair_idx;
            i1 = base + 2 * pair_idx + 1;
        }

        float q0 = q[i0];
        float q1 = q[i1];
        q[i0] = q0 * cos_val - q1 * sin_val;
        q[i1] = q0 * sin_val + q1 * cos_val;
    }

    // Apply RoPE to key (if within k head count)
    if (head_idx < num_k_heads) {
        uint base = head_idx * head_dim;
        uint i0, i1;
        if (use_neox != 0) {
            i0 = base + pair_idx;
            i1 = base + pair_idx + half_rope;
        } else {
            i0 = base + 2 * pair_idx;
            i1 = base + 2 * pair_idx + 1;
        }

        float k0 = k[i0];
        float k1 = k[i1];
        k[i0] = k0 * cos_val - k1 * sin_val;
        k[i1] = k0 * sin_val + k1 * cos_val;
    }
}