use az::{Az, Cast};
use divrem::DivCeil;
use fixed::traits::Fixed;
use std::cmp::PartialEq;
use std::fmt::Debug;
use crate::iter::TreeIter;
use crate::{
iter::IterableTreeData,
traits::{Content, Index},
};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
pub trait Axis: Fixed + Default + Debug + Copy + Sync + Send {
fn rd_update(rd: Self, delta: Self) -> Self;
}
impl<T: Fixed + Default + Debug + Copy + Sync + Send> Axis for T {
#[inline]
fn rd_update(rd: Self, delta: Self) -> Self {
rd.saturating_add(delta)
}
}
#[cfg(feature = "rkyv")]
pub trait AxisRK: num_traits::Zero + Default + Debug + rkyv::Archive {}
#[cfg(feature = "rkyv")]
impl<T: num_traits::Zero + Default + Debug + rkyv::Archive> AxisRK for T {}
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
#[cfg(feature = "rkyv")]
pub struct KdTreeRK<
A: num_traits::PrimInt,
T: Content,
const K: usize,
const B: usize,
IDX: Index<T = IDX>,
> {
pub(crate) leaves: Vec<LeafNodeRK<A, T, K, B, IDX>>,
pub(crate) stems: Vec<StemNodeRK<A, K, IDX>>,
pub(crate) root_index: IDX,
pub(crate) size: T,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Debug, PartialEq)]
pub struct KdTree<A: Copy + Default, T: Copy + Default, const K: usize, const B: usize, IDX> {
pub(crate) leaves: Vec<LeafNode<A, T, K, B, IDX>>,
pub(crate) stems: Vec<StemNode<A, K, IDX>>,
pub(crate) root_index: IDX,
pub(crate) size: T,
}
#[doc(hidden)]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
#[cfg(feature = "rkyv")]
pub struct StemNodeRK<A: num_traits::PrimInt, const K: usize, IDX: Index<T = IDX>> {
pub(crate) left: IDX,
pub(crate) right: IDX,
pub(crate) split_val: A,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct StemNode<A: Copy + Default, const K: usize, IDX> {
pub(crate) left: IDX,
pub(crate) right: IDX,
pub(crate) split_val: A,
}
#[doc(hidden)]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
#[cfg(feature = "rkyv")]
pub struct LeafNodeRK<
A: num_traits::PrimInt,
T: Content,
const K: usize,
const B: usize,
IDX: Index<T = IDX>,
> {
pub(crate) content_points: [[A; K]; B],
pub(crate) content_items: [T; B],
pub(crate) size: IDX,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct LeafNode<
A: Copy + Default,
T: Copy + Default,
const K: usize,
const B: usize,
IDX,
> {
#[cfg_attr(
feature = "serde",
serde(with = "crate::custom_serde::array_of_arrays")
)]
#[cfg_attr(
feature = "serde",
serde(bound(
serialize = "A: Serialize",
deserialize = "A: Deserialize<'de> + Copy + Default"
))
)]
pub(crate) content_points: [[A; K]; B],
#[cfg_attr(feature = "serde", serde(with = "crate::custom_serde::array"))]
#[cfg_attr(
feature = "serde",
serde(bound(
serialize = "A: Serialize, T: Serialize",
deserialize = "A: Deserialize<'de>, T: Deserialize<'de> + Copy + Default"
))
)]
pub(crate) content_items: [T; B],
pub(crate) size: IDX,
}
impl<A, T, const K: usize, const B: usize, IDX> LeafNode<A, T, K, B, IDX>
where
A: Axis,
T: Content,
IDX: Index<T = IDX>,
{
pub(crate) fn new() -> Self {
Self {
content_points: [[A::ZERO; K]; B],
content_items: [T::zero(); B],
size: IDX::zero(),
}
}
}
impl<A, T, const K: usize, const B: usize, IDX> Default for KdTree<A, T, K, B, IDX>
where
A: Axis,
T: Content,
IDX: Index<T = IDX>,
usize: Cast<IDX>,
{
fn default() -> Self {
Self::new()
}
}
impl<A, T, const K: usize, const B: usize, IDX> KdTree<A, T, K, B, IDX>
where
A: Axis,
T: Content,
IDX: Index<T = IDX>,
usize: Cast<IDX>,
{
#[inline]
pub fn new() -> Self {
KdTree::with_capacity(B * 10)
}
#[inline]
pub fn with_capacity(capacity: usize) -> Self {
assert!(capacity <= <IDX as Index>::capacity_with_bucket_size(B));
let mut tree = Self {
size: T::zero(),
stems: Vec::with_capacity(capacity.max(1).ilog2() as usize),
leaves: Vec::with_capacity(DivCeil::div_ceil(capacity, B.az::<usize>())),
root_index: <IDX as Index>::leaf_offset(),
};
tree.leaves.push(LeafNode::new());
tree
}
#[inline]
pub fn size(&self) -> T {
self.size
}
pub fn iter(&self) -> impl Iterator<Item = (T, [A; K])> + '_ {
TreeIter::new(self, B)
}
}
impl<A: Axis, T: Content, const K: usize, const B: usize, IDX: Index<T = IDX>>
IterableTreeData<A, T, K> for KdTree<A, T, K, B, IDX>
{
fn get_leaf_data(&self, idx: usize, out: &mut Vec<(T, [A; K])>) -> Option<usize> {
let leaf = self.leaves.get(idx)?;
let max = leaf.size.cast();
out.extend(
leaf.content_items
.iter()
.cloned()
.zip(leaf.content_points.iter().cloned())
.take(max),
);
Some(max)
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use fixed::types::extra::U14;
use fixed::FixedU16;
use crate::fixed::kdtree::KdTree;
type Fxd = FixedU16<U14>;
#[test]
fn it_can_be_constructed_with_new() {
let tree: KdTree<Fxd, u32, 4, 32, u32> = KdTree::new();
assert_eq!(tree.size(), 0);
}
#[test]
fn it_can_be_constructed_with_a_defined_capacity() {
let tree: KdTree<Fxd, u32, 4, 32, u32> = KdTree::with_capacity(10);
assert_eq!(tree.size(), 0);
}
#[test]
fn it_can_be_constructed_with_a_capacity_of_zero() {
let tree: KdTree<Fxd, u32, 4, 32, u32> = KdTree::with_capacity(0);
assert_eq!(tree.size(), 0);
}
#[cfg(feature = "serde")]
#[test]
fn can_serde() {
let mut tree: KdTree<Fxd, u32, 4, 32, u32> = KdTree::new();
let content_to_add: [([Fxd; 4], u32); 16] = [
(
[
Fxd::from_num(0.9),
Fxd::from_num(0),
Fxd::from_num(0.9),
Fxd::from_num(0),
],
9,
),
(
[
Fxd::from_num(0.4),
Fxd::from_num(0.5),
Fxd::from_num(0.4),
Fxd::from_num(0.50),
],
4,
),
(
[
Fxd::from_num(0.12),
Fxd::from_num(0.3),
Fxd::from_num(0.12),
Fxd::from_num(0.3),
],
12,
),
(
[
Fxd::from_num(0.7),
Fxd::from_num(0.2),
Fxd::from_num(0.7),
Fxd::from_num(0.2),
],
7,
),
(
[
Fxd::from_num(0.13),
Fxd::from_num(0.4),
Fxd::from_num(0.13),
Fxd::from_num(0.4),
],
13,
),
(
[
Fxd::from_num(0.6),
Fxd::from_num(0.3),
Fxd::from_num(0.6),
Fxd::from_num(0.3),
],
6,
),
(
[
Fxd::from_num(0.2),
Fxd::from_num(0.7),
Fxd::from_num(0.2),
Fxd::from_num(0.7),
],
2,
),
(
[
Fxd::from_num(0.14),
Fxd::from_num(0.5),
Fxd::from_num(0.14),
Fxd::from_num(0.5),
],
14,
),
(
[
Fxd::from_num(0.3),
Fxd::from_num(0.6),
Fxd::from_num(0.3),
Fxd::from_num(0.6),
],
3,
),
(
[
Fxd::from_num(0.1),
Fxd::from_num(0.1),
Fxd::from_num(0.10),
Fxd::from_num(0.1),
],
10,
),
(
[
Fxd::from_num(0.16),
Fxd::from_num(0.7),
Fxd::from_num(0.16),
Fxd::from_num(0.7),
],
16,
),
(
[
Fxd::from_num(0.1),
Fxd::from_num(0.8),
Fxd::from_num(0.1),
Fxd::from_num(0.8),
],
1,
),
(
[
Fxd::from_num(0.15),
Fxd::from_num(0.6),
Fxd::from_num(0.15),
Fxd::from_num(0.6),
],
15,
),
(
[
Fxd::from_num(0.5),
Fxd::from_num(0.4),
Fxd::from_num(0.5),
Fxd::from_num(0.4),
],
5,
),
(
[
Fxd::from_num(0.8),
Fxd::from_num(0.1),
Fxd::from_num(0.8),
Fxd::from_num(0.1),
],
8,
),
(
[
Fxd::from_num(0.11),
Fxd::from_num(0.2),
Fxd::from_num(0.11),
Fxd::from_num(0.2),
],
11,
),
];
for (point, item) in content_to_add {
tree.add(&point, item);
}
assert_eq!(tree.size(), 16);
let serialized = serde_json::to_string(&tree).unwrap();
println!("JSON: {:?}", &serialized);
let deserialized: KdTree<Fxd, u32, 4, 32, u32> = serde_json::from_str(&serialized).unwrap();
assert_eq!(tree, deserialized);
}
#[test]
fn can_iterate() {
let mut tree: KdTree<Fxd, u32, 2, 2, u32> = KdTree::new();
let content_to_add: Vec<(u32, [Fxd; 2])> = vec![
(9, [Fxd::from_num(0.9), Fxd::from_num(0)]),
(4, [Fxd::from_num(0.4), Fxd::from_num(0.5)]),
(12, [Fxd::from_num(0.12), Fxd::from_num(0.3)]),
];
let mut expected: HashMap<u32, _> = HashMap::default();
for (item, point) in content_to_add {
tree.add(&point, item);
expected.insert(item, point);
}
let actual: HashMap<u32, _> = tree.iter().collect();
assert_eq!(actual, expected);
}
}