llama-rs 0.16.0

A high-performance Rust implementation of llama.cpp - LLM inference engine with full GGUF support
Documentation
#include <metal_stdlib>
using namespace metal;

struct MatmulParams {
    int m;  // rows of A / rows of C
    int K;  // cols of A / rows of B
    int n;  // cols of B / cols of C
};

// Tiled matrix multiply: C[i,j] = sum_k A[i,k] * B[k,j]
// A: [m, K] row-major, B: [K, n] row-major, C: [m, n] row-major
kernel void matmul_f32(
    device const float* a [[buffer(0)]],
    device const float* b [[buffer(1)]],
    device float* result [[buffer(2)]],
    constant MatmulParams& params [[buffer(3)]],
    uint2 gid [[thread_position_in_grid]],
    uint2 tid [[thread_position_in_threadgroup]]
) {
    const uint TILE = 16;
    threadgroup float tileA[16][16];
    threadgroup float tileB[16][16];

    uint row = gid.y;
    uint col = gid.x;
    uint localRow = tid.y;
    uint localCol = tid.x;

    float acc = 0.0f;
    uint numTiles = (uint(params.K) + TILE - 1) / TILE;

    for (uint t = 0; t < numTiles; t++) {
        uint aCol = t * TILE + localCol;
        uint bRow = t * TILE + localRow;

        tileA[localRow][localCol] = (row < uint(params.m) && aCol < uint(params.K))
            ? a[row * uint(params.K) + aCol] : 0.0f;
        tileB[localRow][localCol] = (bRow < uint(params.K) && col < uint(params.n))
            ? b[bRow * uint(params.n) + col] : 0.0f;

        threadgroup_barrier(mem_flags::mem_threadgroup);

        for (uint i = 0; i < TILE; i++) {
            acc += tileA[localRow][i] * tileB[i][localCol];
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    if (row < uint(params.m) && col < uint(params.n)) {
        result[row * uint(params.n) + col] = acc;
    }
}