mlx-native 0.6.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

// SSM depthwise causal 1D convolution + SiLU for Qwen3.5 Gated DeltaNet.
//
// Spec source: ADR-013 Decision 7. Formula derived from Mamba-family causal
// depthwise 1D convolution literature plus the spec block in ADR-013; no
// llama.cpp source referenced.
//
// Operation:
//   ssm_conv(x, kernel_w, state) -> (y, new_state)
//     x:        [channels, n_tokens, n_seqs]
//     kernel_w: [K, channels]           (K = 4 for Qwen3.5)
//     state:    [K-1, channels, n_seqs] (previous (K-1) conv inputs per seq)
//   extended(c, t_ext, s) =
//     state(t_ext, c, s)              if t_ext < K-1
//     x(c, t_ext - (K-1), s)          otherwise
//   y(c, t, s) = silu( sum_{k=0..K} kernel_w(k, c) * extended(c, t + k, s) )
//   new_state(i, c, s) = extended(c, n_tokens + i, s)  for i in 0..K-1
//
// Causality: sum index t+k, for k in [0, K-1], maps to extended indices
// [t, t+K-1]. At t=0 these are [0, K-1] which are the (K-1) state tokens
// and the first x token, so each output only depends on past/current inputs.
//
// Memory layout (column-major / innermost-first):
//   x[c, t, s]        at offset s * (n_tokens * channels) + t * channels + c
//   y[c, t, s]        same as x
//   state[i, c, s]    at offset s * ((K-1) * channels) + c * (K-1) + i
//   kernel_w[k, c]    at offset c * K + k
//
// Numerical stability: accumulation is performed in f32 regardless of input
// dtype. SiLU uses `x / (1 + exp(-x))`.

// -----------------------------------------------------------------
// Kernel A — convolution forward (one thread per (c, t, s))
// -----------------------------------------------------------------

kernel void ssm_conv_forward_f32(
    device const float *x           [[buffer(0)]],
    device const float *kernel_w    [[buffer(1)]],
    device const float *state       [[buffer(2)]],
    device float       *y           [[buffer(3)]],
    device const uint  *params      [[buffer(4)]],
    uint3 tid [[thread_position_in_grid]]
) {
    const uint channels  = params[0];
    const uint n_tokens  = params[1];
    const uint n_seqs    = params[2];
    const uint k_width   = params[3]; // K, typically 4

    const uint c = tid.x;
    const uint t = tid.y;
    const uint s = tid.z;
    if (c >= channels || t >= n_tokens || s >= n_seqs) {
        return;
    }

    const uint x_seq_stride = n_tokens * channels;
    const uint s_seq_stride = (k_width - 1u) * channels;

    float sum = 0.0f;
    for (uint k = 0; k < k_width; ++k) {
        const uint t_ext = t + k;
        float val;
        if (t_ext < k_width - 1u) {
            val = state[s * s_seq_stride + c * (k_width - 1u) + t_ext];
        } else {
            const uint t_in = t_ext - (k_width - 1u);
            val = x[s * x_seq_stride + t_in * channels + c];
        }
        sum += kernel_w[c * k_width + k] * val;
    }

    // SiLU: y = sum * sigmoid(sum) = sum / (1 + exp(-sum))
    const float out = sum / (1.0f + metal::exp(-sum));

    y[s * x_seq_stride + t * channels + c] = out;
}

