use burn::tensor::{ops::FloatTensor, ops::IntTensor, Tensor, TensorData, TensorPrimitive};
pub fn euclidean_pairwise_distance<B: burn::tensor::backend::Backend>(
x: FloatTensor<B>,
) -> FloatTensor<B> {
let tensor: Tensor<B, 2> = Tensor::from_primitive(TensorPrimitive::Float(x));
let [n, _d] = tensor.dims();
let sq_norms = tensor
.clone()
.powf_scalar(2.0)
.sum_dim(1)
.reshape([n, 1]);
let cross = tensor.clone().matmul(tensor.transpose());
let neg2_cross = cross.mul_scalar(-2.0f32);
let sq_norms_t = sq_norms.clone().reshape([1, n]);
let sq_dists = neg2_cross.add(sq_norms).add(sq_norms_t);
let dists = sq_dists.clamp_min(0.0f32).sqrt();
dists.into_primitive().tensor()
}
pub fn euclidean_pairwise_distance_backward<B: burn::tensor::backend::Backend>(
grad_pairwise: FloatTensor<B>,
x: FloatTensor<B>,
pairwise: FloatTensor<B>,
) -> FloatTensor<B> {
let x_t: Tensor<B, 2> = Tensor::from_primitive(TensorPrimitive::Float(x));
let grad_t: Tensor<B, 2> = Tensor::from_primitive(TensorPrimitive::Float(grad_pairwise));
let pw_t: Tensor<B, 2> = Tensor::from_primitive(TensorPrimitive::Float(pairwise));
let [n, d] = x_t.dims();
let sym_grad = grad_t.clone().add(grad_t.transpose());
let inv_dist = pw_t.clamp_min(1e-8f32).recip();
let scale = sym_grad.mul(inv_dist);
let grad_x = scale.clone().matmul(x_t.clone())
- scale.sum_dim(1).reshape([n, 1]).mul(x_t);
grad_x.into_primitive().tensor()
}
pub fn knn<B: burn::tensor::backend::Backend>(
pairwise_distances: FloatTensor<B>,
k: u32,
) -> (IntTensor<B>, FloatTensor<B>) {
let pw: Tensor<B, 2> = Tensor::from_primitive(TensorPrimitive::Float(pairwise_distances));
let [n, _n2] = pw.dims();
let device = pw.device();
let k = k as usize;
let flat: Vec<f32> = pw.to_data().to_vec::<f32>().unwrap();
let (idx_flat, dist_flat) = crate::train::get_distance_by_metric::knn_from_pairwise_cpu(
&flat, n, k,
);
let indices: Tensor<B, 2, burn::tensor::Int> =
Tensor::from_data(TensorData::new(idx_flat, [n, k]), &device);
let distances: Tensor<B, 2> =
Tensor::from_data(TensorData::new(dist_flat, [n, k]), &device);
(indices.into_primitive(), distances.into_primitive().tensor())
}
pub fn knn_backward<B: burn::tensor::backend::Backend>(
_pairwise_distances: FloatTensor<B>,
_k: u32,
_grad_output: FloatTensor<B>,
) -> FloatTensor<B> {
unimplemented!(
"knn_backward is not needed for the sparse training path. \
Use the GPU (CubeCL) backend for the dense training path."
);
}