docs.rs failed to build rlx-mlx-0.2.0
Please check the build logs for more information.
See Builds for ideas on how to fix a failed build, or Metadata for how to configure docs.rs builds.
If you believe this is docs.rs' fault, open an issue.
Please check the build logs for more information.
See Builds for ideas on how to fix a failed build, or Metadata for how to configure docs.rs builds.
If you believe this is docs.rs' fault, open an issue.
rlx-mlx
Apple MLX backend for RLX — vendored MLX via a hand-rolled C++ shim, eager + lazy + compiled execution.
Modes
- Lazy (default) — build the entire MLX graph in
run(), then callmlx::core::evalonce on all outputs. Lets MLX's optimizer schedule the whole DAG, equivalent in spirit to themps_graphpath in rlx-metal. - Eager — eval after every op. Slower; useful for debugging because failures surface at the offending op rather than at the final eval.
- Compiled —
mlx::compile-built persistent function for repeated shapes; trace-cache amortizes re-runs.
Mode is set per-compile via MlxExecutable::compile_with_mode, or
globally via RLX_MLX_MODE=eager|lazy|compiled (default lazy).
What's here
rlx-mlx-sys— vendored MLX (vendor/mlx), CMake build, andcpp/rlx_mlx_shim.{h,cpp}C ABI overmlx::core::*.src/— re-exportsrlx_mlx_sys::ffi; RAII wrappers and lowering:src/array.rs— RAIIArraywrapper,MlxError, top-leveleval.src/ops.rs— typed wrappers: matmul / add / mul / sub / div / softmax / gelu / silu / cast / layer_norm.src/lower.rs— walksrlx_ir::Graphin topo order, building MLX arrays for each node. Rebuilds the graph fresh eachrun()(see the comment in lower.rs for why).src/backend.rs—MlxExecutable(set_param / run / handles).- Tier-1 / Tier-2 / Tier-3 backward op parity with
rlx-cpufor reverse-mode autodiff (relu, activation, softmax cross-entropy, layer norm, conv2d, max-pool, fake-quantize).
Install
Native MLX lives in rlx-mlx-sys (submodule + build.rs). After clone:
[]
= { = "0.2", = ["mlx"] }
# or directly:
= "0.2"
= "0.2"
The first build compiles MLX from source — minutes, not seconds.
Build / test
Through rlx-runtime:
Status
Mature on Apple Silicon (M1 / M2 / M3 / M4). On Intel Macs MLX falls back to its CPU path; supported but rarely the right choice.
Gotchas
- Op coverage. First cut handled MatMul, Binary (Add/Mul/Sub/Div),
Activation (Gelu/Silu), Cast, Softmax, LayerNorm. Now covers matmul,
all binary / activation / cast / reduce / softmax / layer-norm /
RMS-norm, fused attention (SDPA via
fast::scaled_dot_product_attention), pool composition, dot-general, selective-scan unroll, calibrated cost model, async commit + sync. Anything else returnsMlxError("unsupported op …")fromlower::lower_and_run. Adding an op means: an entry incpp/shim.h, the matching impl inshim.cpp, anextern "C"decl inffi.rs, a wrapper inops.rs, and a match arm inlower.rs. - Fresh-graph-per-run. Every
run()rebuilds the MLX graph from scratch. MLX's own trace cache amortizes this, but if you need lower per-run latency, the next step ismlx::compile-style placeholder bindings (track the input/param NodeIds → MLX placeholder handles, reuse the compiled graph across runs). - F32 I/O default. Inputs/params come in as
&[f32]and outputs come out asVec<f32>. The shim casts to/from MLX's per-array dtype internally (so AutoMixedPrecision still does the right thing inside the graph). The runtime trait now exposesset_param_typed(name, &[u8], dtype)andrun_typed(inputs: &[(&str, &[u8], DType)]) -> Vec<(Vec<u8>, DType)>; default impls handle F32 only; the MLX backend overrides with the zero-widen path throughArray::from_bytes/Array::to_bytes. CPU and Metal inherit the F32 default — they panic for non-F32 typed inputs (override is a future PR for those backends). - Constants must be F32. Non-F32
Op::Constantpayloads error in lower.rs — the constant byte format is little-endian f32. Add F16/I32 constant decoding when a model needs it. - Async pipeline:
commit_no_waitschedules the lowered graph viamlx::core::async_evaland stashes the output handles;sync_pendingcallsmlx::core::synchronizeand drops them.run()always callssync_pending()first, so an explicit run() after a commit is safe. No per-stream isolation yet — synchronize() drains every MLX stream. - KV-cache pattern: if an output slot's name is
out{i}and a handle of the same name is bound,run()syncs the f32 result back into the handle so the next iteration picks it up as input. run_slotsarena: the slot path keeps a syntheticVec<u8>arena owned by the executable. Outputs are copied into it after eachrun_slotscall so callers can read results viaarena_ptr().add(offset)without per-outputVec<f32>allocations. Cheaper thanrun()when output sizes are tiny but the per-call bookkeeping cost matters.- Attention
SlidingWindowmask: synthesized host-side as an additive[seq_q, seq_k]mask (0 where allowed, -inf elsewhere), then passed throughfast::scaled_dot_product_attentionwithmode="array". MLX has no native sliding-window mode. - Sample: temperature scaling +
top_kfilter +top_p(nucleus) filter +mlx::random::categorical. top_k usesmc::topkfor the threshold; top_p sorts descending (viasort+ negate), takes an exclusive cumsum of the sorted probs, masks entries whose cumsum < top_p, picks the smallest probability still in that nucleus as the threshold, and applies it back to the original logits viawhere(p >= threshold, logits, -∞). - Persistent compiled graph (
MlxMode::Compiled): the executable builds aCompiledFnlazily on firstrun(). Internally a Rust callback walks the IR vialower::lower_with_env; the shim wraps it asstd::function, hands it tomc::compile, and stores the returned function. Subsequent calls replay the optimized trace. - Calibration + cost model:
calibrate::Calibration::load_or_measure()measures sgemm GF/s at one large + one small shape plus a tiny-graph round-trip overhead, plus memory bandwidth (large contiguous copy), attention throughput (1×4×128×64 SDPA), and reduce throughput (1024×1024 sum-along-last-axis). Caches at~/.cache/rlx/mlx-calib-<sanitized-device-name>.jsonand feedsrlx_runtime::cost::MlxCostModelsopick_best_devicecan rank MLX honestly. - Pool composition:
Op::Poolis lowered by composingslice_stridedover the kernel grid plus a reduction. Supports 1D / 2D / 3D inputs (channels-first layout) and all five reduction kinds (max/min/sum/mean/prod). Constant-pad with -∞ for max-pool, +∞ for min-pool, 1.0 for prod, 0 elsewhere. - DotGeneral lowering: the canonical 2D pattern (no batch dims,
contract
lhs[1]×rhs[0]) reduces to a plainMatMul, matching what the optimizer'sLowerDotGeneralpass would have produced. Non-canonical patterns (batched, alternative contracting axes) error with a clear diagnostic — same coverage as the optimizer pass. - FusedTransformerLayer composition: the full BERT-style post-norm block (attention → residual+LN → FFN → residual+LN) composed from primitives. Honors all four mask kinds via the underlying SDPA path.
Op::If/Op::Whileare now lowered. We adopt a positional binding convention between the sub-graph'sOp::Inputnodes (in topo order) and the parent's captures (inputs[1..]forIf,inputs[..]forWhile); sub-graphOp::Paramnodes look up by name in the parent's param maps; sub-graphOp::Constantnodes are inline.Op::Ifevaluates both branches and combines viamc::where.Op::Whilerequiresmax_iterationsand unrolls; an active-mask gate viawhere(active && cond, body_out, carried)freezes loop-carried values once the condition becomes false. Single- outputWhileonly — multi-output convention isn't defined in the IR. Compile mode (MlxMode::Compiled) doesn't yet recurse through sub-graph leaves;If/Whileinside a compiled trace will fail with a missing-param diagnostic. UseLazy/Eagerfor control flow.- SelectiveScan composition:
Op::SelectiveScan(Mamba SSM step) is lowered by unrolling the time loop into seq many op chains. At each t we slice δ/x/B/C, broadcast against A, update the running state viaexp(δA) * state + δ*B*x, and accumulatesum_n(C * state)as the output. Per-call cost amortizes throughmlx::compile's trace cache. Acceptable for static-shape graphs (which all our graphs are); for very long sequences a custom Metal kernel viafast::metal_kernelwould beat this on raw throughput. - Native ElementwiseRegion lowering (PLAN L2):
Op::ElementwiseRegionis lowered inlower.rsby composingops::*perChainStep(Activation/Cast/Binary/Compare) directly into MLX's lazy trace. Each step is resolved positionally —ChainOperand::Input(i)readsnode.inputs[i]andChainOperand::Step(i)reads the array produced by chain stepi. Because the whole chain becomes a sub-DAG inside MLX's trace,mlx::compileand the lazy evaluator get to fuse it into a single kernel — no decomposer round-trip and no extra Op nodes for the executor to walk. The runtime backend now runsMarkElementwiseRegions(instead ofUnfuseElementwiseRegions) ahead of MLX compilation so chains are collapsed before lowering.
License
GPL-3.0-only.