kernel void ssm_conv_forward_bf16(
    device const bfloat *x          [[buffer(0)]],
    device const bfloat *kernel_w   [[buffer(1)]],
    device const bfloat *state      [[buffer(2)]],
    device bfloat       *y          [[buffer(3)]],
    device const uint   *params     [[buffer(4)]],
    uint3 tid [[thread_position_in_grid]]
) {
    const uint channels = params[0];
    const uint n_tokens = params[1];
    const uint n_seqs   = params[2];
    const uint k_width  = params[3];

    const uint c = tid.x;
    const uint t = tid.y;
    const uint s = tid.z;
    if (c >= channels || t >= n_tokens || s >= n_seqs) {
        return;
    }

    const uint x_seq_stride = n_tokens * channels;
    const uint s_seq_stride = (k_width - 1u) * channels;

    float sum = 0.0f;
    for (uint k = 0; k < k_width; ++k) {
        const uint t_ext = t + k;
        float val;
        if (t_ext < k_width - 1u) {
            val = float(state[s * s_seq_stride + c * (k_width - 1u) + t_ext]);
        } else {
            const uint t_in = t_ext - (k_width - 1u);
            val = float(x[s * x_seq_stride + t_in * channels + c]);
        }
        sum += float(kernel_w[c * k_width + k]) * val;
    }

    const float out = sum / (1.0f + metal::exp(-sum));

    y[s * x_seq_stride + t * channels + c] = bfloat(out);
}

// -----------------------------------------------------------------
// Kernel B — state update (one thread per (i, c, s))
//
// new_state(i, c, s) = extended(c, n_tokens + i, s)
//                    = state(n_tokens + i, c, s)           if n_tokens + i < K - 1
//                    = x(c, n_tokens + i - (K - 1), s)      otherwise
//
// Written in two passes because in the n_tokens + i < K - 1 branch we read
// from `state` which aliases the output buffer. Caller must provide a
// separate `new_state` buffer to avoid the aliasing race when n_tokens < K-1.
// -----------------------------------------------------------------

kernel void ssm_conv_state_update_f32(
    device const float *x          [[buffer(0)]],
    device const float *old_state  [[buffer(1)]],
    device float       *new_state  [[buffer(2)]],
    device const uint  *params     [[buffer(3)]],
    uint3 tid [[thread_position_in_grid]]
) {
    const uint channels = params[0];
    const uint n_tokens = params[1];
    const uint n_seqs   = params[2];
    const uint k_width  = params[3];
    const uint k_minus1 = k_width - 1u;

    const uint i = tid.x; // 0..k-1
    const uint c = tid.y;
    const uint s = tid.z;
    if (i >= k_minus1 || c >= channels || s >= n_seqs) {
        return;
    }

    const uint x_seq_stride = n_tokens * channels;
    const uint s_seq_stride = k_minus1 * channels;
    const uint t_ext = n_tokens + i;

    float val;
    if (t_ext < k_minus1) {
        val = old_state[s * s_seq_stride + c * k_minus1 + t_ext];
    } else {
        const uint t_in = t_ext - k_minus1;
        val = x[s * x_seq_stride + t_in * channels + c];
    }
    new_state[s * s_seq_stride + c * k_minus1 + i] = val;
}

kernel void ssm_conv_state_update_bf16(
    device const bfloat *x          [[buffer(0)]],
    device const bfloat *old_state  [[buffer(1)]],
    device bfloat       *new_state  [[buffer(2)]],
    device const uint   *params     [[buffer(3)]],
    uint3 tid [[thread_position_in_grid]]
) {
    const uint channels = params[0];
    const uint n_tokens = params[1];
    const uint n_seqs   = params[2];
    const uint k_width  = params[3];
    const uint k_minus1 = k_width - 1u;

    const uint i = tid.x;
    const uint c = tid.y;
    const uint s = tid.z;
    if (i >= k_minus1 || c >= channels || s >= n_seqs) {
        return;
    }

    const uint x_seq_stride = n_tokens * channels;
    const uint s_seq_stride = k_minus1 * channels;
    const uint t_ext = n_tokens + i;

    bfloat val;
    if (t_ext < k_minus1) {
        val = old_state[s * s_seq_stride + c * k_minus1 + t_ext];
    } else {
        const uint t_in = t_ext - k_minus1;
        val = x[s * x_seq_stride + t_in * channels + c];
    }
    new_state[s * s_seq_stride + c * k_minus1 + i] = val;
}