mlx-native 0.7.1

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// ADR-020 iter-11h-c2 — vector outer product `Y = lhs ⊗ rhs`.
/// Building block for `gated_delta_update`'s state-update term
/// `state += outer(delta, k)` (mlx-lm/gated_delta.py:_gated_delta_step_ops).
///
/// Distinct from matmul: matmul kernel has a 32-element floor on each
/// dim (M, N, K all ≥ 32 for backward dW dispatch); outer products
/// have inner-dim = 1, so they fall below the floor.  This kernel
/// family handles the 1-D × 1-D = 2-D case directly.
///
/// Layout:
///   lhs : `[N]` row-major f32
///   rhs : `[M]` row-major f32
///   y   : `[N, M]` row-major f32:  y[i, j] = lhs[i] · rhs[j]
///
/// Math:
///   Forward: y[i, j] = lhs[i] · rhs[j]
///   Backward dlhs[i] = Σ_j dy[i, j] · rhs[j]   (row sum of dy * rhs[None, :])
///   Backward drhs[j] = Σ_i dy[i, j] · lhs[i]   (col sum of dy * lhs[:, None])

// ──────────────────────────────────────────────────────────────────
// Forward
// ──────────────────────────────────────────────────────────────────
kernel void outer_product_f32(
    device const float *lhs    [[buffer(0)]],
    device const float *rhs    [[buffer(1)]],
    device float       *y      [[buffer(2)]],
    device const uint  *params [[buffer(3)]],   // [N, M]
    uint2 tid [[thread_position_in_grid]]
) {
    const uint N = params[0];
    const uint M = params[1];
    const uint i = tid.x;
    const uint j = tid.y;
    if (i >= N || j >= M) return;
    y[i * M + j] = lhs[i] * rhs[j];
}

// ──────────────────────────────────────────────────────────────────
// Backward dlhs: dlhs[i] = Σ_j dy[i, j] · rhs[j]
// ──────────────────────────────────────────────────────────────────
kernel void outer_product_backward_lhs_f32(
    device const float *dy     [[buffer(0)]],
    device const float *rhs    [[buffer(1)]],
    device float       *dlhs   [[buffer(2)]],
    device const uint  *params [[buffer(3)]],   // [N, M]
    uint tid [[thread_position_in_grid]]
) {
    const uint N = params[0];
    const uint M = params[1];
    const uint i = tid;
    if (i >= N) return;
    float acc = 0.0f;
    for (uint j = 0; j < M; ++j) {
        acc += dy[i * M + j] * rhs[j];
    }
    dlhs[i] = acc;
}

// ──────────────────────────────────────────────────────────────────
// Backward drhs: drhs[j] = Σ_i dy[i, j] · lhs[i]
// ──────────────────────────────────────────────────────────────────
kernel void outer_product_backward_rhs_f32(
    device const float *dy     [[buffer(0)]],
    device const float *lhs    [[buffer(1)]],
    device float       *drhs   [[buffer(2)]],
    device const uint  *params [[buffer(3)]],   // [N, M]
    uint tid [[thread_position_in_grid]]
) {
    const uint N = params[0];
    const uint M = params[1];
    const uint j = tid;
    if (j >= M) return;
    float acc = 0.0f;
    for (uint i = 0; i < N; ++i) {
        acc += dy[i * M + j] * lhs[i];
    }
    drhs[j] = acc;
}