use burn::tensor::{Int, Tensor, TensorData, TensorPrimitive};
use crate::backend::Backend;
pub fn pairwise_distances<B: Backend>(data: Tensor<B, 2>) -> Tensor<B, 2> {
let x = data.into_primitive().tensor();
let pairwise = B::euclidean_pairwise_distance(x);
Tensor::from_primitive(TensorPrimitive::Float(pairwise))
}
pub fn knn_from_pairwise_cpu(flat: &[f32], n: usize, k: usize) -> (Vec<i32>, Vec<f32>) {
assert!(k < n, "k ({k}) must be strictly less than n ({n})");
let mut out_idx = vec![0i32; n * k];
let mut out_dist = vec![0f32; n * k];
for i in 0..n {
let mut row: Vec<(f32, usize)> = (0..n)
.filter(|&j| j != i)
.map(|j| (flat[i * n + j], j))
.collect();
if k < row.len() {
row.select_nth_unstable_by(k - 1, |a, b| {
a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
});
}
row[..k].sort_by(|a, b| {
a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
});
for (slot, (dist, idx)) in row[..k].iter().enumerate() {
out_idx[i * k + slot] = *idx as i32;
out_dist[i * k + slot] = *dist;
}
}
(out_idx, out_dist)
}
#[allow(dead_code)]
pub fn knn_tensors_from_cpu<B: Backend>(
idx_flat: Vec<i32>,
dist_flat: Vec<f32>,
n: usize,
k: usize,
device: &burn::tensor::Device<B>,
) -> (Tensor<B, 2, Int>, Tensor<B, 2>) {
let knn_indices: Tensor<B, 2, Int> =
Tensor::from_data(TensorData::new(idx_flat, [n, k]), device);
let knn_dist: Tensor<B, 2> =
Tensor::from_data(TensorData::new(dist_flat, [n, k]), device);
(knn_indices, knn_dist)
}