use std::cmp::Ordering;
use ndarray::{Array2, ArrayBase, ArrayView2, AsArray, Axis, Ix2, ViewRepr};
use crate::error::ImgalError;
use crate::traits::numeric::AsNumeric;
pub struct KDTree<'a, T> {
pub cloud: ArrayView2<'a, T>,
pub nodes: Vec<Node>,
pub root: Option<usize>,
}
pub struct Node {
pub split_axis: usize,
pub point_index: usize,
pub left: Option<usize>,
pub right: Option<usize>,
}
impl<'a, T> KDTree<'a, T>
where
T: AsNumeric,
{
pub fn build<A>(cloud: A) -> Self
where
A: AsArray<'a, T, Ix2>,
{
let view: ArrayBase<ViewRepr<&'a T>, Ix2> = cloud.into();
let mut tree = Self {
cloud: view,
nodes: Vec::new(),
root: None,
};
let total_points = view.dim().0;
let indices: Vec<usize> = (0..total_points).collect();
tree.root = tree.recursive_build(&indices, 0);
tree
}
pub fn search_for_coords(&self, query: &[T], radius: f64) -> Result<Array2<T>, ImgalError> {
let q_dims = query.len();
let c_dims = self.cloud.dim().1;
if q_dims != c_dims {
return Err(ImgalError::MismatchedArrayLengths {
a_arr_name: "query",
a_arr_len: q_dims,
b_arr_name: "cloud array shape",
b_arr_len: c_dims,
});
}
let coord_indices = self.search_for_indices(query, radius).unwrap();
Ok(self.cloud.select(Axis(0), &coord_indices))
}
pub fn search_for_indices(&self, query: &[T], radius: f64) -> Result<Vec<usize>, ImgalError> {
let q_dims = query.len();
let c_dims = self.cloud.dim().1;
if q_dims != c_dims {
return Err(ImgalError::MismatchedArrayLengths {
a_arr_name: "query",
a_arr_len: q_dims,
b_arr_name: "cloud array shape",
b_arr_len: c_dims,
});
}
let radius_sq = radius.powi(2);
let mut results: Vec<usize> = Vec::new();
if let Some(root) = self.root {
self.recursive_search(root, q_dims, query, radius_sq, &mut results);
}
Ok(results)
}
fn recursive_build(&mut self, indices: &[usize], depth: usize) -> Option<usize> {
if indices.is_empty() {
return None;
}
let n_dims = self.cloud.dim().1;
let split_axis = depth % n_dims;
let mut inds_sorted = indices.to_vec();
inds_sorted.sort_by(|&a, &b| {
self.cloud[[a, split_axis]]
.partial_cmp(&self.cloud[[b, split_axis]])
.unwrap_or(Ordering::Less)
});
let median = inds_sorted.len() / 2;
let point_index = inds_sorted[median];
let left = self.recursive_build(&inds_sorted[..median], depth + 1);
let right = self.recursive_build(&inds_sorted[median + 1..], depth + 1);
let node_index = self.nodes.len();
self.nodes
.push(Node::new(split_axis, point_index, left, right));
Some(node_index)
}
fn recursive_search(
&self,
node_index: usize,
n_dims: usize,
query: &[T],
radius_sq: f64,
results: &mut Vec<usize>,
) {
let node = &self.nodes[node_index];
let mut node_point: Vec<T> = Vec::with_capacity(n_dims);
(0..n_dims).for_each(|k| {
node_point.push(self.cloud[[node.point_index, k]]);
});
let node_dist_sq = node_point
.iter()
.zip(query.iter())
.fold(0.0, |acc, (&n, &q)| {
let d = n.to_f64() - q.to_f64();
acc + (d * d)
});
if node_dist_sq <= radius_sq {
results.push(node.point_index);
}
let ax = node.split_axis;
let diff = query[ax].to_f64() - node_point[ax].to_f64();
let (near, far) = if diff <= 0.0 {
(node.left, node.right)
} else {
(node.right, node.left)
};
if let Some(child) = near {
self.recursive_search(child, n_dims, query, radius_sq, results);
}
if diff.powi(2) <= radius_sq {
if let Some(child) = far {
self.recursive_search(child, n_dims, query, radius_sq, results);
}
}
}
}
impl Node {
pub fn new(
split_axis: usize,
point_index: usize,
left: Option<usize>,
right: Option<usize>,
) -> Self {
Self {
split_axis,
point_index,
left,
right,
}
}
}