ferrum-kernels 0.7.7

Unified compute kernels (CUDA/Metal/CPU) and model runner for Ferrum inference
Documentation
// Q6_K MoE indirect-dispatch GEMV — adapted from q6_k_gemv.metal
// (which is itself a port of llama.cpp's `kernel_mul_mv_q6_K_f32_impl`).
//
// One Metal dispatch handles ALL `n_selected` (token, expert) pairs for
// decode m=1, replacing the per-expert down_proj gemv loop. See the
// matching q4_k_moe_id_gemv.metal for the broader rationale.
//
// Inputs:
//   src0 : [num_experts, N, K/256] Q6_K block bytes, contiguous, with
//          stride `nb02` (= N * K/256 * 210 bytes) between experts.
//   src1 : [K] activations.
//   ids  : [n_selected] selected expert IDs (i32).
//   dst  : [n_selected, N] output rows.
//
// Grid: (ceil(N/4), 1, n_selected). Threadgroup: (32, 2, 1).

#include <metal_stdlib>
using namespace metal;

#define QK_K 256
#define N_R0 2
#define N_SG 2
#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x)

struct block_q6_K {
    uchar  ql[QK_K / 2];
    uchar  qh[QK_K / 4];
    int8_t scales[QK_K / 16];
    half   d;
};

struct GemvQ6KMoeParams {
    int N;            // out_features per expert
    int K;            // in_features (multiple of 256)
    int nb01;         // src0 row stride per expert in BYTES = (K/256) * 210
    int nb02;         // src0 expert stride in BYTES = N * (K/256) * 210
    int n_selected;
    int src1_stride;  // 0 for broadcast (gate/up), K for per-slot (down)
};

kernel void gemv_q6kw_moe_id_f32(
    device const block_q6_K * src0  [[buffer(0)]],
    device const float      * src1  [[buffer(1)]],
    device const int        * ids   [[buffer(2)]],   // [n_selected]
    device       float      * dst   [[buffer(3)]],   // [n_selected, N]
    constant GemvQ6KMoeParams & p   [[buffer(4)]],
    uint3  tgpig [[threadgroup_position_in_grid]],
    ushort tiisg [[thread_index_in_simdgroup]],
    ushort sgitg [[simdgroup_index_in_threadgroup]])
{
    const int slot = tgpig.z;
    if (slot >= p.n_selected) return;

    const int expert_id = ids[slot];

    constexpr uint8_t kmask1 = 0x03;
    constexpr uint8_t kmask2 = 0x0C;
    constexpr uint8_t kmask3 = 0x30;
    constexpr uint8_t kmask4 = 0xC0;

    const int nb = p.K / QK_K;
    const int r0 = tgpig.x;
    const int first_row = (r0 * N_SG + sgitg) * N_R0;
    if (first_row >= p.N) return;

    // src0 base: pick the expert slab, then the simdgroup's first row.
    device const block_q6_K * x = (device const block_q6_K *)(
        (device const char *)src0 + expert_id * p.nb02 + first_row * p.nb01
    );
    // Per-slot activation base. See q4_k_moe_id_gemv.metal for the
    // semantics of `src1_stride` (0 = broadcast, K = per-slot rows).
    device const float * yy = src1 + slot * p.src1_stride;

    float sumf[N_R0] = { 0.f };
    float yl[16];

    const short tid = tiisg / 2;
    const short ix  = tiisg % 2;
    const short ip  = tid / 8;
    const short il  = tid % 8;
    const short l0  = 4 * il;
    const short is  = 8 * ip + l0 / 16;

    const short y_offset   = 128 * ip + l0;
    const short q_offset_l =  64 * ip + l0;
    const short q_offset_h =  32 * ip + l0;

    for (int i = ix; i < nb; i += 2) {
        device const uchar  * q1 = x[i].ql + q_offset_l;
        device const uchar  * q2 = q1 + 32;
        device const uchar  * qh = x[i].qh + q_offset_h;
        device const int8_t * sc = x[i].scales + is;
        device const half   * dh = &x[i].d;

        device const float * y = yy + i * QK_K + y_offset;

        FOR_UNROLL (short l = 0; l < 4; ++l) {
            yl[4*l + 0] = y[l +  0];
            yl[4*l + 1] = y[l + 32];
            yl[4*l + 2] = y[l + 64];
            yl[4*l + 3] = y[l + 96];
        }

        for (short row = 0; row < N_R0; ++row) {
            float4 sums = {0.f, 0.f, 0.f, 0.f};

            FOR_UNROLL (short l = 0; l < 4; ++l) {
                sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0x0F) | ((qh[l] & kmask1) << 4)) - 32);
                sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0x0F) | ((qh[l] & kmask2) << 2)) - 32);
                sums[2] += yl[4*l + 2] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
                sums[3] += yl[4*l + 3] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
            }

            sumf[row] += dh[0] * (
                sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]
            );

            q1 += p.nb01;
            q2 += p.nb01;
            qh += p.nb01;
            sc += p.nb01;
            dh += p.nb01 / 2;
        }
    }

    device float * dst_slot = dst + slot * p.N;
    for (int row = 0; row < N_R0 && (first_row + row) < p.N; ++row) {
        float sum_all = simd_sum(sumf[row]);
        if (tiisg == 0) {
            dst_slot[first_row + row] = sum_all;
        }
    }
}