#![deny(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
#![allow(unexpected_cfgs)]
#[macro_use]
mod error;
mod buffer;
mod buffer_pool;
mod device;
mod dtypes;
mod encoder;
mod kernel_registry;
pub mod gguf;
pub mod graph;
pub mod ops;
pub mod turboquant;
pub mod weight;
pub use buffer::MlxBuffer;
pub use buffer_pool::MlxBufferPool;
pub use device::MlxDevice;
pub use dtypes::DType;
pub use encoder::{
dispatch_count, reset_counters, sync_count, CapturedNode, CommandEncoder, DispatchKind,
RecordedBinding,
};
pub use error::{MlxError, Result};
pub use graph::{ComputeGraph, GraphExecutor, GraphSession, OpKind};
pub use kernel_registry::KernelRegistry;
pub use gguf::{GgufFile, MetadataValue, TensorInfo};
pub use ops::quantized_matmul::{quantized_matmul, quantized_matmul_simd, QuantizedMatmulParams};
pub use ops::quantized_matmul_ggml::{
quantized_matmul_ggml, GgmlQuantizedMatmulParams, GgmlType,
};
pub use ops::quantized_matmul_id::{quantized_matmul_id, QuantizedMatmulIdParams};
pub use ops::quantized_matmul_id_ggml::{
quantized_matmul_id_ggml, GgmlQuantizedMatmulIdParams,
};
pub use weight::{
load_quantized_weights, safetensors_to_metal_buffer, QuantizationConfig, QuantizedWeight,
SafetensorsFile, TensorQuantConfig,
};
pub use metal::MTLSize;
pub use metal;
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
fn _assert_send<T: Send>() {}
fn _assert_sync<T: Sync>() {}
#[allow(dead_code)]
fn assert_send_sync() {
_assert_send::<MlxDevice>();
_assert_sync::<MlxDevice>();
_assert_send::<MlxBuffer>();
_assert_sync::<MlxBuffer>();
_assert_send::<MlxError>();
_assert_sync::<MlxError>();
}
#[test]
fn test_device_init() {
let device = MlxDevice::new().expect("MlxDevice::new() should succeed on Apple Silicon");
let name = device.name();
assert!(!name.is_empty(), "Device name should not be empty");
println!("Metal device: {name}");
}
#[test]
fn test_buffer_alloc() {
let device = MlxDevice::new().expect("device");
let shape = vec![2, 3, 4];
let byte_len = 2 * 3 * 4 * DType::F32.size_of(); let buf = device
.alloc_buffer(byte_len, DType::F32, shape.clone())
.expect("alloc_buffer");
assert_eq!(buf.dtype(), DType::F32);
assert_eq!(buf.shape(), &shape);
assert_eq!(buf.byte_len(), byte_len);
assert_eq!(buf.element_count(), 24);
}
#[test]
fn test_buffer_readwrite() {
let device = MlxDevice::new().expect("device");
let n = 64;
let byte_len = n * std::mem::size_of::<f32>();
let mut buf = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("alloc_buffer");
{
let slice: &mut [f32] = buf.as_mut_slice().expect("as_mut_slice");
assert_eq!(slice.len(), n);
for (i, val) in slice.iter_mut().enumerate() {
*val = i as f32 * 1.5;
}
}
{
let slice: &[f32] = buf.as_slice().expect("as_slice");
for (i, &val) in slice.iter().enumerate() {
let expected = i as f32 * 1.5;
assert!(
(val - expected).abs() < f32::EPSILON,
"Mismatch at index {i}: got {val}, expected {expected}"
);
}
}
}
#[test]
fn test_encoder_lifecycle() {
let device = MlxDevice::new().expect("device");
let mut enc = device.command_encoder().expect("command_encoder");
enc.commit_and_wait()
.expect("commit_and_wait on empty encoder");
}
#[test]
fn test_buffer_pool_reuse() {
let device = MlxDevice::new().expect("device");
let mut pool = MlxBufferPool::new(&device);
let buf1 = pool
.alloc(1024, DType::F32, vec![256])
.expect("pool alloc 1");
let buf1_ptr = buf1.contents_ptr();
let buf1_byte_len = buf1.byte_len();
pool.release(buf1);
assert_eq!(pool.free_count(), 1);
let buf2 = pool
.alloc(1024, DType::F32, vec![256])
.expect("pool alloc 2");
let buf2_ptr = buf2.contents_ptr();
let buf2_byte_len = buf2.byte_len();
assert_eq!(buf1_ptr, buf2_ptr, "Pool should reuse the same Metal buffer");
assert_eq!(buf1_byte_len, buf2_byte_len, "Byte lengths should match");
assert_eq!(pool.free_count(), 0, "Free list should be empty after reuse");
}
#[test]
fn test_kernel_registry_caching() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
registry.register_source(
"test_add",
r#"
#include <metal_stdlib>
using namespace metal;
kernel void test_add(
device float *a [[buffer(0)]],
device float *b [[buffer(1)]],
device float *c [[buffer(2)]],
uint id [[thread_position_in_grid]]
) {
c[id] = a[id] + b[id];
}
"#,
);
assert!(!registry.is_cached("test_add"));
let p1 = registry
.get_pipeline("test_add", device.metal_device())
.expect("get_pipeline first call");
let p1_ptr = p1 as *const _;
assert!(registry.is_cached("test_add"));
let p2 = registry
.get_pipeline("test_add", device.metal_device())
.expect("get_pipeline second call");
let p2_ptr = p2 as *const _;
assert_eq!(
p1_ptr, p2_ptr,
"Second get_pipeline call should return the same cached pipeline"
);
}
#[test]
fn test_buffer_alloc_zero_len_error() {
let device = MlxDevice::new().expect("device");
let result = device.alloc_buffer(0, DType::F32, vec![]);
assert!(result.is_err(), "Zero-length allocation should fail");
match result {
Err(MlxError::InvalidArgument(_)) => {}
other => panic!("Expected InvalidArgument, got {:?}", other),
}
}
#[test]
fn test_kernel_not_found() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let result = registry.get_pipeline("nonexistent_kernel", device.metal_device());
assert!(result.is_err());
match result {
Err(MlxError::KernelNotFound(name)) => {
assert_eq!(name, "nonexistent_kernel");
}
other => panic!("Expected KernelNotFound, got {:?}", other),
}
}
#[test]
fn test_dtype_sizes() {
assert_eq!(DType::F32.size_of(), 4);
assert_eq!(DType::F16.size_of(), 2);
assert_eq!(DType::BF16.size_of(), 2);
assert_eq!(DType::U8.size_of(), 1);
assert_eq!(DType::U16.size_of(), 2);
assert_eq!(DType::U32.size_of(), 4);
assert_eq!(DType::I32.size_of(), 4);
}
#[test]
fn test_buffer_debug() {
let device = MlxDevice::new().expect("device");
let buf = device
.alloc_buffer(64, DType::F16, vec![4, 8])
.expect("alloc_buffer");
let debug_str = format!("{:?}", buf);
assert!(debug_str.contains("MlxBuffer"));
assert!(debug_str.contains("F16"));
assert!(debug_str.contains("[4, 8]"));
}
#[test]
fn test_error_display() {
let e = MlxError::DeviceNotFound;
assert!(format!("{e}").contains("Metal GPU device"));
let e = MlxError::ShaderCompilationError {
name: "foo".into(),
message: "syntax error".into(),
};
assert!(format!("{e}").contains("foo"));
assert!(format!("{e}").contains("syntax error"));
}
#[test]
fn test_buffer_pool_size_buckets() {
let device = MlxDevice::new().expect("device");
let mut pool = MlxBufferPool::new(&device);
let buf_100 = pool.alloc(100, DType::U8, vec![100]).expect("alloc 100");
assert!(
buf_100.byte_len() >= 100,
"Buffer should be at least 100 bytes"
);
pool.release(buf_100);
let buf_128 = pool.alloc(128, DType::U8, vec![128]).expect("alloc 128");
assert!(buf_128.byte_len() >= 128);
pool.release(buf_128);
let buf_200 = pool.alloc(200, DType::U8, vec![200]).expect("alloc 200");
assert!(buf_200.byte_len() >= 200);
pool.release(buf_200);
assert_eq!(pool.free_count(), 2, "Two different bucket sizes in pool");
}
}