1#![deny(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
41#![allow(unexpected_cfgs)]
45
46#[macro_use]
48mod error;
49mod buffer;
50mod buffer_pool;
51mod device;
52mod dtypes;
53mod encoder;
54mod kernel_registry;
55pub mod gguf;
56pub mod graph;
57pub mod ops;
58pub mod turboquant;
59pub mod weight;
60
61pub use buffer::MlxBuffer;
63pub use buffer_pool::MlxBufferPool;
64pub use device::MlxDevice;
65pub use dtypes::DType;
66pub use encoder::{
67 dispatch_count, reset_counters, sync_count, CapturedNode, CommandEncoder, DispatchKind,
68 RecordedBinding,
69};
70pub use error::{MlxError, Result};
71pub use graph::{ComputeGraph, GraphExecutor, GraphSession, OpKind};
72pub use kernel_registry::KernelRegistry;
73
74pub use gguf::{GgufFile, MetadataValue, TensorInfo};
76
77pub use ops::quantized_matmul::{quantized_matmul, quantized_matmul_simd, QuantizedMatmulParams};
79pub use ops::quantized_matmul_ggml::{
80 quantized_matmul_ggml, GgmlQuantizedMatmulParams, GgmlType,
81};
82pub use ops::quantized_matmul_id::{quantized_matmul_id, QuantizedMatmulIdParams};
83pub use ops::quantized_matmul_id_ggml::{
84 quantized_matmul_id_ggml, GgmlQuantizedMatmulIdParams,
85};
86
87pub use weight::{
89 load_quantized_weights, safetensors_to_metal_buffer, QuantizationConfig, QuantizedWeight,
90 SafetensorsFile, TensorQuantConfig,
91};
92
93pub use metal::MTLSize;
95pub use metal;
96
97#[cfg(test)]
98#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
99mod tests {
100 use super::*;
101
102 fn _assert_send<T: Send>() {}
104 fn _assert_sync<T: Sync>() {}
105
106 #[allow(dead_code)]
107 fn assert_send_sync() {
108 _assert_send::<MlxDevice>();
109 _assert_sync::<MlxDevice>();
110 _assert_send::<MlxBuffer>();
111 _assert_sync::<MlxBuffer>();
112 _assert_send::<MlxError>();
113 _assert_sync::<MlxError>();
114 }
115
116 #[test]
118 fn test_device_init() {
119 let device = MlxDevice::new().expect("MlxDevice::new() should succeed on Apple Silicon");
120 let name = device.name();
121 assert!(!name.is_empty(), "Device name should not be empty");
122 println!("Metal device: {name}");
123 }
124
125 #[test]
127 fn test_buffer_alloc() {
128 let device = MlxDevice::new().expect("device");
129 let shape = vec![2, 3, 4];
130 let byte_len = 2 * 3 * 4 * DType::F32.size_of(); let buf = device
132 .alloc_buffer(byte_len, DType::F32, shape.clone())
133 .expect("alloc_buffer");
134
135 assert_eq!(buf.dtype(), DType::F32);
136 assert_eq!(buf.shape(), &shape);
137 assert_eq!(buf.byte_len(), byte_len);
138 assert_eq!(buf.element_count(), 24);
139 }
140
141 #[test]
143 fn test_buffer_readwrite() {
144 let device = MlxDevice::new().expect("device");
145 let n = 64;
146 let byte_len = n * std::mem::size_of::<f32>();
147 let mut buf = device
148 .alloc_buffer(byte_len, DType::F32, vec![n])
149 .expect("alloc_buffer");
150
151 {
153 let slice: &mut [f32] = buf.as_mut_slice().expect("as_mut_slice");
154 assert_eq!(slice.len(), n);
155 for (i, val) in slice.iter_mut().enumerate() {
156 *val = i as f32 * 1.5;
157 }
158 }
159
160 {
162 let slice: &[f32] = buf.as_slice().expect("as_slice");
163 for (i, &val) in slice.iter().enumerate() {
164 let expected = i as f32 * 1.5;
165 assert!(
166 (val - expected).abs() < f32::EPSILON,
167 "Mismatch at index {i}: got {val}, expected {expected}"
168 );
169 }
170 }
171 }
172
173 #[test]
175 fn test_encoder_lifecycle() {
176 let device = MlxDevice::new().expect("device");
177 let mut enc = device.command_encoder().expect("command_encoder");
178 enc.commit_and_wait()
180 .expect("commit_and_wait on empty encoder");
181 }
182
183 #[test]
185 fn test_buffer_pool_reuse() {
186 let device = MlxDevice::new().expect("device");
187 let mut pool = MlxBufferPool::new(&device);
188
189 let buf1 = pool
191 .alloc(1024, DType::F32, vec![256])
192 .expect("pool alloc 1");
193 let buf1_ptr = buf1.contents_ptr();
194 let buf1_byte_len = buf1.byte_len();
195
196 pool.release(buf1);
198 assert_eq!(pool.free_count(), 1);
199
200 let buf2 = pool
202 .alloc(1024, DType::F32, vec![256])
203 .expect("pool alloc 2");
204 let buf2_ptr = buf2.contents_ptr();
205 let buf2_byte_len = buf2.byte_len();
206
207 assert_eq!(buf1_ptr, buf2_ptr, "Pool should reuse the same Metal buffer");
208 assert_eq!(buf1_byte_len, buf2_byte_len, "Byte lengths should match");
209 assert_eq!(pool.free_count(), 0, "Free list should be empty after reuse");
210 }
211
212 #[test]
214 fn test_kernel_registry_caching() {
215 let device = MlxDevice::new().expect("device");
216 let mut registry = KernelRegistry::new();
217
218 registry.register_source(
220 "test_add",
221 r#"
222 #include <metal_stdlib>
223 using namespace metal;
224 kernel void test_add(
225 device float *a [[buffer(0)]],
226 device float *b [[buffer(1)]],
227 device float *c [[buffer(2)]],
228 uint id [[thread_position_in_grid]]
229 ) {
230 c[id] = a[id] + b[id];
231 }
232 "#,
233 );
234
235 assert!(!registry.is_cached("test_add"));
237 let p1 = registry
238 .get_pipeline("test_add", device.metal_device())
239 .expect("get_pipeline first call");
240 let p1_ptr = p1 as *const _;
241 assert!(registry.is_cached("test_add"));
242
243 let p2 = registry
245 .get_pipeline("test_add", device.metal_device())
246 .expect("get_pipeline second call");
247 let p2_ptr = p2 as *const _;
248
249 assert_eq!(
250 p1_ptr, p2_ptr,
251 "Second get_pipeline call should return the same cached pipeline"
252 );
253 }
254
255 #[test]
257 fn test_buffer_alloc_zero_len_error() {
258 let device = MlxDevice::new().expect("device");
259 let result = device.alloc_buffer(0, DType::F32, vec![]);
260 assert!(result.is_err(), "Zero-length allocation should fail");
261 match result {
262 Err(MlxError::InvalidArgument(_)) => {}
263 other => panic!("Expected InvalidArgument, got {:?}", other),
264 }
265 }
266
267 #[test]
269 fn test_kernel_not_found() {
270 let device = MlxDevice::new().expect("device");
271 let mut registry = KernelRegistry::new();
272 let result = registry.get_pipeline("nonexistent_kernel", device.metal_device());
273 assert!(result.is_err());
274 match result {
275 Err(MlxError::KernelNotFound(name)) => {
276 assert_eq!(name, "nonexistent_kernel");
277 }
278 other => panic!("Expected KernelNotFound, got {:?}", other),
279 }
280 }
281
282 #[test]
284 fn test_dtype_sizes() {
285 assert_eq!(DType::F32.size_of(), 4);
286 assert_eq!(DType::F16.size_of(), 2);
287 assert_eq!(DType::BF16.size_of(), 2);
288 assert_eq!(DType::U8.size_of(), 1);
289 assert_eq!(DType::U16.size_of(), 2);
290 assert_eq!(DType::U32.size_of(), 4);
291 assert_eq!(DType::I32.size_of(), 4);
292 }
293
294 #[test]
296 fn test_buffer_debug() {
297 let device = MlxDevice::new().expect("device");
298 let buf = device
299 .alloc_buffer(64, DType::F16, vec![4, 8])
300 .expect("alloc_buffer");
301 let debug_str = format!("{:?}", buf);
302 assert!(debug_str.contains("MlxBuffer"));
303 assert!(debug_str.contains("F16"));
304 assert!(debug_str.contains("[4, 8]"));
305 }
306
307 #[test]
309 fn test_error_display() {
310 let e = MlxError::DeviceNotFound;
311 assert!(format!("{e}").contains("Metal GPU device"));
312
313 let e = MlxError::ShaderCompilationError {
314 name: "foo".into(),
315 message: "syntax error".into(),
316 };
317 assert!(format!("{e}").contains("foo"));
318 assert!(format!("{e}").contains("syntax error"));
319 }
320
321 #[test]
323 fn test_buffer_pool_size_buckets() {
324 let device = MlxDevice::new().expect("device");
325 let mut pool = MlxBufferPool::new(&device);
326
327 let buf_100 = pool.alloc(100, DType::U8, vec![100]).expect("alloc 100");
329 assert!(
330 buf_100.byte_len() >= 100,
331 "Buffer should be at least 100 bytes"
332 );
333 pool.release(buf_100);
334
335 let buf_128 = pool.alloc(128, DType::U8, vec![128]).expect("alloc 128");
337 assert!(buf_128.byte_len() >= 128);
338 pool.release(buf_128);
339
340 let buf_200 = pool.alloc(200, DType::U8, vec![200]).expect("alloc 200");
342 assert!(buf_200.byte_len() >= 200);
343 pool.release(buf_200);
344
345 assert_eq!(pool.free_count(), 2, "Two different bucket sizes in pool");
346 }
347}