use crate::distance::simd_euclidean_distance;
use ndarray::{Array2, ArrayView2, Axis, Zip};
use rayon::prelude::*;
use std::cmp::min;
use thiserror::Error;
use tracing::error;
use usearch::ffi::Matches;
use usearch::{IndexOptions, MetricKind};
pub fn find_k_nearest_neighbors_approx(
data: ArrayView2<f32>,
k: usize,
) -> Result<(Array2<u32>, Array2<f32>), KnnError> {
let n = data.nrows();
let options = IndexOptions {
dimensions: data.ncols(),
metric: MetricKind::L2sq,
multi: true,
..Default::default()
};
let index = usearch::new_index(&options).map_err(|e| KnnError::Construction(e.to_string()))?;
index
.reserve(n)
.map_err(|e| KnnError::Reservation(e.to_string()))?;
Zip::indexed(data.axis_iter(Axis(0))).par_for_each(|i, vector| {
let result = match vector.as_slice() {
None => {
let vector = vector.to_vec();
index.add(i as u64, &vector)
}
Some(slice) => index.add(i as u64, slice),
};
if let Err(error) = result {
error!("failed to add vector to index: {error}; skipping");
}
});
let mut neighbors = Array2::zeros((n, k));
let mut distances = Array2::zeros((n, k));
Zip::indexed(data.axis_iter(Axis(0)))
.and(neighbors.axis_iter_mut(Axis(0)))
.and(distances.axis_iter_mut(Axis(0)))
.par_for_each(|i, vector, mut nbrs, mut dists| {
let identity_check = |key| key != i as u64;
let result = match vector.as_slice() {
None => {
let vector = vector.to_vec();
index.filtered_search(&vector, k, identity_check)
}
Some(slice) => index.filtered_search(slice, k, identity_check),
};
let (keys, distances) = match result {
Ok(Matches { keys, distances }) => (keys, distances),
Err(error) => {
error!("index search failed: {error}; skipping");
return;
}
};
for (((nbr, dist), key), distance) in nbrs
.iter_mut()
.zip(dists.iter_mut())
.zip(keys.into_iter())
.zip(distances.into_iter())
{
*nbr = key as u32;
*dist = distance.sqrt();
}
});
Ok((neighbors, distances))
}
#[must_use]
pub fn find_k_nearest_neighbors(data: ArrayView2<f32>, k: usize) -> (Array2<u32>, Array2<f32>) {
let n = data.nrows();
if n == 0 {
return (Array2::<u32>::zeros((0, 0)), Array2::<f32>::zeros((0, 0)));
}
let k = min(k, n - 1);
let distances: Vec<_> = (0..n)
.into_par_iter()
.flat_map(|i| {
let row_i = data.row(i);
let a_slice = row_i.as_slice().unwrap_or(&[]);
(i + 1..n)
.map(move |j| {
let row_j = data.row(j);
let b_slice = row_j.as_slice().unwrap_or(&[]);
let dist = simd_euclidean_distance(a_slice, b_slice);
(i as u32, j as u32, dist)
})
.collect::<Vec<_>>()
})
.collect();
let mut distances_sorted = distances;
distances_sorted.par_sort_unstable_by(|a, b| f32::total_cmp(&a.2, &b.2));
let mut neighbor_array = Array2::<u32>::zeros((n, k));
let mut distance_array = Array2::<f32>::from_elem((n, k), f32::MAX);
let mut counts = vec![0; n];
for &(i, j, distance) in &distances_sorted {
let ix = i as usize;
let jx = j as usize;
if counts[ix] < k {
neighbor_array[(ix, counts[ix])] = j;
distance_array[(ix, counts[ix])] = distance;
counts[ix] += 1;
}
if counts[jx] < k {
neighbor_array[(jx, counts[jx])] = i;
distance_array[(jx, counts[jx])] = distance;
counts[jx] += 1;
}
}
(neighbor_array, distance_array)
}
#[derive(Debug, Error)]
pub enum KnnError {
#[error("failed to construct index: {0}")]
Construction(String),
#[error("failed to reserve space for vectors: {0}")]
Reservation(String),
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use ndarray::{array, Array2};
use std::f32::consts::FRAC_1_SQRT_2;
#[test]
fn test_empty_embeddings() {
let embeddings = Array2::<f32>::zeros((0, 128));
let k = 5;
let (neighbor_indices, distances) = find_k_nearest_neighbors(embeddings.view(), k);
assert_eq!(neighbor_indices.shape(), &[0, 0]);
assert_eq!(distances.shape(), &[0, 0]);
}
#[test]
fn test_single_embedding() {
let embeddings = array![[1.0, 0.0, 0.0]];
let k = 1;
let (neighbor_indices, distances) = find_k_nearest_neighbors(embeddings.view(), k);
assert_eq!(neighbor_indices.shape(), &[1, 0]);
assert_eq!(distances.shape(), &[1, 0]);
}
#[test]
fn test_k_zero() {
let embeddings = array![[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]];
let k = 0;
let (neighbor_indices, distances) = find_k_nearest_neighbors(embeddings.view(), k);
assert_eq!(neighbor_indices.shape(), &[3, 0]);
assert_eq!(distances.shape(), &[3, 0]);
}
#[test]
fn test_k_exceeds() {
let embeddings: Array2<f32> = array![[1.0], [2.0], [3.0]];
let k = 5;
let (neighbor_indices, distances) = find_k_nearest_neighbors(embeddings.view(), k);
assert_eq!(neighbor_indices.shape(), &[3, 2]);
assert_eq!(distances.shape(), &[3, 2]);
assert!(neighbor_indices.iter().all(|&idx| idx < 3));
assert!(distances.iter().all(|&d| d >= 0.0));
}
#[test]
fn test_normal_case() {
let embeddings: Array2<f32> = array![
[1.0, 0.0], [0.0, 1.0], [0.5, 0.5] ];
let k = 2;
let (neighbor_indices, distances) = find_k_nearest_neighbors(embeddings.view(), k);
assert_eq!(neighbor_indices.shape(), &[3, 2]);
assert_eq!(distances.shape(), &[3, 2]);
let expected = vec![
(0, vec![(2, FRAC_1_SQRT_2), (1, 1.4142)]),
(1, vec![(2, FRAC_1_SQRT_2), (0, 1.4142)]),
(2, vec![(0, FRAC_1_SQRT_2), (1, FRAC_1_SQRT_2)]),
];
check_neighbors_and_distances(&neighbor_indices, &distances, &expected);
}
#[test]
fn test_duplicate_embeddings() {
let embeddings: Array2<f32> = array![
[1.0, 0.0], [1.0, 0.0], [0.0, 1.0] ];
let k = 2;
let (neighbor_indices, distances) = find_k_nearest_neighbors(embeddings.view(), k);
assert_eq!(neighbor_indices.shape(), &[3, 2]);
assert_eq!(distances.shape(), &[3, 2]);
let expected = vec![
(0, vec![(1, 0.0), (2, 1.4142)]),
(1, vec![(0, 0.0), (2, 1.4142)]),
(2, vec![(0, 1.4142), (1, 1.4142)]),
];
check_neighbors_and_distances(&neighbor_indices, &distances, &expected);
}
#[test]
fn test_negative_components() {
let embeddings: Array2<f32> = array![
[1.0, 0.0], [-1.0, 0.0], [0.0, 1.0] ];
let k = 2;
let (neighbor_indices, distances) = find_k_nearest_neighbors(embeddings.view(), k);
assert_eq!(neighbor_indices.shape(), &[3, 2]);
assert_eq!(distances.shape(), &[3, 2]);
let expected = vec![
(0, vec![(2, 1.4142), (1, 2.0)]),
(1, vec![(2, 1.4142), (0, 2.0)]),
(2, vec![(0, 1.4142), (1, 1.4142)]),
];
check_neighbors_and_distances(&neighbor_indices, &distances, &expected);
}
fn check_neighbors_and_distances(
neighbor_indices: &Array2<u32>,
distances: &Array2<f32>,
expected: &Vec<(usize, Vec<(usize, f32)>)>,
) {
for &(point_idx, ref expected_neighbors) in expected {
let neighbors = neighbor_indices.row(point_idx);
let neighbor_distances = distances.row(point_idx);
assert_eq!(
neighbors.len(),
expected_neighbors.len(),
"Mismatch in number of neighbors for point {}",
point_idx
);
let mut neighbor_info: Vec<(usize, f32)> = neighbors
.iter()
.zip(neighbor_distances.iter())
.map(|(&idx, &dist)| (idx as usize, dist))
.collect();
let mut expected_sorted = expected_neighbors.clone();
neighbor_info.sort_by_key(|&(idx, _)| idx);
expected_sorted.sort_by_key(|&(idx, _)| idx);
for (&(neighbor_idx, distance), &(exp_neighbor_idx, exp_distance)) in
neighbor_info.iter().zip(expected_sorted.iter())
{
assert_eq!(
neighbor_idx, exp_neighbor_idx,
"Mismatch in neighbor index for point {}",
point_idx
);
assert_relative_eq!(distance, exp_distance, epsilon = 1e-4, max_relative = 1e-4);
}
}
}
}