1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
//! Function transforms: autograd (`value_and_grad`, `grad`, `vjp`, `jvp`),
//! custom-VJP overrides, gradient checkpointing, and bulk eval / async-eval.
//!
//! Mirrors `mlx-swift`'s `MLX.Transforms` (`Transforms.swift`,
//! `Transforms+Eval.swift`, `Transforms+Grad.swift`, `Transforms+Internal.swift`)
//! and `mlx.core.{value_and_grad,grad,vjp,jvp,custom_function,custom_vjp,
//! checkpoint,eval,async_eval}` on the Python side.
//!
//! ## API surface
//!
//! - [`crate::transforms::closure::Closure`] — RAII wrapper over
//! `mlx_closure` that owns the captured Rust callable for the FFI's
//! lifetime. Used internally by the autograd builders; exposed in case a
//! caller needs to build a closure directly.
//! - [`crate::transforms::autograd::value_and_grad`] /
//! [`crate::transforms::autograd::grad`] — return a Rust closure that, when
//! invoked on a slice of [`crate::Array`], runs the forward pass and
//! computes gradients with respect to a chosen subset of inputs. The
//! returned closure is `Fn`-callable repeatedly with different inputs.
//! - [`crate::transforms::autograd::vjp`] /
//! [`crate::transforms::autograd::jvp`] — one-shot vector-Jacobian and
//! Jacobian-vector products over a user function evaluated at `primals`.
//! - [`crate::transforms::custom::custom_vjp`] /
//! [`crate::transforms::custom::custom_function`] — wrap a forward function
//! with a user-defined backward (cotangent) function, overriding the
//! autograd-derived VJP.
//! - [`crate::transforms::checkpoint::checkpoint`] — wrap a function so its
//! activations are recomputed (rather than stored) during the backward
//! pass, trading compute for memory.
//! - [`crate::transforms::eval::eval`] / [`crate::transforms::eval::async_eval`]
//! — synchronously / asynchronously materialize the lazy graph rooted at a
//! batch of arrays.
//!
//! ## Threading
//!
//! Like the rest of mlxrs, `Closure` and the returned `impl Fn` callables are
//! `!Send` + `!Sync` (they own [`crate::Array`] handles transitively through
//! the trampoline's closure, and mlx's evaluator is single-threaded — see
//! `crate::array::Array` for the rationale). The Rust callable passed in
//! (`F: Fn(&[Array]) -> Result<Vec<Array>>`) is still required `+ 'static` so
//! it can outlive the construction scope and be invoked from mlx-c.
pub use ;
pub use checkpoint;
pub use Closure;
pub use ;
pub use ;