use super::*;
use core::sync::atomic;
use core::sync::atomic::Ordering::Acquire;
use core::sync::atomic::Ordering::Relaxed;
use core::sync::atomic::Ordering::Release;
use std::alloc::alloc_zeroed;
use std::alloc::dealloc;
use std::alloc::handle_alloc_error;
use std::alloc::Layout;
use std::ops::Deref;
use std::ops::DerefMut;
use std::ptr::addr_of;
use std::ptr::addr_of_mut;
const BRANCH_ALIGN: usize = 16;
const BRANCH_BASE_SIZE: usize = 48;
const TABLE_ENTRY_SIZE: usize = 8;
#[inline]
pub(crate) fn dst_len<T>(ptr: *const [T]) -> usize {
let ptr: *const [()] = ptr as _;
let slice: &[()] = unsafe { &*ptr };
slice.len()
}
pub(crate) type BranchNN<const KEY_LEN: usize, O, V> =
NonNull<Branch<KEY_LEN, O, [Option<Head<KEY_LEN, O, V>>], V>>;
pub(crate) struct BranchMut<'a, const KEY_LEN: usize, O: KeySchema<KEY_LEN>, V> {
head: &'a mut Head<KEY_LEN, O, V>,
branch_nn: BranchNN<KEY_LEN, O, V>,
}
impl<'a, const KEY_LEN: usize, O: KeySchema<KEY_LEN>, V> BranchMut<'a, KEY_LEN, O, V> {
pub(crate) fn from_head(head: &'a mut Head<KEY_LEN, O, V>) -> Self {
match head.body_mut() {
BodyMut::Branch(branch_ref) => {
let nn = unsafe { NonNull::new_unchecked(branch_ref as *mut _) };
Self {
head,
branch_nn: nn,
}
}
BodyMut::Leaf(_) => panic!("BranchMut requires a Branch body"),
}
}
#[allow(dead_code)]
pub(crate) fn from_slot(slot: &'a mut Option<Head<KEY_LEN, O, V>>) -> Self {
let head = slot.as_mut().expect("slot should not be empty");
Self::from_head(head)
}
pub fn modify_child<F>(&mut self, key: u8, f: F)
where
F: FnOnce(Option<Head<KEY_LEN, O, V>>) -> Option<Head<KEY_LEN, O, V>>,
{
Branch::modify_child(&mut self.branch_nn, key, f);
}
}
impl<'a, const KEY_LEN: usize, O: KeySchema<KEY_LEN>, V> Deref for BranchMut<'a, KEY_LEN, O, V> {
type Target = Branch<KEY_LEN, O, [Option<Head<KEY_LEN, O, V>>], V>;
fn deref(&self) -> &Self::Target {
unsafe { self.branch_nn.as_ref() }
}
}
impl<'a, const KEY_LEN: usize, O: KeySchema<KEY_LEN>, V> DerefMut for BranchMut<'a, KEY_LEN, O, V> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.branch_nn.as_mut() }
}
}
impl<'a, const KEY_LEN: usize, O: KeySchema<KEY_LEN>, V> Drop for BranchMut<'a, KEY_LEN, O, V> {
fn drop(&mut self) {
self.head.set_body(self.branch_nn);
}
}
#[derive(Debug)]
#[repr(C, align(16))]
pub(crate) struct Branch<const KEY_LEN: usize, O: KeySchema<KEY_LEN>, Table: ?Sized, V> {
key_ordering: PhantomData<O>,
key_segments: PhantomData<O::Segmentation>,
rc: atomic::AtomicU32,
pub end_depth: u32,
pub childleaf: *const Leaf<KEY_LEN, V>,
pub leaf_count: u64,
pub segment_count: u64,
pub hash: u128,
pub child_table: Table,
}
impl<const KEY_LEN: usize, O: KeySchema<KEY_LEN>, Table: ?Sized, V> Branch<KEY_LEN, O, Table, V> {
pub fn childleaf(&self) -> &Leaf<KEY_LEN, V> {
unsafe { &*self.childleaf }
}
pub fn childleaf_ptr(&self) -> *const Leaf<KEY_LEN, V> {
self.childleaf
}
}
impl<const KEY_LEN: usize, O: KeySchema<KEY_LEN>, V> Body
for Branch<KEY_LEN, O, [Option<Head<KEY_LEN, O, V>>], V>
{
fn tag(body: NonNull<Self>) -> HeadTag {
unsafe {
let ptr = addr_of!((*body.as_ptr()).child_table);
let exp = dst_len(ptr).ilog2() as u8;
debug_assert!((1..=8).contains(&exp));
HeadTag::from_raw(exp)
}
}
}
impl<const KEY_LEN: usize, O: KeySchema<KEY_LEN>, V>
Branch<KEY_LEN, O, [Option<Head<KEY_LEN, O, V>>], V>
{
pub(super) fn new(
end_depth: usize,
lchild: Head<KEY_LEN, O, V>,
rchild: Head<KEY_LEN, O, V>,
) -> NonNull<Self> {
unsafe {
let size = 2;
let layout = Layout::from_size_align_unchecked(
BRANCH_BASE_SIZE + (TABLE_ENTRY_SIZE * size),
BRANCH_ALIGN,
);
let Some(ptr) =
NonNull::new(std::ptr::slice_from_raw_parts(alloc_zeroed(layout), size)
as *mut Branch<KEY_LEN, O, [Option<Head<KEY_LEN, O, V>>], V>)
else {
handle_alloc_error(layout);
};
addr_of_mut!((*ptr.as_ptr()).rc).write(atomic::AtomicU32::new(1));
addr_of_mut!((*ptr.as_ptr()).end_depth).write(end_depth as u32);
addr_of_mut!((*ptr.as_ptr()).childleaf).write(lchild.childleaf_ptr());
addr_of_mut!((*ptr.as_ptr()).leaf_count).write(lchild.count() + rchild.count());
addr_of_mut!((*ptr.as_ptr()).segment_count)
.write(lchild.count_segment(end_depth) + rchild.count_segment(end_depth));
addr_of_mut!((*ptr.as_ptr()).hash).write(lchild.hash() ^ rchild.hash());
(*ptr.as_ptr()).child_table[0] = Some(lchild);
(*ptr.as_ptr()).child_table[1] = Some(rchild);
ptr
}
}
pub(super) unsafe fn rc_inc(branch: NonNull<Self>) -> NonNull<Self> {
unsafe {
let branch = branch.as_ptr();
let mut current = (*branch).rc.load(Relaxed);
loop {
if current == u32::MAX {
panic!("max refcount exceeded");
}
match (*branch)
.rc
.compare_exchange(current, current + 1, Relaxed, Relaxed)
{
Ok(_) => return NonNull::new_unchecked(branch),
Err(v) => current = v,
}
}
}
}
pub(super) unsafe fn rc_dec(branch: NonNull<Self>) {
unsafe {
let branch = branch.as_ptr();
if (*branch).rc.fetch_sub(1, Release) != 1 {
return;
}
(*branch).rc.load(Acquire);
let size = dst_len(addr_of!((*branch).child_table));
std::ptr::drop_in_place(branch);
let layout = Layout::from_size_align_unchecked(
BRANCH_BASE_SIZE + (TABLE_ENTRY_SIZE * size),
BRANCH_ALIGN,
);
let ptr = branch as *mut u8;
dealloc(ptr, layout);
}
}
pub(super) unsafe fn rc_cow(branch_nn: &mut NonNull<Self>) -> Option<()> {
unsafe {
let branch = branch_nn.as_ptr();
if (*branch).rc.load(Acquire) == 1 {
None
} else {
let size = dst_len(addr_of!((*branch).child_table));
let layout = Layout::from_size_align_unchecked(
BRANCH_BASE_SIZE + (TABLE_ENTRY_SIZE * size),
BRANCH_ALIGN,
);
if let Some(ptr) =
NonNull::new(std::ptr::slice_from_raw_parts(alloc_zeroed(layout), size)
as *mut Branch<KEY_LEN, O, [Option<Head<KEY_LEN, O, V>>], V>)
{
addr_of_mut!((*ptr.as_ptr()).rc).write(atomic::AtomicU32::new(1));
addr_of_mut!((*ptr.as_ptr()).end_depth).write((*branch).end_depth);
addr_of_mut!((*ptr.as_ptr()).childleaf).write((*branch).childleaf);
addr_of_mut!((*ptr.as_ptr()).leaf_count).write((*branch).leaf_count);
addr_of_mut!((*ptr.as_ptr()).segment_count).write((*branch).segment_count);
addr_of_mut!((*ptr.as_ptr()).hash).write((*branch).hash);
(*ptr.as_ptr())
.child_table
.clone_from_slice(&(*branch).child_table);
Self::rc_dec(NonNull::new_unchecked(branch));
*branch_nn = ptr;
Some(())
} else {
handle_alloc_error(layout);
}
}
}
}
pub(crate) fn grow(branch_nn: &mut NonNull<Self>) {
unsafe {
let branch = branch_nn.as_ptr();
let old_size = dst_len(addr_of!((*branch).child_table));
let new_size = old_size * 2;
assert!(new_size <= 256);
let layout = Layout::from_size_align_unchecked(
BRANCH_BASE_SIZE + (TABLE_ENTRY_SIZE * new_size),
BRANCH_ALIGN,
);
if let Some(ptr) = NonNull::new(std::ptr::slice_from_raw_parts(
alloc_zeroed(layout),
new_size,
)
as *mut Branch<KEY_LEN, O, [Option<Head<KEY_LEN, O, V>>], V>)
{
addr_of_mut!((*ptr.as_ptr()).rc).write(atomic::AtomicU32::new(1));
addr_of_mut!((*ptr.as_ptr()).end_depth).write((*branch).end_depth);
addr_of_mut!((*ptr.as_ptr()).leaf_count).write((*branch).leaf_count);
addr_of_mut!((*ptr.as_ptr()).segment_count).write((*branch).segment_count);
addr_of_mut!((*ptr.as_ptr()).childleaf).write((*branch).childleaf);
addr_of_mut!((*ptr.as_ptr()).hash).write((*branch).hash);
(*branch)
.child_table
.table_grow(&mut (*ptr.as_ptr()).child_table);
Branch::<KEY_LEN, O, [Option<Head<KEY_LEN, O, V>>], V>::rc_dec(
NonNull::new_unchecked(branch),
);
*branch_nn = ptr;
} else {
handle_alloc_error(layout);
}
}
}
pub(super) fn modify_child<F>(branch_nn: &mut NonNull<Self>, key: u8, f: F)
where
F: FnOnce(Option<Head<KEY_LEN, O, V>>) -> Option<Head<KEY_LEN, O, V>>,
{
unsafe {
let branch = branch_nn.as_ptr();
let end_depth = (*branch).end_depth as usize;
if let Some(slot) = (*branch).child_table.table_get_slot(key) {
let child = slot.take().unwrap();
let old_child_hash = child.hash();
let old_child_segment_count = child.count_segment(end_depth);
let old_child_leaf_count = child.count();
let replaced_childleaf = child.childleaf_ptr() == (*branch).childleaf;
if let Some(new_child) = f(Some(child)) {
(*branch).hash = ((*branch).hash ^ old_child_hash) ^ new_child.hash();
(*branch).segment_count = ((*branch).segment_count - old_child_segment_count)
+ new_child.count_segment(end_depth);
(*branch).leaf_count =
((*branch).leaf_count - old_child_leaf_count) + new_child.count();
if replaced_childleaf {
(*branch).childleaf = new_child.childleaf_ptr();
}
if slot.replace(new_child.with_key(key)).is_some() {
unreachable!();
}
} else {
(*branch).hash ^= old_child_hash;
(*branch).segment_count -= old_child_segment_count;
(*branch).leaf_count -= old_child_leaf_count;
if replaced_childleaf {
if let Some(other) = (*branch).child_table.iter().find_map(|s| s.as_ref()) {
(*branch).childleaf = other.childleaf_ptr();
}
}
}
} else {
if let Some(mut inserted) = f(None) {
(*branch).leaf_count += inserted.count();
(*branch).segment_count += inserted.count_segment(end_depth);
(*branch).hash ^= inserted.hash();
let mut branch_ptr = branch_nn.as_ptr();
while let Some(new_displaced) = (*branch_ptr).child_table.table_insert(inserted)
{
inserted = new_displaced;
Self::grow(branch_nn);
branch_ptr = branch_nn.as_ptr();
}
}
}
#[cfg(debug_assertions)]
branch_nn.as_ref().debug_check_invariants();
}
}
pub fn count_segment(&self, at_depth: usize) -> u64 {
let node_end = self.end_depth as usize;
if !O::same_segment_tree(at_depth, node_end) {
1
} else {
self.segment_count
}
}
#[cfg(debug_assertions)]
pub fn debug_check_invariants(&self) {
let end_depth: usize = self.end_depth as usize;
let mut agg_leaf_count: u64 = 0;
let mut agg_segment_count: u64 = 0;
let mut agg_hash: u128 = 0;
let mut match_found = false;
for child in self.child_table.iter().flatten() {
agg_leaf_count = agg_leaf_count.saturating_add(child.count());
agg_segment_count = agg_segment_count.saturating_add(child.count_segment(end_depth));
agg_hash ^= child.hash();
if child.childleaf_ptr() == self.childleaf {
match_found = true;
}
}
debug_assert_eq!(
agg_leaf_count, self.leaf_count,
"branch.leaf_count mismatch"
);
debug_assert_eq!(
agg_segment_count, self.segment_count,
"branch.segment_count mismatch"
);
debug_assert_eq!(agg_hash, self.hash, "branch.hash mismatch");
if agg_leaf_count > 0 {
debug_assert!(match_found, "branch.childleaf pointer mismatch");
}
}
pub fn infixes<const PREFIX_LEN: usize, const INFIX_LEN: usize, F>(
&self,
prefix: &[u8; PREFIX_LEN],
at_depth: usize,
f: &mut F,
) where
F: FnMut(&[u8; INFIX_LEN]),
{
let node_end_depth = self.end_depth as usize;
let limit = std::cmp::min(PREFIX_LEN, node_end_depth);
if !self.childleaf().has_prefix::<O>(at_depth, &prefix[..limit]) {
return;
}
if PREFIX_LEN + INFIX_LEN <= node_end_depth {
let infix: [u8; INFIX_LEN] =
core::array::from_fn(|i| self.childleaf().key[O::TREE_TO_KEY[PREFIX_LEN + i]]);
f(&infix);
return;
}
if PREFIX_LEN > node_end_depth {
if let Some(child) = self.child_table.table_get(prefix[node_end_depth]) {
child.infixes(prefix, node_end_depth, f);
}
return;
}
for entry in self.child_table.iter().flatten() {
entry.infixes(prefix, node_end_depth, f);
}
}
pub fn infixes_range<const PREFIX_LEN: usize, const INFIX_LEN: usize, F>(
&self,
prefix: &[u8; PREFIX_LEN],
at_depth: usize,
min_infix: &[u8; INFIX_LEN],
max_infix: &[u8; INFIX_LEN],
f: &mut F,
) where
F: FnMut(&[u8; INFIX_LEN]),
{
let node_end_depth = self.end_depth as usize;
let limit = std::cmp::min(PREFIX_LEN, node_end_depth);
if !self.childleaf().has_prefix::<O>(at_depth, &prefix[..limit]) {
return;
}
if PREFIX_LEN + INFIX_LEN <= node_end_depth {
let infix: [u8; INFIX_LEN] =
core::array::from_fn(|i| self.childleaf().key[O::TREE_TO_KEY[PREFIX_LEN + i]]);
if &infix >= min_infix && &infix <= max_infix {
f(&infix);
}
return;
}
if PREFIX_LEN > node_end_depth {
if let Some(child) = self.child_table.table_get(prefix[node_end_depth]) {
child.infixes_range(prefix, node_end_depth, min_infix, max_infix, f);
}
return;
}
let infix_byte_idx = node_end_depth - PREFIX_LEN;
let mut min_tight = true; let mut max_tight = true; for i in 0..infix_byte_idx {
let path_byte = self.childleaf().key[O::TREE_TO_KEY[PREFIX_LEN + i]];
if min_tight {
if path_byte < min_infix[i] {
return;
} if path_byte > min_infix[i] {
min_tight = false;
} }
if max_tight {
if path_byte > max_infix[i] {
return;
} if path_byte < max_infix[i] {
max_tight = false;
} }
}
for entry in self.child_table.iter().flatten() {
let child_byte = entry.key();
if min_tight && infix_byte_idx < INFIX_LEN && child_byte < min_infix[infix_byte_idx] {
continue;
}
if max_tight && infix_byte_idx < INFIX_LEN && child_byte > max_infix[infix_byte_idx] {
continue;
}
entry.infixes_range(prefix, node_end_depth, min_infix, max_infix, f);
}
}
pub fn count_range<const PREFIX_LEN: usize, const INFIX_LEN: usize>(
&self,
prefix: &[u8; PREFIX_LEN],
at_depth: usize,
min_infix: &[u8; INFIX_LEN],
max_infix: &[u8; INFIX_LEN],
) -> u64 {
let node_end_depth = self.end_depth as usize;
let limit = std::cmp::min(PREFIX_LEN, node_end_depth);
if !self.childleaf().has_prefix::<O>(at_depth, &prefix[..limit]) {
return 0;
}
if PREFIX_LEN + INFIX_LEN <= node_end_depth {
let infix: [u8; INFIX_LEN] =
core::array::from_fn(|i| self.childleaf().key[O::TREE_TO_KEY[PREFIX_LEN + i]]);
return if &infix >= min_infix && &infix <= max_infix {
self.leaf_count
} else {
0
};
}
if PREFIX_LEN > node_end_depth {
if let Some(child) = self.child_table.table_get(prefix[node_end_depth]) {
return child.count_range(prefix, node_end_depth, min_infix, max_infix);
}
return 0;
}
let infix_byte_idx = node_end_depth - PREFIX_LEN;
let mut min_tight = true;
let mut max_tight = true;
for i in 0..infix_byte_idx {
let path_byte = self.childleaf().key[O::TREE_TO_KEY[PREFIX_LEN + i]];
if min_tight {
if path_byte < min_infix[i] {
return 0;
}
if path_byte > min_infix[i] {
min_tight = false;
}
}
if max_tight {
if path_byte > max_infix[i] {
return 0;
}
if path_byte < max_infix[i] {
max_tight = false;
}
}
}
let mut total = 0u64;
for entry in self.child_table.iter().flatten() {
let child_byte = entry.key();
let below_min = min_tight && child_byte < min_infix[infix_byte_idx];
let above_max = max_tight && child_byte > max_infix[infix_byte_idx];
if below_min || above_max {
continue;
}
let on_min = min_tight && child_byte == min_infix[infix_byte_idx];
let on_max = max_tight && child_byte == max_infix[infix_byte_idx];
if on_min || on_max {
total += entry.count_range(prefix, node_end_depth, min_infix, max_infix);
} else {
total += entry.count();
}
}
total
}
pub fn has_prefix<const PREFIX_LEN: usize>(
&self,
at_depth: usize,
prefix: &[u8; PREFIX_LEN],
) -> bool {
const {
assert!(PREFIX_LEN <= KEY_LEN);
}
let node_end_depth = self.end_depth as usize;
let limit = std::cmp::min(PREFIX_LEN, node_end_depth);
if !self.childleaf().has_prefix::<O>(at_depth, &prefix[..limit]) {
return false;
}
if PREFIX_LEN <= node_end_depth {
return true;
}
if let Some(child) = self.child_table.table_get(prefix[node_end_depth]) {
return child.has_prefix::<PREFIX_LEN>(node_end_depth, prefix);
}
false
}
pub fn get<'a>(&'a self, at_depth: usize, key: &[u8; KEY_LEN]) -> Option<&'a V>
where
O: 'a,
{
let node_end_depth = self.end_depth as usize;
let limit = std::cmp::min(KEY_LEN, node_end_depth);
if !self.childleaf().has_prefix::<O>(at_depth, &key[..limit]) {
return None;
}
if node_end_depth >= KEY_LEN {
return Some(&self.childleaf().value);
}
if let Some(child) = self.child_table.table_get(key[node_end_depth]) {
return child.get(node_end_depth, key);
}
None
}
pub fn segmented_len<const PREFIX_LEN: usize>(
&self,
at_depth: usize,
prefix: &[u8; PREFIX_LEN],
) -> u64 {
let node_end_depth = self.end_depth as usize;
let limit = std::cmp::min(PREFIX_LEN, node_end_depth);
if !self.childleaf().has_prefix::<O>(at_depth, &prefix[..limit]) {
return 0;
}
if PREFIX_LEN <= node_end_depth {
if !O::same_segment_tree(PREFIX_LEN, node_end_depth) {
return 1;
} else {
return self.segment_count;
}
}
if let Some(child) = self.child_table.table_get(prefix[node_end_depth]) {
child.segmented_len::<PREFIX_LEN>(node_end_depth, prefix)
} else {
0
}
}
}