Skip to main content

mlx_native/
lib.rs

1//! # mlx-native
2//!
3//! Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple
4//! Silicon.
5//!
6//! This crate provides a thin, safe wrapper around Apple's Metal framework
7//! focused on compute shader dispatch for neural network inference.  It is
8//! designed to be the GPU backend for the `hf2q` inference engine.
9//!
10//! ## Key Types
11//!
12//! | Type | Purpose |
13//! |------|---------|
14//! | [`MlxDevice`]       | Metal device + command queue (entry point) |
15//! | [`CommandEncoder`]   | Batched compute command submission |
16//! | [`MlxBuffer`]        | Typed Metal buffer with shape/dtype metadata |
17//! | [`MlxBufferPool`]    | Arena allocator with power-of-two bucketing |
18//! | [`KernelRegistry`]   | Lazy MSL compilation + pipeline cache |
19//! | [`DType`]            | Element data type enum |
20//! | [`MlxError`]         | Unified error type (never panics) |
21//!
22//! ## Quick Start
23//!
24//! ```ignore
25//! use mlx_native::{MlxDevice, DType};
26//!
27//! let device = MlxDevice::new()?;
28//! let buf = device.alloc_buffer(1024, DType::F32, vec![256])?;
29//! let encoder = device.command_encoder()?;
30//! ```
31//!
32//! ## Design Principles
33//!
34//! * **No panics** — all public APIs return `Result<T, MlxError>`.
35//! * **Zero-copy** — `StorageModeShared` buffers on Apple Silicon unified memory.
36//! * **Thread-safe** — `MlxDevice` and `MlxBuffer` are `Send + Sync`.
37//! * **Lazy compilation** — MSL shaders compiled on first use, then cached.
38
39// Enforce the no-panic policy at compile time.
40#![deny(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
41// The `objc` crate's `msg_send!` macro internally checks `cfg(feature = "cargo-clippy")`
42// which triggers unexpected_cfgs warnings. Suppress at crate level since we can't
43// control the macro expansion site.
44#![allow(unexpected_cfgs)]
45
46// ---- internal modules ----
47#[macro_use]
48mod error;
49mod buffer;
50mod buffer_pool;
51mod device;
52mod dtypes;
53mod encoder;
54mod kernel_registry;
55pub mod gguf;
56pub mod graph;
57pub mod ops;
58pub mod turboquant;
59pub mod weight;
60
61// ---- public re-exports ----
62pub use buffer::MlxBuffer;
63pub use buffer_pool::MlxBufferPool;
64pub use device::MlxDevice;
65pub use dtypes::DType;
66pub use encoder::{
67    dispatch_count, reset_counters, sync_count, CapturedNode, CommandEncoder, DispatchKind,
68    RecordedBinding,
69};
70pub use error::{MlxError, Result};
71pub use graph::{ComputeGraph, GraphExecutor, GraphSession, OpKind};
72pub use kernel_registry::KernelRegistry;
73
74// Re-export GGUF parser.
75pub use gguf::{GgufFile, MetadataValue, TensorInfo};
76
77// Re-export ops.
78pub use ops::quantized_matmul::{quantized_matmul, quantized_matmul_simd, QuantizedMatmulParams};
79pub use ops::quantized_matmul_ggml::{
80    dispatch_mm_for_test, quantized_matmul_ggml, quantized_matmul_mm_tensor_perm021,
81    GgmlQuantizedMatmulParams, GgmlQuantizedMatmulPerm021Params, GgmlType,
82    MM_ROUTING_THRESHOLD,
83};
84pub use ops::quantized_matmul_id::{quantized_matmul_id, QuantizedMatmulIdParams};
85pub use ops::quantized_matmul_id_ggml::{
86    dispatch_id_mm_for_test, quantized_matmul_id_ggml, quantized_matmul_id_ggml_pooled,
87    GgmlIdMmDispatchParams, GgmlQuantizedMatmulIdParams, IdMmScratch,
88    MM_ID_ROUTING_THRESHOLD,
89};
90
91// Re-export weight loading utilities.
92pub use weight::{
93    load_quantized_weights, safetensors_to_metal_buffer, QuantizationConfig, QuantizedWeight,
94    SafetensorsFile, TensorQuantConfig,
95};
96
97// Re-export metal types that appear in the public API.
98pub use metal::MTLSize;
99pub use metal;
100
101#[cfg(test)]
102#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
103mod tests {
104    use super::*;
105
106    // ---- T10.7: compile-time Send + Sync assertions ----
107    fn _assert_send<T: Send>() {}
108    fn _assert_sync<T: Sync>() {}
109
110    #[allow(dead_code)]
111    fn assert_send_sync() {
112        _assert_send::<MlxDevice>();
113        _assert_sync::<MlxDevice>();
114        _assert_send::<MlxBuffer>();
115        _assert_sync::<MlxBuffer>();
116        _assert_send::<MlxError>();
117        _assert_sync::<MlxError>();
118    }
119
120    // ---- T10.1: device initialization ----
121    #[test]
122    fn test_device_init() {
123        let device = MlxDevice::new().expect("MlxDevice::new() should succeed on Apple Silicon");
124        let name = device.name();
125        assert!(!name.is_empty(), "Device name should not be empty");
126        println!("Metal device: {name}");
127    }
128
129    // ---- T10.2: buffer allocation ----
130    #[test]
131    fn test_buffer_alloc() {
132        let device = MlxDevice::new().expect("device");
133        let shape = vec![2, 3, 4];
134        let byte_len = 2 * 3 * 4 * DType::F32.size_of(); // 96 bytes
135        let buf = device
136            .alloc_buffer(byte_len, DType::F32, shape.clone())
137            .expect("alloc_buffer");
138
139        assert_eq!(buf.dtype(), DType::F32);
140        assert_eq!(buf.shape(), &shape);
141        assert_eq!(buf.byte_len(), byte_len);
142        assert_eq!(buf.element_count(), 24);
143    }
144
145    // ---- T10.3: buffer read/write round-trip ----
146    #[test]
147    fn test_buffer_readwrite() {
148        let device = MlxDevice::new().expect("device");
149        let n = 64;
150        let byte_len = n * std::mem::size_of::<f32>();
151        let mut buf = device
152            .alloc_buffer(byte_len, DType::F32, vec![n])
153            .expect("alloc_buffer");
154
155        // Write known data.
156        {
157            let slice: &mut [f32] = buf.as_mut_slice().expect("as_mut_slice");
158            assert_eq!(slice.len(), n);
159            for (i, val) in slice.iter_mut().enumerate() {
160                *val = i as f32 * 1.5;
161            }
162        }
163
164        // Read back and verify.
165        {
166            let slice: &[f32] = buf.as_slice().expect("as_slice");
167            for (i, &val) in slice.iter().enumerate() {
168                let expected = i as f32 * 1.5;
169                assert!(
170                    (val - expected).abs() < f32::EPSILON,
171                    "Mismatch at index {i}: got {val}, expected {expected}"
172                );
173            }
174        }
175    }
176
177    // ---- T10.4: encoder lifecycle ----
178    #[test]
179    fn test_encoder_lifecycle() {
180        let device = MlxDevice::new().expect("device");
181        let mut enc = device.command_encoder().expect("command_encoder");
182        // Commit an empty command buffer — should succeed (no-op on GPU).
183        enc.commit_and_wait()
184            .expect("commit_and_wait on empty encoder");
185    }
186
187    // ---- T10.5: buffer pool reuse ----
188    #[test]
189    fn test_buffer_pool_reuse() {
190        let device = MlxDevice::new().expect("device");
191        let mut pool = MlxBufferPool::new(&device);
192
193        // Allocate a buffer.
194        let buf1 = pool
195            .alloc(1024, DType::F32, vec![256])
196            .expect("pool alloc 1");
197        let buf1_ptr = buf1.contents_ptr();
198        let buf1_byte_len = buf1.byte_len();
199
200        // Release it back to the pool.
201        pool.release(buf1);
202        assert_eq!(pool.free_count(), 1);
203
204        // Allocate again — should reuse the same Metal buffer.
205        let buf2 = pool
206            .alloc(1024, DType::F32, vec![256])
207            .expect("pool alloc 2");
208        let buf2_ptr = buf2.contents_ptr();
209        let buf2_byte_len = buf2.byte_len();
210
211        assert_eq!(buf1_ptr, buf2_ptr, "Pool should reuse the same Metal buffer");
212        assert_eq!(buf1_byte_len, buf2_byte_len, "Byte lengths should match");
213        assert_eq!(pool.free_count(), 0, "Free list should be empty after reuse");
214    }
215
216    // ---- T10.6: kernel registry caching ----
217    #[test]
218    fn test_kernel_registry_caching() {
219        let device = MlxDevice::new().expect("device");
220        let mut registry = KernelRegistry::new();
221
222        // Register a minimal test kernel.
223        registry.register_source(
224            "test_add",
225            r#"
226            #include <metal_stdlib>
227            using namespace metal;
228            kernel void test_add(
229                device float *a [[buffer(0)]],
230                device float *b [[buffer(1)]],
231                device float *c [[buffer(2)]],
232                uint id [[thread_position_in_grid]]
233            ) {
234                c[id] = a[id] + b[id];
235            }
236            "#,
237        );
238
239        // First call — compiles the shader.
240        assert!(!registry.is_cached("test_add"));
241        let p1 = registry
242            .get_pipeline("test_add", device.metal_device())
243            .expect("get_pipeline first call");
244        let p1_ptr = p1 as *const _;
245        assert!(registry.is_cached("test_add"));
246
247        // Second call — returns cached pipeline.
248        let p2 = registry
249            .get_pipeline("test_add", device.metal_device())
250            .expect("get_pipeline second call");
251        let p2_ptr = p2 as *const _;
252
253        assert_eq!(
254            p1_ptr, p2_ptr,
255            "Second get_pipeline call should return the same cached pipeline"
256        );
257    }
258
259    // ---- Additional: test alloc_buffer with zero length returns error ----
260    #[test]
261    fn test_buffer_alloc_zero_len_error() {
262        let device = MlxDevice::new().expect("device");
263        let result = device.alloc_buffer(0, DType::F32, vec![]);
264        assert!(result.is_err(), "Zero-length allocation should fail");
265        match result {
266            Err(MlxError::InvalidArgument(_)) => {}
267            other => panic!("Expected InvalidArgument, got {:?}", other),
268        }
269    }
270
271    // ---- Additional: test kernel not found ----
272    #[test]
273    fn test_kernel_not_found() {
274        let device = MlxDevice::new().expect("device");
275        let mut registry = KernelRegistry::new();
276        let result = registry.get_pipeline("nonexistent_kernel", device.metal_device());
277        assert!(result.is_err());
278        match result {
279            Err(MlxError::KernelNotFound(name)) => {
280                assert_eq!(name, "nonexistent_kernel");
281            }
282            other => panic!("Expected KernelNotFound, got {:?}", other),
283        }
284    }
285
286    // ---- Additional: test DType properties ----
287    #[test]
288    fn test_dtype_sizes() {
289        assert_eq!(DType::F32.size_of(), 4);
290        assert_eq!(DType::F16.size_of(), 2);
291        assert_eq!(DType::BF16.size_of(), 2);
292        assert_eq!(DType::U8.size_of(), 1);
293        assert_eq!(DType::U16.size_of(), 2);
294        assert_eq!(DType::U32.size_of(), 4);
295        assert_eq!(DType::I32.size_of(), 4);
296    }
297
298    // ---- Additional: test MlxBuffer Debug ----
299    #[test]
300    fn test_buffer_debug() {
301        let device = MlxDevice::new().expect("device");
302        let buf = device
303            .alloc_buffer(64, DType::F16, vec![4, 8])
304            .expect("alloc_buffer");
305        let debug_str = format!("{:?}", buf);
306        assert!(debug_str.contains("MlxBuffer"));
307        assert!(debug_str.contains("F16"));
308        assert!(debug_str.contains("[4, 8]"));
309    }
310
311    // ---- Additional: test MlxError Display ----
312    #[test]
313    fn test_error_display() {
314        let e = MlxError::DeviceNotFound;
315        assert!(format!("{e}").contains("Metal GPU device"));
316
317        let e = MlxError::ShaderCompilationError {
318            name: "foo".into(),
319            message: "syntax error".into(),
320        };
321        assert!(format!("{e}").contains("foo"));
322        assert!(format!("{e}").contains("syntax error"));
323    }
324
325    // ---- Additional: test buffer pool with different sizes ----
326    #[test]
327    fn test_buffer_pool_size_buckets() {
328        let device = MlxDevice::new().expect("device");
329        let mut pool = MlxBufferPool::new(&device);
330
331        // Allocate a 100-byte buffer (rounds to 128-byte bucket).
332        let buf_100 = pool.alloc(100, DType::U8, vec![100]).expect("alloc 100");
333        assert!(
334            buf_100.byte_len() >= 100,
335            "Buffer should be at least 100 bytes"
336        );
337        pool.release(buf_100);
338
339        // Allocate a 128-byte buffer — should reuse the same Metal buffer.
340        let buf_128 = pool.alloc(128, DType::U8, vec![128]).expect("alloc 128");
341        assert!(buf_128.byte_len() >= 128);
342        pool.release(buf_128);
343
344        // Allocate a 200-byte buffer — different bucket (256), fresh allocation.
345        let buf_200 = pool.alloc(200, DType::U8, vec![200]).expect("alloc 200");
346        assert!(buf_200.byte_len() >= 200);
347        pool.release(buf_200);
348
349        assert_eq!(pool.free_count(), 2, "Two different bucket sizes in pool");
350    }
351}