use std::{
ffi::c_void,
os::raw::c_int,
panic::{AssertUnwindSafe, catch_unwind},
ptr,
};
use crate::{
Array,
error::{Error, ParsePayload, Result, ensure_handler_installed},
};
pub(crate) type BoxedFn = Box<dyn Fn(&[Array]) -> Result<Vec<Array>> + 'static>;
pub struct Closure {
inner: mlxrs_sys::mlx_closure,
}
impl Closure {
pub fn new<F>(f: F) -> Result<Self>
where
F: Fn(&[Array]) -> Result<Vec<Array>> + 'static,
{
ensure_handler_installed();
let boxed: Box<BoxedFn> = Box::new(Box::new(f));
let payload_ptr: *mut c_void = Box::into_raw(boxed).cast();
let inner = unsafe { call_closure_new_ffi(payload_ptr) };
if inner.ctx.is_null() {
return Err(crate::error::take_last().unwrap_or(Error::FfiNullHandle(
crate::error::FfiNullHandlePayload::new("mlx_closure_new_func_payload"),
)));
}
Ok(Self { inner })
}
#[inline(always)]
pub fn as_raw(&self) -> mlxrs_sys::mlx_closure {
self.inner
}
}
impl Drop for Closure {
fn drop(&mut self) {
unsafe {
let _ = mlxrs_sys::mlx_closure_free(self.inner);
}
}
}
#[inline]
unsafe fn call_closure_new_ffi(payload_ptr: *mut c_void) -> mlxrs_sys::mlx_closure {
#[cfg(not(test))]
unsafe {
mlxrs_sys::mlx_closure_new_func_payload(Some(trampoline), payload_ptr, Some(destroy_payload))
}
#[cfg(test)]
unsafe {
(test_seam::closure_new_fn())(Some(trampoline), payload_ptr, Some(destroy_payload))
}
}
#[inline]
unsafe fn call_closure_custom_new_ffi(payload_ptr: *mut c_void) -> mlxrs_sys::mlx_closure_custom {
#[cfg(not(test))]
unsafe {
mlxrs_sys::mlx_closure_custom_new_func_payload(
Some(trampoline_custom),
payload_ptr,
Some(destroy_payload_3),
)
}
#[cfg(test)]
unsafe {
(test_seam::closure_custom_new_fn())(
Some(trampoline_custom),
payload_ptr,
Some(destroy_payload_3),
)
}
}
extern "C" fn trampoline(
outputs_out: *mut mlxrs_sys::mlx_vector_array,
inputs: mlxrs_sys::mlx_vector_array,
payload: *mut c_void,
) -> c_int {
let result = catch_unwind(AssertUnwindSafe(|| {
let f: &BoxedFn = unsafe { &*payload.cast::<BoxedFn>() };
let inputs_vec = borrow_inputs(inputs)?;
let outputs = f(&inputs_vec)?;
write_outputs(outputs_out, &outputs)?;
Ok::<(), Error>(())
}));
match result {
Ok(Ok(())) => 0,
Ok(Err(e)) => {
crate::error::set_last(e);
unsafe {
if !outputs_out.is_null() {
*outputs_out = mlxrs_sys::mlx_vector_array_new();
}
}
1
}
Err(panic_payload) => {
let msg = if let Some(s) = panic_payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = panic_payload.downcast_ref::<String>() {
s.clone()
} else {
"panic in mlxrs::transforms closure trampoline".to_string()
};
crate::error::set_last(Error::Parse(ParsePayload::new(
"transforms::closure trampoline: caught panic",
"Rust closure panic payload",
std::io::Error::other(msg),
)));
unsafe {
if !outputs_out.is_null() {
*outputs_out = mlxrs_sys::mlx_vector_array_new();
}
}
1
}
}
}
extern "C" fn destroy_payload(payload: *mut c_void) {
if payload.is_null() {
return;
}
let _ = catch_unwind(AssertUnwindSafe(|| {
let _: Box<BoxedFn> = unsafe { Box::from_raw(payload.cast::<BoxedFn>()) };
}));
}
pub(crate) use crate::ffi::{VectorArrayGuard, drain_vector};
fn borrow_inputs(vec: mlxrs_sys::mlx_vector_array) -> Result<Vec<Array>> {
drain_vector(vec)
}
fn write_outputs(out: *mut mlxrs_sys::mlx_vector_array, outputs: &[Array]) -> Result<()> {
let raw: Vec<mlxrs_sys::mlx_array> = outputs.iter().map(|a| a.0).collect();
let data_ptr = if raw.is_empty() {
ptr::null()
} else {
raw.as_ptr()
};
unsafe {
*out = mlxrs_sys::mlx_vector_array_new_data(data_ptr, raw.len());
}
if unsafe { (*out).ctx.is_null() } && !outputs.is_empty() {
return Err(crate::error::take_last().unwrap_or(Error::FfiNullHandle(
crate::error::FfiNullHandlePayload::new("mlx_vector_array_new_data"),
)));
}
Ok(())
}
pub(crate) fn vector_array_from_borrow(arrays: &[&Array]) -> Result<VectorArrayGuard> {
ensure_handler_installed();
let raw: Vec<mlxrs_sys::mlx_array> = arrays.iter().map(|a| a.0).collect();
let data_ptr = if raw.is_empty() {
ptr::null()
} else {
raw.as_ptr()
};
let vec = unsafe { mlxrs_sys::mlx_vector_array_new_data(data_ptr, raw.len()) };
if vec.ctx.is_null() {
return Err(crate::error::take_last().unwrap_or(Error::FfiNullHandle(
crate::error::FfiNullHandlePayload::new("mlx_vector_array_new_data"),
)));
}
Ok(VectorArrayGuard(vec))
}
pub(crate) fn vector_array_from_slice(arrays: &[Array]) -> Result<VectorArrayGuard> {
ensure_handler_installed();
let raw: Vec<mlxrs_sys::mlx_array> = arrays.iter().map(|a| a.0).collect();
let data_ptr = if raw.is_empty() {
ptr::null()
} else {
raw.as_ptr()
};
let vec = unsafe { mlxrs_sys::mlx_vector_array_new_data(data_ptr, raw.len()) };
if vec.ctx.is_null() {
return Err(crate::error::take_last().unwrap_or(Error::FfiNullHandle(
crate::error::FfiNullHandlePayload::new("mlx_vector_array_new_data"),
)));
}
Ok(VectorArrayGuard(vec))
}
pub(crate) struct ClosureValueAndGradGuard(pub(crate) mlxrs_sys::mlx_closure_value_and_grad);
impl ClosureValueAndGradGuard {
#[allow(dead_code)]
#[inline(always)]
pub(crate) const fn as_raw(&self) -> mlxrs_sys::mlx_closure_value_and_grad {
self.0
}
}
impl Drop for ClosureValueAndGradGuard {
fn drop(&mut self) {
unsafe {
let _ = mlxrs_sys::mlx_closure_value_and_grad_free(self.0);
}
}
}
pub(crate) struct ClosureCustomGuard(pub(crate) mlxrs_sys::mlx_closure_custom);
impl ClosureCustomGuard {
#[allow(dead_code)]
#[inline(always)]
pub(crate) const fn as_raw(&self) -> mlxrs_sys::mlx_closure_custom {
self.0
}
}
impl Drop for ClosureCustomGuard {
fn drop(&mut self) {
unsafe {
let _ = mlxrs_sys::mlx_closure_custom_free(self.0);
}
}
}
pub(crate) struct RawClosureGuard(pub(crate) mlxrs_sys::mlx_closure);
impl RawClosureGuard {
#[allow(dead_code)]
#[inline(always)]
pub(crate) const fn as_raw(&self) -> mlxrs_sys::mlx_closure {
self.0
}
}
impl Drop for RawClosureGuard {
fn drop(&mut self) {
unsafe {
let _ = mlxrs_sys::mlx_closure_free(self.0);
}
}
}
pub(crate) fn closure_custom_new<F>(f: F) -> Result<ClosureCustomGuard>
where
F: Fn(&[Array], &[Array], &[Array]) -> Result<Vec<Array>> + 'static,
{
ensure_handler_installed();
let boxed: Box<BoxedFn3> = Box::new(Box::new(f));
let payload_ptr: *mut c_void = Box::into_raw(boxed).cast();
let inner = unsafe { call_closure_custom_new_ffi(payload_ptr) };
if inner.ctx.is_null() {
return Err(crate::error::take_last().unwrap_or(Error::FfiNullHandle(
crate::error::FfiNullHandlePayload::new("mlx_closure_custom_new_func_payload"),
)));
}
Ok(ClosureCustomGuard(inner))
}
pub(crate) type BoxedFn3 =
Box<dyn Fn(&[Array], &[Array], &[Array]) -> Result<Vec<Array>> + 'static>;
extern "C" fn trampoline_custom(
outputs_out: *mut mlxrs_sys::mlx_vector_array,
primals: mlxrs_sys::mlx_vector_array,
cotangents: mlxrs_sys::mlx_vector_array,
outputs: mlxrs_sys::mlx_vector_array,
payload: *mut c_void,
) -> c_int {
let result = catch_unwind(AssertUnwindSafe(|| {
let f: &BoxedFn3 = unsafe { &*payload.cast::<BoxedFn3>() };
let p = borrow_inputs(primals)?;
let c = borrow_inputs(cotangents)?;
let o = borrow_inputs(outputs)?;
let grads = f(&p, &c, &o)?;
write_outputs(outputs_out, &grads)?;
Ok::<(), Error>(())
}));
match result {
Ok(Ok(())) => 0,
Ok(Err(e)) => {
crate::error::set_last(e);
unsafe {
if !outputs_out.is_null() {
*outputs_out = mlxrs_sys::mlx_vector_array_new();
}
}
1
}
Err(panic_payload) => {
let msg = if let Some(s) = panic_payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = panic_payload.downcast_ref::<String>() {
s.clone()
} else {
"panic in mlxrs::transforms custom-VJP trampoline".to_string()
};
crate::error::set_last(Error::Parse(ParsePayload::new(
"transforms::custom_vjp trampoline: caught panic",
"Rust closure panic payload",
std::io::Error::other(msg),
)));
unsafe {
if !outputs_out.is_null() {
*outputs_out = mlxrs_sys::mlx_vector_array_new();
}
}
1
}
}
}
extern "C" fn destroy_payload_3(payload: *mut c_void) {
if payload.is_null() {
return;
}
let _ = catch_unwind(AssertUnwindSafe(|| {
let _: Box<BoxedFn3> = unsafe { Box::from_raw(payload.cast::<BoxedFn3>()) };
}));
}
#[cfg(test)]
pub(crate) mod test_seam {
use std::sync::{
Mutex, MutexGuard, OnceLock,
atomic::{AtomicPtr, Ordering},
};
use super::*;
pub(crate) type ClosureNewFn = unsafe extern "C" fn(
fun: Option<
unsafe extern "C" fn(
*mut mlxrs_sys::mlx_vector_array,
mlxrs_sys::mlx_vector_array,
*mut c_void,
) -> c_int,
>,
payload: *mut c_void,
dtor: Option<unsafe extern "C" fn(*mut c_void)>,
) -> mlxrs_sys::mlx_closure;
pub(crate) type ClosureCustomNewFn = unsafe extern "C" fn(
fun: Option<
unsafe extern "C" fn(
*mut mlxrs_sys::mlx_vector_array,
mlxrs_sys::mlx_vector_array,
mlxrs_sys::mlx_vector_array,
mlxrs_sys::mlx_vector_array,
*mut c_void,
) -> c_int,
>,
payload: *mut c_void,
dtor: Option<unsafe extern "C" fn(*mut c_void)>,
) -> mlxrs_sys::mlx_closure_custom;
fn closure_new_slot() -> &'static AtomicPtr<()> {
static SLOT: OnceLock<AtomicPtr<()>> = OnceLock::new();
SLOT.get_or_init(|| AtomicPtr::new(mlxrs_sys::mlx_closure_new_func_payload as *mut ()))
}
fn closure_new_install_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
fn closure_custom_new_slot() -> &'static AtomicPtr<()> {
static SLOT: OnceLock<AtomicPtr<()>> = OnceLock::new();
SLOT.get_or_init(|| AtomicPtr::new(mlxrs_sys::mlx_closure_custom_new_func_payload as *mut ()))
}
fn closure_custom_new_install_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
pub(crate) fn closure_new_fn() -> ClosureNewFn {
let ptr = closure_new_slot().load(Ordering::Acquire);
unsafe { std::mem::transmute::<*mut (), ClosureNewFn>(ptr) }
}
pub(crate) fn closure_custom_new_fn() -> ClosureCustomNewFn {
let ptr = closure_custom_new_slot().load(Ordering::Acquire);
unsafe { std::mem::transmute::<*mut (), ClosureCustomNewFn>(ptr) }
}
pub(crate) struct ScopedClosureCtor {
_install_guard: MutexGuard<'static, ()>,
prev: *mut (),
}
impl ScopedClosureCtor {
pub(crate) fn install(stub: ClosureNewFn) -> Self {
let guard = closure_new_install_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
let stub_ptr = stub as *mut ();
let prev = closure_new_slot().swap(stub_ptr, Ordering::AcqRel);
Self {
_install_guard: guard,
prev,
}
}
}
impl Drop for ScopedClosureCtor {
fn drop(&mut self) {
closure_new_slot().store(self.prev, Ordering::Release);
}
}
pub(crate) struct ScopedCustomCtor {
_install_guard: MutexGuard<'static, ()>,
prev: *mut (),
}
impl ScopedCustomCtor {
pub(crate) fn install(stub: ClosureCustomNewFn) -> Self {
let guard = closure_custom_new_install_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
let stub_ptr = stub as *mut ();
let prev = closure_custom_new_slot().swap(stub_ptr, Ordering::AcqRel);
Self {
_install_guard: guard,
prev,
}
}
}
impl Drop for ScopedCustomCtor {
fn drop(&mut self) {
closure_custom_new_slot().store(self.prev, Ordering::Release);
}
}
}
#[cfg(test)]
mod tests;