ferrum-kernels 0.7.1

Unified compute kernels (CUDA/Metal/CPU) and model runner for Ferrum inference
Documentation
// MoE router: per-row softmax + top-K + optional renormalize.
//
// Replaces the host-side `B::sync(ctx) + B::to_vec(router_logits) +
// crate::moe::router::route(...)` sequence used in the existing MoE
// path. Each per-layer call previously paid one full Metal pipeline
// drain (~1 ms on M1 Max) plus a host softmax/sort. Doing it on the
// GPU removes the sync entirely (the kernel writes ids/weights into
// device buffers that the next mul_mm_id pass reads directly).
//
// Algorithm: one threadgroup per token row.
//   1. Stable softmax (max-subtract → exp → sum-reduce → divide).
//   2. Repeated argmax to extract top-K — k ≤ 32 expected, so a
//      simple loop that masks each picked entry to -INFINITY suffices
//      (vs a partial sort, which would over-engineer for tiny K).
//   3. Optional renormalize: if `norm_topk_prob`, divide selected
//      weights by their sum so they total 1.0.
//
// Threadgroup: 32 threads. Each thread covers `ceil(num_experts/32)`
// logits during the softmax and argmax reductions. Shared memory
// holds the post-softmax probability vector for the row.

#include <metal_stdlib>
using namespace metal;

struct RouterParams {
    int num_experts;
    int top_k;
    int norm_topk_prob;  // 0 / 1
};

