use std::collections::HashMap;
use super::{AddChildError, CharArtNode, CharCompressedPrefix, CharNodeHeader};
use crate::persistent_artrie::swizzled_ptr::SwizzledPtr;
pub const CHARBUCKET_SHRINK_THRESHOLD: usize = 48;
#[derive(Debug, Clone)]
pub struct CharBucket {
pub header: CharNodeHeader,
pub prefix: CharCompressedPrefix,
pub entries: HashMap<u32, SwizzledPtr>,
pub value_ptr: SwizzledPtr,
}
impl CharBucket {
pub fn new() -> Self {
Self {
header: CharNodeHeader::new(101), prefix: CharCompressedPrefix::empty(),
entries: HashMap::with_capacity(64),
value_ptr: SwizzledPtr::null(),
}
}
pub fn with_prefix(prefix: &[u32]) -> Self {
let mut node = Self::new();
node.prefix = CharCompressedPrefix::from_chars(prefix);
node.header.prefix_len = prefix.len() as u8;
node
}
pub fn shrink(&self) -> super::CharNode48 {
debug_assert!(
self.header.num_children as usize <= CHARBUCKET_SHRINK_THRESHOLD,
"cannot shrink CharBucket with {} children",
self.header.num_children
);
let mut node48 = super::CharNode48::new();
node48.header = self.header.clone();
node48.header.node_type = 48;
node48.prefix = self.prefix;
node48.value_ptr = self.value_ptr.clone();
let mut entries: Vec<_> = self.entries.iter().collect();
entries.sort_by_key(|&(k, _)| *k);
for (i, (key, child)) in entries.iter().enumerate() {
node48.keys[i] = **key;
node48.children[i] = (*child).clone();
}
node48
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
impl Default for CharBucket {
fn default() -> Self {
Self::new()
}
}
impl CharArtNode for CharBucket {
fn find_child(&self, key: u32) -> Option<&SwizzledPtr> {
self.entries.get(&key)
}
fn find_child_mut(&mut self, key: u32) -> Option<&mut SwizzledPtr> {
self.entries.get_mut(&key)
}
fn add_child(&mut self, key: u32, child: SwizzledPtr) -> Result<(), AddChildError> {
if self.entries.contains_key(&key) {
return Err(AddChildError::KeyExists);
}
self.entries.insert(key, child);
self.header.num_children += 1;
Ok(())
}
fn remove_child(&mut self, key: u32) -> Option<SwizzledPtr> {
if let Some(removed) = self.entries.remove(&key) {
self.header.num_children -= 1;
Some(removed)
} else {
None
}
}
fn is_full(&self) -> bool {
false
}
fn iter_children(&self) -> impl Iterator<Item = (u32, &SwizzledPtr)> {
self.entries.iter().map(|(&k, c)| (k, c))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::persistent_artrie::NodeType;
#[test]
fn test_new_charbucket() {
let node = CharBucket::new();
assert_eq!(node.header.node_type, 101); assert_eq!(node.header.num_children, 0);
assert!(!node.is_full()); }
#[test]
fn test_add_and_find_children() {
let mut node = CharBucket::new();
for i in 0..100 {
let child = SwizzledPtr::on_disk(i, 0, NodeType::Node4);
assert!(node.add_child(i, child).is_ok());
}
assert_eq!(node.header.num_children, 100);
assert_eq!(node.len(), 100);
for i in 0..100 {
assert!(node.find_child(i).is_some(), "should find key {}", i);
}
assert!(node.find_child(200).is_none());
}
#[test]
fn test_charbucket_never_full() {
let mut node = CharBucket::new();
for i in 0..500 {
let child = SwizzledPtr::on_disk(i, 0, NodeType::Node4);
assert!(node.add_child(i, child).is_ok());
assert!(!node.is_full()); }
assert_eq!(node.header.num_children, 500);
}
#[test]
fn test_duplicate_key() {
let mut node = CharBucket::new();
let child = SwizzledPtr::on_disk(42, 0, NodeType::Node4);
assert!(node.add_child(42, child.clone()).is_ok());
assert_eq!(node.add_child(42, child), Err(AddChildError::KeyExists));
}
#[test]
fn test_remove_child() {
let mut node = CharBucket::new();
for i in 0..60 {
let child = SwizzledPtr::on_disk(i, 0, NodeType::Node4);
node.add_child(i, child).expect("add should succeed");
}
for i in (0..60).step_by(2) {
let removed = node.remove_child(i);
assert!(removed.is_some());
}
assert_eq!(node.header.num_children, 30);
for i in (1..60).step_by(2) {
assert!(node.find_child(i).is_some());
}
for i in (0..60).step_by(2) {
assert!(node.find_child(i).is_none());
}
}
#[test]
fn test_shrink_to_node48() {
let mut node = CharBucket::new();
for i in 0..48 {
let child = SwizzledPtr::on_disk(i, 0, NodeType::Node4);
node.add_child(i, child).expect("add should succeed");
}
node.header.set_final(true);
let node48 = node.shrink();
assert_eq!(node48.header.node_type, 48);
assert_eq!(node48.header.num_children, 48);
assert!(node48.header.is_final());
let keys: Vec<_> = node48.iter_children().map(|(k, _)| k).collect();
for i in 0..48u32 {
assert_eq!(keys[i as usize], i);
}
}
#[test]
fn test_unicode_keys() {
let mut node = CharBucket::new();
let chars: Vec<u32> = "αβγδεζηθικλμνξοπρστυφχψω日本語中文한글🎉🎊🎋🎌🎍🎎🎏🎐🎑🎒🎓"
.chars()
.map(|c| c as u32)
.collect();
for &key in &chars {
let child = SwizzledPtr::on_disk(key, 0, NodeType::Node4);
assert!(node.add_child(key, child).is_ok());
}
for &key in &chars {
assert!(
node.find_child(key).is_some(),
"should find key {}",
char::from_u32(key).unwrap_or('?')
);
}
}
#[test]
fn test_iter_children() {
let mut node = CharBucket::new();
for i in 0..20 {
let child = SwizzledPtr::on_disk(i, 0, NodeType::Node4);
node.add_child(i, child).expect("add should succeed");
}
let keys: std::collections::HashSet<_> = node.iter_children().map(|(k, _)| k).collect();
assert_eq!(keys.len(), 20);
for i in 0..20 {
assert!(keys.contains(&i));
}
}
#[test]
fn test_len_and_is_empty() {
let mut node = CharBucket::new();
assert!(node.is_empty());
assert_eq!(node.len(), 0);
for i in 0..10 {
let child = SwizzledPtr::on_disk(i, 0, NodeType::Node4);
node.add_child(i, child).expect("add should succeed");
}
assert!(!node.is_empty());
assert_eq!(node.len(), 10);
}
}