use std::fmt::Debug;
use burn::{
backend::{
autodiff::{
checkpoint::{base::Checkpointer, strategy::CheckpointStrategy},
grads::Gradients,
ops::{Backward, Ops, OpsKind},
NodeId,
},
Autodiff,
},
tensor::ops::FloatTensor,
};
use crate::backend::Backend;
pub fn backward<B: Backend, C: CheckpointStrategy>(
x: FloatTensor<Autodiff<B, C>>,
) -> FloatTensor<Autodiff<B, C>> {
#[derive(Debug)]
struct EuclideanPairwiseDistanceBackward;
impl<B: Backend> Backward<B, 1> for EuclideanPairwiseDistanceBackward {
type State = (NodeId, FloatTensor<B>);
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let (node_x, pairwise) = ops.state;
let grad_pairwise = grads.consume::<B>(&ops.node);
let x: FloatTensor<B> = checkpointer.retrieve_node_output(node_x);
let grad_x =
B::euclidean_pairwise_distance_backward(grad_pairwise, x, pairwise);
grads.register::<B>(node_x, grad_x);
}
}
match EuclideanPairwiseDistanceBackward
.prepare::<C>([x.node.clone()])
.compute_bound()
.stateful()
{
OpsKind::Tracked(mut prep) => {
let x_state = prep.checkpoint(&x);
let pairwise = B::euclidean_pairwise_distance(x.clone().primitive);
let state = (x_state, pairwise.clone());
prep.finish(state, pairwise)
}
OpsKind::UnTracked(prep) => {
let output = B::euclidean_pairwise_distance(x.primitive);
prep.finish(output)
}
}
}