Skip to main content

mlx_native/
lib.rs

1//! # mlx-native
2//!
3//! Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple
4//! Silicon.
5//!
6//! This crate provides a thin, safe wrapper around Apple's Metal framework
7//! focused on compute shader dispatch for neural network inference.  It is
8//! designed to be the GPU backend for the `hf2q` inference engine.
9//!
10//! ## Key Types
11//!
12//! | Type | Purpose |
13//! |------|---------|
14//! | [`MlxDevice`]       | Metal device + command queue (entry point) |
15//! | [`CommandEncoder`]   | Batched compute command submission |
16//! | [`MlxBuffer`]        | Typed Metal buffer with shape/dtype metadata |
17//! | [`MlxBufferPool`]    | Arena allocator with power-of-two bucketing |
18//! | [`KernelRegistry`]   | Lazy MSL compilation + pipeline cache |
19//! | [`DType`]            | Element data type enum |
20//! | [`MlxError`]         | Unified error type (never panics) |
21//!
22//! ## Quick Start
23//!
24//! ```ignore
25//! use mlx_native::{MlxDevice, DType};
26//!
27//! let device = MlxDevice::new()?;
28//! let buf = device.alloc_buffer(1024, DType::F32, vec![256])?;
29//! let encoder = device.command_encoder()?;
30//! ```
31//!
32//! ## Design Principles
33//!
34//! * **No panics** — all public APIs return `Result<T, MlxError>`.
35//! * **Zero-copy** — `StorageModeShared` buffers on Apple Silicon unified memory.
36//! * **Thread-safe** — `MlxDevice` and `MlxBuffer` are `Send + Sync`.
37//! * **Lazy compilation** — MSL shaders compiled on first use, then cached.
38
39// Enforce the no-panic policy at compile time.
40#![deny(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
41// The `objc` crate's `msg_send!` macro internally checks `cfg(feature = "cargo-clippy")`
42// which triggers unexpected_cfgs warnings. Suppress at crate level since we can't
43// control the macro expansion site.
44#![allow(unexpected_cfgs)]
45
46// ---- internal modules ----
47#[macro_use]
48mod error;
49mod buffer;
50mod buffer_pool;
51mod device;
52mod dtypes;
53mod encoder;
54mod encoder_session;
55pub mod encoder_worker;
56mod env_flags;
57mod kernel_registry;
58mod mem_ranges;
59mod residency;
60pub mod gguf;
61pub mod kernel_profile;
62pub mod graph;
63pub mod metal_capture;
64pub mod ops;
65pub mod turboquant;
66pub mod tq_oracle;
67pub mod weight;
68
69// ---- public re-exports ----
70pub use buffer::MlxBuffer;
71pub use buffer_pool::MlxBufferPool;
72pub use device::MlxDevice;
73pub use dtypes::DType;
74pub use encoder::{
75    auto_barrier_concurrent_count, auto_barrier_count, barrier_count, barrier_total_ns,
76    cmd_buf_count, dispatch_count, pipeline_dispatch_buckets, reset_counters,
77    reset_pipeline_dispatch_buckets, sync_count, CapturedNode, CapturedOpKind,
78    CommandEncoder, DispatchKind, KernelArg, RecordedBinding,
79};
80pub use encoder_session::EncoderSession;
81pub use mem_ranges::{BufferRange, MemRangeRole, MemRanges};
82pub use error::{MlxError, Result};
83pub use graph::{ComputeGraph, GraphExecutor, GraphSession, OpKind};
84pub use kernel_registry::KernelRegistry;
85// Test-only counters and gate-reset helpers.  Marked #[doc(hidden)] so
86// they don't appear in published rustdoc; consumers should not depend
87// on them outside test code.  Not feature-gated because integration
88// tests in tests/ are a separate crate and cannot rely on the lib's
89// `test` cfg flag.
90#[doc(hidden)]
91pub use residency::{
92    macos_15_or_newer_for_test, reset_residency_env_cache_for_test,
93    reset_residency_test_counters, residency_allocation_count_for_test,
94    residency_commit_call_count_for_test,
95};
96
97// Re-export GGUF parser.
98pub use gguf::{GgufFile, MetadataValue, TensorInfo};
99
100// Re-export ops.
101pub use ops::dense_mm_bf16::{dense_matmul_bf16_f32_tensor, DenseMmBf16F32Params};
102pub use ops::dense_mm_f16::{dense_matmul_f16_f32_tensor, DenseMmF16F32Params};
103pub use ops::dense_mm_f32_f32::{dense_matmul_f32_f32_tensor, DenseMmF32F32Params};
104pub use ops::quantized_matmul::{quantized_matmul, quantized_matmul_simd, QuantizedMatmulParams};
105pub use ops::quantized_matmul_ggml::{
106    dispatch_mm_for_test, quantized_matmul_ggml, quantized_matmul_mm_tensor_perm021,
107    quantized_matmul_mm_tensor_perm021_f16,
108    GgmlQuantizedMatmulParams, GgmlQuantizedMatmulPerm021Params, GgmlType,
109    MM_ROUTING_THRESHOLD,
110};
111pub use ops::mul_mv_ext::{mul_mv_ext_dispatch, MulMvExtParams};
112pub use ops::quantized_matmul_id::{
113    quantized_matmul_id, quantized_matmul_id_into, QuantizedMatmulIdParams,
114};
115pub use ops::quantized_matmul_id_ggml::{
116    dispatch_id_mm_for_test, quantized_matmul_id_ggml, quantized_matmul_id_ggml_pooled,
117    quantized_matmul_id_swiglu_q4_0,
118    GgmlIdMmDispatchParams, GgmlQuantizedMatmulIdParams, IdMmScratch,
119    MM_ID_ROUTING_THRESHOLD,
120};
121
122// Re-export weight loading utilities.
123pub use weight::{
124    load_quantized_weights, safetensors_to_metal_buffer, QuantizationConfig, QuantizedWeight,
125    SafetensorsFile, TensorQuantConfig,
126};
127
128// Re-export metal types that appear in the public API.
129pub use metal::MTLSize;
130pub use metal;
131
132#[cfg(test)]
133#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
134mod tests {
135    use super::*;
136
137    // ---- T10.7: compile-time Send + Sync assertions ----
138    fn _assert_send<T: Send>() {}
139    fn _assert_sync<T: Sync>() {}
140
141    #[allow(dead_code)]
142    fn assert_send_sync() {
143        _assert_send::<MlxDevice>();
144        _assert_sync::<MlxDevice>();
145        _assert_send::<MlxBuffer>();
146        _assert_sync::<MlxBuffer>();
147        _assert_send::<MlxError>();
148        _assert_sync::<MlxError>();
149    }
150
151    // ---- T10.1: device initialization ----
152    #[test]
153    fn test_device_init() {
154        let device = MlxDevice::new().expect("MlxDevice::new() should succeed on Apple Silicon");
155        let name = device.name();
156        assert!(!name.is_empty(), "Device name should not be empty");
157        println!("Metal device: {name}");
158    }
159
160    // ---- T10.2: buffer allocation ----
161    #[test]
162    fn test_buffer_alloc() {
163        let device = MlxDevice::new().expect("device");
164        let shape = vec![2, 3, 4];
165        let byte_len = 2 * 3 * 4 * DType::F32.size_of(); // 96 bytes
166        let buf = device
167            .alloc_buffer(byte_len, DType::F32, shape.clone())
168            .expect("alloc_buffer");
169
170        assert_eq!(buf.dtype(), DType::F32);
171        assert_eq!(buf.shape(), &shape);
172        assert_eq!(buf.byte_len(), byte_len);
173        assert_eq!(buf.element_count(), 24);
174    }
175
176    // ---- T10.3: buffer read/write round-trip ----
177    #[test]
178    fn test_buffer_readwrite() {
179        let device = MlxDevice::new().expect("device");
180        let n = 64;
181        let byte_len = n * std::mem::size_of::<f32>();
182        let mut buf = device
183            .alloc_buffer(byte_len, DType::F32, vec![n])
184            .expect("alloc_buffer");
185
186        // Write known data.
187        {
188            let slice: &mut [f32] = buf.as_mut_slice().expect("as_mut_slice");
189            assert_eq!(slice.len(), n);
190            for (i, val) in slice.iter_mut().enumerate() {
191                *val = i as f32 * 1.5;
192            }
193        }
194
195        // Read back and verify.
196        {
197            let slice: &[f32] = buf.as_slice().expect("as_slice");
198            for (i, &val) in slice.iter().enumerate() {
199                let expected = i as f32 * 1.5;
200                assert!(
201                    (val - expected).abs() < f32::EPSILON,
202                    "Mismatch at index {i}: got {val}, expected {expected}"
203                );
204            }
205        }
206    }
207
208    // ---- T10.4: encoder lifecycle ----
209    #[test]
210    fn test_encoder_lifecycle() {
211        let device = MlxDevice::new().expect("device");
212        let mut enc = device.command_encoder().expect("command_encoder");
213        // Commit an empty command buffer — should succeed (no-op on GPU).
214        enc.commit_and_wait()
215            .expect("commit_and_wait on empty encoder");
216    }
217
218    // ---- T10.5: buffer pool reuse ----
219    #[test]
220    fn test_buffer_pool_reuse() {
221        let device = MlxDevice::new().expect("device");
222        let mut pool = MlxBufferPool::new();
223
224        // Allocate a buffer.
225        let buf1 = pool
226            .alloc(&device, 1024, DType::F32, vec![256])
227            .expect("pool alloc 1");
228        let buf1_ptr = buf1.contents_ptr();
229        let buf1_byte_len = buf1.byte_len();
230
231        // Release it back to the pool.
232        pool.release(buf1);
233        assert_eq!(pool.free_count(), 1);
234
235        // Allocate again — should reuse the same Metal buffer.
236        let buf2 = pool
237            .alloc(&device, 1024, DType::F32, vec![256])
238            .expect("pool alloc 2");
239        let buf2_ptr = buf2.contents_ptr();
240        let buf2_byte_len = buf2.byte_len();
241
242        assert_eq!(buf1_ptr, buf2_ptr, "Pool should reuse the same Metal buffer");
243        assert_eq!(buf1_byte_len, buf2_byte_len, "Byte lengths should match");
244        assert_eq!(pool.free_count(), 0, "Free list should be empty after reuse");
245    }
246
247    // ---- T10.6: kernel registry caching ----
248    #[test]
249    fn test_kernel_registry_caching() {
250        let device = MlxDevice::new().expect("device");
251        let mut registry = KernelRegistry::new();
252
253        // Register a minimal test kernel.
254        registry.register_source(
255            "test_add",
256            r#"
257            #include <metal_stdlib>
258            using namespace metal;
259            kernel void test_add(
260                device float *a [[buffer(0)]],
261                device float *b [[buffer(1)]],
262                device float *c [[buffer(2)]],
263                uint id [[thread_position_in_grid]]
264            ) {
265                c[id] = a[id] + b[id];
266            }
267            "#,
268        );
269
270        // First call — compiles the shader.
271        assert!(!registry.is_cached("test_add"));
272        let p1 = registry
273            .get_pipeline("test_add", device.metal_device())
274            .expect("get_pipeline first call");
275        let p1_ptr = p1 as *const _;
276        assert!(registry.is_cached("test_add"));
277
278        // Second call — returns cached pipeline.
279        let p2 = registry
280            .get_pipeline("test_add", device.metal_device())
281            .expect("get_pipeline second call");
282        let p2_ptr = p2 as *const _;
283
284        assert_eq!(
285            p1_ptr, p2_ptr,
286            "Second get_pipeline call should return the same cached pipeline"
287        );
288    }
289
290    // ---- Additional: test alloc_buffer with zero length returns error ----
291    #[test]
292    fn test_buffer_alloc_zero_len_error() {
293        let device = MlxDevice::new().expect("device");
294        let result = device.alloc_buffer(0, DType::F32, vec![]);
295        assert!(result.is_err(), "Zero-length allocation should fail");
296        match result {
297            Err(MlxError::InvalidArgument(_)) => {}
298            other => panic!("Expected InvalidArgument, got {:?}", other),
299        }
300    }
301
302    // ---- Additional: test kernel not found ----
303    #[test]
304    fn test_kernel_not_found() {
305        let device = MlxDevice::new().expect("device");
306        let mut registry = KernelRegistry::new();
307        let result = registry.get_pipeline("nonexistent_kernel", device.metal_device());
308        assert!(result.is_err());
309        match result {
310            Err(MlxError::KernelNotFound(name)) => {
311                assert_eq!(name, "nonexistent_kernel");
312            }
313            other => panic!("Expected KernelNotFound, got {:?}", other),
314        }
315    }
316
317    // ---- Additional: test DType properties ----
318    #[test]
319    fn test_dtype_sizes() {
320        assert_eq!(DType::F32.size_of(), 4);
321        assert_eq!(DType::F16.size_of(), 2);
322        assert_eq!(DType::BF16.size_of(), 2);
323        assert_eq!(DType::U8.size_of(), 1);
324        assert_eq!(DType::U16.size_of(), 2);
325        assert_eq!(DType::U32.size_of(), 4);
326        assert_eq!(DType::I32.size_of(), 4);
327    }
328
329    // ---- Additional: test MlxBuffer Debug ----
330    #[test]
331    fn test_buffer_debug() {
332        let device = MlxDevice::new().expect("device");
333        let buf = device
334            .alloc_buffer(64, DType::F16, vec![4, 8])
335            .expect("alloc_buffer");
336        let debug_str = format!("{:?}", buf);
337        assert!(debug_str.contains("MlxBuffer"));
338        assert!(debug_str.contains("F16"));
339        assert!(debug_str.contains("[4, 8]"));
340    }
341
342    // ---- Additional: test MlxError Display ----
343    #[test]
344    fn test_error_display() {
345        let e = MlxError::DeviceNotFound;
346        assert!(format!("{e}").contains("Metal GPU device"));
347
348        let e = MlxError::ShaderCompilationError {
349            name: "foo".into(),
350            message: "syntax error".into(),
351        };
352        assert!(format!("{e}").contains("foo"));
353        assert!(format!("{e}").contains("syntax error"));
354    }
355
356    // ---- Additional: test buffer pool with different sizes ----
357    #[test]
358    fn test_buffer_pool_size_buckets() {
359        let device = MlxDevice::new().expect("device");
360        let mut pool = MlxBufferPool::new();
361
362        // Allocate a 100-byte buffer (rounds to 128-byte bucket).
363        let buf_100 = pool.alloc(&device, 100, DType::U8, vec![100]).expect("alloc 100");
364        assert!(
365            buf_100.byte_len() >= 100,
366            "Buffer should be at least 100 bytes"
367        );
368        pool.release(buf_100);
369
370        // Allocate a 128-byte buffer — should reuse the same Metal buffer.
371        let buf_128 = pool.alloc(&device, 128, DType::U8, vec![128]).expect("alloc 128");
372        assert!(buf_128.byte_len() >= 128);
373        pool.release(buf_128);
374
375        // Allocate a 200-byte buffer — different bucket (256), fresh allocation.
376        let buf_200 = pool.alloc(&device, 200, DType::U8, vec![200]).expect("alloc 200");
377        assert!(buf_200.byte_len() >= 200);
378        pool.release(buf_200);
379
380        assert_eq!(pool.free_count(), 2, "Two different bucket sizes in pool");
381    }
382}