1#![deny(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
41#![allow(unexpected_cfgs)]
45
46#[macro_use]
48mod error;
49mod buffer;
50mod buffer_pool;
51mod device;
52mod dtypes;
53mod encoder;
54mod encoder_session;
55mod kernel_registry;
56mod mem_ranges;
57mod residency;
58pub mod gguf;
59pub mod kernel_profile;
60pub mod graph;
61pub mod metal_capture;
62pub mod ops;
63pub mod turboquant;
64pub mod weight;
65
66pub use buffer::MlxBuffer;
68pub use buffer_pool::MlxBufferPool;
69pub use device::MlxDevice;
70pub use dtypes::DType;
71pub use encoder::{
72 auto_barrier_concurrent_count, auto_barrier_count, barrier_count, barrier_total_ns,
73 cmd_buf_count, dispatch_count, reset_counters, sync_count, CapturedNode, CapturedOpKind,
74 CommandEncoder, DispatchKind, KernelArg, RecordedBinding,
75};
76pub use encoder_session::EncoderSession;
77pub use mem_ranges::{BufferRange, MemRangeRole, MemRanges};
78pub use error::{MlxError, Result};
79pub use graph::{ComputeGraph, GraphExecutor, GraphSession, OpKind};
80pub use kernel_registry::KernelRegistry;
81#[doc(hidden)]
87pub use residency::{
88 macos_15_or_newer_for_test, reset_residency_env_cache_for_test,
89 reset_residency_test_counters, residency_allocation_count_for_test,
90 residency_commit_call_count_for_test,
91};
92
93pub use gguf::{GgufFile, MetadataValue, TensorInfo};
95
96pub use ops::dense_mm_bf16::{dense_matmul_bf16_f32_tensor, DenseMmBf16F32Params};
98pub use ops::dense_mm_f16::{dense_matmul_f16_f32_tensor, DenseMmF16F32Params};
99pub use ops::dense_mm_f32_f32::{dense_matmul_f32_f32_tensor, DenseMmF32F32Params};
100pub use ops::quantized_matmul::{quantized_matmul, quantized_matmul_simd, QuantizedMatmulParams};
101pub use ops::quantized_matmul_ggml::{
102 dispatch_mm_for_test, quantized_matmul_ggml, quantized_matmul_mm_tensor_perm021,
103 GgmlQuantizedMatmulParams, GgmlQuantizedMatmulPerm021Params, GgmlType,
104 MM_ROUTING_THRESHOLD,
105};
106pub use ops::quantized_matmul_id::{quantized_matmul_id, QuantizedMatmulIdParams};
107pub use ops::quantized_matmul_id_ggml::{
108 dispatch_id_mm_for_test, quantized_matmul_id_ggml, quantized_matmul_id_ggml_pooled,
109 quantized_matmul_id_swiglu_q4_0,
110 GgmlIdMmDispatchParams, GgmlQuantizedMatmulIdParams, IdMmScratch,
111 MM_ID_ROUTING_THRESHOLD,
112};
113
114pub use weight::{
116 load_quantized_weights, safetensors_to_metal_buffer, QuantizationConfig, QuantizedWeight,
117 SafetensorsFile, TensorQuantConfig,
118};
119
120pub use metal::MTLSize;
122pub use metal;
123
124#[cfg(test)]
125#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
126mod tests {
127 use super::*;
128
129 fn _assert_send<T: Send>() {}
131 fn _assert_sync<T: Sync>() {}
132
133 #[allow(dead_code)]
134 fn assert_send_sync() {
135 _assert_send::<MlxDevice>();
136 _assert_sync::<MlxDevice>();
137 _assert_send::<MlxBuffer>();
138 _assert_sync::<MlxBuffer>();
139 _assert_send::<MlxError>();
140 _assert_sync::<MlxError>();
141 }
142
143 #[test]
145 fn test_device_init() {
146 let device = MlxDevice::new().expect("MlxDevice::new() should succeed on Apple Silicon");
147 let name = device.name();
148 assert!(!name.is_empty(), "Device name should not be empty");
149 println!("Metal device: {name}");
150 }
151
152 #[test]
154 fn test_buffer_alloc() {
155 let device = MlxDevice::new().expect("device");
156 let shape = vec![2, 3, 4];
157 let byte_len = 2 * 3 * 4 * DType::F32.size_of(); let buf = device
159 .alloc_buffer(byte_len, DType::F32, shape.clone())
160 .expect("alloc_buffer");
161
162 assert_eq!(buf.dtype(), DType::F32);
163 assert_eq!(buf.shape(), &shape);
164 assert_eq!(buf.byte_len(), byte_len);
165 assert_eq!(buf.element_count(), 24);
166 }
167
168 #[test]
170 fn test_buffer_readwrite() {
171 let device = MlxDevice::new().expect("device");
172 let n = 64;
173 let byte_len = n * std::mem::size_of::<f32>();
174 let mut buf = device
175 .alloc_buffer(byte_len, DType::F32, vec![n])
176 .expect("alloc_buffer");
177
178 {
180 let slice: &mut [f32] = buf.as_mut_slice().expect("as_mut_slice");
181 assert_eq!(slice.len(), n);
182 for (i, val) in slice.iter_mut().enumerate() {
183 *val = i as f32 * 1.5;
184 }
185 }
186
187 {
189 let slice: &[f32] = buf.as_slice().expect("as_slice");
190 for (i, &val) in slice.iter().enumerate() {
191 let expected = i as f32 * 1.5;
192 assert!(
193 (val - expected).abs() < f32::EPSILON,
194 "Mismatch at index {i}: got {val}, expected {expected}"
195 );
196 }
197 }
198 }
199
200 #[test]
202 fn test_encoder_lifecycle() {
203 let device = MlxDevice::new().expect("device");
204 let mut enc = device.command_encoder().expect("command_encoder");
205 enc.commit_and_wait()
207 .expect("commit_and_wait on empty encoder");
208 }
209
210 #[test]
212 fn test_buffer_pool_reuse() {
213 let device = MlxDevice::new().expect("device");
214 let mut pool = MlxBufferPool::new();
215
216 let buf1 = pool
218 .alloc(&device, 1024, DType::F32, vec![256])
219 .expect("pool alloc 1");
220 let buf1_ptr = buf1.contents_ptr();
221 let buf1_byte_len = buf1.byte_len();
222
223 pool.release(buf1);
225 assert_eq!(pool.free_count(), 1);
226
227 let buf2 = pool
229 .alloc(&device, 1024, DType::F32, vec![256])
230 .expect("pool alloc 2");
231 let buf2_ptr = buf2.contents_ptr();
232 let buf2_byte_len = buf2.byte_len();
233
234 assert_eq!(buf1_ptr, buf2_ptr, "Pool should reuse the same Metal buffer");
235 assert_eq!(buf1_byte_len, buf2_byte_len, "Byte lengths should match");
236 assert_eq!(pool.free_count(), 0, "Free list should be empty after reuse");
237 }
238
239 #[test]
241 fn test_kernel_registry_caching() {
242 let device = MlxDevice::new().expect("device");
243 let mut registry = KernelRegistry::new();
244
245 registry.register_source(
247 "test_add",
248 r#"
249 #include <metal_stdlib>
250 using namespace metal;
251 kernel void test_add(
252 device float *a [[buffer(0)]],
253 device float *b [[buffer(1)]],
254 device float *c [[buffer(2)]],
255 uint id [[thread_position_in_grid]]
256 ) {
257 c[id] = a[id] + b[id];
258 }
259 "#,
260 );
261
262 assert!(!registry.is_cached("test_add"));
264 let p1 = registry
265 .get_pipeline("test_add", device.metal_device())
266 .expect("get_pipeline first call");
267 let p1_ptr = p1 as *const _;
268 assert!(registry.is_cached("test_add"));
269
270 let p2 = registry
272 .get_pipeline("test_add", device.metal_device())
273 .expect("get_pipeline second call");
274 let p2_ptr = p2 as *const _;
275
276 assert_eq!(
277 p1_ptr, p2_ptr,
278 "Second get_pipeline call should return the same cached pipeline"
279 );
280 }
281
282 #[test]
284 fn test_buffer_alloc_zero_len_error() {
285 let device = MlxDevice::new().expect("device");
286 let result = device.alloc_buffer(0, DType::F32, vec![]);
287 assert!(result.is_err(), "Zero-length allocation should fail");
288 match result {
289 Err(MlxError::InvalidArgument(_)) => {}
290 other => panic!("Expected InvalidArgument, got {:?}", other),
291 }
292 }
293
294 #[test]
296 fn test_kernel_not_found() {
297 let device = MlxDevice::new().expect("device");
298 let mut registry = KernelRegistry::new();
299 let result = registry.get_pipeline("nonexistent_kernel", device.metal_device());
300 assert!(result.is_err());
301 match result {
302 Err(MlxError::KernelNotFound(name)) => {
303 assert_eq!(name, "nonexistent_kernel");
304 }
305 other => panic!("Expected KernelNotFound, got {:?}", other),
306 }
307 }
308
309 #[test]
311 fn test_dtype_sizes() {
312 assert_eq!(DType::F32.size_of(), 4);
313 assert_eq!(DType::F16.size_of(), 2);
314 assert_eq!(DType::BF16.size_of(), 2);
315 assert_eq!(DType::U8.size_of(), 1);
316 assert_eq!(DType::U16.size_of(), 2);
317 assert_eq!(DType::U32.size_of(), 4);
318 assert_eq!(DType::I32.size_of(), 4);
319 }
320
321 #[test]
323 fn test_buffer_debug() {
324 let device = MlxDevice::new().expect("device");
325 let buf = device
326 .alloc_buffer(64, DType::F16, vec![4, 8])
327 .expect("alloc_buffer");
328 let debug_str = format!("{:?}", buf);
329 assert!(debug_str.contains("MlxBuffer"));
330 assert!(debug_str.contains("F16"));
331 assert!(debug_str.contains("[4, 8]"));
332 }
333
334 #[test]
336 fn test_error_display() {
337 let e = MlxError::DeviceNotFound;
338 assert!(format!("{e}").contains("Metal GPU device"));
339
340 let e = MlxError::ShaderCompilationError {
341 name: "foo".into(),
342 message: "syntax error".into(),
343 };
344 assert!(format!("{e}").contains("foo"));
345 assert!(format!("{e}").contains("syntax error"));
346 }
347
348 #[test]
350 fn test_buffer_pool_size_buckets() {
351 let device = MlxDevice::new().expect("device");
352 let mut pool = MlxBufferPool::new();
353
354 let buf_100 = pool.alloc(&device, 100, DType::U8, vec![100]).expect("alloc 100");
356 assert!(
357 buf_100.byte_len() >= 100,
358 "Buffer should be at least 100 bytes"
359 );
360 pool.release(buf_100);
361
362 let buf_128 = pool.alloc(&device, 128, DType::U8, vec![128]).expect("alloc 128");
364 assert!(buf_128.byte_len() >= 128);
365 pool.release(buf_128);
366
367 let buf_200 = pool.alloc(&device, 200, DType::U8, vec![200]).expect("alloc 200");
369 assert!(buf_200.byte_len() >= 200);
370 pool.release(buf_200);
371
372 assert_eq!(pool.free_count(), 2, "Two different bucket sizes in pool");
373 }
374}