ferrum-kernels 0.7.7

Unified compute kernels (CUDA/Metal/CPU) and model runner for Ferrum inference
Documentation
#include <metal_stdlib>
#include <metal_simdgroup_matrix>
using namespace metal;

// Mixed-precision GEMM: C[M,N] f32 = A[M,K] f32 @ B[N,K]^T f16
//
// Same 64x32 tile structure as gemm_f32_v2, but B is loaded as half and
// upcast to float when staged into threadgroup memory. A stays float, the
// simdgroup matrix multiply runs in f32 (higher precision than an all-f16
// MAC, which would also require a separate half simdgroup path).
//
// Use when B holds big weight matrices stored as fp16 — halves the weight
// memory footprint while keeping accumulation precision.

struct GemmF16WParams {
    int M;
    int N;
    int K;
};

kernel void gemm_f32a_f16w_v2(
    device const float* A        [[buffer(0)]],
    device const half*  B        [[buffer(1)]],
    device       float* C        [[buffer(2)]],
    constant GemmF16WParams& p   [[buffer(3)]],
    threadgroup float* shmem     [[threadgroup(0)]],
    uint3  tgpig [[threadgroup_position_in_grid]],
    ushort tiitg [[thread_index_in_threadgroup]],
    ushort sgitg [[simdgroup_index_in_threadgroup]])
{
    constexpr short NR0 = 64;  // M tile
    constexpr short NR1 = 32;  // N tile
    constexpr short NK  = 32;  // K tile

    threadgroup float* sa = shmem;
    threadgroup float* sb = shmem + NR0 * NK;

    const int r0 = tgpig.y * NR0;
    const int r1 = tgpig.x * NR1;
    const short sg_row = sgitg * 16;

    simdgroup_float8x8 acc[8];
    for (short i = 0; i < 8; i++) {
        acc[i] = make_filled_simdgroup_matrix<float, 8>(0.0f);
    }

    for (int kk = 0; kk < p.K; kk += NK) {
        // Load A tile [NR0, NK] from f32
        for (short i = tiitg; i < NR0 * NK; i += 128) {
            short row = i / NK;
            short col = i % NK;
            sa[row * NK + col] = (r0 + row < p.M && kk + col < p.K)
                ? A[(r0 + row) * p.K + kk + col] : 0.0f;
        }

        // Load B^T tile [NK, NR1] from f16 and upcast
        for (short i = tiitg; i < NK * NR1; i += 128) {
            short k_idx = i / NR1;
            short n_idx = i % NR1;
            sb[k_idx * NR1 + n_idx] = (r1 + n_idx < p.N && kk + k_idx < p.K)
                ? (float)B[(r1 + n_idx) * p.K + kk + k_idx] : 0.0f;
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);

        for (short k = 0; k < NK; k += 8) {
            simdgroup_float8x8 ma0, ma1;
            simdgroup_load(ma0, sa + (sg_row + 0) * NK + k, NK);
            simdgroup_load(ma1, sa + (sg_row + 8) * NK + k, NK);

            simdgroup_float8x8 mb0, mb1, mb2, mb3;
            simdgroup_load(mb0, sb + k * NR1 + 0,  NR1);
            simdgroup_load(mb1, sb + k * NR1 + 8,  NR1);
            simdgroup_load(mb2, sb + k * NR1 + 16, NR1);
            simdgroup_load(mb3, sb + k * NR1 + 24, NR1);

            simdgroup_multiply_accumulate(acc[0], ma0, mb0, acc[0]);
            simdgroup_multiply_accumulate(acc[1], ma0, mb1, acc[1]);
            simdgroup_multiply_accumulate(acc[2], ma0, mb2, acc[2]);
            simdgroup_multiply_accumulate(acc[3], ma0, mb3, acc[3]);
            simdgroup_multiply_accumulate(acc[4], ma1, mb0, acc[4]);
            simdgroup_multiply_accumulate(acc[5], ma1, mb1, acc[5]);
            simdgroup_multiply_accumulate(acc[6], ma1, mb2, acc[6]);
            simdgroup_multiply_accumulate(acc[7], ma1, mb3, acc[7]);
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    for (short rb = 0; rb < 2; rb++) {
        for (short cb = 0; cb < 4; cb++) {
            int gr = r0 + sg_row + rb * 8;
            int gc = r1 + cb * 8;
            if (gr < p.M && gc < p.N) {
                if (gr + 8 <= p.M && gc + 8 <= p.N) {
                    simdgroup_store(acc[rb * 4 + cb], C + gr * p.N + gc, (ulong)p.N);
                } else {
                    threadgroup float* stage = sa + sgitg * 64;
                    simdgroup_store(acc[rb * 4 + cb], stage, 8);
                    threadgroup_barrier(mem_flags::mem_threadgroup);
                    ushort lane = tiitg % 32;
                    for (short i = lane; i < 64; i += 32) {
                        short r = i / 8, c = i % 8;
                        if (gr + r < p.M && gc + c < p.N) {
                            C[(gr + r) * p.N + gc + c] = stage[i];
                        }
                    }
                    threadgroup_barrier(mem_flags::mem_threadgroup);
                }
            }
        }
    }
}

// GEMV: C[N] f32 = A[K] f32 @ B[N, K]^T f16
// One threadgroup per output column. K reduction via simd_sum.
// Used when M=1 (decode) with f16-weight B.
kernel void gemv_f32a_f16w(
    device const float* A        [[buffer(0)]],
    device const half*  B        [[buffer(1)]],
    device       float* C        [[buffer(2)]],
    constant GemmF16WParams& p   [[buffer(3)]],
    uint3  tgpig [[threadgroup_position_in_grid]],
    ushort tiitg [[thread_index_in_threadgroup]])
{
    const int col = tgpig.x;
    if (col >= p.N) return;

    float acc = 0.0f;
    // Each thread strides through K, summing A[k] * B[col, k]
    for (int k = tiitg; k < p.K; k += 32) {
        acc += A[k] * (float)B[col * p.K + k];
    }
    acc = simd_sum(acc);
    if (tiitg == 0) {
        C[col] = acc;
    }
}