use crate::errors::{CrawdadError, Result};
use crate::mapper::CodeMapper;
use crate::{utils, MpTrie, Node, Trie};
use crate::{END_CODE, END_MARKER, INVALID_IDX, MAX_VALUE, OFFSET_MASK};
use core::cmp::Ordering;
use alloc::vec::Vec;
const DEFAULT_NUM_FREE_BLOCKS: u32 = 16;
#[derive(Default)]
struct Record {
key: Vec<char>,
value: u32,
}
#[derive(Default, Debug, PartialEq, Eq)]
struct Suffix {
key: Vec<char>,
value: u32,
}
pub struct Builder {
records: Vec<Record>,
mapper: CodeMapper,
nodes: Vec<Node>,
suffixes: Option<Vec<Suffix>>,
labels: Vec<u32>,
head_idx: u32,
block_len: u32,
num_free_blocks: u32,
}
impl Default for Builder {
fn default() -> Self {
Self {
records: vec![],
mapper: CodeMapper::default(),
nodes: vec![],
suffixes: None,
labels: vec![],
head_idx: 0,
block_len: 0,
num_free_blocks: DEFAULT_NUM_FREE_BLOCKS,
}
}
}
impl Builder {
pub fn new() -> Self {
Self::default()
}
#[allow(clippy::missing_const_for_fn)]
pub fn minimal_prefix(mut self) -> Self {
self.suffixes = Some(vec![]);
self
}
pub fn build_from_keys<I, K>(self, keys: I) -> Result<Self>
where
I: IntoIterator<Item = K>,
K: AsRef<str>,
{
self.build_from_records(
keys.into_iter()
.enumerate()
.map(|(i, k)| (k, i.try_into().unwrap())),
)
}
pub fn build_from_records<I, K>(mut self, records: I) -> Result<Self>
where
I: IntoIterator<Item = (K, u32)>,
K: AsRef<str>,
{
self.records = records
.into_iter()
.map(|(k, v)| Record {
key: k.as_ref().chars().collect(),
value: v,
})
.collect();
self.records.sort_unstable_by(|a, b| a.key.cmp(&b.key));
for &Record { key: _, value } in &self.records {
if MAX_VALUE < value {
return Err(CrawdadError::scale("input value", MAX_VALUE));
}
}
self.mapper = CodeMapper::new(&make_freqs(&self.records)?);
assert_eq!(self.mapper.get(END_MARKER).unwrap(), END_CODE);
make_prefix_free(&mut self.records)?;
self.block_len = self.mapper.alphabet_size().next_power_of_two().max(2);
self.init_array();
self.arrange_nodes(0, self.records.len(), 0, 0)?;
self.finish();
Ok(self)
}
#[allow(clippy::missing_const_for_fn)]
pub fn release_trie(self) -> Result<Trie> {
if self.suffixes.is_some() {
Err(CrawdadError::setup("minimal_prefix must be disabled."))
} else {
let Self { nodes, mapper, .. } = self;
Ok(Trie { nodes, mapper })
}
}
pub fn release_mptrie(self) -> Result<MpTrie> {
let Self {
mapper,
mut nodes,
suffixes,
..
} = self;
let suffixes =
suffixes.ok_or_else(|| CrawdadError::setup("minimal_prefix must be enabled."))?;
let mut tails = vec![];
let max_code = mapper.alphabet_size() - 1;
let code_size = utils::pack_size(max_code);
let max_value = suffixes.iter().map(|s| s.value).max().unwrap();
let value_size = utils::pack_size(max_value);
for node_idx in 0..nodes.len() {
if nodes[node_idx].is_vacant() {
continue;
}
if !nodes[node_idx].is_leaf() {
continue;
}
debug_assert_eq!(nodes[node_idx].check & !OFFSET_MASK, 0);
let parent_idx = usize::try_from(nodes[node_idx].check).unwrap();
let suf_idx = usize::try_from(nodes[node_idx].base & OFFSET_MASK).unwrap();
let suffix = &suffixes[suf_idx];
if nodes[parent_idx].has_leaf() {
if usize::try_from(nodes[parent_idx].base).unwrap() == node_idx {
assert!(suffix.key.is_empty());
nodes[node_idx].base = suffix.value | !OFFSET_MASK;
continue;
}
}
let tail_start = if tails.len() <= usize::try_from(OFFSET_MASK).unwrap() {
u32::try_from(tails.len()).unwrap()
} else {
return Err(CrawdadError::scale("length of tails", OFFSET_MASK));
};
if suffix.key.len() > usize::from(u8::MAX) {
return Err(CrawdadError::scale("length of suffix", u32::from(u8::MAX)));
}
nodes[node_idx].base = tail_start | !OFFSET_MASK;
tails.push(suffix.key.len().try_into().unwrap());
suffix
.key
.iter()
.map(|&c| mapper.get(c).unwrap())
.for_each(|c| utils::pack_u32(&mut tails, c, code_size));
utils::pack_u32(&mut tails, suffix.value, value_size);
}
Ok(MpTrie {
mapper,
nodes,
tails,
code_size,
value_size,
})
}
#[inline(always)]
fn num_nodes(&self) -> u32 {
self.nodes.len().try_into().unwrap()
}
fn init_array(&mut self) {
self.nodes.clear();
self.nodes
.resize(usize::try_from(self.block_len).unwrap(), Node::default());
for i in 0..self.block_len {
if i == 0 {
self.set_prev(i, self.block_len - 1);
} else {
self.set_prev(i, i - 1);
}
if i == self.block_len - 1 {
self.set_next(i, 0);
} else {
self.set_next(i, i + 1);
}
}
self.head_idx = 0;
self.fix_node(0);
}
fn arrange_nodes(
&mut self,
spos: usize,
epos: usize,
depth: usize,
node_idx: u32,
) -> Result<()> {
debug_assert!(self.is_fixed(node_idx));
if let Some(suffixes) = self.suffixes.as_mut() {
if spos + 1 == epos {
debug_assert_eq!(self.records[spos].value & !OFFSET_MASK, 0);
let suffix_idx = if suffixes.len() <= usize::try_from(OFFSET_MASK).unwrap() {
u32::try_from(suffixes.len()).unwrap()
} else {
return Err(CrawdadError::scale("length of suffixes", OFFSET_MASK));
};
self.nodes[usize::try_from(node_idx).unwrap()].base = suffix_idx | !OFFSET_MASK;
suffixes.push(Suffix {
key: pop_end_marker(&self.records[spos].key[depth..]),
value: self.records[spos].value,
});
return Ok(());
}
} else if self.records[spos].key.len() == depth {
debug_assert_eq!(spos + 1, epos);
debug_assert_eq!(self.records[spos].value & !OFFSET_MASK, 0);
self.node_mut(node_idx).base = self.records[spos].value | !OFFSET_MASK;
return Ok(());
}
self.fetch_labels(spos, epos, depth);
let base = self.define_nodes(node_idx)?;
let mut i1 = spos;
let mut c1 = self.records[i1].key[depth];
for i2 in spos + 1..epos {
let c2 = self.records[i2].key[depth];
if c1 != c2 {
let child_idx = base ^ self.mapper.get(c1).unwrap();
self.arrange_nodes(i1, i2, depth + 1, child_idx)?;
i1 = i2;
c1 = c2;
}
}
let child_idx = base ^ self.mapper.get(c1).unwrap();
self.arrange_nodes(i1, epos, depth + 1, child_idx)
}
fn finish(&mut self) {
self.node_mut(0).check = OFFSET_MASK;
if self.head_idx != INVALID_IDX {
let mut node_idx = self.head_idx;
loop {
let next_idx = self.get_next(node_idx);
self.node_mut(node_idx).base = OFFSET_MASK;
self.node_mut(node_idx).check = OFFSET_MASK;
node_idx = next_idx;
if node_idx == self.head_idx {
break;
}
}
}
for node_idx in 0..self.num_nodes() {
if self.node_ref(node_idx).is_vacant() {
continue;
}
if self.node_ref(node_idx).is_leaf() {
continue;
}
let end_idx = self.node_ref(node_idx).base ^ END_CODE;
if self.node_ref(end_idx).check == node_idx {
self.node_mut(node_idx).check |= !OFFSET_MASK;
}
}
}
fn fetch_labels(&mut self, spos: usize, epos: usize, depth: usize) {
self.labels.clear();
let mut c1 = self.records[spos].key[depth];
for i in spos + 1..epos {
let c2 = self.records[i].key[depth];
if c1 != c2 {
self.labels.push(self.mapper.get(c1).unwrap());
c1 = c2;
}
}
self.labels.push(self.mapper.get(c1).unwrap());
}
fn define_nodes(&mut self, node_idx: u32) -> Result<u32> {
let base = self.find_base(&self.labels);
if base >= self.num_nodes() {
self.enlarge()?;
}
self.node_mut(node_idx).base = base;
for i in 0..self.labels.len() {
let child_idx = base ^ self.labels[i];
self.fix_node(child_idx);
self.node_mut(child_idx).check = node_idx;
}
Ok(base)
}
fn find_base(&self, labels: &[u32]) -> u32 {
debug_assert!(!labels.is_empty());
if self.head_idx == INVALID_IDX {
return self.num_nodes() ^ labels[0];
}
let mut node_idx = self.head_idx;
loop {
let base = node_idx ^ labels[0];
if self.verify_base(base, labels) {
return base;
}
node_idx = self.get_next(node_idx);
if node_idx == self.head_idx {
break;
}
}
self.num_nodes() ^ labels[0]
}
#[inline(always)]
fn verify_base(&self, base: u32, labels: &[u32]) -> bool {
for &label in labels {
let node_idx = base ^ label;
if self.is_fixed(node_idx) {
return false;
}
}
true
}
#[inline(always)]
fn fix_node(&mut self, node_idx: u32) {
debug_assert!(!self.is_fixed(node_idx));
let next = self.get_next(node_idx);
let prev = self.get_prev(node_idx);
self.set_next(prev, next);
self.set_prev(next, prev);
self.set_fixed(node_idx);
if self.head_idx == node_idx {
if next == node_idx {
self.head_idx = INVALID_IDX;
} else {
self.head_idx = next;
}
}
}
fn enlarge(&mut self) -> Result<()> {
let old_len = self.num_nodes();
let new_len = old_len + self.block_len;
if OFFSET_MASK < new_len {
return Err(CrawdadError::scale("num_nodes", OFFSET_MASK));
}
let num_blocks = old_len / self.block_len;
if self.num_free_blocks <= num_blocks {
self.close_block(num_blocks - self.num_free_blocks);
}
for i in old_len..new_len {
self.nodes.push(Node::default());
self.set_next(i, i + 1);
self.set_prev(i, i - 1);
}
if self.head_idx == INVALID_IDX {
self.set_prev(old_len, new_len - 1);
self.set_next(new_len - 1, old_len);
self.head_idx = old_len;
} else {
let head_idx = self.head_idx;
let tail_idx = self.get_prev(head_idx);
self.set_prev(old_len, tail_idx);
self.set_next(tail_idx, old_len);
self.set_next(new_len - 1, head_idx);
self.set_prev(head_idx, new_len - 1);
}
Ok(())
}
fn close_block(&mut self, block_idx: u32) {
let beg_idx = block_idx * self.block_len;
let end_idx = beg_idx + self.block_len;
while self.head_idx < end_idx {
debug_assert_ne!(self.head_idx, INVALID_IDX);
let idx = self.head_idx;
self.fix_node(idx);
self.node_mut(idx).base = OFFSET_MASK;
self.node_mut(idx).check = OFFSET_MASK;
}
}
#[inline(always)]
fn node_ref(&self, i: u32) -> &Node {
&self.nodes[usize::try_from(i).unwrap()]
}
#[inline(always)]
fn node_mut(&mut self, i: u32) -> &mut Node {
&mut self.nodes[usize::try_from(i).unwrap()]
}
#[inline(always)]
fn is_fixed(&self, i: u32) -> bool {
self.node_ref(i).check & !OFFSET_MASK == 0
}
#[inline(always)]
fn set_fixed(&mut self, i: u32) {
debug_assert!(!self.is_fixed(i));
self.node_mut(i).base = INVALID_IDX;
self.node_mut(i).check &= OFFSET_MASK;
}
#[inline(always)]
fn get_next(&self, i: u32) -> u32 {
debug_assert_ne!(self.node_ref(i).base & !OFFSET_MASK, 0);
self.node_ref(i).base & OFFSET_MASK
}
#[inline(always)]
fn get_prev(&self, i: u32) -> u32 {
debug_assert_ne!(self.node_ref(i).check & !OFFSET_MASK, 0);
self.node_ref(i).check & OFFSET_MASK
}
#[inline(always)]
fn set_next(&mut self, i: u32, x: u32) {
debug_assert_eq!(x & !OFFSET_MASK, 0);
self.node_mut(i).base = x | !OFFSET_MASK
}
#[inline(always)]
fn set_prev(&mut self, i: u32, x: u32) {
debug_assert_eq!(x & !OFFSET_MASK, 0);
self.node_mut(i).check = x | !OFFSET_MASK
}
}
fn make_freqs(records: &[Record]) -> Result<Vec<u32>> {
let end_marker = usize::try_from(u32::from(END_MARKER)).unwrap();
let mut freqs = vec![0; end_marker + 1];
for rec in records {
for &c in &rec.key {
let c = usize::try_from(u32::from(c)).unwrap();
if freqs.len() <= c {
freqs.resize(c + 1, 0);
}
freqs[c] += 1;
}
}
if let Some(&freq) = freqs.get(end_marker) {
if freq != 0 {
return Err(CrawdadError::input("END_MARKER must not be contained."));
}
}
freqs[end_marker] = u32::MAX;
Ok(freqs)
}
fn make_prefix_free(records: &mut [Record]) -> Result<()> {
if records.is_empty() {
return Err(CrawdadError::input("records must not be empty."));
}
if records[0].key.is_empty() {
return Err(CrawdadError::input(
"records must not contain an empty key.",
));
}
for i in 1..records.len() {
let (lcp, cmp) = utils::longest_common_prefix(&records[i - 1].key, &records[i].key);
match cmp {
Ordering::Less => {
if lcp == records[i - 1].key.len() {
records[i - 1].key.push(END_MARKER);
}
}
Ordering::Equal => {
return Err(CrawdadError::input(
"records must not contain duplicated keys.",
));
}
_ => unreachable!(),
}
}
Ok(())
}
fn pop_end_marker(x: &[char]) -> Vec<char> {
match x.split_last() {
Some((&END_MARKER, elems)) => elems.to_vec(),
_ => x.to_vec(),
}
}