use super::kernel::*;
use crate::kernels::DEFAULT_CUBE_DIM;
use burn::tensor::{ops::FloatTensor, Shape};
use burn_cubecl::{
kernel::into_contiguous, tensor::CubeTensor, BoolElement, CubeBackend, CubeRuntime,
FloatElement, IntElement,
};
use cubecl::prelude::*;
pub fn forward<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement>(
pairwise_distances: FloatTensor<CubeBackend<R, F, I, BT>>,
k: u32,
) -> (CubeTensor<R>, CubeTensor<R>) {
let pairwise_distances = into_contiguous(pairwise_distances.clone());
let client = pairwise_distances.client.clone();
let device = pairwise_distances.device.clone();
let dims = pairwise_distances.shape.dims.clone();
let n = dims[0];
let indices_shape = Shape::from(vec![n, k as usize]);
let distances_shape = Shape::from(vec![n, k as usize]);
let indices_buffer = client.empty(indices_shape.num_elements() * std::mem::size_of::<F>());
let distances_buffer = client.empty(distances_shape.num_elements() * std::mem::size_of::<F>());
let indices: CubeTensor<R> = CubeTensor::new_contiguous(
client.clone(),
device.clone(),
indices_shape,
indices_buffer,
burn::tensor::DType::I64,
);
let distances: CubeTensor<R> = CubeTensor::new_contiguous(
client.clone(),
device.clone(),
distances_shape,
distances_buffer,
F::dtype(),
);
let local_shape = Shape::from(vec![n, k as usize]);
let local_dist_buffer = pairwise_distances
.client
.empty(local_shape.num_elements() * std::mem::size_of::<F>());
let local_idx_buffer = pairwise_distances
.client
.empty(local_shape.num_elements() * std::mem::size_of::<i64>());
let local_distances: CubeTensor<R> = CubeTensor::new_contiguous(
pairwise_distances.client.clone(),
pairwise_distances.device.clone(),
local_shape.clone(),
local_dist_buffer,
F::dtype(),
);
let local_indices: CubeTensor<R> = CubeTensor::new_contiguous(
pairwise_distances.client.clone(),
pairwise_distances.device.clone(),
local_shape,
local_idx_buffer,
burn::tensor::DType::I64,
);
let cube_dim = DEFAULT_CUBE_DIM;
let cubes_needed_in_x = (n as f32 / cube_dim.x as f32).ceil() as u32;
let cubes_needed_in_y = 1_u32;
let cube_count = CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, 1);
let vectorisation = 1;
knn_kernel::launch::<F, I, R>(
&client,
cube_count,
cube_dim,
pairwise_distances.as_tensor_arg(vectorisation), ScalarArg::new(k), local_distances.as_tensor_arg(vectorisation),
local_indices.as_tensor_arg(vectorisation),
indices.as_tensor_arg(vectorisation), distances.as_tensor_arg(vectorisation), )
.expect("knn_kernel launch failed");
(indices, distances)
}
pub fn backward<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement>(
pairwise_distances: FloatTensor<CubeBackend<R, F, I, BT>>,
k: u32,
grad_output: FloatTensor<CubeBackend<R, F, I, BT>>,
) -> FloatTensor<CubeBackend<R, F, I, BT>> {
let pairwise_distances = into_contiguous(pairwise_distances);
let n = pairwise_distances.shape.dims[0]; let grad_output_shape = Shape::from(vec![n, k as usize]);
let zero_bytes = vec![0u8; grad_output_shape.num_elements() * std::mem::size_of::<F>()];
let grad_buffer = pairwise_distances.client.create_from_slice(&zero_bytes);
let grad_pairwise_distances: CubeTensor<R> = CubeTensor::new_contiguous(
pairwise_distances.client.clone(),
pairwise_distances.device.clone(),
pairwise_distances.shape.clone(),
grad_buffer,
F::dtype(),
);
let local_shape = Shape::from(vec![n, k as usize]);
let local_dist_buffer = pairwise_distances
.client
.empty(local_shape.num_elements() * std::mem::size_of::<F>());
let local_idx_buffer = pairwise_distances
.client
.empty(local_shape.num_elements() * std::mem::size_of::<F>());
let local_distances: CubeTensor<R> = CubeTensor::new_contiguous(
pairwise_distances.client.clone(),
pairwise_distances.device.clone(),
local_shape.clone(),
local_dist_buffer,
F::dtype(),
);
let local_indices: CubeTensor<R> = CubeTensor::new_contiguous(
pairwise_distances.client.clone(),
pairwise_distances.device.clone(),
local_shape,
local_idx_buffer,
F::dtype(),
);
let cube_dim = DEFAULT_CUBE_DIM;
let cubes_needed_in_x = (n as f32 / cube_dim.x as f32).ceil() as u32;
let cubes_needed_in_y = 1_u32;
let cube_count = CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, 1);
let vectorization = 1;
knn_backward_kernel::launch::<F, R>(
&pairwise_distances.client,
cube_count,
cube_dim,
pairwise_distances.as_tensor_arg(vectorization),
ScalarArg::new(k), local_distances.as_tensor_arg(vectorization),
local_indices.as_tensor_arg(vectorization),
grad_output.as_tensor_arg(vectorization),
grad_pairwise_distances.as_tensor_arg(vectorization),
)
.expect("knn_backward_kernel launch failed");
grad_pairwise_distances
}