use vpsearch::{BestCandidate, MetricSpace};
use std::collections::HashSet;
use num_traits::Bounded;
#[derive(Clone, Debug)]
struct PointN {
data: Vec<f32>,
}
impl PointN {
pub fn new(data: impl Into<Vec<f32>>) -> Self {
Self { data: data.into() }
}
}
impl MetricSpace for PointN {
type UserData = ();
type Distance = f32;
fn distance(&self, other: &Self, _: &Self::UserData) -> Self::Distance {
self.data
.iter()
.zip(other.data.iter())
.map(|(s, o)| (s - o).powi(2))
.sum::<f32>()
.sqrt()
}
}
struct CountBasedNeighborhood<Item: MetricSpace<Impl>, Impl> {
max_item_count: usize,
max_observed_distance: Item::Distance,
distance_x_index: Vec<(Item::Distance, usize)>,
}
impl<Item: MetricSpace<Impl>, Impl> CountBasedNeighborhood<Item, Impl> {
fn new(item_count: usize) -> Self {
Self {
max_item_count: item_count,
max_observed_distance: <Item::Distance as Bounded>::min_value(),
distance_x_index: Vec::<(Item::Distance, usize)>::new(),
}
}
fn insert_index(&mut self, index: usize, distance: Item::Distance) {
self.distance_x_index.push((distance, index));
if self.distance_x_index.len() > 1 {
let mut n = self.distance_x_index.len() - 1;
while n > 0 && self.distance_x_index[n].0 < self.distance_x_index[n - 1].0 {
self.distance_x_index.swap(n, n - 1);
n -= 1;
}
self.distance_x_index.truncate(self.max_item_count);
}
self.max_observed_distance = self.distance_x_index.last().unwrap().0;
}
}
impl<Item: MetricSpace<Impl> + Clone, Impl> BestCandidate<Item, Impl>
for CountBasedNeighborhood<Item, Impl>
{
type Output = HashSet<usize>;
#[inline]
fn consider(
&mut self,
_: &Item,
distance: Item::Distance,
candidate_index: usize,
_: &Item::UserData,
) {
if self.max_item_count == 0 {
return;
}
if distance < self.max_observed_distance
|| self.distance_x_index.len() < self.max_item_count
{
self.insert_index(candidate_index, distance);
}
}
#[inline]
fn distance(&self) -> Item::Distance {
self.max_observed_distance
}
fn result(self, _: &Item::UserData) -> Self::Output {
self.distance_x_index
.into_iter()
.map(|(_, index)| index)
.collect::<HashSet<usize>>()
}
}
fn main() {
let points = vec![
PointN::new([2.0, 3.0]),
PointN::new([0.0, 1.0]),
PointN::new([4.0, 5.0]),
];
let tree = vpsearch::Tree::new(&points);
let actual = tree.find_nearest_custom(
&PointN::new([1.0, 2.0]),
&(),
CountBasedNeighborhood::new(1),
);
assert_eq!(actual.len(), 1);
let expected = [0, 1].iter().copied().collect::<HashSet<usize>>();
let actual = tree.find_nearest_custom(
&PointN::new([1.0, 2.0]),
&(),
CountBasedNeighborhood::new(2),
);
assert_eq!(actual, expected);
let expected = [0, 1, 2].iter().copied().collect::<HashSet<usize>>();
let actual = tree.find_nearest_custom(
&PointN::new([1.0, 2.0]),
&(),
CountBasedNeighborhood::new(10),
);
assert_eq!(actual, expected);
}