1#![deny(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
41#![allow(unexpected_cfgs)]
45
46#[macro_use]
48mod error;
49mod buffer;
50mod buffer_pool;
51mod device;
52mod dtypes;
53mod encoder;
54mod kernel_registry;
55mod mem_ranges;
56mod residency;
57pub mod gguf;
58pub mod kernel_profile;
59pub mod graph;
60pub mod metal_capture;
61pub mod ops;
62pub mod turboquant;
63pub mod weight;
64
65pub use buffer::MlxBuffer;
67pub use buffer_pool::MlxBufferPool;
68pub use device::MlxDevice;
69pub use dtypes::DType;
70pub use encoder::{
71 auto_barrier_concurrent_count, auto_barrier_count, barrier_count, barrier_total_ns,
72 cmd_buf_count, dispatch_count, reset_counters, sync_count, CapturedNode, CapturedOpKind,
73 CommandEncoder, DispatchKind, KernelArg, RecordedBinding,
74};
75pub use mem_ranges::{BufferRange, MemRangeRole, MemRanges};
76pub use error::{MlxError, Result};
77pub use graph::{ComputeGraph, GraphExecutor, GraphSession, OpKind};
78pub use kernel_registry::KernelRegistry;
79#[doc(hidden)]
85pub use residency::{
86 macos_15_or_newer_for_test, reset_residency_env_cache_for_test,
87 reset_residency_test_counters, residency_allocation_count_for_test,
88 residency_commit_call_count_for_test,
89};
90
91pub use gguf::{GgufFile, MetadataValue, TensorInfo};
93
94pub use ops::dense_mm_bf16::{dense_matmul_bf16_f32_tensor, DenseMmBf16F32Params};
96pub use ops::dense_mm_f16::{dense_matmul_f16_f32_tensor, DenseMmF16F32Params};
97pub use ops::dense_mm_f32_f32::{dense_matmul_f32_f32_tensor, DenseMmF32F32Params};
98pub use ops::quantized_matmul::{quantized_matmul, quantized_matmul_simd, QuantizedMatmulParams};
99pub use ops::quantized_matmul_ggml::{
100 dispatch_mm_for_test, quantized_matmul_ggml, quantized_matmul_mm_tensor_perm021,
101 GgmlQuantizedMatmulParams, GgmlQuantizedMatmulPerm021Params, GgmlType,
102 MM_ROUTING_THRESHOLD,
103};
104pub use ops::quantized_matmul_id::{quantized_matmul_id, QuantizedMatmulIdParams};
105pub use ops::quantized_matmul_id_ggml::{
106 dispatch_id_mm_for_test, quantized_matmul_id_ggml, quantized_matmul_id_ggml_pooled,
107 quantized_matmul_id_swiglu_q4_0,
108 GgmlIdMmDispatchParams, GgmlQuantizedMatmulIdParams, IdMmScratch,
109 MM_ID_ROUTING_THRESHOLD,
110};
111
112pub use weight::{
114 load_quantized_weights, safetensors_to_metal_buffer, QuantizationConfig, QuantizedWeight,
115 SafetensorsFile, TensorQuantConfig,
116};
117
118pub use metal::MTLSize;
120pub use metal;
121
122#[cfg(test)]
123#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
124mod tests {
125 use super::*;
126
127 fn _assert_send<T: Send>() {}
129 fn _assert_sync<T: Sync>() {}
130
131 #[allow(dead_code)]
132 fn assert_send_sync() {
133 _assert_send::<MlxDevice>();
134 _assert_sync::<MlxDevice>();
135 _assert_send::<MlxBuffer>();
136 _assert_sync::<MlxBuffer>();
137 _assert_send::<MlxError>();
138 _assert_sync::<MlxError>();
139 }
140
141 #[test]
143 fn test_device_init() {
144 let device = MlxDevice::new().expect("MlxDevice::new() should succeed on Apple Silicon");
145 let name = device.name();
146 assert!(!name.is_empty(), "Device name should not be empty");
147 println!("Metal device: {name}");
148 }
149
150 #[test]
152 fn test_buffer_alloc() {
153 let device = MlxDevice::new().expect("device");
154 let shape = vec![2, 3, 4];
155 let byte_len = 2 * 3 * 4 * DType::F32.size_of(); let buf = device
157 .alloc_buffer(byte_len, DType::F32, shape.clone())
158 .expect("alloc_buffer");
159
160 assert_eq!(buf.dtype(), DType::F32);
161 assert_eq!(buf.shape(), &shape);
162 assert_eq!(buf.byte_len(), byte_len);
163 assert_eq!(buf.element_count(), 24);
164 }
165
166 #[test]
168 fn test_buffer_readwrite() {
169 let device = MlxDevice::new().expect("device");
170 let n = 64;
171 let byte_len = n * std::mem::size_of::<f32>();
172 let mut buf = device
173 .alloc_buffer(byte_len, DType::F32, vec![n])
174 .expect("alloc_buffer");
175
176 {
178 let slice: &mut [f32] = buf.as_mut_slice().expect("as_mut_slice");
179 assert_eq!(slice.len(), n);
180 for (i, val) in slice.iter_mut().enumerate() {
181 *val = i as f32 * 1.5;
182 }
183 }
184
185 {
187 let slice: &[f32] = buf.as_slice().expect("as_slice");
188 for (i, &val) in slice.iter().enumerate() {
189 let expected = i as f32 * 1.5;
190 assert!(
191 (val - expected).abs() < f32::EPSILON,
192 "Mismatch at index {i}: got {val}, expected {expected}"
193 );
194 }
195 }
196 }
197
198 #[test]
200 fn test_encoder_lifecycle() {
201 let device = MlxDevice::new().expect("device");
202 let mut enc = device.command_encoder().expect("command_encoder");
203 enc.commit_and_wait()
205 .expect("commit_and_wait on empty encoder");
206 }
207
208 #[test]
210 fn test_buffer_pool_reuse() {
211 let device = MlxDevice::new().expect("device");
212 let mut pool = MlxBufferPool::new();
213
214 let buf1 = pool
216 .alloc(&device, 1024, DType::F32, vec![256])
217 .expect("pool alloc 1");
218 let buf1_ptr = buf1.contents_ptr();
219 let buf1_byte_len = buf1.byte_len();
220
221 pool.release(buf1);
223 assert_eq!(pool.free_count(), 1);
224
225 let buf2 = pool
227 .alloc(&device, 1024, DType::F32, vec![256])
228 .expect("pool alloc 2");
229 let buf2_ptr = buf2.contents_ptr();
230 let buf2_byte_len = buf2.byte_len();
231
232 assert_eq!(buf1_ptr, buf2_ptr, "Pool should reuse the same Metal buffer");
233 assert_eq!(buf1_byte_len, buf2_byte_len, "Byte lengths should match");
234 assert_eq!(pool.free_count(), 0, "Free list should be empty after reuse");
235 }
236
237 #[test]
239 fn test_kernel_registry_caching() {
240 let device = MlxDevice::new().expect("device");
241 let mut registry = KernelRegistry::new();
242
243 registry.register_source(
245 "test_add",
246 r#"
247 #include <metal_stdlib>
248 using namespace metal;
249 kernel void test_add(
250 device float *a [[buffer(0)]],
251 device float *b [[buffer(1)]],
252 device float *c [[buffer(2)]],
253 uint id [[thread_position_in_grid]]
254 ) {
255 c[id] = a[id] + b[id];
256 }
257 "#,
258 );
259
260 assert!(!registry.is_cached("test_add"));
262 let p1 = registry
263 .get_pipeline("test_add", device.metal_device())
264 .expect("get_pipeline first call");
265 let p1_ptr = p1 as *const _;
266 assert!(registry.is_cached("test_add"));
267
268 let p2 = registry
270 .get_pipeline("test_add", device.metal_device())
271 .expect("get_pipeline second call");
272 let p2_ptr = p2 as *const _;
273
274 assert_eq!(
275 p1_ptr, p2_ptr,
276 "Second get_pipeline call should return the same cached pipeline"
277 );
278 }
279
280 #[test]
282 fn test_buffer_alloc_zero_len_error() {
283 let device = MlxDevice::new().expect("device");
284 let result = device.alloc_buffer(0, DType::F32, vec![]);
285 assert!(result.is_err(), "Zero-length allocation should fail");
286 match result {
287 Err(MlxError::InvalidArgument(_)) => {}
288 other => panic!("Expected InvalidArgument, got {:?}", other),
289 }
290 }
291
292 #[test]
294 fn test_kernel_not_found() {
295 let device = MlxDevice::new().expect("device");
296 let mut registry = KernelRegistry::new();
297 let result = registry.get_pipeline("nonexistent_kernel", device.metal_device());
298 assert!(result.is_err());
299 match result {
300 Err(MlxError::KernelNotFound(name)) => {
301 assert_eq!(name, "nonexistent_kernel");
302 }
303 other => panic!("Expected KernelNotFound, got {:?}", other),
304 }
305 }
306
307 #[test]
309 fn test_dtype_sizes() {
310 assert_eq!(DType::F32.size_of(), 4);
311 assert_eq!(DType::F16.size_of(), 2);
312 assert_eq!(DType::BF16.size_of(), 2);
313 assert_eq!(DType::U8.size_of(), 1);
314 assert_eq!(DType::U16.size_of(), 2);
315 assert_eq!(DType::U32.size_of(), 4);
316 assert_eq!(DType::I32.size_of(), 4);
317 }
318
319 #[test]
321 fn test_buffer_debug() {
322 let device = MlxDevice::new().expect("device");
323 let buf = device
324 .alloc_buffer(64, DType::F16, vec![4, 8])
325 .expect("alloc_buffer");
326 let debug_str = format!("{:?}", buf);
327 assert!(debug_str.contains("MlxBuffer"));
328 assert!(debug_str.contains("F16"));
329 assert!(debug_str.contains("[4, 8]"));
330 }
331
332 #[test]
334 fn test_error_display() {
335 let e = MlxError::DeviceNotFound;
336 assert!(format!("{e}").contains("Metal GPU device"));
337
338 let e = MlxError::ShaderCompilationError {
339 name: "foo".into(),
340 message: "syntax error".into(),
341 };
342 assert!(format!("{e}").contains("foo"));
343 assert!(format!("{e}").contains("syntax error"));
344 }
345
346 #[test]
348 fn test_buffer_pool_size_buckets() {
349 let device = MlxDevice::new().expect("device");
350 let mut pool = MlxBufferPool::new();
351
352 let buf_100 = pool.alloc(&device, 100, DType::U8, vec![100]).expect("alloc 100");
354 assert!(
355 buf_100.byte_len() >= 100,
356 "Buffer should be at least 100 bytes"
357 );
358 pool.release(buf_100);
359
360 let buf_128 = pool.alloc(&device, 128, DType::U8, vec![128]).expect("alloc 128");
362 assert!(buf_128.byte_len() >= 128);
363 pool.release(buf_128);
364
365 let buf_200 = pool.alloc(&device, 200, DType::U8, vec![200]).expect("alloc 200");
367 assert!(buf_200.byte_len() >= 200);
368 pool.release(buf_200);
369
370 assert_eq!(pool.free_count(), 2, "Two different bucket sizes in pool");
371 }
372}