use num_traits::Num;
use std::iter::FromIterator;
use crate::coords::{Point, PointAxial};
use crate::quad_prism::QuadPrism;
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct KdTree<V, T> {
data: Vec<Option<NodeData<V, T>>>,
count: usize,
}
impl<V, T> Default for KdTree<V, T> {
fn default() -> Self {
KdTree::new()
}
}
impl<V, T> KdTree<V, T> {
pub fn new() -> Self {
KdTree {
data: Vec::new(),
count: 0,
}
}
}
impl<V, T> KdTree<V, T>
where
V: Copy + Num + PartialOrd,
{
pub fn count(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn max_depth(&self) -> usize {
if self.data.is_empty() {
0
} else {
NodeRef(self.data.len() - 1).depth() + 1
}
}
pub fn exact_depth(&self) -> usize {
for (i, data) in self.data.iter().enumerate().rev() {
if data.is_some() {
return NodeRef(i).depth() + 1;
}
}
0
}
fn root(&self) -> Option<NodeRef> {
if self.is_empty() {
None
} else {
Some(NodeRef(0))
}
}
fn resize_if_less(&mut self, len: usize) {
if len >= self.data.len() {
if len >= self.data.capacity() {
self.data.reserve(len - self.data.capacity() + 1);
}
for _ in 0..=(len - self.data.len()) {
self.data.push(None);
}
}
}
fn node_data(&self, node: NodeRef) -> Option<&Option<NodeData<V, T>>> {
self.data.get(node.0)
}
fn node_data_mut(&mut self, node: NodeRef) -> &mut Option<NodeData<V, T>> {
self.resize_if_less(node.0);
if let Some(option) = self.data.get_mut(node.0) {
option
} else {
panic!("cannot get node even after resize");
}
}
fn drain_subtree(&mut self, node: NodeRef) -> Vec<NodeData<V, T>> {
let max_depth = NodeRef(self.data.len() - 1).depth() - node.depth() + 1;
let mut vec = Vec::with_capacity((1 << max_depth) - 1);
let mut cur = node;
let mut width = 1;
for _ in 0..max_depth {
for i in 0..width {
if let Some(opt) = self.data.get_mut(cur.0 + i) {
if let Some(node) = opt.take() {
vec.push(node);
}
}
}
cur = cur.left();
width *= 2;
}
self.count -= vec.len();
vec
}
fn batch_insert_at<I: IntoIterator<Item = (Geom<V>, T)>>(&mut self, node: NodeRef, it: I) {
struct PartitionResult<'a, V, T> {
pivot: &'a mut (Geom<V>, T),
left: &'a mut [(Geom<V>, T)],
right: &'a mut [(Geom<V>, T)],
}
fn median_of_three<V, T>(slice: &[(Geom<V>, T)], axis: Axis) -> (usize, V)
where
V: Copy + Num + PartialOrd,
{
fn idx_and_value<V, T>(slice: &[(Geom<V>, T)], axis: Axis, idx: usize) -> (usize, V)
where
V: Copy + Num + PartialOrd,
{
(idx, slice[idx].0.base_of_axis(axis))
}
if slice.is_empty() {
panic!("slice is empty");
}
let mut arr: [(usize, V); 3] = [(0, V::zero()); 3];
arr[0] = idx_and_value(slice, axis, 0);
arr[1] = idx_and_value(slice, axis, slice.len() / 2);
arr[2] = idx_and_value(slice, axis, slice.len() - 1);
if arr[2].1 < arr[0].1 {
arr.swap(2, 0);
}
if arr[1].1 < arr[0].1 {
arr.swap(1, 0);
}
if arr[2].1 < arr[1].1 {
arr.swap(2, 1);
}
arr[1]
}
fn partition<V, T>(slice: &mut [(Geom<V>, T)], axis: Axis) -> PartitionResult<V, T>
where
V: Copy + Num + PartialOrd,
{
if slice.is_empty() {
panic!("slice is empty");
}
let (mid_idx, mid_val) = median_of_three(slice, axis);
slice.swap(0, mid_idx);
let (pivot, slice) = slice.split_first_mut().expect("slice is empty");
if slice.is_empty() {
let (left, right) = slice.split_at_mut(0);
return PartitionResult { pivot, left, right };
}
let mut start = 0;
let mut end = slice.len() - 1;
let mut cur = 0;
while cur < end {
if slice[cur].0.base_of_axis(axis) < mid_val {
slice.swap(start, cur);
start += 1;
cur += 1;
} else if slice[cur].0.base_of_axis(axis) >= mid_val {
slice.swap(end, cur);
end -= 1;
}
}
if slice[cur].0.base_of_axis(axis) < mid_val {
cur += 1;
}
let (left, right) = slice.split_at_mut(cur);
PartitionResult { pivot, left, right }
}
fn recursive_insert<V, T>(
tree: &mut KdTree<V, T>,
node: NodeRef,
slice: &mut [(Geom<V>, Option<T>)],
) where
V: Copy + Num + PartialOrd,
{
let PartitionResult { pivot, left, right } =
partition(slice, Axis::of_depth(node.depth()));
let geom = pivot.0;
let data = pivot.1.take().expect("is nothing");
tree.resize_if_less(node.0);
tree.node_data_mut(node).replace(NodeData { geom, data });
if !left.is_empty() {
recursive_insert(tree, node.left(), left);
}
if !right.is_empty() {
recursive_insert(tree, node.right(), right);
}
}
let mut vec: Vec<(Geom<V>, Option<T>)> =
it.into_iter().map(|(g, t)| (g, Some(t))).collect();
if vec.is_empty() {
return;
}
self.count += vec.len();
recursive_insert(self, node, vec.as_mut_slice())
}
pub fn insert<G>(&mut self, geom: G, data: T)
where
G: Into<Geom<V>>,
{
let geom = geom.into();
if self.is_empty() {
self.resize_if_less(1);
self.data[0].replace(NodeData { geom, data });
self.count += 1;
return;
}
let mut node = self.root().expect("tree is not empty");
let mut axis = Axis::X;
while let Some(data) = self.node_data_mut(node) {
let cur_base = data.geom.base_of_axis(axis);
let insert_base = geom.base_of_axis(axis);
if insert_base < cur_base {
node = node.left();
} else {
node = node.right();
}
axis = axis.rotate();
}
self.count += 1;
self.resize_if_less(node.0);
self.node_data_mut(node).replace(NodeData { geom, data });
}
fn find<F>(&self, mut predicate: F) -> Option<NodeRef>
where
F: FnMut(&T) -> bool,
{
for (i, item) in self.data.iter().enumerate() {
if let Some(node) = item {
if predicate(&node.data) {
return Some(NodeRef(i));
}
}
}
None
}
pub fn get<F>(&self, predicate: F) -> Option<(&Geom<V>, &T)>
where
F: FnMut(&T) -> bool,
{
if let Some(node_ref) = self.find(predicate) {
self.data[node_ref.0]
.as_ref()
.map(|node| (&node.geom, &node.data))
} else {
None
}
}
pub fn get_mut<F>(&mut self, predicate: F) -> Option<(&Geom<V>, &mut T)>
where
F: FnMut(&T) -> bool,
{
if let Some(node_ref) = self.find(predicate) {
self.data[node_ref.0]
.as_mut()
.map(|node| (&node.geom, &mut node.data))
} else {
None
}
}
pub fn remove<F>(&mut self, predicate: F) -> Option<(Geom<V>, T)>
where
F: FnMut(&T) -> bool,
{
if let Some(node) = self.find(predicate) {
let mut subtree = self.drain_subtree(node);
let root = subtree.swap_remove(0);
self.batch_insert_at(node, subtree.into_iter().map(|d| (d.geom, d.data)));
Some((root.geom, root.data))
} else {
None
}
}
pub fn query<G>(&self, geom: G) -> Query<V, T>
where
G: Into<Geom<V>>,
{
Query {
tree: &self,
stack: vec![NodeRef(0)],
geom: geom.into(),
}
}
pub fn iter(&self) -> Iter<V, T> {
Iter(self.data.iter())
}
}
impl<V, T: PartialEq> KdTree<V, T>
where
V: Copy + Num + PartialOrd,
{
pub fn get_at_data(&self, data: &T) -> Option<(&Geom<V>, &T)> {
self.get(|other| other == data)
}
pub fn get_mut_at_data(&mut self, data: &T) -> Option<(&Geom<V>, &mut T)> {
self.get_mut(|other| other == data)
}
pub fn remove_at_data(&mut self, data: &T) -> Option<(Geom<V>, T)> {
self.remove(|other| other == data)
}
}
#[derive(Debug)]
pub struct Query<'a, V, T> {
tree: &'a KdTree<V, T>,
stack: Vec<NodeRef>,
geom: Geom<V>,
}
impl<'a, V, T> Iterator for Query<'a, V, T>
where
V: Copy + Num + PartialOrd,
{
type Item = (&'a Geom<V>, &'a T);
fn next(&mut self) -> Option<Self::Item> {
while let Some(top) = self.stack.pop() {
if let Some(Some(node)) = self.tree.node_data(top) {
self.stack.push(top.left());
self.stack.push(top.right());
if self.geom.matches(&node.geom) {
return Some((&node.geom, &node.data));
}
}
}
None
}
fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(self.tree.count()))
}
}
#[derive(Debug)]
pub struct Iter<'a, V, T>(std::slice::Iter<'a, Option<NodeData<V, T>>>);
impl<'a, V, T> Iterator for Iter<'a, V, T>
where
V: Copy + Num + PartialOrd,
{
type Item = (Geom<V>, &'a T);
fn next(&mut self) -> Option<Self::Item> {
while let Some(item) = self.0.next() {
if let Some(node) = item {
return Some((node.geom, &node.data));
}
}
None
}
}
#[derive(Debug)]
pub struct IntoIter<V, T>(std::vec::IntoIter<Option<NodeData<V, T>>>);
impl<V, T> IntoIterator for KdTree<V, T>
where
V: Copy + Num + PartialOrd,
{
type IntoIter = IntoIter<V, T>;
type Item = (Geom<V>, T);
fn into_iter(self) -> Self::IntoIter {
IntoIter(self.data.into_iter())
}
}
impl<V, T> Iterator for IntoIter<V, T>
where
V: Copy + Num + PartialOrd,
{
type Item = (Geom<V>, T);
fn next(&mut self) -> Option<Self::Item> {
while let Some(item) = self.0.next() {
if let Some(node) = item {
return Some((node.geom, node.data));
}
}
None
}
fn size_hint(&self) -> (usize, Option<usize>) {
let (_, max) = self.0.size_hint();
(0, max)
}
}
impl<V, T> FromIterator<(Geom<V>, T)> for KdTree<V, T>
where
V: Copy + Num + PartialOrd,
{
fn from_iter<I: IntoIterator<Item = (Geom<V>, T)>>(into_iter: I) -> Self {
let iter = into_iter.into_iter();
let mut tree = KdTree::new();
let (min, _) = iter.size_hint();
tree.resize_if_less(min);
tree.batch_insert_at(NodeRef(0), iter);
tree
}
}
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub enum Geom<V> {
Point(PointAxial<V>),
Prism(QuadPrism<V>),
}
impl<V> Geom<V>
where
V: Copy + Num + PartialOrd,
{
fn matches(&self, other: &Geom<V>) -> bool {
match self {
Geom::Point(p) => match other {
Geom::Point(p2) => *p == *p2,
Geom::Prism(r2) => r2.contains_axial(p),
},
Geom::Prism(r) => match other {
Geom::Point(p2) => r.contains_axial(p2),
Geom::Prism(r2) => r.intersects(r2),
},
}
}
}
impl<V> From<PointAxial<V>> for Geom<V> {
fn from(p: PointAxial<V>) -> Self {
Geom::Point(p)
}
}
impl<V> From<Point<V>> for Geom<V> {
fn from(p: Point<V>) -> Self {
Geom::Point(p.into_axial())
}
}
impl<V> From<QuadPrism<V>> for Geom<V> {
fn from(r: QuadPrism<V>) -> Self {
Geom::Prism(r)
}
}
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
enum Axis {
X,
Y,
W,
}
impl Axis {
fn of_depth(n: usize) -> Self {
match n % 3 {
0 => Axis::X,
1 => Axis::Y,
2 => Axis::W,
_ => unreachable!(),
}
}
fn rotate(self) -> Self {
match self {
Axis::X => Axis::Y,
Axis::Y => Axis::W,
Axis::W => Axis::X,
}
}
}
impl<V> Geom<V>
where
V: Copy + Num + PartialOrd,
{
fn base_of_axis(&self, axis: Axis) -> V {
match self {
Geom::Point(point) => match axis {
Axis::X => point.x,
Axis::Y => point.y,
Axis::W => point.w,
},
Geom::Prism(rect) => match axis {
Axis::X => rect.center_axial().x,
Axis::Y => rect.center_axial().y,
Axis::W => rect.center_axial().w,
},
}
}
#[allow(dead_code)]
fn min_of_axis(&self, axis: Axis) -> V {
match self {
Geom::Point(point) => match axis {
Axis::X => point.x,
Axis::Y => point.y,
Axis::W => point.w,
},
Geom::Prism(rect) => match axis {
Axis::X => rect.low.x,
Axis::Y => rect.low.y,
Axis::W => rect.low.w,
},
}
}
#[allow(dead_code)]
fn max_of_axis(&self, axis: Axis) -> V {
match self {
Geom::Point(point) => match axis {
Axis::X => point.x,
Axis::Y => point.y,
Axis::W => point.w,
},
Geom::Prism(rect) => match axis {
Axis::X => rect.high.x - V::one(),
Axis::Y => rect.high.y - V::one(),
Axis::W => rect.high.w - V::one(),
},
}
}
}
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
struct NodeData<V, T> {
geom: Geom<V>,
data: T,
}
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
struct NodeRef(usize);
impl NodeRef {
#[allow(dead_code)]
fn parent(self) -> Option<Self> {
match self.0 {
0 => None,
n => Some(NodeRef(n >> 1)),
}
}
fn left(self) -> Self {
NodeRef((self.0 << 1) + 1)
}
fn right(self) -> Self {
NodeRef((self.0 << 1) + 2)
}
fn depth(self) -> usize {
use std::mem::size_of;
(size_of::<usize>() * 8) - ((self.0 + 1).leading_zeros() as usize) - 1
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{PointAxial, QuadPrism, VectorAxial};
#[test]
fn can_insert_remove() {
let mut tree = KdTree::new();
assert_eq!(0, tree.count());
assert!(tree.is_empty());
tree.insert(
QuadPrism::from_base_size_axial(
PointAxial::new(-100, -100, 0),
VectorAxial::new(200, 200, 1),
),
1,
);
tree.insert(PointAxial::new(0, 0, 0), 2);
tree.insert(PointAxial::new(10, -10, 0), 3);
tree.insert(PointAxial::new(-2, 4, 0), 4);
tree.insert(PointAxial::new(1, -1, 0), 5);
assert_eq!(5, tree.count());
assert_eq!(false, tree.is_empty());
tree.remove(|i| *i > 4);
tree.remove_at_data(&2);
assert_eq!(3, tree.count());
assert_eq!(false, tree.is_empty());
}
#[test]
fn can_iter() {
let mut tree = KdTree::new();
tree.insert(
QuadPrism::from_base_size_axial(
PointAxial::new(-100, -100, 0),
VectorAxial::new(200, 200, 1),
),
1,
);
tree.insert(PointAxial::new(0, 0, 0), 2);
tree.insert(PointAxial::new(10, -10, 0), 3);
tree.insert(PointAxial::new(-2, 4, 0), 4);
tree.insert(PointAxial::new(1, -1, 0), 5);
tree.remove(|i| *i > 4);
tree.remove_at_data(&2);
assert_eq!(8, tree.iter().fold(0, |a, (_, x)| a + x));
assert_eq!(8, tree.into_iter().fold(0, |a, (_, x)| a + x));
}
#[test]
fn can_query() {
let mut tree = KdTree::new();
tree.insert(
QuadPrism::from_base_size_axial(
PointAxial::new(-100, -100, 0),
VectorAxial::new(200, 200, 1),
),
1,
);
tree.insert(PointAxial::new(0, 0, 0), 2);
tree.insert(PointAxial::new(10, -10, 0), 3);
tree.insert(PointAxial::new(-2, 4, 0), 4);
tree.insert(PointAxial::new(1, -1, 0), 5);
let query = tree.query(QuadPrism::from_base_size_axial(
PointAxial::new(-10, -2, 0),
VectorAxial::new(12, 8, 1),
));
let mut value = 0;
let mut count = 0;
for (_, n) in query {
value += *n;
count += 1;
}
assert_eq!(4, count);
assert_eq!(12, value);
}
#[test]
fn can_from_iter() {
let tree = KdTree::from_iter(vec![
(
QuadPrism::from_base_size_axial(
PointAxial::new(-100, -100, 0),
VectorAxial::new(200, 200, 1),
)
.into(),
1,
),
(PointAxial::new(0, 0, 0).into(), 2),
(PointAxial::new(10, -10, 0).into(), 3),
(PointAxial::new(1, -1, 0).into(), 4),
(PointAxial::new(-2, 4, 0).into(), 5),
(PointAxial::new(-1, 1, 0).into(), 6),
(PointAxial::new(-9, 2, 0).into(), 7),
]);
assert_eq!(28, tree.iter().fold(0, |a, (_, x)| a + x));
let rebuilt_tree = KdTree::from_iter(tree);
assert_eq!(28, rebuilt_tree.iter().fold(0, |a, (_, x)| a + x));
}
}