apple-mlx 0.1.1

Rust bindings and safe wrappers for Apple MLX via the official mlx-c API
Documentation
#![allow(unsafe_op_in_unsafe_fn)]

#[path = "support/mod.rs"]
mod support;

use apple_mlx::raw;
use support::{Result, check, print_array, vector_array_get};

fn main() -> Result<()> {
    unsafe {
        let stream = raw::mlx_default_gpu_stream_new();
        let mut input = raw::mlx_array_new();
        let empty_key = raw::mlx_array {
            ctx: std::ptr::null_mut(),
        };
        let dims = [4, 16];
        check(
            raw::mlx_random_normal(
                &mut input,
                dims.as_ptr(),
                dims.len(),
                raw::mlx_dtype__MLX_FLOAT32,
                0.0,
                1.0,
                empty_key,
                stream,
            ),
            "mlx_random_normal",
        )?;

        let source = support::cstring(
            "uint elem = thread_position_in_grid.x;\
             T tmp = inp[elem];\
             out[elem] = metal::exp(tmp);",
        )?;
        let input_names = raw::mlx_vector_string_new_value(c"inp".as_ptr());
        let output_names = raw::mlx_vector_string_new_value(c"out".as_ptr());
        let kernel = raw::mlx_fast_metal_kernel_new(
            c"myexp".as_ptr(),
            input_names,
            output_names,
            source.as_ptr(),
            c"".as_ptr(),
            true,
            false,
        );
        let config = raw::mlx_fast_metal_kernel_config_new();
        let inputs = raw::mlx_vector_array_new_value(input);
        check(
            raw::mlx_fast_metal_kernel_config_add_template_arg_dtype(
                config,
                c"T".as_ptr(),
                raw::mlx_dtype__MLX_FLOAT32,
            ),
            "mlx_fast_metal_kernel_config_add_template_arg_dtype",
        )?;
        check(
            raw::mlx_fast_metal_kernel_config_set_grid(
                config,
                raw::mlx_array_size(input) as i32,
                1,
                1,
            ),
            "mlx_fast_metal_kernel_config_set_grid",
        )?;
        check(
            raw::mlx_fast_metal_kernel_config_set_thread_group(config, 256, 1, 1),
            "mlx_fast_metal_kernel_config_set_thread_group",
        )?;
        check(
            raw::mlx_fast_metal_kernel_config_add_output_arg(
                config,
                raw::mlx_array_shape(input),
                raw::mlx_array_ndim(input),
                raw::mlx_array_dtype(input),
            ),
            "mlx_fast_metal_kernel_config_add_output_arg",
        )?;

        let mut outputs = raw::mlx_vector_array_new();
        check(
            raw::mlx_fast_metal_kernel_apply(&mut outputs, kernel, inputs, config, stream),
            "mlx_fast_metal_kernel_apply",
        )?;
        let output = vector_array_get(outputs, 0)?;

        print_array("input", input)?;
        print_array("output", output)?;

        let _ = raw::mlx_array_free(input);
        let _ = raw::mlx_array_free(output);
        let _ = raw::mlx_stream_free(stream);
        let _ = raw::mlx_fast_metal_kernel_config_free(config);
        let _ = raw::mlx_fast_metal_kernel_free(kernel);
        let _ = raw::mlx_vector_array_free(inputs);
        let _ = raw::mlx_vector_array_free(outputs);
        let _ = raw::mlx_vector_string_free(input_names);
        let _ = raw::mlx_vector_string_free(output_names);
    }
    Ok(())
}