pub use crate::float::kdtree::Axis;
use crate::float_leaf_slice::leaf_slice::{LeafSlice, LeafSliceFloat, LeafSliceFloatChunk};
#[cfg(feature = "modified_van_emde_boas")]
use crate::modified_van_emde_boas::modified_van_emde_boas_get_child_idx_v2_branchless;
use crate::traits::Content;
use aligned_vec::{avec, AVec, ConstAlign, CACHELINE_ALIGN};
use array_init::array_init;
use az::{Az, Cast};
use cmov::Cmov;
use ordered_float::OrderedFloat;
#[cfg(feature = "rkyv")]
use rkyv::vec::ArchivedVec;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::cmp::PartialEq;
use std::fmt::Debug;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Debug, PartialEq)]
pub struct ImmutableKdTree<A: Copy + Default, T: Copy + Default, const K: usize, const B: usize> {
pub(crate) stems: AVec<A>,
#[cfg_attr(feature = "serde", serde(with = "crate::custom_serde::array_of_vecs"))]
#[cfg_attr(
feature = "serde",
serde(bound(
serialize = "A: Serialize, T: Serialize",
deserialize = "A: Deserialize<'de>, T: Deserialize<'de> + Copy + Default"
))
)]
pub(crate) leaf_points: [Vec<A>; K],
pub(crate) leaf_items: Vec<T>,
pub(crate) leaf_extents: Vec<(u32, u32)>,
pub(crate) max_stem_level: i32,
}
#[cfg(feature = "rkyv")]
#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
pub struct ImmutableKdTreeRK<A: Copy + Default, T: Copy + Default, const K: usize, const B: usize> {
pub(crate) stems: Vec<A>,
pub(crate) leaf_points: [Vec<A>; K],
pub(crate) leaf_items: Vec<T>,
pub(crate) leaf_extents: Vec<(u32, u32)>,
pub(crate) max_stem_level: i32,
}
#[cfg(feature = "rkyv")]
impl<A: Axis, T: Content, const K: usize, const B: usize> From<ImmutableKdTree<A, T, K, B>>
for ImmutableKdTreeRK<A, T, K, B>
where
A: Axis + LeafSliceFloat<T> + LeafSliceFloatChunk<T, K>,
T: Content,
usize: Cast<T>,
{
fn from(orig: ImmutableKdTree<A, T, K, B>) -> Self {
let ImmutableKdTree {
stems,
leaf_points,
leaf_items,
leaf_extents,
max_stem_level,
} = orig;
let (ptr, _, length, capacity) = stems.into_raw_parts();
let stems = unsafe { Vec::from_raw_parts(ptr, length, capacity) };
ImmutableKdTreeRK {
stems,
leaf_points,
leaf_items,
leaf_extents,
max_stem_level,
}
}
}
#[cfg(feature = "rkyv")]
#[derive(Debug, PartialEq)]
pub struct AlignedArchivedImmutableKdTree<
'a,
A: Copy + Default,
T: Copy + Default,
const K: usize,
const B: usize,
> {
pub(crate) stems: AVec<A, ConstAlign<CACHELINE_ALIGN>>,
pub(crate) leaf_points: &'a [ArchivedVec<A>; K],
pub(crate) leaf_items: &'a ArchivedVec<T>,
pub(crate) leaf_extents: &'a ArchivedVec<(u32, u32)>,
pub(crate) max_stem_level: i32,
}
#[cfg(feature = "rkyv")]
impl<
'a,
A: Copy + Default + rkyv::Archive<Archived = A>,
T: Copy + Default + rkyv::Archive<Archived = T>,
const K: usize,
const B: usize,
> AlignedArchivedImmutableKdTree<'a, A, T, K, B>
{
pub(crate) fn new_from(
value: &'a ArchivedImmutableKdTreeRK<A, T, K, B>,
) -> AlignedArchivedImmutableKdTree<'a, A, T, K, B> {
AlignedArchivedImmutableKdTree {
stems: AVec::from_slice(CACHELINE_ALIGN, &value.stems[..]),
leaf_points: &value.leaf_points,
leaf_extents: &value.leaf_extents,
leaf_items: &value.leaf_items,
max_stem_level: value.max_stem_level,
}
}
#[cfg(feature = "rkyv")]
pub fn from_bytes(bytes: &'a [u8]) -> AlignedArchivedImmutableKdTree<'a, A, T, K, B> {
let tree_rk = unsafe { rkyv::archived_root::<ImmutableKdTreeRK<A, T, K, B>>(bytes) };
AlignedArchivedImmutableKdTree::new_from(tree_rk)
}
}
#[cfg(feature = "rkyv")]
impl<A, T, const K: usize, const B: usize> AlignedArchivedImmutableKdTree<'_, A, T, K, B>
where
A: Axis + LeafSliceFloat<T> + LeafSliceFloatChunk<T, K> + rkyv::Archive<Archived = A>,
T: Content + rkyv::Archive<Archived = T>,
usize: Cast<T>,
{
#[inline]
pub fn size(&self) -> usize {
self.leaf_items.len()
}
#[inline]
pub(crate) fn get_leaf_slice(&self, leaf_idx: usize) -> LeafSlice<A, T, K> {
let (start, end) = unsafe { *self.leaf_extents.get_unchecked(leaf_idx) };
LeafSlice::new(
array_init::array_init(|i| &self.leaf_points[i][start as usize..end as usize]),
&self.leaf_items[start as usize..end as usize],
)
}
}
impl<A: Axis, T: Content, const K: usize, const B: usize> From<&[[A; K]]>
for ImmutableKdTree<A, T, K, B>
where
A: Axis + LeafSliceFloat<T> + LeafSliceFloatChunk<T, K>,
T: Content,
usize: Cast<T>,
{
fn from(slice: &[[A; K]]) -> Self {
ImmutableKdTree::new_from_slice(slice)
}
}
#[allow(unexpected_cfgs)]
impl<A, T, const K: usize, const B: usize> ImmutableKdTree<A, T, K, B>
where
A: Axis + LeafSliceFloat<T> + LeafSliceFloatChunk<T, K>,
T: Content,
usize: Cast<T>,
{
#[inline]
pub fn new_from_slice(source: &[[A; K]]) -> Self
where
usize: Cast<T>,
{
let item_count = source.len();
let leaf_node_count = item_count.div_ceil(B);
#[cfg(not(feature = "modified_van_emde_boas"))]
let stem_node_count = if leaf_node_count < 2 {
0
} else {
leaf_node_count.next_power_of_two()
};
#[cfg(feature = "modified_van_emde_boas")]
let stem_node_count = if leaf_node_count < 2 {
0
} else {
leaf_node_count.next_power_of_two() - 1
};
let max_stem_level: i32 = leaf_node_count.next_power_of_two().ilog2() as i32 - 1;
#[cfg(feature = "modified_van_emde_boas")]
let stem_node_count = stem_node_count * 5;
let mut stems = avec![A::infinity(); stem_node_count];
let mut leaf_points: [Vec<A>; K] = array_init(|_| Vec::with_capacity(item_count));
let mut leaf_items: Vec<T> = Vec::with_capacity(item_count);
let mut leaf_extents: Vec<(u32, u32)> = Vec::with_capacity(item_count.div_ceil(B));
let mut sort_index = Vec::from_iter(0..item_count);
if stem_node_count == 0 {
leaf_extents.push((0u32, sort_index.len() as u32));
(0..sort_index.len()).for_each(|i| {
(0..K).for_each(|dim| leaf_points[dim].push(source[sort_index[i]][dim]));
leaf_items.push(sort_index[i].az::<T>())
});
} else {
#[cfg(not(feature = "modified_van_emde_boas"))]
let initial_stem_idx = 1;
#[cfg(feature = "modified_van_emde_boas")]
let initial_stem_idx = 0;
Self::populate_recursive(
&mut stems,
0,
source,
&mut sort_index,
initial_stem_idx,
0,
0,
max_stem_level,
leaf_node_count * B,
&mut leaf_points,
&mut leaf_items,
&mut leaf_extents,
);
#[cfg(feature = "modified_van_emde_boas")]
if !stems.is_empty() {
let mut level: usize = 0;
let mut minor_level: u64 = 0;
let mut stem_idx = 0;
loop {
let val = stems[stem_idx];
let is_right_child = val.is_finite();
stem_idx = modified_van_emde_boas_get_child_idx_v2_branchless(
stem_idx as u32,
is_right_child,
minor_level as u32,
) as usize;
level += 1;
minor_level += 1;
minor_level.cmovnz(&0, u8::from(minor_level == 3));
if level == max_stem_level as usize {
break;
}
}
stems.truncate(stem_idx + 1);
}
}
Self {
stems,
leaf_points,
leaf_items,
leaf_extents,
max_stem_level,
}
}
#[allow(clippy::too_many_arguments)]
fn populate_recursive(
stems: &mut AVec<A, ConstAlign<{ CACHELINE_ALIGN }>>,
dim: usize,
source: &[[A; K]],
sort_index: &mut [usize],
stem_index: usize,
mut level: i32,
mut minor_level: u64,
max_stem_level: i32,
capacity: usize,
leaf_points: &mut [Vec<A>; K],
leaf_items: &mut Vec<T>,
leaf_extents: &mut Vec<(u32, u32)>,
) {
let chunk_length = sort_index.len();
if level > max_stem_level {
leaf_extents.push((
leaf_items.len() as u32,
(leaf_items.len() + chunk_length) as u32,
));
(0..chunk_length).for_each(|i| {
(0..K).for_each(|dim| leaf_points[dim].push(source[sort_index[i]][dim]));
leaf_items.push(sort_index[i].az::<T>())
});
return;
}
let levels_below = max_stem_level - level;
let left_capacity = (2usize.pow(levels_below as u32) * B).min(capacity);
let right_capacity = capacity.saturating_sub(left_capacity);
let mut pivot = Self::calc_pivot(chunk_length, stem_index, right_capacity);
if pivot < chunk_length {
pivot = Self::update_pivot(source, sort_index, dim, pivot);
debug_assert!(pivot > 0 || chunk_length == 1);
debug_assert!(
stems[stem_index].is_infinite(),
"Wrote to stem #{:?} for a second time",
stem_index
);
stems[stem_index] = source[sort_index[pivot]][dim];
}
#[cfg(feature = "modified_van_emde_boas")]
let left_child_idx = modified_van_emde_boas_get_child_idx_v2_branchless(
stem_index as u32,
false,
minor_level as u32,
) as usize;
#[cfg(feature = "modified_van_emde_boas")]
let right_child_idx = modified_van_emde_boas_get_child_idx_v2_branchless(
stem_index as u32,
true,
minor_level as u32,
) as usize;
#[cfg(not(feature = "modified_van_emde_boas"))]
let left_child_idx = stem_index << 1;
#[cfg(not(feature = "modified_van_emde_boas"))]
let right_child_idx = (stem_index << 1) + 1;
let (lower_sort_index, upper_sort_index) = sort_index.split_at_mut(pivot);
level += 1;
minor_level += 1;
minor_level.cmovnz(&0, u8::from(minor_level == 3));
let next_dim = (dim + 1) % K;
Self::populate_recursive(
stems,
next_dim,
source,
lower_sort_index,
left_child_idx,
level,
minor_level,
max_stem_level,
left_capacity,
leaf_points,
leaf_items,
leaf_extents,
);
Self::populate_recursive(
stems,
next_dim,
source,
upper_sort_index,
right_child_idx,
level,
minor_level,
max_stem_level,
right_capacity,
leaf_points,
leaf_items,
leaf_extents,
);
}
#[cfg(not(feature = "unreliable_select_nth_unstable"))]
#[inline]
fn update_pivot(
source: &[[A; K]],
sort_index: &mut [usize],
dim: usize,
mut pivot: usize,
) -> usize {
sort_index.select_nth_unstable_by_key(pivot, |&i| OrderedFloat(source[i][dim]));
if pivot == 0 {
return pivot;
}
while source[sort_index[pivot]][dim] == source[sort_index[pivot - 1]][dim] && pivot > 1 {
pivot -= 1;
}
pivot
}
#[inline]
pub fn size(&self) -> usize {
self.leaf_items.len()
}
#[inline]
pub fn capacity(&self) -> usize {
self.size()
}
fn calc_pivot(chunk_length: usize, _stem_index: usize, _right_capacity: usize) -> usize {
chunk_length >> 1
}
#[inline]
pub(crate) fn get_leaf_slice(&self, leaf_idx: usize) -> LeafSlice<A, T, K> {
let (start, end) = unsafe { *self.leaf_extents.get_unchecked(leaf_idx) };
LeafSlice::new(
array_init::array_init(|i| &self.leaf_points[i][start as usize..end as usize]),
&self.leaf_items[start as usize..end as usize],
)
}
}
#[cfg(test)]
mod tests {
use crate::immutable::float::kdtree::ImmutableKdTree;
use crate::SquaredEuclidean;
use ordered_float::OrderedFloat;
use rand::{Rng, SeedableRng};
#[test]
fn can_construct_an_empty_tree() {
let tree = ImmutableKdTree::<f64, u32, 3, 32>::new_from_slice(&[]);
let _result = tree.nearest_one::<SquaredEuclidean>(&[0.; 3]);
}
#[test]
fn can_construct_optimized_tree_with_straddled_split() {
let content_to_add = vec![
[1.0, 101.0],
[2.0, 102.0],
[3.0, 103.0],
[4.0, 104.0],
[4.0, 104.0],
[5.0, 105.0],
[6.0, 106.0],
[7.0, 107.0],
[8.0, 108.0],
[9.0, 109.0],
[10.0, 110.0],
[11.0, 111.0],
[12.0, 112.0],
[13.0, 113.0],
[14.0, 114.0],
[15.0, 115.0],
];
let _tree: ImmutableKdTree<f32, usize, 2, 4> =
ImmutableKdTree::new_from_slice(&content_to_add);
}
#[test]
fn can_construct_optimized_tree_with_straddled_split_2() {
let content_to_add = vec![
[1.0, 101.0],
[2.0, 102.0],
[3.0, 103.0],
[4.0, 104.0],
[4.0, 104.0],
[5.0, 105.0],
[6.0, 106.0],
[7.0, 107.0],
[8.0, 108.0],
[9.0, 109.0],
[10.0, 110.0],
[11.0, 111.0],
[12.0, 112.0],
[13.0, 113.0],
[14.0, 114.0],
[15.0, 115.0],
[16.0, 116.0],
[17.0, 117.0],
[18.0, 118.0],
];
let _tree: ImmutableKdTree<f32, usize, 2, 4> =
ImmutableKdTree::new_from_slice(&content_to_add);
}
#[test]
fn can_construct_optimized_tree_with_straddled_split_3() {
use rand::seq::SliceRandom;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(23);
let mut content_to_add = vec![
[1.0, 101.0],
[2.0, 102.0],
[3.0, 103.0],
[4.0, 104.0],
[4.0, 104.0],
[5.0, 105.0],
[6.0, 106.0],
[7.0, 107.0],
[8.0, 108.0],
[9.0, 109.0],
[10.0, 110.0],
[11.0, 111.0],
[12.0, 112.0],
[13.0, 113.0],
[14.0, 114.0],
[15.0, 115.0],
[16.0, 116.0],
[17.0, 117.0],
[18.0, 118.0],
];
content_to_add.shuffle(&mut rng);
let _tree: ImmutableKdTree<f32, usize, 2, 4> =
ImmutableKdTree::new_from_slice(&content_to_add);
}
#[test]
fn can_construct_optimized_tree_with_multiple_dupes() {
use rand::seq::SliceRandom;
for seed in 0..1_000 {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let mut content_to_add = vec![
[4.0, 104.0],
[2.0, 102.0],
[3.0, 103.0],
[4.0, 104.0],
[4.0, 104.0],
[4.0, 104.0],
[4.0, 104.0],
[7.0, 107.0],
[8.0, 108.0],
[9.0, 109.0],
[10.0, 110.0],
[4.0, 104.0],
[12.0, 112.0],
[13.0, 113.0],
[4.0, 104.0],
[4.0, 104.0],
[17.0, 117.0],
[18.0, 118.0],
];
content_to_add.shuffle(&mut rng);
let _tree: ImmutableKdTree<f32, usize, 2, 8> =
ImmutableKdTree::new_from_slice(&content_to_add);
}
}
#[test]
fn can_construct_optimized_tree_bad_example_0() {
let tree_size = 18;
let seed = 894771;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let content_to_add: Vec<[f32; 4]> = (0..tree_size).map(|_| rng.gen::<[f32; 4]>()).collect();
let tree: ImmutableKdTree<f32, usize, 4, 4> =
ImmutableKdTree::new_from_slice(&content_to_add);
println!("tree: {:?}", tree);
}
#[test]
fn can_construct_optimized_tree_bad_example_1() {
let tree_size = 33;
let seed = 100045;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let content_to_add: Vec<[f32; 4]> = (0..tree_size).map(|_| rng.gen::<[f32; 4]>()).collect();
let _tree: ImmutableKdTree<f32, usize, 4, 4> =
ImmutableKdTree::new_from_slice(&content_to_add);
}
#[test]
fn can_construct_optimized_tree_bad_example_2() {
let tree_size = 155;
let seed = 480;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let content_to_add: Vec<[f32; 4]> = (0..tree_size).map(|_| rng.gen::<[f32; 4]>()).collect();
let _tree: ImmutableKdTree<f32, usize, 4, 4> =
ImmutableKdTree::new_from_slice(&content_to_add);
}
#[test]
fn can_construct_optimized_tree_bad_example_3() {
let tree_size = 26; let seed = 455191;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let content_to_add: Vec<[f32; 4]> = (0..tree_size).map(|_| rng.gen::<[f32; 4]>()).collect();
let _tree: ImmutableKdTree<f32, usize, 4, 4> =
ImmutableKdTree::new_from_slice(&content_to_add);
}
#[test]
fn can_construct_optimized_tree_bad_example_4() {
let tree_size = 21;
let seed = 131851;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let content_to_add: Vec<[f32; 4]> = (0..tree_size).map(|_| rng.gen::<[f32; 4]>()).collect();
let _tree: ImmutableKdTree<f32, usize, 4, 4> =
ImmutableKdTree::new_from_slice(&content_to_add);
}
#[test]
fn can_construct_optimized_tree_bad_example_5() {
let tree_size = 32;
let seed = 455191;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let content_to_add: Vec<[f32; 4]> = (0..tree_size).map(|_| rng.gen::<[f32; 4]>()).collect();
let _tree: ImmutableKdTree<f32, usize, 4, 4> =
ImmutableKdTree::new_from_slice(&content_to_add);
}
#[test]
fn can_construct_optimized_tree_bad_example_6() {
let tree_size = 56;
let seed = 450533;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let content_to_add: Vec<[f32; 4]> = (0..tree_size).map(|_| rng.gen::<[f32; 4]>()).collect();
let _tree: ImmutableKdTree<f32, usize, 4, 4> =
ImmutableKdTree::new_from_slice(&content_to_add);
}
#[test]
fn can_construct_optimized_tree_bad_example_7() {
let tree_size = 18;
let seed = 992063;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let content_to_add: Vec<[f32; 4]> = (0..tree_size).map(|_| rng.gen::<[f32; 4]>()).collect();
let _tree: ImmutableKdTree<f32, usize, 4, 4> =
ImmutableKdTree::new_from_slice(&content_to_add);
}
#[test]
fn can_construct_optimized_tree_bad_example_8() {
let tree_size = 19;
let seed = 894771;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let content_to_add: Vec<[f32; 4]> = (0..tree_size).map(|_| rng.gen::<[f32; 4]>()).collect();
let _tree: ImmutableKdTree<f32, usize, 4, 4> =
ImmutableKdTree::new_from_slice(&content_to_add);
}
#[test]
fn can_construct_optimized_tree_bad_example_9() {
let tree_size = 20;
let seed = 894771;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let content_to_add: Vec<[f32; 4]> = (0..tree_size).map(|_| rng.gen::<[f32; 4]>()).collect();
let _tree: ImmutableKdTree<f32, usize, 4, 4> =
ImmutableKdTree::new_from_slice(&content_to_add);
}
#[test]
fn can_construct_optimized_tree_bad_example_10() {
let tree_size = 36;
let seed = 375096;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let content_to_add: Vec<[f32; 4]> = (0..tree_size).map(|_| rng.gen::<[f32; 4]>()).collect();
let _tree: ImmutableKdTree<f32, usize, 4, 4> =
ImmutableKdTree::new_from_slice(&content_to_add);
}
#[test]
fn can_construct_optimized_tree_bad_example_11() {
let tree_size = 10000;
let seed = 257281;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let content_to_add: Vec<[f32; 4]> = (0..tree_size).map(|_| rng.gen::<[f32; 4]>()).collect();
let _tree: ImmutableKdTree<f32, usize, 4, 4> =
ImmutableKdTree::new_from_slice(&content_to_add);
}
#[test]
fn can_construct_optimized_tree_many_dupes() {
let tree_size = 8;
let seed = 0;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let content_to_add: Vec<[f32; 4]> = (0..tree_size).map(|_| rng.gen::<[f32; 4]>()).collect();
let mut duped: Vec<[f32; 4]> = Vec::with_capacity(content_to_add.len() * 10);
for item in content_to_add {
for _ in 0..6 {
duped.push(item);
}
}
let _tree: ImmutableKdTree<f32, usize, 4, 8> = ImmutableKdTree::new_from_slice(&duped);
}
#[test]
fn can_construct_optimized_tree_medium_rand() {
use itertools::Itertools;
const TREE_SIZE: usize = 2usize.pow(19);
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(493);
let content_to_add: Vec<[f32; 4]> = (0..TREE_SIZE).map(|_| rng.gen::<[f32; 4]>()).collect();
let num_uniq = content_to_add
.iter()
.flatten()
.map(|&x| OrderedFloat(x))
.unique()
.count();
println!("dupes: {:?}", TREE_SIZE * 4 - num_uniq);
let _tree: ImmutableKdTree<f32, usize, 4, 4> =
ImmutableKdTree::new_from_slice(&content_to_add);
}
#[test]
fn can_construct_optimized_tree_large_rand() {
const TREE_SIZE: usize = 2usize.pow(23);
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(493);
let content_to_add: Vec<[f32; 4]> = (0..TREE_SIZE).map(|_| rng.gen::<[f32; 4]>()).collect();
let _tree: ImmutableKdTree<f32, usize, 4, 32> =
ImmutableKdTree::new_from_slice(&content_to_add);
}
}