use std::rc::Rc;
use crate::{
Array,
error::{
EmptyInputPayload, Error, Result, check, check_vector_array_handle, ensure_handler_installed,
},
stream::assert_streams_not_cleared,
transforms::closure::{
BoxedFn, Closure, ClosureValueAndGradGuard, VectorArrayGuard, drain_vector,
vector_array_from_slice,
},
};
#[allow(clippy::type_complexity)]
pub fn value_and_grad<F>(
f: F,
argnums: &[i32],
) -> Result<impl Fn(&[Array]) -> Result<(Vec<Array>, Vec<Array>)>>
where
F: Fn(&[Array]) -> Result<Vec<Array>> + 'static,
{
if argnums.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"value_and_grad: argnums",
)));
}
let f: Rc<BoxedFn> = Rc::new(Box::new(f));
let argnums = argnums.to_vec();
Ok(
move |inputs: &[Array]| -> Result<(Vec<Array>, Vec<Array>)> {
let f = Rc::clone(&f);
let closure = Closure::new(move |xs: &[Array]| f(xs))?;
let vag = build_value_and_grad(&closure, &argnums)?;
apply_value_and_grad(&vag, inputs)
},
)
}
pub fn grad<F>(f: F, argnums: &[i32]) -> Result<impl Fn(&[Array]) -> Result<Vec<Array>>>
where
F: Fn(&[Array]) -> Result<Vec<Array>> + 'static,
{
let vag = value_and_grad(f, argnums)?;
Ok(move |inputs: &[Array]| -> Result<Vec<Array>> { Ok(vag(inputs)?.1) })
}
pub fn vjp<F>(f: F, primals: &[Array], cotangents: &[Array]) -> Result<(Vec<Array>, Vec<Array>)>
where
F: Fn(&[Array]) -> Result<Vec<Array>> + 'static,
{
ensure_handler_installed();
assert_streams_not_cleared();
let closure = Closure::new(f)?;
let p_guard = vector_array_from_slice(primals)?;
let c_guard = vector_array_from_slice(cotangents)?;
let mut out0 = unsafe { mlxrs_sys::mlx_vector_array_new() };
check_vector_array_handle(out0)?;
let _out0_guard = VectorArrayGuard(out0);
let mut out1 = unsafe { mlxrs_sys::mlx_vector_array_new() };
check_vector_array_handle(out1)?;
let _out1_guard = VectorArrayGuard(out1);
check(unsafe {
mlxrs_sys::mlx_vjp(&mut out0, &mut out1, closure.as_raw(), p_guard.0, c_guard.0)
})?;
let values = drain_vector(out0)?;
let grads = drain_vector(out1)?;
Ok((values, grads))
}
pub fn jvp<F>(f: F, primals: &[Array], tangents: &[Array]) -> Result<(Vec<Array>, Vec<Array>)>
where
F: Fn(&[Array]) -> Result<Vec<Array>> + 'static,
{
ensure_handler_installed();
assert_streams_not_cleared();
let closure = Closure::new(f)?;
let p_guard = vector_array_from_slice(primals)?;
let t_guard = vector_array_from_slice(tangents)?;
let mut out0 = unsafe { mlxrs_sys::mlx_vector_array_new() };
check_vector_array_handle(out0)?;
let _out0_guard = VectorArrayGuard(out0);
let mut out1 = unsafe { mlxrs_sys::mlx_vector_array_new() };
check_vector_array_handle(out1)?;
let _out1_guard = VectorArrayGuard(out1);
check(unsafe {
mlxrs_sys::mlx_jvp(&mut out0, &mut out1, closure.as_raw(), p_guard.0, t_guard.0)
})?;
let values = drain_vector(out0)?;
let jvp_out = drain_vector(out1)?;
Ok((values, jvp_out))
}
fn build_value_and_grad(closure: &Closure, argnums: &[i32]) -> Result<ClosureValueAndGradGuard> {
ensure_handler_installed();
let mut vag = unsafe { mlxrs_sys::mlx_closure_value_and_grad_new() };
debug_assert!(
!argnums.is_empty(),
"build_value_and_grad: empty argnums must be rejected at value_and_grad"
);
let argnums_ptr = argnums.as_ptr();
check(unsafe {
mlxrs_sys::mlx_value_and_grad(&mut vag, closure.as_raw(), argnums_ptr, argnums.len())
})?;
Ok(ClosureValueAndGradGuard(vag))
}
fn apply_value_and_grad(
vag: &ClosureValueAndGradGuard,
inputs: &[Array],
) -> Result<(Vec<Array>, Vec<Array>)> {
ensure_handler_installed();
assert_streams_not_cleared();
let in_guard = vector_array_from_slice(inputs)?;
let mut out0 = unsafe { mlxrs_sys::mlx_vector_array_new() };
check_vector_array_handle(out0)?;
let _out0_guard = VectorArrayGuard(out0);
let mut out1 = unsafe { mlxrs_sys::mlx_vector_array_new() };
check_vector_array_handle(out1)?;
let _out1_guard = VectorArrayGuard(out1);
check(unsafe {
mlxrs_sys::mlx_closure_value_and_grad_apply(&mut out0, &mut out1, vag.0, in_guard.0)
})?;
let values = drain_vector(out0)?;
let grads = drain_vector(out1)?;
Ok((values, grads))
}