use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
use std::sync::Arc;
use crate::persistent_artrie_core::key_encoding::KeyEncoding;
use crate::persistent_artrie_core::swizzled_ptr::SwizzledPtr;
#[allow(unused_imports)]
use crate::value::DictionaryValue;
pub mod flags {
pub const IS_FINAL: u8 = 0b0000_0001;
pub const IS_DIRTY: u8 = 0b0000_0010;
pub const IS_LEAF: u8 = 0b0000_0100;
pub const HAS_VALUE: u8 = 0b0000_1000;
}
const INLINE_CAPACITY: usize = 4;
pub enum Child<K: KeyEncoding, V = ()> {
InMem(Arc<OverlayNode<K, V>>),
OnDisk(SwizzledPtr),
}
impl<K: KeyEncoding, V: Clone> Clone for Child<K, V> {
fn clone(&self) -> Self {
match self {
Child::InMem(node) => Child::InMem(Arc::clone(node)),
Child::OnDisk(ptr) => Child::OnDisk(ptr.clone()),
}
}
}
impl<K: KeyEncoding, V> std::fmt::Debug for Child<K, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Child::InMem(_) => f.write_str("Child::InMem(..)"),
Child::OnDisk(p) => f.debug_tuple("Child::OnDisk").field(p).finish(),
}
}
}
impl<K: KeyEncoding, V> Child<K, V> {
#[inline]
fn empty() -> Self {
Child::OnDisk(SwizzledPtr::null())
}
#[inline]
pub fn is_null(&self) -> bool {
matches!(self, Child::OnDisk(p) if p.is_null())
}
#[inline]
pub fn is_on_disk(&self) -> bool {
matches!(self, Child::OnDisk(_))
}
#[inline]
pub fn as_in_mem(&self) -> Option<&Arc<OverlayNode<K, V>>> {
match self {
Child::InMem(node) => Some(node),
Child::OnDisk(_) => None,
}
}
#[inline]
pub fn as_on_disk(&self) -> Option<&SwizzledPtr> {
match self {
Child::OnDisk(ptr) => Some(ptr),
Child::InMem(_) => None,
}
}
}
enum ChildStore<K: KeyEncoding, V = ()> {
Inline {
count: u8,
keys: [K::Unit; INLINE_CAPACITY],
children: [Child<K, V>; INLINE_CAPACITY],
},
Heap {
keys: Vec<K::Unit>,
children: Vec<Child<K, V>>,
},
}
impl<K: KeyEncoding, V: Clone> Clone for ChildStore<K, V> {
fn clone(&self) -> Self {
match self {
ChildStore::Inline {
count,
keys,
children,
} => ChildStore::Inline {
count: *count,
keys: *keys,
children: children.clone(),
},
ChildStore::Heap { keys, children } => ChildStore::Heap {
keys: keys.clone(),
children: children.clone(),
},
}
}
}
impl<K: KeyEncoding, V> std::fmt::Debug for ChildStore<K, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ChildStore::Inline { count, .. } => f
.debug_struct("ChildStore::Inline")
.field("count", count)
.finish_non_exhaustive(),
ChildStore::Heap { keys, .. } => f
.debug_struct("ChildStore::Heap")
.field("len", &keys.len())
.finish_non_exhaustive(),
}
}
}
impl<K: KeyEncoding, V> ChildStore<K, V> {
#[inline]
fn empty_inline() -> Self {
ChildStore::Inline {
count: 0,
keys: [K::UNIT_ZERO; INLINE_CAPACITY],
children: [
Child::empty(),
Child::empty(),
Child::empty(),
Child::empty(),
],
}
}
#[inline]
fn take(&mut self) -> Self {
std::mem::replace(self, Self::empty_inline())
}
fn drain_in_mem_into(self, out: &mut Vec<Arc<OverlayNode<K, V>>>) {
match self {
ChildStore::Inline {
count, children, ..
} => {
for child in children.into_iter().take(count as usize) {
if let Child::InMem(arc) = child {
out.push(arc);
}
}
}
ChildStore::Heap { children, .. } => {
for child in children {
if let Child::InMem(arc) = child {
out.push(arc);
}
}
}
}
}
}
impl<K: KeyEncoding, V: Clone> ChildStore<K, V> {
#[inline]
fn new() -> Self {
Self::empty_inline()
}
#[inline]
fn len(&self) -> usize {
match self {
ChildStore::Inline { count, .. } => *count as usize,
ChildStore::Heap { keys, .. } => keys.len(),
}
}
#[inline]
fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
fn find_child(&self, key: K::Unit) -> Option<&Child<K, V>> {
match self {
ChildStore::Inline {
count,
keys,
children,
} => {
let n = *count as usize;
for i in 0..n {
if keys[i] == key {
return Some(&children[i]);
}
if keys[i] > key {
return None;
}
}
None
}
ChildStore::Heap { keys, children } => match keys.binary_search(&key) {
Ok(idx) => Some(&children[idx]),
Err(_) => None,
},
}
}
#[inline]
fn has_child(&self, key: K::Unit) -> bool {
self.find_child(key).is_some()
}
#[inline]
fn child_at(&self, index: usize) -> Option<(&K::Unit, &Child<K, V>)> {
match self {
ChildStore::Inline {
count,
keys,
children,
} => {
if index < *count as usize {
Some((&keys[index], &children[index]))
} else {
None
}
}
ChildStore::Heap { keys, children } => {
if index < keys.len() {
Some((&keys[index], &children[index]))
} else {
None
}
}
}
}
#[inline]
fn slices(&self) -> (&[K::Unit], &[Child<K, V>]) {
match self {
ChildStore::Inline {
count,
keys,
children,
} => {
let n = *count as usize;
(&keys[..n], &children[..n])
}
ChildStore::Heap { keys, children } => (keys.as_slice(), children.as_slice()),
}
}
fn with_child(&self, key: K::Unit, child: Child<K, V>) -> Self {
match self {
ChildStore::Inline {
count,
keys,
children,
} => {
let n = *count as usize;
let mut insert_pos = n;
for i in 0..n {
if keys[i] == key {
let new_keys = *keys;
let mut new_children = children.clone();
new_children[i] = child;
return ChildStore::Inline {
count: *count,
keys: new_keys,
children: new_children,
};
}
if keys[i] > key {
insert_pos = i;
break;
}
}
if n < INLINE_CAPACITY {
let mut new_keys = *keys;
let mut new_children = children.clone();
for i in (insert_pos..n).rev() {
new_keys[i + 1] = new_keys[i];
new_children[i + 1] = new_children[i].clone();
}
new_keys[insert_pos] = key;
new_children[insert_pos] = child;
ChildStore::Inline {
count: *count + 1,
keys: new_keys,
children: new_children,
}
} else {
let mut new_keys = Vec::with_capacity(n + 1);
let mut new_children = Vec::with_capacity(n + 1);
for i in 0..insert_pos {
new_keys.push(keys[i]);
new_children.push(children[i].clone());
}
new_keys.push(key);
new_children.push(child);
for i in insert_pos..n {
new_keys.push(keys[i]);
new_children.push(children[i].clone());
}
ChildStore::Heap {
keys: new_keys,
children: new_children,
}
}
}
ChildStore::Heap { keys, children } => {
match keys.binary_search(&key) {
Ok(idx) => {
let mut new_children = children.clone();
new_children[idx] = child;
ChildStore::Heap {
keys: keys.clone(),
children: new_children,
}
}
Err(idx) => {
let mut new_keys = keys.clone();
let mut new_children = children.clone();
new_keys.insert(idx, key);
new_children.insert(idx, child);
ChildStore::Heap {
keys: new_keys,
children: new_children,
}
}
}
}
}
}
fn without_child(&self, key: K::Unit) -> Option<Self> {
match self {
ChildStore::Inline {
count,
keys,
children,
} => {
let n = *count as usize;
let mut found_pos = None;
for i in 0..n {
if keys[i] == key {
found_pos = Some(i);
break;
}
if keys[i] > key {
return None; }
}
let pos = found_pos?;
let mut new_keys = *keys;
let mut new_children = children.clone();
for i in pos..n - 1 {
new_keys[i] = new_keys[i + 1];
new_children[i] = new_children[i + 1].clone();
}
new_keys[n - 1] = K::UNIT_ZERO;
new_children[n - 1] = Child::empty();
Some(ChildStore::Inline {
count: *count - 1,
keys: new_keys,
children: new_children,
})
}
ChildStore::Heap { keys, children } => {
let idx = keys.binary_search(&key).ok()?;
let new_len = keys.len() - 1;
if new_len <= INLINE_CAPACITY {
let mut new_keys = [K::UNIT_ZERO; INLINE_CAPACITY];
let mut new_children = [
Child::empty(),
Child::empty(),
Child::empty(),
Child::empty(),
];
let mut j = 0;
for i in 0..keys.len() {
if i != idx {
new_keys[j] = keys[i];
new_children[j] = children[i].clone();
j += 1;
}
}
Some(ChildStore::Inline {
count: new_len as u8,
keys: new_keys,
children: new_children,
})
} else {
let mut new_keys = keys.clone();
let mut new_children = children.clone();
new_keys.remove(idx);
new_children.remove(idx);
Some(ChildStore::Heap {
keys: new_keys,
children: new_children,
})
}
}
}
}
fn memory_usage(&self) -> usize {
match self {
ChildStore::Inline { count, .. } => {
let n = *count as usize;
n * (std::mem::size_of::<K::Unit>() + std::mem::size_of::<Child<K, V>>())
}
ChildStore::Heap { keys, children } => {
keys.capacity() * std::mem::size_of::<K::Unit>()
+ children.capacity() * std::mem::size_of::<Child<K, V>>()
}
}
}
}
pub struct OverlayNode<K: KeyEncoding, V = ()> {
version: AtomicU64,
serial_disk_ptr: AtomicU64,
store: ChildStore<K, V>,
flags: AtomicU8,
value: Option<V>,
prefix: Arc<[K::Unit]>,
prefix_len: u8,
}
impl<K: KeyEncoding, V: Clone> std::fmt::Debug for OverlayNode<K, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OverlayNode")
.field("is_final", &self.is_final())
.field("num_children", &self.num_children())
.field("has_value", &self.value.is_some())
.field("prefix_len", &self.prefix_len)
.finish()
}
}
impl<K: KeyEncoding, V: Clone> OverlayNode<K, V> {
pub fn new() -> Self {
Self {
version: AtomicU64::new(0),
serial_disk_ptr: AtomicU64::new(0),
store: ChildStore::new(),
flags: AtomicU8::new(0),
value: None,
prefix: Arc::new([]),
prefix_len: 0,
}
}
pub fn with_prefix(prefix: &[K::Unit]) -> Self {
let prefix_len = prefix.len().min(K::MAX_PREFIX_LEN) as u8;
let prefix_data: Arc<[K::Unit]> = prefix[..prefix_len as usize].into();
Self {
version: AtomicU64::new(0),
serial_disk_ptr: AtomicU64::new(0),
store: ChildStore::new(),
flags: AtomicU8::new(0),
value: None,
prefix: prefix_data,
prefix_len,
}
}
#[inline]
pub fn version(&self) -> u64 {
self.version.load(Ordering::Acquire)
}
#[inline]
pub fn durable_stamp(&self) -> u64 {
self.serial_disk_ptr.load(Ordering::Acquire)
}
#[inline]
pub fn set_durable_stamp(&self, raw: u64) {
self.serial_disk_ptr.store(raw, Ordering::Release);
}
#[inline]
pub fn num_children(&self) -> usize {
self.store.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.store.is_empty()
}
#[inline]
pub fn prefix(&self) -> &[K::Unit] {
&self.prefix[..self.prefix_len as usize]
}
#[inline]
pub fn prefix_len(&self) -> usize {
self.prefix_len as usize
}
#[inline]
pub fn is_final(&self) -> bool {
self.flags.load(Ordering::Acquire) & flags::IS_FINAL != 0
}
#[inline]
pub fn has_value(&self) -> bool {
self.value.is_some()
}
#[inline]
pub fn get_value(&self) -> Option<V> {
self.value.clone()
}
#[inline]
pub fn try_set_final(&self) -> bool {
let old = self.flags.fetch_or(flags::IS_FINAL, Ordering::AcqRel);
let newly_final = (old & flags::IS_FINAL) == 0;
if newly_final {
self.serial_disk_ptr.store(0, Ordering::Release);
}
newly_final
}
#[inline]
pub fn find_child(&self, key: K::Unit) -> Option<&Child<K, V>> {
self.store.find_child(key)
}
#[inline]
pub fn has_child(&self, key: K::Unit) -> bool {
self.store.has_child(key)
}
#[inline]
pub fn child_at(&self, index: usize) -> Option<(&K::Unit, &Child<K, V>)> {
self.store.child_at(index)
}
pub fn iter_children(&self) -> impl Iterator<Item = (&K::Unit, &Child<K, V>)> {
let (keys, children) = self.store.slices();
keys.iter().zip(children.iter())
}
pub fn with_child(&self, key: K::Unit, child: Child<K, V>) -> Self {
Self {
version: AtomicU64::new(self.version.load(Ordering::Acquire) + 1),
serial_disk_ptr: AtomicU64::new(0),
store: self.store.with_child(key, child),
flags: AtomicU8::new(self.flags.load(Ordering::Acquire)),
value: self.value.clone(),
prefix: self.prefix.clone(),
prefix_len: self.prefix_len,
}
}
pub fn without_child(&self, key: K::Unit) -> Option<Self> {
self.store.without_child(key).map(|new_store| Self {
version: AtomicU64::new(self.version.load(Ordering::Acquire) + 1),
serial_disk_ptr: AtomicU64::new(0),
store: new_store,
flags: AtomicU8::new(self.flags.load(Ordering::Acquire)),
value: self.value.clone(),
prefix: self.prefix.clone(),
prefix_len: self.prefix_len,
})
}
pub fn with_prefix_replaced(&self, prefix: &[K::Unit]) -> Self {
let prefix_len = prefix.len().min(K::MAX_PREFIX_LEN) as u8;
let prefix_data: Arc<[K::Unit]> = prefix[..prefix_len as usize].into();
Self {
version: AtomicU64::new(self.version.load(Ordering::Acquire) + 1),
serial_disk_ptr: AtomicU64::new(0),
store: self.store.clone(),
flags: AtomicU8::new(self.flags.load(Ordering::Acquire)),
value: self.value.clone(),
prefix: prefix_data,
prefix_len,
}
}
pub fn as_final(&self) -> Self {
Self {
version: AtomicU64::new(self.version.load(Ordering::Acquire) + 1),
serial_disk_ptr: AtomicU64::new(0),
store: self.store.clone(),
flags: AtomicU8::new(self.flags.load(Ordering::Acquire) | flags::IS_FINAL),
value: self.value.clone(),
prefix: self.prefix.clone(),
prefix_len: self.prefix_len,
}
}
pub fn as_non_final(&self) -> Self {
Self {
version: AtomicU64::new(self.version.load(Ordering::Acquire) + 1),
serial_disk_ptr: AtomicU64::new(0),
store: self.store.clone(),
flags: AtomicU8::new(
self.flags.load(Ordering::Acquire) & !(flags::IS_FINAL | flags::HAS_VALUE),
),
value: None,
prefix: self.prefix.clone(),
prefix_len: self.prefix_len,
}
}
pub fn with_value(&self, value: V) -> Self {
Self {
version: AtomicU64::new(self.version.load(Ordering::Acquire) + 1),
serial_disk_ptr: AtomicU64::new(0),
store: self.store.clone(),
flags: AtomicU8::new(self.flags.load(Ordering::Acquire) | flags::HAS_VALUE),
value: Some(value),
prefix: self.prefix.clone(),
prefix_len: self.prefix_len,
}
}
pub fn match_prefix(&self, key: &[K::Unit]) -> usize {
let prefix = self.prefix();
let check_len = prefix.len().min(key.len());
for i in 0..check_len {
if prefix[i] != key[i] {
return i;
}
}
check_len
}
#[inline]
pub fn prefix_matches(&self, key: &[K::Unit]) -> bool {
self.match_prefix(key) == self.prefix_len()
}
pub fn memory_usage(&self) -> usize {
let base = std::mem::size_of::<Self>();
let store_heap = self.store.memory_usage();
let prefix_size = self.prefix.len() * std::mem::size_of::<K::Unit>();
base + store_heap + prefix_size
}
}
impl<K: KeyEncoding, V: Clone> Default for OverlayNode<K, V> {
fn default() -> Self {
Self::new()
}
}
impl<K: KeyEncoding, V: Clone> Clone for OverlayNode<K, V> {
fn clone(&self) -> Self {
Self {
version: AtomicU64::new(self.version.load(Ordering::Acquire)),
serial_disk_ptr: AtomicU64::new(0),
store: self.store.clone(),
flags: AtomicU8::new(self.flags.load(Ordering::Acquire)),
value: self.value.clone(),
prefix: self.prefix.clone(),
prefix_len: self.prefix_len,
}
}
}
impl<K: KeyEncoding, V> Drop for OverlayNode<K, V> {
fn drop(&mut self) {
let mut worklist: Vec<Arc<OverlayNode<K, V>>> = Vec::new();
self.store.take().drain_in_mem_into(&mut worklist);
while let Some(arc) = worklist.pop() {
if let Ok(mut node) = Arc::try_unwrap(arc) {
node.store.take().drain_in_mem_into(&mut worklist);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::persistent_artrie_core::key_encoding::{ByteKey, CharKey};
use crate::persistent_artrie_core::swizzled_ptr::NodeType;
type ByteNode = OverlayNode<ByteKey, ()>;
type ByteValuedNode = OverlayNode<ByteKey, u64>;
type CharNode = OverlayNode<CharKey, ()>;
type CharValuedNode = OverlayNode<CharKey, u64>;
#[test]
fn test_new_node() {
let node = ByteNode::new();
assert_eq!(node.num_children(), 0);
assert!(node.is_empty());
assert!(!node.is_final());
assert!(!node.has_value());
assert_eq!(node.version(), 0);
let cnode = CharNode::new();
assert_eq!(cnode.num_children(), 0);
assert!(cnode.is_empty());
}
fn check_durable_stamp_invariant<K: KeyEncoding>(nt: NodeType, edge: K::Unit) {
let node = OverlayNode::<K, u64>::new();
assert_eq!(node.durable_stamp(), 0, "new() stamp must be 0");
assert_eq!(
OverlayNode::<K, u64>::with_prefix(&[]).durable_stamp(),
0,
"with_prefix() stamp must be 0"
);
let raw = SwizzledPtr::on_disk(1, 100, nt).to_raw();
assert_ne!(raw, 0, "test ptr must be non-zero");
node.set_durable_stamp(raw);
assert_eq!(node.durable_stamp(), raw, "stamp must round-trip");
let child = Child::OnDisk(SwizzledPtr::on_disk(2, 200, nt));
assert_eq!(
node.with_child(edge, child).durable_stamp(),
0,
"with_child copy must clear the stamp"
);
assert_eq!(node.as_final().durable_stamp(), 0, "as_final copy clears");
assert_eq!(
node.as_non_final().durable_stamp(),
0,
"as_non_final copy clears"
);
assert_eq!(
node.with_value(7).durable_stamp(),
0,
"with_value copy clears"
);
assert_eq!(
node.with_prefix_replaced(&[]).durable_stamp(),
0,
"with_prefix_replaced copy clears"
);
assert_eq!(node.clone().durable_stamp(), 0, "Clone clears");
let parent = node.with_child(edge, Child::OnDisk(SwizzledPtr::on_disk(3, 300, nt)));
parent.set_durable_stamp(raw);
if let Some(removed) = parent.without_child(edge) {
assert_eq!(removed.durable_stamp(), 0, "without_child copy clears");
}
}
#[test]
fn durable_stamp_invariant_byte() {
check_durable_stamp_invariant::<ByteKey>(NodeType::Node4, b'a');
}
#[test]
fn durable_stamp_invariant_char() {
check_durable_stamp_invariant::<CharKey>(NodeType::CharNode4, 'a' as u32);
}
#[test]
fn test_with_prefix_byte() {
let prefix: Vec<u8> = b"hello".to_vec();
let node = ByteNode::with_prefix(&prefix);
assert_eq!(node.prefix_len(), 5);
assert_eq!(node.prefix(), b"hello");
}
#[test]
fn test_with_prefix_char() {
let prefix: Vec<u32> = "hello".chars().map(|c| c as u32).collect();
let node = CharNode::with_prefix(&prefix);
assert_eq!(node.prefix_len(), 5);
let got: Vec<u32> = node.prefix().to_vec();
assert_eq!(got, prefix);
}
#[test]
fn test_prefix_max_length_byte() {
let prefix: Vec<u8> = b"abcdefghijklmnop".to_vec();
let node = ByteNode::with_prefix(&prefix);
assert_eq!(node.prefix_len(), ByteKey::MAX_PREFIX_LEN);
assert_eq!(node.prefix_len(), 12);
assert_eq!(node.prefix(), b"abcdefghijkl");
}
#[test]
fn test_prefix_max_length_char() {
let prefix: Vec<u32> = "abcdefghi".chars().map(|c| c as u32).collect();
let node = CharNode::with_prefix(&prefix);
assert_eq!(node.prefix_len(), CharKey::MAX_PREFIX_LEN);
assert_eq!(node.prefix_len(), 6);
let got: Vec<u32> = node.prefix().to_vec();
assert_eq!(
got,
"abcdef".chars().map(|c| c as u32).collect::<Vec<u32>>()
);
}
#[test]
fn test_with_child_immutability_byte() {
let node = ByteNode::new();
let child = Child::OnDisk(SwizzledPtr::on_disk(1, 100, NodeType::Node4));
let node2 = node.with_child(b'a', child);
assert_eq!(node.num_children(), 0);
assert_eq!(node2.num_children(), 1);
assert!(node2.has_child(b'a'));
}
#[test]
fn test_with_child_sorted_order_char() {
let mut node = CharNode::new();
let keys: [u32; 4] = ['z' as u32, 'a' as u32, 'm' as u32, 'f' as u32];
for &k in &keys {
let child = Child::OnDisk(SwizzledPtr::on_disk(k, 0, NodeType::CharNode4));
node = node.with_child(k, child);
}
assert_eq!(node.num_children(), 4);
let collected: Vec<u32> = node.iter_children().map(|(&k, _)| k).collect();
assert_eq!(
collected,
vec!['a' as u32, 'f' as u32, 'm' as u32, 'z' as u32]
);
}
#[test]
fn test_with_child_replace_byte() {
let child1 = Child::OnDisk(SwizzledPtr::on_disk(1, 100, NodeType::Node4));
let ptr2 = SwizzledPtr::on_disk(2, 200, NodeType::Node4);
let child2_raw = ptr2.to_raw();
let node = ByteNode::new().with_child(b'a', child1);
assert_eq!(node.num_children(), 1);
let node2 = node.with_child(b'a', Child::OnDisk(ptr2));
assert_eq!(node2.num_children(), 1);
let found = node2.find_child(b'a').expect("should find child");
assert_eq!(
found.as_on_disk().expect("on-disk child").to_raw(),
child2_raw
);
}
#[test]
fn test_without_child_byte() {
let child = Child::OnDisk(SwizzledPtr::on_disk(1, 100, NodeType::Node4));
let node = ByteNode::new()
.with_child(b'a', child.clone())
.with_child(b'b', child.clone())
.with_child(b'c', child);
assert_eq!(node.num_children(), 3);
let node2 = node.without_child(b'b').expect("should remove");
assert_eq!(node2.num_children(), 2);
assert!(node2.has_child(b'a'));
assert!(!node2.has_child(b'b'));
assert!(node2.has_child(b'c'));
assert_eq!(node.num_children(), 3);
}
#[test]
fn test_without_child_not_found_byte() {
let node = ByteNode::new();
assert!(node.without_child(b'x').is_none());
}
#[test]
fn test_try_set_final_byte() {
let node = ByteNode::new();
assert!(node.try_set_final());
assert!(node.is_final());
assert!(!node.try_set_final());
assert!(node.is_final());
}
#[test]
fn test_version_increment_char() {
let node = CharNode::new();
assert_eq!(node.version(), 0);
let child = Child::OnDisk(SwizzledPtr::on_disk(1, 100, NodeType::CharNode4));
let node2 = node.with_child('a' as u32, child);
assert_eq!(node2.version(), 1);
let node3 = node2.as_final();
assert_eq!(node3.version(), 2);
}
#[test]
fn test_prefix_matching_byte() {
let prefix: Vec<u8> = b"hello".to_vec();
let node = ByteNode::with_prefix(&prefix);
assert_eq!(node.match_prefix(b"helloworld"), 5);
assert!(node.prefix_matches(b"helloworld"));
assert_eq!(node.match_prefix(b"help"), 3);
assert!(!node.prefix_matches(b"help"));
assert_eq!(node.match_prefix(b"world"), 0);
assert!(!node.prefix_matches(b"world"));
}
#[test]
fn test_clone_byte_valued() {
let node = ByteValuedNode::new();
let child = Child::OnDisk(SwizzledPtr::on_disk(1, 100, NodeType::Node4));
let node = node.with_child(b'a', child).as_final().with_value(42);
let cloned = node.clone();
assert_eq!(cloned.num_children(), 1);
assert!(cloned.is_final());
assert_eq!(cloned.get_value(), Some(42));
assert_eq!(cloned.version(), node.version());
}
#[test]
fn test_as_final_char() {
let node = CharNode::new();
assert!(!node.is_final());
let final_node = node.as_final();
assert!(final_node.is_final());
assert!(!node.is_final());
}
#[test]
fn test_with_value_char_valued() {
let node = CharValuedNode::new();
assert!(!node.has_value());
let valued_node = node.with_value(123);
assert!(valued_node.has_value());
assert_eq!(valued_node.get_value(), Some(123));
assert!(!node.has_value());
}
#[test]
fn test_iter_children_byte() {
let child = Child::OnDisk(SwizzledPtr::on_disk(1, 100, NodeType::Node4));
let node = ByteNode::new()
.with_child(b'c', child.clone())
.with_child(b'a', child.clone())
.with_child(b'b', child);
let pairs: Vec<(u8, u64)> = node
.iter_children()
.map(|(&k, c)| (k, c.as_on_disk().expect("on-disk child").to_raw()))
.collect();
assert_eq!(pairs.len(), 3);
assert_eq!(pairs[0].0, b'a');
assert_eq!(pairs[1].0, b'b');
assert_eq!(pairs[2].0, b'c');
}
#[test]
fn test_all_byte_values() {
let mut node = ByteNode::new();
let child = Child::OnDisk(SwizzledPtr::on_disk(1, 100, NodeType::Node4));
for key in 0u8..=255 {
node = node.with_child(key, child.clone());
}
assert_eq!(node.num_children(), 256);
assert!(matches!(node.store, ChildStore::Heap { .. }));
for key in 0u8..=255 {
assert!(node.has_child(key), "should find key {}", key);
}
let collected: Vec<u8> = node.iter_children().map(|(&k, _)| k).collect();
let expected: Vec<u8> = (0u8..=255).collect();
assert_eq!(collected, expected);
}
#[test]
fn test_inline_to_heap_promotion_byte() {
let mut node = ByteNode::new();
for i in 0..4u8 {
let child = Child::OnDisk(SwizzledPtr::on_disk(
i as u32,
(i as u32) * 100,
NodeType::Node4,
));
node = node.with_child(i + 100, child);
}
assert_eq!(node.num_children(), 4);
assert!(matches!(node.store, ChildStore::Inline { .. }));
let child = Child::OnDisk(SwizzledPtr::on_disk(5, 500, NodeType::Node4));
node = node.with_child(104, child);
assert_eq!(node.num_children(), 5);
assert!(matches!(node.store, ChildStore::Heap { .. }));
let keys: Vec<u8> = node.iter_children().map(|(&k, _)| k).collect();
assert_eq!(keys, vec![100, 101, 102, 103, 104]);
}
#[test]
fn test_heap_to_inline_demotion_char() {
let mut node = CharNode::new();
for i in 0..5u32 {
let child = Child::OnDisk(SwizzledPtr::on_disk(i, i * 100, NodeType::CharNode4));
node = node.with_child(i + 100, child);
}
assert!(matches!(node.store, ChildStore::Heap { .. }));
let node2 = node.without_child(102).expect("should remove");
assert_eq!(node2.num_children(), 4);
assert!(matches!(node2.store, ChildStore::Inline { .. }));
let keys: Vec<u32> = node2.iter_children().map(|(&k, _)| k).collect();
assert_eq!(keys, vec![100, 101, 103, 104]);
}
#[test]
fn test_heap_stays_heap_above_threshold_byte() {
let mut node = ByteNode::new();
for i in 0..6u8 {
let child = Child::OnDisk(SwizzledPtr::on_disk(
i as u32,
(i as u32) * 100,
NodeType::Node4,
));
node = node.with_child(i + 100, child);
}
assert!(matches!(node.store, ChildStore::Heap { .. }));
let node2 = node.without_child(102).expect("should remove");
assert_eq!(node2.num_children(), 5);
assert!(matches!(node2.store, ChildStore::Heap { .. }));
}
#[test]
fn test_child_at_byte() {
let child = Child::OnDisk(SwizzledPtr::on_disk(1, 100, NodeType::Node4));
let node = ByteNode::new()
.with_child(b'b', child.clone())
.with_child(b'a', child);
let (k, _) = node.child_at(0).expect("should exist");
assert_eq!(*k, b'a');
let (k, _) = node.child_at(1).expect("should exist");
assert_eq!(*k, b'b');
assert!(node.child_at(2).is_none());
}
#[test]
fn test_inline_replace_preserves_count_byte() {
let child1 = Child::OnDisk(SwizzledPtr::on_disk(1, 100, NodeType::Node4));
let child2 = Child::OnDisk(SwizzledPtr::on_disk(2, 200, NodeType::Node4));
let node = ByteNode::new()
.with_child(b'a', child1.clone())
.with_child(b'b', child1);
assert_eq!(node.num_children(), 2);
assert!(matches!(node.store, ChildStore::Inline { .. }));
let node2 = node.with_child(b'a', child2);
assert_eq!(node2.num_children(), 2);
assert!(matches!(node2.store, ChildStore::Inline { .. }));
}
#[test]
fn test_supersized_unicode_keys_char() {
let mut node = CharNode::new();
let keys: [u32; 3] = [0x1F600, 0x10FFFF, 'a' as u32];
for &k in &keys {
let child = Child::OnDisk(SwizzledPtr::on_disk(k, 0, NodeType::CharNode4));
node = node.with_child(k, child);
}
assert_eq!(node.num_children(), 3);
assert!(node.has_child(0x1F600));
assert!(node.has_child(0x10FFFF));
let collected: Vec<u32> = node.iter_children().map(|(&k, _)| k).collect();
assert_eq!(collected, vec!['a' as u32, 0x1F600, 0x10FFFF]);
}
fn an_on_disk_child<K: KeyEncoding, V>(raw: u32, nt: NodeType) -> Child<K, V> {
Child::OnDisk(SwizzledPtr::on_disk(raw, 0, nt))
}
fn check_sorted_order<K: KeyEncoding>(keys: &[K::Unit], nt: NodeType)
where
K::Unit: Into<u64>,
{
let mut node = OverlayNode::<K, ()>::new();
let original = OverlayNode::<K, ()>::new();
for (i, &k) in keys.iter().enumerate() {
node = node.with_child(k, an_on_disk_child::<K, ()>(i as u32, nt));
}
assert_eq!(node.num_children(), keys.len());
assert_eq!(original.num_children(), 0);
let collected: Vec<u64> = node.iter_children().map(|(&k, _)| k.into()).collect();
let mut sorted = collected.clone();
sorted.sort_unstable();
assert_eq!(
collected, sorted,
"children must iterate in ascending key order"
);
for &k in keys {
assert!(node.has_child(k), "inserted key must be present");
assert!(node.find_child(k).is_some());
}
for (idx, (&k, _)) in node.iter_children().enumerate() {
let (ck, _) = node.child_at(idx).expect("child_at within bounds");
assert_eq!(*ck, k);
}
assert!(node.child_at(keys.len()).is_none());
}
fn check_tier_transitions<K: KeyEncoding>(base: u32, nt: NodeType)
where
K::Unit: TryFrom<u32>,
<K::Unit as TryFrom<u32>>::Error: std::fmt::Debug,
{
let mk = |u: u32| K::Unit::try_from(u).expect("unit fits");
let mut node = OverlayNode::<K, ()>::new();
for i in 0..4u32 {
node = node.with_child(mk(base + i), an_on_disk_child::<K, ()>(i, nt));
}
assert_eq!(node.num_children(), 4);
assert!(matches!(node.store, ChildStore::Inline { .. }));
node = node.with_child(mk(base + 4), an_on_disk_child::<K, ()>(4, nt));
assert_eq!(node.num_children(), 5);
assert!(matches!(node.store, ChildStore::Heap { .. }));
let demoted = node
.without_child(mk(base + 2))
.expect("remove present key");
assert_eq!(demoted.num_children(), 4);
assert!(matches!(demoted.store, ChildStore::Inline { .. }));
let six = node.with_child(mk(base + 5), an_on_disk_child::<K, ()>(5, nt));
let five = six.without_child(mk(base + 2)).expect("remove present key");
assert_eq!(five.num_children(), 5);
assert!(matches!(five.store, ChildStore::Heap { .. }));
assert!(OverlayNode::<K, ()>::new()
.without_child(mk(base))
.is_none());
}
fn check_value_and_final<K: KeyEncoding>(nt: NodeType)
where
K::Unit: TryFrom<u32>,
<K::Unit as TryFrom<u32>>::Error: std::fmt::Debug,
{
let k0 = K::Unit::try_from(b'a' as u32).expect("unit fits");
let node = OverlayNode::<K, u64>::new();
assert!(!node.is_final());
assert!(!node.has_value());
let valued = node
.with_child(k0, an_on_disk_child::<K, u64>(1, nt))
.as_final()
.with_value(7);
assert!(valued.is_final());
assert!(valued.has_value());
assert_eq!(valued.get_value(), Some(7));
assert_eq!(valued.num_children(), 1);
let fresh = OverlayNode::<K, u64>::new();
assert!(fresh.try_set_final());
assert!(!fresh.try_set_final());
assert!(!node.is_final());
assert!(node.get_value().is_none());
}
fn check_prefix<K: KeyEncoding>(units: &[K::Unit])
where
K::Unit: PartialEq,
{
let node = OverlayNode::<K, ()>::with_prefix(units);
let expect_len = units.len().min(K::MAX_PREFIX_LEN);
assert_eq!(node.prefix_len(), expect_len);
assert_eq!(node.prefix(), &units[..expect_len]);
let key = &units[..expect_len];
assert_eq!(node.match_prefix(key), expect_len);
assert!(node.prefix_matches(key));
}
#[test]
fn generic_sorted_order_byte() {
check_sorted_order::<ByteKey>(&[b'z', b'a', b'm', b'f'], NodeType::Node4);
}
#[test]
fn generic_sorted_order_char() {
check_sorted_order::<CharKey>(
&['z' as u32, 'a' as u32, 0x1F600, 'f' as u32],
NodeType::CharNode4,
);
}
#[test]
fn generic_tier_transitions_byte() {
check_tier_transitions::<ByteKey>(100, NodeType::Node4);
}
#[test]
fn generic_tier_transitions_char() {
check_tier_transitions::<CharKey>(0x3000, NodeType::CharNode4);
}
#[test]
fn generic_value_and_final_byte() {
check_value_and_final::<ByteKey>(NodeType::Node4);
}
#[test]
fn generic_value_and_final_char() {
check_value_and_final::<CharKey>(NodeType::CharNode4);
}
#[test]
fn generic_prefix_byte() {
let units: Vec<u8> = b"abcdefghijklmnop".to_vec();
check_prefix::<ByteKey>(&units);
check_prefix::<ByteKey>(b"hi");
}
#[test]
fn generic_prefix_char() {
let units: Vec<u32> = "abcdefghi".chars().map(|c| c as u32).collect();
check_prefix::<CharKey>(&units);
let short: Vec<u32> = vec!['h' as u32, 0x1F600];
check_prefix::<CharKey>(&short);
}
#[test]
fn generic_memory_usage_is_monotonic_byte() {
let empty = OverlayNode::<ByteKey, ()>::new().memory_usage();
let with_one = OverlayNode::<ByteKey, ()>::new()
.with_child(b'a', an_on_disk_child::<ByteKey, ()>(1, NodeType::Node4))
.memory_usage();
assert!(
with_one >= empty,
"adding a child must not shrink reported usage"
);
}
#[test]
fn generic_memory_usage_is_monotonic_char() {
let empty = OverlayNode::<CharKey, ()>::new().memory_usage();
let with_one = OverlayNode::<CharKey, ()>::new()
.with_child(
'a' as u32,
an_on_disk_child::<CharKey, ()>(1, NodeType::CharNode4),
)
.memory_usage();
assert!(
with_one >= empty,
"adding a child must not shrink reported usage"
);
}
fn check_replace_child<K: KeyEncoding>(nt: NodeType)
where
K::Unit: TryFrom<u32>,
<K::Unit as TryFrom<u32>>::Error: std::fmt::Debug,
{
let k = K::Unit::try_from(b'a' as u32).expect("unit fits");
let raw1 = SwizzledPtr::on_disk(1, 100, nt).to_raw();
let raw2 = SwizzledPtr::on_disk(2, 200, nt).to_raw();
let node =
OverlayNode::<K, ()>::new().with_child(k, Child::OnDisk(SwizzledPtr::from_raw(raw1)));
assert_eq!(node.num_children(), 1);
let found1 = node
.find_child(k)
.expect("present")
.as_on_disk()
.expect("on-disk")
.to_raw();
assert_eq!(found1, raw1);
let node2 = node.with_child(k, Child::OnDisk(SwizzledPtr::from_raw(raw2)));
assert_eq!(
node2.num_children(),
1,
"replace must not change child count"
);
let found2 = node2
.find_child(k)
.expect("present")
.as_on_disk()
.expect("on-disk")
.to_raw();
assert_eq!(found2, raw2);
}
fn check_remove_middle<K: KeyEncoding>(keys: &[K::Unit], remove: K::Unit, nt: NodeType)
where
K::Unit: Into<u64> + PartialEq,
{
let mut node = OverlayNode::<K, ()>::new();
for (i, &k) in keys.iter().enumerate() {
node = node.with_child(k, an_on_disk_child::<K, ()>(i as u32, nt));
}
let before = node.num_children();
let node2 = node.without_child(remove).expect("remove a present key");
assert_eq!(node2.num_children(), before - 1);
assert!(!node2.has_child(remove));
let survivors: Vec<u64> = node2.iter_children().map(|(&k, _)| k.into()).collect();
let mut sorted = survivors.clone();
sorted.sort_unstable();
assert_eq!(survivors, sorted);
assert_eq!(node.num_children(), before);
assert!(node.has_child(remove));
}
fn check_find_miss_and_version<K: KeyEncoding>(nt: NodeType)
where
K::Unit: TryFrom<u32>,
<K::Unit as TryFrom<u32>>::Error: std::fmt::Debug,
{
let mk = |u: u32| K::Unit::try_from(u).expect("unit fits");
let node = OverlayNode::<K, ()>::new();
assert_eq!(node.version(), 0);
assert!(node.find_child(mk(b'x' as u32)).is_none());
let node2 = node.with_child(mk(b'm' as u32), an_on_disk_child::<K, ()>(1, nt));
assert_eq!(node2.version(), 1);
assert!(node2.find_child(mk(b'a' as u32)).is_none());
assert!(node2.find_child(mk(b'z' as u32)).is_none());
let node3 = node2.as_final();
assert_eq!(
node3.version(),
2,
"as_final is a structural edit ⇒ +1 version"
);
}
fn check_large_fanout<K: KeyEncoding>(range: std::ops::RangeInclusive<u32>, nt: NodeType)
where
K::Unit: TryFrom<u32> + Into<u64>,
<K::Unit as TryFrom<u32>>::Error: std::fmt::Debug,
{
let mut node = OverlayNode::<K, ()>::new();
let mut expected: Vec<u64> = Vec::new();
for u in range.clone() {
let k = K::Unit::try_from(u).expect("unit fits");
node = node.with_child(k, an_on_disk_child::<K, ()>(u, nt));
expected.push(k.into());
}
assert_eq!(node.num_children(), expected.len());
assert!(matches!(node.store, ChildStore::Heap { .. }));
for u in range {
let k = K::Unit::try_from(u).expect("unit fits");
assert!(node.has_child(k), "key {u} must be present");
}
let collected: Vec<u64> = node.iter_children().map(|(&k, _)| k.into()).collect();
expected.sort_unstable();
assert_eq!(collected, expected);
}
#[test]
fn generic_replace_child_byte() {
check_replace_child::<ByteKey>(NodeType::Node4);
}
#[test]
fn generic_replace_child_char() {
check_replace_child::<CharKey>(NodeType::CharNode4);
}
#[test]
fn generic_remove_middle_byte() {
check_remove_middle::<ByteKey>(&[b'a', b'b', b'c'], b'b', NodeType::Node4);
}
#[test]
fn generic_remove_middle_char() {
check_remove_middle::<CharKey>(
&['a' as u32, 'b' as u32, 0x1F600],
'b' as u32,
NodeType::CharNode4,
);
}
#[test]
fn generic_find_miss_and_version_byte() {
check_find_miss_and_version::<ByteKey>(NodeType::Node4);
}
#[test]
fn generic_find_miss_and_version_char() {
check_find_miss_and_version::<CharKey>(NodeType::CharNode4);
}
#[test]
fn generic_large_fanout_byte() {
check_large_fanout::<ByteKey>(0..=255, NodeType::Node4);
}
#[test]
fn generic_large_fanout_char() {
check_large_fanout::<CharKey>(0x2000..=0x2100, NodeType::CharNode4);
}
fn check_prefix_replaced_default_debug<K: KeyEncoding>(units: &[K::Unit], nt: NodeType)
where
K::Unit: TryFrom<u32> + PartialEq,
<K::Unit as TryFrom<u32>>::Error: std::fmt::Debug,
{
let k0 = K::Unit::try_from(b'a' as u32).expect("unit fits");
let node = OverlayNode::<K, ()>::new().with_child(k0, an_on_disk_child::<K, ()>(1, nt));
assert!(!node.is_empty());
let v0 = node.version();
let replaced = node.with_prefix_replaced(units);
let expect_len = units.len().min(K::MAX_PREFIX_LEN);
assert_eq!(replaced.prefix_len(), expect_len);
assert_eq!(replaced.prefix(), &units[..expect_len]);
assert_eq!(
replaced.num_children(),
1,
"replacing prefix keeps children"
);
assert_eq!(replaced.version(), v0 + 1);
let d: OverlayNode<K, ()> = Default::default();
assert!(d.is_empty());
assert!(!d.is_final());
let s = format!("{:?}", replaced);
assert!(s.contains("OverlayNode"));
}
fn check_heap_clone<K: KeyEncoding>(nt: NodeType)
where
K::Unit: TryFrom<u32> + Into<u64>,
<K::Unit as TryFrom<u32>>::Error: std::fmt::Debug,
{
let mut node = OverlayNode::<K, ()>::new();
for u in 0..6u32 {
let k = K::Unit::try_from(0x40 + u).expect("unit fits");
node = node.with_child(k, an_on_disk_child::<K, ()>(u, nt));
}
assert!(matches!(node.store, ChildStore::Heap { .. }));
let cloned = node.clone();
assert_eq!(cloned.num_children(), node.num_children());
let a: Vec<u64> = node.iter_children().map(|(&k, _)| k.into()).collect();
let b: Vec<u64> = cloned.iter_children().map(|(&k, _)| k.into()).collect();
assert_eq!(a, b);
assert_eq!(cloned.version(), node.version());
}
#[test]
fn generic_prefix_replaced_default_debug_byte() {
check_prefix_replaced_default_debug::<ByteKey>(b"abcdefghijklmnop", NodeType::Node4);
}
#[test]
fn generic_prefix_replaced_default_debug_char() {
let units: Vec<u32> = "abcdefghi".chars().map(|c| c as u32).collect();
check_prefix_replaced_default_debug::<CharKey>(&units, NodeType::CharNode4);
}
#[test]
fn generic_heap_clone_byte() {
check_heap_clone::<ByteKey>(NodeType::Node4);
}
#[test]
fn generic_heap_clone_char() {
check_heap_clone::<CharKey>(NodeType::CharNode4);
}
fn check_as_non_final<K: KeyEncoding>(nt: NodeType)
where
K::Unit: TryFrom<u32>,
<K::Unit as TryFrom<u32>>::Error: std::fmt::Debug,
{
let kx = K::Unit::try_from(b'x' as u32).expect("unit fits");
let original = OverlayNode::<K, u64>::new()
.with_child(kx, an_on_disk_child::<K, u64>(1, nt))
.as_final()
.with_value(42);
assert!(original.is_final());
assert!(original.has_value());
assert_eq!(original.get_value(), Some(42));
assert_eq!(original.num_children(), 1);
let v_before = original.version();
let cleared = original.as_non_final();
assert!(!cleared.is_final(), "as_non_final must clear IS_FINAL");
assert!(!cleared.has_value(), "as_non_final must clear HAS_VALUE");
assert_eq!(
cleared.get_value(),
None,
"as_non_final must drop the value (None, not Some(0))"
);
assert_eq!(
cleared.num_children(),
1,
"as_non_final must RETAIN children (remove \"cat\" keeps \"cats\")"
);
assert!(
cleared.has_child(kx),
"the retained child must still be found"
);
assert_eq!(
cleared.version(),
v_before + 1,
"as_non_final is a structural edit ⇒ +1 version"
);
assert!(original.is_final(), "original must remain final");
assert!(original.has_value(), "original must keep its value");
assert_eq!(original.get_value(), Some(42));
let refinal = cleared.as_final();
assert!(refinal.is_final(), "re-finalized node must be final again");
assert!(
!refinal.has_value(),
"re-finalizing a cleared node does not resurrect the dropped value"
);
assert_eq!(refinal.num_children(), 1, "round-trip must keep children");
}
fn check_as_non_final_deep_child<K: KeyEncoding>(_nt: NodeType)
where
K::Unit: TryFrom<u32>,
<K::Unit as TryFrom<u32>>::Error: std::fmt::Debug,
{
let ks = K::Unit::try_from(b's' as u32).expect("unit fits");
let cats_leaf = Arc::new(OverlayNode::<K, ()>::new().as_final());
let cat = OverlayNode::<K, ()>::new()
.with_child(ks, Child::InMem(Arc::clone(&cats_leaf)))
.as_final();
assert!(cat.is_final());
assert!(cat.has_child(ks));
let cat_removed = cat.as_non_final();
assert!(!cat_removed.is_final(), "\"cat\" must no longer be final");
let surviving = cat_removed
.find_child(ks)
.expect("\"cats\" edge must survive removing \"cat\"")
.as_in_mem()
.expect("the surviving child is in-memory");
assert!(
surviving.is_final(),
"\"cats\" must remain final after removing the prefix \"cat\""
);
assert!(cat.is_final(), "original \"cat\" node unchanged");
}
#[test]
fn generic_as_non_final_byte() {
check_as_non_final::<ByteKey>(NodeType::Node4);
}
#[test]
fn generic_as_non_final_char() {
check_as_non_final::<CharKey>(NodeType::CharNode4);
}
#[test]
fn generic_as_non_final_deep_child_byte() {
check_as_non_final_deep_child::<ByteKey>(NodeType::Node4);
}
#[test]
fn generic_as_non_final_deep_child_char() {
check_as_non_final_deep_child::<CharKey>(NodeType::CharNode4);
}
}