1use std::error::Error as StdError;
7use std::ffi::{CStr, CString};
8use std::fmt;
9use std::os::raw::c_int;
10use std::ptr;
11use std::slice;
12
13pub mod raw {
14 #![allow(
15 clippy::all,
16 non_camel_case_types,
17 non_snake_case,
18 non_upper_case_globals,
19 unsafe_op_in_unsafe_fn
20 )]
21
22 include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
23}
24
25type MlxArrayRaw = raw::mlx_array;
26type MlxDeviceRaw = raw::mlx_device;
27type MlxStreamRaw = raw::mlx_stream;
28type MlxDeviceInfoRaw = raw::mlx_device_info;
29
30const MLX_DTYPE_COMPLEX64: raw::mlx_dtype = raw::mlx_dtype__MLX_COMPLEX64;
31const MLX_DEVICE_CPU: raw::mlx_device_type = raw::mlx_device_type__MLX_CPU;
32const MLX_DEVICE_GPU: raw::mlx_device_type = raw::mlx_device_type__MLX_GPU;
33
34#[derive(Debug)]
35pub struct Error(String);
36
37impl fmt::Display for Error {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 f.write_str(&self.0)
40 }
41}
42
43impl StdError for Error {}
44
45pub type Result<T> = std::result::Result<T, Error>;
46
47fn check(code: c_int, context: &str) -> Result<()> {
48 if code == 0 {
49 Ok(())
50 } else {
51 Err(Error(format!(
52 "{context} failed with MLX error code {code}"
53 )))
54 }
55}
56
57#[repr(C)]
58#[derive(Clone, Copy, Debug, PartialEq)]
59pub struct Complex32 {
60 pub re: f32,
61 pub im: f32,
62}
63
64impl Complex32 {
65 pub const fn new(re: f32, im: f32) -> Self {
66 Self { re, im }
67 }
68}
69
70impl fmt::Display for Complex32 {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 write!(f, "{:.3}{:+.3}i", self.re, self.im)
73 }
74}
75
76struct DeviceInfo {
77 raw: MlxDeviceInfoRaw,
78}
79
80impl DeviceInfo {
81 fn load(device: &Device) -> Result<Self> {
82 let mut raw = MlxDeviceInfoRaw {
83 ctx: ptr::null_mut(),
84 };
85 unsafe {
86 check(
87 raw::mlx_device_info_get(&mut raw, device.raw),
88 "mlx_device_info_get",
89 )?;
90 }
91 if raw.ctx.is_null() {
92 return Err(Error("mlx_device_info_get returned a null handle".into()));
93 }
94 Ok(Self { raw })
95 }
96
97 fn get_string(&self, key: &str) -> Result<Option<String>> {
98 let key = CString::new(key)
99 .map_err(|_| Error(format!("device info key contains interior null: {key:?}")))?;
100 let mut exists = false;
101 unsafe {
102 check(
103 raw::mlx_device_info_has_key(&mut exists, self.raw, key.as_ptr()),
104 "mlx_device_info_has_key",
105 )?;
106 }
107 if !exists {
108 return Ok(None);
109 }
110
111 let mut value = ptr::null();
112 unsafe {
113 check(
114 raw::mlx_device_info_get_string(&mut value, self.raw, key.as_ptr()),
115 "mlx_device_info_get_string",
116 )?;
117 if value.is_null() {
118 return Ok(None);
119 }
120 Ok(Some(CStr::from_ptr(value).to_string_lossy().into_owned()))
121 }
122 }
123}
124
125impl Drop for DeviceInfo {
126 fn drop(&mut self) {
127 unsafe {
128 let _ = raw::mlx_device_info_free(self.raw);
129 }
130 }
131}
132
133pub struct Device {
134 raw: MlxDeviceRaw,
135}
136
137impl Device {
138 pub fn gpu_if_available() -> Result<Option<Self>> {
139 let raw = unsafe { raw::mlx_device_new_type(MLX_DEVICE_GPU, 0) };
140 let device = Self { raw };
141 let mut available = false;
142 unsafe {
143 check(
144 raw::mlx_device_is_available(&mut available, device.raw),
145 "mlx_device_is_available",
146 )?;
147 }
148 if available {
149 Ok(Some(device))
150 } else {
151 Ok(None)
152 }
153 }
154
155 pub fn cpu() -> Self {
156 let raw = unsafe { raw::mlx_device_new_type(MLX_DEVICE_CPU, 0) };
157 Self { raw }
158 }
159
160 pub fn preferred() -> Result<Self> {
161 if let Some(gpu) = Self::gpu_if_available()? {
162 return Ok(gpu);
163 }
164 Ok(Self::cpu())
165 }
166
167 pub fn kind(&self) -> Result<&'static str> {
168 let mut kind = MLX_DEVICE_CPU;
169 unsafe {
170 check(
171 raw::mlx_device_get_type(&mut kind, self.raw),
172 "mlx_device_get_type",
173 )?;
174 }
175 Ok(match kind {
176 MLX_DEVICE_CPU => "CPU",
177 MLX_DEVICE_GPU => "GPU",
178 _ => "Unknown",
179 })
180 }
181
182 pub fn index(&self) -> Result<i32> {
183 let mut index = 0;
184 unsafe {
185 check(
186 raw::mlx_device_get_index(&mut index, self.raw),
187 "mlx_device_get_index",
188 )?;
189 }
190 Ok(index)
191 }
192
193 pub fn name(&self) -> Result<String> {
194 let info = DeviceInfo::load(self)?;
195 if let Some(name) = info.get_string("device_name")? {
196 return Ok(name);
197 }
198 Ok(format!("{} device {}", self.kind()?, self.index()?))
199 }
200}
201
202impl Drop for Device {
203 fn drop(&mut self) {
204 unsafe {
205 let _ = raw::mlx_device_free(self.raw);
206 }
207 }
208}
209
210pub struct Stream {
211 raw: MlxStreamRaw,
212}
213
214impl Stream {
215 pub fn new(device: &Device) -> Self {
216 let raw = unsafe { raw::mlx_stream_new_device(device.raw) };
217 Self { raw }
218 }
219
220 pub fn synchronize(&self) -> Result<()> {
221 unsafe { check(raw::mlx_synchronize(self.raw), "mlx_synchronize") }
222 }
223}
224
225impl Drop for Stream {
226 fn drop(&mut self) {
227 unsafe {
228 let _ = raw::mlx_stream_free(self.raw);
229 }
230 }
231}
232
233pub struct Array {
234 raw: MlxArrayRaw,
235}
236
237impl Array {
238 pub fn from_complex_matrix(rows: usize, cols: usize, values: &[Complex32]) -> Result<Self> {
239 if rows * cols != values.len() {
240 return Err(Error(format!(
241 "shape {rows}x{cols} does not match {} values",
242 values.len()
243 )));
244 }
245
246 let shape = [rows as c_int, cols as c_int];
247 let raw = unsafe {
248 raw::mlx_array_new_data(
249 values.as_ptr().cast(),
250 shape.as_ptr(),
251 shape.len() as c_int,
252 MLX_DTYPE_COMPLEX64,
253 )
254 };
255
256 if raw.ctx.is_null() {
257 return Err(Error("mlx_array_new_data returned a null handle".into()));
258 }
259
260 Ok(Self { raw })
261 }
262
263 pub fn matmul(&self, rhs: &Self, stream: &Stream) -> Result<Self> {
264 let mut out = MlxArrayRaw {
265 ctx: ptr::null_mut(),
266 };
267 unsafe {
268 check(
269 raw::mlx_matmul(&mut out, self.raw, rhs.raw, stream.raw),
270 "mlx_matmul",
271 )?;
272 }
273 Ok(Self { raw: out })
274 }
275
276 pub fn max_abs_error(&self, rhs: &Self, stream: &Stream) -> Result<f32> {
277 let mut delta = MlxArrayRaw {
278 ctx: ptr::null_mut(),
279 };
280 let mut magnitude = MlxArrayRaw {
281 ctx: ptr::null_mut(),
282 };
283 let mut max_value = MlxArrayRaw {
284 ctx: ptr::null_mut(),
285 };
286
287 unsafe {
288 check(
289 raw::mlx_subtract(&mut delta, self.raw, rhs.raw, stream.raw),
290 "mlx_subtract",
291 )?;
292 check(raw::mlx_abs(&mut magnitude, delta, stream.raw), "mlx_abs")?;
293 check(
294 raw::mlx_max(&mut max_value, magnitude, false, stream.raw),
295 "mlx_max",
296 )?;
297 check(raw::mlx_array_eval(max_value), "mlx_array_eval")?;
298 stream.synchronize()?;
299 let mut value = 0.0;
300 check(
301 raw::mlx_array_item_float32(&mut value, max_value),
302 "mlx_array_item_float32",
303 )?;
304 let _ = raw::mlx_array_free(delta);
305 let _ = raw::mlx_array_free(magnitude);
306 let _ = raw::mlx_array_free(max_value);
307 Ok(value)
308 }
309 }
310
311 pub fn shape(&self) -> Result<Vec<usize>> {
312 unsafe {
313 let ndim = raw::mlx_array_ndim(self.raw);
314 let shape_ptr = raw::mlx_array_shape(self.raw);
315 if shape_ptr.is_null() {
316 return Err(Error("mlx_array_shape returned a null pointer".into()));
317 }
318 Ok(slice::from_raw_parts(shape_ptr, ndim)
319 .iter()
320 .map(|dim| *dim as usize)
321 .collect())
322 }
323 }
324
325 pub fn to_complex_vec(&self, stream: &Stream) -> Result<Vec<Complex32>> {
326 unsafe {
327 if raw::mlx_array_dtype(self.raw) != MLX_DTYPE_COMPLEX64 {
328 return Err(Error("expected MLX complex64 output".into()));
329 }
330 check(raw::mlx_array_eval(self.raw), "mlx_array_eval")?;
331 stream.synchronize()?;
332 let count = raw::mlx_array_size(self.raw);
333 let ptr = raw::mlx_array_data_complex64(self.raw) as *const Complex32;
334 if ptr.is_null() {
335 return Err(Error("mlx_array_data_complex64 returned null".into()));
336 }
337 Ok(slice::from_raw_parts(ptr, count).to_vec())
338 }
339 }
340}
341
342impl Drop for Array {
343 fn drop(&mut self) {
344 unsafe {
345 let _ = raw::mlx_array_free(self.raw);
346 }
347 }
348}
349
350pub fn cpu_complex_matmul(
351 lhs: &[Complex32],
352 rhs: &[Complex32],
353 lhs_rows: usize,
354 lhs_cols: usize,
355 rhs_cols: usize,
356) -> Vec<Complex32> {
357 let mut out = vec![Complex32::new(0.0, 0.0); lhs_rows * rhs_cols];
358 for row in 0..lhs_rows {
359 for col in 0..rhs_cols {
360 let mut acc = Complex32::new(0.0, 0.0);
361 for k in 0..lhs_cols {
362 let a = lhs[row * lhs_cols + k];
363 let b = rhs[k * rhs_cols + col];
364 acc.re += a.re * b.re - a.im * b.im;
365 acc.im += a.re * b.im + a.im * b.re;
366 }
367 out[row * rhs_cols + col] = acc;
368 }
369 }
370 out
371}
372
373pub fn print_matrix(values: &[Complex32], rows: usize, cols: usize, label: &str) {
374 println!("{label}:");
375 for row in values.chunks(cols).take(rows) {
376 let rendered = row
377 .iter()
378 .map(ToString::to_string)
379 .collect::<Vec<_>>()
380 .join(" ");
381 println!(" {rendered}");
382 }
383}
384
385pub fn demo_complex_matmul() -> Result<()> {
386 let lhs = vec![
387 Complex32::new(1.0, 2.0),
388 Complex32::new(3.0, -1.0),
389 Complex32::new(-2.0, 0.5),
390 Complex32::new(0.0, 4.0),
391 ];
392 let rhs = vec![
393 Complex32::new(0.5, -1.0),
394 Complex32::new(2.0, 0.0),
395 Complex32::new(-3.0, 1.5),
396 Complex32::new(1.0, -2.0),
397 ];
398
399 let device = Device::preferred()?;
400 let stream = Stream::new(&device);
401 let lhs_array = Array::from_complex_matrix(2, 2, &lhs)?;
402 let rhs_array = Array::from_complex_matrix(2, 2, &rhs)?;
403 let product = lhs_array.matmul(&rhs_array, &stream)?;
404 let product_shape = product.shape()?;
405 let product_values = product.to_complex_vec(&stream)?;
406
407 let expected_values = cpu_complex_matmul(&lhs, &rhs, 2, 2, 2);
408 let expected = Array::from_complex_matrix(2, 2, &expected_values)?;
409 let max_abs_error = product.max_abs_error(&expected, &stream)?;
410
411 println!(
412 "Using Apple MLX on {} device {} ({})",
413 device.kind()?,
414 device.index()?,
415 device.name()?
416 );
417 println!("Output shape: {:?}", product_shape);
418 print_matrix(&lhs, 2, 2, "Left matrix");
419 print_matrix(&rhs, 2, 2, "Right matrix");
420 print_matrix(&product_values, 2, 2, "MLX product");
421 println!("Max absolute error vs CPU reference: {max_abs_error:.6}");
422
423 if max_abs_error > 1e-4 {
424 return Err(Error(format!(
425 "MLX result drifted from the CPU reference: {max_abs_error}"
426 )));
427 }
428
429 Ok(())
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435
436 #[test]
437 fn cpu_reference_matches_known_values() {
438 let lhs = vec![
439 Complex32::new(1.0, 2.0),
440 Complex32::new(3.0, -1.0),
441 Complex32::new(-2.0, 0.5),
442 Complex32::new(0.0, 4.0),
443 ];
444 let rhs = vec![
445 Complex32::new(0.5, -1.0),
446 Complex32::new(2.0, 0.0),
447 Complex32::new(-3.0, 1.5),
448 Complex32::new(1.0, -2.0),
449 ];
450
451 let actual = cpu_complex_matmul(&lhs, &rhs, 2, 2, 2);
452 let expected = vec![
453 Complex32::new(-5.0, 7.5),
454 Complex32::new(3.0, -3.0),
455 Complex32::new(-6.5, -9.75),
456 Complex32::new(4.0, 5.0),
457 ];
458
459 assert_eq!(actual, expected);
460 }
461}