use std::rc::Rc;
use crate::{
Array,
error::{Result, check, check_vector_array_handle, ensure_handler_installed},
stream::assert_streams_not_cleared,
transforms::closure::{
BoxedFn, BoxedFn3, Closure, ClosureCustomGuard, RawClosureGuard, VectorArrayGuard,
closure_custom_new, drain_vector, vector_array_from_slice,
},
};
pub fn custom_vjp<F, V>(f: F, vjp_fn: V) -> Result<impl Fn(&[Array]) -> Result<Vec<Array>>>
where
F: Fn(&[Array]) -> Result<Vec<Array>> + 'static,
V: Fn(&[Array], &[Array], &[Array]) -> Result<Vec<Array>> + 'static,
{
ensure_handler_installed();
let f: Rc<BoxedFn> = Rc::new(Box::new(f));
let vjp_closure: Rc<ClosureCustomGuard> = Rc::new(closure_custom_new(vjp_fn)?);
Ok(move |inputs: &[Array]| -> Result<Vec<Array>> {
let f = Rc::clone(&f);
let vjp_closure = Rc::clone(&vjp_closure);
let fwd = Closure::new(move |xs: &[Array]| f(xs))?;
let mut wrapped = unsafe { mlxrs_sys::mlx_closure_new() };
check(unsafe { mlxrs_sys::mlx_custom_vjp(&mut wrapped, fwd.as_raw(), vjp_closure.0) })?;
let wrapped_guard = RawClosureGuard(wrapped);
let in_guard = vector_array_from_slice(inputs)?;
let mut out = unsafe { mlxrs_sys::mlx_vector_array_new() };
check_vector_array_handle(out)?;
let _out_guard = VectorArrayGuard(out);
check(unsafe { mlxrs_sys::mlx_closure_apply(&mut out, wrapped_guard.0, in_guard.0) })?;
drain_vector(out)
})
}
pub fn custom_function<F, V>(f: F, vjp_fn: V) -> Result<impl Fn(&[Array]) -> Result<Vec<Array>>>
where
F: Fn(&[Array]) -> Result<Vec<Array>> + 'static,
V: Fn(&[Array], &[Array], &[Array]) -> Result<Vec<Array>> + 'static,
{
ensure_handler_installed();
let f: Rc<BoxedFn> = Rc::new(Box::new(f));
let vjp_closure: Rc<ClosureCustomGuard> = Rc::new(closure_custom_new(vjp_fn)?);
Ok(move |inputs: &[Array]| -> Result<Vec<Array>> {
let f = Rc::clone(&f);
let vjp_closure = Rc::clone(&vjp_closure);
let fwd = Closure::new(move |xs: &[Array]| f(xs))?;
let mut wrapped = unsafe { mlxrs_sys::mlx_closure_new() };
let null_jvp = mlxrs_sys::mlx_closure_custom_jvp {
ctx: std::ptr::null_mut(),
};
let null_vmap = mlxrs_sys::mlx_closure_custom_vmap {
ctx: std::ptr::null_mut(),
};
check(unsafe {
mlxrs_sys::mlx_custom_function(
&mut wrapped,
fwd.as_raw(),
vjp_closure.0,
null_jvp,
null_vmap,
)
})?;
let wrapped_guard = RawClosureGuard(wrapped);
let in_guard = vector_array_from_slice(inputs)?;
let mut out = unsafe { mlxrs_sys::mlx_vector_array_new() };
check_vector_array_handle(out)?;
let _out_guard = VectorArrayGuard(out);
check(unsafe { mlxrs_sys::mlx_closure_apply(&mut out, wrapped_guard.0, in_guard.0) })?;
drain_vector(out)
})
}
#[allow(dead_code)]
type _BoxedFn3 = BoxedFn3;
#[allow(dead_code)]
fn _streams_guard() {
assert_streams_not_cleared();
}