Please check the build logs for more information.
See Builds for ideas on how to fix a failed build, or Metadata for how to configure docs.rs builds.
If you believe this is docs.rs' fault, open an issue.
rlx-mlx
Apple MLX backend for RLX — vendored MLX via a hand-rolled C++ shim, eager + lazy + compiled execution.
Modes
- Lazy (default) — build the entire MLX graph in
run(), then callmlx::core::evalonce on all outputs. Lets MLX's optimizer schedule the whole DAG, equivalent in spirit to themps_graphpath in rlx-metal. - Eager — eval after every op. Slower; useful for debugging because failures surface at the offending op rather than at the final eval.
- Compiled —
mlx::compile-built persistent function for repeated shapes; trace-cache amortizes re-runs.
Mode is set per-compile via MlxExecutable::compile_with_mode, or
globally via RLX_MLX_MODE=eager|lazy|compiled (default lazy).
What's here
rlx-mlx-sys— vendored MLX (vendor/mlx), CMake build, andcpp/rlx_mlx_shim.{h,cpp}C ABI overmlx::core::*.src/— re-exportsrlx_mlx_sys::ffi; RAII wrappers and lowering:src/array.rs— RAIIArraywrapper,MlxError, top-leveleval.src/ops.rs— typed wrappers: matmul / add / mul / sub / div / softmax / gelu / silu / cast / layer_norm.src/lower.rs— walksrlx_ir::Graphin topo order, building MLX arrays for each node. Rebuilds the graph fresh eachrun()(see the comment in lower.rs for why).src/backend.rs—MlxExecutable(set_param / run / handles).- FFT — native
mlx::fft::fftviarlx_mlx_shimforOp::Fft; graph helpers (rfft,irfft, …) lower through the same path. - Tier-1 / Tier-2 / Tier-3 backward op parity with
rlx-cpufor reverse-mode autodiff (relu, activation, softmax cross-entropy, layer norm, conv2d, max-pool, fake-quantize).
Install
Native MLX lives in rlx-mlx-sys (submodule + build.rs). After clone:
[]
= { = "0.2", = ["mlx"] }
# or directly:
= "0.2"
= "0.2"
The first build compiles MLX from source — minutes on Linux CPU, ~1 hour if you
opt into the CUDA backend (--features cuda or RLX_MLX_CUDA=1). See
rlx-mlx-sys for compile-time tips (ccache, RLX_MLX_JOBS).
Linux device selection
Runtime backend inside MLX (CPU OpenBLAS vs CUDA GPU):
RLX_MLX_DEVICE=cpu RLX_MLX_DEVICE=gpu
Via WSL rig:
Full Linux/WSL guide and rlx-mlx CPU vs rlx-cpu matmul benchmarks:
docs/benchmarks/mlx-linux.md.
Build / test
Through rlx-runtime:
Status
Mature on Apple Silicon (M1 / M2 / M3 / M4). Linux/WSL: CPU MLX
compiles and passes parity tests; CUDA is opt-in. See
docs/benchmarks/mlx-linux.md.
On Intel Macs MLX falls back to its CPU path; supported but rarely the
right choice.
Gotchas
- Op coverage. First cut handled MatMul, Binary (Add/Mul/Sub/Div),
Activation (Gelu/Silu), Cast, Softmax, LayerNorm. Now covers matmul,
all binary / activation / cast / reduce / softmax / layer-norm /
RMS-norm, fused attention (SDPA via
fast::scaled_dot_product_attention), pool composition, dot-general, selective-scan unroll, calibrated cost model, async commit + sync. Anything else returnsMlxError("unsupported op …")fromlower::lower_and_run. Adding an op means: an entry incpp/shim.h, the matching impl inshim.cpp, anextern "C"decl inffi.rs, a wrapper inops.rs, and a match arm inlower.rs. - Fresh-graph-per-run. Every
run()rebuilds the MLX graph from scratch. MLX's own trace cache amortizes this, but if you need lower per-run latency, the next step ismlx::compile-style placeholder bindings (track the input/param NodeIds → MLX placeholder handles, reuse the compiled graph across runs). - F32 I/O default. Inputs/params come in as
&[f32]and outputs come out asVec<f32>. The shim casts to/from MLX's per-array dtype internally (so AutoMixedPrecision still does the right thing inside the graph). The runtime trait now exposesset_param_typed(name, &[u8], dtype)andrun_typed(inputs: &[(&str, &[u8], DType)]) -> Vec<(Vec<u8>, DType)>; default impls handle F32 only; the MLX backend overrides with the zero-widen path throughArray::from_bytes/Array::to_bytes. CPU and Metal inherit the F32 default — they panic for non-F32 typed inputs (override is a future PR for those backends). - Constants must be F32. Non-F32
Op::Constantpayloads error in lower.rs — the constant byte format is little-endian f32. Add F16/I32 constant decoding when a model needs it. - Async pipeline:
commit_no_waitschedules the lowered graph viamlx::core::async_evaland stashes the output handles;sync_pendingcallsmlx::core::synchronizeand drops them.run()always callssync_pending()first, so an explicit run() after a commit is safe. No per-stream isolation yet — synchronize() drains every MLX stream. - KV-cache pattern: if an output slot's name is
out{i}and a handle of the same name is bound,run()syncs the f32 result back into the handle so the next iteration picks it up as input. run_slotsarena: the slot path keeps a syntheticVec<u8>arena owned by the executable. Outputs are copied into it after eachrun_slotscall so callers can read results viaarena_ptr().add(offset)without per-outputVec<f32>allocations. Cheaper thanrun()when output sizes are tiny but the per-call bookkeeping cost matters.- Attention
SlidingWindowmask: synthesized host-side as an additive[seq_q, seq_k]mask (0 where allowed, -inf elsewhere), then passed throughfast::scaled_dot_product_attentionwithmode="array". MLX has no native sliding-window mode. - Sample: temperature scaling +
top_kfilter +top_p(nucleus) filter +mlx::random::categorical. top_k usesmc::topkfor the threshold; top_p sorts descending (viasort+ negate), takes an exclusive cumsum of the sorted probs, masks entries whose cumsum < top_p, picks the smallest probability still in that nucleus as the threshold, and applies it back to the original logits viawhere(p >= threshold, logits, -∞). - Persistent compiled graph (
MlxMode::Compiled): the executable builds aCompiledFnlazily on firstrun(). Internally a Rust callback walks the IR vialower::lower_with_env; the shim wraps it asstd::function, hands it tomc::compile, and stores the returned function. Subsequent calls replay the optimized trace. - Calibration + cost model:
calibrate::Calibration::load_or_measure()measures sgemm GF/s at one large + one small shape plus a tiny-graph round-trip overhead, plus memory bandwidth (large contiguous copy), attention throughput (1×4×128×64 SDPA), and reduce throughput (1024×1024 sum-along-last-axis). Caches at~/.cache/rlx/mlx-calib-<sanitized-device-name>.jsonand feedsrlx_runtime::cost::MlxCostModelsopick_best_devicecan rank MLX honestly. - Pool composition:
Op::Poolis lowered by composingslice_stridedover the kernel grid plus a reduction. Supports 1D / 2D / 3D inputs (channels-first layout) and all five reduction kinds (max/min/sum/mean/prod). Constant-pad with -∞ for max-pool, +∞ for min-pool, 1.0 for prod, 0 elsewhere. - DotGeneral lowering: the canonical 2D pattern (no batch dims,
contract
lhs[1]×rhs[0]) reduces to a plainMatMul, matching what the optimizer'sLowerDotGeneralpass would have produced. Non-canonical patterns (batched, alternative contracting axes) error with a clear diagnostic — same coverage as the optimizer pass. - FusedTransformerLayer composition: the full BERT-style post-norm block (attention → residual+LN → FFN → residual+LN) composed from primitives. Honors all four mask kinds via the underlying SDPA path.
Op::If/Op::Whileare now lowered. We adopt a positional binding convention between the sub-graph'sOp::Inputnodes (in topo order) and the parent's captures (inputs[1..]forIf,inputs[..]forWhile); sub-graphOp::Paramnodes look up by name in the parent's param maps; sub-graphOp::Constantnodes are inline.Op::Ifevaluates both branches and combines viamc::where.Op::Whilerequiresmax_iterationsand unrolls; an active-mask gate viawhere(active && cond, body_out, carried)freezes loop-carried values once the condition becomes false. Single- outputWhileonly — multi-output convention isn't defined in the IR. Compile mode (MlxMode::Compiled) doesn't yet recurse through sub-graph leaves;If/Whileinside a compiled trace will fail with a missing-param diagnostic. UseLazy/Eagerfor control flow.- SelectiveScan composition:
Op::SelectiveScan(Mamba SSM step) is lowered by unrolling the time loop into seq many op chains. At each t we slice δ/x/B/C, broadcast against A, update the running state viaexp(δA) * state + δ*B*x, and accumulatesum_n(C * state)as the output. Per-call cost amortizes throughmlx::compile's trace cache. Acceptable for static-shape graphs (which all our graphs are); for very long sequences a custom Metal kernel viafast::metal_kernelwould beat this on raw throughput. - Native ElementwiseRegion lowering (PLAN L2):
Op::ElementwiseRegionis lowered inlower.rsby composingops::*perChainStep(Activation/Cast/Binary/Compare) directly into MLX's lazy trace. Each step is resolved positionally —ChainOperand::Input(i)readsnode.inputs[i]andChainOperand::Step(i)reads the array produced by chain stepi. Because the whole chain becomes a sub-DAG inside MLX's trace,mlx::compileand the lazy evaluator get to fuse it into a single kernel — no decomposer round-trip and no extra Op nodes for the executor to walk. The runtime backend now runsMarkElementwiseRegions(instead ofUnfuseElementwiseRegions) ahead of MLX compilation so chains are collapsed before lowering.
License
GPL-3.0-only.