ferrum-kernels 0.7.7

Unified compute kernels (CUDA/Metal/CPU) and model runner for Ferrum inference
Documentation
// Q6_K GEMV — adapted from llama.cpp's kernel_mul_mv_q6_K_f32_impl
// (ggml/src/ggml-metal/ggml-metal.metal, MIT licensed).
//
// Q6_K block layout (256 weights per super-block, 210 bytes = 6.5 bits/w):
//   uint8_t  ql[128];        // lower 4 bits of each weight
//   uint8_t  qh[64];          // upper 2 bits (packed 4 weights per byte)
//   int8_t   scales[16];      // 16 sub-block scales (one per 16 weights)
//   half     d;               // super-block scale
//
// Each weight reconstructs as `int8(low4 | high2_shifted) - 32`, scaled by
// `d * scales[sub]`.
//
// Threadgroup: (32, N_SG, 1) — same N_R0=2, N_SG=2 layout as q4_k_gemv_v2.

#include <metal_stdlib>
using namespace metal;

#define QK_K  256
#define N_R0  2
#define N_SG  2
#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x)

struct block_q6_K {
    uchar  ql[QK_K / 2];     // 128 bytes
    uchar  qh[QK_K / 4];     // 64 bytes
    int8_t scales[QK_K / 16]; // 16 int8 sub-block scales
    half   d;                 // super-block scale
};

struct GemvQ6KParams {
    int N;        // out_features
    int K;        // in_features (multiple of 256)
    int nb01;     // src0 row stride in BYTES = (K/256)*210
};

kernel void gemv_f32a_q6kw_v2(
    device const block_q6_K * src0 [[buffer(0)]],   // [N, K/256] super-blocks
    device const float      * src1 [[buffer(1)]],   // [K] activations
    device       float      * dst  [[buffer(2)]],   // [N] output
    constant GemvQ6KParams  & p    [[buffer(3)]],
    uint3  tgpig [[threadgroup_position_in_grid]],
    ushort tiisg [[thread_index_in_simdgroup]],
    ushort sgitg [[simdgroup_index_in_threadgroup]])
{
    constexpr uint8_t kmask1 = 0x03;
    constexpr uint8_t kmask2 = 0x0C;
    constexpr uint8_t kmask3 = 0x30;
    constexpr uint8_t kmask4 = 0xC0;

    const int nb = p.K / QK_K;
    const int r0 = tgpig.x;
    const int first_row = (r0 * N_SG + sgitg) * N_R0;
    if (first_row >= p.N) return;

    device const block_q6_K * x = src0 + first_row * nb;
    device const float      * yy = src1;

    float sumf[N_R0] = { 0.f };
    float yl[16];

    const short tid = tiisg / 2;     // 0..15
    const short ix  = tiisg % 2;     // 0 or 1
    const short ip  = tid / 8;       // 0 or 1
    const short il  = tid % 8;       // 0..7
    const short l0  = 4 * il;        // 0, 4, 8, ...
    const short is  = 8 * ip + l0 / 16;

    const short y_offset   = 128 * ip + l0;
    const short q_offset_l =  64 * ip + l0;
    const short q_offset_h =  32 * ip + l0;

    for (int i = ix; i < nb; i += 2) {
        device const uchar  * q1 = x[i].ql + q_offset_l;
        device const uchar  * q2 = q1 + 32;
        device const uchar  * qh = x[i].qh + q_offset_h;
        device const int8_t * sc = x[i].scales + is;
        device const half   * dh = &x[i].d;

        device const float * y = yy + i * QK_K + y_offset;

        FOR_UNROLL (short l = 0; l < 4; ++l) {
            yl[4*l + 0] = y[l +  0];
            yl[4*l + 1] = y[l + 32];
            yl[4*l + 2] = y[l + 64];
            yl[4*l + 3] = y[l + 96];
        }

        for (short row = 0; row < N_R0; ++row) {
            float4 sums = {0.f, 0.f, 0.f, 0.f};

            FOR_UNROLL (short l = 0; l < 4; ++l) {
                sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0x0F) | ((qh[l] & kmask1) << 4)) - 32);
                sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0x0F) | ((qh[l] & kmask2) << 2)) - 32);
                sums[2] += yl[4*l + 2] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
                sums[3] += yl[4*l + 3] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
            }

            sumf[row] += dh[0] * (
                sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]
            );

            // Advance pointers by ONE row stride. nb01 is the byte stride of
            // src0 between consecutive output rows. q1, q2, qh are uchar*; sc
            // is int8*; dh is half*. nb01 is in bytes for all of these.
            q1 += p.nb01;
            q2 += p.nb01;
            qh += p.nb01;
            sc += p.nb01;
            dh += p.nb01 / 2;
        }
    }

    for (int row = 0; row < N_R0 && (first_row + row) < p.N; ++row) {
        float sum_all = simd_sum(sumf[row]);
        if (tiisg == 0) {
            dst[first_row + row] = sum_all;
        }
    }
}