use std::marker::PhantomData;
use ordered_float::NotNan;
use crate::{Point, Scalar};
union SplitValOrBucketIndex<T: Scalar> {
split_val: NotNan<T>,
bucket_start_index: u32,
}
fn create_dim_child_bucket_size<T: Scalar, P: Point<T>>(
dim: u32,
child_index_or_bucket_size: u32,
) -> u32 {
dim | (child_index_or_bucket_size << P::DIM_BIT_COUNT)
}
pub(crate) struct Node<T: Scalar, P: Point<T>> {
dim_child_bucket_size: u32,
split_val_or_bucket_start_index: SplitValOrBucketIndex<T>,
phantom: PhantomData<P>,
}
impl<T: Scalar, P: Point<T>> Node<T, P> {
pub(crate) fn set_child_index(&mut self, child_index: u32) {
self.dim_child_bucket_size |= child_index << P::DIM_BIT_COUNT;
}
pub(crate) fn new_split_node(split_dim: u32, split_val: NotNan<T>) -> Self {
Node {
dim_child_bucket_size: split_dim,
split_val_or_bucket_start_index: SplitValOrBucketIndex { split_val },
phantom: PhantomData,
}
}
pub(crate) fn new_leaf_node(bucket_start_index: u32, bucket_size: u32) -> Self {
Node {
dim_child_bucket_size: create_dim_child_bucket_size::<T, P>(P::DIM, bucket_size),
split_val_or_bucket_start_index: SplitValOrBucketIndex { bucket_start_index },
phantom: PhantomData,
}
}
#[inline]
pub(crate) fn dispatch_on_type<Fl, Fs, Ctx, R>(&self, ctx: Ctx, split_cb: Fs, leaf_cb: Fl) -> R
where
Fl: FnOnce(Ctx, u32, u32) -> R, Fs: FnOnce(Ctx, u32, NotNan<T>, u32) -> R, {
if self.dim_child_bucket_size & P::DIM_MASK == P::DIM {
let bucket_start_index =
unsafe { self.split_val_or_bucket_start_index.bucket_start_index };
let bucket_size = self.dim_child_bucket_size >> P::DIM_BIT_COUNT;
leaf_cb(ctx, bucket_start_index, bucket_size)
} else {
let split_val = unsafe { self.split_val_or_bucket_start_index.split_val };
let split_dim = self.dim_child_bucket_size & P::DIM_MASK;
let right_child = self.dim_child_bucket_size >> P::DIM_BIT_COUNT;
split_cb(ctx, split_dim, split_val, right_child)
}
}
}
impl<T: Scalar, P: Point<T>> std::fmt::Debug for Node<T, P> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.dispatch_on_type(
f,
|f, split_dim, split_val, right_child| {
f.debug_struct("Node(split)")
.field("split_dim", &split_dim)
.field("split_val", &split_val)
.field("right_child", &right_child)
.finish()
},
|f, bucket_start_index, bucket_size| {
f.debug_struct("Node(leaf)")
.field("bucket_size", &bucket_size)
.field("bucket_start_index", &bucket_start_index)
.finish()
},
)
}
}
#[cfg(test)]
mod tests {
use crate::dummy_point::P2;
use crate::*;
#[test]
fn sizes() {
dbg!(std::mem::size_of::<Node<f32, P2>>());
}
#[test]
fn dim_bit_count() {
let d: u32 = 4;
let dim_bit_count = 32 - d.leading_zeros();
assert_eq!(dim_bit_count, 3);
}
}