use cubecl::{cube, prelude::*};
#[cube(launch)]
pub fn euclidean_pairwise_distance_kernel<F: Float>(
x: &Tensor<F>, output: &mut Tensor<F>, ) {
let row = ABSOLUTE_POS_X as usize;
let col = ABSOLUTE_POS_Y as usize;
let n = x.shape(0);
let d = x.shape(1);
if row >= n || col >= n || row > col {
} else if row == col {
output[row * n + col] = F::new(0.0);
} else {
let mut sum = F::new(0.0);
for i in 0..d {
let diff = x[row * d + i] - x[col * d + i];
sum += diff * diff;
}
let dist = F::sqrt(sum);
output[row * n + col] = dist;
output[col * n + row] = dist; }
}
#[cube(launch)]
pub fn euclidean_pairwise_distance_backward_kernel<F: Float>(
x: &Tensor<F>, pairwise: &Tensor<F>, grad_pairwise: &Tensor<F>, grad_x: &mut Tensor<F>, ) {
let row = ABSOLUTE_POS_X as usize; let feat = ABSOLUTE_POS_Y as usize;
let n = pairwise.shape(0);
let d = x.shape(1);
let epsilon = F::new(1e-8);
if row >= n || feat >= d {
} else {
let x_row_feat = x[row * d + feat];
let mut grad_sum = F::new(0.0);
for col in 0..n {
if col != row {
let dist = F::max(pairwise[row * n + col], epsilon);
let g_rc = grad_pairwise[row * n + col];
let g_cr = grad_pairwise[col * n + row];
let diff = x_row_feat - x[col * d + feat];
grad_sum += (g_rc + g_cr) * diff / dist;
}
}
grad_x[row * d + feat] = grad_sum;
}
}