mlx-native 0.8.0

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// ADR-020 iter-11h-b — Causal depthwise 1D convolution kernels for
/// the GpuTape autograd pipeline.
///
/// Distinct from `ssm_conv.metal` which (a) fuses a SiLU activation
/// and (b) takes a state buffer for autoregressive decode.  This
/// kernel family is for TRAINING-MODE forward + backward:
///   * No state (zero-pad on the past — first K-1 outputs see
///     fewer-than-K input taps).
///   * No SiLU fusion (silu is a separate `OpKind` in GpuTape; this
///     keeps backward derivation clean and aligned with the
///     "one OpKind = one math primitive" composition principle
///     established by iter-11h-a).
///
/// ## Math
///
/// Forward (per (t, c)):
///   y[t, c] = Σ_{k=0..K-1, t+k-(K-1) >= 0} kernel_w[c, k] · x[t+k-(K-1), c]
///
/// Equivalently with `i = t+k-(K-1)`:
///   y[t, c] = Σ_{i=max(0, t-K+1)..t} kernel_w[c, t-i+(K-1)] · x[i, c]
///
/// Backward dx (per (i, c)):
///   dx[i, c] = Σ_{k=0..K-1, t=i+(K-1)-k in [0, n)} kernel_w[c, k] · dy[t, c]
///
/// Backward dw (per (c, k)):
///   dw[c, k] = Σ_{t=K-1-k..n-1} x[t+k-(K-1), c] · dy[t, c]
///            = Σ_{i=0..n-1-(K-1-k)} x[i, c] · dy[i+(K-1)-k, c]
///
/// ## Layout
///
///   x : `[n_tokens, channels]` row-major: `x[t][c]` at `t*channels + c`.
///   y : `[n_tokens, channels]` same layout as x.
///   kernel_w : `[channels, K]` row-major: `w[c][k]` at `c*K + k`.
///
/// (Matches `ssm_conv.metal`'s `kernel_w` layout convention; differs
/// in the activation/state input plane shape because we drop n_seqs
/// — this kernel handles a single sequence at a time, n_seqs=1
/// implicit.  Multi-seq batching can wrap with an outer dispatch.)

// ──────────────────────────────────────────────────────────────────
// Forward: y[t, c] = Σ_k kernel_w[c, k] · x_ext[t+k-(K-1), c]
//   x_ext zero-pads for indices < 0.
// ──────────────────────────────────────────────────────────────────
kernel void conv1d_depthwise_causal_forward_f32(
    device const float *x         [[buffer(0)]],
    device const float *kernel_w  [[buffer(1)]],
    device float       *y         [[buffer(2)]],
    device const uint  *params    [[buffer(3)]],   // [n_tokens, channels, K]
    uint2 tid [[thread_position_in_grid]]
) {
    const uint n_tokens = params[0];
    const uint channels = params[1];
    const uint K        = params[2];

    const uint t = tid.x;
    const uint c = tid.y;
    if (t >= n_tokens || c >= channels) return;

    float sum = 0.0f;
    // k_min: smallest k such that t+k-(K-1) >= 0, i.e. k >= K-1-t.
    const uint k_min = (t + 1u >= K) ? 0u : (K - 1u - t);
    for (uint k = k_min; k < K; ++k) {
        const uint i = t + k - (K - 1u);
        sum += kernel_w[c * K + k] * x[i * channels + c];
    }
    y[t * channels + c] = sum;
}

// ──────────────────────────────────────────────────────────────────
// Backward dx: dx[i, c] = Σ_{k where t=i+(K-1)-k is valid} kernel_w[c, k] · dy[t, c]
//   t = i + (K-1) - k in [0, n_tokens) ⇒
//     k in [max(0, i+(K-1) - (n_tokens-1)), min(K-1, i+(K-1))]
//                          ↑ = max(0, i-K+1+n_tokens-1)... wait i+(K-1) - (n_tokens-1) = i-n_tokens+K
//                      so k_min = max(0, i+K - n_tokens) (drop terms where t >= n_tokens)
//                      and k_max = min(K-1, i+(K-1)) = K-1 since i >= 0 ⇒ i+(K-1) >= K-1
// ──────────────────────────────────────────────────────────────────
kernel void conv1d_depthwise_causal_backward_dx_f32(
    device const float *dy        [[buffer(0)]],
    device const float *kernel_w  [[buffer(1)]],
    device float       *dx        [[buffer(2)]],
    device const uint  *params    [[buffer(3)]],   // [n_tokens, channels, K]
    uint2 tid [[thread_position_in_grid]]
) {
    const uint n_tokens = params[0];
    const uint channels = params[1];
    const uint K        = params[2];

    const uint i = tid.x;
    const uint c = tid.y;
    if (i >= n_tokens || c >= channels) return;

    float sum = 0.0f;
    // k_min = max(0, i+K - n_tokens). Branch-free via conditional.
    const uint k_min = (i + K > n_tokens) ? (i + K - n_tokens) : 0u;
    // k_max = K-1 always when i >= 0 (always true for unsigned).
    for (uint k = k_min; k < K; ++k) {
        const uint t = i + (K - 1u) - k;
        sum += kernel_w[c * K + k] * dy[t * channels + c];
    }
    dx[i * channels + c] = sum;
}

// ──────────────────────────────────────────────────────────────────
// Backward dw: dw[c, k] = Σ_{t=K-1-k..n-1} x[t+k-(K-1), c] · dy[t, c]
//   Substituting i = t+k-(K-1):
//   dw[c, k] = Σ_{i=0..n-1-(K-1-k)} x[i, c] · dy[i+(K-1)-k, c]
//                   = Σ_{i=0..n-1-(K-1-k)} x[i, c] · dy[t, c]   where t = i+K-1-k
//
// Range guard: we need t < n_tokens.  For each (c, k), iterate i from
// 0 upward; stop when t = i+K-1-k >= n_tokens.
// ──────────────────────────────────────────────────────────────────
kernel void conv1d_depthwise_causal_backward_dw_f32(
    device const float *x         [[buffer(0)]],
    device const float *dy        [[buffer(1)]],
    device float       *dw        [[buffer(2)]],
    device const uint  *params    [[buffer(3)]],   // [n_tokens, channels, K]
    uint2 tid [[thread_position_in_grid]]
) {
    const uint n_tokens = params[0];
    const uint channels = params[1];
    const uint K        = params[2];

    const uint c = tid.x;
    const uint k = tid.y;
    if (c >= channels || k >= K) return;

    float sum = 0.0f;
    // t = i + (K-1) - k starts at i=0 → t = K-1-k.  Stop at t = n_tokens-1.
    // i_max = n_tokens - 1 - (K-1-k) = n_tokens - K + k.
    // Guard against underflow when n_tokens < K - k.
    if (n_tokens + k < K) {
        dw[c * K + k] = 0.0f;
        return;
    }
    const uint i_max = n_tokens + k - K;  // inclusive bound
    for (uint i = 0u; i <= i_max; ++i) {
        const uint t = i + (K - 1u) - k;
        sum += x[i * channels + c] * dy[t * channels + c];
    }
    dw[c * K + k] = sum;
}