use super::{AddChildError, ArtNode, CompressedPrefix, NodeHeader};
use crate::persistent_artrie::swizzled_ptr::SwizzledPtr;
pub const NODE4_MAX_CHILDREN: usize = 4;
#[repr(C)]
#[derive(Debug, Clone)]
pub struct Node4 {
pub header: NodeHeader,
pub prefix: CompressedPrefix,
pub keys: [u8; NODE4_MAX_CHILDREN],
pub children: [SwizzledPtr; NODE4_MAX_CHILDREN],
}
impl Node4 {
pub fn new() -> Self {
Self {
header: NodeHeader::new(4),
prefix: CompressedPrefix::empty(),
keys: [0; NODE4_MAX_CHILDREN],
children: [
SwizzledPtr::null(),
SwizzledPtr::null(),
SwizzledPtr::null(),
SwizzledPtr::null(),
],
}
}
pub fn with_prefix(prefix: &[u8]) -> Self {
let mut node = Self::new();
node.prefix = CompressedPrefix::from_bytes(prefix);
node.header.prefix_len = prefix.len() as u8;
node
}
fn find_key_index(&self, key: u8) -> Result<usize, usize> {
let count = self.header.num_children as usize;
for i in 0..count {
if self.keys[i] == key {
return Ok(i);
}
if self.keys[i] > key {
return Err(i);
}
}
Err(count)
}
}
impl Default for Node4 {
fn default() -> Self {
Self::new()
}
}
impl ArtNode for Node4 {
fn find_child(&self, key: u8) -> Option<&SwizzledPtr> {
let count = self.header.num_children as usize;
for i in 0..count {
if self.keys[i] == key {
return Some(&self.children[i]);
}
}
None
}
fn find_child_mut(&mut self, key: u8) -> Option<&mut SwizzledPtr> {
let count = self.header.num_children as usize;
for i in 0..count {
if self.keys[i] == key {
return Some(&mut self.children[i]);
}
}
None
}
fn add_child(&mut self, key: u8, child: SwizzledPtr) -> Result<(), AddChildError> {
let count = self.header.num_children as usize;
if count >= NODE4_MAX_CHILDREN {
return Err(AddChildError::NodeFull);
}
match self.find_key_index(key) {
Ok(_) => return Err(AddChildError::KeyExists),
Err(insert_pos) => {
for i in (insert_pos..count).rev() {
self.keys[i + 1] = self.keys[i];
self.children[i + 1] = self.children[i].clone();
}
self.keys[insert_pos] = key;
self.children[insert_pos] = child;
self.header.num_children += 1;
Ok(())
}
}
}
fn remove_child(&mut self, key: u8) -> Option<SwizzledPtr> {
let count = self.header.num_children as usize;
if let Ok(index) = self.find_key_index(key) {
let removed = self.children[index].clone();
for i in index..(count - 1) {
self.keys[i] = self.keys[i + 1];
self.children[i] = self.children[i + 1].clone();
}
self.keys[count - 1] = 0;
self.children[count - 1] = SwizzledPtr::null();
self.header.num_children -= 1;
Some(removed)
} else {
None
}
}
fn is_full(&self) -> bool {
self.header.num_children as usize >= NODE4_MAX_CHILDREN
}
fn iter_children(&self) -> impl Iterator<Item = (u8, &SwizzledPtr)> {
let count = self.header.num_children as usize;
self.keys[..count]
.iter()
.zip(self.children[..count].iter())
.map(|(&k, c)| (k, c))
}
}
impl Node4 {
pub fn grow(&self) -> super::Node16 {
let mut node16 = super::Node16::new();
node16.header = self.header.clone();
node16.header.node_type = 16;
node16.prefix = self.prefix;
let count = self.header.num_children as usize;
for i in 0..count {
node16.keys[i] = self.keys[i];
node16.children[i] = self.children[i].clone();
}
node16
}
pub fn get_child_atomic(&self, key: u8) -> Option<SwizzledPtr> {
let count = self.header.num_children as usize;
for i in 0..count {
if self.keys[i] == key {
return Some(self.children[i].clone());
}
}
None
}
#[inline]
pub fn child_slot(&self, index: usize) -> &SwizzledPtr {
debug_assert!(index < NODE4_MAX_CHILDREN, "index {} out of bounds", index);
&self.children[index]
}
pub fn find_slot_for_key(&self, key: u8) -> Result<usize, usize> {
self.find_key_index(key)
}
pub fn next_slot(&self) -> Option<usize> {
let count = self.header.num_children as usize;
if count < NODE4_MAX_CHILDREN {
Some(count)
} else {
None
}
}
#[inline]
pub fn key_at(&self, index: usize) -> u8 {
debug_assert!(index < NODE4_MAX_CHILDREN, "index {} out of bounds", index);
self.keys[index]
}
pub fn iter_indexed(&self) -> impl Iterator<Item = (usize, u8, &SwizzledPtr)> {
let count = self.header.num_children as usize;
(0..count).map(move |i| (i, self.keys[i], &self.children[i]))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_node4() {
let node = Node4::new();
assert_eq!(node.header.node_type, 4);
assert_eq!(node.header.num_children, 0);
assert!(!node.is_full());
}
#[test]
fn test_add_and_find_children() {
let mut node = Node4::new();
let child_c = SwizzledPtr::on_disk(1, 0, crate::persistent_artrie::NodeType::Node4);
let child_a = SwizzledPtr::on_disk(2, 0, crate::persistent_artrie::NodeType::Node4);
let child_b = SwizzledPtr::on_disk(3, 0, crate::persistent_artrie::NodeType::Node4);
assert!(node.add_child(b'c', child_c).is_ok());
assert!(node.add_child(b'a', child_a).is_ok());
assert!(node.add_child(b'b', child_b).is_ok());
assert_eq!(node.header.num_children, 3);
assert_eq!(node.keys[0], b'a');
assert_eq!(node.keys[1], b'b');
assert_eq!(node.keys[2], b'c');
assert!(node.find_child(b'a').is_some());
assert!(node.find_child(b'b').is_some());
assert!(node.find_child(b'c').is_some());
assert!(node.find_child(b'd').is_none());
}
#[test]
fn test_node4_full() {
let mut node = Node4::new();
for i in 0..4 {
let child =
SwizzledPtr::on_disk(i as u32, 0, crate::persistent_artrie::NodeType::Node4);
assert!(node.add_child(i as u8, child).is_ok());
}
assert!(node.is_full());
let child = SwizzledPtr::on_disk(4, 0, crate::persistent_artrie::NodeType::Node4);
assert_eq!(node.add_child(4, child), Err(AddChildError::NodeFull));
}
#[test]
fn test_remove_child() {
let mut node = Node4::new();
for i in 0..4 {
let child =
SwizzledPtr::on_disk(i as u32, 0, crate::persistent_artrie::NodeType::Node4);
node.add_child(i as u8, child).expect("add should succeed");
}
assert_eq!(node.header.num_children, 4);
let removed = node.remove_child(2);
assert!(removed.is_some());
assert_eq!(node.header.num_children, 3);
assert!(node.find_child(2).is_none());
assert!(node.find_child(0).is_some());
assert!(node.find_child(1).is_some());
assert!(node.find_child(3).is_some());
}
#[test]
fn test_duplicate_key() {
let mut node = Node4::new();
let child = SwizzledPtr::on_disk(1, 0, crate::persistent_artrie::NodeType::Node4);
assert!(node.add_child(b'a', child).is_ok());
let child = SwizzledPtr::on_disk(2, 0, crate::persistent_artrie::NodeType::Node4);
assert_eq!(node.add_child(b'a', child), Err(AddChildError::KeyExists));
}
#[test]
fn test_iter_children() {
let mut node = Node4::new();
for i in 0..3 {
let child =
SwizzledPtr::on_disk(i as u32, 0, crate::persistent_artrie::NodeType::Node4);
node.add_child(b'a' + i, child).expect("add should succeed");
}
let children: Vec<_> = node.iter_children().map(|(k, _)| k).collect();
assert_eq!(children, vec![b'a', b'b', b'c']);
}
#[test]
fn test_with_prefix() {
let node = Node4::with_prefix(b"hello");
assert_eq!(node.header.prefix_len, 5);
assert_eq!(node.prefix.as_slice(5), b"hello");
}
}