Skip to main content

apple_mlx/
lib.rs

1//! Rust bindings for Apple MLX through the official `mlx-c` C API.
2//!
3//! `raw` exposes the generated low-level bindings.
4//! The top-level types provide a small safe wrapper over a subset of the API.
5
6use std::error::Error as StdError;
7use std::ffi::{CStr, CString};
8use std::fmt;
9use std::os::raw::c_int;
10use std::ptr;
11use std::slice;
12
13pub mod raw {
14    #![allow(
15        clippy::all,
16        non_camel_case_types,
17        non_snake_case,
18        non_upper_case_globals,
19        unsafe_op_in_unsafe_fn
20    )]
21
22    include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
23}
24
25type MlxArrayRaw = raw::mlx_array;
26type MlxDeviceRaw = raw::mlx_device;
27type MlxStreamRaw = raw::mlx_stream;
28type MlxDeviceInfoRaw = raw::mlx_device_info;
29
30const MLX_DTYPE_COMPLEX64: raw::mlx_dtype = raw::mlx_dtype__MLX_COMPLEX64;
31const MLX_DEVICE_CPU: raw::mlx_device_type = raw::mlx_device_type__MLX_CPU;
32const MLX_DEVICE_GPU: raw::mlx_device_type = raw::mlx_device_type__MLX_GPU;
33
34#[derive(Debug)]
35pub struct Error(String);
36
37impl fmt::Display for Error {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        f.write_str(&self.0)
40    }
41}
42
43impl StdError for Error {}
44
45pub type Result<T> = std::result::Result<T, Error>;
46
47fn check(code: c_int, context: &str) -> Result<()> {
48    if code == 0 {
49        Ok(())
50    } else {
51        Err(Error(format!(
52            "{context} failed with MLX error code {code}"
53        )))
54    }
55}
56
57#[repr(C)]
58#[derive(Clone, Copy, Debug, PartialEq)]
59pub struct Complex32 {
60    pub re: f32,
61    pub im: f32,
62}
63
64impl Complex32 {
65    pub const fn new(re: f32, im: f32) -> Self {
66        Self { re, im }
67    }
68}
69
70impl fmt::Display for Complex32 {
71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72        write!(f, "{:.3}{:+.3}i", self.re, self.im)
73    }
74}
75
76struct DeviceInfo {
77    raw: MlxDeviceInfoRaw,
78}
79
80impl DeviceInfo {
81    fn load(device: &Device) -> Result<Self> {
82        let mut raw = MlxDeviceInfoRaw {
83            ctx: ptr::null_mut(),
84        };
85        unsafe {
86            check(
87                raw::mlx_device_info_get(&mut raw, device.raw),
88                "mlx_device_info_get",
89            )?;
90        }
91        if raw.ctx.is_null() {
92            return Err(Error("mlx_device_info_get returned a null handle".into()));
93        }
94        Ok(Self { raw })
95    }
96
97    fn get_string(&self, key: &str) -> Result<Option<String>> {
98        let key = CString::new(key)
99            .map_err(|_| Error(format!("device info key contains interior null: {key:?}")))?;
100        let mut exists = false;
101        unsafe {
102            check(
103                raw::mlx_device_info_has_key(&mut exists, self.raw, key.as_ptr()),
104                "mlx_device_info_has_key",
105            )?;
106        }
107        if !exists {
108            return Ok(None);
109        }
110
111        let mut value = ptr::null();
112        unsafe {
113            check(
114                raw::mlx_device_info_get_string(&mut value, self.raw, key.as_ptr()),
115                "mlx_device_info_get_string",
116            )?;
117            if value.is_null() {
118                return Ok(None);
119            }
120            Ok(Some(CStr::from_ptr(value).to_string_lossy().into_owned()))
121        }
122    }
123}
124
125impl Drop for DeviceInfo {
126    fn drop(&mut self) {
127        unsafe {
128            let _ = raw::mlx_device_info_free(self.raw);
129        }
130    }
131}
132
133pub struct Device {
134    raw: MlxDeviceRaw,
135}
136
137impl Device {
138    pub fn gpu_if_available() -> Result<Option<Self>> {
139        let raw = unsafe { raw::mlx_device_new_type(MLX_DEVICE_GPU, 0) };
140        let device = Self { raw };
141        let mut available = false;
142        unsafe {
143            check(
144                raw::mlx_device_is_available(&mut available, device.raw),
145                "mlx_device_is_available",
146            )?;
147        }
148        if available {
149            Ok(Some(device))
150        } else {
151            Ok(None)
152        }
153    }
154
155    pub fn cpu() -> Self {
156        let raw = unsafe { raw::mlx_device_new_type(MLX_DEVICE_CPU, 0) };
157        Self { raw }
158    }
159
160    pub fn preferred() -> Result<Self> {
161        if let Some(gpu) = Self::gpu_if_available()? {
162            return Ok(gpu);
163        }
164        Ok(Self::cpu())
165    }
166
167    pub fn kind(&self) -> Result<&'static str> {
168        let mut kind = MLX_DEVICE_CPU;
169        unsafe {
170            check(
171                raw::mlx_device_get_type(&mut kind, self.raw),
172                "mlx_device_get_type",
173            )?;
174        }
175        Ok(match kind {
176            MLX_DEVICE_CPU => "CPU",
177            MLX_DEVICE_GPU => "GPU",
178            _ => "Unknown",
179        })
180    }
181
182    pub fn index(&self) -> Result<i32> {
183        let mut index = 0;
184        unsafe {
185            check(
186                raw::mlx_device_get_index(&mut index, self.raw),
187                "mlx_device_get_index",
188            )?;
189        }
190        Ok(index)
191    }
192
193    pub fn name(&self) -> Result<String> {
194        let info = DeviceInfo::load(self)?;
195        if let Some(name) = info.get_string("device_name")? {
196            return Ok(name);
197        }
198        Ok(format!("{} device {}", self.kind()?, self.index()?))
199    }
200}
201
202impl Drop for Device {
203    fn drop(&mut self) {
204        unsafe {
205            let _ = raw::mlx_device_free(self.raw);
206        }
207    }
208}
209
210pub struct Stream {
211    raw: MlxStreamRaw,
212}
213
214impl Stream {
215    pub fn new(device: &Device) -> Self {
216        let raw = unsafe { raw::mlx_stream_new_device(device.raw) };
217        Self { raw }
218    }
219
220    pub fn synchronize(&self) -> Result<()> {
221        unsafe { check(raw::mlx_synchronize(self.raw), "mlx_synchronize") }
222    }
223}
224
225impl Drop for Stream {
226    fn drop(&mut self) {
227        unsafe {
228            let _ = raw::mlx_stream_free(self.raw);
229        }
230    }
231}
232
233pub struct Array {
234    raw: MlxArrayRaw,
235}
236
237impl Array {
238    pub fn from_complex_matrix(rows: usize, cols: usize, values: &[Complex32]) -> Result<Self> {
239        if rows * cols != values.len() {
240            return Err(Error(format!(
241                "shape {rows}x{cols} does not match {} values",
242                values.len()
243            )));
244        }
245
246        let shape = [rows as c_int, cols as c_int];
247        let raw = unsafe {
248            raw::mlx_array_new_data(
249                values.as_ptr().cast(),
250                shape.as_ptr(),
251                shape.len() as c_int,
252                MLX_DTYPE_COMPLEX64,
253            )
254        };
255
256        if raw.ctx.is_null() {
257            return Err(Error("mlx_array_new_data returned a null handle".into()));
258        }
259
260        Ok(Self { raw })
261    }
262
263    pub fn matmul(&self, rhs: &Self, stream: &Stream) -> Result<Self> {
264        let mut out = MlxArrayRaw {
265            ctx: ptr::null_mut(),
266        };
267        unsafe {
268            check(
269                raw::mlx_matmul(&mut out, self.raw, rhs.raw, stream.raw),
270                "mlx_matmul",
271            )?;
272        }
273        Ok(Self { raw: out })
274    }
275
276    pub fn max_abs_error(&self, rhs: &Self, stream: &Stream) -> Result<f32> {
277        let mut delta = MlxArrayRaw {
278            ctx: ptr::null_mut(),
279        };
280        let mut magnitude = MlxArrayRaw {
281            ctx: ptr::null_mut(),
282        };
283        let mut max_value = MlxArrayRaw {
284            ctx: ptr::null_mut(),
285        };
286
287        unsafe {
288            check(
289                raw::mlx_subtract(&mut delta, self.raw, rhs.raw, stream.raw),
290                "mlx_subtract",
291            )?;
292            check(raw::mlx_abs(&mut magnitude, delta, stream.raw), "mlx_abs")?;
293            check(
294                raw::mlx_max(&mut max_value, magnitude, false, stream.raw),
295                "mlx_max",
296            )?;
297            check(raw::mlx_array_eval(max_value), "mlx_array_eval")?;
298            stream.synchronize()?;
299            let mut value = 0.0;
300            check(
301                raw::mlx_array_item_float32(&mut value, max_value),
302                "mlx_array_item_float32",
303            )?;
304            let _ = raw::mlx_array_free(delta);
305            let _ = raw::mlx_array_free(magnitude);
306            let _ = raw::mlx_array_free(max_value);
307            Ok(value)
308        }
309    }
310
311    pub fn shape(&self) -> Result<Vec<usize>> {
312        unsafe {
313            let ndim = raw::mlx_array_ndim(self.raw);
314            let shape_ptr = raw::mlx_array_shape(self.raw);
315            if shape_ptr.is_null() {
316                return Err(Error("mlx_array_shape returned a null pointer".into()));
317            }
318            Ok(slice::from_raw_parts(shape_ptr, ndim)
319                .iter()
320                .map(|dim| *dim as usize)
321                .collect())
322        }
323    }
324
325    pub fn to_complex_vec(&self, stream: &Stream) -> Result<Vec<Complex32>> {
326        unsafe {
327            if raw::mlx_array_dtype(self.raw) != MLX_DTYPE_COMPLEX64 {
328                return Err(Error("expected MLX complex64 output".into()));
329            }
330            check(raw::mlx_array_eval(self.raw), "mlx_array_eval")?;
331            stream.synchronize()?;
332            let count = raw::mlx_array_size(self.raw);
333            let ptr = raw::mlx_array_data_complex64(self.raw) as *const Complex32;
334            if ptr.is_null() {
335                return Err(Error("mlx_array_data_complex64 returned null".into()));
336            }
337            Ok(slice::from_raw_parts(ptr, count).to_vec())
338        }
339    }
340}
341
342impl Drop for Array {
343    fn drop(&mut self) {
344        unsafe {
345            let _ = raw::mlx_array_free(self.raw);
346        }
347    }
348}
349
350pub fn cpu_complex_matmul(
351    lhs: &[Complex32],
352    rhs: &[Complex32],
353    lhs_rows: usize,
354    lhs_cols: usize,
355    rhs_cols: usize,
356) -> Vec<Complex32> {
357    let mut out = vec![Complex32::new(0.0, 0.0); lhs_rows * rhs_cols];
358    for row in 0..lhs_rows {
359        for col in 0..rhs_cols {
360            let mut acc = Complex32::new(0.0, 0.0);
361            for k in 0..lhs_cols {
362                let a = lhs[row * lhs_cols + k];
363                let b = rhs[k * rhs_cols + col];
364                acc.re += a.re * b.re - a.im * b.im;
365                acc.im += a.re * b.im + a.im * b.re;
366            }
367            out[row * rhs_cols + col] = acc;
368        }
369    }
370    out
371}
372
373pub fn print_matrix(values: &[Complex32], rows: usize, cols: usize, label: &str) {
374    println!("{label}:");
375    for row in values.chunks(cols).take(rows) {
376        let rendered = row
377            .iter()
378            .map(ToString::to_string)
379            .collect::<Vec<_>>()
380            .join("  ");
381        println!("  {rendered}");
382    }
383}
384
385pub fn demo_complex_matmul() -> Result<()> {
386    let lhs = vec![
387        Complex32::new(1.0, 2.0),
388        Complex32::new(3.0, -1.0),
389        Complex32::new(-2.0, 0.5),
390        Complex32::new(0.0, 4.0),
391    ];
392    let rhs = vec![
393        Complex32::new(0.5, -1.0),
394        Complex32::new(2.0, 0.0),
395        Complex32::new(-3.0, 1.5),
396        Complex32::new(1.0, -2.0),
397    ];
398
399    let device = Device::preferred()?;
400    let stream = Stream::new(&device);
401    let lhs_array = Array::from_complex_matrix(2, 2, &lhs)?;
402    let rhs_array = Array::from_complex_matrix(2, 2, &rhs)?;
403    let product = lhs_array.matmul(&rhs_array, &stream)?;
404    let product_shape = product.shape()?;
405    let product_values = product.to_complex_vec(&stream)?;
406
407    let expected_values = cpu_complex_matmul(&lhs, &rhs, 2, 2, 2);
408    let expected = Array::from_complex_matrix(2, 2, &expected_values)?;
409    let max_abs_error = product.max_abs_error(&expected, &stream)?;
410
411    println!(
412        "Using Apple MLX on {} device {} ({})",
413        device.kind()?,
414        device.index()?,
415        device.name()?
416    );
417    println!("Output shape: {:?}", product_shape);
418    print_matrix(&lhs, 2, 2, "Left matrix");
419    print_matrix(&rhs, 2, 2, "Right matrix");
420    print_matrix(&product_values, 2, 2, "MLX product");
421    println!("Max absolute error vs CPU reference: {max_abs_error:.6}");
422
423    if max_abs_error > 1e-4 {
424        return Err(Error(format!(
425            "MLX result drifted from the CPU reference: {max_abs_error}"
426        )));
427    }
428
429    Ok(())
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435
436    #[test]
437    fn cpu_reference_matches_known_values() {
438        let lhs = vec![
439            Complex32::new(1.0, 2.0),
440            Complex32::new(3.0, -1.0),
441            Complex32::new(-2.0, 0.5),
442            Complex32::new(0.0, 4.0),
443        ];
444        let rhs = vec![
445            Complex32::new(0.5, -1.0),
446            Complex32::new(2.0, 0.0),
447            Complex32::new(-3.0, 1.5),
448            Complex32::new(1.0, -2.0),
449        ];
450
451        let actual = cpu_complex_matmul(&lhs, &rhs, 2, 2, 2);
452        let expected = vec![
453            Complex32::new(-5.0, 7.5),
454            Complex32::new(3.0, -3.0),
455            Complex32::new(-6.5, -9.75),
456            Complex32::new(4.0, 5.0),
457        ];
458
459        assert_eq!(actual, expected);
460    }
461}