use std::fmt::Debug;
use burn::{
backend::{
autodiff::{
checkpoint::{base::Checkpointer, strategy::CheckpointStrategy},
grads::Gradients,
ops::{Backward, Ops, OpsKind},
NodeId,
},
Autodiff,
},
tensor::ops::{FloatTensor, IntTensor},
};
use crate::{backend::*, print_if, print_primitive_tensor};
const VERBOSE: bool = false;
pub fn backward<B: Backend, C: CheckpointStrategy>(
pairwise_distances: FloatTensor<Autodiff<B, C>>,
k: u32,
) -> (IntTensor<Autodiff<B, C>>, FloatTensor<Autodiff<B, C>>) {
#[derive(Debug)]
struct KnnBackward;
impl<B: Backend> Backward<B, 1> for KnnBackward {
type State = (NodeId, u32);
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let (node_pairwise_distances, k) = ops.state;
let grad_output = grads.consume::<B>(&ops.node);
let pairwise_distances: FloatTensor<B> =
checkpointer.retrieve_node_output(node_pairwise_distances);
if VERBOSE {
println!("grad_output {grad_output:?}");
print_primitive_tensor::<B>(&grad_output, 10, 10);
println!("pairwise_distances {pairwise_distances:?}");
print_primitive_tensor::<B>(&pairwise_distances, 10, 10);
}
let grad_pairwise_distances = B::knn_backward(pairwise_distances, k, grad_output);
if VERBOSE {
println!("===grad_pairwise_distances=== {grad_pairwise_distances:?}");
print_primitive_tensor::<B>(&grad_pairwise_distances, 0, 0);
}
grads.register::<B>(node_pairwise_distances, grad_pairwise_distances);
}
}
let indicies = match KnnBackward
.prepare::<C>([pairwise_distances.node.clone()])
.compute_bound()
.stateful()
{
OpsKind::Tracked(mut prep) => {
let pairwise_distances_state = prep.checkpoint(&pairwise_distances);
let (indicies, distances) = B::knn(pairwise_distances.clone().primitive, k); print_if!(VERBOSE, "Forward pass indicies (Tracked): {:?}", indicies); print_if!(VERBOSE, "Forward pass distances (Tracked): {:?}", distances);
let state = (pairwise_distances_state, k);
let indicies = B::int_into_float(indicies);
let indicies = prep.finish(state, indicies);
indicies
}
OpsKind::UnTracked(prep) => {
let output = B::knn(pairwise_distances.clone().primitive, k); let (indicies, distances) = output;
print_if!(VERBOSE, "Forward pass indicies (UnTracked): {:?}", indicies); print_if!(
VERBOSE,
"Forward pass distances (UnTracked): {:?}",
distances
);
let indicies = B::int_into_float(indicies);
let indicies = prep.finish(indicies);
indicies
}
};
let distances = match KnnBackward
.prepare::<C>([pairwise_distances.node.clone()])
.compute_bound()
.stateful()
{
OpsKind::Tracked(mut prep) => {
let pairwise_distances_state = prep.checkpoint(&pairwise_distances);
let (indicies, distances) = B::knn(pairwise_distances.clone().primitive, k); print_if!(VERBOSE, "Forward pass indicies (Tracked): {:?}", indicies); print_if!(VERBOSE, "Forward pass distances (Tracked): {:?}", distances);
let state = (pairwise_distances_state, k);
let distances = prep.finish(state, distances);
distances
}
OpsKind::UnTracked(prep) => {
let output = B::knn(pairwise_distances.clone().primitive, k); let (indicies, distances) = output;
print_if!(VERBOSE, "Forward pass indicies (UnTracked): {:?}", indicies); print_if!(
VERBOSE,
"Forward pass distances (UnTracked): {:?}",
distances
);
let distances = prep.finish(distances);
distances
}
};
let inner_tensor = indicies.into_primitive();
let int_tensor = B::float_into_int(inner_tensor);
let indicies: IntTensor<Autodiff<B, C>> = IntTensor::<Autodiff<B, C>>::from(int_tensor);
(indicies, distances)
}