ferrum-kernels 0.7.7

Unified compute kernels (CUDA/Metal/CPU) and model runner for Ferrum inference
Documentation
// Shared CUDA reduction primitives.
// Include this header instead of duplicating warp/block reduce in every kernel.

#pragma once

#include <cuda_fp16.h>

// ── Warp-level reductions (32 threads) ─────────────────────────────────

__inline__ __device__ float warp_reduce_sum(float val) {
    for (int offset = 16; offset > 0; offset >>= 1)
        val += __shfl_xor_sync(0xffffffff, val, offset);
    return val;
}

__inline__ __device__ float warp_reduce_max(float val) {
    for (int offset = 16; offset > 0; offset >>= 1)
        val = fmaxf(val, __shfl_xor_sync(0xffffffff, val, offset));
    return val;
}

// ── Block-level reductions (up to 256 threads = 8 warps) ───────────────

__inline__ __device__ float block_reduce_sum(float val) {
    // Sized for up to 32 warps = 1024 threads (max CUDA block size).
    // Was `shared[8]` — caused OOB write for blocks > 256 threads, which
    // silently corrupted shared memory on older GPUs but trips
    // CUDA_ERROR_ILLEGAL_ADDRESS on Blackwell's stricter memory model.
    __shared__ float shared[32];
    int lane = threadIdx.x % 32;
    int wid  = threadIdx.x / 32;
    val = warp_reduce_sum(val);
    if (lane == 0) shared[wid] = val;
    __syncthreads();
    val = (threadIdx.x < blockDim.x / 32) ? shared[lane] : 0.0f;
    if (wid == 0) val = warp_reduce_sum(val);
    return val;
}

__inline__ __device__ float block_reduce_max(float val) {
    // Sized for up to 32 warps = 1024 threads (max CUDA block size).
    // Was `shared[8]` — caused OOB write for blocks > 256 threads, which
    // silently corrupted shared memory on older GPUs but trips
    // CUDA_ERROR_ILLEGAL_ADDRESS on Blackwell's stricter memory model.
    __shared__ float shared[32];
    int lane = threadIdx.x % 32;
    int wid  = threadIdx.x / 32;
    val = warp_reduce_max(val);
    if (lane == 0) shared[wid] = val;
    __syncthreads();
    val = (threadIdx.x < blockDim.x / 32) ? shared[lane] : -1e20f;
    if (wid == 0) val = warp_reduce_max(val);
    return val;
}