ferrum-kernels 0.7.7

Unified compute kernels (CUDA/Metal/CPU) and model runner for Ferrum inference
Documentation
#include <metal_stdlib>
#include <metal_simdgroup_matrix>
using namespace metal;

// ── High-performance f32 GEMM: C[M,N] = A[M,K] @ B[N,K]^T ─────────────
// 64x32 output tiles, 4 simdgroups (128 threads).
// Each simdgroup: 16 rows x 32 cols via 2x4 = 8 accumulators of 8x8.
// K-dimension in tiles of 32.
// Shared memory: sa[64][32] for A tile, sb[32][32] for B^T tile.

struct GemmParams {
    int M;
    int N;
    int K;
};

kernel void gemm_f32_v2(
    device const float* A        [[buffer(0)]],
    device const float* B        [[buffer(1)]],
    device       float* C        [[buffer(2)]],
    constant GemmParams& p       [[buffer(3)]],
    threadgroup float* shmem     [[threadgroup(0)]],
    uint3  tgpig [[threadgroup_position_in_grid]],
    ushort tiitg [[thread_index_in_threadgroup]],
    ushort sgitg [[simdgroup_index_in_threadgroup]])
{
    constexpr short NR0 = 64;  // M tile
    constexpr short NR1 = 32;  // N tile
    constexpr short NK  = 32;  // K tile

    // sa: A tile [NR0, NK], sb: B^T tile [NK, NR1] (B stored transposed)
    threadgroup float* sa = shmem;              // 64 * 32 = 2048 floats
    threadgroup float* sb = shmem + NR0 * NK;   // 32 * 32 = 1024 floats

    const int r0 = tgpig.y * NR0;  // M offset
    const int r1 = tgpig.x * NR1;  // N offset

    // 4 simdgroups, each handles 16 rows of the 64-row tile
    const short sg_row = sgitg * 16;

    // 8 accumulators: 2 row-blocks x 4 col-blocks of 8x8
    simdgroup_float8x8 acc[8];
    for (short i = 0; i < 8; i++) {
        acc[i] = make_filled_simdgroup_matrix<float, 8>(0.0f);
    }

    for (int kk = 0; kk < p.K; kk += NK) {
        // ── Load A tile [NR0, NK] ──
        // A[M, K] row-major. Load rows [r0..r0+64], cols [kk..kk+32]
        for (short i = tiitg; i < NR0 * NK; i += 128) {
            short row = i / NK;
            short col = i % NK;
            sa[row * NK + col] = (r0 + row < p.M && kk + col < p.K)
                ? A[(r0 + row) * p.K + kk + col] : 0.0f;
        }

        // ── Load B^T tile [NK, NR1] ──
        // B[N, K] row-major. We want B^T[K, N] = sb[k][n].
        // sb[k * NR1 + n] = B[n, k] = B[(r1 + n) * K + (kk + k)]
        for (short i = tiitg; i < NK * NR1; i += 128) {
            short k_idx = i / NR1;  // row in B^T = K dimension
            short n_idx = i % NR1;  // col in B^T = N dimension
            sb[k_idx * NR1 + n_idx] = (r1 + n_idx < p.N && kk + k_idx < p.K)
                ? B[(r1 + n_idx) * p.K + kk + k_idx] : 0.0f;
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);

        // ── Multiply: this simdgroup's 16x32 sub-tile ──
        for (short k = 0; k < NK; k += 8) {
            // Load 2 A blocks: [8x8] from sa, rows [sg_row+0..+8] and [sg_row+8..+16]
            simdgroup_float8x8 ma0, ma1;
            simdgroup_load(ma0, sa + (sg_row + 0) * NK + k, NK);  // stride = NK = 32
            simdgroup_load(ma1, sa + (sg_row + 8) * NK + k, NK);

            // Load 4 B^T blocks: [8x8] from sb, cols [0..8], [8..16], [16..24], [24..32]
            // sb layout: [NK, NR1], stride = NR1 = 32
            simdgroup_float8x8 mb0, mb1, mb2, mb3;
            simdgroup_load(mb0, sb + k * NR1 + 0,  NR1);
            simdgroup_load(mb1, sb + k * NR1 + 8,  NR1);
            simdgroup_load(mb2, sb + k * NR1 + 16, NR1);
            simdgroup_load(mb3, sb + k * NR1 + 24, NR1);

            // C += A * B^T:
            // acc[rb*4 + cb] accumulates rows [sg_row+rb*8..+8] x cols [cb*8..+8]
            simdgroup_multiply_accumulate(acc[0], ma0, mb0, acc[0]);
            simdgroup_multiply_accumulate(acc[1], ma0, mb1, acc[1]);
            simdgroup_multiply_accumulate(acc[2], ma0, mb2, acc[2]);
            simdgroup_multiply_accumulate(acc[3], ma0, mb3, acc[3]);
            simdgroup_multiply_accumulate(acc[4], ma1, mb0, acc[4]);
            simdgroup_multiply_accumulate(acc[5], ma1, mb1, acc[5]);
            simdgroup_multiply_accumulate(acc[6], ma1, mb2, acc[6]);
            simdgroup_multiply_accumulate(acc[7], ma1, mb3, acc[7]);
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    // ── Store results directly to device memory ──
    // Each simdgroup writes its 2x4 = 8 accumulator blocks (8x8 each)
    for (short rb = 0; rb < 2; rb++) {
        for (short cb = 0; cb < 4; cb++) {
            int gr = r0 + sg_row + rb * 8;
            int gc = r1 + cb * 8;
            if (gr < p.M && gc < p.N) {
                // Store 8x8 block, stride = p.N (output row stride)
                // Need bounds check — use shared mem for partial tiles
                if (gr + 8 <= p.M && gc + 8 <= p.N) {
                    // Full 8x8 block — direct store
                    simdgroup_store(acc[rb * 4 + cb], C + gr * p.N + gc, (ulong)p.N);
                } else {
                    // Partial tile — stage through threadgroup memory
                    // Each simdgroup uses its own 64-float section
                    threadgroup float* stage = sa + sgitg * 64;
                    simdgroup_store(acc[rb * 4 + cb], stage, 8);
                    threadgroup_barrier(mem_flags::mem_threadgroup);
                    ushort lane = tiitg % 32;
                    for (short i = lane; i < 64; i += 32) {
                        short r = i / 8, c = i % 8;
                        if (gr + r < p.M && gc + c < p.N) {
                            C[(gr + r) * p.N + gc + c] = stage[i];
                        }
                    }
                    threadgroup_barrier(mem_flags::mem_threadgroup);
                }
            }
        }
    }
}