1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
//! Gradient checkpointing: [`checkpoint`].
//!
//! Mirrors `mlx.core.checkpoint` (Python) /
//! [`mlx-swift`](https://github.com/ml-explore/mlx-swift)'s checkpointing
//! helpers, and the `mlx.nn.utils.checkpoint` recipe. Wraps a function so
//! its intermediate activations are *recomputed* during the backward pass
//! instead of being stored, trading compute for memory — useful when peak
//! memory dominates training cost (e.g. long-sequence transformers).
//!
//! ## Semantics
//!
//! - Forward pass: identical to the unwrapped function. Returns the same
//! `Vec<Array>`.
//! - Backward pass (when differentiated via [`super::grad`] /
//! [`super::value_and_grad`] / [`super::vjp`]): mlx re-traces the forward
//! function to reconstruct activations on demand, rather than holding them
//! live from the forward pass. Mathematically equivalent gradient; lower
//! peak memory; ~2x forward compute over the wrapped region.
//!
//! ## Re-entrancy
//!
//! Like [`super::custom_vjp`], the underlying `mlx_closure` is built once at
//! construction time (held in `Rc` so the returned `Fn` can call it
//! repeatedly). The wrapped `mlx_closure` returned by `mlx_checkpoint` is
//! also built once and cached.
use Rc;
use crate::;
/// Wrap `f` so its activations are recomputed (not stored) during backward.
/// Forward pass is identical to invoking `f` directly.
///
/// ```no_run
/// # fn run() -> mlxrs::Result<()> {
/// use mlxrs::{Array, transforms::{checkpoint, grad}};
/// // Wrap a function in `checkpoint` — forward identical, backward
/// // recomputes the activations.
/// let cf = checkpoint(|xs| Ok(vec![mlxrs::ops::arithmetic::square(&xs[0])?]))?;
/// let x = Array::full::<f32>(&[0i32; 0], 3.0)?;
/// let mut vals = cf(&[x.try_clone()?])?;
/// assert_eq!(vals[0].item::<f32>()?, 9.0);
///
/// // Gradient through the checkpointed function is identical to the
/// // non-checkpointed gradient (same math, different memory profile).
/// let g = grad(cf, &[0])?;
/// let mut grads = g(&[x])?;
/// assert_eq!(grads[0].item::<f32>()?, 6.0);
/// # Ok(()) }
/// ```