use crate::graph::AsGraph;
use crate::op::{ComputeContext, GradientContext, Op, OpError};
use crate::tensor::Tensor;
use crate::Float;
use once_cell::sync::Lazy;
use std::collections::HashSet;
use std::marker::PhantomData;
use std::sync::Mutex;
static CHECKPOINT_REGISTRY: Lazy<Mutex<CheckpointRegistry>> =
Lazy::new(|| Mutex::new(CheckpointRegistry::new()));
struct CheckpointRegistry {
checkpoint_ops: HashSet<usize>,
estimated_memory_saved: usize,
tracking_enabled: bool,
}
impl CheckpointRegistry {
fn new() -> Self {
Self {
checkpoint_ops: HashSet::new(),
estimated_memory_saved: 0,
tracking_enabled: false,
}
}
fn register_checkpoint(&mut self, tensor_id: usize, estimatedsize: usize) {
self.checkpoint_ops.insert(tensor_id);
if self.tracking_enabled {
self.estimated_memory_saved += estimatedsize;
}
}
fn enable_tracking(&mut self) {
self.tracking_enabled = true;
self.estimated_memory_saved = 0;
}
fn disable_tracking(&mut self) {
self.tracking_enabled = false;
}
fn reset_statistics(&mut self) {
self.estimated_memory_saved = 0;
self.checkpoint_ops.clear();
}
fn get_memory_saved(&self) -> usize {
self.estimated_memory_saved
}
}
pub struct CheckpointOp;
impl<F: Float> Op<F> for CheckpointOp {
fn name(&self) -> &'static str {
"GradientCheckpoint"
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let input = ctx.input(0).to_owned();
let estimated_size = input.len() * std::mem::size_of::<F>();
ctx.append_output(input);
CHECKPOINT_REGISTRY
.lock()
.expect("Operation failed")
.register_checkpoint(0, estimated_size);
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<F>) {
let grad_output = ctx.output_grad();
let g = ctx.graph();
let input = ctx.input(0);
if let Ok(input_array) = input.eval(g) {
let inputshape = input_array.shape();
if let Ok(grad_output_array) = grad_output.eval(g) {
if grad_output_array.shape() == inputshape {
ctx.append_input_grad(0, Some(*grad_output));
} else {
let shape_tensor = crate::tensor_ops::convert_to_tensor(
scirs2_core::ndarray::Array::from_shape_vec(
scirs2_core::ndarray::IxDyn(&[inputshape.len()]),
inputshape
.iter()
.map(|&x| F::from(x).expect("Failed to convert to float"))
.collect::<Vec<_>>(),
)
.expect("Operation failed"),
g,
);
let ones = crate::tensor_ops::ones(&shape_tensor, g);
ctx.append_input_grad(0, Some(ones));
}
} else {
ctx.append_input_grad(0, Some(crate::tensor_ops::scalar(F::one(), g)));
}
} else {
ctx.append_input_grad(0, Some(*grad_output));
}
}
}
#[allow(dead_code)]
pub fn checkpoint<'g, F: Float>(tensor: &Tensor<'g, F>) -> Tensor<'g, F> {
let g = tensor.graph();
Tensor::builder(g)
.append_input(tensor, false)
.build(CheckpointOp)
}
#[allow(dead_code)]
pub fn detach<'g, F: Float>(tensor: &Tensor<'g, F>) -> Tensor<'g, F> {
let g = tensor.graph();
Tensor::builder(g)
.append_input(tensor, false)
.set_differentiable(false)
.build(CheckpointOp)
}
#[allow(dead_code)]
pub fn checkpoint_segment<'g, F: Float, Func, const N: usize>(
_ctx: &'g crate::graph::Context<'g, F>,
input_tensors: [&Tensor<'g, F>; N],
segment_fn: Func,
) -> Tensor<'g, F>
where
Func: FnOnce([&Tensor<'g, F>; N]) -> Tensor<'g, F>,
{
let detached_inputs: Vec<Tensor<'g, F>> = input_tensors.iter().map(|t| detach(t)).collect();
let detached_refs: Vec<&Tensor<'g, F>> = detached_inputs.iter().collect();
let output = if let Ok(array_refs) = detached_refs.try_into() {
segment_fn(array_refs)
} else {
panic!("Failed to convert Vec to array in checkpoint_segment")
};
checkpoint(&output)
}
#[allow(dead_code)]
pub fn checkpoint_segment_flex<'g, F: Float, Func>(
_ctx: &'g crate::graph::Context<'g, F>,
input_tensors: &[&Tensor<'g, F>],
segment_fn: Func,
) -> Tensor<'g, F>
where
Func: FnOnce(&[&Tensor<'g, F>]) -> Tensor<'g, F>,
{
let detached_inputs: Vec<Tensor<'g, F>> = input_tensors.iter().map(|t| detach(t)).collect();
let detached_refs: Vec<&Tensor<'g, F>> = detached_inputs.iter().collect();
let output = segment_fn(&detached_refs);
checkpoint(&output)
}
pub struct CheckpointGroup<'g, F: Float> {
_ctx: &'g crate::graph::Context<'g, F>,
_phantom: PhantomData<F>,
}
impl<'g, F: Float> CheckpointGroup<'g, F> {
pub fn new(ctx: &'g crate::graph::Context<'g, F>) -> Self {
Self {
_ctx: ctx,
_phantom: PhantomData,
}
}
pub fn checkpoint_fn<Inputs, Outputs, Func>(&self, inputs: Inputs, segmentfn: Func) -> Outputs
where
Inputs: Clone,
Func: FnOnce(Inputs) -> Outputs,
Outputs: CheckpointOutput<'g, F>,
{
let outputs = segmentfn(inputs);
outputs.checkpoint()
}
pub fn checkpoint_fn_flex<Func>(
&self,
inputs: &[&Tensor<'g, F>],
segment_fn: Func,
) -> Tensor<'g, F>
where
Func: FnOnce(&[&Tensor<'g, F>]) -> Tensor<'g, F>,
{
let output = segment_fn(inputs);
checkpoint(&output)
}
pub fn checkpoint_fn_flex2<Func>(
&self,
inputs: &[&Tensor<'g, F>],
segment_fn: Func,
) -> (Tensor<'g, F>, Tensor<'g, F>)
where
Func: FnOnce(&[&Tensor<'g, F>]) -> (Tensor<'g, F>, Tensor<'g, F>),
{
let (output1, output2) = segment_fn(inputs);
(checkpoint(&output1), checkpoint(&output2))
}
pub fn checkpoint_fn_flex3<Func>(
&self,
inputs: &[&Tensor<'g, F>],
segment_fn: Func,
) -> (Tensor<'g, F>, Tensor<'g, F>, Tensor<'g, F>)
where
Func: FnOnce(&[&Tensor<'g, F>]) -> (Tensor<'g, F>, Tensor<'g, F>, Tensor<'g, F>),
{
let (output1, output2, output3) = segment_fn(inputs);
(
checkpoint(&output1),
checkpoint(&output2),
checkpoint(&output3),
)
}
}
pub trait CheckpointOutput<'g, F: Float> {
fn checkpoint(self) -> Self;
}
impl<'g, F: Float> CheckpointOutput<'g, F> for Tensor<'g, F> {
fn checkpoint(self) -> Self {
checkpoint(&self)
}
}
impl<'g, F: Float> CheckpointOutput<'g, F> for (Tensor<'g, F>, Tensor<'g, F>) {
fn checkpoint(self) -> Self {
(checkpoint(&self.0), checkpoint(&self.1))
}
}
impl<'g, F: Float> CheckpointOutput<'g, F> for (Tensor<'g, F>, Tensor<'g, F>, Tensor<'g, F>) {
fn checkpoint(self) -> Self {
(
checkpoint(&self.0),
checkpoint(&self.1),
checkpoint(&self.2),
)
}
}
impl<'g, F: Float> CheckpointOutput<'g, F> for Vec<Tensor<'g, F>> {
fn checkpoint(self) -> Self {
self.iter().map(|t| checkpoint(t)).collect()
}
}
#[allow(dead_code)]
pub fn adaptive_checkpoint<'g, F: Float>(
tensor: &Tensor<'g, F>,
mut memory_threshold_bytes: usize,
) -> Tensor<'g, F> {
let ctx = tensor.graph();
let shape_tensor = crate::tensor_ops::shape(tensor);
let mut element_count = 1_usize;
if let Some(ctx_ref) = ctx.context_ref() {
if let Ok(shape_array) = shape_tensor.eval(ctx_ref) {
for &dim in shape_array.iter() {
if let Some(size) = dim.to_usize() {
if size > 0 {
element_count = element_count.saturating_mul(size);
}
}
}
}
} else {
element_count = 10000;
memory_threshold_bytes /= 2;
}
let estimated_memory = element_count * std::mem::size_of::<F>();
if estimated_memory > memory_threshold_bytes {
checkpoint(tensor)
} else {
*tensor
}
}
pub struct CheckpointProfiler;
impl CheckpointProfiler {
pub fn start_tracking() {
CHECKPOINT_REGISTRY
.lock()
.expect("Operation failed")
.enable_tracking();
}
pub fn stop_tracking() {
CHECKPOINT_REGISTRY
.lock()
.expect("Operation failed")
.disable_tracking();
}
pub fn reset_statistics() {
CHECKPOINT_REGISTRY
.lock()
.expect("Operation failed")
.reset_statistics();
}
pub fn memory_saved() -> usize {
CHECKPOINT_REGISTRY
.lock()
.expect("Operation failed")
.get_memory_saved()
}
pub fn checkpoint_count() -> usize {
CHECKPOINT_REGISTRY
.lock()
.expect("Operation failed")
.checkpoint_ops
.len()
}
}