use burn::{
backend::Autodiff,
tensor::ops::{FloatTensor, IntTensor},
};
use burn_cubecl::{BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement};
pub trait Backend: burn::tensor::backend::Backend {
fn euclidean_pairwise_distance(x: FloatTensor<Self>) -> FloatTensor<Self>;
fn euclidean_pairwise_distance_backward(
grad_pairwise: FloatTensor<Self>,
x: FloatTensor<Self>,
pairwise: FloatTensor<Self>,
) -> FloatTensor<Self>;
fn knn(pairwise_distances: FloatTensor<Self>, k: u32) -> (IntTensor<Self>, FloatTensor<Self>);
fn knn_backward(
pairwise_distances: FloatTensor<Self>,
k: u32,
grad_output: FloatTensor<Self>,
) -> FloatTensor<Self>;
}
pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {}
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> AutodiffBackend
for Autodiff<CubeBackend<R, F, I, BT>>
{
}