use crate::error::{Result, ZiporaError};
pub const ALIGN_SIZE: usize = 4;
pub const NIL_STATE: u32 = u32::MAX;
pub const MAX_ZPATH: usize = 254;
pub const INITIAL_STATE: u32 = 0;
const FREE_LIST_MAX_SLOTS: usize = 128;
const FREE_LIST_NIL: u32 = u32::MAX;
pub const SKIP_SLOTS: [u32; 16] = [
1, 1, 1, 2, 2, 2, 2, 5, 10, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, 2, ];
#[repr(C, align(4))]
#[derive(Clone, Copy)]
pub union PatriciaNode {
pub meta: MetaInfo,
pub big: BigCount,
pub child: u32,
pub bytes: [u8; 4],
}
#[repr(C)]
#[derive(Clone, Copy)]
pub struct MetaInfo {
pub flags: u8, pub n_zpath_len: u8,
pub c_label: [u8; 2],
}
#[repr(C)]
#[derive(Clone, Copy)]
pub struct BigCount {
pub _unused: u16,
pub n_children: u16,
}
impl PatriciaNode {
#[inline(always)]
pub fn empty() -> Self {
PatriciaNode { child: NIL_STATE }
}
}
#[derive(Debug, Clone, Default)]
pub struct MemStat {
pub fastbin: Vec<usize>,
pub used_size: usize,
pub capacity: usize,
pub frag_size: usize,
pub large_size: usize,
pub large_cnt: usize,
pub lazy_free_sum: usize,
pub lazy_free_cnt: usize,
}
#[derive(Clone, Copy)]
struct LazyFreeItem {
slot: u32,
slots: u32,
}
pub struct NodeView<'a> {
nodes: &'a [PatriciaNode],
curr: u32,
}
impl<'a> NodeView<'a> {
#[inline(always)]
pub fn new(nodes: &'a [PatriciaNode], curr: u32) -> Self {
debug_assert!((curr as usize) < nodes.len());
Self { nodes, curr }
}
#[inline(always)]
pub fn meta(&self) -> MetaInfo {
unsafe { self.nodes.get_unchecked(self.curr as usize).meta }
}
#[inline(always)]
pub fn big(&self) -> BigCount {
unsafe { self.nodes.get_unchecked(self.curr as usize).big }
}
#[inline(always)]
pub fn child(&self, offset: usize) -> u32 {
unsafe { self.nodes.get_unchecked(self.curr as usize + offset).child }
}
#[inline(always)]
pub fn bytes(&self, offset: usize) -> [u8; 4] {
unsafe { self.nodes.get_unchecked(self.curr as usize + offset).bytes }
}
#[inline(always)]
pub fn cnt_type(&self) -> u8 {
self.meta().flags & 0x0F
}
#[inline(always)]
pub fn is_final(&self) -> bool {
(self.meta().flags & 0x10) != 0
}
#[inline(always)]
pub fn zpath_len(&self) -> usize {
self.meta().n_zpath_len as usize
}
#[inline(always)]
pub fn n_children(&self) -> usize {
let t = self.cnt_type();
if t <= 6 {
t as usize
} else {
self.big().n_children as usize
}
}
#[inline(always)]
pub fn skip_slots(&self) -> usize {
SKIP_SLOTS[self.cnt_type() as usize] as usize
}
#[inline(always)]
fn get_label(&self, idx: usize) -> u8 {
if idx < 2 {
self.meta().c_label[idx]
} else {
self.bytes(1)[idx - 2]
}
}
#[inline(always)]
pub fn state_move(&self, ch: u8) -> u32 {
let cnt_type = self.cnt_type();
match cnt_type {
0 => NIL_STATE,
1 => {
if ch == self.meta().c_label[0] {
self.child(1)
} else {
NIL_STATE
}
}
2 => {
let meta = self.meta();
if ch == meta.c_label[1] {
self.child(2)
} else if ch == meta.c_label[0] {
self.child(1)
} else {
NIL_STATE
}
}
3 => {
if ch == self.get_label(2) { return self.child(4); }
if ch == self.get_label(1) { return self.child(3); }
if ch == self.get_label(0) { return self.child(2); }
NIL_STATE
}
4 => {
if ch == self.get_label(3) { return self.child(5); }
if ch == self.get_label(2) { return self.child(4); }
if ch == self.get_label(1) { return self.child(3); }
if ch == self.get_label(0) { return self.child(2); }
NIL_STATE
}
5 => {
if ch == self.get_label(4) { return self.child(6); }
if ch == self.get_label(3) { return self.child(5); }
if ch == self.get_label(2) { return self.child(4); }
if ch == self.get_label(1) { return self.child(3); }
if ch == self.get_label(0) { return self.child(2); }
NIL_STATE
}
6 => {
if ch == self.get_label(5) { return self.child(7); }
if ch == self.get_label(4) { return self.child(6); }
if ch == self.get_label(3) { return self.child(5); }
if ch == self.get_label(2) { return self.child(4); }
if ch == self.get_label(1) { return self.child(3); }
if ch == self.get_label(0) { return self.child(2); }
NIL_STATE
}
7 => {
let n_children = self.n_children();
let label_slice = unsafe {
let ptr = self.nodes.as_ptr().add(self.curr as usize + 1) as *const u8;
std::slice::from_raw_parts(ptr, 16)
};
let idx = crate::fsa::fast_search::fast_search_byte_max_16(&label_slice[0..n_children], ch);
if idx < n_children {
self.child(5 + idx)
} else {
NIL_STATE
}
}
8 => {
let bitmap_slice = unsafe {
let ptr = self.nodes.as_ptr().add(self.curr as usize + 2) as *const u8;
std::slice::from_raw_parts(ptr, 32)
};
let byte_idx = (ch / 8) as usize;
let bit_idx = ch % 8;
if (bitmap_slice[byte_idx] & (1 << bit_idx)) != 0 {
let data_ptr = unsafe { self.nodes.as_ptr().add(self.curr as usize + 1) as *const u8 };
let i = (ch / 64) as usize;
let w = unsafe {
std::ptr::read_unaligned(data_ptr.add(4 + i * 8) as *const u64)
};
let b = unsafe { *data_ptr.add(i) } as usize;
let mask = (1u64 << (ch % 64)) - 1;
let idx = b + (w & mask).count_ones() as usize;
self.child(10 + idx)
} else {
NIL_STATE
}
}
15 => {
self.child(2 + ch as usize)
}
_ => NIL_STATE,
}
}
pub fn zpath_slice(&self) -> &'a [u8] {
let zlen = self.zpath_len();
if zlen == 0 {
return &[];
}
let skip = self.skip_slots();
let n_children = self.n_children();
let offset = skip + n_children;
unsafe {
let ptr = self.nodes.as_ptr().add(self.curr as usize + offset) as *const u8;
std::slice::from_raw_parts(ptr, zlen)
}
}
pub fn valpos(&self) -> usize {
let skip = self.skip_slots();
let n_children = self.n_children();
let zlen = self.zpath_len();
let offset = skip + n_children;
let zpath_padded = (zlen + 3) & !3; (self.curr as usize + offset) * 4 + zpath_padded
}
#[inline(always)]
pub fn for_each_child<F>(&self, mut f: F)
where
F: FnMut(u8, u32),
{
let cnt_type = self.cnt_type();
match cnt_type {
0 => {}
1 => {
f(self.meta().c_label[0], self.child(1));
}
2 => {
f(self.meta().c_label[0], self.child(1));
f(self.meta().c_label[1], self.child(2));
}
3 => {
f(self.get_label(0), self.child(2));
f(self.get_label(1), self.child(3));
f(self.get_label(2), self.child(4));
}
4 => {
f(self.get_label(0), self.child(2));
f(self.get_label(1), self.child(3));
f(self.get_label(2), self.child(4));
f(self.get_label(3), self.child(5));
}
5 => {
f(self.get_label(0), self.child(2));
f(self.get_label(1), self.child(3));
f(self.get_label(2), self.child(4));
f(self.get_label(3), self.child(5));
f(self.get_label(4), self.child(6));
}
6 => {
f(self.get_label(0), self.child(2));
f(self.get_label(1), self.child(3));
f(self.get_label(2), self.child(4));
f(self.get_label(3), self.child(5));
f(self.get_label(4), self.child(6));
f(self.get_label(5), self.child(7));
}
7 => {
let n_children = self.n_children();
let label_slice = unsafe {
let ptr = self.nodes.as_ptr().add(self.curr as usize + 1) as *const u8;
std::slice::from_raw_parts(ptr, 16)
};
for i in 0..n_children {
f(label_slice[i], self.child(5 + i));
}
}
8 => {
let bitmap_slice = unsafe {
let ptr = self.nodes.as_ptr().add(self.curr as usize + 2) as *const u8;
std::slice::from_raw_parts(ptr, 32)
};
let mut child_idx = 0;
for byte_idx in 0..32 {
let mut b = bitmap_slice[byte_idx];
let mut bit_offset = 0;
while b != 0 {
let tz = b.trailing_zeros();
let ch = (byte_idx * 8) as u8 + tz as u8;
f(ch, self.child(10 + child_idx));
child_idx += 1;
b &= b - 1;
bit_offset += tz + 1;
}
}
}
15 => {
for ch in 0..=255 {
let child = self.child(2 + ch as usize);
if child != NIL_STATE {
f(ch as u8, child);
}
}
}
_ => {}
}
}
}
impl std::fmt::Debug for CsppTrie {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CsppTrie")
.field("n_words", &self.n_words)
.field("n_nodes", &self.n_nodes)
.finish()
}
}
pub struct CsppTrie {
pub mempool: Vec<PatriciaNode>,
pub n_words: usize,
pub n_nodes: usize,
pub valsize: usize,
pub max_word_len: usize,
fast_bins: Vec<u32>, large_list: Vec<(u32, u32)>, frag_size: usize, lazy_free_list: Vec<LazyFreeItem>,
}
impl CsppTrie {
pub fn new(valsize: usize) -> Self {
let mut trie = Self {
mempool: Vec::new(),
n_words: 0,
n_nodes: 1, valsize,
max_word_len: 0,
fast_bins: vec![FREE_LIST_NIL; FREE_LIST_MAX_SLOTS],
large_list: Vec::new(),
frag_size: 0,
lazy_free_list: Vec::new(),
};
trie.init_root();
trie
}
fn init_root(&mut self) {
let val_slots = (self.valsize + 3) / 4;
let root_slots = 2 + 256 + val_slots;
self.mempool.resize(root_slots, PatriciaNode::empty());
self.mempool[0].meta = MetaInfo {
flags: 15, n_zpath_len: 0,
c_label: [0, 0],
};
unsafe {
let meta_ptr = &mut self.mempool[0].meta as *mut MetaInfo as *mut u8;
std::ptr::write_unaligned(meta_ptr.add(2) as *mut u16, 256);
}
self.mempool[1].big = BigCount {
_unused: 0,
n_children: 0,
};
}
#[inline]
pub fn node_view(&self, pos: u32) -> NodeView {
NodeView::new(&self.mempool, pos)
}
#[inline]
pub fn total_states(&self) -> usize {
self.mempool.len()
}
#[inline]
pub fn num_words(&self) -> usize {
self.n_words
}
#[inline]
pub fn get_value<T: Copy>(&self, valpos: usize) -> T {
debug_assert!(valpos + std::mem::size_of::<T>() <= self.mempool.len() * 4);
unsafe {
let ptr = self.mempool.as_ptr() as *const u8;
std::ptr::read_unaligned(ptr.add(valpos) as *const T)
}
}
pub fn lookup(&self, key: &[u8]) -> Option<usize> {
let mut curr = INITIAL_STATE;
let mut pos = 0;
loop {
let view = self.node_view(curr);
let zlen = view.zpath_len();
if zlen > 0 {
let zpath = view.zpath_slice();
let match_len = std::cmp::min(zlen, key.len() - pos);
if &key[pos..pos + match_len] != &zpath[..match_len] {
return None;
}
pos += match_len;
if key.len() - pos < zlen - match_len { return None;
}
if key.len() == pos {
if view.is_final() {
return Some(view.valpos());
}
return None;
}
} else {
if key.len() == pos {
if view.is_final() {
return Some(view.valpos());
}
return None;
}
}
let next = view.state_move(key[pos]);
if next == NIL_STATE {
return None;
}
curr = next;
pos += 1;
}
}
pub fn contains(&self, key: &[u8]) -> bool {
self.lookup(key).is_some()
}
fn alloc_node(&mut self, byte_size: usize) -> u32 {
let slots = (byte_size + 3) / 4;
if slots > 0 && slots <= FREE_LIST_MAX_SLOTS {
let bin_idx = slots - 1;
let head = self.fast_bins[bin_idx];
if head != FREE_LIST_NIL {
let next = unsafe { self.mempool[head as usize].child };
self.fast_bins[bin_idx] = next;
self.frag_size -= slots * ALIGN_SIZE;
return head;
}
} else if slots > FREE_LIST_MAX_SLOTS {
if let Some(idx) = self.large_list.iter().position(|&(_, s)| s as usize >= slots) {
let (pos, block_slots) = self.large_list.swap_remove(idx);
self.frag_size -= block_slots as usize * ALIGN_SIZE;
let leftover = block_slots as usize - slots;
if leftover > 0 {
self.free_node(pos + slots as u32, leftover * ALIGN_SIZE);
}
return pos;
}
}
let pos = self.mempool.len() as u32;
self.mempool.resize(self.mempool.len() + slots, PatriciaNode::empty());
pos
}
fn free_node(&mut self, slot: u32, byte_size: usize) {
let slots = (byte_size + 3) / 4;
if slots == 0 { return; }
if slot as usize + slots == self.mempool.len() {
self.mempool.truncate(slot as usize);
return;
}
if slots <= FREE_LIST_MAX_SLOTS {
let bin_idx = slots - 1;
unsafe {
(*self.mempool.as_mut_ptr().add(slot as usize)).child = self.fast_bins[bin_idx];
}
self.fast_bins[bin_idx] = slot;
} else {
self.large_list.push((slot, slots as u32));
}
self.frag_size += slots * ALIGN_SIZE;
}
pub fn free_node_deferred_pub(&mut self, slot: u32, byte_size: usize) {
self.free_node_deferred(slot, byte_size);
}
fn free_node_deferred(&mut self, slot: u32, byte_size: usize) {
let slots = ((byte_size + 3) / 4) as u32;
self.lazy_free_list.push(LazyFreeItem { slot, slots });
}
pub fn reclaim_lazy_frees(&mut self) {
let items: Vec<_> = self.lazy_free_list.drain(..).collect();
for item in items {
self.free_node(item.slot, item.slots as usize * ALIGN_SIZE);
}
}
fn realloc_node(&mut self, old_slot: u32, old_size: usize, new_size: usize) -> u32 {
let old_slots = (old_size + 3) / 4;
let new_slots = (new_size + 3) / 4;
if old_slots == new_slots { return old_slot; }
if old_slot as usize + old_slots == self.mempool.len() {
self.mempool.resize(old_slot as usize + new_slots, PatriciaNode::empty());
return old_slot;
}
let new_slot = self.alloc_node(new_size);
let copy_slots = old_slots.min(new_slots);
unsafe {
let src = self.mempool.as_ptr().add(old_slot as usize);
let dst = self.mempool.as_mut_ptr().add(new_slot as usize);
std::ptr::copy_nonoverlapping(src, dst, copy_slots);
}
self.free_node(old_slot, old_size);
new_slot
}
pub fn mem_get_stat(&self) -> MemStat {
let mut fastbin = Vec::with_capacity(FREE_LIST_MAX_SLOTS);
for bin_idx in 0..FREE_LIST_MAX_SLOTS {
let mut count = 0;
let mut head = self.fast_bins[bin_idx];
while head != FREE_LIST_NIL {
count += 1;
head = unsafe { self.mempool[head as usize].child };
}
fastbin.push(count);
}
let large_size: usize = self.large_list.iter().map(|&(_, s)| s as usize * ALIGN_SIZE).sum();
let lazy_sum: usize = self.lazy_free_list.iter().map(|i| i.slots as usize * ALIGN_SIZE).sum();
MemStat {
fastbin,
used_size: self.mempool.len() * ALIGN_SIZE,
capacity: self.mempool.capacity() * ALIGN_SIZE,
frag_size: self.frag_size,
large_size,
large_cnt: self.large_list.len(),
lazy_free_sum: lazy_sum,
lazy_free_cnt: self.lazy_free_list.len(),
}
}
pub fn mem_frag_size(&self) -> usize {
self.frag_size
}
fn new_suffix_chain(&mut self, suffix: &[u8]) -> (u32, usize) {
let mut remaining = suffix;
let mut head = NIL_STATE;
let mut prev_child_slot: u32 = NIL_STATE;
while remaining.len() > MAX_ZPATH {
let link_size = ALIGN_SIZE * 2 + MAX_ZPATH; let node = self.alloc_node(link_size);
unsafe {
let p = self.mempool.as_mut_ptr().add(node as usize);
(*p).meta = MetaInfo {
flags: 1, n_zpath_len: MAX_ZPATH as u8,
c_label: [remaining[MAX_ZPATH], 0],
};
(*p.add(1)).child = NIL_STATE; let zpath_dst = p.add(2) as *mut u8;
std::ptr::copy_nonoverlapping(remaining.as_ptr(), zpath_dst, MAX_ZPATH);
*zpath_dst.add(254) = 0;
*zpath_dst.add(255) = 0;
}
if head == NIL_STATE { head = node; }
if prev_child_slot != NIL_STATE {
unsafe { (*self.mempool.as_mut_ptr().add(prev_child_slot as usize)).child = node; }
}
prev_child_slot = node + 1; remaining = &remaining[MAX_ZPATH + 1..];
}
let zpath_padded = (remaining.len() + 3) & !3;
let leaf_size = ALIGN_SIZE + zpath_padded + self.valsize;
let node = self.alloc_node(leaf_size);
let valpos;
unsafe {
let p = self.mempool.as_mut_ptr().add(node as usize);
(*p).meta = MetaInfo {
flags: 0x10, n_zpath_len: remaining.len() as u8,
c_label: [0, 0],
};
let zpath_dst = (p as *mut u8).add(ALIGN_SIZE);
std::ptr::copy_nonoverlapping(remaining.as_ptr(), zpath_dst, remaining.len());
for i in remaining.len()..zpath_padded {
*zpath_dst.add(i) = 0;
}
valpos = (node as usize + 1) * ALIGN_SIZE + zpath_padded;
}
if head == NIL_STATE { head = node; }
if prev_child_slot != NIL_STATE {
unsafe { (*self.mempool.as_mut_ptr().add(prev_child_slot as usize)).child = node; }
}
(head, valpos)
}
fn build_bitmap_node(
&mut self, labels: &[u8], children: &[u32], n_children: usize,
flags: u8, zpath_len: usize, trailing: &[u8], trailing_len: usize,
) -> u32 {
let node_size = (10 + n_children) * ALIGN_SIZE + trailing_len;
let node = self.alloc_node(node_size);
unsafe {
let p = self.mempool.as_mut_ptr().add(node as usize);
let new_flags = (flags & !0x0F) | 8;
(*p).meta = MetaInfo {
flags: new_flags,
n_zpath_len: zpath_len as u8,
c_label: [0, 0],
};
std::ptr::write_unaligned((p as *mut u8).add(2) as *mut u16, n_children as u16);
let bmp = p.add(2) as *mut u8;
std::ptr::write_bytes(bmp, 0, 32);
for i in 0..n_children {
let label = labels[i];
*bmp.add(label as usize / 8) |= 1 << (label % 8);
}
let rank = p.add(1) as *mut u8;
let mut cumulative = 0u32;
for q in 0..4 {
*rank.add(q) = cumulative as u8;
let w = std::ptr::read_unaligned(bmp.add(q * 8) as *const u64);
cumulative += w.count_ones();
}
for i in 0..n_children {
(*p.add(10 + i)).child = children[i];
}
if trailing_len > 0 {
let dst = (p as *mut u8).add((10 + n_children) * ALIGN_SIZE);
std::ptr::copy_nonoverlapping(trailing.as_ptr(), dst, trailing_len);
}
}
node
}
fn add_state_move_bitmap(&mut self, curr: u32, ch: u8, suffix_node: u32) -> u32 {
let meta = unsafe { self.mempool[curr as usize].meta };
let zpath_len = meta.n_zpath_len as usize;
let is_final = meta.flags & 0x10 != 0;
let old_n = unsafe { self.mempool[curr as usize].big }.n_children as usize;
let mut bitmap = [0u8; 32];
let mut rank_prefix = [0u8; 4];
unsafe {
let bmp_src = self.mempool.as_ptr().add(curr as usize + 2) as *const u8;
std::ptr::copy_nonoverlapping(bmp_src, bitmap.as_mut_ptr(), 32);
let rank_src = self.mempool.as_ptr().add(curr as usize + 1) as *const u8;
std::ptr::copy_nonoverlapping(rank_src, rank_prefix.as_mut_ptr(), 4);
}
let mut old_children = [0u32; 257];
for i in 0..old_n {
old_children[i] = unsafe { self.mempool[curr as usize + 10 + i].child };
}
let zpath_padded = (zpath_len + 3) & !3;
let trailing_len = zpath_padded + if is_final { self.valsize } else { 0 };
let mut trailing = [0u8; 512];
if trailing_len > 0 {
let off = (10 + old_n) * ALIGN_SIZE;
unsafe {
let src = (self.mempool.as_ptr().add(curr as usize) as *const u8).add(off);
std::ptr::copy_nonoverlapping(src, trailing.as_mut_ptr(), trailing_len);
}
}
let ch_rank = {
let q = (ch / 64) as usize;
let w = unsafe { std::ptr::read_unaligned(bitmap.as_ptr().add(q * 8) as *const u64) };
let mask = (1u64 << (ch % 64)) - 1;
rank_prefix[q] as usize + (w & mask).count_ones() as usize
};
bitmap[(ch / 8) as usize] |= 1 << (ch % 8);
let mut cumulative = 0u32;
for q in 0..4 {
rank_prefix[q] = cumulative as u8;
let w = unsafe { std::ptr::read_unaligned(bitmap.as_ptr().add(q * 8) as *const u64) };
cumulative += w.count_ones();
}
for i in (ch_rank..old_n).rev() {
old_children[i + 1] = old_children[i];
}
old_children[ch_rank] = suffix_node;
let new_n = old_n + 1;
let node_size = (10 + new_n) * ALIGN_SIZE + trailing_len;
let node = self.alloc_node(node_size);
unsafe {
let p = self.mempool.as_mut_ptr().add(node as usize);
(*p).meta = MetaInfo {
flags: meta.flags, n_zpath_len: zpath_len as u8,
c_label: [0, 0],
};
std::ptr::write_unaligned((p as *mut u8).add(2) as *mut u16, new_n as u16);
let rank_dst = p.add(1) as *mut u8;
std::ptr::copy_nonoverlapping(rank_prefix.as_ptr(), rank_dst, 4);
let bmp_dst = p.add(2) as *mut u8;
std::ptr::copy_nonoverlapping(bitmap.as_ptr(), bmp_dst, 32);
for i in 0..new_n {
(*p.add(10 + i)).child = old_children[i];
}
if trailing_len > 0 {
let dst = (p as *mut u8).add((10 + new_n) * ALIGN_SIZE);
std::ptr::copy_nonoverlapping(trailing.as_ptr(), dst, trailing_len);
}
}
node
}
fn add_state_move(&mut self, curr: u32, ch: u8, suffix_node: u32) -> u32 {
let meta = unsafe { self.mempool[curr as usize].meta };
let cnt_type = meta.flags & 0x0F;
if cnt_type == 8 {
return self.add_state_move_bitmap(curr, ch, suffix_node);
}
let zpath_len = meta.n_zpath_len as usize;
let is_final = meta.flags & 0x10 != 0;
let old_skip = SKIP_SLOTS[cnt_type as usize] as usize;
let old_n: usize = if cnt_type <= 6 {
cnt_type as usize
} else {
unsafe { self.mempool[curr as usize].big }.n_children as usize
};
let mut labels = [0u8; 17];
match cnt_type {
0 => {}
1 | 2 => {
labels[0] = meta.c_label[0];
if cnt_type >= 2 { labels[1] = meta.c_label[1]; }
}
3..=6 => {
labels[0] = meta.c_label[0];
labels[1] = meta.c_label[1];
let pad = unsafe { self.mempool[curr as usize + 1].bytes };
for i in 2..old_n { labels[i] = pad[i - 2]; }
}
7 => {
unsafe {
let src = self.mempool.as_ptr().add(curr as usize + 1) as *const u8;
for i in 0..old_n { labels[i] = *src.add(i); }
}
}
_ => unreachable!()
}
let mut children = [0u32; 17];
for i in 0..old_n {
children[i] = unsafe { self.mempool[curr as usize + old_skip + i].child };
}
let zpath_padded = (zpath_len + 3) & !3;
let trailing_len = zpath_padded + if is_final { self.valsize } else { 0 };
let mut trailing = [0u8; 512];
if trailing_len > 0 {
let trailing_start = (old_skip + old_n) * ALIGN_SIZE;
unsafe {
let src = (self.mempool.as_ptr().add(curr as usize) as *const u8).add(trailing_start);
std::ptr::copy_nonoverlapping(src, trailing.as_mut_ptr(), trailing_len);
}
}
let idx = labels[..old_n].partition_point(|&l| l < ch);
for i in (idx..old_n).rev() {
labels[i + 1] = labels[i];
children[i + 1] = children[i];
}
labels[idx] = ch;
children[idx] = suffix_node;
let new_n = old_n + 1;
let new_cnt_type: u8 = match cnt_type {
0..=5 => cnt_type + 1,
6 => 7,
7 if old_n < 16 => 7,
7 => 8, _ => unreachable!()
};
if new_cnt_type == 8 {
return self.build_bitmap_node(
&labels, &children, new_n,
meta.flags, zpath_len, &trailing, trailing_len,
);
}
let new_skip = SKIP_SLOTS[new_cnt_type as usize] as usize;
let new_size = (new_skip + new_n) * ALIGN_SIZE + trailing_len;
let node = self.alloc_node(new_size);
unsafe {
let p = self.mempool.as_mut_ptr().add(node as usize);
let new_flags = (meta.flags & !0x0F) | new_cnt_type;
match new_cnt_type {
1 | 2 => {
(*p).meta = MetaInfo {
flags: new_flags,
n_zpath_len: zpath_len as u8,
c_label: [labels[0], if new_cnt_type >= 2 { labels[1] } else { 0 }],
};
}
3..=6 => {
(*p).meta = MetaInfo {
flags: new_flags,
n_zpath_len: zpath_len as u8,
c_label: [labels[0], labels[1]],
};
let pad_ptr = p.add(1) as *mut u8;
for i in 2..new_n { *pad_ptr.add(i - 2) = labels[i]; }
for i in (new_n - 2)..4 { *pad_ptr.add(i) = 0; }
}
7 => {
(*p).meta = MetaInfo {
flags: new_flags,
n_zpath_len: zpath_len as u8,
c_label: [0, 0],
};
std::ptr::write_unaligned((p as *mut u8).add(2) as *mut u16, new_n as u16);
let lbl_ptr = p.add(1) as *mut u8;
for i in 0..new_n { *lbl_ptr.add(i) = labels[i]; }
for i in new_n..16 { *lbl_ptr.add(i) = 0; }
}
_ => unreachable!()
}
for i in 0..new_n {
(*p.add(new_skip + i)).child = children[i];
}
if trailing_len > 0 {
let dst = (p as *mut u8).add((new_skip + new_n) * ALIGN_SIZE);
std::ptr::copy_nonoverlapping(trailing.as_ptr(), dst, trailing_len);
}
}
node
}
fn fork(
&mut self, curr: u32, zidx: usize,
old_skip: usize, old_n_children: usize, zpath_len: usize,
node_size: usize, zpath_buf: &[u8],
new_char: u8, new_suffix_node: u32,
) -> (u32, u32) {
let old_char = zpath_buf[zidx];
let suffix_zlen = zpath_len - zidx - 1;
let suffix_zpath_padded = (suffix_zlen + 3) & !3;
let val_size = node_size - ((old_skip + old_n_children) * ALIGN_SIZE + ((zpath_len + 3) & !3));
let suffix_size = (old_skip + old_n_children) * ALIGN_SIZE + suffix_zpath_padded + val_size;
let suffix_node = self.alloc_node(suffix_size);
unsafe {
let base = self.mempool.as_mut_ptr();
let src = base.add(curr as usize) as *const u8;
let dst = base.add(suffix_node as usize) as *mut u8;
let struct_size = (old_skip + old_n_children) * ALIGN_SIZE;
std::ptr::copy_nonoverlapping(src, dst, struct_size);
(*base.add(suffix_node as usize)).meta.n_zpath_len = suffix_zlen as u8;
let zpath_dst = dst.add(struct_size);
for i in 0..suffix_zlen {
*zpath_dst.add(i) = zpath_buf[zidx + 1 + i];
}
for i in suffix_zlen..suffix_zpath_padded {
*zpath_dst.add(i) = 0;
}
if val_size > 0 {
let old_val_off = struct_size + ((zpath_len + 3) & !3);
std::ptr::copy_nonoverlapping(
src.add(old_val_off),
zpath_dst.add(suffix_zpath_padded),
val_size,
);
}
}
let prefix_zpath_padded = (zidx + 3) & !3;
let parent_size = 3 * ALIGN_SIZE + prefix_zpath_padded; let parent = self.alloc_node(parent_size);
unsafe {
let base = self.mempool.as_mut_ptr();
let p = base.add(parent as usize);
let (label0, child0, label1, child1) = if old_char < new_char {
(old_char, suffix_node, new_char, new_suffix_node)
} else {
(new_char, new_suffix_node, old_char, suffix_node)
};
(*p).meta = MetaInfo {
flags: 2, n_zpath_len: zidx as u8,
c_label: [label0, label1],
};
(*p.add(1)).child = child0;
(*p.add(2)).child = child1;
let zpath_dst = (p as *mut u8).add(3 * ALIGN_SIZE);
for i in 0..zidx {
*zpath_dst.add(i) = zpath_buf[i];
}
for i in zidx..prefix_zpath_padded {
*zpath_dst.add(i) = 0;
}
}
(parent, suffix_node)
}
fn split_zpath(
&mut self, curr: u32,
split_pos: usize,
old_skip: usize, old_n_children: usize, zpath_len: usize,
node_size: usize, zpath_buf: &[u8],
) -> (u32, usize) {
let split_char = zpath_buf[split_pos];
let suffix_zlen = zpath_len - split_pos - 1;
let suffix_zpath_padded = (suffix_zlen + 3) & !3;
let val_size = node_size - ((old_skip + old_n_children) * ALIGN_SIZE + ((zpath_len + 3) & !3));
let suffix_size = (old_skip + old_n_children) * ALIGN_SIZE + suffix_zpath_padded + val_size;
let suffix_node = self.alloc_node(suffix_size);
unsafe {
let base = self.mempool.as_mut_ptr();
let src = base.add(curr as usize) as *const u8;
let dst = base.add(suffix_node as usize) as *mut u8;
let struct_size = (old_skip + old_n_children) * ALIGN_SIZE;
std::ptr::copy_nonoverlapping(src, dst, struct_size);
(*base.add(suffix_node as usize)).meta.n_zpath_len = suffix_zlen as u8;
let zpath_dst = dst.add(struct_size);
for i in 0..suffix_zlen {
*zpath_dst.add(i) = zpath_buf[split_pos + 1 + i];
}
for i in suffix_zlen..suffix_zpath_padded {
*zpath_dst.add(i) = 0;
}
if val_size > 0 {
let old_val_off = struct_size + ((zpath_len + 3) & !3);
std::ptr::copy_nonoverlapping(
src.add(old_val_off),
zpath_dst.add(suffix_zpath_padded),
val_size,
);
}
}
let prefix_zpath_padded = (split_pos + 3) & !3;
let prefix_size = 2 * ALIGN_SIZE + prefix_zpath_padded + self.valsize;
let prefix_node = self.alloc_node(prefix_size);
let valpos;
unsafe {
let base = self.mempool.as_mut_ptr();
let p = base.add(prefix_node as usize);
(*p).meta = MetaInfo {
flags: 1 | 0x10, n_zpath_len: split_pos as u8,
c_label: [split_char, 0],
};
(*p.add(1)).child = suffix_node;
let zpath_dst = (p as *mut u8).add(2 * ALIGN_SIZE);
for i in 0..split_pos {
*zpath_dst.add(i) = zpath_buf[i];
}
for i in split_pos..prefix_zpath_padded {
*zpath_dst.add(i) = 0;
}
valpos = (prefix_node as usize + 2) * ALIGN_SIZE + prefix_zpath_padded;
}
(prefix_node, valpos)
}
fn find_child_slot(&self, curr: u32, ch: u8) -> u32 {
let view = self.node_view(curr);
let cnt_type = view.cnt_type();
match cnt_type {
0 => NIL_STATE,
1 => {
if ch == view.meta().c_label[0] { curr + 1 } else { NIL_STATE }
}
2 => {
let meta = view.meta();
if ch == meta.c_label[0] { curr + 1 }
else if ch == meta.c_label[1] { curr + 2 }
else { NIL_STATE }
}
3..=6 => {
for i in 0..cnt_type as usize {
if ch == view.get_label(i) {
return curr + 2 + i as u32;
}
}
NIL_STATE
}
7 => {
let n = view.n_children();
let label_slice = unsafe {
let ptr = self.mempool.as_ptr().add(curr as usize + 1) as *const u8;
std::slice::from_raw_parts(ptr, 16)
};
let idx = crate::fsa::fast_search::fast_search_byte_max_16(&label_slice[..n], ch);
if idx < n { curr + 5 + idx as u32 } else { NIL_STATE }
}
8 => {
let bitmap_slice = unsafe {
let ptr = self.mempool.as_ptr().add(curr as usize + 2) as *const u8;
std::slice::from_raw_parts(ptr, 32)
};
let byte_idx = (ch / 8) as usize;
let bit_idx = ch % 8;
if (bitmap_slice[byte_idx] & (1 << bit_idx)) != 0 {
let data_ptr = unsafe {
self.mempool.as_ptr().add(curr as usize + 1) as *const u8
};
let i = (ch / 64) as usize;
let w = unsafe { std::ptr::read_unaligned(data_ptr.add(4 + i * 8) as *const u64) };
let b = unsafe { *data_ptr.add(i) } as usize;
let mask = (1u64 << (ch % 64)) - 1;
let idx = b + (w & mask).count_ones() as usize;
curr + 10 + idx as u32
} else {
NIL_STATE
}
}
15 => curr + 2 + ch as u32,
_ => NIL_STATE,
}
}
pub fn insert(&mut self, key: &[u8]) -> (bool, usize) {
let mut curr_slot: u32 = NIL_STATE; let mut curr: u32 = INITIAL_STATE;
let mut pos: usize = 0;
loop {
let (cnt_type, zpath_len, is_final, skip, n_children, flags) = {
let view = self.node_view(curr);
(view.cnt_type(), view.zpath_len(), view.is_final(),
view.skip_slots(), view.n_children(), view.meta().flags)
};
let node_size = (skip + n_children) * ALIGN_SIZE
+ ((zpath_len + 3) & !3)
+ if is_final { self.valsize } else { 0 };
if zpath_len > 0 {
let mut zpath_buf = [0u8; 256];
let zpath_off = (skip + n_children) * ALIGN_SIZE;
unsafe {
let src = (self.mempool.as_ptr().add(curr as usize) as *const u8).add(zpath_off);
std::ptr::copy_nonoverlapping(src, zpath_buf.as_mut_ptr(), zpath_len);
}
let remaining_key = key.len() - pos;
let match_len = std::cmp::min(zpath_len, remaining_key);
let mut mismatch_at: Option<usize> = None;
for i in 0..match_len {
if key[pos + i] != zpath_buf[i] {
mismatch_at = Some(i);
break;
}
}
if let Some(zidx) = mismatch_at {
let (new_suffix, valpos) = self.new_suffix_chain(&key[pos + zidx + 1..]);
let (new_parent, _old_suffix) = self.fork(
curr, zidx, skip, n_children, zpath_len, node_size,
&zpath_buf[..zpath_len], key[pos + zidx], new_suffix,
);
if curr_slot != NIL_STATE {
unsafe { (*self.mempool.as_mut_ptr().add(curr_slot as usize)).child = new_parent; }
}
self.free_node(curr, node_size);
self.n_words += 1;
if key.len() > self.max_word_len { self.max_word_len = key.len(); }
return (true, valpos);
}
pos += match_len;
if remaining_key < zpath_len {
let (prefix_node, valpos) = self.split_zpath(
curr, match_len, skip, n_children, zpath_len, node_size,
&zpath_buf[..zpath_len],
);
if curr_slot != NIL_STATE {
unsafe { (*self.mempool.as_mut_ptr().add(curr_slot as usize)).child = prefix_node; }
}
self.free_node(curr, node_size);
self.n_words += 1;
if key.len() > self.max_word_len { self.max_word_len = key.len(); }
return (true, valpos);
}
if pos == key.len() {
if is_final {
let vp = (curr as usize + skip + n_children) * ALIGN_SIZE + ((zpath_len + 3) & !3);
return (false, vp);
}
let old_size = node_size;
let new_size = old_size + self.valsize;
let new_curr = self.realloc_node(curr, old_size, new_size);
unsafe {
(*self.mempool.as_mut_ptr().add(new_curr as usize)).meta.flags |= 0x10;
}
if curr_slot != NIL_STATE && new_curr != curr {
unsafe { (*self.mempool.as_mut_ptr().add(curr_slot as usize)).child = new_curr; }
}
let vp = (new_curr as usize + skip + n_children) * ALIGN_SIZE + ((zpath_len + 3) & !3);
self.n_words += 1;
if key.len() > self.max_word_len { self.max_word_len = key.len(); }
return (true, vp);
}
} else {
if pos == key.len() {
if is_final {
let vp = (curr as usize + skip + n_children) * ALIGN_SIZE;
return (false, vp);
}
if cnt_type == 15 {
unsafe {
(*self.mempool.as_mut_ptr().add(curr as usize)).meta.flags |= 0x10;
}
let vp = (curr as usize + 2 + 256) * ALIGN_SIZE;
self.n_words += 1;
if key.len() > self.max_word_len { self.max_word_len = key.len(); }
return (true, vp);
}
let old_size = node_size;
let new_size = old_size + self.valsize;
let new_curr = self.realloc_node(curr, old_size, new_size);
unsafe {
(*self.mempool.as_mut_ptr().add(new_curr as usize)).meta.flags |= 0x10;
}
if curr_slot != NIL_STATE && new_curr != curr {
unsafe { (*self.mempool.as_mut_ptr().add(curr_slot as usize)).child = new_curr; }
}
let vp = (new_curr as usize + skip + n_children) * ALIGN_SIZE;
self.n_words += 1;
if key.len() > self.max_word_len { self.max_word_len = key.len(); }
return (true, vp);
}
}
let ch = key[pos];
let next = self.node_view(curr).state_move(ch);
if next == NIL_STATE {
let (suffix_node, valpos) = self.new_suffix_chain(&key[pos + 1..]);
if cnt_type != 15 {
let new_curr = self.add_state_move(curr, ch, suffix_node);
if curr_slot != NIL_STATE {
unsafe { (*self.mempool.as_mut_ptr().add(curr_slot as usize)).child = new_curr; }
}
self.free_node(curr, node_size);
} else {
unsafe {
(*self.mempool.as_mut_ptr().add(curr as usize + 2 + ch as usize)).child = suffix_node;
let real_cnt = &mut (*self.mempool.as_mut_ptr().add(curr as usize + 1)).big;
(*real_cnt).n_children += 1;
}
}
self.n_words += 1;
if key.len() > self.max_word_len { self.max_word_len = key.len(); }
return (true, valpos);
}
curr_slot = self.find_child_slot(curr, ch);
curr = next;
pos += 1;
}
}
}
pub struct IterEntry {
pub state: u32,
pub child_idx: usize,
pub n_children: usize,
pub zpath_consumed: bool,
}
pub struct CsppTrieIterator<'a, T> {
trie: &'a CsppTrie,
stack: Vec<IterEntry>,
word: Vec<u8>,
_marker: std::marker::PhantomData<T>,
}
impl<'a, T: Copy> CsppTrieIterator<'a, T> {
pub fn new(trie: &'a CsppTrie) -> Self {
Self {
trie,
stack: Vec::with_capacity(32),
word: Vec::with_capacity(32),
_marker: std::marker::PhantomData,
}
}
pub fn seek_begin(&mut self) -> bool {
self.stack.clear();
self.word.clear();
self.stack.push(IterEntry {
state: INITIAL_STATE,
child_idx: 0,
n_children: self.trie.node_view(INITIAL_STATE).n_children(),
zpath_consumed: false,
});
self.descend_leftmost()
}
fn descend_leftmost(&mut self) -> bool {
while let Some(mut top) = self.stack.pop() {
let view = self.trie.node_view(top.state);
if !top.zpath_consumed {
let zlen = view.zpath_len();
if zlen > 0 {
self.word.extend_from_slice(view.zpath_slice());
}
top.zpath_consumed = true;
self.stack.push(top);
if view.is_final() {
return true;
}
top = self.stack.pop().unwrap();
}
if top.child_idx < view.n_children() {
let mut first_child = None;
let mut current_idx = 0;
view.for_each_child(|ch, child_state| {
if current_idx == top.child_idx {
first_child = Some((ch, child_state));
}
current_idx += 1;
});
top.child_idx += 1;
self.stack.push(top);
if let Some((ch, child_state)) = first_child {
self.word.push(ch);
self.stack.push(IterEntry {
state: child_state,
child_idx: 0,
n_children: self.trie.node_view(child_state).n_children(),
zpath_consumed: false,
});
}
} else {
self.stack.push(top);
return self.incr();
}
}
false
}
pub fn incr(&mut self) -> bool {
while let Some(mut top) = self.stack.pop() {
let view = self.trie.node_view(top.state);
if top.child_idx < view.n_children() {
let mut next_child = None;
let mut current_idx = 0;
view.for_each_child(|ch, child_state| {
if current_idx == top.child_idx {
next_child = Some((ch, child_state));
}
current_idx += 1;
});
top.child_idx += 1;
self.stack.push(top);
if let Some((ch, child_state)) = next_child {
self.word.push(ch);
self.stack.push(IterEntry {
state: child_state,
child_idx: 0,
n_children: self.trie.node_view(child_state).n_children(),
zpath_consumed: false,
});
if self.descend_leftmost() {
return true;
}
}
} else {
if let Some(_) = self.stack.last() {
let backtrack_len = 1 + view.zpath_len();
self.word.truncate(self.word.len().saturating_sub(backtrack_len));
} else {
self.word.clear();
return false;
}
}
}
false
}
pub fn word(&self) -> &[u8] {
&self.word
}
pub fn value(&self) -> T {
let top = self.stack.last().unwrap();
let view = self.trie.node_view(top.state);
self.trie.get_value(view.valpos())
}
}