#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[cfg(feature = "std")]
use std::vec::Vec;
use core::cmp::Ordering::{self, Equal};
use num_traits::Float;
use crate::primitives::buffer::{NeighborhoodSearchBuffer, NeighborhoodStorage};
#[derive(Debug, Clone, Copy)]
pub struct NodeDistance<T>(pub usize, pub T);
impl<T: PartialEq> PartialEq for NodeDistance<T> {
fn eq(&self, other: &Self) -> bool {
self.1 == other.1
}
}
impl<T: PartialEq> Eq for NodeDistance<T> {}
impl<T: PartialOrd> PartialOrd for NodeDistance<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T: PartialOrd> Ord for NodeDistance<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.1.partial_cmp(&other.1).unwrap_or(Equal)
}
}
pub trait PointDistance<T: Float> {
fn distance_squared(&self, a: &[T], b: &[T]) -> T;
fn split_distance(&self, dim: usize, split_val: T, query_val: T) -> T;
fn split_distance_squared(&self, dim: usize, split_val: T, query_val: T) -> T;
fn post_process_distance(&self, d: T) -> T;
}
#[derive(Debug, Clone)]
pub struct Neighborhood<T> {
pub indices: Vec<usize>,
pub distances: Vec<T>,
pub max_distance: T,
}
impl<T: Float> Neighborhood<T> {
pub fn new() -> Self {
Self {
indices: Vec::new(),
distances: Vec::new(),
max_distance: T::zero(),
}
}
pub fn with_capacity(k: usize) -> Self {
Self {
indices: Vec::with_capacity(k),
distances: Vec::with_capacity(k),
max_distance: T::zero(),
}
}
#[inline]
pub fn len(&self) -> usize {
self.indices.len()
}
#[inline]
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.indices.is_empty()
}
}
impl<T: Float> NeighborhoodStorage for Neighborhood<T> {
fn with_capacity(k: usize) -> Self {
Self::with_capacity(k)
}
fn capacity(&self) -> usize {
self.indices.capacity()
}
}
impl<T: Float> Default for Neighborhood<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct KDNode {
pub index: usize,
}
#[derive(Debug, Clone)]
pub struct KDTree<T: Float> {
nodes: Vec<KDNode>,
points: Vec<T>,
dimensions: usize,
}
impl<T: Float> KDTree<T> {
pub fn new(points: &[T], dimensions: usize) -> Self {
let n = points.len() / dimensions;
let mut indices: Vec<usize> = (0..n).collect();
let mut nodes = vec![KDNode::default(); n];
let mut permuted_points = vec![T::zero(); points.len()];
Self::build_recursive(
points,
dimensions,
&mut indices,
0,
&mut nodes,
&mut permuted_points,
0,
);
Self {
nodes,
points: permuted_points,
dimensions,
}
}
pub fn from_parts(nodes: Vec<KDNode>, points: Vec<T>, dimensions: usize) -> Self {
Self {
nodes,
points,
dimensions,
}
}
pub fn find_k_nearest<D: PointDistance<T>>(
&self,
query: &[T],
k: usize,
dist_calc: &D,
exclude_self: Option<usize>,
buffer: &mut NeighborhoodSearchBuffer<NodeDistance<T>>,
neighborhood: &mut Neighborhood<T>,
) {
if k == 0 || self.nodes.is_empty() {
neighborhood.max_distance = T::zero();
neighborhood.indices.clear();
neighborhood.distances.clear();
return;
}
buffer.clear();
self.search_iterative(query, k, dist_calc, exclude_self, buffer);
neighborhood.indices.clear();
neighborhood.distances.clear();
for &NodeDistance(idx, dist) in buffer.heap.iter() {
neighborhood.indices.push(idx);
neighborhood
.distances
.push(dist_calc.post_process_distance(dist));
}
let raw_max = buffer.heap.peek().map(|nd| nd.1).unwrap_or(T::zero());
neighborhood.max_distance = dist_calc.post_process_distance(raw_max);
}
fn build_recursive(
points: &[T],
dims: usize,
indices: &mut [usize],
depth: usize,
nodes: &mut [KDNode],
permuted_points: &mut [T],
curr_idx: usize,
) {
if indices.is_empty() {
return;
}
let axis = depth % dims;
let n = indices.len();
let left_count = Self::calculate_left_subtree_size(n);
let median_idx = left_count;
if median_idx < n {
indices.select_nth_unstable_by(median_idx, |&a, &b| {
points[a * dims + axis]
.partial_cmp(&points[b * dims + axis])
.unwrap_or(Equal)
});
}
let point_idx = indices[median_idx];
nodes[curr_idx] = KDNode { index: point_idx };
let src_start = point_idx * dims;
let dest_start = curr_idx * dims;
permuted_points[dest_start..dest_start + dims]
.copy_from_slice(&points[src_start..src_start + dims]);
let (left_part, right_part_with_median) = indices.split_at_mut(median_idx);
let right_part = &mut right_part_with_median[1..];
Self::build_recursive(
points,
dims,
left_part,
depth + 1,
nodes,
permuted_points,
2 * curr_idx + 1,
);
Self::build_recursive(
points,
dims,
right_part,
depth + 1,
nodes,
permuted_points,
2 * curr_idx + 2,
);
}
#[inline]
fn search_iterative<D: PointDistance<T>>(
&self,
query: &[T],
k: usize,
dist_calc: &D,
exclude_self: Option<usize>,
buffer: &mut NeighborhoodSearchBuffer<NodeDistance<T>>,
) {
let d = self.dimensions;
let heap = &mut buffer.heap;
let stack = &mut buffer.stack;
let nodes_len = self.nodes.len();
let mut heap_full = false;
let mut max_dist = T::infinity();
if nodes_len > 0 {
stack.push(0);
}
while let Some(packed) = stack.pop() {
let axis = packed & 0xFF; let node_idx = packed >> 8;
let node = unsafe { self.nodes.get_unchecked(node_idx) };
let offset = node_idx * d;
let node_point = unsafe { self.points.get_unchecked(offset..offset + d) };
if exclude_self != Some(node.index) {
let dist = dist_calc.distance_squared(query, node_point);
if !heap_full {
heap.push(NodeDistance(node.index, dist));
if heap.len() == k {
heap_full = true;
max_dist = heap.peek().map(|nd| nd.1).unwrap_or(T::infinity());
}
} else if dist < max_dist {
if let Some(mut top) = heap.peek_mut() {
*top = NodeDistance(node.index, dist);
}
max_dist = heap.peek().map(|nd| nd.1).unwrap_or(T::infinity());
}
}
let left_child = 2 * node_idx + 1;
if left_child >= nodes_len {
continue;
}
let split_dim = axis;
let split_val = node_point[split_dim];
let diff = query[split_dim] - split_val;
let right_child = left_child + 1;
let has_right = right_child < nodes_len;
let next_axis = if split_dim + 1 == d { 0 } else { split_dim + 1 };
let packed_left = (left_child << 8) | next_axis;
let packed_right = (right_child << 8) | next_axis;
let (near_packed, far_packed, has_far) = if diff <= T::zero() {
(packed_left, packed_right, has_right)
} else {
(packed_right, packed_left, true)
};
if has_far {
let dist_to_plane =
dist_calc.split_distance_squared(split_dim, split_val, query[split_dim]);
if !heap_full || dist_to_plane < max_dist {
stack.push(far_packed);
}
}
let near_exists = if diff <= T::zero() {
true } else {
has_right
};
if near_exists {
stack.push(near_packed);
}
}
}
pub fn calculate_left_subtree_size(n: usize) -> usize {
if n == 0 {
return 0;
}
let h = (usize::BITS - n.leading_zeros() - 1) as usize;
if h == 0 {
return 0;
}
let max_leaf_capacity = 1 << h;
let total_nodes_above_leaf = max_leaf_capacity - 1;
let r = n - total_nodes_above_leaf;
let left_part_leaves = r.min(max_leaf_capacity / 2);
let left_subtree_capacity_full = (max_leaf_capacity / 2) - 1;
left_subtree_capacity_full + left_part_leaves
}
}