use super::distance::DistanceMetric;
use super::flat::DataRef;
use rand::prelude::*;
struct VpNode {
point_idx: usize,
threshold: f32,
left: Option<Box<VpNode>>,
right: Option<Box<VpNode>>,
}
pub(crate) struct VpTree<'a, D: DistanceMetric, T: DataRef + ?Sized> {
root: Option<Box<VpNode>>,
data: &'a T,
metric: &'a D,
}
#[allow(dead_code)]
impl<'a, D: DistanceMetric, T: DataRef + ?Sized> VpTree<'a, D, T> {
pub(crate) fn new(data: &'a T, metric: &'a D) -> Self {
let n = data.n();
let mut indices: Vec<usize> = (0..n).collect();
let mut rng = StdRng::seed_from_u64(42);
let root = Self::build(&mut indices, data, metric, &mut rng);
Self { root, data, metric }
}
fn build(indices: &mut [usize], data: &T, metric: &D, rng: &mut StdRng) -> Option<Box<VpNode>> {
if indices.is_empty() {
return None;
}
if indices.len() == 1 {
return Some(Box::new(VpNode {
point_idx: indices[0],
threshold: 0.0,
left: None,
right: None,
}));
}
let vp_pos = rng.random_range(0..indices.len());
indices.swap(0, vp_pos);
let vp_idx = indices[0];
let rest = &mut indices[1..];
let mut dists: Vec<(usize, f32)> = rest
.iter()
.map(|&i| (i, metric.distance(data.row(vp_idx), data.row(i))))
.collect();
let mid = dists.len() / 2;
dists.select_nth_unstable_by(mid, |a, b| a.1.total_cmp(&b.1));
let threshold = dists[mid].1;
let mut left_indices: Vec<usize> = Vec::with_capacity(mid + 1);
let mut right_indices: Vec<usize> = Vec::with_capacity(dists.len() - mid);
for &(idx, dist) in &dists {
if dist <= threshold {
left_indices.push(idx);
} else {
right_indices.push(idx);
}
}
let left = Self::build(&mut left_indices, data, metric, rng);
let right = Self::build(&mut right_indices, data, metric, rng);
Some(Box::new(VpNode {
point_idx: vp_idx,
threshold,
left,
right,
}))
}
pub(crate) fn range_query(&self, query: &[f32], radius: f32) -> Vec<usize> {
let mut results = Vec::new();
if let Some(ref root) = self.root {
self.range_search(root, query, radius, &mut results);
}
results
}
fn range_search(&self, node: &VpNode, query: &[f32], radius: f32, results: &mut Vec<usize>) {
let dist = self.metric.distance(query, self.data.row(node.point_idx));
if dist <= radius {
results.push(node.point_idx);
}
if dist - radius <= node.threshold {
if let Some(ref left) = node.left {
self.range_search(left, query, radius, results);
}
}
if dist + radius > node.threshold {
if let Some(ref right) = node.right {
self.range_search(right, query, radius, results);
}
}
}
pub(crate) fn knn(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
let mut heap: Vec<(usize, f32)> = Vec::with_capacity(k + 1);
let mut tau = f32::MAX; if let Some(ref root) = self.root {
self.knn_search(root, query, k, &mut heap, &mut tau);
}
heap.sort_by(|a, b| a.1.total_cmp(&b.1));
heap
}
fn knn_search(
&self,
node: &VpNode,
query: &[f32],
k: usize,
heap: &mut Vec<(usize, f32)>,
tau: &mut f32,
) {
let dist = self.metric.distance(query, self.data.row(node.point_idx));
if heap.len() < k {
heap.push((node.point_idx, dist));
if heap.len() == k {
*tau = heap.iter().map(|(_, d)| *d).fold(0.0f32, f32::max);
}
} else if dist < *tau {
let max_pos = heap
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.1.total_cmp(&b.1))
.map(|(i, _)| i)
.unwrap();
heap[max_pos] = (node.point_idx, dist);
*tau = heap.iter().map(|(_, d)| *d).fold(0.0f32, f32::max);
}
let (first, second) = if dist <= node.threshold {
(&node.left, &node.right)
} else {
(&node.right, &node.left)
};
if let Some(ref child) = first {
if dist <= node.threshold {
if dist - *tau <= node.threshold {
self.knn_search(child, query, k, heap, tau);
}
} else {
if dist + *tau > node.threshold {
self.knn_search(child, query, k, heap, tau);
}
}
}
if let Some(ref child) = second {
if dist <= node.threshold {
if dist + *tau > node.threshold {
self.knn_search(child, query, k, heap, tau);
}
} else {
if dist - *tau <= node.threshold {
self.knn_search(child, query, k, heap, tau);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cluster::distance::Euclidean;
#[test]
fn range_query_finds_neighbors() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let tree = VpTree::new(&data, &Euclidean);
let mut results = tree.range_query(&[0.0, 0.0], 0.5);
results.sort();
assert!(results.contains(&0));
assert!(results.contains(&1));
assert!(!results.contains(&2));
}
#[test]
fn knn_finds_nearest() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![1.0, 1.0],
vec![10.0, 10.0],
];
let tree = VpTree::new(&data, &Euclidean);
let results = tree.knn(&[0.0, 0.0], 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 0); assert_eq!(results[1].0, 1); }
#[test]
fn empty_data() {
let data: Vec<Vec<f32>> = vec![];
let tree = VpTree::new(&data, &Euclidean);
assert!(tree.range_query(&[0.0], 1.0).is_empty());
assert!(tree.knn(&[0.0], 1).is_empty());
}
#[test]
fn single_point() {
let data = vec![vec![5.0, 5.0]];
let tree = VpTree::new(&data, &Euclidean);
let results = tree.range_query(&[5.0, 5.0], 0.1);
assert_eq!(results, vec![0]);
}
}