use super::kernel::*;
use crate::kernels::DEFAULT_CUBE_DIM;
use burn::tensor::{ops::FloatTensor, Shape};
use burn_cubecl::{
kernel::into_contiguous, tensor::CubeTensor, BoolElement, CubeBackend, CubeRuntime,
FloatElement, IntElement,
};
use cubecl::prelude::*;
pub fn forward<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement>(
x: FloatTensor<CubeBackend<R, F, I, BT>>,
) -> FloatTensor<CubeBackend<R, F, I, BT>> {
let xx = into_contiguous(x.clone());
let client = xx.client;
let device = xx.device;
let dims = xx.shape.dims;
let n = dims[0];
let output_shape = Shape::from(vec![n, n]);
let buffer = client.empty(output_shape.num_elements() * std::mem::size_of::<F>());
let output = CubeTensor::new_contiguous(
client.clone(),
device.clone(),
output_shape,
buffer,
F::dtype(),
);
let cube_dim = DEFAULT_CUBE_DIM;
let cubes_needed_in_x = (n as f32 / cube_dim.x as f32).ceil() as u32;
let cubes_needed_in_y = (n as f32 / cube_dim.y as f32).ceil() as u32;
let cube_count = CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, 1);
euclidean_pairwise_distance_kernel::launch::<F, R>(
&client,
cube_count,
cube_dim,
x.as_tensor_arg(1),
output.as_tensor_arg(1),
)
.expect("euclidean_pairwise_distance_kernel launch failed");
#[cfg(feature = "verbose")]
{
}
output
}
pub fn backward<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement>(
grad_pairwise: FloatTensor<CubeBackend<R, F, I, BT>>, x: FloatTensor<CubeBackend<R, F, I, BT>>, pairwise: FloatTensor<CubeBackend<R, F, I, BT>>, ) -> FloatTensor<CubeBackend<R, F, I, BT>> {
let x = into_contiguous(x);
let pairwise = into_contiguous(pairwise);
let grad_pairwise = into_contiguous(grad_pairwise);
let n = x.shape.dims[0];
let d = x.shape.dims[1];
let grad_x_shape = Shape::from(vec![n, d]);
let buffer = x
.client
.empty(grad_x_shape.num_elements() * std::mem::size_of::<F>());
let grad_x: CubeTensor<R> = CubeTensor::new_contiguous(
x.client.clone(),
x.device.clone(),
grad_x_shape,
buffer,
F::dtype(),
);
let cube_dim = DEFAULT_CUBE_DIM;
let cubes_x = (n as f32 / cube_dim.x as f32).ceil() as u32;
let cubes_y = (d as f32 / cube_dim.y as f32).ceil() as u32;
let cube_count = CubeCount::Static(cubes_x, cubes_y, 1);
euclidean_pairwise_distance_backward_kernel::launch::<F, R>(
&x.client,
cube_count,
cube_dim,
x.as_tensor_arg(1),
pairwise.as_tensor_arg(1),
grad_pairwise.as_tensor_arg(1),
grad_x.as_tensor_arg(1),
)
.expect("euclidean_pairwise_distance_backward_kernel launch failed");
grad_x
}