use crate::builder::Builder;
use crate::errors::Result;
use crate::mapper::CodeMapper;
use crate::{utils, Node};
use crate::END_CODE;
use alloc::vec::Vec;
use core::mem;
pub struct MpTrie {
pub(crate) mapper: CodeMapper,
pub(crate) nodes: Vec<Node>,
pub(crate) tails: Vec<u8>,
pub(crate) code_size: u8,
pub(crate) value_size: u8,
}
impl MpTrie {
pub fn from_keys<I, K>(keys: I) -> Result<Self>
where
I: IntoIterator<Item = K>,
K: AsRef<str>,
{
Builder::new()
.minimal_prefix()
.build_from_keys(keys)?
.release_mptrie()
}
pub fn from_records<I, K>(records: I) -> Result<Self>
where
I: IntoIterator<Item = (K, u32)>,
K: AsRef<str>,
{
Builder::new()
.minimal_prefix()
.build_from_records(records)?
.release_mptrie()
}
pub fn serialize_to_vec(&self) -> Vec<u8> {
let mut dest = Vec::with_capacity(self.io_bytes());
self.mapper.serialize_into_vec(&mut dest);
dest.extend_from_slice(&u32::try_from(self.nodes.len()).unwrap().to_le_bytes());
for node in &self.nodes {
dest.extend_from_slice(&node.serialize());
}
dest.extend_from_slice(&u32::try_from(self.tails.len()).unwrap().to_le_bytes());
dest.extend_from_slice(&self.tails);
dest.extend_from_slice(&[self.code_size]);
dest.extend_from_slice(&[self.value_size]);
dest
}
pub fn deserialize_from_slice(source: &[u8]) -> (Self, &[u8]) {
let (mapper, mut source) = CodeMapper::deserialize_from_slice(source);
let nodes = {
let len = u32::from_le_bytes(source[..4].try_into().unwrap()) as usize;
source = &source[4..];
let mut nodes = Vec::with_capacity(len);
for _ in 0..len {
nodes.push(Node::deserialize(
source[..Node::io_bytes()].try_into().unwrap(),
));
source = &source[Node::io_bytes()..];
}
nodes
};
let tails = {
let len = u32::from_le_bytes(source[..4].try_into().unwrap()) as usize;
source = &source[4..];
let tails = source[..len].to_vec();
source = &source[len..];
tails
};
let code_size = source[0];
let value_size = source[1];
(
Self {
mapper,
nodes,
tails,
code_size,
value_size,
},
&source[2..],
)
}
#[inline(always)]
pub fn exact_match<I>(&self, key: I) -> Option<u32>
where
I: IntoIterator<Item = char>,
{
let mut node_idx = 0;
let mut chars = key.into_iter();
while !self.is_leaf(node_idx) {
if let Some(c) = chars.next() {
node_idx = self
.mapper
.get(c)
.and_then(|mc| self.get_child_idx(node_idx, mc))?;
} else {
return self
.has_leaf(node_idx)
.then(|| self.get_value(self.get_leaf_idx(node_idx)));
}
}
let tail_pos = usize::try_from(self.get_value(node_idx)).unwrap();
let mut tail_iter = self.tail_iter(tail_pos);
for tc in tail_iter.by_ref() {
chars
.next()
.and_then(|c| self.mapper.get(c))
.filter(|&mc| mc == tc)?;
}
chars.next().is_none().then(|| tail_iter.value())
}
pub const fn common_prefix_search<I>(&self, haystack: I) -> CommonPrefixSearchIter<I> {
CommonPrefixSearchIter {
haystack,
haystack_pos: 0,
trie: self,
node_idx: 0,
}
}
#[inline(always)]
fn tail_iter(&self, tail_pos: usize) -> TailIter {
let tail_len = usize::try_from(self.tails[tail_pos]).unwrap();
TailIter {
trie: self,
pos: tail_pos + 1,
len: tail_len,
}
}
#[inline(always)]
fn get_child_idx(&self, node_idx: u32, mc: u32) -> Option<u32> {
if self.is_leaf(node_idx) {
return None;
}
Some(self.get_base(node_idx) ^ mc)
.filter(|&child_idx| self.get_check(child_idx) == node_idx)
}
#[inline(always)]
fn node_ref(&self, node_idx: u32) -> &Node {
&self.nodes[usize::try_from(node_idx).unwrap()]
}
#[inline(always)]
fn get_base(&self, node_idx: u32) -> u32 {
self.node_ref(node_idx).get_base()
}
#[inline(always)]
fn get_check(&self, node_idx: u32) -> u32 {
self.node_ref(node_idx).get_check()
}
#[inline(always)]
fn is_leaf(&self, node_idx: u32) -> bool {
self.node_ref(node_idx).is_leaf()
}
#[inline(always)]
fn has_leaf(&self, node_idx: u32) -> bool {
self.node_ref(node_idx).has_leaf()
}
#[inline(always)]
fn get_leaf_idx(&self, node_idx: u32) -> u32 {
let leaf_idx = self.get_base(node_idx) ^ END_CODE;
debug_assert_eq!(self.get_check(leaf_idx), node_idx);
leaf_idx
}
#[inline(always)]
fn get_value(&self, node_idx: u32) -> u32 {
debug_assert!(self.is_leaf(node_idx));
self.node_ref(node_idx).get_base()
}
pub fn heap_bytes(&self) -> usize {
self.mapper.heap_bytes()
+ self.nodes.len() * mem::size_of::<Node>()
+ self.tails.len() * mem::size_of::<u8>()
}
pub fn io_bytes(&self) -> usize {
self.mapper.io_bytes()
+ self.nodes.len() * Node::io_bytes()
+ mem::size_of::<u32>()
+ self.tails.len() * mem::size_of::<u8>()
+ mem::size_of::<u32>()
+ mem::size_of::<u8>() * 2
}
pub fn num_elems(&self) -> usize {
self.nodes.len()
}
pub fn num_vacants(&self) -> usize {
self.nodes.iter().filter(|nd| nd.is_vacant()).count()
}
}
pub struct CommonPrefixSearchIter<'t, I> {
haystack: I,
haystack_pos: usize,
trie: &'t MpTrie,
node_idx: u32,
}
impl<I> Iterator for CommonPrefixSearchIter<'_, I>
where
I: Iterator<Item = char>,
{
type Item = (u32, usize);
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
for c in self.haystack.by_ref() {
let mc = self.trie.mapper.get(c)?;
self.node_idx = self.trie.get_child_idx(self.node_idx, mc)?;
self.haystack_pos += 1;
if self.trie.is_leaf(self.node_idx) {
let tail_pos = usize::try_from(self.trie.get_value(self.node_idx)).unwrap();
let mut tail_iter = self.trie.tail_iter(tail_pos);
for tc in tail_iter.by_ref() {
let mc = self.trie.mapper.get(self.haystack.next()?);
mc.filter(|&c| c == tc)?;
self.haystack_pos += 1;
}
return Some((tail_iter.value(), self.haystack_pos));
} else if self.trie.has_leaf(self.node_idx) {
let leaf_idx = self.trie.get_leaf_idx(self.node_idx);
return Some((self.trie.get_value(leaf_idx), self.haystack_pos));
}
}
None
}
}
struct TailIter<'a> {
trie: &'a MpTrie,
pos: usize,
len: usize,
}
impl TailIter<'_> {
#[inline(always)]
fn value(&self) -> u32 {
utils::unpack_u32(&self.trie.tails[self.pos..], self.trie.value_size)
}
}
impl Iterator for TailIter<'_> {
type Item = u32;
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
if self.len != 0 {
let c = utils::unpack_u32(&self.trie.tails[self.pos..], self.trie.code_size);
self.pos += usize::try_from(self.trie.code_size).unwrap();
self.len -= 1;
Some(c)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_match() {
let keys = vec!["世界", "世界中", "世論調査", "統計調査"];
let trie = MpTrie::from_keys(&keys).unwrap();
for (i, key) in keys.iter().enumerate() {
assert_eq!(
trie.exact_match(key.chars()),
Some(u32::try_from(i).unwrap())
);
}
assert_eq!(trie.exact_match("世".chars()), None);
assert_eq!(trie.exact_match("世論".chars()), None);
assert_eq!(trie.exact_match("世界中で".chars()), None);
assert_eq!(trie.exact_match("統計".chars()), None);
assert_eq!(trie.exact_match("統計調".chars()), None);
assert_eq!(trie.exact_match("日本".chars()), None);
}
#[test]
fn test_common_prefix_search() {
let keys = vec!["世界", "世界中", "世論調査", "統計調査"];
let trie = MpTrie::from_keys(&keys).unwrap();
let haystack: Vec<_> = "世界中の統計世論調査".chars().collect();
let mut matches = vec![];
for i in 0..haystack.len() {
for (v, j) in trie.common_prefix_search(haystack[i..].iter().copied()) {
matches.push((v, i..i + j));
}
}
assert_eq!(matches, vec![(0, 0..2), (1, 0..3), (2, 6..10)]);
}
#[test]
fn test_serialize() {
let keys = vec!["世界", "世界中", "世論調査", "統計調査"];
let trie = MpTrie::from_keys(&keys).unwrap();
let bytes = trie.serialize_to_vec();
assert_eq!(trie.io_bytes(), bytes.len());
let (other, remain) = MpTrie::deserialize_from_slice(&bytes);
assert!(remain.is_empty());
assert_eq!(trie.mapper, other.mapper);
assert_eq!(trie.nodes, other.nodes);
assert_eq!(trie.tails, other.tails);
assert_eq!(trie.code_size, other.code_size);
assert_eq!(trie.value_size, other.value_size);
}
#[test]
fn test_empty_set() {
assert!(MpTrie::from_keys(&[""][0..0]).is_err());
}
#[test]
fn test_empty_char() {
assert!(MpTrie::from_keys([""]).is_err());
}
#[test]
fn test_empty_key() {
assert!(MpTrie::from_keys(["", "AAA"]).is_err());
}
#[test]
fn test_unsorted_keys() {
assert!(MpTrie::from_keys(["BB", "AA"]).is_ok());
assert!(MpTrie::from_keys(["AAA", "AA"]).is_ok());
}
#[test]
fn test_duplicate_keys() {
assert!(MpTrie::from_keys(["AA", "AA"]).is_err());
}
}