mlx-native 0.7.0

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// ADR-020 iter-11h-c1 — elementwise exponential.  Building block for
/// the GatedDeltaNet recurrence's `alpha = exp(-g[t])` state-decay
/// factor (mlx-lm/qwen3_5.py:GatedDeltaNet.__call__).
///
/// Forward:  y[i] = exp(x[i])
/// Backward: dx[i] = dy[i] · exp(x[i]) = dy[i] · y[i]
///
/// We expose y as the SECOND backward input (caller passes the
/// forward output back through to the backward kernel) — saves
/// recomputing exp during backward.

kernel void exp_f32(
    device const float *input  [[buffer(0)]],
    device float       *output [[buffer(1)]],
    device const uint  *params [[buffer(2)]],
    uint tid [[thread_position_in_grid]]
) {
    const uint n = params[0];
    if (tid >= n) return;
    output[tid] = metal::exp(input[tid]);
}

/// Backward: `dx = dy * y` where y is the forward output.  Caller
/// must pass the FORWARD OUTPUT (not the input) — this is the
/// canonical autograd pattern that avoids recomputation.
kernel void exp_backward_f32(
    device const float *y      [[buffer(0)]],   // forward output
    device const float *dy     [[buffer(1)]],
    device float       *dx     [[buffer(2)]],
    device const uint  *params [[buffer(3)]],
    uint tid [[thread_position_in_grid]]
) {
    const uint n = params[0];
    if (tid >= n) return;
    dx[tid] = dy[tid] * y[tid];
}