use std::ffi::CStr;
use std::ptr;
use rlx_ir::DType;
use crate::ffi::{self, MlxDtype, RLX_MLX_OK, mlx_array_t};
pub struct Array {
pub(crate) ptr: *mut mlx_array_t,
}
impl Array {
pub fn from_f32_slice(data: &[f32], shape: &[usize], dtype: DType) -> Result<Self, MlxError> {
let shape_i: Vec<i32> = shape.iter().map(|&d| d as i32).collect();
let mut out: *mut mlx_array_t = ptr::null_mut();
let rc = unsafe {
ffi::rlx_mlx_array_from_data(
shape_i.as_ptr(),
shape_i.len(),
data.as_ptr(),
data.len(),
map_dtype(dtype),
&mut out,
)
};
check(rc)?;
Ok(Self { ptr: out })
}
pub fn to_f32(&self) -> Result<Vec<f32>, MlxError> {
let nelems = self.num_elements()?;
let mut buf = vec![0f32; nelems];
let rc = unsafe { ffi::rlx_mlx_array_to_f32(self.ptr, buf.as_mut_ptr(), nelems) };
check(rc)?;
Ok(buf)
}
pub fn from_bytes(data: &[u8], shape: &[usize], dtype: DType) -> Result<Self, MlxError> {
let shape_i: Vec<i32> = shape.iter().map(|&d| d as i32).collect();
let mut out: *mut mlx_array_t = std::ptr::null_mut();
let rc = unsafe {
ffi::rlx_mlx_array_from_bytes(
shape_i.as_ptr(),
shape_i.len(),
data.as_ptr() as *const std::ffi::c_void,
data.len(),
map_dtype(dtype),
&mut out,
)
};
check(rc)?;
Ok(Self { ptr: out })
}
pub fn to_bytes(&self) -> Result<Vec<u8>, MlxError> {
let nelems = self.num_elements()?;
let mut buf = vec![0u8; nelems * 8];
let mut written = 0usize;
let rc = unsafe {
ffi::rlx_mlx_array_to_bytes(
self.ptr,
buf.as_mut_ptr() as *mut std::ffi::c_void,
buf.len(),
&mut written,
)
};
check(rc)?;
buf.truncate(written);
Ok(buf)
}
pub fn shape(&self) -> Result<Vec<usize>, MlxError> {
let mut tmp = [0i32; 8];
let mut ndim = 0usize;
let rc =
unsafe { ffi::rlx_mlx_array_shape(self.ptr, tmp.as_mut_ptr(), tmp.len(), &mut ndim) };
check(rc)?;
Ok(tmp[..ndim].iter().map(|&d| d as usize).collect())
}
pub fn num_elements(&self) -> Result<usize, MlxError> {
Ok(self.shape()?.iter().product())
}
pub(crate) fn from_raw(ptr: *mut mlx_array_t) -> Self {
Self { ptr }
}
pub fn clone_handle(&self) -> Result<Self, MlxError> {
let mut out: *mut mlx_array_t = std::ptr::null_mut();
let rc = unsafe { ffi::rlx_mlx_array_clone(self.ptr, &mut out) };
check(rc)?;
Ok(Self { ptr: out })
}
}
impl Drop for Array {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { ffi::rlx_mlx_array_free(self.ptr) };
self.ptr = ptr::null_mut();
}
}
}
unsafe impl Send for Array {}
#[derive(Debug, Clone)]
pub struct MlxError(pub String);
impl std::fmt::Display for MlxError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "mlx error: {}", self.0)
}
}
impl std::error::Error for MlxError {}
pub(crate) fn check(rc: std::ffi::c_int) -> Result<(), MlxError> {
if rc == RLX_MLX_OK {
return Ok(());
}
let msg = unsafe {
let p = ffi::rlx_mlx_last_error();
if p.is_null() {
String::from("(no message)")
} else {
CStr::from_ptr(p).to_string_lossy().into_owned()
}
};
Err(MlxError(msg))
}
pub(crate) fn map_dtype(d: DType) -> MlxDtype {
match d {
DType::F32 => MlxDtype::F32,
DType::F16 => MlxDtype::F16,
DType::BF16 => MlxDtype::BF16,
DType::I32 => MlxDtype::I32,
DType::F64 => MlxDtype::F64,
DType::I8 => MlxDtype::I8,
DType::I16 => MlxDtype::I16,
DType::I64 => MlxDtype::I64,
DType::U8 => MlxDtype::U8,
DType::U32 => MlxDtype::U32,
DType::Bool => MlxDtype::Bool,
DType::C64 => panic!("rlx-mlx: DType::C64 (complex) not supported"),
}
}
pub fn version() -> String {
unsafe {
let p = ffi::rlx_mlx_version();
if p.is_null() {
return String::new();
}
CStr::from_ptr(p).to_string_lossy().into_owned()
}
}
pub fn device_name() -> String {
unsafe {
let p = ffi::rlx_mlx_device_name();
if p.is_null() {
return String::new();
}
CStr::from_ptr(p).to_string_lossy().into_owned()
}
}
pub fn eval(arrays: &[&Array]) -> Result<(), MlxError> {
if arrays.is_empty() {
return Ok(());
}
let handles: Vec<*mut mlx_array_t> = arrays.iter().map(|a| a.ptr).collect();
let rc = unsafe { ffi::rlx_mlx_eval(handles.as_ptr(), handles.len()) };
check(rc)
}
pub fn async_eval(arrays: &[&Array]) -> Result<(), MlxError> {
if arrays.is_empty() {
return Ok(());
}
let handles: Vec<*mut mlx_array_t> = arrays.iter().map(|a| a.ptr).collect();
let rc = unsafe { ffi::rlx_mlx_async_eval(handles.as_ptr(), handles.len()) };
check(rc)
}
pub fn synchronize() -> Result<(), MlxError> {
let rc = unsafe { ffi::rlx_mlx_synchronize() };
check(rc)
}