1#![deny(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
41#![allow(unexpected_cfgs)]
45
46#[macro_use]
48mod error;
49mod buffer;
50mod buffer_pool;
51mod device;
52mod dtypes;
53mod encoder;
54mod encoder_session;
55mod env_flags;
56mod kernel_registry;
57mod mem_ranges;
58mod residency;
59pub mod gguf;
60pub mod kernel_profile;
61pub mod graph;
62pub mod metal_capture;
63pub mod ops;
64pub mod turboquant;
65pub mod tq_oracle;
66pub mod weight;
67
68pub use buffer::MlxBuffer;
70pub use buffer_pool::MlxBufferPool;
71pub use device::MlxDevice;
72pub use dtypes::DType;
73pub use encoder::{
74 auto_barrier_concurrent_count, auto_barrier_count, barrier_count, barrier_total_ns,
75 cmd_buf_count, dispatch_count, pipeline_dispatch_buckets, reset_counters,
76 reset_pipeline_dispatch_buckets, sync_count, CapturedNode, CapturedOpKind,
77 CommandEncoder, DispatchKind, KernelArg, RecordedBinding,
78};
79pub use encoder_session::EncoderSession;
80pub use mem_ranges::{BufferRange, MemRangeRole, MemRanges};
81pub use error::{MlxError, Result};
82pub use graph::{ComputeGraph, GraphExecutor, GraphSession, OpKind};
83pub use kernel_registry::KernelRegistry;
84#[doc(hidden)]
90pub use residency::{
91 macos_15_or_newer_for_test, reset_residency_env_cache_for_test,
92 reset_residency_test_counters, residency_allocation_count_for_test,
93 residency_commit_call_count_for_test,
94};
95
96pub use gguf::{GgufFile, MetadataValue, TensorInfo};
98
99pub use ops::dense_mm_bf16::{dense_matmul_bf16_f32_tensor, DenseMmBf16F32Params};
101pub use ops::dense_mm_f16::{dense_matmul_f16_f32_tensor, DenseMmF16F32Params};
102pub use ops::dense_mm_f32_f32::{dense_matmul_f32_f32_tensor, DenseMmF32F32Params};
103pub use ops::quantized_matmul::{quantized_matmul, quantized_matmul_simd, QuantizedMatmulParams};
104pub use ops::quantized_matmul_ggml::{
105 dispatch_mm_for_test, quantized_matmul_ggml, quantized_matmul_mm_tensor_perm021,
106 GgmlQuantizedMatmulParams, GgmlQuantizedMatmulPerm021Params, GgmlType,
107 MM_ROUTING_THRESHOLD,
108};
109pub use ops::mul_mv_ext::{mul_mv_ext_dispatch, MulMvExtParams};
110pub use ops::quantized_matmul_id::{
111 quantized_matmul_id, quantized_matmul_id_into, QuantizedMatmulIdParams,
112};
113pub use ops::quantized_matmul_id_ggml::{
114 dispatch_id_mm_for_test, quantized_matmul_id_ggml, quantized_matmul_id_ggml_pooled,
115 quantized_matmul_id_swiglu_q4_0,
116 GgmlIdMmDispatchParams, GgmlQuantizedMatmulIdParams, IdMmScratch,
117 MM_ID_ROUTING_THRESHOLD,
118};
119
120pub use weight::{
122 load_quantized_weights, safetensors_to_metal_buffer, QuantizationConfig, QuantizedWeight,
123 SafetensorsFile, TensorQuantConfig,
124};
125
126pub use metal::MTLSize;
128pub use metal;
129
130#[cfg(test)]
131#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
132mod tests {
133 use super::*;
134
135 fn _assert_send<T: Send>() {}
137 fn _assert_sync<T: Sync>() {}
138
139 #[allow(dead_code)]
140 fn assert_send_sync() {
141 _assert_send::<MlxDevice>();
142 _assert_sync::<MlxDevice>();
143 _assert_send::<MlxBuffer>();
144 _assert_sync::<MlxBuffer>();
145 _assert_send::<MlxError>();
146 _assert_sync::<MlxError>();
147 }
148
149 #[test]
151 fn test_device_init() {
152 let device = MlxDevice::new().expect("MlxDevice::new() should succeed on Apple Silicon");
153 let name = device.name();
154 assert!(!name.is_empty(), "Device name should not be empty");
155 println!("Metal device: {name}");
156 }
157
158 #[test]
160 fn test_buffer_alloc() {
161 let device = MlxDevice::new().expect("device");
162 let shape = vec![2, 3, 4];
163 let byte_len = 2 * 3 * 4 * DType::F32.size_of(); let buf = device
165 .alloc_buffer(byte_len, DType::F32, shape.clone())
166 .expect("alloc_buffer");
167
168 assert_eq!(buf.dtype(), DType::F32);
169 assert_eq!(buf.shape(), &shape);
170 assert_eq!(buf.byte_len(), byte_len);
171 assert_eq!(buf.element_count(), 24);
172 }
173
174 #[test]
176 fn test_buffer_readwrite() {
177 let device = MlxDevice::new().expect("device");
178 let n = 64;
179 let byte_len = n * std::mem::size_of::<f32>();
180 let mut buf = device
181 .alloc_buffer(byte_len, DType::F32, vec![n])
182 .expect("alloc_buffer");
183
184 {
186 let slice: &mut [f32] = buf.as_mut_slice().expect("as_mut_slice");
187 assert_eq!(slice.len(), n);
188 for (i, val) in slice.iter_mut().enumerate() {
189 *val = i as f32 * 1.5;
190 }
191 }
192
193 {
195 let slice: &[f32] = buf.as_slice().expect("as_slice");
196 for (i, &val) in slice.iter().enumerate() {
197 let expected = i as f32 * 1.5;
198 assert!(
199 (val - expected).abs() < f32::EPSILON,
200 "Mismatch at index {i}: got {val}, expected {expected}"
201 );
202 }
203 }
204 }
205
206 #[test]
208 fn test_encoder_lifecycle() {
209 let device = MlxDevice::new().expect("device");
210 let mut enc = device.command_encoder().expect("command_encoder");
211 enc.commit_and_wait()
213 .expect("commit_and_wait on empty encoder");
214 }
215
216 #[test]
218 fn test_buffer_pool_reuse() {
219 let device = MlxDevice::new().expect("device");
220 let mut pool = MlxBufferPool::new();
221
222 let buf1 = pool
224 .alloc(&device, 1024, DType::F32, vec![256])
225 .expect("pool alloc 1");
226 let buf1_ptr = buf1.contents_ptr();
227 let buf1_byte_len = buf1.byte_len();
228
229 pool.release(buf1);
231 assert_eq!(pool.free_count(), 1);
232
233 let buf2 = pool
235 .alloc(&device, 1024, DType::F32, vec![256])
236 .expect("pool alloc 2");
237 let buf2_ptr = buf2.contents_ptr();
238 let buf2_byte_len = buf2.byte_len();
239
240 assert_eq!(buf1_ptr, buf2_ptr, "Pool should reuse the same Metal buffer");
241 assert_eq!(buf1_byte_len, buf2_byte_len, "Byte lengths should match");
242 assert_eq!(pool.free_count(), 0, "Free list should be empty after reuse");
243 }
244
245 #[test]
247 fn test_kernel_registry_caching() {
248 let device = MlxDevice::new().expect("device");
249 let mut registry = KernelRegistry::new();
250
251 registry.register_source(
253 "test_add",
254 r#"
255 #include <metal_stdlib>
256 using namespace metal;
257 kernel void test_add(
258 device float *a [[buffer(0)]],
259 device float *b [[buffer(1)]],
260 device float *c [[buffer(2)]],
261 uint id [[thread_position_in_grid]]
262 ) {
263 c[id] = a[id] + b[id];
264 }
265 "#,
266 );
267
268 assert!(!registry.is_cached("test_add"));
270 let p1 = registry
271 .get_pipeline("test_add", device.metal_device())
272 .expect("get_pipeline first call");
273 let p1_ptr = p1 as *const _;
274 assert!(registry.is_cached("test_add"));
275
276 let p2 = registry
278 .get_pipeline("test_add", device.metal_device())
279 .expect("get_pipeline second call");
280 let p2_ptr = p2 as *const _;
281
282 assert_eq!(
283 p1_ptr, p2_ptr,
284 "Second get_pipeline call should return the same cached pipeline"
285 );
286 }
287
288 #[test]
290 fn test_buffer_alloc_zero_len_error() {
291 let device = MlxDevice::new().expect("device");
292 let result = device.alloc_buffer(0, DType::F32, vec![]);
293 assert!(result.is_err(), "Zero-length allocation should fail");
294 match result {
295 Err(MlxError::InvalidArgument(_)) => {}
296 other => panic!("Expected InvalidArgument, got {:?}", other),
297 }
298 }
299
300 #[test]
302 fn test_kernel_not_found() {
303 let device = MlxDevice::new().expect("device");
304 let mut registry = KernelRegistry::new();
305 let result = registry.get_pipeline("nonexistent_kernel", device.metal_device());
306 assert!(result.is_err());
307 match result {
308 Err(MlxError::KernelNotFound(name)) => {
309 assert_eq!(name, "nonexistent_kernel");
310 }
311 other => panic!("Expected KernelNotFound, got {:?}", other),
312 }
313 }
314
315 #[test]
317 fn test_dtype_sizes() {
318 assert_eq!(DType::F32.size_of(), 4);
319 assert_eq!(DType::F16.size_of(), 2);
320 assert_eq!(DType::BF16.size_of(), 2);
321 assert_eq!(DType::U8.size_of(), 1);
322 assert_eq!(DType::U16.size_of(), 2);
323 assert_eq!(DType::U32.size_of(), 4);
324 assert_eq!(DType::I32.size_of(), 4);
325 }
326
327 #[test]
329 fn test_buffer_debug() {
330 let device = MlxDevice::new().expect("device");
331 let buf = device
332 .alloc_buffer(64, DType::F16, vec![4, 8])
333 .expect("alloc_buffer");
334 let debug_str = format!("{:?}", buf);
335 assert!(debug_str.contains("MlxBuffer"));
336 assert!(debug_str.contains("F16"));
337 assert!(debug_str.contains("[4, 8]"));
338 }
339
340 #[test]
342 fn test_error_display() {
343 let e = MlxError::DeviceNotFound;
344 assert!(format!("{e}").contains("Metal GPU device"));
345
346 let e = MlxError::ShaderCompilationError {
347 name: "foo".into(),
348 message: "syntax error".into(),
349 };
350 assert!(format!("{e}").contains("foo"));
351 assert!(format!("{e}").contains("syntax error"));
352 }
353
354 #[test]
356 fn test_buffer_pool_size_buckets() {
357 let device = MlxDevice::new().expect("device");
358 let mut pool = MlxBufferPool::new();
359
360 let buf_100 = pool.alloc(&device, 100, DType::U8, vec![100]).expect("alloc 100");
362 assert!(
363 buf_100.byte_len() >= 100,
364 "Buffer should be at least 100 bytes"
365 );
366 pool.release(buf_100);
367
368 let buf_128 = pool.alloc(&device, 128, DType::U8, vec![128]).expect("alloc 128");
370 assert!(buf_128.byte_len() >= 128);
371 pool.release(buf_128);
372
373 let buf_200 = pool.alloc(&device, 200, DType::U8, vec![200]).expect("alloc 200");
375 assert!(buf_200.byte_len() >= 200);
376 pool.release(buf_200);
377
378 assert_eq!(pool.free_count(), 2, "Two different bucket sizes in pool");
379 }
380}