use std::sync::Arc;
use crate::dtype::Float;
use crate::error::FerrotorchResult;
use crate::gpu_dispatch::GpuRngState;
use crate::tensor::Tensor;
type CheckpointFn<T> = Arc<dyn Fn(&Tensor<T>) -> FerrotorchResult<Tensor<T>> + Send + Sync>;
type CheckpointMultiFn<T> =
Arc<dyn Fn(&[Tensor<T>]) -> FerrotorchResult<Tensor<T>> + Send + Sync>;
pub fn checkpoint<T, F>(f: F, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
where
T: Float,
F: Fn(&Tensor<T>) -> FerrotorchResult<Tensor<T>> + Send + Sync + 'static,
{
use crate::autograd::no_grad::no_grad;
use crate::storage::TensorStorage;
let saved_gpu_rng = save_gpu_rng_state(input);
let output = no_grad(|| f(input))?;
if !input.requires_grad() {
return Ok(output);
}
let checkpoint_fn = Arc::new(CheckpointBackward {
func: Arc::new(f),
input: input.clone(),
output_shape: output.shape().to_vec(),
saved_gpu_rng,
});
let device = output.device();
let storage = TensorStorage::on_device(output.data_vec()?, device)?;
Tensor::from_operation(storage, output.shape().to_vec(), checkpoint_fn)
}
pub fn checkpoint_multi<T, F>(f: F, inputs: &[Tensor<T>]) -> FerrotorchResult<Tensor<T>>
where
T: Float,
F: Fn(&[Tensor<T>]) -> FerrotorchResult<Tensor<T>> + Send + Sync + 'static,
{
use crate::autograd::no_grad::no_grad;
use crate::storage::TensorStorage;
if inputs.is_empty() {
return Err(crate::error::FerrotorchError::InvalidArgument {
message: "checkpoint_multi: at least one input required".into(),
});
}
let saved_gpu_rng = save_gpu_rng_state(&inputs[0]);
let output = no_grad(|| f(inputs))?;
let any_requires_grad = inputs.iter().any(|t| t.requires_grad());
if !any_requires_grad {
return Ok(output);
}
let checkpoint_fn = Arc::new(CheckpointMultiBackward {
func: Arc::new(f),
inputs: inputs.to_vec(),
output_shape: output.shape().to_vec(),
saved_gpu_rng,
});
let device = output.device();
let storage = TensorStorage::on_device(output.data_vec()?, device)?;
Tensor::from_operation(storage, output.shape().to_vec(), checkpoint_fn)
}
fn save_gpu_rng_state<T: Float>(tensor: &Tensor<T>) -> Option<GpuRngState> {
let device_ordinal = match tensor.device() {
crate::device::Device::Cuda(id) => id,
crate::device::Device::Cpu => return None,
};
let backend = crate::gpu_dispatch::gpu_backend()?;
backend.save_rng_state(device_ordinal).ok()
}
struct GpuRngGuard {
previous: GpuRngState,
}
impl Drop for GpuRngGuard {
fn drop(&mut self) {
if let Some(backend) = crate::gpu_dispatch::gpu_backend() {
let _ = backend.restore_rng_state(self.previous);
}
}
}
struct CheckpointBackward<T: Float> {
func: CheckpointFn<T>,
input: Tensor<T>,
output_shape: Vec<usize>,
saved_gpu_rng: Option<GpuRngState>,
}
impl<T: Float> std::fmt::Debug for CheckpointBackward<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CheckpointBackward")
.field("input_shape", &self.input.shape())
.field("output_shape", &self.output_shape)
.field("has_gpu_rng_state", &self.saved_gpu_rng.is_some())
.finish()
}
}
impl<T: Float> crate::tensor::GradFn<T> for CheckpointBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let _rng_guard = if let Some(saved_state) = self.saved_gpu_rng {
let current_state = save_gpu_rng_state(&self.input);
if let Some(backend) = crate::gpu_dispatch::gpu_backend() {
let _ = backend.restore_rng_state(saved_state);
}
current_state.map(|prev| GpuRngGuard { previous: prev })
} else {
None
};
let input_with_grad = self.input.clone().requires_grad_(true);
let recomputed = (self.func)(&input_with_grad)?;
use crate::grad_fns::arithmetic::mul;
use crate::grad_fns::reduction::sum;
let weighted = mul(
&recomputed,
&grad_output.clone().requires_grad_(false).detach(),
)?;
let scalar = sum(&weighted)?;
scalar.backward()?;
let input_grad = input_with_grad.grad()?;
Ok(vec![input_grad])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"CheckpointBackward"
}
}
struct CheckpointMultiBackward<T: Float> {
func: CheckpointMultiFn<T>,
inputs: Vec<Tensor<T>>,
output_shape: Vec<usize>,
saved_gpu_rng: Option<GpuRngState>,
}
impl<T: Float> std::fmt::Debug for CheckpointMultiBackward<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CheckpointMultiBackward")
.field("num_inputs", &self.inputs.len())
.field("output_shape", &self.output_shape)
.field("has_gpu_rng_state", &self.saved_gpu_rng.is_some())
.finish()
}
}
impl<T: Float> crate::tensor::GradFn<T> for CheckpointMultiBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let _rng_guard = if let Some(saved_state) = self.saved_gpu_rng {
let current_state = save_gpu_rng_state(&self.inputs[0]);
if let Some(backend) = crate::gpu_dispatch::gpu_backend() {
let _ = backend.restore_rng_state(saved_state);
}
current_state.map(|prev| GpuRngGuard { previous: prev })
} else {
None
};
let inputs_with_grad: Vec<Tensor<T>> = self
.inputs
.iter()
.map(|t| {
if t.requires_grad() {
t.clone().requires_grad_(true)
} else {
t.clone()
}
})
.collect();
let recomputed = (self.func)(&inputs_with_grad)?;
use crate::grad_fns::arithmetic::mul;
use crate::grad_fns::reduction::sum;
let weighted = mul(
&recomputed,
&grad_output.clone().requires_grad_(false).detach(),
)?;
let scalar = sum(&weighted)?;
scalar.backward()?;
let mut grads = Vec::with_capacity(self.inputs.len());
for (orig, with_grad) in self.inputs.iter().zip(inputs_with_grad.iter()) {
if orig.requires_grad() {
grads.push(with_grad.grad()?);
} else {
grads.push(None);
}
}
Ok(grads)
}
fn inputs(&self) -> Vec<&Tensor<T>> {
self.inputs.iter().collect()
}
fn name(&self) -> &'static str {
"CheckpointMultiBackward"
}
}