use az::{Az, Cast};
use std::collections::BinaryHeap;
use std::ops::Rem;
use crate::float_sss::{
kdtree::{Axis, KdTree},
neighbour::Neighbour,
};
use crate::types::{Content, Index};
impl<A: Axis, T: Content, const K: usize, const B: usize, IDX: Index<T = IDX>>
KdTree<A, T, K, B, IDX>
where
usize: Cast<IDX>,
{
#[inline]
pub fn within<F>(&self, query: &[A; K], dist: A, distance_fn: &F) -> Vec<Neighbour<A, T>>
where
F: Fn(&[A; K], &[A; K]) -> A,
{
self.within_exclusive(query, dist, distance_fn, true)
}
#[inline]
pub fn within_exclusive<F>(
&self,
query: &[A; K],
dist: A,
distance_fn: &F,
inclusive: bool,
) -> Vec<Neighbour<A, T>>
where
F: Fn(&[A; K], &[A; K]) -> A,
{
let mut off = [A::zero(); K];
let mut matching_items: BinaryHeap<Neighbour<A, T>> = BinaryHeap::new();
unsafe {
self.within_recurse(
query,
dist,
distance_fn,
self.root_index,
0,
&mut matching_items,
&mut off,
A::zero(),
inclusive,
);
}
matching_items.into_sorted_vec()
}
unsafe fn within_recurse<F>(
&self,
query: &[A; K],
radius: A,
distance_fn: &F,
curr_node_idx: IDX,
split_dim: usize,
matching_items: &mut BinaryHeap<Neighbour<A, T>>,
off: &mut [A; K],
rd: A,
inclusive: bool,
) where
F: Fn(&[A; K], &[A; K]) -> A,
{
if KdTree::<A, T, K, B, IDX>::is_stem_index(curr_node_idx) {
let node = self.stems.get_unchecked(curr_node_idx.az::<usize>());
let mut rd = rd;
let old_off = off[split_dim];
let new_off = query[split_dim] - node.split_val;
let [closer_node_idx, further_node_idx] =
if *query.get_unchecked(split_dim) < node.split_val {
[node.left, node.right]
} else {
[node.right, node.left]
};
let next_split_dim = (split_dim + 1).rem(K);
self.within_recurse(
query,
radius,
distance_fn,
closer_node_idx,
next_split_dim,
matching_items,
off,
rd,
inclusive,
);
rd = rd + new_off * new_off - old_off * old_off;
if if inclusive { rd <= radius } else { rd < radius } {
off[split_dim] = new_off;
self.within_recurse(
query,
radius,
distance_fn,
further_node_idx,
next_split_dim,
matching_items,
off,
rd,
inclusive,
);
off[split_dim] = old_off;
}
} else {
let leaf_node = self
.leaves
.get_unchecked((curr_node_idx - IDX::leaf_offset()).az::<usize>());
leaf_node
.content_points
.iter()
.enumerate()
.take(leaf_node.size.az::<usize>())
.for_each(|(idx, entry)| {
let distance = distance_fn(query, entry);
if if inclusive {
distance <= radius
} else {
distance < radius
} {
matching_items.push(Neighbour {
distance,
item: *leaf_node.content_items.get_unchecked(idx.az::<usize>()),
})
}
});
}
}
}
#[cfg(test)]
mod tests {
use crate::float::distance::manhattan;
use crate::float::kdtree::{Axis, KdTree};
use rand::Rng;
use rstest::rstest;
use std::cmp::Ordering;
type AX = f32;
#[test]
fn can_query_items_within_radius() {
let mut tree: KdTree<AX, u32, 4, 5, u32> = KdTree::new();
let content_to_add: [([AX; 4], u32); 16] = [
([0.9f32, 0.0f32, 0.9f32, 0.0f32], 9),
([0.4f32, 0.5f32, 0.4f32, 0.5f32], 4),
([0.12f32, 0.3f32, 0.12f32, 0.3f32], 12),
([0.7f32, 0.2f32, 0.7f32, 0.2f32], 7),
([0.13f32, 0.4f32, 0.13f32, 0.4f32], 13),
([0.6f32, 0.3f32, 0.6f32, 0.3f32], 6),
([0.2f32, 0.7f32, 0.2f32, 0.7f32], 2),
([0.14f32, 0.5f32, 0.14f32, 0.5f32], 14),
([0.3f32, 0.6f32, 0.3f32, 0.6f32], 3),
([0.10f32, 0.1f32, 0.10f32, 0.1f32], 10),
([0.16f32, 0.7f32, 0.16f32, 0.7f32], 16),
([0.1f32, 0.8f32, 0.1f32, 0.8f32], 1),
([0.15f32, 0.6f32, 0.15f32, 0.6f32], 15),
([0.5f32, 0.4f32, 0.5f32, 0.4f32], 5),
([0.8f32, 0.1f32, 0.8f32, 0.1f32], 8),
([0.11f32, 0.2f32, 0.11f32, 0.2f32], 11),
];
for (point, item) in content_to_add {
tree.add(&point, item);
}
assert_eq!(tree.size(), 16);
let query_point = [0.78f32, 0.55f32, 0.78f32, 0.55f32];
let radius = 0.2;
let expected = linear_search(&content_to_add, &query_point, radius);
let mut result: Vec<_> = tree
.within(&query_point, radius, &manhattan)
.into_iter()
.map(|n| (n.distance, n.item))
.collect();
stabilize_sort(&mut result);
assert_eq!(result, expected);
let mut rng = rand::rng();
for _i in 0..1000 {
let query_point = [
rng.random_range(0f32..1f32),
rng.random_range(0f32..1f32),
rng.random_range(0f32..1f32),
rng.random_range(0f32..1f32),
];
let radius: f32 = 2.0;
let expected = linear_search(&content_to_add, &query_point, radius);
let mut result: Vec<_> = tree
.within(&query_point, radius, &manhattan)
.into_iter()
.map(|n| (n.distance, n.item))
.collect();
stabilize_sort(&mut result);
assert_eq!(result, expected);
}
}
#[test]
fn can_query_items_within_radius_large_scale() {
const TREE_SIZE: usize = 100_000;
const NUM_QUERIES: usize = 100;
const RADIUS: f32 = 0.2;
let content_to_add: Vec<([f32; 4], u32)> = (0..TREE_SIZE)
.map(|_| rand::random::<([f32; 4], u32)>())
.collect();
let mut tree: KdTree<AX, u32, 4, 32, u32> = KdTree::with_capacity(TREE_SIZE);
content_to_add
.iter()
.for_each(|(point, content)| tree.add(point, *content));
assert_eq!(tree.size(), TREE_SIZE as u32);
let query_points: Vec<[f32; 4]> = (0..NUM_QUERIES)
.map(|_| rand::random::<[f32; 4]>())
.collect();
for query_point in query_points {
let expected = linear_search(&content_to_add, &query_point, RADIUS);
let result: Vec<_> = tree
.within(&query_point, RADIUS, &manhattan)
.into_iter()
.map(|n| (n.distance, n.item))
.collect();
assert_eq!(result, expected);
}
}
#[rstest]
#[case(true, 1)]
#[case(false, 0)]
fn test_within_boundary_inclusiveness(#[case] inclusive: bool, #[case] expected_len: usize) {
let mut kdtree: KdTree<f32, u32, 2, 5, u32> = KdTree::new();
kdtree.add(&[1.0, 0.0], 1);
kdtree.add(&[2.0, 0.0], 2);
let query = [0.0, 0.0];
let radius = 1.0;
let results = kdtree.within_exclusive(
&query,
radius,
&|a, b| {
let mut dist = 0.0;
for i in 0..2 {
dist += (a[i] - b[i]) * (a[i] - b[i]);
}
dist
},
inclusive,
);
assert_eq!(results.len(), expected_len);
if expected_len > 0 {
assert_eq!(results[0].item, 1);
assert_eq!(results[0].distance, 1.0);
}
}
fn linear_search<A: Axis, const K: usize>(
content: &[([A; K], u32)],
query_point: &[A; K],
radius: A,
) -> Vec<(A, u32)> {
let mut matching_items = vec![];
for &(p, item) in content {
let dist = manhattan(query_point, &p);
if dist <= radius {
matching_items.push((dist, item));
}
}
stabilize_sort(&mut matching_items);
matching_items
}
fn stabilize_sort<A: Axis>(matching_items: &mut Vec<(A, u32)>) {
matching_items.sort_unstable_by(|a, b| {
let dist_cmp = a.0.partial_cmp(&b.0).unwrap();
if dist_cmp == Ordering::Equal {
a.1.cmp(&b.1)
} else {
dist_cmp
}
});
}
}