use std::cmp::Ordering;
use ndarray::{Array1, Array2, ArrayBase, ArrayView2, AsArray, Axis, Ix1, Ix2, ViewRepr};
use crate::prelude::*;
#[derive(Debug)]
pub struct KDTree<'a, T> {
pub cloud: ArrayView2<'a, T>,
pub nodes: Vec<Node>,
pub root: Option<usize>,
}
#[derive(Debug)]
pub struct Node {
pub split_axis: usize,
pub point_index: usize,
pub left: Option<usize>,
pub right: Option<usize>,
}
impl<'a, T> KDTree<'a, T>
where
T: 'a + AsNumeric,
{
pub fn build<A>(cloud: A) -> Self
where
A: AsArray<'a, T, Ix2>,
{
let cloud: ArrayBase<ViewRepr<&'a T>, Ix2> = cloud.into();
let total_points = cloud.dim().0;
let mut tree = Self {
cloud,
nodes: Vec::with_capacity(total_points),
root: None,
};
let mut indices: Vec<usize> = (0..total_points).collect();
tree.root = tree.recursive_build(&mut indices, 0);
tree
}
pub fn search_for_coords<'b, B>(&self, query: B, radius: f64) -> Result<Array2<T>, ImgalError>
where
B: AsArray<'b, T, Ix1>,
T: 'b + AsNumeric,
{
let query: ArrayBase<ViewRepr<&'b T>, Ix1> = query.into();
let q_dims = query.len();
let c_dims = self.cloud.dim().1;
if q_dims != c_dims {
return Err(ImgalError::MismatchedArrayLengths {
a_arr_name: "query",
a_arr_len: q_dims,
b_arr_name: "cloud array shape",
b_arr_len: c_dims,
});
}
let (coord_indices, _) = self
.search_for_indices(query, radius)?
.into_raw_vec_and_offset();
Ok(self.cloud.select(Axis(0), &coord_indices))
}
pub fn search_for_indices<'b, B>(
&self,
query: B,
radius: f64,
) -> Result<Array1<usize>, ImgalError>
where
B: AsArray<'b, T, Ix1>,
T: 'b + AsNumeric,
{
let query: ArrayBase<ViewRepr<&'b T>, Ix1> = query.into();
let query = query.to_vec();
let q_dims = query.len();
let c_dims = self.cloud.dim().1;
if q_dims != c_dims {
return Err(ImgalError::MismatchedArrayLengths {
a_arr_name: "query",
a_arr_len: q_dims,
b_arr_name: "cloud array shape",
b_arr_len: c_dims,
});
}
let radius_sq = radius * radius;
let mut results: Vec<usize> = Vec::new();
if let Some(root) = self.root {
self.recursive_search(root, &query, radius_sq, &mut results);
}
Ok(Array1::from_vec(results))
}
fn recursive_build(&mut self, indices: &mut [usize], depth: usize) -> Option<usize> {
if indices.is_empty() {
return None;
}
let n_dims = self.cloud.dim().1;
let split_axis = depth % n_dims;
let median = indices.len() / 2;
indices.select_nth_unstable_by(median, |&a, &b| {
self.cloud[[a, split_axis]]
.partial_cmp(&self.cloud[[b, split_axis]])
.unwrap_or(Ordering::Less)
});
let point_index = indices[median];
let left = self.recursive_build(&mut indices[..median], depth + 1);
let right = self.recursive_build(&mut indices[median + 1..], depth + 1);
let node_index = self.nodes.len();
self.nodes
.push(Node::new(split_axis, point_index, left, right));
Some(node_index)
}
fn recursive_search(
&self,
node_index: usize,
query: &[T],
radius_sq: f64,
results: &mut Vec<usize>,
) {
let node = &self.nodes[node_index];
let node_dist_sq = query.iter().enumerate().fold(0.0, |acc, (i, &q)| {
let d = self.cloud[[node.point_index, i]].to_f64() - q.to_f64();
acc + d * d
});
if node_dist_sq <= radius_sq {
results.push(node.point_index);
}
let ax = node.split_axis;
let diff = query[ax].to_f64() - self.cloud[[node.point_index, ax]].to_f64();
let (near, far) = if diff <= 0.0 {
(node.left, node.right)
} else {
(node.right, node.left)
};
if let Some(child) = near {
self.recursive_search(child, query, radius_sq, results);
}
if diff * diff <= radius_sq
&& let Some(child) = far
{
self.recursive_search(child, query, radius_sq, results);
}
}
}
impl Node {
pub fn new(
split_axis: usize,
point_index: usize,
left: Option<usize>,
right: Option<usize>,
) -> Self {
Self {
split_axis,
point_index,
left,
right,
}
}
}