ferrum-kernels 0.7.7

Unified compute kernels (CUDA/Metal/CPU) and model runner for Ferrum inference
Documentation
#include <metal_stdlib>
using namespace metal;

// ── Softmax Last Dim ────────────────────────────────────────────────────
// In-place softmax over the last dimension of a 2D [rows, cols] tensor.
// Each threadgroup handles one row. Uses simd_sum for parallel reduction.
//
// Matches candle's softmax_last_dim CPU implementation:
//   1. Find max (simd reduction)
//   2. Subtract max and exp
//   3. Sum exp values (simd reduction)
//   4. Divide by sum

struct SoftmaxParams {
    int rows;
    int cols;
};

kernel void softmax_last_dim_f32(
    device float* data          [[buffer(0)]],   // [rows, cols] in-place
    constant SoftmaxParams& p   [[buffer(1)]],
    uint  tgpig [[threadgroup_position_in_grid]],  // row index
    uint  tiisg [[thread_index_in_simdgroup]],
    uint  sgitg [[simdgroup_index_in_threadgroup]],
    uint  ntg   [[threads_per_threadgroup]],
    threadgroup float* shmem [[threadgroup(0)]])
{
    const int row = tgpig;
    if (row >= p.rows) return;

    device float* x = data + row * p.cols;
    const int cols = p.cols;

    // 1. Find max (parallel reduction)
    float local_max = -INFINITY;
    for (int i = tiisg; i < cols; i += 32) {
        local_max = max(local_max, x[i]);
    }
    local_max = simd_max(local_max);

    // Cross-simdgroup reduction if needed
    if (ntg > 32) {
        if (sgitg == 0) shmem[tiisg] = -INFINITY;
        threadgroup_barrier(mem_flags::mem_threadgroup);
        if (tiisg == 0) shmem[sgitg] = local_max;
        threadgroup_barrier(mem_flags::mem_threadgroup);
        local_max = shmem[tiisg < (ntg / 32) ? tiisg : 0];
        local_max = simd_max(local_max);
    }

    // 2. Exp and sum
    float local_sum = 0.0f;
    for (int i = tiisg; i < cols; i += 32) {
        float e = exp(x[i] - local_max);
        x[i] = e;
        local_sum += e;
    }
    local_sum = simd_sum(local_sum);

    if (ntg > 32) {
        if (sgitg == 0) shmem[tiisg] = 0.0f;
        threadgroup_barrier(mem_flags::mem_threadgroup);
        if (tiisg == 0) shmem[sgitg] = local_sum;
        threadgroup_barrier(mem_flags::mem_threadgroup);
        local_sum = shmem[tiisg < (ntg / 32) ? tiisg : 0];
        local_sum = simd_sum(local_sum);
    }

    // 3. Normalize
    float inv_sum = 1.0f / local_sum;
    for (int i = tiisg; i < cols; i += 32) {
        x[i] *= inv_sum;
    }
}

// Non-in-place version: reads from input, writes to output
kernel void softmax_last_dim_f32_out(
    device const float* input   [[buffer(0)]],
    device       float* output  [[buffer(1)]],
    constant SoftmaxParams& p   [[buffer(2)]],
    uint  tgpig [[threadgroup_position_in_grid]],
    uint  tiisg [[thread_index_in_simdgroup]],
    uint  sgitg [[simdgroup_index_in_threadgroup]],
    uint  ntg   [[threads_per_threadgroup]],
    threadgroup float* shmem [[threadgroup(0)]])
{
    const int row = tgpig;
    if (row >= p.rows) return;

    device const float* x = input + row * p.cols;
    device float* y = output + row * p.cols;
    const int cols = p.cols;

    // 1. Find max
    float local_max = -INFINITY;
    for (int i = tiisg; i < cols; i += 32) {
        local_max = max(local_max, x[i]);
    }
    local_max = simd_max(local_max);

    if (ntg > 32) {
        if (sgitg == 0) shmem[tiisg] = -INFINITY;
        threadgroup_barrier(mem_flags::mem_threadgroup);
        if (tiisg == 0) shmem[sgitg] = local_max;
        threadgroup_barrier(mem_flags::mem_threadgroup);
        local_max = shmem[tiisg < (ntg / 32) ? tiisg : 0];
        local_max = simd_max(local_max);
    }

    // 2. Exp and sum
    float local_sum = 0.0f;
    for (int i = tiisg; i < cols; i += 32) {
        float e = exp(x[i] - local_max);
        y[i] = e;
        local_sum += e;
    }
    local_sum = simd_sum(local_sum);

    if (ntg > 32) {
        if (sgitg == 0) shmem[tiisg] = 0.0f;
        threadgroup_barrier(mem_flags::mem_threadgroup);
        if (tiisg == 0) shmem[sgitg] = local_sum;
        threadgroup_barrier(mem_flags::mem_threadgroup);
        local_sum = shmem[tiisg < (ntg / 32) ? tiisg : 0];
        local_sum = simd_sum(local_sum);
    }

    // 3. Normalize
    float inv_sum = 1.0f / local_sum;
    for (int i = tiisg; i < cols; i += 32) {
        y[i] *= inv_sum;
    }
}