use super::{AddChildError, ArtNode, CompressedPrefix, NodeHeader};
use crate::persistent_artrie::swizzled_ptr::SwizzledPtr;
pub const NODE16_MAX_CHILDREN: usize = 16;
#[repr(C, align(16))] #[derive(Debug, Clone)]
pub struct Node16 {
pub header: NodeHeader,
pub prefix: CompressedPrefix,
pub keys: [u8; NODE16_MAX_CHILDREN],
pub children: [SwizzledPtr; NODE16_MAX_CHILDREN],
}
impl Node16 {
pub fn new() -> Self {
Self {
header: NodeHeader::new(16),
prefix: CompressedPrefix::empty(),
keys: [0; NODE16_MAX_CHILDREN],
children: [
SwizzledPtr::null(),
SwizzledPtr::null(),
SwizzledPtr::null(),
SwizzledPtr::null(),
SwizzledPtr::null(),
SwizzledPtr::null(),
SwizzledPtr::null(),
SwizzledPtr::null(),
SwizzledPtr::null(),
SwizzledPtr::null(),
SwizzledPtr::null(),
SwizzledPtr::null(),
SwizzledPtr::null(),
SwizzledPtr::null(),
SwizzledPtr::null(),
SwizzledPtr::null(),
],
}
}
pub fn with_prefix(prefix: &[u8]) -> Self {
let mut node = Self::new();
node.prefix = CompressedPrefix::from_bytes(prefix);
node.header.prefix_len = prefix.len() as u8;
node
}
#[cfg(all(target_arch = "x86_64", target_feature = "sse4.1"))]
fn find_key_index_simd(&self, key: u8) -> Option<usize> {
use std::arch::x86_64::*;
unsafe {
let keys = _mm_loadu_si128(self.keys.as_ptr() as *const __m128i);
let search = _mm_set1_epi8(key as i8);
let cmp = _mm_cmpeq_epi8(keys, search);
let mask = _mm_movemask_epi8(cmp) as u32;
if mask != 0 {
let index = mask.trailing_zeros() as usize;
if index < self.header.num_children as usize {
return Some(index);
}
}
None
}
}
fn find_key_index_linear(&self, key: u8) -> Option<usize> {
let count = self.header.num_children as usize;
for i in 0..count {
if self.keys[i] == key {
return Some(i);
}
}
None
}
fn find_insert_point(&self, key: u8) -> usize {
let count = self.header.num_children as usize;
for i in 0..count {
if self.keys[i] >= key {
return i;
}
}
count
}
}
impl Default for Node16 {
fn default() -> Self {
Self::new()
}
}
impl ArtNode for Node16 {
fn find_child(&self, key: u8) -> Option<&SwizzledPtr> {
#[cfg(all(target_arch = "x86_64", target_feature = "sse4.1"))]
{
self.find_key_index_simd(key).map(|i| &self.children[i])
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "sse4.1")))]
{
self.find_key_index_linear(key).map(|i| &self.children[i])
}
}
fn find_child_mut(&mut self, key: u8) -> Option<&mut SwizzledPtr> {
#[cfg(all(target_arch = "x86_64", target_feature = "sse4.1"))]
let index = self.find_key_index_simd(key);
#[cfg(not(all(target_arch = "x86_64", target_feature = "sse4.1")))]
let index = self.find_key_index_linear(key);
index.map(move |i| &mut self.children[i])
}
fn add_child(&mut self, key: u8, child: SwizzledPtr) -> Result<(), AddChildError> {
let count = self.header.num_children as usize;
if count >= NODE16_MAX_CHILDREN {
return Err(AddChildError::NodeFull);
}
if self.find_key_index_linear(key).is_some() {
return Err(AddChildError::KeyExists);
}
let insert_pos = self.find_insert_point(key);
for i in (insert_pos..count).rev() {
self.keys[i + 1] = self.keys[i];
self.children[i + 1] = self.children[i].clone();
}
self.keys[insert_pos] = key;
self.children[insert_pos] = child;
self.header.num_children += 1;
Ok(())
}
fn remove_child(&mut self, key: u8) -> Option<SwizzledPtr> {
let count = self.header.num_children as usize;
if let Some(index) = self.find_key_index_linear(key) {
let removed = self.children[index].clone();
for i in index..(count - 1) {
self.keys[i] = self.keys[i + 1];
self.children[i] = self.children[i + 1].clone();
}
self.keys[count - 1] = 0;
self.children[count - 1] = SwizzledPtr::null();
self.header.num_children -= 1;
Some(removed)
} else {
None
}
}
fn is_full(&self) -> bool {
self.header.num_children as usize >= NODE16_MAX_CHILDREN
}
fn iter_children(&self) -> impl Iterator<Item = (u8, &SwizzledPtr)> {
let count = self.header.num_children as usize;
self.keys[..count]
.iter()
.zip(self.children[..count].iter())
.map(|(&k, c)| (k, c))
}
}
impl Node16 {
pub fn shrink(&self) -> super::Node4 {
debug_assert!(
self.header.num_children <= 4,
"cannot shrink Node16 with {} children",
self.header.num_children
);
let mut node4 = super::Node4::new();
node4.header = self.header.clone();
node4.header.node_type = 4;
node4.prefix = self.prefix;
let count = self.header.num_children as usize;
for i in 0..count {
node4.keys[i] = self.keys[i];
node4.children[i] = self.children[i].clone();
}
node4
}
pub fn grow(&self) -> super::Node48 {
let mut node48 = super::Node48::new();
node48.header = self.header.clone();
node48.header.node_type = 48;
node48.prefix = self.prefix;
let count = self.header.num_children as usize;
for i in 0..count {
let key = self.keys[i];
node48.index[key as usize] = i as u8;
node48.children[i] = self.children[i].clone();
}
node48
}
pub fn get_child_atomic(&self, key: u8) -> Option<SwizzledPtr> {
#[cfg(all(target_arch = "x86_64", target_feature = "sse4.1"))]
{
self.find_key_index_simd(key)
.map(|i| self.children[i].clone())
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "sse4.1")))]
{
self.find_key_index_linear(key)
.map(|i| self.children[i].clone())
}
}
#[inline]
pub fn child_slot(&self, index: usize) -> &SwizzledPtr {
debug_assert!(index < NODE16_MAX_CHILDREN, "index {} out of bounds", index);
&self.children[index]
}
pub fn find_slot_for_key(&self, key: u8) -> Result<usize, usize> {
match self.find_key_index_linear(key) {
Some(i) => Ok(i),
None => Err(self.find_insert_point(key)),
}
}
pub fn next_slot(&self) -> Option<usize> {
let count = self.header.num_children as usize;
if count < NODE16_MAX_CHILDREN {
Some(count)
} else {
None
}
}
#[inline]
pub fn key_at(&self, index: usize) -> u8 {
debug_assert!(index < NODE16_MAX_CHILDREN, "index {} out of bounds", index);
self.keys[index]
}
pub fn iter_indexed(&self) -> impl Iterator<Item = (usize, u8, &SwizzledPtr)> {
let count = self.header.num_children as usize;
(0..count).map(move |i| (i, self.keys[i], &self.children[i]))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_node16() {
let node = Node16::new();
assert_eq!(node.header.node_type, 16);
assert_eq!(node.header.num_children, 0);
assert!(!node.is_full());
}
#[test]
fn test_add_and_find_children() {
let mut node = Node16::new();
for &key in &[b'h', b'a', b'd', b'f', b'c', b'e', b'g', b'b'] {
let child =
SwizzledPtr::on_disk(key as u32, 0, crate::persistent_artrie::NodeType::Node4);
assert!(node.add_child(key, child).is_ok());
}
assert_eq!(node.header.num_children, 8);
assert_eq!(&node.keys[..8], b"abcdefgh");
for key in b'a'..=b'h' {
assert!(
node.find_child(key).is_some(),
"should find key '{}'",
key as char
);
}
assert!(node.find_child(b'z').is_none());
}
#[test]
fn test_node16_full() {
let mut node = Node16::new();
for i in 0..16 {
let child =
SwizzledPtr::on_disk(i as u32, 0, crate::persistent_artrie::NodeType::Node4);
assert!(node.add_child(i as u8, child).is_ok());
}
assert!(node.is_full());
let child = SwizzledPtr::on_disk(16, 0, crate::persistent_artrie::NodeType::Node4);
assert_eq!(node.add_child(16, child), Err(AddChildError::NodeFull));
}
#[test]
fn test_remove_child() {
let mut node = Node16::new();
for i in 0..10 {
let child =
SwizzledPtr::on_disk(i as u32, 0, crate::persistent_artrie::NodeType::Node4);
node.add_child(i as u8, child).expect("add should succeed");
}
let removed = node.remove_child(5);
assert!(removed.is_some());
assert_eq!(node.header.num_children, 9);
assert!(node.find_child(5).is_none());
for i in 0..10 {
if i != 5 {
assert!(node.find_child(i).is_some());
}
}
}
#[test]
fn test_iter_children() {
let mut node = Node16::new();
for i in 0..8 {
let child =
SwizzledPtr::on_disk(i as u32, 0, crate::persistent_artrie::NodeType::Node4);
node.add_child(b'a' + i, child).expect("add should succeed");
}
let keys: Vec<_> = node.iter_children().map(|(k, _)| k).collect();
assert_eq!(keys, (b'a'..=b'h').collect::<Vec<_>>());
}
#[test]
fn test_shrink_to_node4() {
let mut node = Node16::new();
for i in 0..4 {
let child =
SwizzledPtr::on_disk(i as u32, 0, crate::persistent_artrie::NodeType::Node4);
node.add_child(i as u8, child).expect("add should succeed");
}
let node4 = node.shrink();
assert_eq!(node4.header.node_type, 4);
assert_eq!(node4.header.num_children, 4);
for i in 0..4 {
assert!(node4.find_child(i as u8).is_some());
}
}
}