use num_traits::Zero;
use crate::{
allocator::{AllocIdx, CellAllocator, Loc, NodeAllocator, RawPtr},
node::{
child_bit, child_cover_mask, data_bit, data_cover_mask, extend_repr, Key, MaskedLexIter,
MultiBitNode, DATA_BIT_TO_PREFIX,
},
prefix::mask_from_prefix_len,
Prefix,
};
pub(crate) const K: u32 = 5;
pub(crate) const NUM_DATA: usize = (1 << (K as usize)) - 1;
pub(crate) const NUM_CHILDREN: usize = 1 << (K as usize);
#[derive(Debug, Clone, Copy)]
pub(crate) struct DataIdx {
pub(crate) node: Loc,
pub(crate) bit: u32,
pub(crate) depth: u32,
}
impl DataIdx {
#[inline]
pub(crate) unsafe fn resolve<'a, T>(self, table: &'a Table<T>) -> Option<Present<'a, T>> {
let node = table.node(self.node);
if !node.has_data_bit(self.bit) {
return None;
}
let data = Loc::new(node.data_idx, self.bit, node.data_bitmap);
Some(Present {
table,
data,
depth: self.depth,
})
}
#[inline]
pub(crate) unsafe fn resolve_mut<'a, T>(
self,
table: &'a mut Table<T>,
) -> Option<PresentMut<'a, T>> {
let node = table.node(self.node);
if !node.has_data_bit(self.bit) {
return None;
}
let data = Loc::new(node.data_idx, self.bit, node.data_bitmap);
Some(PresentMut {
table,
node: self.node,
data,
depth: self.depth,
})
}
}
pub(crate) struct Present<'a, T> {
table: &'a Table<T>,
data: Loc,
depth: u32,
}
impl<'a, T> Present<'a, T> {
#[inline]
pub(crate) fn get(self) -> &'a T {
unsafe { self.table.cells.get(self.data) }
}
#[inline]
#[allow(clippy::mut_from_ref)]
pub(crate) unsafe fn unsafe_get_mut(self, ptr: &mut RawPtr<T>) -> &'a mut T {
unsafe { self.table.cells.unsafe_get_mut(ptr, self.data) }
}
pub(crate) fn prefix<P: Prefix>(&self, key: P::R) -> P {
prefix(key, self.depth, self.data.bit as usize)
}
}
pub(crate) struct PresentMut<'a, T> {
table: &'a mut Table<T>,
node: Loc,
data: Loc,
depth: u32,
}
impl<'a, T> PresentMut<'a, T> {
#[inline]
pub(crate) fn get_mut(self) -> &'a mut T {
unsafe { self.table.cells.get_mut(self.data) }
}
#[inline]
pub(crate) fn get(&self) -> &T {
unsafe { self.table.cells.get(self.data) }
}
#[inline]
pub(crate) fn as_mut(&mut self) -> &mut T {
unsafe { self.table.cells.get_mut(self.data) }
}
#[inline]
pub(crate) fn replace(self, val: T) -> T {
unsafe { self.table.cells.replace(self.data, val) }
}
#[inline]
pub(crate) fn take(self) -> T {
self.table
.cells
.remove_bit(&mut self.table.nodes[self.node], self.data)
}
pub(crate) fn prefix<P: Prefix>(&self, key: P::R) -> P {
prefix(key, self.depth, self.data.bit as usize)
}
}
pub(crate) struct EmptyMut<'a, T> {
pub(crate) table: &'a mut Table<T>,
pub(crate) node: Loc,
pub(crate) data_bit: u32,
pub(crate) depth: u32,
}
impl<'a, T> EmptyMut<'a, T> {
#[inline]
pub(crate) fn insert(self, val: T) -> PresentMut<'a, T> {
let data =
self.table
.cells
.insert_new_bit(&mut self.table.nodes[self.node], self.data_bit, val);
PresentMut {
table: self.table,
node: self.node,
data,
depth: self.depth,
}
}
}
pub(crate) struct NoNodeMut<'a, T> {
pub(crate) table: &'a mut Table<T>,
pub(crate) last_node: Loc,
pub(crate) last_depth: u32,
}
impl<'a, T> NoNodeMut<'a, T> {
#[inline(always)]
pub(crate) fn advance<R: Key>(self, key: R, prefix_len: u32) -> Result<EmptyMut<'a, T>, Self> {
let NoNodeMut {
table,
last_node,
last_depth,
} = self;
if prefix_len < last_depth + K {
let data_bit = data_bit(key, prefix_len);
Ok(EmptyMut {
table,
node: last_node,
data_bit,
depth: last_depth,
})
} else {
let child_bit = child_bit(last_depth, key);
let next_node = table.nodes.insert_new_bit(last_node, child_bit);
Err(NoNodeMut {
table,
last_node: next_node,
last_depth: last_depth + K,
})
}
}
pub(crate) fn insert_path_and_data<R: Key>(
mut self,
key: R,
prefix_len: u32,
val: T,
) -> PresentMut<'a, T> {
loop {
match self.advance(key, prefix_len) {
Ok(empty) => return empty.insert(val),
Err(next) => self = next,
}
}
}
}
pub(crate) enum Location<'a, T> {
Present(PresentMut<'a, T>),
Empty(EmptyMut<'a, T>),
NoNode(NoNodeMut<'a, T>),
}
impl<'a, T> Location<'a, T> {
#[inline]
pub(crate) fn present(self) -> Option<PresentMut<'a, T>> {
match self {
Location::Present(p) => Some(p),
_ => None,
}
}
#[inline]
pub(crate) fn node_loc(&self) -> Loc {
match self {
Location::Present(p) => p.node,
Location::Empty(e) => e.node,
Location::NoNode(n) => n.last_node,
}
}
#[inline]
pub(crate) fn depth(&self) -> u32 {
match self {
Location::Present(p) => p.depth,
Location::Empty(e) => e.depth,
Location::NoNode(n) => n.last_depth,
}
}
}
pub(crate) struct Table<T> {
nodes: NodeAllocator,
cells: CellAllocator<T>,
}
impl<T> Default for Table<T> {
fn default() -> Self {
Self {
nodes: Default::default(),
cells: Default::default(),
}
}
}
impl<T> Drop for Table<T> {
fn drop(&mut self) {
self.drop_values();
}
}
unsafe impl<T: Send> Send for Table<T> {}
unsafe impl<T: Sync> Sync for Table<T> {}
impl<T> Table<T> {
pub(crate) fn raw_cells(&mut self) -> RawPtr<T> {
self.cells.raw_ptr()
}
#[inline(always)]
pub(crate) fn node(&self, pos: Loc) -> &MultiBitNode {
&self.nodes[pos]
}
#[inline(always)]
pub(crate) unsafe fn child(&self, pos: Loc, child_bit: u32) -> Option<Loc> {
let node = self.node(pos);
if node.has_child_bit(child_bit) {
Some(Loc::new(node.children_idx, child_bit, node.child_bitmap))
} else {
None
}
}
#[inline(always)]
pub(crate) unsafe fn remove_child_at(&mut self, parent_loc: Loc, child_bit: u32) {
self.nodes.remove_bit(parent_loc, child_bit);
}
pub(crate) fn mem_size(&self) -> usize {
self.nodes.mem_size() + self.cells.mem_size()
}
fn drop_values(&mut self) {
let mut stack = vec![Loc::root()];
while let Some(loc) = stack.pop() {
let node = *self.node(loc);
stack.extend(node.child_locs());
for data_loc in node.data_locs() {
let _ = unsafe { self.cells.remove_raw(data_loc) };
}
}
}
#[inline(always)]
fn find_loc<R: Key>(&self, key: R, prefix_len: u32) -> Option<(Loc, u32)> {
let mut loc = Loc::root();
let mut depth = 0u32;
while prefix_len >= depth + K {
let cb = child_bit(depth, key);
let Some(next) = (unsafe { self.child(loc, cb) }) else {
return None;
};
loc = next;
depth += K;
}
Some((loc, depth))
}
#[inline(always)]
pub(crate) fn find<R: Key>(&self, key: R, prefix_len: u32) -> Option<Present<'_, T>> {
let (loc, depth) = self.find_loc(key, prefix_len)?;
let node = self.node(loc);
let db = data_bit(key, prefix_len);
if node.has_data_bit(db) {
let data = Loc::new(node.data_idx, db, node.data_bitmap);
Some(Present {
table: self,
data,
depth,
})
} else {
None
}
}
#[inline(always)]
pub(crate) fn find_lpm<R: Key>(&self, key: R, prefix_len: u32) -> Option<Present<'_, T>> {
let mut loc = Loc::root();
let mut depth = 0;
let mut lpm: Option<Present<'_, T>> = None;
loop {
let node = self.node(loc);
if let Some(data_loc) = node.data_lpm_loc(depth, key, prefix_len) {
lpm = Some(Present {
table: self,
data: data_loc,
depth,
});
}
if prefix_len < depth + K {
return lpm;
}
let child_bit = child_bit(depth, key);
let Some(next) = (unsafe { self.child(loc, child_bit) }) else {
return lpm;
};
loc = next;
depth += K;
}
}
#[inline(always)]
pub(crate) fn find_lpm_mut<R: Key>(
&mut self,
key: R,
prefix_len: u32,
) -> Option<PresentMut<'_, T>> {
let mut loc = Loc::root();
let mut depth = 0;
let mut lpm: Option<(Loc, Loc, u32)> = None;
loop {
let node = self.node(loc);
if let Some(data_loc) = node.data_lpm_loc(depth, key, prefix_len) {
lpm = Some((loc, data_loc, depth));
}
if prefix_len < depth + K {
break;
}
let child_bit = child_bit(depth, key);
let Some(next) = (unsafe { self.child(loc, child_bit) }) else {
break;
};
loc = next;
depth += K;
}
let (node, data, depth) = lpm?;
let node_snap = self.node(node);
let data = Loc::new(node_snap.data_idx, data.bit, node_snap.data_bitmap);
Some(PresentMut {
table: self,
node,
data,
depth,
})
}
#[inline(always)]
pub(crate) fn find_spm<R: Key>(&self, key: R, prefix_len: u32) -> Option<Present<'_, T>> {
let mut loc = Loc::root();
let mut depth = 0;
loop {
let node = self.node(loc);
if let Some(data_loc) = node.data_spm_loc(depth, key, prefix_len) {
return Some(Present {
table: self,
data: data_loc,
depth,
});
}
if prefix_len < depth + K {
return None;
}
let child_bit = child_bit(depth, key);
loc = unsafe { self.child(loc, child_bit) }?;
depth += K;
}
}
#[inline(always)]
pub(crate) fn find_mut<R: Key>(&mut self, key: R, prefix_len: u32) -> Location<'_, T> {
let mut loc = Loc::root();
let mut depth = 0;
while prefix_len >= depth + K {
let child_bit = child_bit(depth, key);
let Some(next) = (unsafe { self.child(loc, child_bit) }) else {
return Location::NoNode(NoNodeMut {
table: self,
last_node: loc,
last_depth: depth,
});
};
loc = next;
depth += K;
}
let db = data_bit(key, prefix_len);
let node = self.node(loc);
if node.has_data_bit(db) {
let data = Loc::new(node.data_idx, db, node.data_bitmap);
Location::Present(PresentMut {
table: self,
node: loc,
data,
depth,
})
} else {
Location::Empty(EmptyMut {
table: self,
node: loc,
data_bit: db,
depth,
})
}
}
#[inline(always)]
pub(crate) fn find_or_insert_mut<R: Key>(
&mut self,
key: R,
prefix_len: u32,
) -> Result<PresentMut<'_, T>, EmptyMut<'_, T>> {
let mut loc = Loc::root();
let mut depth = 0;
while prefix_len >= depth + K {
let cb = child_bit(depth, key);
loc = match unsafe { self.child(loc, cb) } {
Some(next) => next,
None => self.nodes.insert_new_bit(loc, cb),
};
depth += K;
}
let db = data_bit(key, prefix_len);
let node = self.node(loc);
if node.has_data_bit(db) {
let data = Loc::new(node.data_idx, db, node.data_bitmap);
Ok(PresentMut {
table: self,
node: loc,
data,
depth,
})
} else {
Err(EmptyMut {
table: self,
node: loc,
data_bit: db,
depth,
})
}
}
#[inline(always)]
#[allow(clippy::type_complexity)]
pub(crate) fn find_mut_with_path<R: Key>(
&mut self,
key: R,
prefix_len: u32,
) -> Option<(Location<'_, T>, Vec<(Loc, u32)>)> {
let mut path: Vec<(Loc, u32)> = Vec::new();
let mut loc = Loc::root();
let mut depth = 0;
while prefix_len >= depth + K {
let child_bit = child_bit(depth, key);
let next_loc = unsafe { self.child(loc, child_bit) }?;
path.push((loc, child_bit));
loc = next_loc;
depth += K;
}
let db = data_bit(key, prefix_len);
let node = self.node(loc);
let loc_mut = if node.has_data_bit(db) {
let data = Loc::new(node.data_idx, db, node.data_bitmap);
Location::Present(PresentMut {
table: self,
node: loc,
data,
depth,
})
} else {
Location::Empty(EmptyMut {
table: self,
node: loc,
data_bit: db,
depth,
})
};
Some((loc_mut, path))
}
pub(crate) unsafe fn data_descendants<R: Key>(
&self,
loc: Loc,
depth: u32,
key: R,
prefix_len: u32,
) -> impl DoubleEndedIterator<Item = DataIdx> + 'static {
self.node(loc)
.data_cover_locs(depth, key, prefix_len)
.map(move |data_loc| DataIdx {
node: loc,
bit: data_loc.bit,
depth,
})
}
pub(crate) unsafe fn data_ancestors<R: Key>(
&self,
loc: Loc,
depth: u32,
key: R,
prefix_len: u32,
) -> impl Iterator<Item = DataIdx> + 'static {
self.node(loc)
.data_lpm_locs(depth, key, prefix_len)
.map(move |data_loc| DataIdx {
node: loc,
bit: data_loc.bit,
depth,
})
}
pub(crate) fn lex_iter_at<R: Key>(&self, key: R, prefix_len: u32) -> MaskedLexIter<R> {
let Some((loc, depth)) = self.find_loc(key, prefix_len) else {
return MaskedLexIter::default();
};
let mut lex = unsafe { self.lex_iter(loc, depth, key) };
lex.apply_data_mask(data_cover_mask(depth, key, prefix_len));
lex.apply_child_mask(child_cover_mask(depth, key, prefix_len));
lex
}
pub(crate) unsafe fn data_iter(
&self,
loc: Loc,
depth: u32,
) -> impl DoubleEndedIterator<Item = DataIdx> + 'static {
self.node(loc).data_locs().map(move |data_loc| DataIdx {
node: loc,
bit: data_loc.bit,
depth,
})
}
pub(crate) unsafe fn lex_iter<R: Key>(&self, loc: Loc, depth: u32, key: R) -> MaskedLexIter<R> {
MaskedLexIter::new(loc, depth, key, *self.node(loc))
}
pub(crate) fn retain_all<P: Prefix, F>(&mut self, f: &mut F) -> usize
where
F: FnMut(&P, &T) -> bool,
{
let mut removed_total = 0usize;
#[allow(clippy::type_complexity)]
let mut stack: Vec<Vec<(Vec<u32>, u32, P::R)>> =
vec![vec![(vec![], 0, <P::R as Zero>::zero())]];
'main: while let Some(mut siblings) = stack.pop() {
let Some((offsets, depth, key)) = siblings.pop() else {
continue;
};
stack.push(siblings);
let mut loc = Loc::root();
let mut path: Vec<(Loc, u32)> = Vec::with_capacity(offsets.len());
for &offset in &offsets {
let Some(child_loc) = (unsafe { self.child(loc, offset) }) else {
continue 'main; };
path.push((loc, offset));
loc = child_loc;
}
let to_remove: Vec<u32> = unsafe { self.data_iter(loc, depth) }
.filter(|&dl| {
let r = unsafe { dl.resolve(self) }.expect("data_iter: bit not in bitmap");
!f(&r.prefix::<P>(key), r.get())
})
.map(|dl| dl.bit)
.collect();
removed_total += to_remove.len();
for bit in to_remove {
let idx = DataIdx {
node: loc,
bit,
depth,
};
let r = unsafe { idx.resolve_mut(self) }.expect("retain_all: bit not in bitmap");
r.take();
}
unsafe { self.cleanup_tree(loc, &mut path) };
let mut cur = Loc::root();
for &offset in &offsets {
let Some(child_loc) = (unsafe { self.child(cur, offset) }) else {
continue 'main;
};
cur = child_loc;
}
let node_snap = *self.node(cur);
if node_snap.child_bitmap != 0 {
let children: Vec<_> = node_snap
.child_locs()
.map(|child| {
let mut child_offsets = offsets.clone();
child_offsets.push(child.bit);
(child_offsets, depth + K, extend_repr(key, depth, child.bit))
})
.collect();
stack.push(children);
}
}
removed_total
}
pub(crate) unsafe fn cleanup_tree(
&mut self,
start_loc: Loc,
path: &mut Vec<(Loc, u32)>,
) -> (Loc, usize) {
let mut loc = start_loc;
let mut num_removed = 0;
loop {
let MultiBitNode {
data_bitmap,
child_bitmap,
..
} = *self.node(loc);
if data_bitmap != 0 || child_bitmap != 0 {
break;
}
if loc.is_root() {
break;
}
num_removed += 1;
let Some((parent_loc, child_offset)) = path.pop() else {
unreachable!("Path must go back all the way to the root");
};
unsafe { self.remove_child_at(parent_loc, child_offset) };
if parent_loc.is_root() {
return (Loc::root(), num_removed);
}
loc = parent_loc;
}
(loc, num_removed)
}
pub(crate) unsafe fn clear_node_and_children(&mut self, loc: Loc) -> usize {
let is_all = loc.is_root();
let mut stack = vec![loc];
let mut count = 0;
let mut children_to_free: Vec<(AllocIdx, usize)> = Vec::new();
while let Some(loc) = stack.pop() {
let node_snap = *self.node(loc);
stack.extend(node_snap.child_locs());
let child_count = node_snap.child_bitmap.count_ones() as usize;
children_to_free.push((node_snap.children_idx, child_count));
let data_count = node_snap.data_bitmap.count_ones() as usize;
count += data_count;
if data_count > 0 {
for data_loc in node_snap.data_locs() {
let _ = unsafe { self.cells.remove_raw(data_loc) };
}
self.cells.free(node_snap.data_idx, data_count);
self.nodes[loc].data_bitmap = 0;
self.nodes[loc].data_idx = AllocIdx::empty();
}
}
if is_all {
self.nodes.clear();
self.cells.clear();
} else {
for (to_free, child_count) in children_to_free {
if !to_free.is_empty() && child_count > 0 {
self.nodes.free(to_free, child_count);
}
}
}
count
}
#[cfg(test)]
pub(crate) fn check_memory_alloc(&self) -> bool {
use crate::allocator::{
CHILD_COUNT_TO_TIER, CHILD_SPACING, DATA_COUNT_TO_TIER, DATA_SPACING,
};
macro_rules! assert_soft {
($var:ident, $check:expr, $($fmt:expr),*) => {
if !($check) {
$var = false;
eprintln!($($fmt),*);
}
}
}
let mut correct = true;
let cell_len = self.cells.total_slots();
let node_len = self.nodes.total_slots();
let mut cell_acc = vec![false; cell_len];
let mut node_acc = vec![false; node_len];
node_acc[0] = true;
let mut stack = vec![Loc::root()];
while let Some(loc) = stack.pop() {
let node = *self.node(loc);
if node.data_bitmap != 0 {
assert_soft!(
correct,
!node.data_idx.is_empty(),
"node at slot {} has non-zero data_bitmap but empty data_idx",
loc.idx.as_usize() + loc.slot as usize
);
let count = node.data_bitmap.count_ones() as usize;
let cap = DATA_SPACING[DATA_COUNT_TO_TIER[count.min(31)] as usize];
let start = node.data_idx.as_usize();
for (i, item) in cell_acc.iter_mut().enumerate().skip(start).take(cap) {
assert_soft!(
correct,
!*item,
"cell slot {i} is referenced by more than one live node"
);
*item = true;
}
} else {
assert_soft!(
correct,
node.data_idx.is_empty(),
"cell at slot {} has zero data_bitmap but non-empty data_idx",
loc.idx.as_usize() + loc.slot as usize
);
}
if node.child_bitmap != 0 {
assert_soft!(
correct,
!node.children_idx.is_empty(),
"node at slot {} has non-zero child_bitmap but empty children_idx",
loc.idx.as_usize() + loc.slot as usize
);
let count = node.child_bitmap.count_ones() as usize;
let cap = CHILD_SPACING[CHILD_COUNT_TO_TIER[count.min(32)] as usize];
let start = node.children_idx.as_usize();
for (i, item) in node_acc.iter_mut().enumerate().skip(start).take(cap) {
assert_soft!(
correct,
!*item,
"node slot {i} is referenced by more than one live node"
);
*item = true;
}
stack.extend(node.child_locs());
} else {
assert_soft!(
correct,
node.children_idx.is_empty(),
"node at slot {} has zero child_bitmap but non-empty children_idx",
loc.idx.as_usize() + loc.slot as usize
);
}
}
for (start, cap) in self.cells.free_list_slots() {
for (i, item) in cell_acc.iter_mut().enumerate().skip(start).take(cap) {
assert_soft!(
correct,
!*item,
"cell slot {i} appears in both the live tree and a free list"
);
*item = true;
}
}
for (start, cap) in self.nodes.free_list_slots() {
for (i, item) in node_acc.iter_mut().enumerate().skip(start).take(cap) {
assert_soft!(
correct,
!*item,
"node slot {i} appears in both the live tree and a free list"
);
*item = true;
}
}
for (i, &acc) in cell_acc.iter().enumerate() {
assert_soft!(
correct,
acc,
"cell slot {i} is leaked (neither referenced by any live node nor in any free list)"
);
}
for (i, &acc) in node_acc.iter().enumerate() {
assert_soft!(
correct,
acc,
"node slot {i} is leaked (neither referenced by any live node nor in any free list)"
);
}
correct
}
}
impl<T: Clone> Clone for Table<T> {
fn clone(&self) -> Self {
let mut x = Self {
nodes: self.nodes.clone(),
cells: Default::default(),
};
let mut stack = vec![Loc::root()];
while let Some(loc) = stack.pop() {
let node = self.nodes[loc];
x.nodes[loc].data_bitmap = 0;
x.nodes[loc].data_idx = AllocIdx::empty();
if node.child_bitmap != 0 {
stack.extend(node.child_locs());
}
}
let mut stack = vec![Loc::root()];
while let Some(loc) = stack.pop() {
let node = self.nodes[loc];
if node.data_bitmap != 0 && !node.data_idx.is_empty() {
let count = node.data_bitmap.count_ones() as usize;
let data_idx = x.cells.alloc(count);
x.nodes[loc].data_idx = data_idx;
for data_loc in node.data_locs() {
let val = unsafe { self.cells.get(data_loc) }.clone();
unsafe { x.cells.write_at(data_idx, data_loc.slot, val) };
x.nodes[loc].set_data_bit(data_loc.bit);
}
}
if node.child_bitmap != 0 {
stack.extend(node.child_locs());
}
}
x
}
}
fn prefix<P: Prefix>(key: P::R, depth: u32, data_offset: usize) -> P {
let mask = mask_from_prefix_len(depth as u8);
let root = key & mask;
let (offset, offset_len) = DATA_BIT_TO_PREFIX[data_offset];
let offset = <P::R as num_traits::cast::NumCast>::from(offset).unwrap();
let offset_bits = K - 1;
let total_width = P::num_bits();
let shifted_offset = if total_width > depth + offset_bits {
offset << (total_width - (depth + offset_bits)) as usize
} else {
offset >> (depth + offset_bits - total_width) as usize
};
let repr = root | shifted_offset;
let prefix_len = depth + offset_len as u32;
P::from_repr_len(repr, prefix_len as u8)
}