apple-mlx 0.1.1

Rust bindings and safe wrappers for Apple MLX via the official mlx-c API
Documentation
#![allow(unsafe_op_in_unsafe_fn)]

#[path = "support/mod.rs"]
mod support;

use apple_mlx::raw;
use support::{Result, check};

unsafe extern "C" {
    #[cfg(target_os = "macos")]
    static mut __stdoutp: *mut raw::FILE;
    #[cfg(not(target_os = "macos"))]
    static mut stdout: *mut raw::FILE;
}

unsafe fn stdout_file() -> *mut raw::FILE {
    #[cfg(target_os = "macos")]
    {
        __stdoutp
    }
    #[cfg(not(target_os = "macos"))]
    {
        stdout
    }
}

fn main() -> Result<()> {
    unsafe {
        let stream = raw::mlx_default_cpu_stream_new();
        let data = [1f32, 2., 3., 4., 5., 6.];
        let shape = [2, 3];
        let mut res = raw::mlx_array_new();
        let val = raw::mlx_array_new_data(
            data.as_ptr().cast(),
            shape.as_ptr(),
            shape.len() as i32,
            raw::mlx_dtype__MLX_FLOAT32,
        );
        let two = raw::mlx_array_new_int(2);
        check(raw::mlx_divide(&mut res, val, two, stream), "mlx_divide")?;
        check(raw::mlx_log(&mut res, res, stream), "mlx_log")?;

        let namer = raw::mlx_node_namer_new();
        check(
            raw::mlx_node_namer_set_name(namer, val, c"inputs".as_ptr()),
            "mlx_node_namer_set_name inputs",
        )?;
        check(
            raw::mlx_node_namer_set_name(namer, res, c"result".as_ptr()),
            "mlx_node_namer_set_name result",
        )?;
        let vec = raw::mlx_vector_array_new();
        let _ = raw::mlx_vector_array_append_value(vec, val);
        let _ = raw::mlx_vector_array_append_value(vec, two);
        let _ = raw::mlx_vector_array_append_value(vec, res);

        check(
            raw::mlx_export_to_dot(stdout_file(), namer, vec),
            "mlx_export_to_dot",
        )?;

        let _ = raw::mlx_array_free(val);
        let _ = raw::mlx_array_free(two);
        let _ = raw::mlx_array_free(res);
        let _ = raw::mlx_vector_array_free(vec);
        let _ = raw::mlx_node_namer_free(namer);
        let _ = raw::mlx_stream_free(stream);
    }
    Ok(())
}