apple-mlx 0.1.1

Rust bindings and safe wrappers for Apple MLX via the official mlx-c API
Documentation
#![allow(unsafe_op_in_unsafe_fn)]

#[path = "support/mod.rs"]
mod support;

use apple_mlx::raw;
use std::ffi::CStr;
use std::ptr;
use support::{Result, check, cstring, print_array, vector_string_get};

fn gpu_info() -> Result<()> {
    unsafe {
        println!("==================================================");
        println!("GPU info (using mlx_device_info API):");

        let mut gpu_count = 0;
        check(
            raw::mlx_device_count(&mut gpu_count, raw::mlx_device_type__MLX_GPU),
            "mlx_device_count",
        )?;
        println!("GPU device count: {gpu_count}");

        let mut dev = raw::mlx_device_new();
        check(
            raw::mlx_get_default_device(&mut dev),
            "mlx_get_default_device",
        )?;

        let mut info = raw::mlx_device_info_new();
        if raw::mlx_device_info_get(&mut info, dev) == 0 {
            let mut keys = raw::mlx_vector_string_new();
            check(
                raw::mlx_device_info_get_keys(&mut keys, info),
                "mlx_device_info_get_keys",
            )?;
            let num_keys = raw::mlx_vector_string_size(keys);
            println!("Device info ({num_keys} keys):");
            for i in 0..num_keys {
                let key = vector_string_get(keys, i)?;
                let key_c = cstring(&key)?;
                let mut is_string = false;
                check(
                    raw::mlx_device_info_is_string(&mut is_string, info, key_c.as_ptr()),
                    "mlx_device_info_is_string",
                )?;
                if is_string {
                    let mut value = ptr::null();
                    check(
                        raw::mlx_device_info_get_string(&mut value, info, key_c.as_ptr()),
                        "mlx_device_info_get_string",
                    )?;
                    let rendered = if value.is_null() {
                        "(null)".to_string()
                    } else {
                        CStr::from_ptr(value).to_string_lossy().into_owned()
                    };
                    println!("  {key}: {rendered}");
                } else {
                    let mut value = 0usize;
                    check(
                        raw::mlx_device_info_get_size(&mut value, info, key_c.as_ptr()),
                        "mlx_device_info_get_size",
                    )?;
                    println!("  {key}: {value}");
                }
            }
            let _ = raw::mlx_vector_string_free(keys);
        }

        let _ = raw::mlx_device_info_free(info);
        let _ = raw::mlx_device_free(dev);
        println!("==================================================");
        Ok(())
    }
}

fn main() -> Result<()> {
    unsafe {
        let mut version = raw::mlx_string_new();
        check(raw::mlx_version(&mut version), "mlx_version")?;
        println!("MLX version: {}", support::mlx_string_to_rust(version));
        let _ = raw::mlx_string_free(version);

        gpu_info()?;

        let stream = raw::mlx_default_cpu_stream_new();
        let data = [1f32, 2., 3., 4., 5., 6.];
        let shape = [2, 3];
        let mut arr = raw::mlx_array_new_data(
            data.as_ptr().cast(),
            shape.as_ptr(),
            shape.len() as i32,
            raw::mlx_dtype__MLX_FLOAT32,
        );
        print_array("hello world!", arr)?;

        let two = raw::mlx_array_new_int(2);
        check(raw::mlx_divide(&mut arr, arr, two, stream), "mlx_divide")?;
        print_array("divide by 2!", arr)?;

        check(
            raw::mlx_arange(&mut arr, 0.0, 3.0, 0.5, raw::mlx_dtype__MLX_FLOAT32, stream),
            "mlx_arange",
        )?;
        print_array("arange", arr)?;

        let mut managed = Box::new([0f32, 1., 2., 3., 4., 5.]);
        let data_ptr = managed.as_mut_ptr();
        let payload = Box::into_raw(managed);
        let arr_managed = raw::mlx_array_new_data_managed_payload(
            data_ptr.cast(),
            shape.as_ptr(),
            shape.len() as i32,
            raw::mlx_dtype__MLX_FLOAT32,
            payload.cast(),
            Some(support::drop_boxed_f32x6),
        );
        print_array("from user buffer", arr_managed)?;

        let _ = raw::mlx_array_free(arr);
        let _ = raw::mlx_array_free(two);
        let _ = raw::mlx_array_free(arr_managed);
        let _ = raw::mlx_stream_free(stream);
    }
    Ok(())
}