use super::{AddChildError, CharArtNode, CharCompressedPrefix, CharNodeHeader};
use crate::persistent_artrie::swizzled_ptr::SwizzledPtr;
pub const CHARNODE48_MAX_CHILDREN: usize = 48;
#[repr(C)]
#[derive(Debug, Clone)]
pub struct CharNode48 {
pub header: CharNodeHeader,
pub prefix: CharCompressedPrefix,
pub keys: [u32; CHARNODE48_MAX_CHILDREN],
pub children: [SwizzledPtr; CHARNODE48_MAX_CHILDREN],
pub value_ptr: SwizzledPtr,
}
impl CharNode48 {
pub fn new() -> Self {
Self {
header: CharNodeHeader::new(148), prefix: CharCompressedPrefix::empty(),
keys: [0; CHARNODE48_MAX_CHILDREN],
children: std::array::from_fn(|_| SwizzledPtr::null()),
value_ptr: SwizzledPtr::null(),
}
}
pub fn with_prefix(prefix: &[u32]) -> Self {
let mut node = Self::new();
node.prefix = CharCompressedPrefix::from_chars(prefix);
node.header.prefix_len = prefix.len() as u8;
node
}
fn find_key_index(&self, key: u32) -> Option<usize> {
let count = self.header.num_children as usize;
if count == 0 {
return None;
}
match self.keys[..count].binary_search(&key) {
Ok(index) => Some(index),
Err(_) => None,
}
}
fn find_insert_point(&self, key: u32) -> usize {
let count = self.header.num_children as usize;
match self.keys[..count].binary_search(&key) {
Ok(index) => index, Err(index) => index, }
}
pub fn shrink(&self) -> super::CharNode16 {
debug_assert!(
self.header.num_children <= 16,
"cannot shrink CharNode48 with {} children",
self.header.num_children
);
let mut node16 = super::CharNode16::new();
node16.header = self.header.clone();
node16.header.node_type = 16;
node16.prefix = self.prefix;
node16.value_ptr = self.value_ptr.clone();
let count = self.header.num_children as usize;
for i in 0..count {
node16.keys[i] = self.keys[i];
node16.children[i] = self.children[i].clone();
}
node16
}
pub fn grow(&self) -> super::CharBucket {
let mut bucket = super::CharBucket::new();
bucket.header = self.header.clone();
bucket.header.node_type = 49; bucket.prefix = self.prefix;
bucket.value_ptr = self.value_ptr.clone();
let count = self.header.num_children as usize;
for i in 0..count {
bucket
.entries
.insert(self.keys[i], self.children[i].clone());
}
bucket
}
}
impl Default for CharNode48 {
fn default() -> Self {
Self::new()
}
}
impl CharArtNode for CharNode48 {
fn find_child(&self, key: u32) -> Option<&SwizzledPtr> {
self.find_key_index(key).map(|i| &self.children[i])
}
fn find_child_mut(&mut self, key: u32) -> Option<&mut SwizzledPtr> {
if let Some(i) = self.find_key_index(key) {
Some(&mut self.children[i])
} else {
None
}
}
fn add_child(&mut self, key: u32, child: SwizzledPtr) -> Result<(), AddChildError> {
let count = self.header.num_children as usize;
if count >= CHARNODE48_MAX_CHILDREN {
return Err(AddChildError::NodeFull);
}
if self.find_key_index(key).is_some() {
return Err(AddChildError::KeyExists);
}
let insert_pos = self.find_insert_point(key);
for i in (insert_pos..count).rev() {
self.keys[i + 1] = self.keys[i];
self.children[i + 1] = self.children[i].clone();
}
self.keys[insert_pos] = key;
self.children[insert_pos] = child;
self.header.num_children += 1;
Ok(())
}
fn remove_child(&mut self, key: u32) -> Option<SwizzledPtr> {
let count = self.header.num_children as usize;
if let Some(index) = self.find_key_index(key) {
let removed = self.children[index].clone();
for i in index..(count - 1) {
self.keys[i] = self.keys[i + 1];
self.children[i] = self.children[i + 1].clone();
}
self.keys[count - 1] = 0;
self.children[count - 1] = SwizzledPtr::null();
self.header.num_children -= 1;
Some(removed)
} else {
None
}
}
fn is_full(&self) -> bool {
self.header.num_children as usize >= CHARNODE48_MAX_CHILDREN
}
fn iter_children(&self) -> impl Iterator<Item = (u32, &SwizzledPtr)> {
let count = self.header.num_children as usize;
self.keys[..count]
.iter()
.zip(self.children[..count].iter())
.map(|(&k, c)| (k, c))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::persistent_artrie::NodeType;
#[test]
fn test_new_charnode48() {
let node = CharNode48::new();
assert_eq!(node.header.node_type, 148); assert_eq!(node.header.num_children, 0);
assert!(!node.is_full());
}
#[test]
fn test_add_and_find_children() {
let mut node = CharNode48::new();
let keys: Vec<u32> = vec![50, 10, 30, 70, 20, 40, 60, 80];
for &key in &keys {
let child = SwizzledPtr::on_disk(key, 0, NodeType::Node4);
assert!(node.add_child(key, child).is_ok());
}
assert_eq!(node.header.num_children, 8);
let sorted: Vec<_> = node.iter_children().map(|(k, _)| k).collect();
assert_eq!(sorted, vec![10, 20, 30, 40, 50, 60, 70, 80]);
for &key in &keys {
assert!(node.find_child(key).is_some(), "should find key {}", key);
}
assert!(node.find_child(100).is_none());
}
#[test]
fn test_binary_search_correctness() {
let mut node = CharNode48::new();
for i in (0..30).map(|x| x * 1000) {
let child = SwizzledPtr::on_disk(i, 0, NodeType::Node4);
node.add_child(i, child).expect("add should succeed");
}
for i in (0..30).map(|x| x * 1000) {
assert!(node.find_child(i).is_some(), "should find key {}", i);
}
for i in (0..30).map(|x| x * 1000 + 1) {
assert!(node.find_child(i).is_none(), "should not find key {}", i);
}
}
#[test]
fn test_charnode48_full() {
let mut node = CharNode48::new();
for i in 0..48 {
let child = SwizzledPtr::on_disk(i, 0, NodeType::Node4);
assert!(node.add_child(i, child).is_ok());
}
assert!(node.is_full());
let child = SwizzledPtr::on_disk(48, 0, NodeType::Node4);
assert_eq!(node.add_child(48, child), Err(AddChildError::NodeFull));
}
#[test]
fn test_remove_child() {
let mut node = CharNode48::new();
for i in 0..30 {
let child = SwizzledPtr::on_disk(i, 0, NodeType::Node4);
node.add_child(i, child).expect("add should succeed");
}
let removed = node.remove_child(15);
assert!(removed.is_some());
assert_eq!(node.header.num_children, 29);
assert!(node.find_child(15).is_none());
for i in 0..30 {
if i != 15 {
assert!(node.find_child(i).is_some());
}
}
let keys: Vec<_> = node.iter_children().map(|(k, _)| k).collect();
for window in keys.windows(2) {
assert!(window[0] < window[1], "keys should be sorted");
}
}
#[test]
fn test_shrink_to_node16() {
let mut node = CharNode48::new();
for i in 0..16 {
let child = SwizzledPtr::on_disk(i, 0, NodeType::Node4);
node.add_child(i, child).expect("add should succeed");
}
node.header.set_final(true);
let node16 = node.shrink();
assert_eq!(node16.header.node_type, 16);
assert_eq!(node16.header.num_children, 16);
assert!(node16.header.is_final());
for i in 0..16 {
assert!(node16.find_child(i).is_some());
}
}
#[test]
fn test_unicode_keys() {
let mut node = CharNode48::new();
let keys: Vec<u32> = "αβγδεζηθικλμνξοπρστυφχψω"
.chars()
.map(|c| c as u32)
.collect();
for &key in &keys {
let child = SwizzledPtr::on_disk(key, 0, NodeType::Node4);
assert!(node.add_child(key, child).is_ok());
}
assert_eq!(node.header.num_children, 24);
for &key in &keys {
assert!(
node.find_child(key).is_some(),
"should find key {}",
char::from_u32(key).unwrap_or('?')
);
}
}
}