use core::fmt;
use core::ptr::NonNull;
use super::Nibble;
#[inline]
const fn mask_bit(n: usize) -> u32 {
debug_assert!(n < 32);
1 << (31 - n)
}
const MASKS: [u32; 16] = [
mask_bit(0) | mask_bit(1) | mask_bit(3) | mask_bit(7) | mask_bit(16), mask_bit(0) | mask_bit(1) | mask_bit(3) | mask_bit(7) | mask_bit(17), mask_bit(0) | mask_bit(1) | mask_bit(3) | mask_bit(8) | mask_bit(18), mask_bit(0) | mask_bit(1) | mask_bit(3) | mask_bit(8) | mask_bit(19), mask_bit(0) | mask_bit(1) | mask_bit(4) | mask_bit(9) | mask_bit(20), mask_bit(0) | mask_bit(1) | mask_bit(4) | mask_bit(9) | mask_bit(21), mask_bit(0) | mask_bit(1) | mask_bit(4) | mask_bit(10) | mask_bit(22), mask_bit(0) | mask_bit(1) | mask_bit(4) | mask_bit(10) | mask_bit(23), mask_bit(0) | mask_bit(2) | mask_bit(5) | mask_bit(11) | mask_bit(24), mask_bit(0) | mask_bit(2) | mask_bit(5) | mask_bit(11) | mask_bit(25), mask_bit(0) | mask_bit(2) | mask_bit(5) | mask_bit(12) | mask_bit(26), mask_bit(0) | mask_bit(2) | mask_bit(5) | mask_bit(12) | mask_bit(27), mask_bit(0) | mask_bit(2) | mask_bit(6) | mask_bit(13) | mask_bit(28), mask_bit(0) | mask_bit(2) | mask_bit(6) | mask_bit(13) | mask_bit(29), mask_bit(0) | mask_bit(2) | mask_bit(6) | mask_bit(14) | mask_bit(30), mask_bit(0) | mask_bit(2) | mask_bit(6) | mask_bit(14) | mask_bit(31), ];
const CHILD_MASKS: [u32; 16] = [
mask_bit(16), mask_bit(17), mask_bit(18), mask_bit(19), mask_bit(20), mask_bit(21), mask_bit(22), mask_bit(23), mask_bit(24), mask_bit(25), mask_bit(26), mask_bit(27), mask_bit(28), mask_bit(29), mask_bit(30), mask_bit(31), ];
#[rustfmt::skip]
const VALUE_MASKS: [[u32; 16]; 5] = [
[
mask_bit(0), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
[
mask_bit(0) | mask_bit(1), 0, 0, 0, 0, 0, 0, 0,
mask_bit(0) | mask_bit(2), 0, 0, 0, 0, 0, 0, 0,
],
[
mask_bit(0) | mask_bit(1) | mask_bit(3), 0, 0, 0,
mask_bit(0) | mask_bit(1) | mask_bit(4), 0, 0, 0,
mask_bit(0) | mask_bit(2) | mask_bit(5), 0, 0, 0,
mask_bit(0) | mask_bit(2) | mask_bit(6), 0, 0, 0,
],
[
mask_bit(0) | mask_bit(1) | mask_bit(3) | mask_bit(7), 0,
mask_bit(0) | mask_bit(1) | mask_bit(3) | mask_bit(8), 0,
mask_bit(0) | mask_bit(1) | mask_bit(4) | mask_bit(9), 0,
mask_bit(0) | mask_bit(1) | mask_bit(4) | mask_bit(10), 0,
mask_bit(0) | mask_bit(2) | mask_bit(5) | mask_bit(11), 0,
mask_bit(0) | mask_bit(2) | mask_bit(5) | mask_bit(12), 0,
mask_bit(0) | mask_bit(2) | mask_bit(6) | mask_bit(13), 0,
mask_bit(0) | mask_bit(2) | mask_bit(6) | mask_bit(14), 0,
],
MASKS,
];
#[inline]
const fn exact_value_mask(nibble: u8, bits: u8) -> u32 {
debug_assert!(nibble < 16);
debug_assert!(bits < 5);
let offset = ((nibble << bits) & 0b1111_0000) >> bits;
let v = VALUE_MASKS[bits as usize][offset as usize];
1u32 << v.trailing_zeros()
}
#[inline]
const fn match_value_mask(nibble: u8, bits: u8) -> u32 {
debug_assert!(nibble < 16);
debug_assert!(bits < 5);
let offset = ((nibble << bits) & 0b1111_0000) >> bits;
VALUE_MASKS[bits as usize][offset as usize]
}
pub(super) struct Node<V> {
bitmap: u32,
values: [Option<Box<V>>; 32],
children: [Option<NonNull<Node<V>>>; 16],
}
impl<V> Node<V> {
const END_NODE_BIT: u32 = 1 << 16;
#[inline]
fn empty() -> Self {
Self {
bitmap: 0,
values: [
None, None, None, None, None, None, None, None, None, None, None, None, None, None,
None, None, None, None, None, None, None, None, None, None, None, None, None, None,
None, None, None, None,
],
children: [
None, None, None, None, None, None, None, None, None, None, None, None, None, None,
None, None,
],
}
}
#[inline]
pub fn internal() -> NonNull<Self> {
let boxed = Box::new(Self::empty());
let ptr = Box::into_raw(boxed);
NonNull::new(ptr).unwrap()
}
#[inline]
pub fn end() -> NonNull<Self> {
let mut me = Self::empty();
me.bitmap |= Self::END_NODE_BIT;
let boxed = Box::new(me);
let ptr = Box::into_raw(boxed);
NonNull::new(ptr).unwrap()
}
#[inline]
pub const fn is_end(&self) -> bool {
self.bitmap & Self::END_NODE_BIT != 0
}
#[inline]
const fn value_bits(&self) -> u32 {
if self.is_end() {
self.bitmap & !Self::END_NODE_BIT
} else {
self.bitmap & 0xffff_0000
}
}
#[inline]
const fn child_bits(&self) -> u32 {
debug_assert!(!self.is_end());
self.bitmap & 0x0000_ffff
}
#[inline]
pub fn list_values(&self, nibble: Nibble) -> Vec<&V> {
let mask = match_value_mask(nibble.byte, nibble.bits);
let masked = self.value_bits() & mask;
(0..32u32)
.filter(|i| masked & (1 << (31 - i)) != 0)
.map(|i| self.values[i as usize].as_ref().unwrap())
.map(|p| p.as_ref())
.collect()
}
#[inline]
pub fn get_longest_match_value(&self, nibble: Nibble) -> Option<&V> {
let mask = match_value_mask(nibble.byte, nibble.bits);
let offset = (self.value_bits() & mask).trailing_zeros() as usize;
if offset == 32 {
return None;
}
let index = 31 - offset;
self.values[index].as_ref().map(|p| p.as_ref())
}
#[inline]
pub fn get_exact_match_value(&self, nibble: Nibble) -> Option<&V> {
let mask = exact_value_mask(nibble.byte, nibble.bits);
let offset = (self.value_bits() & mask).trailing_zeros() as usize;
if offset == 32 {
return None;
}
let index = 31 - offset;
self.values[index].as_ref().map(|p| p.as_ref())
}
#[inline]
pub fn set_value(&mut self, nibble: Nibble, value: V) {
let mask = exact_value_mask(nibble.byte, nibble.bits);
let offset = mask.trailing_zeros() as usize;
debug_assert!(offset < 32, "offset = {offset}");
let index = 31 - offset;
debug_assert!(self.values[index].is_none());
self.bitmap |= 1 << offset;
self.values[index] = Some(Box::new(value));
}
#[inline]
pub fn remove_value(&mut self, nibble: Nibble) -> Option<V> {
let mask = exact_value_mask(nibble.byte, nibble.bits);
let offset = (self.value_bits() & mask).trailing_zeros() as usize;
if offset == 32 {
return None;
}
let index = 31 - offset;
self.values[index].take().map(|p| *p)
}
#[inline]
pub fn get_child(&self, nibble: u8) -> Option<NonNull<Node<V>>> {
let mask = CHILD_MASKS[nibble as usize];
let masked = self.child_bits() & mask;
if masked.trailing_zeros() > 15 {
return None;
}
let index = 15 - masked.trailing_zeros();
self.children[index as usize]
}
#[inline]
pub fn set_child(&mut self, nibble: u8, node: NonNull<Node<V>>) {
debug_assert!(nibble < 16);
let i = nibble as usize;
debug_assert!(self.children[i].is_none());
debug_assert!(self.bitmap & CHILD_MASKS[i] == 0);
self.children[i] = Some(node);
self.bitmap |= CHILD_MASKS[i];
}
}
impl<V> Drop for Node<V> {
fn drop(&mut self) {
for child in self.children.iter_mut() {
if let Some(child) = child.take() {
let _ = unsafe { Box::from_raw(child.as_ptr()) };
}
}
}
}
impl<V> fmt::Debug for Node<V> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut debug = f.debug_struct("Node");
debug
.field("is_end", &self.is_end())
.field("values", &format_args!("{:032b}", self.value_bits()));
if !self.is_end() {
debug.field("children", &format_args!("{:016b}", self.child_bits()));
}
debug.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mask() {
pub const MSB: u32 = 1 << 31;
pub static MATCH_MASKS: [u32; 16] = [
MSB | MSB >> 1 | MSB >> 3 | MSB >> 7 | MSB >> 16, MSB | MSB >> 1 | MSB >> 3 | MSB >> 7 | MSB >> 17, MSB | MSB >> 1 | MSB >> 3 | MSB >> 8 | MSB >> 18, MSB | MSB >> 1 | MSB >> 3 | MSB >> 8 | MSB >> 19, MSB | MSB >> 1 | MSB >> 4 | MSB >> 9 | MSB >> 20, MSB | MSB >> 1 | MSB >> 4 | MSB >> 9 | MSB >> 21, MSB | MSB >> 1 | MSB >> 4 | MSB >> 10 | MSB >> 22, MSB | MSB >> 1 | MSB >> 4 | MSB >> 10 | MSB >> 23, MSB | MSB >> 2 | MSB >> 5 | MSB >> 11 | MSB >> 24, MSB | MSB >> 2 | MSB >> 5 | MSB >> 11 | MSB >> 25, MSB | MSB >> 2 | MSB >> 5 | MSB >> 12 | MSB >> 26, MSB | MSB >> 2 | MSB >> 5 | MSB >> 12 | MSB >> 27, MSB | MSB >> 2 | MSB >> 6 | MSB >> 13 | MSB >> 28, MSB | MSB >> 2 | MSB >> 6 | MSB >> 13 | MSB >> 29, MSB | MSB >> 2 | MSB >> 6 | MSB >> 14 | MSB >> 30, MSB | MSB >> 2 | MSB >> 6 | MSB >> 14 | MSB >> 31,
];
for i in 0..16 {
assert_eq!(MATCH_MASKS[i], MASKS[i]);
}
}
#[test]
fn test_node() {
let mut root = Node::<u64>::internal();
let p = unsafe { root.as_mut() };
p.set_child(0b0010, Node::end());
assert!(p.get_child(0b0010).is_some());
}
#[test]
fn test_exact_value_mask() {
let tests = [
((0b0000, 0), [0b1000_0000, 0, 0, 0]),
((0b0000, 1), [0b0100_0000, 0, 0, 0]),
((0b1000, 1), [0b0010_0000, 0, 0, 0]),
((0b0000, 2), [0b0001_0000, 0, 0, 0]),
((0b0100, 2), [0b0000_1000, 0, 0, 0]),
];
for (input, expected) in tests.into_iter() {
let (nibble, bits) = input;
let mask = exact_value_mask(nibble, bits);
let actual = mask.to_be_bytes();
assert_eq!(
actual,
expected,
"input: ({nibble:08b}, {bits}) = {mask:032b}, expected: {:032b}",
u32::from_be_bytes(expected),
);
}
}
#[test]
fn test_match_value_mask() {
let tests = [
((0b0000, 0), [0b1000_0000, 0, 0, 0]),
((0b0000, 1), [0b1100_0000, 0, 0, 0]),
((0b1000, 1), [0b1010_0000, 0, 0, 0]),
((0b0000, 2), [0b1101_0000, 0, 0, 0]),
((0b0100, 2), [0b1100_1000, 0, 0, 0]),
];
for (input, expected) in tests.into_iter() {
let (nibble, bits) = input;
let mask = match_value_mask(nibble, bits);
let actual = mask.to_be_bytes();
assert_eq!(
actual,
expected,
"input: ({nibble:08b}, {bits}) = {mask:032b}, expected: {:032b}",
u32::from_be_bytes(expected),
);
}
}
}