mlx-native
Pure-Rust Metal GPU compute library for transformer inference on Apple Silicon. Built as the GPU backend for the hf2q inference engine.
Status
Active development, pre-1.0. API may change between minor versions (0.x.0 → 0.(x+1).0 signals breaking changes). Public functions and structs evolve as new model families are added. Patch versions (0.x.y → 0.x.(y+1)) are non-breaking.
Supported model families used in production:
- Qwen3 / Qwen3.5 / Qwen3.6 (dense + MoE, GGUF)
- Gemma 3 / Gemma 4 (dense, with SWA + softcap, GQA)
- BERT-style embeddings (bge-small-en-v1.5)
- Generic transformer kernels for custom architectures
What is this?
A thin, safe wrapper around Apple's Metal framework focused on compute shader dispatch for neural network inference. It handles buffer management, MSL shader compilation, and GPU command encoding so callers can focus on graph construction and execution.
Apple Silicon only — leverages unified memory (StorageModeShared) for zero-copy CPU↔GPU buffer access.
Design principles
- No panics — all public APIs return
Result<T, MlxError> - Zero-copy —
StorageModeSharedbuffers on Apple Silicon unified memory - Thread-safe —
MlxDeviceandMlxBufferareSend + Sync - Lazy compilation — MSL shaders compiled on first use, then cached
- Buffer pooling — power-of-two arena allocator for reuse
- Single-encoder graphs —
GraphExecutorbatches dispatches for ~120× lower per-token overhead than per-op encoders (matches the llama.cpp pattern)
Quick start
A Q4_0 GGUF mat-vec dispatch:
use ;
let device = new?;
let mut registry = new;
let input = device.alloc_buffer?; // f32 input
let weight = /* mmap GGUF Q4_0 blocks into an MlxBuffer */;
let mut output = device.alloc_buffer?;
let mut enc = device.command_encoder?;
quantized_matmul_ggml?;
enc.commit_and_wait?;
For multi-op forward passes, use GraphExecutor to batch all dispatches into a single command buffer with one GPU sync:
let executor = new; // takes ownership
let mut session = executor.begin?;
session.rms_norm?;
session.barrier; // explicit barrier between dependent ops
session.quantized_matmul_ggml?;
session.barrier;
session.flash_attn_vec?;
session.finish?; // one commit_and_wait for the whole pass
Key types
| Type | Purpose |
|---|---|
MlxDevice |
Metal device + command queue (entry point) |
MlxBuffer |
Typed Metal buffer with shape/dtype metadata + byte_offset slicing |
MlxBufferPool |
Arena allocator with power-of-two bucketing |
CommandEncoder |
Compute command submission (single dispatch path) |
KernelRegistry |
Lazy MSL compilation + pipeline cache |
GraphExecutor / GraphSession |
Single-encoder batched forward passes |
ComputeGraph |
Recorded graph IR (capture, fuse, replay) |
DType |
Element data type enum (F32, F16, BF16, U8/16/32, I32) |
MlxError |
Unified error type |
GgufFile / TensorInfo |
GGUF model file mmap + metadata |
SafetensorsFile |
Safetensors mmap + tensor loading |
GPU operations
Attention
flash_attn_vec— SIMD-vectorized decode-path SDPA (NWG-parallel, llama.cpp port)flash_attn_vec_tq/flash_attn_vec_tq_hb— TurboQuant-quantized KV variants (Lloyd-Max + Hadamard)flash_attn_prefill(D=256, D=512) — Tiled prefill with bf16 kernels, SWA mask, sentinel handlingsdpa/sdpa_sliding— Reference SDPA with optional sliding windowsdpa_decode— Tiled decode-path SDPA with N_SG=4 simdgroups
Matrix multiplication
- GGUF formats: Q4_0, Q5_K, Q6_K, Q8_0, I16 — mat-vec + mul_mm tensor-core kernels
- MLX format: 4/6/8-bit affine quantization (
quantized_matmul) - MoE expert-routed:
quantized_matmul_id/_id_ggml(top_k=1 tensor-mm fast path) - Dense BF16:
dense_mm_bf16_tensor,dense_gemv_bf16_f32(M=1 decode) - Dense F16:
dense_gemm_f16,dense_matvec_f16
Normalization
rms_norm— RMS normalization (f32 + triple-output variants)l2_norm— L2 normalizationfused_residual_norm— RMS norm + residual addfused_norm_add— MoE weighted_sum + RMS norm + addfused_head_norm_rope— Per-head RMS norm + RoPE (with bf16 co-write variants)
Activation & gating
gelu— GeLU activation (F32, BF16)silu_mul— SwiGLU (SiLU + elementwise multiply)sigmoid_mul— Sigmoid-gated multiplysoftmax,softcap,scale_mask_softmax— Softmax variantssoftmax_sample— Sampling from logits
Position encoding
rope— Standard RoPErope_multi— Multi-axis RoPE with IMROPE mode (Qwen3.5)
MoE
moe_gate— Gate logits → weightsmoe_softmax_topk— GPU softmax + top-k expert selectionmoe_dispatch— Per-expert matvec sequence with proper barriersmoe_weighted_reduce— Weighted sum across selected experts
State-space (Mamba/Gated DeltaNet)
ssm_conv— Depthwise causal 1D convolution + SiLUssm_norm_gate— Norm + gate fusion (eliminates CPU bridge)gated_delta_net— Fused GDN kernelcompute_g_beta— GDN g/beta computationtri_solve— Lower-triangular unit-diagonal forward substitutioncumsum— Cumulative sum
Memory & layout
kv_cache_copy— Linear + sliding-window KV cache copy (with ring-wrap)embedding— Embedding lookupgather— Indexed gather (F16, nibble-packed)transpose,permute_021— Layout conversionscopy,offset_copy— Strided copyargmax,argsort,top_k— Reductions
Hadamard / TurboQuant
hadamard— Standalone FWHT (D=128/256/512)hadamard_quantize_kv— Fused Hadamard + KV quantizationtq_dequantize_kv— TurboQuant KV dequantization
Weight loading
Load safetensors and GGUF models directly into Metal buffers via mmap:
use ;
let device = new?;
// Safetensors — returns (dtype, shape, buffer)
use Path;
let st = open?;
let =
st.load_tensor?;
// GGUF — raw block format passed through to GPU (no intermediate dequant)
let gguf = open?;
for name in gguf.tensor_names
Third-party licenses
This crate includes Metal kernels and dispatch code derived from:
Per-file attribution headers identify which kernels are derived from which upstream.
License
MIT — see LICENSE.