# rlx-mlx
Apple MLX backend for RLX — vendored MLX via a hand-rolled C++ shim,
eager + lazy + compiled execution.
## Modes
- **Lazy** *(default)* — build the entire MLX graph in `run()`, then
call `mlx::core::eval` once on all outputs. Lets MLX's optimizer
schedule the whole DAG, equivalent in spirit to the `mps_graph`
path in rlx-metal.
- **Eager** — eval after every op. Slower; useful for debugging
because failures surface at the offending op rather than at the
final eval.
- **Compiled** — `mlx::compile`-built persistent function for repeated
shapes; trace-cache amortizes re-runs.
Mode is set per-compile via `MlxExecutable::compile_with_mode`, or
globally via `RLX_MLX_MODE=eager|lazy|compiled` (default lazy).
## What's here
- [`rlx-mlx-sys`](../rlx-mlx-sys) — vendored MLX (`vendor/mlx`), CMake
build, and `cpp/rlx_mlx_shim.{h,cpp}` C ABI over `mlx::core::*`.
- `src/` — re-exports `rlx_mlx_sys::ffi`; RAII wrappers and lowering:
- `src/array.rs` — RAII `Array` wrapper, `MlxError`, top-level `eval`.
- `src/ops.rs` — typed wrappers: matmul / add / mul / sub / div /
softmax / gelu / silu / cast / layer_norm.
- `src/lower.rs` — walks `rlx_ir::Graph` in topo order, building MLX
arrays for each node. Rebuilds the graph fresh each `run()` (see the
comment in lower.rs for why).
- `src/backend.rs` — `MlxExecutable` (set_param / run / handles).
- **FFT** — native `mlx::fft::fft` via `rlx_mlx_shim` for `Op::Fft`;
graph helpers (`rfft`, `irfft`, …) lower through the same path.
- Tier-1 / Tier-2 / Tier-3 backward op parity with `rlx-cpu` for
reverse-mode autodiff (relu, activation, softmax cross-entropy, layer
norm, conv2d, max-pool, fake-quantize).
## Install
Native MLX lives in **`rlx-mlx-sys`** (submodule + `build.rs`). After clone:
```sh
git submodule update --init rlx-mlx-sys/vendor/mlx
```
```toml
[dependencies]
rlx = { version = "0.2", features = ["mlx"] }
# or directly:
rlx-mlx = "0.2"
rlx-mlx-sys = "0.2"
```
The first build compiles MLX from source — minutes, not seconds.
## Build / test
```sh
cargo build -p rlx-mlx --release
cargo test -p rlx-mlx --release
```
Through `rlx-runtime`:
```sh
cargo build -p rlx-runtime --features mlx --release
```
## Status
Mature on Apple Silicon (M1 / M2 / M3 / M4). On Intel Macs MLX falls
back to its CPU path; supported but rarely the right choice.
## Gotchas
- **Op coverage.** First cut handled MatMul, Binary (Add/Mul/Sub/Div),
Activation (Gelu/Silu), Cast, Softmax, LayerNorm. Now covers matmul,
all binary / activation / cast / reduce / softmax / layer-norm /
RMS-norm, fused attention (SDPA via
`fast::scaled_dot_product_attention`), pool composition, dot-general,
selective-scan unroll, calibrated cost model, async commit + sync.
Anything else returns `MlxError("unsupported op …")` from
`lower::lower_and_run`. Adding an op means: an entry in `cpp/shim.h`,
the matching impl in `shim.cpp`, an `extern "C"` decl in `ffi.rs`,
a wrapper in `ops.rs`, and a match arm in `lower.rs`.
- **Fresh-graph-per-run.** Every `run()` rebuilds the MLX graph from
scratch. MLX's own trace cache amortizes this, but if you need lower
per-run latency, the next step is `mlx::compile`-style placeholder
bindings (track the input/param NodeIds → MLX placeholder handles,
reuse the compiled graph across runs).
- **F32 I/O default.** Inputs/params come in as `&[f32]` and outputs come
out as `Vec<f32>`. The shim casts to/from MLX's per-array dtype
internally (so AutoMixedPrecision still does the right thing inside
the graph). The runtime trait now exposes
`set_param_typed(name, &[u8], dtype)` and
`run_typed(inputs: &[(&str, &[u8], DType)]) -> Vec<(Vec<u8>, DType)>`;
default impls handle F32 only; the MLX backend overrides with the
zero-widen path through `Array::from_bytes` / `Array::to_bytes`. CPU
and Metal inherit the F32 default — they panic for non-F32 typed
inputs (override is a future PR for those backends).
- **Constants must be F32.** Non-F32 `Op::Constant` payloads error in
lower.rs — the constant byte format is little-endian f32. Add F16/I32
constant decoding when a model needs it.
- **Async pipeline:** `commit_no_wait` schedules the lowered graph via
`mlx::core::async_eval` and stashes the output handles; `sync_pending`
calls `mlx::core::synchronize` and drops them. `run()` always calls
`sync_pending()` first, so an explicit run() after a commit is safe.
No per-stream isolation yet — synchronize() drains every MLX stream.
- **KV-cache pattern:** if an output slot's name is `out{i}` and a
handle of the same name is bound, `run()` syncs the f32 result back
into the handle so the next iteration picks it up as input.
- **`run_slots` arena:** the slot path keeps a synthetic `Vec<u8>`
arena owned by the executable. Outputs are copied into it after each
`run_slots` call so callers can read results via
`arena_ptr().add(offset)` without per-output `Vec<f32>` allocations.
Cheaper than `run()` when output sizes are tiny but the per-call
bookkeeping cost matters.
- **Attention `SlidingWindow` mask:** synthesized host-side as an
additive `[seq_q, seq_k]` mask (0 where allowed, -inf elsewhere),
then passed through `fast::scaled_dot_product_attention` with
`mode="array"`. MLX has no native sliding-window mode.
- **Sample:** temperature scaling + `top_k` filter + `top_p`
(nucleus) filter + `mlx::random::categorical`. top_k uses `mc::topk`
for the threshold; top_p sorts descending (via `sort` + negate),
takes an exclusive cumsum of the sorted probs, masks entries whose
cumsum < top_p, picks the smallest probability still in that
nucleus as the threshold, and applies it back to the original
logits via `where(p >= threshold, logits, -∞)`.
- **Persistent compiled graph (`MlxMode::Compiled`):** the executable
builds a `CompiledFn` lazily on first `run()`. Internally a Rust
callback walks the IR via `lower::lower_with_env`; the shim wraps it
as `std::function`, hands it to `mc::compile`, and stores the
returned function. Subsequent calls replay the optimized trace.
- **Calibration + cost model:** `calibrate::Calibration::load_or_measure()`
measures sgemm GF/s at one large + one small shape plus a tiny-graph
round-trip overhead, **plus** memory bandwidth (large contiguous
copy), attention throughput (1×4×128×64 SDPA), and reduce throughput
(1024×1024 sum-along-last-axis). Caches at
`~/.cache/rlx/mlx-calib-<sanitized-device-name>.json` and feeds
`rlx_runtime::cost::MlxCostModel` so `pick_best_device` can rank MLX
honestly.
- **Pool composition:** `Op::Pool` is lowered by composing
`slice_strided` over the kernel grid plus a reduction.
Supports 1D / 2D / 3D inputs (channels-first layout) and all five
reduction kinds (max/min/sum/mean/prod). Constant-pad with -∞ for
max-pool, +∞ for min-pool, 1.0 for prod, 0 elsewhere.
- **DotGeneral lowering:** the canonical 2D pattern (no batch dims,
contract `lhs[1]` × `rhs[0]`) reduces to a plain `MatMul`, matching
what the optimizer's `LowerDotGeneral` pass would have produced.
Non-canonical patterns (batched, alternative contracting axes) error
with a clear diagnostic — same coverage as the optimizer pass.
- **FusedTransformerLayer composition:** the full BERT-style post-norm
block (attention → residual+LN → FFN → residual+LN) composed from
primitives. Honors all four mask kinds via the underlying SDPA path.
- **`Op::If` / `Op::While`** are now lowered. We adopt a positional
binding convention between the sub-graph's `Op::Input` nodes (in
topo order) and the parent's captures (`inputs[1..]` for `If`,
`inputs[..]` for `While`); sub-graph `Op::Param` nodes look up by
name in the parent's param maps; sub-graph `Op::Constant` nodes are
inline. `Op::If` evaluates both branches and combines via
`mc::where`. `Op::While` requires `max_iterations` and unrolls; an
active-mask gate via `where(active && cond, body_out, carried)`
freezes loop-carried values once the condition becomes false. Single-
output `While` only — multi-output convention isn't defined in the
IR. Compile mode (`MlxMode::Compiled`) doesn't yet recurse through
sub-graph leaves; `If`/`While` inside a compiled trace will fail
with a missing-param diagnostic. Use `Lazy`/`Eager` for control flow.
- **SelectiveScan composition:** `Op::SelectiveScan` (Mamba SSM step)
is lowered by unrolling the time loop into seq many op chains.
At each t we slice δ/x/B/C, broadcast against A, update the
running state via `exp(δA) * state + δ*B*x`, and accumulate
`sum_n(C * state)` as the output. Per-call cost amortizes through
`mlx::compile`'s trace cache. Acceptable for static-shape graphs
(which all our graphs are); for very long sequences a custom Metal
kernel via `fast::metal_kernel` would beat this on raw throughput.
- **Native ElementwiseRegion lowering (PLAN L2):** `Op::ElementwiseRegion`
is lowered in `lower.rs` by composing `ops::*` per `ChainStep`
(Activation/Cast/Binary/Compare) directly into MLX's lazy trace.
Each step is resolved positionally — `ChainOperand::Input(i)` reads
`node.inputs[i]` and `ChainOperand::Step(i)` reads the array
produced by chain step `i`. Because the whole chain becomes a sub-DAG
inside MLX's trace, `mlx::compile` and the lazy evaluator get to
fuse it into a single kernel — no decomposer round-trip and no
extra Op nodes for the executor to walk. The runtime backend now
runs `MarkElementwiseRegions` (instead of `UnfuseElementwiseRegions`)
ahead of MLX compilation so chains are collapsed before lowering.
## License
GPL-3.0-only.