Skip to main content

Crate mlx_native

Crate mlx_native 

Source
Expand description

§mlx-native

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon.

This crate provides a thin, safe wrapper around Apple’s Metal framework focused on compute shader dispatch for neural network inference. It is designed to be the GPU backend for the hf2q inference engine.

§Key Types

TypePurpose
MlxDeviceMetal device + command queue (entry point)
CommandEncoderBatched compute command submission
MlxBufferTyped Metal buffer with shape/dtype metadata
MlxBufferPoolArena allocator with power-of-two bucketing
KernelRegistryLazy MSL compilation + pipeline cache
DTypeElement data type enum
MlxErrorUnified error type (never panics)

§Quick Start

use mlx_native::{MlxDevice, DType};

let device = MlxDevice::new()?;
let buf = device.alloc_buffer(1024, DType::F32, vec![256])?;
let encoder = device.command_encoder()?;

§Design Principles

  • No panics — all public APIs return Result<T, MlxError>.
  • Zero-copyStorageModeShared buffers on Apple Silicon unified memory.
  • Thread-safeMlxDevice and MlxBuffer are Send + Sync.
  • Lazy compilation — MSL shaders compiled on first use, then cached.

Re-exports§

pub use graph::ComputeGraph;
pub use graph::GraphExecutor;
pub use graph::GraphSession;
pub use graph::OpKind;
pub use gguf::GgufFile;
pub use gguf::MetadataValue;
pub use gguf::TensorInfo;
pub use ops::dense_mm_bf16::dense_matmul_bf16_f32_tensor;
pub use ops::dense_mm_bf16::DenseMmBf16F32Params;
pub use ops::dense_mm_f16::dense_matmul_f16_f32_tensor;
pub use ops::dense_mm_f16::DenseMmF16F32Params;
pub use ops::dense_mm_f32_f32::dense_matmul_f32_f32_tensor;
pub use ops::dense_mm_f32_f32::DenseMmF32F32Params;
pub use ops::quantized_matmul::quantized_matmul;
pub use ops::quantized_matmul::quantized_matmul_simd;
pub use ops::quantized_matmul::QuantizedMatmulParams;
pub use ops::quantized_matmul_ggml::quantized_matmul_ggml;
pub use ops::quantized_matmul_ggml::quantized_matmul_mm_tensor_perm021;
pub use ops::quantized_matmul_ggml::quantized_matmul_mm_tensor_perm021_f16;
pub use ops::quantized_matmul_ggml::GgmlQuantizedMatmulParams;
pub use ops::quantized_matmul_ggml::GgmlQuantizedMatmulPerm021Params;
pub use ops::quantized_matmul_ggml::GgmlType;
pub use ops::quantized_matmul_ggml::MM_ROUTING_THRESHOLD;
pub use ops::mul_mv_ext::mul_mv_ext_dispatch;
pub use ops::mul_mv_ext::MulMvExtParams;
pub use ops::quantized_matmul_id::quantized_matmul_id;
pub use ops::quantized_matmul_id::quantized_matmul_id_into;
pub use ops::quantized_matmul_id::QuantizedMatmulIdParams;
pub use ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml;
pub use ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml_pooled;
pub use ops::quantized_matmul_id_ggml::quantized_matmul_id_swiglu_q4_0;
pub use ops::quantized_matmul_id_ggml::GgmlIdMmDispatchParams;
pub use ops::quantized_matmul_id_ggml::GgmlQuantizedMatmulIdParams;
pub use ops::quantized_matmul_id_ggml::IdMmScratch;
pub use ops::quantized_matmul_id_ggml::MM_ID_ROUTING_THRESHOLD;
pub use weight::load_quantized_weights;
pub use weight::safetensors_to_metal_buffer;
pub use weight::QuantizationConfig;
pub use weight::QuantizedWeight;
pub use weight::SafetensorsFile;
pub use weight::TensorQuantConfig;
pub use metal;

Modules§

encoder_worker
Persistent encoder worker thread (ADR-028 iter-380).
gguf
GGUF v3 file format parser.
graph
GraphExecutor — batched Metal dispatch for single-encoder forward passes.
kernel_profile
Per-command-buffer + per-dispatch GPU timing accumulator for kernel-level profiling.
metal_capture
Programmatic Metal Frame Capture wrapping (ADR-015 iter63 Part B).
ops
GPU kernel host-side dispatch functions.
tq_oracle
ADR-007 Path C F-0.1: CPU F32 oracle for flash_attn_vec_tq_hb decode.
turboquant
TurboQuant KV cache compression — CPU reference implementation.
weight
Weight loading from safetensors files into Metal GPU buffers.

Structs§

BufferRange
A buffer region recorded for dataflow tracking.
CommandEncoder
A batched compute command encoder.
DispatchRecord
Pre-baked dispatch record for hot decode paths.
EncoderSession
Session-level wrapper around a CommandEncoder for one or more logical transformer stages.
KernelRegistry
Registry that lazily compiles and caches Metal compute pipelines from embedded MSL source.
MTLSize
See https://developer.apple.com/documentation/metal/mtlsize
MemRanges
Cumulative dataflow state for a sequence of concurrent dispatches.
MlxBuffer
A Metal GPU buffer annotated with element dtype and tensor shape.
MlxBufferPool
Arena-style buffer pool that reuses Metal buffer allocations.
MlxDevice
Wraps a Metal device and its command queue.

Enums§

CapturedNode
A single captured compute dispatch or barrier sentinel.
CapturedOpKind
Operation kind tag for captured nodes, used by the fusion pass (4e.2).
DType
Element data type carried by an MlxBuffer.
DispatchKind
How to dispatch the recorded kernel.
KernelArg
A buffer or inline-bytes binding for a compute kernel argument slot.
MemRangeRole
Whether a recorded range was read by a dispatch (Src) or written by a dispatch (Dst). Mirrors ggml_mem_range_type in ggml-metal-common.h:14-17.
MlxError
Unified error type for all Metal GPU operations.
RecordedBinding
A recorded kernel argument binding.

Functions§

auto_barrier_concurrent_count
Read the cumulative number of dispatch_tracked calls that did NOT emit a barrier (ran concurrent with the previous group).
auto_barrier_count
Read the cumulative number of auto-emitted barriers across all encoders since process start (or last reset_counters).
barrier_count
Read the current value of BARRIER_COUNT.
barrier_total_ns
Read the total nanoseconds spent in the memoryBarrierWithScope: objc::msg_send! site. Only non-zero when MLX_PROFILE_BARRIERS=1 was in the environment at the time of the first memory_barrier() call (the env check is cached on first use).
cmd_buf_count
Read the current value of CMD_BUF_COUNT.
dispatch_count
Read the current value of DISPATCH_COUNT.
pipeline_dispatch_buckets
Public dump of MLX_DISP_BUCKET data: Vec<(label, count)> sorted descending by count. Returns empty when env-flag is off / never recorded.
reset_counters
Reset all counters to zero.
reset_pipeline_dispatch_buckets
Reset the per-pipeline dispatch buckets (typically called at decode start to ignore prefill / warmup contributions).
sync_count
Read the current value of SYNC_COUNT.

Type Aliases§

Result
Convenience alias used throughout the crate.