use std::sync::Arc;
use crate::dtype::Float;
use crate::error::FerrotorchResult;
use crate::gpu_dispatch::GpuRngState;
use crate::tensor::Tensor;
type CheckpointFn<T> = Arc<dyn Fn(&Tensor<T>) -> FerrotorchResult<Tensor<T>> + Send + Sync>;
pub fn checkpoint<T, F>(f: F, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
where
T: Float,
F: Fn(&Tensor<T>) -> FerrotorchResult<Tensor<T>> + Send + Sync + 'static,
{
use crate::autograd::no_grad::no_grad;
use crate::storage::TensorStorage;
let saved_gpu_rng = save_gpu_rng_state(input);
let output = no_grad(|| f(input))?;
if !input.requires_grad() {
return Ok(output);
}
let checkpoint_fn = Arc::new(CheckpointBackward {
func: Arc::new(f),
input: input.clone(),
output_shape: output.shape().to_vec(),
saved_gpu_rng,
});
let device = output.device();
let storage = TensorStorage::on_device(output.data_vec()?, device)?;
Tensor::from_operation(storage, output.shape().to_vec(), checkpoint_fn)
}
fn save_gpu_rng_state<T: Float>(tensor: &Tensor<T>) -> Option<GpuRngState> {
let device_ordinal = match tensor.device() {
crate::device::Device::Cuda(id) => id,
crate::device::Device::Cpu => return None,
};
let backend = crate::gpu_dispatch::gpu_backend()?;
backend.save_rng_state(device_ordinal).ok()
}
struct GpuRngGuard {
previous: GpuRngState,
}
impl Drop for GpuRngGuard {
fn drop(&mut self) {
if let Some(backend) = crate::gpu_dispatch::gpu_backend() {
let _ = backend.restore_rng_state(self.previous);
}
}
}
struct CheckpointBackward<T: Float> {
func: CheckpointFn<T>,
input: Tensor<T>,
output_shape: Vec<usize>,
saved_gpu_rng: Option<GpuRngState>,
}
impl<T: Float> std::fmt::Debug for CheckpointBackward<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CheckpointBackward")
.field("input_shape", &self.input.shape())
.field("output_shape", &self.output_shape)
.field("has_gpu_rng_state", &self.saved_gpu_rng.is_some())
.finish()
}
}
impl<T: Float> crate::tensor::GradFn<T> for CheckpointBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let _rng_guard = if let Some(saved_state) = self.saved_gpu_rng {
let current_state = save_gpu_rng_state(&self.input);
if let Some(backend) = crate::gpu_dispatch::gpu_backend() {
let _ = backend.restore_rng_state(saved_state);
}
current_state.map(|prev| GpuRngGuard { previous: prev })
} else {
None
};
let input_with_grad = self.input.clone().requires_grad_(true);
let recomputed = (self.func)(&input_with_grad)?;
use crate::grad_fns::arithmetic::mul;
use crate::grad_fns::reduction::sum;
let weighted = mul(
&recomputed,
&grad_output.clone().requires_grad_(false).detach(),
)?;
let scalar = sum(&weighted)?;
scalar.backward()?;
let input_grad = input_with_grad.grad()?;
Ok(vec![input_grad])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"CheckpointBackward"
}
}