mlx-native
Pure-Rust Metal GPU compute library for transformer inference on Apple Silicon. Built as the GPU backend for the hf2q inference engine.
When to use this
mlx-native is the right tool when all of these hold:
- You're running transformer (or Mamba / Gated DeltaNet) inference on Apple Silicon
- Your weights are GGUF, MLX-quant, or safetensors (no PyTorch checkpoints, no ONNX)
- You want low Metal decode latency and are willing to drive a kernel-dispatch API
- You're fine assembling the forward pass yourself — there is no
Tensortype, noModulesystem, no model zoo
Reach for candle instead if you need autograd / training, multi-backend support (CUDA / CPU / WASM), Python bindings, ONNX import, a built-in model zoo, or a high-level tensor algebra surface. The two are complementary: candle is "PyTorch-shaped Rust ML framework," mlx-native is the Metal compute backend of a llama.cpp-shaped inference engine.
What we do that candle's Metal backend doesn't
- One
ComputeCommandEncoderper forward pass (GraphExecutor/GraphSession) — candle acquires an encoder per op and pools ~50 per command buffer - TurboQuant KV cache — Lloyd-Max codebooks (2 / 3 / 4-bit nibble-packed) and byte-packed higher-bit (5 / 6 / 8-bit) variants, with fused Hadamard incoherence transform
- MoE routing on GPU —
moe_gate+moe_softmax_topk+ expert-routed quantized matmul (no CPU round-trip for top-k expert selection) - Custom Metal kernels for state-space models —
gated_delta_net,ssm_conv,ssm_norm_gate,tri_solve,cumsum - Shape-specialized prefill — D=256 / D=512 tiled flash-attention kernels tuned for production model shapes (Qwen3, Gemma 3 / 4)
- Fused norm-family kernels —
fused_norm_add,fused_residual_norm,fused_post_attn_triple_norm,fused_moe_wsum_norm_add,fused_head_norm_rope - GPU-resident sampling —
softmax_sampleeliminates the logits-to-CPU readback on the hot path - Sliding-window KV cache copy with ring wrap — single GPU kernel instead of CPU-side index math
- Explicit barrier control —
session.barrier()andsession.barrier_between(reads, writes)for precise GPU sync between dependent ops
Trade-offs to know going in
- Apple Silicon only. No CPU, no CUDA, no WASM. If you need to ship cross-platform, this is the wrong layer.
- No autograd. A growing set of backward + optimizer kernels exists — SiLU / RMSNorm / softmax / log / row-sum / embedding-scatter / exp / divide / sqrt / outer-product / conv1d-depthwise-causal / MoE-weighted-sum / MoE-SwiGLU backward, differentiable affine qdq, Adam step, and
flash_attn_train(forward + backward through attention with dQ/dK/dV) — but you wire the training loop yourself; there is noVar/VarMap/ autodiff /Modulesystem. - GGML matmul coverage is the inference subset, not the full set. Q4_0, Q8_0, Q6_K have full mat-vec / mat-mat / tensor-mm and expert-routed variants. Q4_K and Q5_K have dense mat-vec / mat-mat plus expert-routed (
mm_id) variants. Q5_1 and IQ4_NL have dense and expert-routed variants. Q4_1, Q5_0, Q8_1, Q2_K, Q3_K, Q8_K are not supported in the Metal matmul path. MLX-format affine quantization supports 4 / 6 / 8-bit (no 3-bit). - No high-level model code. This is a kernel library; the consumer (e.g. hf2q) builds the actual transformer forward pass.
Status
Active development, pre-1.0. API may change between minor versions (0.x.0 → 0.(x+1).0 signals breaking changes). Public functions and structs evolve as new model families are added. Patch versions (0.x.y → 0.x.(y+1)) are non-breaking.
Supported model families used in production:
- Qwen3 / Qwen3.5 / Qwen3.6 (dense + MoE, GGUF)
- Gemma 3 / Gemma 4 (dense, with SWA + softcap, GQA)
- BERT-style embeddings (bge-small-en-v1.5)
- Generic transformer kernels for custom architectures
What is this?
A thin, safe wrapper around Apple's Metal framework focused on compute shader dispatch for neural network inference. It handles buffer management, MSL shader compilation, and GPU command encoding so callers can focus on graph construction and execution.
Apple Silicon only — leverages unified memory (StorageModeShared) for zero-copy CPU↔GPU buffer access.
Design principles
- No panics — all public APIs return
Result<T, MlxError> - Zero-copy —
StorageModeSharedbuffers on Apple Silicon unified memory - Thread-safe —
MlxDeviceandMlxBufferareSend + Sync - Lazy compilation — MSL shaders compiled on first use, then cached
- Buffer pooling — power-of-two arena allocator for reuse
- Single-encoder graphs —
GraphExecutorbatches dispatches for ~120× lower per-token overhead than per-op encoders (matches the llama.cpp pattern)
Quick start
A Q4_0 GGUF mat-vec dispatch:
use ;
let device = new?;
let mut registry = new;
let input = device.alloc_buffer?; // f32 input
let weight = /* mmap GGUF Q4_0 blocks into an MlxBuffer */;
let mut output = device.alloc_buffer?;
let mut enc = device.command_encoder?;
quantized_matmul_ggml?;
enc.commit_and_wait?;
For multi-op forward passes, use GraphExecutor to batch all dispatches into a single command buffer with one GPU sync:
let executor = new; // takes ownership
let mut session = executor.begin?;
session.rms_norm?;
session.barrier; // explicit barrier between dependent ops
session.quantized_matmul_ggml?;
session.barrier;
session.flash_attn_vec?;
session.finish?; // one commit_and_wait for the whole pass
Key types
| Type | Purpose |
|---|---|
MlxDevice |
Metal device + command queue (entry point) |
MlxBuffer |
Typed Metal buffer with shape/dtype metadata + byte_offset slicing |
MlxBufferPool |
Arena allocator with power-of-two bucketing |
CommandEncoder |
Compute command submission (single dispatch path) |
KernelRegistry |
Lazy MSL compilation + pipeline cache |
GraphExecutor / GraphSession |
Single-encoder batched forward passes |
ComputeGraph |
Recorded graph IR (capture, fuse, replay) |
DType |
Element data type enum (F32, F16, BF16, U8/16/32, I32) |
MlxError |
Unified error type |
GgufFile / TensorInfo |
GGUF model file mmap + metadata |
SafetensorsFile |
Safetensors mmap + tensor loading |
GPU operations
Attention
flash_attn_vec— SIMD-vectorized decode-path SDPA (NWG-parallel, llama.cpp port)flash_attn_vec_tq/flash_attn_vec_tq_hb— TurboQuant-quantized KV variants (Lloyd-Max + Hadamard)flash_attn_vec_hybrid— F16-K + TQ-HB-V SDPA (memory savings without full KV quant cost)flash_attn_vec_peer_port_f16(+_nwg32NWG=32 variant with reduce dispatcher) — verbatim peer kernel port for F16 decodeflash_attn_prefill(D=256, D=512) — Tiled prefill with bf16 kernels, SWA mask, sentinel handling — plus F16/BF16_resumedispatchers for restart from arbitraryqLoffsetflash_attn_train— forward + backward (dQ/dK/dVvia FA-2 Algorithm 4) bf16 kernels at D=64 / D=256, the missing piece for transformer training on this backendsdpa/sdpa_sliding— Reference SDPA with optional sliding window;do_causalflag toggles causal vs bidirectional (DFlash drafter block-diffusion)sdpa_decode— Tiled decode-path SDPA with N_SG=4 simdgroups
Matrix multiplication
- GGUF formats: Q4_0, Q4_K, Q5_K, Q5_1, Q6_K, Q8_0, IQ4_NL, I16 — mat-vec + mul_mm tensor-core kernels (peer-parity with llama.cpp inference subset)
- GGUF expert-routed (
mm_id): Q4_0, Q4_K, Q5_K, Q5_1, Q6_K, Q8_0, IQ4_NL (top_k>1 MoE mat-vec + tensor-mm) - MLX format: 4/6/8-bit affine quantization (
quantized_matmul) - MLX fused dequant+matmul:
qmm_affine_t_f32+qmm_affine_t_f32_tiled(2.29× over non-tiled), simdgroup-MMAqmm_affine_t_f32_simd/qmm_affine_simd4variants, and packed-U32qmm_affine_t_packed_simd4_b4 - MoE expert-routed:
quantized_matmul_id/_id_ggml/_id_into(top_k=1 tensor-mm fast path;_intoaccepts caller-provided output buffer) - Dense BF16:
dense_mm_bf16_tensor,dense_gemv_bf16_f32(M=1 decode) - Dense F16:
dense_gemm_f16,dense_matvec_f16
Normalization
rms_norm— RMS normalization (f32 + triple-output variants)l2_norm— L2 normalizationfused_residual_norm— RMS norm + residual addfused_norm_add— MoE weighted_sum + RMS norm + addfused_head_norm_rope— Per-head RMS norm + RoPE (with bf16 co-write variants)
Activation & gating
gelu— GeLU activation (F32, BF16)silu_mul— SwiGLU (SiLU + elementwise multiply)sigmoid_mul— Sigmoid-gated multiplysoftmax,softcap,scale_mask_softmax— Softmax variantssoftmax_sample— Sampling from logits
Position encoding
rope— Standard RoPErope_multi— Multi-axis RoPE with IMROPE (Qwen3.5) and Vision (Qwen3-VL ViT 2D positions) modes
MoE
moe_gate— Gate logits → weightsmoe_softmax_topk— GPU softmax + top-k expert selectionmoe_dispatch— Per-expert matvec sequence with proper barriersmoe_weighted_reduce— Weighted sum across selected experts
State-space (Mamba/Gated DeltaNet)
ssm_conv— Depthwise causal 1D convolution + SiLUssm_norm_gate— Norm + gate fusion (eliminates CPU bridge)gated_delta_net— Fused GDN kernelcompute_g_beta— GDN g/beta computationtri_solve— Lower-triangular unit-diagonal forward substitutioncumsum— Cumulative sum
Memory & layout
kv_cache_copy— Linear + sliding-window KV cache copy (with ring-wrap)kv_cache_copy_seq_bf16/_seq_bf16_to_bf16_head_major— BF16 sequence-batched cache copies (incl. head-major layout for prefill)embedding— Embedding lookupgather— Indexed gather (F16, nibble-packed)transpose,permute_021— Layout conversionscopy,offset_copy— Strided copyargmax,argsort,top_k— Reductions
Dispatch pre-bake (ADR-029)
Pre-baked DispatchRecord objects skip per-dispatch pipeline lookups, env-var reads, and parameter packing — meaningful on short-prompt decode hot paths.
build_q6k_nr2_m1_record— dense Q6_K mv NR2 m=1build_q6k_id_nr2_m1_record— MoE Q6_K_ID NR2 m=1build_q8_0_id_decode_record— MoE Q8_0_ID regular decodebuild_rms_norm_decode_record— per-(dtype, rows, dim) RMSNorm decode
Vision / ViT (Qwen3-VL prelude)
im2col_2d_3ch_f32+add_bias_row_2d_f32— patch-embed helpersbilinear_resize_2d_f32— antialiased 2-D resizeblock_merge_2x2_f32— 2×2 spatial merge / permutationfeature_concat_f32— strided channel-axis concat
Hadamard / TurboQuant
hadamard— Standalone FWHT (D=128/256/512)hadamard_quantize_kv— Fused Hadamard + KV quantizationtq_dequantize_kv— TurboQuant KV dequantization
Quantize / dequantize (qdq)
qdq_q4_0_f32,qdq_q8_0_f32— GPU-side dequant for legacy GGUF blocksqdq_affine_init_f32/qdq_affine_forward_f32— MLX-format affine qdq with differentiable variantsqdq_affine_backward_scales_f32,qdq_affine_backward_biases_f32— backward through quantization parameters
Backward & training kernels
flash_attn_train_fwd_bf16_{d64,d256}+flash_attn_train_bwd_bf16_{d64,d256}— attention forward (with logsumexp output) and backward (dQ / dK / dV via FA-2 Algorithm 4)silu_backward_f32,softmax_backward,log_backward_f32,row_sum_backward_f32,exp_backward_f32,divide_backward_f32,sqrt_backward_f32rms_norm_compute_rms_inv+rms_norm_backward_dx+rms_norm_backward_dwouter_productforward + backwardconv1d_depthwise_causalforward + backwardtake_along_axis(gather + scatter-backward)moe_weighted_sum_seqbackward;moe_swiglu_seqfused backwardembedding_lookup_f32+embedding_scatter_add_f32(forward + scatter-add backward)adam_update_f32— fused Adam optimizer step (m / v moments + bias-correction)slice_2d_cols_f32+copy_2d_cols_into_f32— strided 2-D slice / scatter for column-major training layouts
Weight loading
Load safetensors and GGUF models directly into Metal buffers via mmap:
use ;
let device = new?;
// Safetensors — returns (dtype, shape, buffer)
use Path;
let st = open?;
let =
st.load_tensor?;
// GGUF — raw block format passed through to GPU (no intermediate dequant)
let gguf = open?;
for name in gguf.tensor_names
Third-party licenses
This crate includes Metal kernels and dispatch code derived from:
Per-file attribution headers identify which kernels are derived from which upstream.
License
MIT — see LICENSE.