1#![deny(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
41#![allow(unexpected_cfgs)]
45
46#[macro_use]
48mod error;
49mod buffer;
50mod buffer_pool;
51mod device;
52mod dtypes;
53mod encoder;
54mod kernel_registry;
55pub mod gguf;
56pub mod graph;
57pub mod ops;
58pub mod turboquant;
59pub mod weight;
60
61pub use buffer::MlxBuffer;
63pub use buffer_pool::MlxBufferPool;
64pub use device::MlxDevice;
65pub use dtypes::DType;
66pub use encoder::{
67 dispatch_count, reset_counters, sync_count, CapturedNode, CommandEncoder, DispatchKind,
68 RecordedBinding,
69};
70pub use error::{MlxError, Result};
71pub use graph::{ComputeGraph, GraphExecutor, GraphSession, OpKind};
72pub use kernel_registry::KernelRegistry;
73
74pub use gguf::{GgufFile, MetadataValue, TensorInfo};
76
77pub use ops::quantized_matmul::{quantized_matmul, quantized_matmul_simd, QuantizedMatmulParams};
79pub use ops::quantized_matmul_ggml::{
80 dispatch_mm_for_test, quantized_matmul_ggml, quantized_matmul_mm_tensor_perm021,
81 GgmlQuantizedMatmulParams, GgmlQuantizedMatmulPerm021Params, GgmlType,
82 MM_ROUTING_THRESHOLD,
83};
84pub use ops::quantized_matmul_id::{quantized_matmul_id, QuantizedMatmulIdParams};
85pub use ops::quantized_matmul_id_ggml::{
86 dispatch_id_mm_for_test, quantized_matmul_id_ggml, quantized_matmul_id_ggml_pooled,
87 GgmlIdMmDispatchParams, GgmlQuantizedMatmulIdParams, IdMmScratch,
88 MM_ID_ROUTING_THRESHOLD,
89};
90
91pub use weight::{
93 load_quantized_weights, safetensors_to_metal_buffer, QuantizationConfig, QuantizedWeight,
94 SafetensorsFile, TensorQuantConfig,
95};
96
97pub use metal::MTLSize;
99pub use metal;
100
101#[cfg(test)]
102#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
103mod tests {
104 use super::*;
105
106 fn _assert_send<T: Send>() {}
108 fn _assert_sync<T: Sync>() {}
109
110 #[allow(dead_code)]
111 fn assert_send_sync() {
112 _assert_send::<MlxDevice>();
113 _assert_sync::<MlxDevice>();
114 _assert_send::<MlxBuffer>();
115 _assert_sync::<MlxBuffer>();
116 _assert_send::<MlxError>();
117 _assert_sync::<MlxError>();
118 }
119
120 #[test]
122 fn test_device_init() {
123 let device = MlxDevice::new().expect("MlxDevice::new() should succeed on Apple Silicon");
124 let name = device.name();
125 assert!(!name.is_empty(), "Device name should not be empty");
126 println!("Metal device: {name}");
127 }
128
129 #[test]
131 fn test_buffer_alloc() {
132 let device = MlxDevice::new().expect("device");
133 let shape = vec![2, 3, 4];
134 let byte_len = 2 * 3 * 4 * DType::F32.size_of(); let buf = device
136 .alloc_buffer(byte_len, DType::F32, shape.clone())
137 .expect("alloc_buffer");
138
139 assert_eq!(buf.dtype(), DType::F32);
140 assert_eq!(buf.shape(), &shape);
141 assert_eq!(buf.byte_len(), byte_len);
142 assert_eq!(buf.element_count(), 24);
143 }
144
145 #[test]
147 fn test_buffer_readwrite() {
148 let device = MlxDevice::new().expect("device");
149 let n = 64;
150 let byte_len = n * std::mem::size_of::<f32>();
151 let mut buf = device
152 .alloc_buffer(byte_len, DType::F32, vec![n])
153 .expect("alloc_buffer");
154
155 {
157 let slice: &mut [f32] = buf.as_mut_slice().expect("as_mut_slice");
158 assert_eq!(slice.len(), n);
159 for (i, val) in slice.iter_mut().enumerate() {
160 *val = i as f32 * 1.5;
161 }
162 }
163
164 {
166 let slice: &[f32] = buf.as_slice().expect("as_slice");
167 for (i, &val) in slice.iter().enumerate() {
168 let expected = i as f32 * 1.5;
169 assert!(
170 (val - expected).abs() < f32::EPSILON,
171 "Mismatch at index {i}: got {val}, expected {expected}"
172 );
173 }
174 }
175 }
176
177 #[test]
179 fn test_encoder_lifecycle() {
180 let device = MlxDevice::new().expect("device");
181 let mut enc = device.command_encoder().expect("command_encoder");
182 enc.commit_and_wait()
184 .expect("commit_and_wait on empty encoder");
185 }
186
187 #[test]
189 fn test_buffer_pool_reuse() {
190 let device = MlxDevice::new().expect("device");
191 let mut pool = MlxBufferPool::new(&device);
192
193 let buf1 = pool
195 .alloc(1024, DType::F32, vec![256])
196 .expect("pool alloc 1");
197 let buf1_ptr = buf1.contents_ptr();
198 let buf1_byte_len = buf1.byte_len();
199
200 pool.release(buf1);
202 assert_eq!(pool.free_count(), 1);
203
204 let buf2 = pool
206 .alloc(1024, DType::F32, vec![256])
207 .expect("pool alloc 2");
208 let buf2_ptr = buf2.contents_ptr();
209 let buf2_byte_len = buf2.byte_len();
210
211 assert_eq!(buf1_ptr, buf2_ptr, "Pool should reuse the same Metal buffer");
212 assert_eq!(buf1_byte_len, buf2_byte_len, "Byte lengths should match");
213 assert_eq!(pool.free_count(), 0, "Free list should be empty after reuse");
214 }
215
216 #[test]
218 fn test_kernel_registry_caching() {
219 let device = MlxDevice::new().expect("device");
220 let mut registry = KernelRegistry::new();
221
222 registry.register_source(
224 "test_add",
225 r#"
226 #include <metal_stdlib>
227 using namespace metal;
228 kernel void test_add(
229 device float *a [[buffer(0)]],
230 device float *b [[buffer(1)]],
231 device float *c [[buffer(2)]],
232 uint id [[thread_position_in_grid]]
233 ) {
234 c[id] = a[id] + b[id];
235 }
236 "#,
237 );
238
239 assert!(!registry.is_cached("test_add"));
241 let p1 = registry
242 .get_pipeline("test_add", device.metal_device())
243 .expect("get_pipeline first call");
244 let p1_ptr = p1 as *const _;
245 assert!(registry.is_cached("test_add"));
246
247 let p2 = registry
249 .get_pipeline("test_add", device.metal_device())
250 .expect("get_pipeline second call");
251 let p2_ptr = p2 as *const _;
252
253 assert_eq!(
254 p1_ptr, p2_ptr,
255 "Second get_pipeline call should return the same cached pipeline"
256 );
257 }
258
259 #[test]
261 fn test_buffer_alloc_zero_len_error() {
262 let device = MlxDevice::new().expect("device");
263 let result = device.alloc_buffer(0, DType::F32, vec![]);
264 assert!(result.is_err(), "Zero-length allocation should fail");
265 match result {
266 Err(MlxError::InvalidArgument(_)) => {}
267 other => panic!("Expected InvalidArgument, got {:?}", other),
268 }
269 }
270
271 #[test]
273 fn test_kernel_not_found() {
274 let device = MlxDevice::new().expect("device");
275 let mut registry = KernelRegistry::new();
276 let result = registry.get_pipeline("nonexistent_kernel", device.metal_device());
277 assert!(result.is_err());
278 match result {
279 Err(MlxError::KernelNotFound(name)) => {
280 assert_eq!(name, "nonexistent_kernel");
281 }
282 other => panic!("Expected KernelNotFound, got {:?}", other),
283 }
284 }
285
286 #[test]
288 fn test_dtype_sizes() {
289 assert_eq!(DType::F32.size_of(), 4);
290 assert_eq!(DType::F16.size_of(), 2);
291 assert_eq!(DType::BF16.size_of(), 2);
292 assert_eq!(DType::U8.size_of(), 1);
293 assert_eq!(DType::U16.size_of(), 2);
294 assert_eq!(DType::U32.size_of(), 4);
295 assert_eq!(DType::I32.size_of(), 4);
296 }
297
298 #[test]
300 fn test_buffer_debug() {
301 let device = MlxDevice::new().expect("device");
302 let buf = device
303 .alloc_buffer(64, DType::F16, vec![4, 8])
304 .expect("alloc_buffer");
305 let debug_str = format!("{:?}", buf);
306 assert!(debug_str.contains("MlxBuffer"));
307 assert!(debug_str.contains("F16"));
308 assert!(debug_str.contains("[4, 8]"));
309 }
310
311 #[test]
313 fn test_error_display() {
314 let e = MlxError::DeviceNotFound;
315 assert!(format!("{e}").contains("Metal GPU device"));
316
317 let e = MlxError::ShaderCompilationError {
318 name: "foo".into(),
319 message: "syntax error".into(),
320 };
321 assert!(format!("{e}").contains("foo"));
322 assert!(format!("{e}").contains("syntax error"));
323 }
324
325 #[test]
327 fn test_buffer_pool_size_buckets() {
328 let device = MlxDevice::new().expect("device");
329 let mut pool = MlxBufferPool::new(&device);
330
331 let buf_100 = pool.alloc(100, DType::U8, vec![100]).expect("alloc 100");
333 assert!(
334 buf_100.byte_len() >= 100,
335 "Buffer should be at least 100 bytes"
336 );
337 pool.release(buf_100);
338
339 let buf_128 = pool.alloc(128, DType::U8, vec![128]).expect("alloc 128");
341 assert!(buf_128.byte_len() >= 128);
342 pool.release(buf_128);
343
344 let buf_200 = pool.alloc(200, DType::U8, vec![200]).expect("alloc 200");
346 assert!(buf_200.byte_len() >= 200);
347 pool.release(buf_200);
348
349 assert_eq!(pool.free_count(), 2, "Two different bucket sizes in pool");
350 }
351}