mlx-native 0.1.3

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// Softcap kernel: tanh-based logit capping.
///
/// Computes: output = tanh(input / cap) * cap
///
/// This bounds output values to the range (-cap, +cap).
///
/// Buffer layout:
///   buffer(0): input  — float array
///   buffer(1): output — float array
///   buffer(2): params — float2: (cap, n_elements_as_float)
///
/// Grid: dispatched with enough threads to cover n_elements.
/// Threads beyond n_elements are no-ops (bounds check).

kernel void softcap_f32(
    device const float *input  [[buffer(0)]],
    device float       *output [[buffer(1)]],
    device const float *params [[buffer(2)]],
    uint id [[thread_position_in_grid]]
) {
    const uint n_elements = as_type<uint>(params[1]);
    if (id >= n_elements) return;
    const float cap = params[0];
    const float x = input[id];
    output[id] = tanh(x / cap) * cap;
}

kernel void softcap_f16(
    device const half  *input  [[buffer(0)]],
    device half        *output [[buffer(1)]],
    device const float *params [[buffer(2)]],
    uint id [[thread_position_in_grid]]
) {
    const uint n_elements = as_type<uint>(params[1]);
    if (id >= n_elements) return;
    const float cap = params[0];
    // Promote to f32 for accurate tanh computation
    const float x = float(input[id]);
    output[id] = half(tanh(x / cap) * cap);
}

kernel void softcap_bf16(
    device const bfloat *input  [[buffer(0)]],
    device bfloat       *output [[buffer(1)]],
    device const float  *params [[buffer(2)]],
    uint id [[thread_position_in_grid]]
) {
    const uint n_elements = as_type<uint>(params[1]);
    if (id >= n_elements) return;
    const float cap = params[0];
    // Promote to f32 for accurate tanh computation; accumulate in f32
    const float x = static_cast<float>(input[id]);
    output[id] = bfloat(tanh(x / cap) * cap);
}