use crate::coords::{Coordinates, CoordinateMetric, CoordinateProximity};
use crate::distance::{Metric, Proximity};
use crate::util::Ordered;
use crate::{ExactNeighbors, NearestNeighbors, Neighborhood};
use std::iter::FromIterator;
use std::ops::Deref;
#[derive(Debug)]
struct KdNode<T> {
item: T,
left: Option<Box<Self>>,
right: Option<Box<Self>>,
}
impl<T: Coordinates> KdNode<T> {
fn new(item: T) -> Self {
Self {
item,
left: None,
right: None,
}
}
fn balanced<I: IntoIterator<Item = T>>(items: I) -> Option<Self> {
let mut nodes: Vec<_> = items
.into_iter()
.map(Self::new)
.map(Box::new)
.map(Some)
.collect();
Self::balanced_recursive(&mut nodes, 0)
.map(|node| *node)
}
fn balanced_recursive(nodes: &mut [Option<Box<Self>>], level: usize) -> Option<Box<Self>> {
if nodes.is_empty() {
return None;
}
nodes.sort_by_cached_key(|x| Ordered::new(x.as_ref().unwrap().item.coord(level)));
let (left, right) = nodes.split_at_mut(nodes.len() / 2);
let (node, right) = right.split_first_mut().unwrap();
let mut node = node.take().unwrap();
let next = (level + 1) % node.item.dims();
node.left = Self::balanced_recursive(left, next);
node.right = Self::balanced_recursive(right, next);
Some(node)
}
fn push(&mut self, item: T, level: usize) {
let next = (level + 1) % item.dims();
if item.coord(level) <= self.item.coord(level) {
if let Some(left) = &mut self.left {
left.push(item, next);
} else {
self.left = Some(Box::new(Self::new(item)));
}
} else {
if let Some(right) = &mut self.right {
right.push(item, next);
} else {
self.right = Some(Box::new(Self::new(item)));
}
}
}
}
pub trait KdProximity<V: ?Sized = Self>
where
Self: Coordinates<Value = V::Value>,
Self: Proximity<V>,
Self: CoordinateProximity<V::Value, Distance = <Self as Proximity<V>>::Distance>,
V: Coordinates,
{}
impl<K, V> KdProximity<V> for K
where
K: Coordinates<Value = V::Value>,
K: Proximity<V>,
K: CoordinateProximity<V::Value, Distance = <K as Proximity<V>>::Distance>,
V: Coordinates,
{}
pub trait KdMetric<V: ?Sized = Self>
where
Self: KdProximity<V>,
Self: Metric<V>,
Self: CoordinateMetric<V::Value>,
V: Coordinates,
{}
impl<K, V> KdMetric<V> for K
where
K: KdProximity<V>,
K: Metric<V>,
K: CoordinateMetric<V::Value>,
V: Coordinates,
{}
trait KdSearch<K, V, N>: Copy
where
K: KdProximity<V>,
V: Coordinates + Copy,
N: Neighborhood<K, V>,
{
fn item(self) -> V;
fn left(self) -> Option<Self>;
fn right(self) -> Option<Self>;
fn search(self, level: usize, closest: &mut [V::Value], neighborhood: &mut N) {
let item = self.item();
neighborhood.consider(item);
let target = neighborhood.target();
if target.coord(level) <= item.coord(level) {
self.search_near(self.left(), level, closest, neighborhood);
self.search_far(self.right(), level, closest, neighborhood);
} else {
self.search_near(self.right(), level, closest, neighborhood);
self.search_far(self.left(), level, closest, neighborhood);
}
}
fn search_near(self, near: Option<Self>, level: usize, closest: &mut [V::Value], neighborhood: &mut N) {
if let Some(near) = near {
let next = (level + 1) % self.item().dims();
near.search(next, closest, neighborhood);
}
}
fn search_far(self, far: Option<Self>, level: usize, closest: &mut [V::Value], neighborhood: &mut N) {
if let Some(far) = far {
let item = self.item();
let target = neighborhood.target();
let saved = std::mem::replace(&mut closest[level], item.coord(level));
if neighborhood.contains(target.distance_to_coords(closest)) {
let next = (level + 1) % item.dims();
far.search(next, closest, neighborhood);
}
closest[level] = saved;
}
}
}
impl<'a, K, V, N> KdSearch<K, &'a V, N> for &'a KdNode<V>
where
K: KdProximity<&'a V>,
V: Coordinates,
N: Neighborhood<K, &'a V>,
{
fn item(self) -> &'a V {
&self.item
}
fn left(self) -> Option<Self> {
self.left.as_ref().map(Box::deref)
}
fn right(self) -> Option<Self> {
self.right.as_ref().map(Box::deref)
}
}
#[derive(Debug)]
pub struct KdTree<T> {
root: Option<KdNode<T>>,
}
impl<T: Coordinates> KdTree<T> {
pub fn new() -> Self {
Self {
root: None,
}
}
pub fn balanced<I: IntoIterator<Item = T>>(items: I) -> Self {
Self {
root: KdNode::balanced(items),
}
}
pub fn balance(&mut self) {
let mut nodes = Vec::new();
if let Some(root) = self.root.take() {
nodes.push(Some(Box::new(root)));
}
let mut i = 0;
while i < nodes.len() {
let node = nodes[i].as_mut().unwrap();
let inside = node.left.take();
let outside = node.right.take();
if inside.is_some() {
nodes.push(inside);
}
if outside.is_some() {
nodes.push(outside);
}
i += 1;
}
self.root = KdNode::balanced_recursive(&mut nodes, 0)
.map(|node| *node);
}
pub fn push(&mut self, item: T) {
if let Some(root) = &mut self.root {
root.push(item, 0);
} else {
self.root = Some(KdNode::new(item));
}
}
}
impl<T: Coordinates> Extend<T> for KdTree<T> {
fn extend<I: IntoIterator<Item = T>>(&mut self, items: I) {
if self.root.is_some() {
for item in items {
self.push(item);
}
} else {
self.root = KdNode::balanced(items);
}
}
}
impl<T: Coordinates> FromIterator<T> for KdTree<T> {
fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
Self::balanced(items)
}
}
#[derive(Debug)]
pub struct IntoIter<T> {
stack: Vec<KdNode<T>>,
}
impl<T> IntoIter<T> {
fn new(node: Option<KdNode<T>>) -> Self {
Self {
stack: node.into_iter().collect(),
}
}
}
impl<T> Iterator for IntoIter<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
self.stack.pop().map(|node| {
if let Some(left) = node.left {
self.stack.push(*left);
}
if let Some(right) = node.right {
self.stack.push(*right);
}
node.item
})
}
}
impl<T> IntoIterator for KdTree<T> {
type Item = T;
type IntoIter = IntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
IntoIter::new(self.root)
}
}
impl<K, V> NearestNeighbors<K, V> for KdTree<V>
where
K: KdProximity<V>,
V: Coordinates,
{
fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N
where
K: 'k,
V: 'v,
N: Neighborhood<&'k K, &'v V>,
{
if let Some(root) = &self.root {
let mut closest = neighborhood.target().as_vec();
root.search(0, &mut closest, &mut neighborhood);
}
neighborhood
}
}
impl<K, V> ExactNeighbors<K, V> for KdTree<V>
where
K: KdMetric<V>,
V: Coordinates,
{}
#[derive(Debug)]
struct FlatKdNode<T> {
item: T,
left_len: usize,
}
impl<T: Coordinates> FlatKdNode<T> {
fn new(item: T) -> Self {
Self {
item,
left_len: 0,
}
}
fn balanced<I: IntoIterator<Item = T>>(items: I) -> Vec<Self> {
let mut nodes: Vec<_> = items
.into_iter()
.map(Self::new)
.collect();
Self::balance_recursive(&mut nodes, 0);
nodes
}
fn balance_recursive(nodes: &mut [Self], level: usize) {
if !nodes.is_empty() {
nodes.sort_by_cached_key(|x| Ordered::new(x.item.coord(level)));
let mid = nodes.len() / 2;
nodes.swap(0, mid);
let (node, children) = nodes.split_first_mut().unwrap();
let (left, right) = children.split_at_mut(mid);
node.left_len = left.len();
let next = (level + 1) % node.item.dims();
Self::balance_recursive(left, next);
Self::balance_recursive(right, next);
}
}
}
impl<'a, K, V, N> KdSearch<K, &'a V, N> for &'a [FlatKdNode<V>]
where
K: KdProximity<&'a V>,
V: Coordinates,
N: Neighborhood<K, &'a V>,
{
fn item(self) -> &'a V {
&self[0].item
}
fn left(self) -> Option<Self> {
let end = self[0].left_len + 1;
if end > 1 {
Some(&self[1..end])
} else {
None
}
}
fn right(self) -> Option<Self> {
let start = self[0].left_len + 1;
if start < self.len() {
Some(&self[start..])
} else {
None
}
}
}
#[derive(Debug)]
pub struct FlatKdTree<T> {
nodes: Vec<FlatKdNode<T>>,
}
impl<T: Coordinates> FlatKdTree<T> {
pub fn balanced<I: IntoIterator<Item = T>>(items: I) -> Self {
Self {
nodes: FlatKdNode::balanced(items),
}
}
}
impl<T: Coordinates> FromIterator<T> for FlatKdTree<T> {
fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
Self::balanced(items)
}
}
#[derive(Debug)]
pub struct FlatIntoIter<T>(std::vec::IntoIter<FlatKdNode<T>>);
impl<T> Iterator for FlatIntoIter<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
self.0.next().map(|n| n.item)
}
}
impl<T> IntoIterator for FlatKdTree<T> {
type Item = T;
type IntoIter = FlatIntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
FlatIntoIter(self.nodes.into_iter())
}
}
impl<K, V> NearestNeighbors<K, V> for FlatKdTree<V>
where
K: KdProximity<V>,
V: Coordinates,
{
fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N
where
K: 'k,
V: 'v,
N: Neighborhood<&'k K, &'v V>,
{
if !self.nodes.is_empty() {
let mut closest = neighborhood.target().as_vec();
self.nodes.as_slice().search(0, &mut closest, &mut neighborhood);
}
neighborhood
}
}
impl<K, V> ExactNeighbors<K, V> for FlatKdTree<V>
where
K: KdMetric<V>,
V: Coordinates,
{}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::test_nearest_neighbors;
#[test]
fn test_kd_tree() {
test_nearest_neighbors(KdTree::from_iter);
}
#[test]
fn test_unbalanced_kd_tree() {
test_nearest_neighbors(|points| {
let mut tree = KdTree::new();
for point in points {
tree.push(point);
}
tree
});
}
#[test]
fn test_flat_kd_tree() {
test_nearest_neighbors(FlatKdTree::from_iter);
}
}