use rayon::prelude::*;
const INFINITY: f32 = f32::MAX;
const PARALLEL_MIN_ELEMS: usize = 64 * 64;
pub fn knn_forward_packed(pairwise: &[f32], n: usize, k: usize, packed: &mut [f32]) {
assert_eq!(pairwise.len(), n * n);
assert_eq!(packed.len(), n * 2 * k);
assert!(k < n, "k ({k}) must be strictly less than n ({n})");
if n * n >= PARALLEL_MIN_ELEMS {
packed
.par_chunks_mut(2 * k)
.enumerate()
.for_each(|(row, out_row)| {
knn_row(&pairwise[row * n..(row + 1) * n], row, n, k, out_row);
});
} else {
for row in 0..n {
let base = row * 2 * k;
knn_row(
&pairwise[row * n..(row + 1) * n],
row,
n,
k,
&mut packed[base..base + 2 * k],
);
}
}
}
#[inline]
fn knn_row(row_pw: &[f32], row: usize, n: usize, k: usize, out: &mut [f32]) {
let (idx_slice, dist_slice) = out.split_at_mut(k);
for slot in idx_slice.iter_mut() {
*slot = n as f32;
}
for slot in dist_slice.iter_mut() {
*slot = INFINITY;
}
let mut worst = INFINITY;
for (col, &dist) in row_pw.iter().enumerate() {
if col == row || dist >= worst {
continue;
}
if dist < dist_slice[k - 1] {
let mut slot = k - 1;
while slot > 0 && dist < dist_slice[slot - 1] {
dist_slice[slot] = dist_slice[slot - 1];
idx_slice[slot] = idx_slice[slot - 1];
slot -= 1;
}
dist_slice[slot] = dist;
idx_slice[slot] = col as f32;
worst = dist_slice[k - 1];
}
}
}
pub fn knn_backward_pairwise(
pairwise: &[f32],
d_dist: &[f32],
n: usize,
k: usize,
d_pairwise: &mut [f32],
) {
assert_eq!(pairwise.len(), n * n);
assert_eq!(d_dist.len(), n * k);
assert_eq!(d_pairwise.len(), n * n);
let eps = 1e-8f32;
d_pairwise.fill(0.0);
if n * n >= PARALLEL_MIN_ELEMS {
d_pairwise
.par_chunks_mut(n)
.enumerate()
.for_each(|(row, d_row)| {
let mut scratch_idx = vec![0f32; k];
let mut scratch_dist = vec![INFINITY; k];
knn_backward_row(
pairwise,
d_dist,
row,
n,
k,
eps,
d_row,
&mut scratch_idx,
&mut scratch_dist,
);
});
} else {
let mut scratch_idx = vec![0f32; k];
let mut scratch_dist = vec![INFINITY; k];
for row in 0..n {
knn_backward_row(
pairwise,
d_dist,
row,
n,
k,
eps,
&mut d_pairwise[row * n..(row + 1) * n],
&mut scratch_idx,
&mut scratch_dist,
);
}
}
}
fn knn_backward_row(
pairwise: &[f32],
d_dist: &[f32],
row: usize,
n: usize,
k: usize,
eps: f32,
d_row: &mut [f32],
scratch_idx: &mut [f32],
scratch_dist: &mut [f32],
) {
for slot in scratch_idx.iter_mut() {
*slot = n as f32;
}
for slot in scratch_dist.iter_mut() {
*slot = INFINITY;
}
let row_off = row * n;
for col in 0..n {
if row == col {
continue;
}
let dist = pairwise[row_off + col];
if dist < scratch_dist[k - 1] {
let mut slot = k - 1;
while slot > 0 && dist < scratch_dist[slot - 1] {
scratch_dist[slot] = scratch_dist[slot - 1];
scratch_idx[slot] = scratch_idx[slot - 1];
slot -= 1;
}
scratch_dist[slot] = dist;
scratch_idx[slot] = col as f32;
}
}
for slot in 0..k {
let grad_value = d_dist[row * k + slot];
if grad_value == 0.0 {
continue;
}
let dist = scratch_dist[slot].max(eps);
let neighbor_col = scratch_idx[slot] as usize;
if neighbor_col < n {
d_row[neighbor_col] += grad_value / dist;
}
}
}