use burn::tensor::ops::{FloatTensor, IntTensor};
use cubecl::CubeDim;
use crate::backend::Backend;
use burn_cubecl::{BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement};
mod euclidean;
mod knn;
pub const DEFAULT_CUBE_DIM: CubeDim = CubeDim { x: 32, y: 32, z: 1 };
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Backend
for CubeBackend<R, F, I, BT>
{
fn euclidean_pairwise_distance(x: FloatTensor<Self>) -> FloatTensor<Self> {
euclidean::forward::forward::<R, F, I, BT>(x)
}
fn euclidean_pairwise_distance_backward(
grad_pairwise: FloatTensor<Self>,
x: FloatTensor<Self>,
pairwise: FloatTensor<Self>,
) -> FloatTensor<Self> {
euclidean::forward::backward::<R, F, I, BT>(grad_pairwise, x, pairwise)
}
fn knn(pairwise_distances: FloatTensor<Self>, k: u32) -> (IntTensor<Self>, FloatTensor<Self>) {
knn::forward::forward::<R, F, I, BT>(pairwise_distances, k)
}
fn knn_backward(
pairwise_distances: FloatTensor<Self>,
k: u32,
grad_output: FloatTensor<Self>,
) -> FloatTensor<Self> {
knn::forward::backward::<R, F, I, BT>(pairwise_distances, k, grad_output)
}
}