llama-gguf 0.14.0

A high-performance Rust implementation of llama.cpp - LLM inference engine with full GGUF support
Documentation
#version 450

// Dequantize Q4_K: 144-byte blocks, 256 elements each
// Block: [f16 d, f16 dmin, u8 scales[12], u8 qs[128]]
layout(local_size_x = 256) in;

layout(set = 0, binding = 0) readonly buffer RawData { uint raw[]; };
layout(set = 0, binding = 1) writeonly buffer Output { float result[]; };

layout(push_constant) uniform Params {
    int num_blocks;
};

float half_to_float(uint h) {
    uint sign = (h & 0x8000u) << 16;
    uint expo = (h >> 10) & 0x1Fu;
    uint mant = h & 0x3FFu;
    if (expo == 0u) {
        if (mant == 0u) return uintBitsToFloat(sign);
        while ((mant & 0x400u) == 0u) { mant <<= 1; expo--; }
        expo++; mant &= 0x3FFu;
    } else if (expo == 31u) {
        return uintBitsToFloat(sign | 0x7F800000u | (mant << 13));
    }
    return uintBitsToFloat(sign | ((expo + 112u) << 23) | (mant << 13));
}

uint read_byte(uint byte_offset) {
    return (raw[byte_offset / 4u] >> ((byte_offset % 4u) * 8u)) & 0xFFu;
}

uint read_u16(uint byte_offset) {
    uint b0 = read_byte(byte_offset);
    uint b1 = read_byte(byte_offset + 1u);
    return b0 | (b1 << 8);
}

void main() {
    uint block_idx = gl_WorkGroupID.x;
    uint elem_idx = gl_LocalInvocationID.x;
    if (block_idx >= uint(num_blocks)) return;

    uint base = block_idx * 144u;

    float d = half_to_float(read_u16(base));
    float dmin = half_to_float(read_u16(base + 2u));
    uint scales_base = base + 4u;
    uint qs_base = base + 16u;

    // Read scales[12]
    uint sc[12];
    for (int i = 0; i < 12; i++) {
        sc[i] = read_byte(scales_base + uint(i));
    }

    // Decode 8 scale/min pairs
    float scl[8];
    float mn[8];
    for (int j = 0; j < 4; j++) {
        scl[j] = float(sc[j] & 0x3Fu);
        mn[j] = float(sc[j + 4] & 0x3Fu);
    }
    for (int j = 4; j < 8; j++) {
        scl[j] = float((sc[j + 4] & 0x0Fu) | ((sc[j - 4] >> 6) << 4));
        mn[j] = float(((sc[j + 4] >> 4) & 0x0Fu) | ((sc[j] >> 6) << 4));
    }

    // Which group of 64 does this element belong to?
    uint group = elem_idx / 64u;
    uint within_group = elem_idx % 64u;
    uint is = group * 2u;
    uint qs_ptr = qs_base + group * 32u;

    float sc_val, mn_val;
    uint q;

    if (within_group < 32u) {
        sc_val = d * scl[is];
        mn_val = dmin * mn[is];
        q = read_byte(qs_ptr + within_group) & 0x0Fu;
    } else {
        sc_val = d * scl[is + 1u];
        mn_val = dmin * mn[is + 1u];
        q = (read_byte(qs_ptr + within_group - 32u) >> 4) & 0x0Fu;
    }

    result[block_idx * 256u + elem_idx] = sc_val * float(q) - mn_val;
}