use std::sync::Arc;
use crate::autograd::autocast::{
AutocastSnapshot, current_autocast_snapshot, with_autocast_state,
};
use crate::dtype::Float;
use crate::error::FerrotorchResult;
use crate::gpu_dispatch::GpuRngState;
use crate::tensor::Tensor;
type CheckpointFn<T> = Arc<dyn Fn(&Tensor<T>) -> FerrotorchResult<Tensor<T>> + Send + Sync>;
type CheckpointMultiFn<T> =
Arc<dyn Fn(&[Tensor<T>]) -> FerrotorchResult<Tensor<T>> + Send + Sync>;
pub fn checkpoint<T, F>(f: F, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
where
T: Float,
F: Fn(&Tensor<T>) -> FerrotorchResult<Tensor<T>> + Send + Sync + 'static,
{
use crate::autograd::no_grad::no_grad;
use crate::storage::TensorStorage;
let saved_gpu_rng = save_gpu_rng_state(input);
let saved_autocast = current_autocast_snapshot();
let output = no_grad(|| f(input))?;
if !input.requires_grad() {
return Ok(output);
}
let checkpoint_fn = Arc::new(CheckpointBackward {
func: Arc::new(f),
input: input.clone(),
output_shape: output.shape().to_vec(),
saved_gpu_rng,
saved_autocast,
});
let device = output.device();
let storage = TensorStorage::on_device(output.data_vec()?, device)?;
Tensor::from_operation(storage, output.shape().to_vec(), checkpoint_fn)
}
pub fn checkpoint_multi<T, F>(f: F, inputs: &[Tensor<T>]) -> FerrotorchResult<Tensor<T>>
where
T: Float,
F: Fn(&[Tensor<T>]) -> FerrotorchResult<Tensor<T>> + Send + Sync + 'static,
{
use crate::autograd::no_grad::no_grad;
use crate::storage::TensorStorage;
if inputs.is_empty() {
return Err(crate::error::FerrotorchError::InvalidArgument {
message: "checkpoint_multi: at least one input required".into(),
});
}
let saved_gpu_rng = save_gpu_rng_state(&inputs[0]);
let saved_autocast = current_autocast_snapshot();
let output = no_grad(|| f(inputs))?;
let any_requires_grad = inputs.iter().any(|t| t.requires_grad());
if !any_requires_grad {
return Ok(output);
}
let checkpoint_fn = Arc::new(CheckpointMultiBackward {
func: Arc::new(f),
inputs: inputs.to_vec(),
output_shape: output.shape().to_vec(),
saved_gpu_rng,
saved_autocast,
});
let device = output.device();
let storage = TensorStorage::on_device(output.data_vec()?, device)?;
Tensor::from_operation(storage, output.shape().to_vec(), checkpoint_fn)
}
fn save_gpu_rng_state<T: Float>(tensor: &Tensor<T>) -> Option<GpuRngState> {
let device_ordinal = match tensor.device() {
crate::device::Device::Cuda(id) => id,
crate::device::Device::Xpu(_)
| crate::device::Device::Cpu
| crate::device::Device::Meta => return None,
};
let backend = crate::gpu_dispatch::gpu_backend()?;
backend.save_rng_state(device_ordinal).ok()
}
struct GpuRngGuard {
previous: GpuRngState,
}
impl Drop for GpuRngGuard {
fn drop(&mut self) {
if let Some(backend) = crate::gpu_dispatch::gpu_backend() {
let _ = backend.restore_rng_state(self.previous);
}
}
}
struct CheckpointBackward<T: Float> {
func: CheckpointFn<T>,
input: Tensor<T>,
output_shape: Vec<usize>,
saved_gpu_rng: Option<GpuRngState>,
saved_autocast: AutocastSnapshot,
}
impl<T: Float> std::fmt::Debug for CheckpointBackward<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CheckpointBackward")
.field("input_shape", &self.input.shape())
.field("output_shape", &self.output_shape)
.field("has_gpu_rng_state", &self.saved_gpu_rng.is_some())
.field("autocast_enabled", &self.saved_autocast.enabled)
.finish()
}
}
impl<T: Float> crate::tensor::GradFn<T> for CheckpointBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let _rng_guard = if let Some(saved_state) = self.saved_gpu_rng {
let current_state = save_gpu_rng_state(&self.input);
if let Some(backend) = crate::gpu_dispatch::gpu_backend() {
let _ = backend.restore_rng_state(saved_state);
}
current_state.map(|prev| GpuRngGuard { previous: prev })
} else {
None
};
with_autocast_state(self.saved_autocast, || {
let input_with_grad = self.input.clone().requires_grad_(true);
let recomputed = (self.func)(&input_with_grad)?;
use crate::grad_fns::arithmetic::mul;
use crate::grad_fns::reduction::sum;
let weighted = mul(
&recomputed,
&grad_output.clone().requires_grad_(false).detach(),
)?;
let scalar = sum(&weighted)?;
scalar.backward()?;
let input_grad = input_with_grad.grad()?;
Ok(vec![input_grad])
})
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"CheckpointBackward"
}
}
struct CheckpointMultiBackward<T: Float> {
func: CheckpointMultiFn<T>,
inputs: Vec<Tensor<T>>,
output_shape: Vec<usize>,
saved_gpu_rng: Option<GpuRngState>,
saved_autocast: AutocastSnapshot,
}
impl<T: Float> std::fmt::Debug for CheckpointMultiBackward<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CheckpointMultiBackward")
.field("num_inputs", &self.inputs.len())
.field("output_shape", &self.output_shape)
.field("has_gpu_rng_state", &self.saved_gpu_rng.is_some())
.field("autocast_enabled", &self.saved_autocast.enabled)
.finish()
}
}
impl<T: Float> crate::tensor::GradFn<T> for CheckpointMultiBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let _rng_guard = if let Some(saved_state) = self.saved_gpu_rng {
let current_state = save_gpu_rng_state(&self.inputs[0]);
if let Some(backend) = crate::gpu_dispatch::gpu_backend() {
let _ = backend.restore_rng_state(saved_state);
}
current_state.map(|prev| GpuRngGuard { previous: prev })
} else {
None
};
with_autocast_state(self.saved_autocast, || {
let inputs_with_grad: Vec<Tensor<T>> = self
.inputs
.iter()
.map(|t| {
if t.requires_grad() {
t.clone().requires_grad_(true)
} else {
t.clone()
}
})
.collect();
let recomputed = (self.func)(&inputs_with_grad)?;
use crate::grad_fns::arithmetic::mul;
use crate::grad_fns::reduction::sum;
let weighted = mul(
&recomputed,
&grad_output.clone().requires_grad_(false).detach(),
)?;
let scalar = sum(&weighted)?;
scalar.backward()?;
let mut grads = Vec::with_capacity(self.inputs.len());
for (orig, with_grad) in self.inputs.iter().zip(inputs_with_grad.iter()) {
if orig.requires_grad() {
grads.push(with_grad.grad()?);
} else {
grads.push(None);
}
}
Ok(grads)
})
}
fn inputs(&self) -> Vec<&Tensor<T>> {
self.inputs.iter().collect()
}
fn name(&self) -> &'static str {
"CheckpointMultiBackward"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::autocast::{AutocastDtype, autocast, is_autocast_enabled};
use crate::creation::{from_slice, scalar};
use crate::grad_fns::arithmetic::{add, mul};
use crate::grad_fns::reduction::sum;
use crate::storage::TensorStorage;
fn leaf_grad(data: &[f32], shape: &[usize]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), true).unwrap()
}
#[test]
fn test_checkpoint_single_input_basic() {
let x = leaf_grad(&[1.0, 2.0, 3.0], &[3]);
let y = checkpoint(
|t: &Tensor<f32>| {
let sq = mul(t, t)?;
add(&sq, t)
},
&x,
)
.unwrap();
let s = sum(&y).unwrap();
assert!((s.item().unwrap() - 20.0).abs() < 1e-5);
s.backward().unwrap();
let g = x.grad().unwrap().expect("x should have a gradient");
let gd = g.data().unwrap();
assert!((gd[0] - 3.0).abs() < 1e-5);
assert!((gd[1] - 5.0).abs() < 1e-5);
assert!((gd[2] - 7.0).abs() < 1e-5);
}
#[test]
fn test_checkpoint_no_grad_input_returns_output_only() {
let x = from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
let y = checkpoint(
|t: &Tensor<f32>| {
let two = scalar(2.0f32)?;
mul(t, &two)
},
&x,
)
.unwrap();
let yd = y.data().unwrap();
assert_eq!(yd, &[2.0, 4.0, 6.0]);
assert!(y.grad_fn().is_none());
}
#[test]
fn test_checkpoint_multi_two_inputs_both_grad() {
let a = leaf_grad(&[1.0, 2.0, 3.0], &[3]);
let b = leaf_grad(&[4.0, 5.0, 6.0], &[3]);
let y = checkpoint_multi(
|ts: &[Tensor<f32>]| {
let prod = mul(&ts[0], &ts[1])?;
add(&prod, &ts[0])
},
&[a.clone(), b.clone()],
)
.unwrap();
let s = sum(&y).unwrap();
s.backward().unwrap();
let ga = a.grad().unwrap().expect("a should have a gradient");
let gad = ga.data().unwrap();
assert!((gad[0] - 5.0).abs() < 1e-5);
assert!((gad[1] - 6.0).abs() < 1e-5);
assert!((gad[2] - 7.0).abs() < 1e-5);
let gb = b.grad().unwrap().expect("b should have a gradient");
let gbd = gb.data().unwrap();
assert!((gbd[0] - 1.0).abs() < 1e-5);
assert!((gbd[1] - 2.0).abs() < 1e-5);
assert!((gbd[2] - 3.0).abs() < 1e-5);
}
#[test]
fn test_checkpoint_multi_partial_grad() {
let a = from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
let b = leaf_grad(&[4.0, 5.0, 6.0], &[3]);
let y = checkpoint_multi(
|ts: &[Tensor<f32>]| mul(&ts[0], &ts[1]),
&[a.clone(), b.clone()],
)
.unwrap();
let s = sum(&y).unwrap();
s.backward().unwrap();
assert!(a.grad().unwrap().is_none());
let gb = b.grad().unwrap().expect("b should have a gradient");
let gbd = gb.data().unwrap();
assert_eq!(gbd, &[1.0, 2.0, 3.0]);
}
#[test]
fn test_checkpoint_multi_empty_inputs_errors() {
let result = checkpoint_multi(|_: &[Tensor<f32>]| panic!("should not run"), &[]);
assert!(result.is_err());
}
#[test]
fn test_checkpoint_multi_no_grad_inputs_returns_output_only() {
let a = from_slice(&[1.0f32, 2.0], &[2]).unwrap();
let b = from_slice(&[3.0f32, 4.0], &[2]).unwrap();
let y = checkpoint_multi(
|ts: &[Tensor<f32>]| add(&ts[0], &ts[1]),
&[a, b],
)
.unwrap();
let yd = y.data().unwrap();
assert_eq!(yd, &[4.0, 6.0]);
assert!(y.grad_fn().is_none());
}
#[test]
fn test_current_autocast_snapshot_outside_region() {
let snap = current_autocast_snapshot();
assert!(!snap.enabled);
}
#[test]
fn test_current_autocast_snapshot_inside_region() {
autocast(AutocastDtype::BF16, || {
let snap = current_autocast_snapshot();
assert!(snap.enabled);
assert_eq!(snap.dtype, AutocastDtype::BF16);
});
}
#[test]
fn test_with_autocast_state_restores_disabled() {
let disabled = AutocastSnapshot {
enabled: false,
dtype: AutocastDtype::F16,
};
autocast(AutocastDtype::F16, || {
assert!(is_autocast_enabled());
with_autocast_state(disabled, || {
assert!(!is_autocast_enabled());
});
assert!(is_autocast_enabled());
});
}
#[test]
fn test_with_autocast_state_overrides_dtype() {
let f16_state = AutocastSnapshot {
enabled: true,
dtype: AutocastDtype::F16,
};
autocast(AutocastDtype::BF16, || {
with_autocast_state(f16_state, || {
assert!(is_autocast_enabled());
assert_eq!(crate::autograd::autocast::autocast_dtype(), AutocastDtype::F16);
});
assert_eq!(crate::autograd::autocast::autocast_dtype(), AutocastDtype::BF16);
});
}
#[test]
fn test_checkpoint_captures_autocast_snapshot() {
let x = leaf_grad(&[1.0f32, 2.0, 3.0], &[3]);
let y_inside = autocast(AutocastDtype::F16, || {
checkpoint(|t: &Tensor<f32>| mul(t, t), &x)
})
.unwrap();
let dbg = format!("{:?}", y_inside.grad_fn().unwrap());
assert!(
dbg.contains("autocast_enabled: true"),
"expected captured autocast=true in debug repr, got {}",
dbg
);
}
#[test]
fn test_checkpoint_outside_autocast_captures_disabled() {
let x = leaf_grad(&[1.0f32, 2.0, 3.0], &[3]);
let y = checkpoint(|t: &Tensor<f32>| mul(t, t), &x).unwrap();
let dbg = format!("{:?}", y.grad_fn().unwrap());
assert!(
dbg.contains("autocast_enabled: false"),
"expected captured autocast=false in debug repr, got {}",
dbg
);
}
#[test]
fn test_checkpoint_recomputation_uses_saved_autocast() {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc as StdArc;
let saw_autocast = StdArc::new(AtomicBool::new(false));
let saw_autocast_clone = StdArc::clone(&saw_autocast);
let x = leaf_grad(&[1.0f32, 2.0, 3.0], &[3]);
let y = autocast(AutocastDtype::F16, || {
checkpoint(
move |t: &Tensor<f32>| {
saw_autocast_clone.store(is_autocast_enabled(), Ordering::SeqCst);
mul(t, t)
},
&x,
)
})
.unwrap();
saw_autocast.store(false, Ordering::SeqCst);
assert!(!is_autocast_enabled());
let s = sum(&y).unwrap();
s.backward().unwrap();
assert!(
saw_autocast.load(Ordering::SeqCst),
"checkpoint backward should re-enable autocast during recomputation"
);
assert!(!is_autocast_enabled());
}
#[test]
fn test_checkpoint_multi_recomputation_uses_saved_autocast() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc as StdArc;
let observed = StdArc::new(AtomicUsize::new(0));
let observed_clone = StdArc::clone(&observed);
let a = leaf_grad(&[1.0f32, 2.0], &[2]);
let b = leaf_grad(&[3.0f32, 4.0], &[2]);
let y = autocast(AutocastDtype::BF16, || {
checkpoint_multi(
move |ts: &[Tensor<f32>]| {
let dtype = crate::autograd::autocast::autocast_dtype();
let val = if is_autocast_enabled() {
match dtype {
AutocastDtype::F16 => 1,
AutocastDtype::BF16 => 2,
}
} else {
0
};
observed_clone.store(val, Ordering::SeqCst);
add(&ts[0], &ts[1])
},
&[a.clone(), b.clone()],
)
})
.unwrap();
observed.store(0, Ordering::SeqCst);
let s = sum(&y).unwrap();
s.backward().unwrap();
assert_eq!(
observed.load(Ordering::SeqCst),
2,
"expected recomputation to see autocast(BF16), got code {}",
observed.load(Ordering::SeqCst)
);
}
#[test]
fn test_checkpoint_recomputation_does_not_leak_autocast() {
let x = leaf_grad(&[1.0f32, 2.0], &[2]);
let y = autocast(AutocastDtype::F16, || {
checkpoint(|t: &Tensor<f32>| mul(t, t), &x)
})
.unwrap();
assert!(!is_autocast_enabled());
let s = sum(&y).unwrap();
s.backward().unwrap();
assert!(
!is_autocast_enabled(),
"checkpoint backward should not leak autocast state to caller"
);
}
#[test]
fn test_checkpoint_recomputation_inside_different_autocast() {
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc as StdArc;
let observed = StdArc::new(AtomicU8::new(0));
let observed_clone = StdArc::clone(&observed);
let x = leaf_grad(&[1.0f32, 2.0], &[2]);
let y = autocast(AutocastDtype::F16, || {
checkpoint(
move |t: &Tensor<f32>| {
let code: u8 = if is_autocast_enabled() {
match crate::autograd::autocast::autocast_dtype() {
AutocastDtype::F16 => 1,
AutocastDtype::BF16 => 2,
}
} else {
0
};
observed_clone.store(code, Ordering::SeqCst);
mul(t, t)
},
&x,
)
})
.unwrap();
observed.store(0, Ordering::SeqCst);
autocast(AutocastDtype::BF16, || {
let s = sum(&y).unwrap();
s.backward().unwrap();
assert_eq!(
crate::autograd::autocast::autocast_dtype(),
AutocastDtype::BF16,
"with_autocast_state should restore caller's BF16 state"
);
});
assert_eq!(
observed.load(Ordering::SeqCst),
1,
"expected recomputation to see F16 (saved snapshot), got code {}",
observed.load(Ordering::SeqCst)
);
}
}