1#![deny(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
41#![allow(unexpected_cfgs)]
45
46#[macro_use]
48mod error;
49mod buffer;
50mod buffer_pool;
51mod device;
52mod dtypes;
53mod encoder;
54mod kernel_registry;
55pub mod gguf;
56pub mod graph;
57pub mod ops;
58pub mod turboquant;
59pub mod weight;
60
61pub use buffer::MlxBuffer;
63pub use buffer_pool::MlxBufferPool;
64pub use device::MlxDevice;
65pub use dtypes::DType;
66pub use encoder::{
67 dispatch_count, reset_counters, sync_count, CapturedNode, CommandEncoder, DispatchKind,
68 RecordedBinding,
69};
70pub use error::{MlxError, Result};
71pub use graph::{ComputeGraph, GraphExecutor, GraphSession, OpKind};
72pub use kernel_registry::KernelRegistry;
73
74pub use gguf::{GgufFile, MetadataValue, TensorInfo};
76
77pub use ops::dense_mm_bf16::{dense_matmul_bf16_f32_tensor, DenseMmBf16F32Params};
79pub use ops::dense_mm_f16::{dense_matmul_f16_f32_tensor, DenseMmF16F32Params};
80pub use ops::dense_mm_f32_f32::{dense_matmul_f32_f32_tensor, DenseMmF32F32Params};
81pub use ops::quantized_matmul::{quantized_matmul, quantized_matmul_simd, QuantizedMatmulParams};
82pub use ops::quantized_matmul_ggml::{
83 dispatch_mm_for_test, quantized_matmul_ggml, quantized_matmul_mm_tensor_perm021,
84 GgmlQuantizedMatmulParams, GgmlQuantizedMatmulPerm021Params, GgmlType,
85 MM_ROUTING_THRESHOLD,
86};
87pub use ops::quantized_matmul_id::{quantized_matmul_id, QuantizedMatmulIdParams};
88pub use ops::quantized_matmul_id_ggml::{
89 dispatch_id_mm_for_test, quantized_matmul_id_ggml, quantized_matmul_id_ggml_pooled,
90 GgmlIdMmDispatchParams, GgmlQuantizedMatmulIdParams, IdMmScratch,
91 MM_ID_ROUTING_THRESHOLD,
92};
93
94pub use weight::{
96 load_quantized_weights, safetensors_to_metal_buffer, QuantizationConfig, QuantizedWeight,
97 SafetensorsFile, TensorQuantConfig,
98};
99
100pub use metal::MTLSize;
102pub use metal;
103
104#[cfg(test)]
105#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
106mod tests {
107 use super::*;
108
109 fn _assert_send<T: Send>() {}
111 fn _assert_sync<T: Sync>() {}
112
113 #[allow(dead_code)]
114 fn assert_send_sync() {
115 _assert_send::<MlxDevice>();
116 _assert_sync::<MlxDevice>();
117 _assert_send::<MlxBuffer>();
118 _assert_sync::<MlxBuffer>();
119 _assert_send::<MlxError>();
120 _assert_sync::<MlxError>();
121 }
122
123 #[test]
125 fn test_device_init() {
126 let device = MlxDevice::new().expect("MlxDevice::new() should succeed on Apple Silicon");
127 let name = device.name();
128 assert!(!name.is_empty(), "Device name should not be empty");
129 println!("Metal device: {name}");
130 }
131
132 #[test]
134 fn test_buffer_alloc() {
135 let device = MlxDevice::new().expect("device");
136 let shape = vec![2, 3, 4];
137 let byte_len = 2 * 3 * 4 * DType::F32.size_of(); let buf = device
139 .alloc_buffer(byte_len, DType::F32, shape.clone())
140 .expect("alloc_buffer");
141
142 assert_eq!(buf.dtype(), DType::F32);
143 assert_eq!(buf.shape(), &shape);
144 assert_eq!(buf.byte_len(), byte_len);
145 assert_eq!(buf.element_count(), 24);
146 }
147
148 #[test]
150 fn test_buffer_readwrite() {
151 let device = MlxDevice::new().expect("device");
152 let n = 64;
153 let byte_len = n * std::mem::size_of::<f32>();
154 let mut buf = device
155 .alloc_buffer(byte_len, DType::F32, vec![n])
156 .expect("alloc_buffer");
157
158 {
160 let slice: &mut [f32] = buf.as_mut_slice().expect("as_mut_slice");
161 assert_eq!(slice.len(), n);
162 for (i, val) in slice.iter_mut().enumerate() {
163 *val = i as f32 * 1.5;
164 }
165 }
166
167 {
169 let slice: &[f32] = buf.as_slice().expect("as_slice");
170 for (i, &val) in slice.iter().enumerate() {
171 let expected = i as f32 * 1.5;
172 assert!(
173 (val - expected).abs() < f32::EPSILON,
174 "Mismatch at index {i}: got {val}, expected {expected}"
175 );
176 }
177 }
178 }
179
180 #[test]
182 fn test_encoder_lifecycle() {
183 let device = MlxDevice::new().expect("device");
184 let mut enc = device.command_encoder().expect("command_encoder");
185 enc.commit_and_wait()
187 .expect("commit_and_wait on empty encoder");
188 }
189
190 #[test]
192 fn test_buffer_pool_reuse() {
193 let device = MlxDevice::new().expect("device");
194 let mut pool = MlxBufferPool::new();
195
196 let buf1 = pool
198 .alloc(&device, 1024, DType::F32, vec![256])
199 .expect("pool alloc 1");
200 let buf1_ptr = buf1.contents_ptr();
201 let buf1_byte_len = buf1.byte_len();
202
203 pool.release(buf1);
205 assert_eq!(pool.free_count(), 1);
206
207 let buf2 = pool
209 .alloc(&device, 1024, DType::F32, vec![256])
210 .expect("pool alloc 2");
211 let buf2_ptr = buf2.contents_ptr();
212 let buf2_byte_len = buf2.byte_len();
213
214 assert_eq!(buf1_ptr, buf2_ptr, "Pool should reuse the same Metal buffer");
215 assert_eq!(buf1_byte_len, buf2_byte_len, "Byte lengths should match");
216 assert_eq!(pool.free_count(), 0, "Free list should be empty after reuse");
217 }
218
219 #[test]
221 fn test_kernel_registry_caching() {
222 let device = MlxDevice::new().expect("device");
223 let mut registry = KernelRegistry::new();
224
225 registry.register_source(
227 "test_add",
228 r#"
229 #include <metal_stdlib>
230 using namespace metal;
231 kernel void test_add(
232 device float *a [[buffer(0)]],
233 device float *b [[buffer(1)]],
234 device float *c [[buffer(2)]],
235 uint id [[thread_position_in_grid]]
236 ) {
237 c[id] = a[id] + b[id];
238 }
239 "#,
240 );
241
242 assert!(!registry.is_cached("test_add"));
244 let p1 = registry
245 .get_pipeline("test_add", device.metal_device())
246 .expect("get_pipeline first call");
247 let p1_ptr = p1 as *const _;
248 assert!(registry.is_cached("test_add"));
249
250 let p2 = registry
252 .get_pipeline("test_add", device.metal_device())
253 .expect("get_pipeline second call");
254 let p2_ptr = p2 as *const _;
255
256 assert_eq!(
257 p1_ptr, p2_ptr,
258 "Second get_pipeline call should return the same cached pipeline"
259 );
260 }
261
262 #[test]
264 fn test_buffer_alloc_zero_len_error() {
265 let device = MlxDevice::new().expect("device");
266 let result = device.alloc_buffer(0, DType::F32, vec![]);
267 assert!(result.is_err(), "Zero-length allocation should fail");
268 match result {
269 Err(MlxError::InvalidArgument(_)) => {}
270 other => panic!("Expected InvalidArgument, got {:?}", other),
271 }
272 }
273
274 #[test]
276 fn test_kernel_not_found() {
277 let device = MlxDevice::new().expect("device");
278 let mut registry = KernelRegistry::new();
279 let result = registry.get_pipeline("nonexistent_kernel", device.metal_device());
280 assert!(result.is_err());
281 match result {
282 Err(MlxError::KernelNotFound(name)) => {
283 assert_eq!(name, "nonexistent_kernel");
284 }
285 other => panic!("Expected KernelNotFound, got {:?}", other),
286 }
287 }
288
289 #[test]
291 fn test_dtype_sizes() {
292 assert_eq!(DType::F32.size_of(), 4);
293 assert_eq!(DType::F16.size_of(), 2);
294 assert_eq!(DType::BF16.size_of(), 2);
295 assert_eq!(DType::U8.size_of(), 1);
296 assert_eq!(DType::U16.size_of(), 2);
297 assert_eq!(DType::U32.size_of(), 4);
298 assert_eq!(DType::I32.size_of(), 4);
299 }
300
301 #[test]
303 fn test_buffer_debug() {
304 let device = MlxDevice::new().expect("device");
305 let buf = device
306 .alloc_buffer(64, DType::F16, vec![4, 8])
307 .expect("alloc_buffer");
308 let debug_str = format!("{:?}", buf);
309 assert!(debug_str.contains("MlxBuffer"));
310 assert!(debug_str.contains("F16"));
311 assert!(debug_str.contains("[4, 8]"));
312 }
313
314 #[test]
316 fn test_error_display() {
317 let e = MlxError::DeviceNotFound;
318 assert!(format!("{e}").contains("Metal GPU device"));
319
320 let e = MlxError::ShaderCompilationError {
321 name: "foo".into(),
322 message: "syntax error".into(),
323 };
324 assert!(format!("{e}").contains("foo"));
325 assert!(format!("{e}").contains("syntax error"));
326 }
327
328 #[test]
330 fn test_buffer_pool_size_buckets() {
331 let device = MlxDevice::new().expect("device");
332 let mut pool = MlxBufferPool::new();
333
334 let buf_100 = pool.alloc(&device, 100, DType::U8, vec![100]).expect("alloc 100");
336 assert!(
337 buf_100.byte_len() >= 100,
338 "Buffer should be at least 100 bytes"
339 );
340 pool.release(buf_100);
341
342 let buf_128 = pool.alloc(&device, 128, DType::U8, vec![128]).expect("alloc 128");
344 assert!(buf_128.byte_len() >= 128);
345 pool.release(buf_128);
346
347 let buf_200 = pool.alloc(&device, 200, DType::U8, vec![200]).expect("alloc 200");
349 assert!(buf_200.byte_len() >= 200);
350 pool.release(buf_200);
351
352 assert_eq!(pool.free_count(), 2, "Two different bucket sizes in pool");
353 }
354}