llama-gguf 0.14.0

A high-performance Rust implementation of llama.cpp - LLM inference engine with full GGUF support
Documentation
#version 450

// Dequantize Q6_K: 210-byte blocks, 256 elements each
// Block: [u8 ql[128], u8 qh[64], i8 scales[16], f16 d]
layout(local_size_x = 256) in;

layout(set = 0, binding = 0) readonly buffer RawData { uint raw[]; };
layout(set = 0, binding = 1) writeonly buffer Output { float result[]; };

layout(push_constant) uniform Params {
    int num_blocks;
};

float half_to_float(uint h) {
    uint sign = (h & 0x8000u) << 16;
    uint expo = (h >> 10) & 0x1Fu;
    uint mant = h & 0x3FFu;
    if (expo == 0u) {
        if (mant == 0u) return uintBitsToFloat(sign);
        while ((mant & 0x400u) == 0u) { mant <<= 1; expo--; }
        expo++; mant &= 0x3FFu;
    } else if (expo == 31u) {
        return uintBitsToFloat(sign | 0x7F800000u | (mant << 13));
    }
    return uintBitsToFloat(sign | ((expo + 112u) << 23) | (mant << 13));
}

uint read_byte(uint byte_offset) {
    return (raw[byte_offset / 4u] >> ((byte_offset % 4u) * 8u)) & 0xFFu;
}

int read_i8(uint byte_offset) {
    int v = int(read_byte(byte_offset));
    if (v >= 128) v -= 256;
    return v;
}

uint read_u16(uint byte_offset) {
    return read_byte(byte_offset) | (read_byte(byte_offset + 1u) << 8);
}

void main() {
    uint block_idx = gl_WorkGroupID.x;
    uint elem_idx = gl_LocalInvocationID.x;
    if (block_idx >= uint(num_blocks)) return;

    uint base = block_idx * 210u;
    uint ql_base = base;
    uint qh_base = base + 128u;
    uint sc_base = base + 192u;
    float d = half_to_float(read_u16(base + 208u));

    // Which half (0 or 1) of the 256 elements
    uint half_idx = elem_idx / 128u;
    uint within_half = elem_idx % 128u;
    uint l = within_half % 32u;
    uint quarter = within_half / 32u;

    uint ql_off = half_idx * 64u + l;
    uint qh_off = half_idx * 32u + l;
    uint sc_off = half_idx * 8u + (l / 16u);

    uint ql_val, qh_val;
    int scale;

    if (quarter == 0u) {
        ql_val = read_byte(ql_base + ql_off) & 0x0Fu;
        qh_val = read_byte(qh_base + qh_off) & 0x03u;
        scale = read_i8(sc_base + sc_off);
    } else if (quarter == 1u) {
        ql_val = read_byte(ql_base + ql_off + 32u) & 0x0Fu;
        qh_val = (read_byte(qh_base + qh_off) >> 2) & 0x03u;
        scale = read_i8(sc_base + sc_off + 2u);
    } else if (quarter == 2u) {
        ql_val = (read_byte(ql_base + ql_off) >> 4) & 0x0Fu;
        qh_val = (read_byte(qh_base + qh_off) >> 4) & 0x03u;
        scale = read_i8(sc_base + sc_off + 4u);
    } else {
        ql_val = (read_byte(ql_base + ql_off + 32u) >> 4) & 0x0Fu;
        qh_val = (read_byte(qh_base + qh_off) >> 6) & 0x03u;
        scale = read_i8(sc_base + sc_off + 6u);
    }

    int q = int(ql_val | (qh_val << 4)) - 32;
    result[block_idx * 256u + elem_idx] = d * float(scale) * float(q);
}