rlx-mlx 0.2.0

MLX backend for RLX — Apple's array framework via hand-rolled C++ shim, eager + lazy execution
docs.rs failed to build rlx-mlx-0.2.0
Please check the build logs for more information.
See Builds for ideas on how to fix a failed build, or Metadata for how to configure docs.rs builds.
If you believe this is docs.rs' fault, open an issue.

rlx-mlx

Apple MLX backend for RLX — vendored MLX via a hand-rolled C++ shim, eager + lazy + compiled execution.

Modes

  • Lazy (default) — build the entire MLX graph in run(), then call mlx::core::eval once on all outputs. Lets MLX's optimizer schedule the whole DAG, equivalent in spirit to the mps_graph path in rlx-metal.
  • Eager — eval after every op. Slower; useful for debugging because failures surface at the offending op rather than at the final eval.
  • Compiledmlx::compile-built persistent function for repeated shapes; trace-cache amortizes re-runs.

Mode is set per-compile via MlxExecutable::compile_with_mode, or globally via RLX_MLX_MODE=eager|lazy|compiled (default lazy).

What's here

  • rlx-mlx-sys — vendored MLX (vendor/mlx), CMake build, and cpp/rlx_mlx_shim.{h,cpp} C ABI over mlx::core::*.
  • src/ — re-exports rlx_mlx_sys::ffi; RAII wrappers and lowering:
  • src/array.rs — RAII Array wrapper, MlxError, top-level eval.
  • src/ops.rs — typed wrappers: matmul / add / mul / sub / div / softmax / gelu / silu / cast / layer_norm.
  • src/lower.rs — walks rlx_ir::Graph in topo order, building MLX arrays for each node. Rebuilds the graph fresh each run() (see the comment in lower.rs for why).
  • src/backend.rsMlxExecutable (set_param / run / handles).
  • Tier-1 / Tier-2 / Tier-3 backward op parity with rlx-cpu for reverse-mode autodiff (relu, activation, softmax cross-entropy, layer norm, conv2d, max-pool, fake-quantize).

Install

Native MLX lives in rlx-mlx-sys (submodule + build.rs). After clone:

git submodule update --init rlx-mlx-sys/vendor/mlx
[dependencies]
rlx = { version = "0.2", features = ["mlx"] }
# or directly:
rlx-mlx = "0.2"
rlx-mlx-sys = "0.2"

The first build compiles MLX from source — minutes, not seconds.

Build / test

cargo build -p rlx-mlx --release
cargo test  -p rlx-mlx --release

Through rlx-runtime:

cargo build -p rlx-runtime --features mlx --release

Status

Mature on Apple Silicon (M1 / M2 / M3 / M4). On Intel Macs MLX falls back to its CPU path; supported but rarely the right choice.

