1#![deny(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
41#![allow(unexpected_cfgs)]
45
46#[macro_use]
48mod error;
49mod buffer;
50mod buffer_pool;
51mod device;
52mod dtypes;
53mod encoder;
54mod encoder_session;
55mod kernel_registry;
56mod mem_ranges;
57mod residency;
58pub mod gguf;
59pub mod kernel_profile;
60pub mod graph;
61pub mod metal_capture;
62pub mod ops;
63pub mod turboquant;
64pub mod tq_oracle;
65pub mod weight;
66
67pub use buffer::MlxBuffer;
69pub use buffer_pool::MlxBufferPool;
70pub use device::MlxDevice;
71pub use dtypes::DType;
72pub use encoder::{
73 auto_barrier_concurrent_count, auto_barrier_count, barrier_count, barrier_total_ns,
74 cmd_buf_count, dispatch_count, reset_counters, sync_count, CapturedNode, CapturedOpKind,
75 CommandEncoder, DispatchKind, KernelArg, RecordedBinding,
76};
77pub use encoder_session::EncoderSession;
78pub use mem_ranges::{BufferRange, MemRangeRole, MemRanges};
79pub use error::{MlxError, Result};
80pub use graph::{ComputeGraph, GraphExecutor, GraphSession, OpKind};
81pub use kernel_registry::KernelRegistry;
82#[doc(hidden)]
88pub use residency::{
89 macos_15_or_newer_for_test, reset_residency_env_cache_for_test,
90 reset_residency_test_counters, residency_allocation_count_for_test,
91 residency_commit_call_count_for_test,
92};
93
94pub use gguf::{GgufFile, MetadataValue, TensorInfo};
96
97pub use ops::dense_mm_bf16::{dense_matmul_bf16_f32_tensor, DenseMmBf16F32Params};
99pub use ops::dense_mm_f16::{dense_matmul_f16_f32_tensor, DenseMmF16F32Params};
100pub use ops::dense_mm_f32_f32::{dense_matmul_f32_f32_tensor, DenseMmF32F32Params};
101pub use ops::quantized_matmul::{quantized_matmul, quantized_matmul_simd, QuantizedMatmulParams};
102pub use ops::quantized_matmul_ggml::{
103 dispatch_mm_for_test, quantized_matmul_ggml, quantized_matmul_mm_tensor_perm021,
104 GgmlQuantizedMatmulParams, GgmlQuantizedMatmulPerm021Params, GgmlType,
105 MM_ROUTING_THRESHOLD,
106};
107pub use ops::mul_mv_ext::{mul_mv_ext_dispatch, MulMvExtParams};
108pub use ops::quantized_matmul_id::{
109 quantized_matmul_id, quantized_matmul_id_into, QuantizedMatmulIdParams,
110};
111pub use ops::quantized_matmul_id_ggml::{
112 dispatch_id_mm_for_test, quantized_matmul_id_ggml, quantized_matmul_id_ggml_pooled,
113 quantized_matmul_id_swiglu_q4_0,
114 GgmlIdMmDispatchParams, GgmlQuantizedMatmulIdParams, IdMmScratch,
115 MM_ID_ROUTING_THRESHOLD,
116};
117
118pub use weight::{
120 load_quantized_weights, safetensors_to_metal_buffer, QuantizationConfig, QuantizedWeight,
121 SafetensorsFile, TensorQuantConfig,
122};
123
124pub use metal::MTLSize;
126pub use metal;
127
128#[cfg(test)]
129#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
130mod tests {
131 use super::*;
132
133 fn _assert_send<T: Send>() {}
135 fn _assert_sync<T: Sync>() {}
136
137 #[allow(dead_code)]
138 fn assert_send_sync() {
139 _assert_send::<MlxDevice>();
140 _assert_sync::<MlxDevice>();
141 _assert_send::<MlxBuffer>();
142 _assert_sync::<MlxBuffer>();
143 _assert_send::<MlxError>();
144 _assert_sync::<MlxError>();
145 }
146
147 #[test]
149 fn test_device_init() {
150 let device = MlxDevice::new().expect("MlxDevice::new() should succeed on Apple Silicon");
151 let name = device.name();
152 assert!(!name.is_empty(), "Device name should not be empty");
153 println!("Metal device: {name}");
154 }
155
156 #[test]
158 fn test_buffer_alloc() {
159 let device = MlxDevice::new().expect("device");
160 let shape = vec![2, 3, 4];
161 let byte_len = 2 * 3 * 4 * DType::F32.size_of(); let buf = device
163 .alloc_buffer(byte_len, DType::F32, shape.clone())
164 .expect("alloc_buffer");
165
166 assert_eq!(buf.dtype(), DType::F32);
167 assert_eq!(buf.shape(), &shape);
168 assert_eq!(buf.byte_len(), byte_len);
169 assert_eq!(buf.element_count(), 24);
170 }
171
172 #[test]
174 fn test_buffer_readwrite() {
175 let device = MlxDevice::new().expect("device");
176 let n = 64;
177 let byte_len = n * std::mem::size_of::<f32>();
178 let mut buf = device
179 .alloc_buffer(byte_len, DType::F32, vec![n])
180 .expect("alloc_buffer");
181
182 {
184 let slice: &mut [f32] = buf.as_mut_slice().expect("as_mut_slice");
185 assert_eq!(slice.len(), n);
186 for (i, val) in slice.iter_mut().enumerate() {
187 *val = i as f32 * 1.5;
188 }
189 }
190
191 {
193 let slice: &[f32] = buf.as_slice().expect("as_slice");
194 for (i, &val) in slice.iter().enumerate() {
195 let expected = i as f32 * 1.5;
196 assert!(
197 (val - expected).abs() < f32::EPSILON,
198 "Mismatch at index {i}: got {val}, expected {expected}"
199 );
200 }
201 }
202 }
203
204 #[test]
206 fn test_encoder_lifecycle() {
207 let device = MlxDevice::new().expect("device");
208 let mut enc = device.command_encoder().expect("command_encoder");
209 enc.commit_and_wait()
211 .expect("commit_and_wait on empty encoder");
212 }
213
214 #[test]
216 fn test_buffer_pool_reuse() {
217 let device = MlxDevice::new().expect("device");
218 let mut pool = MlxBufferPool::new();
219
220 let buf1 = pool
222 .alloc(&device, 1024, DType::F32, vec![256])
223 .expect("pool alloc 1");
224 let buf1_ptr = buf1.contents_ptr();
225 let buf1_byte_len = buf1.byte_len();
226
227 pool.release(buf1);
229 assert_eq!(pool.free_count(), 1);
230
231 let buf2 = pool
233 .alloc(&device, 1024, DType::F32, vec![256])
234 .expect("pool alloc 2");
235 let buf2_ptr = buf2.contents_ptr();
236 let buf2_byte_len = buf2.byte_len();
237
238 assert_eq!(buf1_ptr, buf2_ptr, "Pool should reuse the same Metal buffer");
239 assert_eq!(buf1_byte_len, buf2_byte_len, "Byte lengths should match");
240 assert_eq!(pool.free_count(), 0, "Free list should be empty after reuse");
241 }
242
243 #[test]
245 fn test_kernel_registry_caching() {
246 let device = MlxDevice::new().expect("device");
247 let mut registry = KernelRegistry::new();
248
249 registry.register_source(
251 "test_add",
252 r#"
253 #include <metal_stdlib>
254 using namespace metal;
255 kernel void test_add(
256 device float *a [[buffer(0)]],
257 device float *b [[buffer(1)]],
258 device float *c [[buffer(2)]],
259 uint id [[thread_position_in_grid]]
260 ) {
261 c[id] = a[id] + b[id];
262 }
263 "#,
264 );
265
266 assert!(!registry.is_cached("test_add"));
268 let p1 = registry
269 .get_pipeline("test_add", device.metal_device())
270 .expect("get_pipeline first call");
271 let p1_ptr = p1 as *const _;
272 assert!(registry.is_cached("test_add"));
273
274 let p2 = registry
276 .get_pipeline("test_add", device.metal_device())
277 .expect("get_pipeline second call");
278 let p2_ptr = p2 as *const _;
279
280 assert_eq!(
281 p1_ptr, p2_ptr,
282 "Second get_pipeline call should return the same cached pipeline"
283 );
284 }
285
286 #[test]
288 fn test_buffer_alloc_zero_len_error() {
289 let device = MlxDevice::new().expect("device");
290 let result = device.alloc_buffer(0, DType::F32, vec![]);
291 assert!(result.is_err(), "Zero-length allocation should fail");
292 match result {
293 Err(MlxError::InvalidArgument(_)) => {}
294 other => panic!("Expected InvalidArgument, got {:?}", other),
295 }
296 }
297
298 #[test]
300 fn test_kernel_not_found() {
301 let device = MlxDevice::new().expect("device");
302 let mut registry = KernelRegistry::new();
303 let result = registry.get_pipeline("nonexistent_kernel", device.metal_device());
304 assert!(result.is_err());
305 match result {
306 Err(MlxError::KernelNotFound(name)) => {
307 assert_eq!(name, "nonexistent_kernel");
308 }
309 other => panic!("Expected KernelNotFound, got {:?}", other),
310 }
311 }
312
313 #[test]
315 fn test_dtype_sizes() {
316 assert_eq!(DType::F32.size_of(), 4);
317 assert_eq!(DType::F16.size_of(), 2);
318 assert_eq!(DType::BF16.size_of(), 2);
319 assert_eq!(DType::U8.size_of(), 1);
320 assert_eq!(DType::U16.size_of(), 2);
321 assert_eq!(DType::U32.size_of(), 4);
322 assert_eq!(DType::I32.size_of(), 4);
323 }
324
325 #[test]
327 fn test_buffer_debug() {
328 let device = MlxDevice::new().expect("device");
329 let buf = device
330 .alloc_buffer(64, DType::F16, vec![4, 8])
331 .expect("alloc_buffer");
332 let debug_str = format!("{:?}", buf);
333 assert!(debug_str.contains("MlxBuffer"));
334 assert!(debug_str.contains("F16"));
335 assert!(debug_str.contains("[4, 8]"));
336 }
337
338 #[test]
340 fn test_error_display() {
341 let e = MlxError::DeviceNotFound;
342 assert!(format!("{e}").contains("Metal GPU device"));
343
344 let e = MlxError::ShaderCompilationError {
345 name: "foo".into(),
346 message: "syntax error".into(),
347 };
348 assert!(format!("{e}").contains("foo"));
349 assert!(format!("{e}").contains("syntax error"));
350 }
351
352 #[test]
354 fn test_buffer_pool_size_buckets() {
355 let device = MlxDevice::new().expect("device");
356 let mut pool = MlxBufferPool::new();
357
358 let buf_100 = pool.alloc(&device, 100, DType::U8, vec![100]).expect("alloc 100");
360 assert!(
361 buf_100.byte_len() >= 100,
362 "Buffer should be at least 100 bytes"
363 );
364 pool.release(buf_100);
365
366 let buf_128 = pool.alloc(&device, 128, DType::U8, vec![128]).expect("alloc 128");
368 assert!(buf_128.byte_len() >= 128);
369 pool.release(buf_128);
370
371 let buf_200 = pool.alloc(&device, 200, DType::U8, vec![200]).expect("alloc 200");
373 assert!(buf_200.byte_len() >= 200);
374 pool.release(buf_200);
375
376 assert_eq!(pool.free_count(), 2, "Two different bucket sizes in pool");
377 }
378}