1#![deny(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
41#![allow(unexpected_cfgs)]
45
46#[macro_use]
48mod error;
49mod buffer;
50mod buffer_pool;
51mod device;
52mod dtypes;
53mod encoder;
54mod encoder_session;
55pub mod encoder_worker;
56mod env_flags;
57mod kernel_registry;
58mod mem_ranges;
59mod residency;
60pub mod gguf;
61pub mod kernel_profile;
62pub mod graph;
63pub mod metal_capture;
64pub mod ops;
65pub mod turboquant;
66pub mod tq_oracle;
67pub mod weight;
68
69pub use buffer::MlxBuffer;
71pub use buffer_pool::MlxBufferPool;
72pub use device::MlxDevice;
73pub use dtypes::DType;
74pub use encoder::{
75 auto_barrier_concurrent_count, auto_barrier_count, barrier_count, barrier_total_ns,
76 cmd_buf_count, dispatch_count, pipeline_dispatch_buckets, reset_counters,
77 reset_pipeline_dispatch_buckets, sync_count, CapturedNode, CapturedOpKind,
78 CommandEncoder, DispatchKind, KernelArg, RecordedBinding,
79};
80pub use encoder_session::EncoderSession;
81pub use mem_ranges::{BufferRange, MemRangeRole, MemRanges};
82pub use error::{MlxError, Result};
83pub use graph::{ComputeGraph, GraphExecutor, GraphSession, OpKind};
84pub use kernel_registry::KernelRegistry;
85#[doc(hidden)]
91pub use residency::{
92 macos_15_or_newer_for_test, reset_residency_env_cache_for_test,
93 reset_residency_test_counters, residency_allocation_count_for_test,
94 residency_commit_call_count_for_test,
95};
96
97pub use gguf::{GgufFile, MetadataValue, TensorInfo};
99
100pub use ops::dense_mm_bf16::{dense_matmul_bf16_f32_tensor, DenseMmBf16F32Params};
102pub use ops::dense_mm_f16::{dense_matmul_f16_f32_tensor, DenseMmF16F32Params};
103pub use ops::dense_mm_f32_f32::{dense_matmul_f32_f32_tensor, DenseMmF32F32Params};
104pub use ops::quantized_matmul::{quantized_matmul, quantized_matmul_simd, QuantizedMatmulParams};
105pub use ops::quantized_matmul_ggml::{
106 dispatch_mm_for_test, quantized_matmul_ggml, quantized_matmul_mm_tensor_perm021,
107 quantized_matmul_mm_tensor_perm021_f16,
108 GgmlQuantizedMatmulParams, GgmlQuantizedMatmulPerm021Params, GgmlType,
109 MM_ROUTING_THRESHOLD,
110};
111pub use ops::mul_mv_ext::{mul_mv_ext_dispatch, MulMvExtParams};
112pub use ops::quantized_matmul_id::{
113 quantized_matmul_id, quantized_matmul_id_into, QuantizedMatmulIdParams,
114};
115pub use ops::quantized_matmul_id_ggml::{
116 dispatch_id_mm_for_test, quantized_matmul_id_ggml, quantized_matmul_id_ggml_pooled,
117 quantized_matmul_id_swiglu_q4_0,
118 GgmlIdMmDispatchParams, GgmlQuantizedMatmulIdParams, IdMmScratch,
119 MM_ID_ROUTING_THRESHOLD,
120};
121
122pub use weight::{
124 load_quantized_weights, safetensors_to_metal_buffer, QuantizationConfig, QuantizedWeight,
125 SafetensorsFile, TensorQuantConfig,
126};
127
128pub use metal::MTLSize;
130pub use metal;
131
132#[cfg(test)]
133#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
134mod tests {
135 use super::*;
136
137 fn _assert_send<T: Send>() {}
139 fn _assert_sync<T: Sync>() {}
140
141 #[allow(dead_code)]
142 fn assert_send_sync() {
143 _assert_send::<MlxDevice>();
144 _assert_sync::<MlxDevice>();
145 _assert_send::<MlxBuffer>();
146 _assert_sync::<MlxBuffer>();
147 _assert_send::<MlxError>();
148 _assert_sync::<MlxError>();
149 }
150
151 #[test]
153 fn test_device_init() {
154 let device = MlxDevice::new().expect("MlxDevice::new() should succeed on Apple Silicon");
155 let name = device.name();
156 assert!(!name.is_empty(), "Device name should not be empty");
157 println!("Metal device: {name}");
158 }
159
160 #[test]
162 fn test_buffer_alloc() {
163 let device = MlxDevice::new().expect("device");
164 let shape = vec![2, 3, 4];
165 let byte_len = 2 * 3 * 4 * DType::F32.size_of(); let buf = device
167 .alloc_buffer(byte_len, DType::F32, shape.clone())
168 .expect("alloc_buffer");
169
170 assert_eq!(buf.dtype(), DType::F32);
171 assert_eq!(buf.shape(), &shape);
172 assert_eq!(buf.byte_len(), byte_len);
173 assert_eq!(buf.element_count(), 24);
174 }
175
176 #[test]
178 fn test_buffer_readwrite() {
179 let device = MlxDevice::new().expect("device");
180 let n = 64;
181 let byte_len = n * std::mem::size_of::<f32>();
182 let mut buf = device
183 .alloc_buffer(byte_len, DType::F32, vec![n])
184 .expect("alloc_buffer");
185
186 {
188 let slice: &mut [f32] = buf.as_mut_slice().expect("as_mut_slice");
189 assert_eq!(slice.len(), n);
190 for (i, val) in slice.iter_mut().enumerate() {
191 *val = i as f32 * 1.5;
192 }
193 }
194
195 {
197 let slice: &[f32] = buf.as_slice().expect("as_slice");
198 for (i, &val) in slice.iter().enumerate() {
199 let expected = i as f32 * 1.5;
200 assert!(
201 (val - expected).abs() < f32::EPSILON,
202 "Mismatch at index {i}: got {val}, expected {expected}"
203 );
204 }
205 }
206 }
207
208 #[test]
210 fn test_encoder_lifecycle() {
211 let device = MlxDevice::new().expect("device");
212 let mut enc = device.command_encoder().expect("command_encoder");
213 enc.commit_and_wait()
215 .expect("commit_and_wait on empty encoder");
216 }
217
218 #[test]
220 fn test_buffer_pool_reuse() {
221 let device = MlxDevice::new().expect("device");
222 let mut pool = MlxBufferPool::new();
223
224 let buf1 = pool
226 .alloc(&device, 1024, DType::F32, vec![256])
227 .expect("pool alloc 1");
228 let buf1_ptr = buf1.contents_ptr();
229 let buf1_byte_len = buf1.byte_len();
230
231 pool.release(buf1);
233 assert_eq!(pool.free_count(), 1);
234
235 let buf2 = pool
237 .alloc(&device, 1024, DType::F32, vec![256])
238 .expect("pool alloc 2");
239 let buf2_ptr = buf2.contents_ptr();
240 let buf2_byte_len = buf2.byte_len();
241
242 assert_eq!(buf1_ptr, buf2_ptr, "Pool should reuse the same Metal buffer");
243 assert_eq!(buf1_byte_len, buf2_byte_len, "Byte lengths should match");
244 assert_eq!(pool.free_count(), 0, "Free list should be empty after reuse");
245 }
246
247 #[test]
249 fn test_kernel_registry_caching() {
250 let device = MlxDevice::new().expect("device");
251 let mut registry = KernelRegistry::new();
252
253 registry.register_source(
255 "test_add",
256 r#"
257 #include <metal_stdlib>
258 using namespace metal;
259 kernel void test_add(
260 device float *a [[buffer(0)]],
261 device float *b [[buffer(1)]],
262 device float *c [[buffer(2)]],
263 uint id [[thread_position_in_grid]]
264 ) {
265 c[id] = a[id] + b[id];
266 }
267 "#,
268 );
269
270 assert!(!registry.is_cached("test_add"));
272 let p1 = registry
273 .get_pipeline("test_add", device.metal_device())
274 .expect("get_pipeline first call");
275 let p1_ptr = p1 as *const _;
276 assert!(registry.is_cached("test_add"));
277
278 let p2 = registry
280 .get_pipeline("test_add", device.metal_device())
281 .expect("get_pipeline second call");
282 let p2_ptr = p2 as *const _;
283
284 assert_eq!(
285 p1_ptr, p2_ptr,
286 "Second get_pipeline call should return the same cached pipeline"
287 );
288 }
289
290 #[test]
292 fn test_buffer_alloc_zero_len_error() {
293 let device = MlxDevice::new().expect("device");
294 let result = device.alloc_buffer(0, DType::F32, vec![]);
295 assert!(result.is_err(), "Zero-length allocation should fail");
296 match result {
297 Err(MlxError::InvalidArgument(_)) => {}
298 other => panic!("Expected InvalidArgument, got {:?}", other),
299 }
300 }
301
302 #[test]
304 fn test_kernel_not_found() {
305 let device = MlxDevice::new().expect("device");
306 let mut registry = KernelRegistry::new();
307 let result = registry.get_pipeline("nonexistent_kernel", device.metal_device());
308 assert!(result.is_err());
309 match result {
310 Err(MlxError::KernelNotFound(name)) => {
311 assert_eq!(name, "nonexistent_kernel");
312 }
313 other => panic!("Expected KernelNotFound, got {:?}", other),
314 }
315 }
316
317 #[test]
319 fn test_dtype_sizes() {
320 assert_eq!(DType::F32.size_of(), 4);
321 assert_eq!(DType::F16.size_of(), 2);
322 assert_eq!(DType::BF16.size_of(), 2);
323 assert_eq!(DType::U8.size_of(), 1);
324 assert_eq!(DType::U16.size_of(), 2);
325 assert_eq!(DType::U32.size_of(), 4);
326 assert_eq!(DType::I32.size_of(), 4);
327 }
328
329 #[test]
331 fn test_buffer_debug() {
332 let device = MlxDevice::new().expect("device");
333 let buf = device
334 .alloc_buffer(64, DType::F16, vec![4, 8])
335 .expect("alloc_buffer");
336 let debug_str = format!("{:?}", buf);
337 assert!(debug_str.contains("MlxBuffer"));
338 assert!(debug_str.contains("F16"));
339 assert!(debug_str.contains("[4, 8]"));
340 }
341
342 #[test]
344 fn test_error_display() {
345 let e = MlxError::DeviceNotFound;
346 assert!(format!("{e}").contains("Metal GPU device"));
347
348 let e = MlxError::ShaderCompilationError {
349 name: "foo".into(),
350 message: "syntax error".into(),
351 };
352 assert!(format!("{e}").contains("foo"));
353 assert!(format!("{e}").contains("syntax error"));
354 }
355
356 #[test]
358 fn test_buffer_pool_size_buckets() {
359 let device = MlxDevice::new().expect("device");
360 let mut pool = MlxBufferPool::new();
361
362 let buf_100 = pool.alloc(&device, 100, DType::U8, vec![100]).expect("alloc 100");
364 assert!(
365 buf_100.byte_len() >= 100,
366 "Buffer should be at least 100 bytes"
367 );
368 pool.release(buf_100);
369
370 let buf_128 = pool.alloc(&device, 128, DType::U8, vec![128]).expect("alloc 128");
372 assert!(buf_128.byte_len() >= 128);
373 pool.release(buf_128);
374
375 let buf_200 = pool.alloc(&device, 200, DType::U8, vec![200]).expect("alloc 200");
377 assert!(buf_200.byte_len() >= 200);
378 pool.release(buf_200);
379
380 assert_eq!(pool.free_count(), 2, "Two different bucket sizes in pool");
381 }
382}