llama-rs 0.17.0

A high-performance Rust implementation of llama.cpp - LLM inference engine with full GGUF support
Documentation
#include <metal_stdlib>
using namespace metal;

// Compute partial sum of squares per threadgroup
kernel void rms_norm_sum_f32(
    device const float* x [[buffer(0)]],
    device float* partial_sum [[buffer(1)]],
    constant int& n [[buffer(2)]],
    uint idx [[thread_position_in_grid]],
    uint tid [[thread_index_in_threadgroup]],
    uint tg_id [[threadgroup_position_in_grid]],
    uint tg_size [[threads_per_threadgroup]]
) {
    threadgroup float sdata[256];

    float val = (idx < uint(n)) ? x[idx] : 0.0f;
    sdata[tid] = val * val;

    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (uint s = tg_size / 2; s > 0; s >>= 1) {
        if (tid < s) {
            sdata[tid] += sdata[tid + s];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    if (tid == 0) {
        partial_sum[tg_id] = sdata[0];
    }
}