use burn::{
backend::{autodiff::checkpoint::strategy::CheckpointStrategy, Autodiff},
tensor::ops::{FloatTensor, IntTensor},
};
use cubecl::CubeDim;
use crate::backend::Backend;
use burn_cubecl::{BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement};
mod euclidean;
mod knn;
pub const DEFAULT_CUBE_DIM: CubeDim = CubeDim { x: 32, y: 32, z: 1 };
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Backend
for CubeBackend<R, F, I, BT>
{
fn euclidean_pairwise_distance(x: FloatTensor<Self>) -> FloatTensor<Self> {
euclidean::forward::forward::<R, F, I, BT>(x)
}
fn euclidean_pairwise_distance_backward(
grad_pairwise: FloatTensor<Self>,
x: FloatTensor<Self>,
pairwise: FloatTensor<Self>,
) -> FloatTensor<Self> {
euclidean::forward::backward::<R, F, I, BT>(grad_pairwise, x, pairwise)
}
fn knn(pairwise_distances: FloatTensor<Self>, k: u32) -> (IntTensor<Self>, FloatTensor<Self>) {
knn::forward::forward::<R, F, I, BT>(pairwise_distances, k)
}
fn knn_backward(
pairwise_distances: FloatTensor<Self>,
k: u32,
grad_output: FloatTensor<Self>,
) -> FloatTensor<Self> {
knn::forward::backward::<R, F, I, BT>(pairwise_distances, k, grad_output)
}
}
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
fn euclidean_pairwise_distance(x: FloatTensor<Self>) -> FloatTensor<Self> {
euclidean::backward::backward::<B, C>(x)
}
fn euclidean_pairwise_distance_backward(
_grad_pairwise: FloatTensor<Self>,
_x: FloatTensor<Self>,
_pairwise: FloatTensor<Self>,
) -> FloatTensor<Self> {
unimplemented!(
"Called on inner CubeBackend only; Autodiff dispatches via euclidean::backward."
);
}
fn knn(pairwise_distances: FloatTensor<Self>, k: u32) -> (IntTensor<Self>, FloatTensor<Self>) {
knn::backward::backward::<B, C>(pairwise_distances, k)
}
fn knn_backward(
_pairwise_distances: FloatTensor<Self>,
_k: u32,
_grad_output: FloatTensor<Self>,
) -> FloatTensor<Self> {
unimplemented!(
"Triggered on the inner CubeBackend only; \
the Autodiff wrapper delegates via knn::backward."
);
}
}