1#![deny(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
41#![allow(unexpected_cfgs)]
45
46#[macro_use]
48mod error;
49mod buffer;
50mod buffer_pool;
51mod device;
52mod dtypes;
53mod encoder;
54mod kernel_registry;
55mod residency;
56pub mod gguf;
57pub mod kernel_profile;
58pub mod graph;
59pub mod ops;
60pub mod turboquant;
61pub mod weight;
62
63pub use buffer::MlxBuffer;
65pub use buffer_pool::MlxBufferPool;
66pub use device::MlxDevice;
67pub use dtypes::DType;
68pub use encoder::{
69 barrier_count, barrier_total_ns, cmd_buf_count, dispatch_count, reset_counters, sync_count,
70 CapturedNode, CommandEncoder, DispatchKind, RecordedBinding,
71};
72pub use error::{MlxError, Result};
73pub use graph::{ComputeGraph, GraphExecutor, GraphSession, OpKind};
74pub use kernel_registry::KernelRegistry;
75#[doc(hidden)]
81pub use residency::{
82 macos_15_or_newer_for_test, reset_residency_env_cache_for_test,
83 reset_residency_test_counters, residency_allocation_count_for_test,
84 residency_commit_call_count_for_test,
85};
86
87pub use gguf::{GgufFile, MetadataValue, TensorInfo};
89
90pub use ops::dense_mm_bf16::{dense_matmul_bf16_f32_tensor, DenseMmBf16F32Params};
92pub use ops::dense_mm_f16::{dense_matmul_f16_f32_tensor, DenseMmF16F32Params};
93pub use ops::dense_mm_f32_f32::{dense_matmul_f32_f32_tensor, DenseMmF32F32Params};
94pub use ops::quantized_matmul::{quantized_matmul, quantized_matmul_simd, QuantizedMatmulParams};
95pub use ops::quantized_matmul_ggml::{
96 dispatch_mm_for_test, quantized_matmul_ggml, quantized_matmul_mm_tensor_perm021,
97 GgmlQuantizedMatmulParams, GgmlQuantizedMatmulPerm021Params, GgmlType,
98 MM_ROUTING_THRESHOLD,
99};
100pub use ops::quantized_matmul_id::{quantized_matmul_id, QuantizedMatmulIdParams};
101pub use ops::quantized_matmul_id_ggml::{
102 dispatch_id_mm_for_test, quantized_matmul_id_ggml, quantized_matmul_id_ggml_pooled,
103 quantized_matmul_id_swiglu_q4_0,
104 GgmlIdMmDispatchParams, GgmlQuantizedMatmulIdParams, IdMmScratch,
105 MM_ID_ROUTING_THRESHOLD,
106};
107
108pub use weight::{
110 load_quantized_weights, safetensors_to_metal_buffer, QuantizationConfig, QuantizedWeight,
111 SafetensorsFile, TensorQuantConfig,
112};
113
114pub use metal::MTLSize;
116pub use metal;
117
118#[cfg(test)]
119#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
120mod tests {
121 use super::*;
122
123 fn _assert_send<T: Send>() {}
125 fn _assert_sync<T: Sync>() {}
126
127 #[allow(dead_code)]
128 fn assert_send_sync() {
129 _assert_send::<MlxDevice>();
130 _assert_sync::<MlxDevice>();
131 _assert_send::<MlxBuffer>();
132 _assert_sync::<MlxBuffer>();
133 _assert_send::<MlxError>();
134 _assert_sync::<MlxError>();
135 }
136
137 #[test]
139 fn test_device_init() {
140 let device = MlxDevice::new().expect("MlxDevice::new() should succeed on Apple Silicon");
141 let name = device.name();
142 assert!(!name.is_empty(), "Device name should not be empty");
143 println!("Metal device: {name}");
144 }
145
146 #[test]
148 fn test_buffer_alloc() {
149 let device = MlxDevice::new().expect("device");
150 let shape = vec![2, 3, 4];
151 let byte_len = 2 * 3 * 4 * DType::F32.size_of(); let buf = device
153 .alloc_buffer(byte_len, DType::F32, shape.clone())
154 .expect("alloc_buffer");
155
156 assert_eq!(buf.dtype(), DType::F32);
157 assert_eq!(buf.shape(), &shape);
158 assert_eq!(buf.byte_len(), byte_len);
159 assert_eq!(buf.element_count(), 24);
160 }
161
162 #[test]
164 fn test_buffer_readwrite() {
165 let device = MlxDevice::new().expect("device");
166 let n = 64;
167 let byte_len = n * std::mem::size_of::<f32>();
168 let mut buf = device
169 .alloc_buffer(byte_len, DType::F32, vec![n])
170 .expect("alloc_buffer");
171
172 {
174 let slice: &mut [f32] = buf.as_mut_slice().expect("as_mut_slice");
175 assert_eq!(slice.len(), n);
176 for (i, val) in slice.iter_mut().enumerate() {
177 *val = i as f32 * 1.5;
178 }
179 }
180
181 {
183 let slice: &[f32] = buf.as_slice().expect("as_slice");
184 for (i, &val) in slice.iter().enumerate() {
185 let expected = i as f32 * 1.5;
186 assert!(
187 (val - expected).abs() < f32::EPSILON,
188 "Mismatch at index {i}: got {val}, expected {expected}"
189 );
190 }
191 }
192 }
193
194 #[test]
196 fn test_encoder_lifecycle() {
197 let device = MlxDevice::new().expect("device");
198 let mut enc = device.command_encoder().expect("command_encoder");
199 enc.commit_and_wait()
201 .expect("commit_and_wait on empty encoder");
202 }
203
204 #[test]
206 fn test_buffer_pool_reuse() {
207 let device = MlxDevice::new().expect("device");
208 let mut pool = MlxBufferPool::new();
209
210 let buf1 = pool
212 .alloc(&device, 1024, DType::F32, vec![256])
213 .expect("pool alloc 1");
214 let buf1_ptr = buf1.contents_ptr();
215 let buf1_byte_len = buf1.byte_len();
216
217 pool.release(buf1);
219 assert_eq!(pool.free_count(), 1);
220
221 let buf2 = pool
223 .alloc(&device, 1024, DType::F32, vec![256])
224 .expect("pool alloc 2");
225 let buf2_ptr = buf2.contents_ptr();
226 let buf2_byte_len = buf2.byte_len();
227
228 assert_eq!(buf1_ptr, buf2_ptr, "Pool should reuse the same Metal buffer");
229 assert_eq!(buf1_byte_len, buf2_byte_len, "Byte lengths should match");
230 assert_eq!(pool.free_count(), 0, "Free list should be empty after reuse");
231 }
232
233 #[test]
235 fn test_kernel_registry_caching() {
236 let device = MlxDevice::new().expect("device");
237 let mut registry = KernelRegistry::new();
238
239 registry.register_source(
241 "test_add",
242 r#"
243 #include <metal_stdlib>
244 using namespace metal;
245 kernel void test_add(
246 device float *a [[buffer(0)]],
247 device float *b [[buffer(1)]],
248 device float *c [[buffer(2)]],
249 uint id [[thread_position_in_grid]]
250 ) {
251 c[id] = a[id] + b[id];
252 }
253 "#,
254 );
255
256 assert!(!registry.is_cached("test_add"));
258 let p1 = registry
259 .get_pipeline("test_add", device.metal_device())
260 .expect("get_pipeline first call");
261 let p1_ptr = p1 as *const _;
262 assert!(registry.is_cached("test_add"));
263
264 let p2 = registry
266 .get_pipeline("test_add", device.metal_device())
267 .expect("get_pipeline second call");
268 let p2_ptr = p2 as *const _;
269
270 assert_eq!(
271 p1_ptr, p2_ptr,
272 "Second get_pipeline call should return the same cached pipeline"
273 );
274 }
275
276 #[test]
278 fn test_buffer_alloc_zero_len_error() {
279 let device = MlxDevice::new().expect("device");
280 let result = device.alloc_buffer(0, DType::F32, vec![]);
281 assert!(result.is_err(), "Zero-length allocation should fail");
282 match result {
283 Err(MlxError::InvalidArgument(_)) => {}
284 other => panic!("Expected InvalidArgument, got {:?}", other),
285 }
286 }
287
288 #[test]
290 fn test_kernel_not_found() {
291 let device = MlxDevice::new().expect("device");
292 let mut registry = KernelRegistry::new();
293 let result = registry.get_pipeline("nonexistent_kernel", device.metal_device());
294 assert!(result.is_err());
295 match result {
296 Err(MlxError::KernelNotFound(name)) => {
297 assert_eq!(name, "nonexistent_kernel");
298 }
299 other => panic!("Expected KernelNotFound, got {:?}", other),
300 }
301 }
302
303 #[test]
305 fn test_dtype_sizes() {
306 assert_eq!(DType::F32.size_of(), 4);
307 assert_eq!(DType::F16.size_of(), 2);
308 assert_eq!(DType::BF16.size_of(), 2);
309 assert_eq!(DType::U8.size_of(), 1);
310 assert_eq!(DType::U16.size_of(), 2);
311 assert_eq!(DType::U32.size_of(), 4);
312 assert_eq!(DType::I32.size_of(), 4);
313 }
314
315 #[test]
317 fn test_buffer_debug() {
318 let device = MlxDevice::new().expect("device");
319 let buf = device
320 .alloc_buffer(64, DType::F16, vec![4, 8])
321 .expect("alloc_buffer");
322 let debug_str = format!("{:?}", buf);
323 assert!(debug_str.contains("MlxBuffer"));
324 assert!(debug_str.contains("F16"));
325 assert!(debug_str.contains("[4, 8]"));
326 }
327
328 #[test]
330 fn test_error_display() {
331 let e = MlxError::DeviceNotFound;
332 assert!(format!("{e}").contains("Metal GPU device"));
333
334 let e = MlxError::ShaderCompilationError {
335 name: "foo".into(),
336 message: "syntax error".into(),
337 };
338 assert!(format!("{e}").contains("foo"));
339 assert!(format!("{e}").contains("syntax error"));
340 }
341
342 #[test]
344 fn test_buffer_pool_size_buckets() {
345 let device = MlxDevice::new().expect("device");
346 let mut pool = MlxBufferPool::new();
347
348 let buf_100 = pool.alloc(&device, 100, DType::U8, vec![100]).expect("alloc 100");
350 assert!(
351 buf_100.byte_len() >= 100,
352 "Buffer should be at least 100 bytes"
353 );
354 pool.release(buf_100);
355
356 let buf_128 = pool.alloc(&device, 128, DType::U8, vec![128]).expect("alloc 128");
358 assert!(buf_128.byte_len() >= 128);
359 pool.release(buf_128);
360
361 let buf_200 = pool.alloc(&device, 200, DType::U8, vec![200]).expect("alloc 200");
363 assert!(buf_200.byte_len() >= 200);
364 pool.release(buf_200);
365
366 assert_eq!(pool.free_count(), 2, "Two different bucket sizes in pool");
367 }
368}