use core::{
fmt,
iter::{Enumerate, FusedIterator},
slice::Iter,
};
#[cfg(feature = "nightly")]
use core::{
iter::FilterMap,
simd::{cmp::SimdPartialEq, usizex64},
};
use crate::raw::{
representation::assert_valid_range_bounds, Header, InnerNode, InnerNode48, InnerNodeCommon,
InnerNodeIndirect, InnerNodeSorted, Node, NodeType, OpaqueNodePtr,
};
#[repr(C, align(8))]
pub struct InnerNodeDirect<K, V, const PREFIX_LEN: usize> {
pub header: Header<PREFIX_LEN>,
pub child_pointers: [Option<OpaqueNodePtr<K, V, PREFIX_LEN>>; 256],
}
impl<K, V, const PREFIX_LEN: usize, const OTHER_SIZE: usize>
From<&InnerNodeSorted<K, V, PREFIX_LEN, OTHER_SIZE>> for InnerNodeDirect<K, V, PREFIX_LEN>
{
fn from(value: &InnerNodeSorted<K, V, PREFIX_LEN, OTHER_SIZE>) -> Self {
let header = value.header.clone();
let mut child_pointers = [None; 256];
for (key_fragment, child_pointer) in value.inner_iter() {
child_pointers[usize::from(key_fragment)] = Some(child_pointer);
}
InnerNodeDirect {
header,
child_pointers,
}
}
}
impl<K, V, const PREFIX_LEN: usize, const OTHER_SIZE: usize>
From<&InnerNodeIndirect<K, V, PREFIX_LEN, OTHER_SIZE>> for InnerNodeDirect<K, V, PREFIX_LEN>
{
fn from(value: &InnerNodeIndirect<K, V, PREFIX_LEN, OTHER_SIZE>) -> Self {
let header = value.header.clone();
let mut child_pointers = [None; 256];
let initialized_child_pointers = value.initialized_child_pointers();
for (key_fragment, idx) in value.child_indices.iter().enumerate() {
let Some(idx) = idx else {
continue;
};
let idx = usize::from(*idx);
unsafe {
core::hint::assert_unchecked(idx < initialized_child_pointers.len());
}
let child_pointer = initialized_child_pointers[idx];
child_pointers[key_fragment] = Some(child_pointer);
}
InnerNodeDirect {
header,
child_pointers,
}
}
}
impl<K, V, const PREFIX_LEN: usize> fmt::Debug for InnerNodeDirect<K, V, PREFIX_LEN> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InnerNodeDirect")
.field("header", &self.header)
.field("child_pointers", &self.child_pointers)
.finish()
}
}
impl<K, V, const PREFIX_LEN: usize> Clone for InnerNodeDirect<K, V, PREFIX_LEN> {
fn clone(&self) -> Self {
Self {
header: self.header.clone(),
child_pointers: self.child_pointers,
}
}
}
unsafe impl<K, V, const PREFIX_LEN: usize> InnerNodeCommon<K, V, PREFIX_LEN>
for InnerNodeDirect<K, V, PREFIX_LEN>
{
#[cfg(not(feature = "nightly"))]
type Iter<'a>
= NodeDirectIter<'a, K, V, PREFIX_LEN>
where
Self: 'a;
#[cfg(feature = "nightly")]
type Iter<'a>
= FilterMap<
Enumerate<Iter<'a, Option<OpaqueNodePtr<K, V, PREFIX_LEN>>>>,
impl FnMut(
(usize, &'a Option<OpaqueNodePtr<K, V, PREFIX_LEN>>),
) -> Option<(u8, OpaqueNodePtr<K, V, PREFIX_LEN>)>,
>
where
Self: 'a;
fn header(&self) -> &Header<PREFIX_LEN> {
&self.header
}
fn from_header(header: Header<PREFIX_LEN>) -> Self {
InnerNodeDirect {
header,
child_pointers: [None; 256],
}
}
fn lookup_child(&self, key_fragment: u8) -> Option<OpaqueNodePtr<K, V, PREFIX_LEN>> {
self.child_pointers[usize::from(key_fragment)]
}
fn write_child(&mut self, key_fragment: u8, child_pointer: OpaqueNodePtr<K, V, PREFIX_LEN>) {
let key_fragment_idx = usize::from(key_fragment);
let existing_pointer = self.child_pointers[key_fragment_idx];
self.child_pointers[key_fragment_idx] = Some(child_pointer);
if existing_pointer.is_none() {
self.header.inc_num_children();
}
}
fn remove_child(&mut self, key_fragment: u8) -> Option<OpaqueNodePtr<K, V, PREFIX_LEN>> {
let removed_child = self.child_pointers[usize::from(key_fragment)].take();
if removed_child.is_some() {
self.header.dec_num_children();
}
removed_child
}
fn iter(&self) -> Self::Iter<'_> {
#[cfg(not(feature = "nightly"))]
{
NodeDirectIter {
it: self.child_pointers.iter().enumerate(),
}
}
#[cfg(feature = "nightly")]
{
self.child_pointers
.iter()
.enumerate()
.filter_map(|(key, node)| node.map(|node| (key as u8, node)))
}
}
fn range(
&self,
bound: impl core::ops::RangeBounds<u8>,
) -> impl DoubleEndedIterator<Item = (u8, OpaqueNodePtr<K, V, PREFIX_LEN>)> + FusedIterator
{
assert_valid_range_bounds(&bound);
let start = bound.start_bound().map(|val| usize::from(*val));
let key_offset = match bound.start_bound() {
core::ops::Bound::Included(val) => *val,
core::ops::Bound::Excluded(val) => val.saturating_add(1),
core::ops::Bound::Unbounded => 0,
};
let end = bound.end_bound().map(|val| usize::from(*val));
self.child_pointers[(start, end)]
.iter()
.enumerate()
.filter_map(move |(key, child)| {
child.map(|child| ((key as u8).saturating_add(key_offset), child))
})
}
#[cfg(feature = "nightly")]
#[cfg_attr(test, mutants::skip)]
fn min(&self) -> (u8, OpaqueNodePtr<K, V, PREFIX_LEN>) {
let child_pointers: &[usize; 256] = unsafe { core::mem::transmute(&self.child_pointers) };
let empty = usizex64::splat(0);
let r0 = usizex64::from_array(child_pointers[0..64].try_into().unwrap())
.simd_eq(empty)
.to_bitmask();
let r1 = usizex64::from_array(child_pointers[64..128].try_into().unwrap())
.simd_eq(empty)
.to_bitmask();
let r2 = usizex64::from_array(child_pointers[128..192].try_into().unwrap())
.simd_eq(empty)
.to_bitmask();
let r3 = usizex64::from_array(child_pointers[192..256].try_into().unwrap())
.simd_eq(empty)
.to_bitmask();
let key = if r0 != u64::MAX {
r0.trailing_ones()
} else if r1 != u64::MAX {
r1.trailing_ones() + 64
} else if r2 != u64::MAX {
r2.trailing_ones() + 128
} else {
r3.trailing_ones() + 192
} as usize;
unsafe {
core::hint::assert_unchecked(key < self.child_pointers.len());
}
(key as u8, unsafe {
self.child_pointers[key].unwrap_unchecked()
})
}
#[cfg(not(feature = "nightly"))]
fn min(&self) -> (u8, OpaqueNodePtr<K, V, PREFIX_LEN>) {
for (key, child_pointer) in self.child_pointers.iter().enumerate() {
match child_pointer {
Some(child_pointer) => return (key as u8, *child_pointer),
None => continue,
}
}
unreachable!("inner node must have non-zero number of children");
}
#[cfg(feature = "nightly")]
#[cfg_attr(test, mutants::skip)]
fn max(&self) -> (u8, OpaqueNodePtr<K, V, PREFIX_LEN>) {
let child_pointers: &[usize; 256] = unsafe { core::mem::transmute(&self.child_pointers) };
let empty = usizex64::splat(0);
let r0 = usizex64::from_array(child_pointers[0..64].try_into().unwrap())
.simd_eq(empty)
.to_bitmask();
let r1 = usizex64::from_array(child_pointers[64..128].try_into().unwrap())
.simd_eq(empty)
.to_bitmask();
let r2 = usizex64::from_array(child_pointers[128..192].try_into().unwrap())
.simd_eq(empty)
.to_bitmask();
let r3 = usizex64::from_array(child_pointers[192..256].try_into().unwrap())
.simd_eq(empty)
.to_bitmask();
let key = if r3 != u64::MAX {
255 - r3.leading_ones()
} else if r2 != u64::MAX {
191 - r2.leading_ones()
} else if r1 != u64::MAX {
127 - r1.leading_ones()
} else {
63 - r0.leading_ones()
} as usize;
unsafe {
core::hint::assert_unchecked(key < self.child_pointers.len());
}
(key as u8, unsafe {
self.child_pointers[key].unwrap_unchecked()
})
}
#[cfg(not(feature = "nightly"))]
fn max(&self) -> (u8, OpaqueNodePtr<K, V, PREFIX_LEN>) {
for (key, child_pointer) in self.child_pointers.iter().enumerate().rev() {
match child_pointer {
Some(child_pointer) => return (key as u8, *child_pointer),
None => continue,
}
}
unreachable!("inner node must have non-zero number of children");
}
}
impl<K, V, const PREFIX_LEN: usize> Node<PREFIX_LEN> for InnerNodeDirect<K, V, PREFIX_LEN> {
type Key = K;
type Value = V;
const TYPE: NodeType = NodeType::Node256;
}
impl<K, V, const PREFIX_LEN: usize> InnerNode<PREFIX_LEN> for InnerNodeDirect<K, V, PREFIX_LEN> {
type GrownNode = Self;
type ShrunkNode = InnerNode48<K, V, PREFIX_LEN>;
fn grow(&self) -> Self::GrownNode {
panic!("unable to grow a Node256, something went wrong!")
}
fn shrink(&self) -> Self::ShrunkNode {
self.into()
}
}
#[cfg(not(feature = "nightly"))]
pub struct NodeDirectIter<'a, K, V, const PREFIX_LEN: usize> {
pub(crate) it: Enumerate<Iter<'a, Option<OpaqueNodePtr<K, V, PREFIX_LEN>>>>,
}
#[cfg(not(feature = "nightly"))]
impl<K, V, const PREFIX_LEN: usize> Iterator for NodeDirectIter<'_, K, V, PREFIX_LEN> {
type Item = (u8, OpaqueNodePtr<K, V, PREFIX_LEN>);
fn next(&mut self) -> Option<Self::Item> {
for (key, node) in self.it.by_ref() {
match node {
Some(node) => return Some((key as u8, *node)),
None => continue,
}
}
None
}
}
#[cfg(not(feature = "nightly"))]
impl<K, V, const PREFIX_LEN: usize> DoubleEndedIterator for NodeDirectIter<'_, K, V, PREFIX_LEN> {
fn next_back(&mut self) -> Option<Self::Item> {
while let Some((key, node)) = self.it.next_back() {
match node {
Some(node) => return Some((key as u8, *node)),
None => continue,
}
}
None
}
}
#[cfg(not(feature = "nightly"))]
impl<K, V, const PREFIX_LEN: usize> FusedIterator for NodeDirectIter<'_, K, V, PREFIX_LEN> {}
#[cfg(test)]
mod tests {
use alloc::{boxed::Box, vec::Vec};
use core::ops::{Bound, RangeBounds};
use super::*;
use crate::raw::{
representation::tests::{
inner_node_min_max_test, inner_node_remove_child_test, inner_node_shrink_test,
inner_node_write_child_test, FixtureReturn,
},
LeafNode, NodePtr,
};
#[test]
fn lookup() {
let mut n = InnerNodeDirect::<Box<[u8]>, (), 16>::empty();
let mut l1 = LeafNode::with_no_siblings(Box::from([]), ());
let mut l2 = LeafNode::with_no_siblings(Box::from([]), ());
let mut l3 = LeafNode::with_no_siblings(Box::from([]), ());
let l1_ptr = NodePtr::from(&mut l1).to_opaque();
let l2_ptr = NodePtr::from(&mut l2).to_opaque();
let l3_ptr = NodePtr::from(&mut l3).to_opaque();
assert!(n.lookup_child(123).is_none());
n.header.inc_num_children();
n.header.inc_num_children();
n.header.inc_num_children();
n.child_pointers[1] = Some(l1_ptr);
n.child_pointers[123] = Some(l2_ptr);
n.child_pointers[3] = Some(l3_ptr);
assert_eq!(n.lookup_child(123), Some(l2_ptr));
}
#[test]
fn write_child() {
inner_node_write_child_test(InnerNodeDirect::<_, _, 16>::empty(), 256)
}
#[test]
fn remove_child() {
inner_node_remove_child_test(InnerNodeDirect::<_, _, 16>::empty(), 256)
}
#[test]
#[should_panic = "unable to grow a Node256, something went wrong!"]
fn grow() {
let n = InnerNodeDirect::<Box<[u8]>, (), 16>::empty();
n.grow();
}
#[test]
fn shrink() {
inner_node_shrink_test(InnerNodeDirect::<_, _, 16>::empty(), 48);
}
#[test]
#[should_panic = "Cannot shrink a InnerNodeDirect when it has more than 48 children. Currently \
has [49] children."]
fn shrink_too_many_children_panic() {
inner_node_shrink_test(InnerNodeDirect::<_, _, 16>::empty(), 49);
}
#[test]
fn min_max() {
inner_node_min_max_test(InnerNodeDirect::<_, _, 16>::empty(), 256);
}
fn fixture() -> FixtureReturn<InnerNodeDirect<Box<[u8]>, (), 16>, 4> {
let mut n256 = InnerNodeDirect::empty();
let mut l1 = LeafNode::with_no_siblings(vec![].into(), ());
let mut l2 = LeafNode::with_no_siblings(vec![].into(), ());
let mut l3 = LeafNode::with_no_siblings(vec![].into(), ());
let mut l4 = LeafNode::with_no_siblings(vec![].into(), ());
let l1_ptr = NodePtr::from(&mut l1).to_opaque();
let l2_ptr = NodePtr::from(&mut l2).to_opaque();
let l3_ptr = NodePtr::from(&mut l3).to_opaque();
let l4_ptr = NodePtr::from(&mut l4).to_opaque();
n256.write_child(3, l1_ptr);
n256.write_child(255, l2_ptr);
n256.write_child(0u8, l3_ptr);
n256.write_child(85, l4_ptr);
(n256, [l1, l2, l3, l4], [l1_ptr, l2_ptr, l3_ptr, l4_ptr])
}
#[test]
fn iterate() {
let (node, _, [l1_ptr, l2_ptr, l3_ptr, l4_ptr]) = fixture();
let mut iter = node.iter();
assert_eq!(iter.next().unwrap(), (0u8, l3_ptr));
assert_eq!(iter.next().unwrap(), (3, l1_ptr));
assert_eq!(iter.next().unwrap(), (85, l4_ptr));
assert_eq!(iter.next().unwrap(), (255, l2_ptr));
assert_eq!(iter.next(), None);
}
#[test]
fn iterate_rev() {
let (node, _, [l1_ptr, l2_ptr, l3_ptr, l4_ptr]) = fixture();
let mut iter = node.iter().rev();
assert_eq!(iter.next().unwrap(), (255, l2_ptr));
assert_eq!(iter.next().unwrap(), (85, l4_ptr));
assert_eq!(iter.next().unwrap(), (3, l1_ptr));
assert_eq!(iter.next().unwrap(), (0u8, l3_ptr));
assert_eq!(iter.next(), None);
}
#[test]
fn range_iterate() {
let (node, _, [l1_ptr, l2_ptr, l3_ptr, l4_ptr]) = fixture();
#[track_caller]
fn check<K, V, const PREFIX_LEN: usize, const N: usize>(
node: &InnerNodeDirect<K, V, PREFIX_LEN>,
bound: impl RangeBounds<u8>,
expected_pairs: [(u8, OpaqueNodePtr<K, V, PREFIX_LEN>); N],
) {
let pairs = node.range(bound).collect::<Vec<_>>();
assert_eq!(pairs, expected_pairs);
}
check(
&node,
(Bound::Included(0), Bound::Included(3)),
[(0u8, l3_ptr), (3, l1_ptr)],
);
check(&node, (Bound::Excluded(0), Bound::Excluded(3)), []);
check(
&node,
(Bound::Included(0), Bound::Included(0)),
[(0u8, l3_ptr)],
);
check(
&node,
(Bound::Included(0), Bound::Included(255)),
[(0u8, l3_ptr), (3, l1_ptr), (85, l4_ptr), (255, l2_ptr)],
);
check(
&node,
(Bound::Included(255), Bound::Included(255)),
[(255, l2_ptr)],
);
check(&node, (Bound::Included(255), Bound::Excluded(255)), []);
check(&node, (Bound::Excluded(255), Bound::Included(255)), []);
check(
&node,
(Bound::Excluded(0), Bound::Excluded(255)),
[(3, l1_ptr), (85, l4_ptr)],
);
check(
&node,
(Bound::<u8>::Unbounded, Bound::Unbounded),
[(0u8, l3_ptr), (3, l1_ptr), (85, l4_ptr), (255, l2_ptr)],
);
check(
&node,
(Bound::<u8>::Unbounded, Bound::Included(86)),
[(0u8, l3_ptr), (3, l1_ptr), (85, l4_ptr)],
);
}
fn fixture_empty_edges() -> FixtureReturn<InnerNodeDirect<Box<[u8]>, (), 16>, 4> {
let mut n4 = InnerNodeDirect::empty();
let mut l1 = LeafNode::with_no_siblings(vec![].into(), ());
let mut l2 = LeafNode::with_no_siblings(vec![].into(), ());
let mut l3 = LeafNode::with_no_siblings(vec![].into(), ());
let mut l4 = LeafNode::with_no_siblings(vec![].into(), ());
let l1_ptr = NodePtr::from(&mut l1).to_opaque();
let l2_ptr = NodePtr::from(&mut l2).to_opaque();
let l3_ptr = NodePtr::from(&mut l3).to_opaque();
let l4_ptr = NodePtr::from(&mut l4).to_opaque();
n4.write_child(3, l1_ptr);
n4.write_child(254, l2_ptr);
n4.write_child(2u8, l3_ptr);
n4.write_child(85, l4_ptr);
(n4, [l1, l2, l3, l4], [l1_ptr, l2_ptr, l3_ptr, l4_ptr])
}
#[test]
fn range_iterate_boundary_conditions() {
let (node, _, [l1_ptr, l2_ptr, l3_ptr, l4_ptr]) = fixture_empty_edges();
#[track_caller]
fn check<K, V, const PREFIX_LEN: usize, const N: usize>(
node: &InnerNodeDirect<K, V, PREFIX_LEN>,
bound: impl RangeBounds<u8>,
expected_pairs: [(u8, OpaqueNodePtr<K, V, PREFIX_LEN>); N],
) {
let pairs = node.range(bound).collect::<Vec<_>>();
assert_eq!(pairs, expected_pairs);
}
check(
&node,
(Bound::<u8>::Unbounded, Bound::Included(86)),
[(2u8, l3_ptr), (3, l1_ptr), (85, l4_ptr)],
);
check(
&node,
(Bound::<u8>::Unbounded, Bound::Included(4)),
[(2u8, l3_ptr), (3, l1_ptr)],
);
check(
&node,
(Bound::<u8>::Unbounded, Bound::Excluded(3)),
[(2u8, l3_ptr)],
);
check(
&node,
(Bound::<u8>::Unbounded, Bound::Included(2)),
[(2u8, l3_ptr)],
);
check(&node, (Bound::<u8>::Unbounded, Bound::Included(1)), []);
check(&node, (Bound::<u8>::Unbounded, Bound::Included(0)), []);
check(
&node,
(Bound::Included(1), Bound::<u8>::Unbounded),
[(2u8, l3_ptr), (3, l1_ptr), (85, l4_ptr), (254, l2_ptr)],
);
check(
&node,
(Bound::Included(3), Bound::<u8>::Unbounded),
[(3, l1_ptr), (85, l4_ptr), (254, l2_ptr)],
);
check(
&node,
(Bound::Excluded(84), Bound::<u8>::Unbounded),
[(85, l4_ptr), (254, l2_ptr)],
);
check(
&node,
(Bound::Included(253), Bound::<u8>::Unbounded),
[(254, l2_ptr)],
);
check(&node, (Bound::Included(255), Bound::<u8>::Unbounded), []);
}
#[test]
#[should_panic = "range start and end are equal and excluded: (80)"]
fn range_iterate_out_of_bounds_panic_both_excluded() {
let (node, _, [_l1_ptr, _l2_ptr, _l3_ptr, _l4_ptr]) = fixture();
let pairs = node
.range((Bound::Excluded(80), Bound::Excluded(80)))
.collect::<Vec<_>>();
assert_eq!(pairs, &[]);
}
#[test]
#[should_panic = "range start (80) is greater than range end (0)"]
fn range_iterate_start_greater_than_end() {
let (node, _, [_l1_ptr, _l2_ptr, _l3_ptr, _l4_ptr]) = fixture();
let _pairs = node
.range((Bound::Excluded(80), Bound::Included(0)))
.collect::<Vec<_>>();
}
}