kernel void moe_router_topk_softmax_f32(
    device const float * logits        [[buffer(0)]],   // [batch, num_experts]
    device       int   * out_ids       [[buffer(1)]],   // [batch, top_k]
    device       float * out_weights   [[buffer(2)]],   // [batch, top_k]
    constant RouterParams & p          [[buffer(3)]],
    threadgroup float  * shmem         [[threadgroup(0)]],
    uint3  tgpig [[threadgroup_position_in_grid]],
    ushort tiisg [[thread_index_in_simdgroup]])
{
    const int row = tgpig.x;
    const int n_exp = p.num_experts;
    const int top_k = p.top_k;

    threadgroup float * probs = shmem;       // [num_experts]
    // Cooperative load of logits into shmem, finding row max along the way.
    float thread_max = -INFINITY;
    for (int i = tiisg; i < n_exp; i += 32) {
        const float v = logits[row * n_exp + i];
        probs[i] = v;
        thread_max = max(thread_max, v);
    }
    // Reduce max across simdgroup.
    float row_max = simd_max(thread_max);

    // exp(logit - max) and partial sum.
    float thread_sum = 0.0f;
    for (int i = tiisg; i < n_exp; i += 32) {
        const float e = exp(probs[i] - row_max);
        probs[i] = e;
        thread_sum += e;
    }
    float row_sum = simd_sum(thread_sum);
    float inv_sum = 1.0f / row_sum;

    // Normalise.
    for (int i = tiisg; i < n_exp; i += 32) {
        probs[i] *= inv_sum;
    }

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Repeated argmax for top-K. k ≤ 32 in all known MoE configs;
    // this is a 32-thread cooperative scan per pick (k passes total).
    // The picked entry is overwritten with -INFINITY so the next pass
    // sees the next-best value.
    //
    // Tie-breaking: simd_max returns the largest float. We follow
    // ferrum's host-side `route` convention of "smaller index wins"
    // by storing `(prob, -index)` pairs and reducing on prob first,
    // index-as-tiebreaker. Here we encode that as: when the max prob
    // is observed, threads with that prob race-write their index into
    // a shmem slot, and only `min(index)` survives (one extra reduce).
    threadgroup float * sel_weights  = (threadgroup float *)(probs + n_exp);
    threadgroup int   * sel_idxs     = (threadgroup int   *)(sel_weights + top_k);
    // One slot for the running renorm sum so it's visible to every
    // thread when the final write phase computes `scale = 1/sum`.
    threadgroup float * renorm_slot  = (threadgroup float *)(sel_idxs + top_k);
    if (tiisg == 0) {
        renorm_slot[0] = 0.0f;
    }

    for (int k = 0; k < top_k; k++) {
        // Find max prob this round.
        float thread_best = -INFINITY;
        int   thread_idx  = -1;
        for (int i = tiisg; i < n_exp; i += 32) {
            const float v = probs[i];
            if (v > thread_best) {
                thread_best = v;
                thread_idx = i;
            }
        }
        const float best = simd_max(thread_best);

        // Race: each thread that holds `best` reports its index; we
        // pick the smallest one as the winner. simd_min over threads
        // that don't match writes INT_MAX so they lose the race.
        int my_idx_for_min = (thread_best == best) ? thread_idx : 0x7fffffff;
        const int win_idx = simd_min(my_idx_for_min);

        if (tiisg == 0) {
            sel_weights[k] = best;
            sel_idxs[k] = win_idx;
            renorm_slot[0] += best;
            // Mask the picked entry from future passes.
            probs[win_idx] = -INFINITY;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    // Optional renormalise so the K weights sum to 1.0.
    float scale = 1.0f;
    if (p.norm_topk_prob != 0) {
        // Clamp like llama.cpp's `ggml_clamp` to avoid div-by-zero on
        // degenerate inputs (all logits -INFINITY).
        scale = 1.0f / max(renorm_slot[0], 6.103515625e-5f);
    }

    if (tiisg < top_k) {
        out_ids[row * top_k + tiisg] = sel_idxs[tiisg];
        out_weights[row * top_k + tiisg] = sel_weights[tiisg] * scale;
    }
}

// ── compute_ids_tpe: bucket selected pairs by expert ────────────────────
//
// Replaces the host-side `compute_ids_tpe` that ran inside
// `moe_forward_batched_prefill_impl`. After `moe_router_topk_softmax_f32`
// emits `selected_ids[batch, top_k]`, this kernel groups those `(token,
// slot)` pairs by their target expert, producing:
//
//   tpe[e]                   = number of pairs assigned to expert `e`
//   ids[e * row_stride + s]  = global pair index (= token*top_k + slot)
//                              for the s-th pair routed to expert `e`
//
// `row_stride` is the *worst-case* `batch * top_k` (the consumer GEMM
// kernel still uses this stride for ids indexing — only the GRID size
// is tightened by reducing `max(tpe[e])` and emitting indirect-dispatch
// args, leaving the kernel's `r1 >= tpe[e]` early-exit untouched).
//
// Algorithm: single 1-D threadgroup, `THREADS` threads.
//   Phase 1  zero tpe.
//   Phase 2  walk selected_ids, atomic_fetch_add slot claim, write ids.
//   Phase 3  max-reduce tpe[e] across threads → max_per_expert.
//   Phase 4  thread 0 writes (grid_x, grid_y, grid_z) to two indirect
//            args buffers — one for gate/up (M=expert_inter), one for
//            down (M=hidden). grid_x is shared because all three GEMMs
//            walk the same per-expert pair list.

constant int FERRUM_IDS_TPE_THREADS = 256;
constant int FERRUM_IDS_TPE_NSG     = 8;        // = THREADS / 32
constant int FERRUM_GEMM_NR0        = 64;       // matches gemm_q4kw_moe_id_f32
constant int FERRUM_GEMM_NR1        = 32;       // matches gemm_q4kw_moe_id_f32

struct ComputeIdsTpeParams {
    int num_experts;
    int row_stride;     // = batch * top_k (worst-case ids row stride)
    int total_pairs;    // = batch * top_k
    int m_gate_up;      // M for gate / up GEMM (= expert_intermediate_size)
    int m_down;         // M for down GEMM (= hidden_size)
};

kernel void moe_compute_ids_tpe_f32(
    device const int      * selected_ids [[buffer(0)]],   // [total_pairs] i32
    device atomic_int     * tpe          [[buffer(1)]],   // [num_experts] i32
    device       int      * ids          [[buffer(2)]],   // [num_experts * row_stride] i32
    device       uint     * gate_up_args [[buffer(3)]],   // [3] u32 indirect args
    device       uint     * down_args    [[buffer(4)]],   // [3] u32 indirect args
    constant ComputeIdsTpeParams & p     [[buffer(5)]],
    uint tid [[thread_position_in_threadgroup]],
    ushort tiisg [[thread_index_in_simdgroup]],
    ushort sgitg [[simdgroup_index_in_threadgroup]])
{
    // Phase 1: zero tpe.
    for (int e = int(tid); e < p.num_experts; e += FERRUM_IDS_TPE_THREADS) {
        atomic_store_explicit(&tpe[e], 0, memory_order_relaxed);
    }
    threadgroup_barrier(mem_flags::mem_device);

    // Phase 2: bucket. Each thread covers a strided slice of pairs.
    for (int pair_idx = int(tid); pair_idx < p.total_pairs; pair_idx += FERRUM_IDS_TPE_THREADS) {
        const int e = selected_ids[pair_idx];
        if (e >= 0 && e < p.num_experts) {
            const int slot = atomic_fetch_add_explicit(&tpe[e], 1, memory_order_relaxed);
            ids[e * p.row_stride + slot] = pair_idx;
        }
    }
    threadgroup_barrier(mem_flags::mem_device);

    // Phase 3: reduce max(tpe[e]) across threads.
    int local_max = 0;
    for (int e = int(tid); e < p.num_experts; e += FERRUM_IDS_TPE_THREADS) {
        const int v = atomic_load_explicit(&tpe[e], memory_order_relaxed);
        if (v > local_max) {
            local_max = v;
        }
    }
    const int sg_max = simd_max(local_max);
    threadgroup int sg_results[FERRUM_IDS_TPE_NSG];
    if (tiisg == 0) {
        sg_results[sgitg] = sg_max;
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Phase 4: thread 0 cross-simdgroup reduce + write indirect args.
    // grid_x is clamped to ≥ 1 so the dispatch is non-empty even on
    // pathological "all logits identical" routings (max_per_expert = 0
    // would otherwise produce a 0×Y×Z grid, which Metal rejects).
    if (tid == 0) {
        int max_pe = 0;
        for (int s = 0; s < FERRUM_IDS_TPE_NSG; ++s) {
            if (sg_results[s] > max_pe) {
                max_pe = sg_results[s];
            }
        }
        uint grid_x = uint((max_pe + FERRUM_GEMM_NR1 - 1) / FERRUM_GEMM_NR1);
        if (grid_x < 1u) {
            grid_x = 1u;
        }
        const uint grid_z = uint(p.num_experts);
        const uint grid_y_gate_up = uint((p.m_gate_up + FERRUM_GEMM_NR0 - 1) / FERRUM_GEMM_NR0);
        const uint grid_y_down    = uint((p.m_down    + FERRUM_GEMM_NR0 - 1) / FERRUM_GEMM_NR0);

        gate_up_args[0] = grid_x;
        gate_up_args[1] = grid_y_gate_up;
        gate_up_args[2] = grid_z;

        down_args[0] = grid_x;
        down_args[1] = grid_y_down;
        down_args[2] = grid_z;
    }
}