mlx-native 0.8.1

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// Elementwise SiLU (swish) forward.
///
///   silu(x) = x · sigmoid(x) = x / (1 + exp(-x))
///
/// Buffer layout:
///   buffer(0): input  — float[n]
///   buffer(1): output — float[n]
///   buffer(2): params — uint[1]: n
///
/// Grid: 1D threads across n.
kernel void silu_f32(
    device const float *input  [[buffer(0)]],
    device float       *output [[buffer(1)]],
    device const uint  *params [[buffer(2)]],
    uint tid [[thread_position_in_grid]]
) {
    const uint n = params[0];
    if (tid >= n) return;
    const float x = input[tid];
    const float s = 1.0f / (1.0f + metal::exp(-x));
    output[tid] = x * s;
}

/// Elementwise SiLU backward.
///
///   silu'(x) = sigmoid(x) + x · sigmoid(x) · (1 − sigmoid(x))
///            = sigmoid(x) · (1 + x · (1 − sigmoid(x)))
///   dx[i]    = dy[i] · silu'(x[i])
///
/// `x` is the FORWARD INPUT (not the forward output).
///
/// Buffer layout:
///   buffer(0): x      — float[n]    (forward input)
///   buffer(1): dy     — float[n]    (upstream gradient)
///   buffer(2): dx     — float[n]    (output gradient)
///   buffer(3): params — uint[1]: n
///
/// Grid: 1D threads across n.
kernel void silu_backward_f32(
    device const float *x      [[buffer(0)]],
    device const float *dy     [[buffer(1)]],
    device float       *dx     [[buffer(2)]],
    device const uint  *params [[buffer(3)]],
    uint tid [[thread_position_in_grid]]
) {
    const uint n = params[0];
    if (tid >= n) return;
    const float xv = x[tid];
    const float s = 1.0f / (1.0f + metal::exp(-xv));
    // silu'(x) = s · (1 + x · (1 − s))
    const float deriv = s * (1.0f + xv * (1.0f - s));
    dx[tid] = dy[tid] * deriv;
}