mlx-native 0.9.0

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// ADR-020 iter-11h-misc-3 — elementwise sqrt + backward.
///
/// Forward:  y[i] = sqrt(x[i])
/// Backward: dx[i] = dy[i] / (2 · y[i])  (uses forward output;
///                                        avoids recomputing 1/(2√x))
///
/// Caller responsibility: x[i] >= 0 for all i (sqrt of negative
/// floats returns NaN per IEEE 754).  Backward produces NaN at
/// x[i] = 0 (division by zero) — caller should add eps if needed.

kernel void sqrt_f32(
    device const float *input  [[buffer(0)]],
    device float       *output [[buffer(1)]],
    device const uint  *params [[buffer(2)]],
    uint tid [[thread_position_in_grid]]
) {
    const uint n = params[0];
    if (tid >= n) return;
    output[tid] = metal::sqrt(input[tid]);
}

/// Backward: dx[i] = dy[i] / (2 · y[i]).
kernel void sqrt_backward_f32(
    device const float *y      [[buffer(0)]],   // forward output
    device const float *dy     [[buffer(1)]],
    device float       *dx     [[buffer(2)]],
    device const uint  *params [[buffer(3)]],
    uint tid [[thread_position_in_grid]]
) {
    const uint n = params[0];
    if (tid >= n) return;
    dx[tid] = dy[tid] / (2.0f * y[tid]);
}