Gotchas

  • Op coverage. First cut handled MatMul, Binary (Add/Mul/Sub/Div), Activation (Gelu/Silu), Cast, Softmax, LayerNorm. Now covers matmul, all binary / activation / cast / reduce / softmax / layer-norm / RMS-norm, fused attention (SDPA via fast::scaled_dot_product_attention), pool composition, dot-general, selective-scan unroll, calibrated cost model, async commit + sync. Anything else returns MlxError("unsupported op …") from lower::lower_and_run. Adding an op means: an entry in cpp/shim.h, the matching impl in shim.cpp, an extern "C" decl in ffi.rs, a wrapper in ops.rs, and a match arm in lower.rs.
  • Fresh-graph-per-run. Every run() rebuilds the MLX graph from scratch. MLX's own trace cache amortizes this, but if you need lower per-run latency, the next step is mlx::compile-style placeholder bindings (track the input/param NodeIds → MLX placeholder handles, reuse the compiled graph across runs).
  • F32 I/O default. Inputs/params come in as &[f32] and outputs come out as Vec<f32>. The shim casts to/from MLX's per-array dtype internally (so AutoMixedPrecision still does the right thing inside the graph). The runtime trait now exposes set_param_typed(name, &[u8], dtype) and run_typed(inputs: &[(&str, &[u8], DType)]) -> Vec<(Vec<u8>, DType)>; default impls handle F32 only; the MLX backend overrides with the zero-widen path through Array::from_bytes / Array::to_bytes. CPU and Metal inherit the F32 default — they panic for non-F32 typed inputs (override is a future PR for those backends).
  • Constants must be F32. Non-F32 Op::Constant payloads error in lower.rs — the constant byte format is little-endian f32. Add F16/I32 constant decoding when a model needs it.
  • Async pipeline: commit_no_wait schedules the lowered graph via mlx::core::async_eval and stashes the output handles; sync_pending calls mlx::core::synchronize and drops them. run() always calls sync_pending() first, so an explicit run() after a commit is safe. No per-stream isolation yet — synchronize() drains every MLX stream.
  • KV-cache pattern: if an output slot's name is out{i} and a handle of the same name is bound, run() syncs the f32 result back into the handle so the next iteration picks it up as input.
  • run_slots arena: the slot path keeps a synthetic Vec<u8> arena owned by the executable. Outputs are copied into it after each run_slots call so callers can read results via arena_ptr().add(offset) without per-output Vec<f32> allocations. Cheaper than run() when output sizes are tiny but the per-call bookkeeping cost matters.
  • Attention SlidingWindow mask: synthesized host-side as an additive [seq_q, seq_k] mask (0 where allowed, -inf elsewhere), then passed through fast::scaled_dot_product_attention with mode="array". MLX has no native sliding-window mode.
  • Sample: temperature scaling + top_k filter + top_p (nucleus) filter + mlx::random::categorical. top_k uses mc::topk for the threshold; top_p sorts descending (via sort + negate), takes an exclusive cumsum of the sorted probs, masks entries whose cumsum < top_p, picks the smallest probability still in that nucleus as the threshold, and applies it back to the original logits via where(p >= threshold, logits, -∞).
  • Persistent compiled graph (MlxMode::Compiled): the executable builds a CompiledFn lazily on first run(). Internally a Rust callback walks the IR via lower::lower_with_env; the shim wraps it as std::function, hands it to mc::compile, and stores the returned function. Subsequent calls replay the optimized trace.
  • Calibration + cost model: calibrate::Calibration::load_or_measure() measures sgemm GF/s at one large + one small shape plus a tiny-graph round-trip overhead, plus memory bandwidth (large contiguous copy), attention throughput (1×4×128×64 SDPA), and reduce throughput (1024×1024 sum-along-last-axis). Caches at ~/.cache/rlx/mlx-calib-<sanitized-device-name>.json and feeds rlx_runtime::cost::MlxCostModel so pick_best_device can rank MLX honestly.
  • Pool composition: Op::Pool is lowered by composing slice_strided over the kernel grid plus a reduction. Supports 1D / 2D / 3D inputs (channels-first layout) and all five reduction kinds (max/min/sum/mean/prod). Constant-pad with -∞ for max-pool, +∞ for min-pool, 1.0 for prod, 0 elsewhere.
  • DotGeneral lowering: the canonical 2D pattern (no batch dims, contract lhs[1] × rhs[0]) reduces to a plain MatMul, matching what the optimizer's LowerDotGeneral pass would have produced. Non-canonical patterns (batched, alternative contracting axes) error with a clear diagnostic — same coverage as the optimizer pass.
  • FusedTransformerLayer composition: the full BERT-style post-norm block (attention → residual+LN → FFN → residual+LN) composed from primitives. Honors all four mask kinds via the underlying SDPA path.
  • Op::If / Op::While are now lowered. We adopt a positional binding convention between the sub-graph's Op::Input nodes (in topo order) and the parent's captures (inputs[1..] for If, inputs[..] for While); sub-graph Op::Param nodes look up by name in the parent's param maps; sub-graph Op::Constant nodes are inline. Op::If evaluates both branches and combines via mc::where. Op::While requires max_iterations and unrolls; an active-mask gate via where(active && cond, body_out, carried) freezes loop-carried values once the condition becomes false. Single- output While only — multi-output convention isn't defined in the IR. Compile mode (MlxMode::Compiled) doesn't yet recurse through sub-graph leaves; If/While inside a compiled trace will fail with a missing-param diagnostic. Use Lazy/Eager for control flow.
  • SelectiveScan composition: Op::SelectiveScan (Mamba SSM step) is lowered by unrolling the time loop into seq many op chains. At each t we slice δ/x/B/C, broadcast against A, update the running state via exp(δA) * state + δ*B*x, and accumulate sum_n(C * state) as the output. Per-call cost amortizes through mlx::compile's trace cache. Acceptable for static-shape graphs (which all our graphs are); for very long sequences a custom Metal kernel via fast::metal_kernel would beat this on raw throughput.
  • Native ElementwiseRegion lowering (PLAN L2): Op::ElementwiseRegion is lowered in lower.rs by composing ops::* per ChainStep (Activation/Cast/Binary/Compare) directly into MLX's lazy trace. Each step is resolved positionally — ChainOperand::Input(i) reads node.inputs[i] and ChainOperand::Step(i) reads the array produced by chain step i. Because the whole chain becomes a sub-DAG inside MLX's trace, mlx::compile and the lazy evaluator get to fuse it into a single kernel — no decomposer round-trip and no extra Op nodes for the executor to walk. The runtime backend now runs MarkElementwiseRegions (instead of UnfuseElementwiseRegions) ahead of MLX compilation so chains are collapsed before lowering.

License

GPL-3.0-only.