use crate::utils::{dist, theiler_exclude};
use std::cmp::Ordering;
#[derive(Debug, Clone)]
pub struct VpNode {
pub idx: usize,
pub tau: f64,
pub left: Option<Box<VpNode>>,
pub right: Option<Box<VpNode>>,
}
#[derive(Debug)]
pub struct VpTree<'a> {
data: &'a [Vec<f64>],
root: Option<Box<VpNode>>,
}
impl<'a> VpTree<'a> {
pub fn build(data: &'a [Vec<f64>], indices: &mut [usize]) -> Self {
let root = Self::build_rec(data, indices);
Self { data, root }
}
fn build_rec(data: &[Vec<f64>], indices: &mut [usize]) -> Option<Box<VpNode>> {
if indices.is_empty() {
return None;
}
let vp = indices[indices.len() - 1];
if indices.len() == 1 {
return Some(Box::new(VpNode {
idx: vp,
tau: 0.0,
left: None,
right: None
}));
}
let (left_slice, _vp_slot) = indices.split_at_mut(indices.len() - 1);
let mut dists: Vec<(usize, f64)> = left_slice
.iter()
.map(|&j| (j, dist(&data[vp], &data[j])))
.collect();
let mid = dists.len() / 2;
dists.select_nth_unstable_by(mid, |a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
let tau = dists[mid].1;
let mut inner: Vec<usize> = Vec::with_capacity(mid + 1);
let mut outer: Vec<usize> = Vec::with_capacity(dists.len() - mid);
for (j, d) in dists {
if d <= tau {
inner.push(j);
} else {
outer.push(j);
}
}
let left = Self::build_rec(data, &mut inner);
let right = Self::build_rec(data, &mut outer);
Some(Box::new(VpNode { idx: vp, tau, left, right }))
}
pub fn nearest_excluding(&self, q: &[f64], target_i: usize, theiler: usize) -> Option<(usize, f64)> {
let mut best_idx = usize::MAX;
let mut best_dist = f64::INFINITY;
self.search(&self.root, q, target_i, theiler, &mut best_idx, &mut best_dist);
if best_idx == usize::MAX {
None
} else {
Some((best_idx, best_dist))
}
}
fn search(
&self,
node: &Option<Box<VpNode>>,
q: &[f64],
target_i: usize,
theiler: usize,
best_idx: &mut usize,
best_dist: &mut f64,
) {
let Some(n) = node else { return; };
let d = dist(q, &self.data[n.idx]);
if n.idx != target_i && !theiler_exclude(target_i, n.idx, theiler) {
if d < *best_dist {
*best_dist = d;
*best_idx = n.idx;
}
}
let first_left = d < n.tau || n.right.is_none();
let (first, second) = if first_left {
(&n.left, &n.right)
} else {
(&n.right, &n.left)
};
if first.is_some() {
self.search(first, q, target_i, theiler, best_idx, best_dist);
}
if (d - n.tau).abs() <= *best_dist {
if second.is_some() {
self.search(second, q, target_i, theiler, best_idx, best_dist);
}
}
}
pub fn stats(&self) -> TreeStats {
let mut stats = TreeStats::default();
self.compute_stats(&self.root, 0, &mut stats);
stats
}
fn compute_stats(&self, node: &Option<Box<VpNode>>, depth: usize, stats: &mut TreeStats) {
let Some(n) = node else { return; };
stats.node_count += 1;
stats.max_depth = stats.max_depth.max(depth);
if n.left.is_none() && n.right.is_none() {
stats.leaf_count += 1;
}
self.compute_stats(&n.left, depth + 1, stats);
self.compute_stats(&n.right, depth + 1, stats);
}
}
#[derive(Debug, Default, Clone)]
pub struct TreeStats {
pub node_count: usize,
pub leaf_count: usize,
pub max_depth: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vp_tree_build() {
let data = vec![
vec![1.0, 2.0],
vec![3.0, 4.0],
vec![5.0, 6.0],
];
let mut indices: Vec<usize> = (0..data.len()).collect();
let tree = VpTree::build(&data, &mut indices);
assert!(tree.root.is_some());
}
#[test]
fn test_nearest_excluding() {
let data = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![2.0, 0.0],
vec![3.0, 0.0],
];
let mut indices: Vec<usize> = (0..data.len()).collect();
let tree = VpTree::build(&data, &mut indices);
let query = vec![0.1, 0.0];
let result = tree.nearest_excluding(&query, 0, 0);
assert!(result.is_some());
let (idx, _dist) = result.unwrap();
assert_eq!(idx, 1); }
#[test]
fn test_theiler_exclusion() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.0], vec![2.0, 0.0],
];
let mut indices: Vec<usize> = (0..data.len()).collect();
let tree = VpTree::build(&data, &mut indices);
let query = vec![0.0, 0.0];
let result = tree.nearest_excluding(&query, 0, 2);
assert!(result.is_some());
let (idx, _dist) = result.unwrap();
assert_eq!(idx, 2); }
}