use core::f32;
use cubecl::{cube, prelude::*};
#[cube]
pub fn u32_to_float(x: u32) -> f32 {
f32::cast_from(x)
}
const INFINITY: f32 = 3.40282347e+38;
#[cube(launch)]
pub fn knn_kernel<F: Float + CubePrimitive, I: Int>(
pairwise_distances: &Tensor<F>, k: u32, local_distances: &mut Tensor<F>, local_indices: &mut Tensor<I>, indices: &mut Tensor<I>, distances: &mut Tensor<F>, ) {
let row = ABSOLUTE_POS_X as usize;
let n = pairwise_distances.shape(0);
let k_usize = k as usize;
if row >= n {
} else {
for i in 0..k_usize {
local_distances[row * k_usize + i] = F::new(INFINITY);
local_indices[row * k_usize + i] = I::cast_from(n); }
for col in 0..n {
if row != col {
let dist = pairwise_distances[row * n + col];
if dist < local_distances[row * k_usize + k_usize - 1] {
let mut i = k_usize - 1;
while i > 0 {
if dist < local_distances[row * k_usize + i] {
local_distances[row * k_usize + i] =
local_distances[row * k_usize + i - 1];
local_indices[row * k_usize + i] =
local_indices[row * k_usize + i - 1];
} else {
break;
}
i -= 1;
}
local_distances[row * k_usize + i] = dist;
local_indices[row * k_usize + i] = I::cast_from(col);
}
}
}
for i in 0..k_usize {
distances[row * k_usize + i] = local_distances[row * k_usize + i];
indices[row * k_usize + i] = local_indices[row * k_usize + i];
}
}
}
#[cube(launch)]
pub fn knn_backward_kernel<F: Float + CubePrimitive>(
pairwise_distances: &Tensor<F>, k: u32,
local_distances: &mut Tensor<F>, local_indices: &mut Tensor<F>, grad_output: &Tensor<F>, grad_pairwise_distances: &mut Tensor<F>, ) {
let row = ABSOLUTE_POS_X as usize;
let n = pairwise_distances.shape(0);
let k_usize = k as usize;
if row >= n {
} else {
for i in 0..k_usize {
local_distances[row * k_usize + i] = F::new(INFINITY);
local_indices[row * k_usize + i] = F::new(0.0);
}
for col in 0..n {
if row != col {
let dist = pairwise_distances[row * n + col];
if dist < local_distances[row * k_usize + k_usize - 1] {
let mut i = k_usize - 1;
while i > 0 {
if dist < local_distances[row * k_usize + i] {
local_distances[row * k_usize + i] =
local_distances[row * k_usize + i - 1];
local_indices[row * k_usize + i] =
local_indices[row * k_usize + i - 1];
} else {
break;
}
i -= 1;
}
local_distances[row * k_usize + i] = dist;
local_indices[row * k_usize + i] = F::cast_from(col);
}
}
}
let epsilon = F::new(1e-8);
for i in 0..k_usize {
let grad_value = grad_output[row * k_usize + i];
if grad_value != F::new(0.0) {
let dist = F::max(local_distances[row * k_usize + i], epsilon);
let grad_pairwise = grad_value / dist;
let neighbor_col = u32::cast_from(local_indices[row * k_usize + i]) as usize;
if neighbor_col < n {
grad_pairwise_distances[row * n + neighbor_col] += grad_pairwise;
}
}
}
}
}