use crate::distance::{Distance, Proximity};
use alloc::vec::Vec;
#[derive(Clone, Copy, Debug)]
pub struct Neighbor<V, D> {
pub item: V,
pub distance: D,
}
impl<V, D> Neighbor<V, D> {
pub fn new(item: V, distance: D) -> Self {
Self { item, distance }
}
}
impl<V1, D1, V2, D2> PartialEq<Neighbor<V2, D2>> for Neighbor<V1, D1>
where
V1: PartialEq<V2>,
D1: PartialEq<D2>,
{
fn eq(&self, other: &Neighbor<V2, D2>) -> bool {
self.item == other.item && self.distance == other.distance
}
}
pub trait Neighborhood<K: Proximity<V>, V> {
fn target(&self) -> K;
fn contains<D>(&self, distance: D) -> bool
where
D: PartialOrd<K::Distance>;
fn consider(&mut self, item: V) -> K::Distance;
}
#[derive(Debug)]
struct SingletonNeighborhood<K, V, D> {
target: K,
threshold: Option<D>,
neighbor: Option<Neighbor<V, D>>,
}
impl<K, V, D> SingletonNeighborhood<K, V, D> {
fn new(target: K, threshold: Option<D>) -> Self {
Self {
target,
threshold,
neighbor: None,
}
}
fn into_option(self) -> Option<Neighbor<V, D>> {
self.neighbor
}
}
impl<K, V> Neighborhood<K, V> for SingletonNeighborhood<K, V, K::Distance>
where
K: Copy + Proximity<V>,
{
fn target(&self) -> K {
self.target
}
fn contains<D>(&self, distance: D) -> bool
where
D: PartialOrd<K::Distance>,
{
self.threshold.map_or(true, |t| distance <= t)
}
fn consider(&mut self, item: V) -> K::Distance {
let distance = self.target.distance(&item);
if self.contains(distance) {
self.threshold = Some(distance);
self.neighbor = Some(Neighbor::new(item, distance));
}
distance
}
}
#[derive(Debug)]
struct HeapNeighborhood<'a, K, V, D> {
target: K,
k: usize,
threshold: Option<D>,
heap: &'a mut Vec<Neighbor<V, D>>,
}
impl<'a, K, V, D: Distance> HeapNeighborhood<'a, K, V, D> {
fn new(
target: K,
k: usize,
mut threshold: Option<D>,
heap: &'a mut Vec<Neighbor<V, D>>,
) -> Self {
heap.reverse();
if k > 0 && heap.len() == k {
let distance = heap[0].distance;
if threshold.map_or(true, |t| distance <= t) {
threshold = Some(distance);
}
}
Self {
target,
k,
threshold,
heap,
}
}
fn push(&mut self, item: Neighbor<V, D>) {
let mut i = self.heap.len();
self.heap.push(item);
while i > 0 {
let parent = (i - 1) / 2;
if self.heap[i].distance > self.heap[parent].distance {
self.heap.swap(i, parent);
i = parent;
} else {
break;
}
}
}
fn sink_root(&mut self, len: usize) {
let mut i = 0;
let dist = self.heap[i].distance;
loop {
let mut child = 2 * i + 1;
let right = child + 1;
if right < len && self.heap[child].distance < self.heap[right].distance {
child = right;
}
if child < len && dist < self.heap[child].distance {
self.heap.swap(i, child);
i = child;
} else {
break;
}
}
}
fn replace_root(&mut self, item: Neighbor<V, D>) {
self.heap[0] = item;
self.sink_root(self.heap.len());
}
fn sort(&mut self) {
for i in (0..self.heap.len()).rev() {
self.heap.swap(0, i);
self.sink_root(i);
}
}
}
impl<'a, K, V> Neighborhood<K, V> for HeapNeighborhood<'a, K, V, K::Distance>
where
K: Copy + Proximity<V>,
{
fn target(&self) -> K {
self.target
}
fn contains<D>(&self, distance: D) -> bool
where
D: PartialOrd<K::Distance>,
{
self.k > 0 && self.threshold.map_or(true, |t| distance <= t)
}
fn consider(&mut self, item: V) -> K::Distance {
let distance = self.target.distance(&item);
if self.contains(distance) {
let neighbor = Neighbor::new(item, distance);
if self.heap.len() < self.k {
self.push(neighbor);
} else {
self.replace_root(neighbor);
}
if self.heap.len() == self.k {
self.threshold = Some(self.heap[0].distance);
}
}
distance
}
}
pub trait NearestNeighbors<K: Proximity<V>, V = K> {
fn nearest(&self, target: &K) -> Option<Neighbor<&V, K::Distance>> {
self.search(SingletonNeighborhood::new(target, None))
.into_option()
}
fn nearest_within<D>(&self, target: &K, threshold: D) -> Option<Neighbor<&V, K::Distance>>
where
D: TryInto<K::Distance>,
{
if let Ok(distance) = threshold.try_into() {
self.search(SingletonNeighborhood::new(target, Some(distance)))
.into_option()
} else {
None
}
}
fn k_nearest(&self, target: &K, k: usize) -> Vec<Neighbor<&V, K::Distance>> {
let mut neighbors = Vec::with_capacity(k);
self.merge_k_nearest(target, k, &mut neighbors);
neighbors
}
fn k_nearest_within<D>(
&self,
target: &K,
k: usize,
threshold: D,
) -> Vec<Neighbor<&V, K::Distance>>
where
D: TryInto<K::Distance>,
{
let mut neighbors = Vec::with_capacity(k);
self.merge_k_nearest_within(target, k, threshold, &mut neighbors);
neighbors
}
fn merge_k_nearest<'v>(
&'v self,
target: &K,
k: usize,
neighbors: &mut Vec<Neighbor<&'v V, K::Distance>>,
) {
self.search(HeapNeighborhood::new(target, k, None, neighbors))
.sort();
}
fn merge_k_nearest_within<'v, D>(
&'v self,
target: &K,
k: usize,
threshold: D,
neighbors: &mut Vec<Neighbor<&'v V, K::Distance>>,
) where
D: TryInto<K::Distance>,
{
if let Ok(distance) = threshold.try_into() {
self.search(HeapNeighborhood::new(target, k, Some(distance), neighbors))
.sort();
}
}
fn search<'k, 'v, N>(&'v self, neighborhood: N) -> N
where
K: 'k,
V: 'v,
N: Neighborhood<&'k K, &'v V>;
}
pub trait ExactNeighbors<K: Proximity<V>, V = K>: NearestNeighbors<K, V> {}
#[cfg(test)]
pub mod tests {
use super::*;
use crate::euclid::{Euclidean, EuclideanDistance};
use crate::exhaustive::ExhaustiveSearch;
use rand::random;
use alloc::vec;
type Point = Euclidean<[f32; 3]>;
pub fn test_exact_neighbors<T, F>(from_iter: F)
where
T: ExactNeighbors<Point>,
F: Fn(Vec<Point>) -> T,
{
test_empty(&from_iter);
test_pythagorean(&from_iter);
test_random_points(&from_iter);
}
fn test_empty<T, F>(from_iter: &F)
where
T: NearestNeighbors<Point>,
F: Fn(Vec<Point>) -> T,
{
let points = Vec::new();
let index = from_iter(points);
let target = Euclidean([0.0, 0.0, 0.0]);
assert_eq!(index.nearest(&target), None);
assert_eq!(index.nearest_within(&target, 1.0), None);
assert!(index.k_nearest(&target, 0).is_empty());
assert!(index.k_nearest(&target, 3).is_empty());
assert!(index.k_nearest_within(&target, 0, 1.0).is_empty());
assert!(index.k_nearest_within(&target, 3, 1.0).is_empty());
}
fn test_pythagorean<T, F>(from_iter: &F)
where
T: NearestNeighbors<Point>,
F: Fn(Vec<Point>) -> T,
{
let points = vec![
Euclidean([3.0, 4.0, 0.0]),
Euclidean([5.0, 0.0, 12.0]),
Euclidean([0.0, 8.0, 15.0]),
Euclidean([1.0, 2.0, 2.0]),
Euclidean([2.0, 3.0, 6.0]),
Euclidean([4.0, 4.0, 7.0]),
];
let index = from_iter(points);
let target = Euclidean([0.0, 0.0, 0.0]);
assert_eq!(
index.nearest(&target).expect("No nearest neighbor found"),
Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0)
);
assert_eq!(index.nearest_within(&target, 2.0), None);
assert_eq!(
index.nearest_within(&target, 4.0).expect("No nearest neighbor found within 4.0"),
Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0)
);
assert!(index.k_nearest(&target, 0).is_empty());
assert_eq!(
index.k_nearest(&target, 3),
vec![
Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0),
Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0),
Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0),
]
);
assert!(index.k_nearest(&target, 0).is_empty());
assert_eq!(
index.k_nearest_within(&target, 3, 6.0),
vec![
Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0),
Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0),
]
);
assert_eq!(
index.k_nearest_within(&target, 3, 8.0),
vec![
Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0),
Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0),
Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0),
]
);
let mut neighbors = Vec::new();
index.merge_k_nearest(&target, 3, &mut neighbors);
assert_eq!(
neighbors,
vec![
Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0),
Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0),
Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0),
]
);
neighbors = vec![
Neighbor::new(&target, EuclideanDistance::from_squared(0.0)),
Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), EuclideanDistance::from_squared(25.0)),
Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), EuclideanDistance::from_squared(49.0)),
];
index.merge_k_nearest_within(&target, 3, 4.0, &mut neighbors);
assert_eq!(
neighbors,
vec![
Neighbor::new(&target, 0.0),
Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0),
Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0),
]
);
}
fn test_random_points<T, F>(from_iter: &F)
where
T: NearestNeighbors<Point>,
F: Fn(Vec<Point>) -> T,
{
let mut points = Vec::new();
for _ in 0..256 {
points.push(Euclidean([random(), random(), random()]));
}
let index = from_iter(points.clone());
let eindex = ExhaustiveSearch::from_iter(points.clone());
let target = Euclidean([random(), random(), random()]);
assert_eq!(
index.k_nearest(&target, 3),
eindex.k_nearest(&target, 3),
"target: {:?}, points: {:#?}",
target,
points,
);
}
}