rlx-mlx 0.2.2

MLX backend for RLX — Apple's array framework via hand-rolled C++ shim, eager + lazy execution
// SPDX-License-Identifier: GPL-3.0-only
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.

// ─── Batched dense LU + solve, f32, one threadgroup per system ─────
//
// Workload shape that motivates this kernel:
//   A: [B, n, n]  — B independent dense systems
//   b: [B, n]     — per-system RHS
//   x: [B, n]     — per-system solution
// with B in the thousands and n in 10..NMAX — the eda-mna Monte-Carlo
// case after the per-draw MNA system is batched along axis 0.
//
// One MLX threadgroup handles one system. Each thread owns one row.
// Capacity bound: A_local + b_local must fit in threadgroup memory
// (32 KB on M-series). At f32 that's n² + n ≤ 8192 floats ⇒ n ≤ 90.
//
// **Algorithm:** Doolittle LU **with partial pivoting**, then forward
// + back substitution. Per-step pivot: tid 0 scans column k for the
// row with maximum |A[i,k]|, swaps that row with row k, then all
// threads cooperate on the rank-1 elimination. The pivot search +
// swap is serial on tid 0 — for n ≤ NMAX this is O(n) work per step
// against O(n) threads doing real elimination work in parallel after
// the barrier, so the serial bit is bounded and not the bottleneck.
// (For larger n a SIMD-reduction argmax would matter; out of scope
// for this kernel.)
//
// MLX dispatch contract (set on the host side):
//   grid       = (B, 1, 1)
//   threadgroup = (n, 1, 1)
//
// MLX-binding contract:
//   • Source is the **body** of the kernel function. MLX wraps it in
//     `[[kernel]] void custom_kernel_<name>(...)` with parameters
//     auto-generated from input/output names + dtypes.
//   • Inputs `A`, `b` are bound as `const device float*` (row-major
//     contiguous; ensure_row_contiguous=true on the host).
//   • Output `x` is bound as `device float*`.
//   • Because the source mentions `A_shape`, MLX also auto-injects
//     `const constant int* A_shape` — we read n and B from there.
//   • Built-ins `thread_position_in_threadgroup` and
//     `threadgroup_position_in_grid` are auto-detected and added with
//     the right `[[…]]` annotations.
//   • `NMAX` is supplied via the `header` arg as a #define so the
//     threadgroup-memory sizes are fixed per-dispatch-shape.

#ifndef NMAX
#define NMAX 64
#endif

// ── Workspace ────────────────────────────────────────────────────
threadgroup float Aloc[NMAX * NMAX];
threadgroup float bloc[NMAX];
// Single-slot threadgroup scratch for the per-step pivot row index.
// tid 0 writes; all threads read after the barrier.
threadgroup uint pivot_row[1];

// Read shape from MLX-injected A_shape: A is [B, n, n] row-major.
const uint B = (uint)A_shape[0];
const uint n = (uint)A_shape[1];

const uint gid = threadgroup_position_in_grid.x;
const uint tid = thread_position_in_threadgroup.x;

if (gid >= B) return;
if (tid >= n) return;

// ── Stage 1: cooperative load A[gid], b[gid] → threadgroup ─────
// Index directly off the parameter pointers — small inputs may land
// in `constant` address space (MLX uses constant for size < 8) while
// large inputs land in `device`. Keeping access inline lets the same
// kernel compile for both layouts without auto/decltype gymnastics.
for (uint j = 0; j < n; ++j) {
    Aloc[tid * NMAX + j] = A[gid * (n * n) + tid * n + j];
}
bloc[tid] = b[gid * n + tid];
threadgroup_barrier(mem_flags::mem_threadgroup);

// ── Stage 2: Doolittle LU in place, with partial pivoting ────────
// For k = 0..n-1:
//   • tid 0 finds row p in [k, n) with max |A[p, k]|, swaps rows k↔p
//     in both Aloc and bloc.
//   • Barrier so all threads see the post-swap matrix.
//   • Threads with tid > k do the rank-1 elimination on their row.
//   • Barrier so the next k sees a consistent matrix.
for (uint k = 0; k < n; ++k) {
    if (tid == 0) {
        uint  best     = k;
        float best_abs = fabs(Aloc[k * NMAX + k]);
        for (uint i = k + 1; i < n; ++i) {
            float v = fabs(Aloc[i * NMAX + k]);
            if (v > best_abs) { best = i; best_abs = v; }
        }
        pivot_row[0] = best;
        if (best != k) {
            for (uint j = 0; j < n; ++j) {
                float tmp           = Aloc[k    * NMAX + j];
                Aloc[k    * NMAX + j] = Aloc[best * NMAX + j];
                Aloc[best * NMAX + j] = tmp;
            }
            float tmp_b = bloc[k];
            bloc[k]    = bloc[best];
            bloc[best] = tmp_b;
        }
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    float pivot = Aloc[k * NMAX + k];
    if (tid > k && tid < n) {
        float factor = Aloc[tid * NMAX + k] / pivot;
        Aloc[tid * NMAX + k] = factor;
        for (uint j = k + 1; j < n; ++j) {
            Aloc[tid * NMAX + j] -= factor * Aloc[k * NMAX + j];
        }
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);
}

// ── Stage 3: forward solve L·y = b; L is unit-lower of Aloc ──────
// Sequential on i. For n ≤ NMAX a serial loop on tid==0 avoids
// barrier proliferation. Per-i inner dot could be SIMD-reduced if
// profiling shows this dominates.
if (tid == 0) {
    for (uint i = 0; i < n; ++i) {
        float sum = bloc[i];
        for (uint j = 0; j < i; ++j) {
            sum -= Aloc[i * NMAX + j] * bloc[j];
        }
        bloc[i] = sum;  // y, stored in place of b
    }

    // ── Stage 4: back-solve U·x = y; U is upper of Aloc inc. diag ─
    for (int i = (int)n - 1; i >= 0; --i) {
        float sum = bloc[i];
        for (uint j = (uint)i + 1; j < n; ++j) {
            sum -= Aloc[i * NMAX + j] * bloc[j];
        }
        bloc[i] = sum / Aloc[i * NMAX + i];  // x[i]
    }
}
threadgroup_barrier(mem_flags::mem_threadgroup);

// ── Stage 5: write x out ─────────────────────────────────────────
x[gid * n + tid] = bloc[tid];