mod iter;
pub use iter::TrieFindIter;
use alloc::{
collections::{BTreeMap, VecDeque},
vec::Vec,
};
use crate::daac::DoubleArrayAhoCorasick;
use crate::types::{Match, MatchKind, Output};
pub const DEAD_STATE: u32 = u32::MAX;
#[derive(Debug, Clone)]
pub struct TrieState {
pub edges: BTreeMap<u8, u32>,
pub fail: u32,
pub outpos: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct Trie {
pub states: Vec<TrieState>,
pub outputs: Vec<Output>,
pub match_kind: MatchKind,
pub anchor: Option<u32>,
}
impl Default for Trie {
fn default() -> Self {
Self::new()
}
}
impl Trie {
pub fn new() -> Self {
let root = TrieState {
edges: BTreeMap::new(),
fail: 0,
outpos: None,
};
Trie {
states: vec![root],
outputs: Vec::new(),
match_kind: MatchKind::default(),
anchor: None,
}
}
pub fn add(&mut self, pattern: &[u8], pattern_id: u32) {
let outpos = self.outputs.len() as u32;
self.outputs.push(Output {
pattern_id,
length: pattern.len() as u32,
parent: u32::MAX,
});
let mut state = 0;
for &byte in pattern {
state = match self.states[state].edges.get(&byte) {
Some(&next) => next as usize,
None => {
let next_id = self.states.len() as u32;
self.states.push(TrieState {
edges: BTreeMap::new(),
fail: 0,
outpos: None,
});
self.states[state].edges.insert(byte, next_id);
next_id as usize
}
};
}
self.states[state].outpos = Some(outpos);
}
pub fn build(&mut self, match_kind: MatchKind) {
self.match_kind = match_kind;
if match_kind == MatchKind::WordPiece {
self.build_wordpiece(b"##");
return;
}
if match_kind == MatchKind::LeftmostFirst {
self.prune_outputs_for_leftmost_first();
}
let leftmost = matches!(
match_kind,
MatchKind::LeftmostFirst | MatchKind::LeftmostLongest
);
let mut queue: VecDeque<usize> = VecDeque::new();
let root_children: Vec<u32> = self.states[0].edges.values().copied().collect();
for child in root_children {
queue.push_back(child as usize);
self.states[child as usize].fail =
if leftmost && self.states[child as usize].outpos.is_some() {
DEAD_STATE
} else {
0
};
}
while let Some(state) = queue.pop_front() {
let parent_fail = self.states[state].fail;
let edges: Vec<(u8, u32)> = self.states[state]
.edges
.iter()
.map(|(&k, &v)| (k, v))
.collect();
for (byte, child) in edges {
queue.push_back(child as usize);
if leftmost && parent_fail == DEAD_STATE {
self.states[child as usize].fail = DEAD_STATE;
continue;
}
let mut fail_state = parent_fail as usize;
while fail_state != 0 && !self.states[fail_state].edges.contains_key(&byte) {
let f = self.states[fail_state].fail;
if leftmost && f == DEAD_STATE {
break;
}
fail_state = f as usize;
}
let computed_fail = self.states[fail_state]
.edges
.get(&byte)
.copied()
.unwrap_or(0);
self.states[child as usize].fail =
if leftmost && self.states[child as usize].outpos.is_some() {
DEAD_STATE
} else {
computed_fail
};
if !leftmost {
let fail_outpos = if computed_fail == 0 {
u32::MAX
} else {
self.states[computed_fail as usize]
.outpos
.unwrap_or(u32::MAX)
};
if let Some(outpos) = self.states[child as usize].outpos {
self.outputs[outpos as usize].parent = fail_outpos;
} else if fail_outpos != u32::MAX {
self.states[child as usize].outpos = Some(fail_outpos);
}
}
}
}
}
pub fn compile(self) -> DoubleArrayAhoCorasick {
DoubleArrayAhoCorasick::from_trie(self)
}
fn prune_outputs_for_leftmost_first(&mut self) {
let mut stack: Vec<(usize, Option<u32>)> = vec![(0, None)];
while let Some((state, min_ancestor_outpos)) = stack.pop() {
let current_outpos = self.states[state].outpos;
if let (Some(ancestor), Some(current)) = (min_ancestor_outpos, current_outpos) {
if ancestor < current {
self.states[state].outpos = None;
}
}
let new_min = match (min_ancestor_outpos, self.states[state].outpos) {
(None, None) => None,
(None, Some(c)) => Some(c),
(Some(a), None) => Some(a),
(Some(a), Some(c)) => Some(a.min(c)),
};
for &child in self.states[state].edges.values() {
stack.push((child as usize, new_min));
}
}
}
pub fn build_wordpiece(&mut self, prefix: &[u8]) {
self.match_kind = MatchKind::WordPiece;
let mut anchor = 0u32;
for &byte in prefix {
match self.states[anchor as usize].edges.get(&byte) {
Some(&next) => anchor = next,
None => {
self.anchor = None;
self.build_standard_failure_links();
return;
}
}
}
self.anchor = Some(anchor);
self.build_standard_failure_links();
for state_id in 0..self.states.len() {
if self.states[state_id].outpos.is_some() {
self.states[state_id].fail = anchor;
}
}
self.states[anchor as usize].fail = 0;
}
fn build_standard_failure_links(&mut self) {
let mut queue: VecDeque<usize> = VecDeque::new();
let root_children: Vec<u32> = self.states[0].edges.values().copied().collect();
for child in root_children {
queue.push_back(child as usize);
self.states[child as usize].fail = 0;
}
while let Some(state) = queue.pop_front() {
let parent_fail = self.states[state].fail;
let edges: Vec<(u8, u32)> = self.states[state]
.edges
.iter()
.map(|(&k, &v)| (k, v))
.collect();
for (byte, child) in edges {
queue.push_back(child as usize);
let mut fail_state = parent_fail as usize;
while fail_state != 0 && !self.states[fail_state].edges.contains_key(&byte) {
fail_state = self.states[fail_state].fail as usize;
}
let computed_fail = self.states[fail_state]
.edges
.get(&byte)
.copied()
.unwrap_or(0);
self.states[child as usize].fail = computed_fail;
}
}
}
pub fn find_iter<'a>(&'a self, text: &'a [u8]) -> TrieFindIter<'a> {
TrieFindIter::new(self, text)
}
pub fn find(&self, text: &[u8]) -> Vec<Match> {
self.find_iter(text).collect()
}
pub fn num_states(&self) -> usize {
self.states.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_creates_root_state() {
let trie = Trie::new();
assert_eq!(trie.states.len(), 1);
assert_eq!(trie.states[0].edges.len(), 0);
assert_eq!(trie.states[0].fail, 0);
}
#[test]
fn test_add_single_pattern() {
let mut trie = Trie::new();
trie.add(b"he", 0);
assert_eq!(trie.states.len(), 3);
assert!(trie.states[2].outpos.is_some());
let outpos = trie.states[2].outpos.unwrap();
let output = &trie.outputs[outpos as usize];
assert_eq!(output.pattern_id, 0);
assert_eq!(output.length, 2);
}
#[test]
fn test_build_fails() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.add(b"she", 1);
trie.build(MatchKind::Overlapping);
let s = trie.states[0].edges.get(&b's').copied().unwrap_or(0);
let sh = trie.states[s as usize]
.edges
.get(&b'h')
.copied()
.unwrap_or(0);
let she = trie.states[sh as usize]
.edges
.get(&b'e')
.copied()
.unwrap_or(0);
let h = trie.states[0].edges.get(&b'h').copied().unwrap_or(0);
let he = trie.states[h as usize]
.edges
.get(&b'e')
.copied()
.unwrap_or(0);
assert_eq!(trie.states[she as usize].fail, he);
}
#[test]
fn test_find_single_pattern() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.build(MatchKind::Overlapping);
let matches = trie.find(b"she");
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].pattern_id, 0);
assert_eq!(matches[0].start, 1);
assert_eq!(matches[0].end, 3);
}
#[test]
fn test_find_no_match() {
let mut trie = Trie::new();
trie.add(b"xyz", 0);
trie.build(MatchKind::Overlapping);
let matches = trie.find(b"abc");
assert_eq!(matches.len(), 0);
}
#[test]
fn test_find_multiple_patterns() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.add(b"she", 1);
trie.add(b"hers", 2);
trie.build(MatchKind::Overlapping);
let matches = trie.find(b"ushers");
assert_eq!(matches.len(), 3);
let match_tuples: Vec<(u32, usize, usize)> = matches
.iter()
.map(|m| (m.pattern_id, m.start, m.end))
.collect();
assert!(match_tuples.contains(&(1, 1, 4))); assert!(match_tuples.contains(&(0, 2, 4))); assert!(match_tuples.contains(&(2, 2, 6))); }
#[test]
fn test_find_at_start() {
let mut trie = Trie::new();
trie.add(b"hello", 0);
trie.build(MatchKind::Overlapping);
let matches = trie.find(b"hello world");
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].start, 0);
assert_eq!(matches[0].end, 5);
}
#[test]
fn test_find_at_end() {
let mut trie = Trie::new();
trie.add(b"end", 0);
trie.build(MatchKind::Overlapping);
let matches = trie.find(b"the end");
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].start, 4);
assert_eq!(matches[0].end, 7);
}
#[test]
fn test_find_overlapping() {
let mut trie = Trie::new();
trie.add(b"a", 0);
trie.add(b"aa", 1);
trie.add(b"aaa", 2);
trie.build(MatchKind::Overlapping);
let matches = trie.find(b"aaaa");
assert_eq!(matches.len(), 9);
}
#[test]
fn test_find_empty_text() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.build(MatchKind::Overlapping);
let matches = trie.find(b"");
assert_eq!(matches.len(), 0);
}
#[test]
fn test_leftmost_longest_build_dead_state() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.add(b"she", 1);
trie.build(MatchKind::LeftmostLongest);
let h = trie.states[0].edges.get(&b'h').copied().unwrap();
let he = trie.states[h as usize].edges.get(&b'e').copied().unwrap();
assert_eq!(trie.states[he as usize].fail, DEAD_STATE);
}
#[test]
fn test_leftmost_longest_propagates_dead() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.add(b"hers", 1);
trie.build(MatchKind::LeftmostLongest);
let h = trie.states[0].edges.get(&b'h').copied().unwrap();
let he = trie.states[h as usize].edges.get(&b'e').copied().unwrap();
let her = trie.states[he as usize].edges.get(&b'r').copied().unwrap();
let hers = trie.states[her as usize].edges.get(&b's').copied().unwrap();
assert_eq!(trie.states[he as usize].fail, DEAD_STATE);
assert_eq!(trie.states[her as usize].fail, DEAD_STATE);
assert_eq!(trie.states[hers as usize].fail, DEAD_STATE);
}
#[test]
fn test_leftmost_longest_find_non_overlapping() {
let mut trie = Trie::new();
trie.add(b"a", 0);
trie.add(b"aa", 1);
trie.add(b"aaa", 2);
trie.build(MatchKind::LeftmostLongest);
let matches = trie.find(b"aaaa");
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].pattern_id, 2); assert_eq!(matches[0].start, 0);
assert_eq!(matches[0].end, 3);
assert_eq!(matches[1].pattern_id, 0); assert_eq!(matches[1].start, 3);
assert_eq!(matches[1].end, 4);
}
#[test]
fn test_leftmost_longest_different_add_order() {
let mut trie = Trie::new();
trie.add(b"aaa", 0);
trie.add(b"aa", 1);
trie.add(b"a", 2);
trie.build(MatchKind::LeftmostLongest);
let matches = trie.find(b"aaaa");
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].pattern_id, 0); assert_eq!(matches[0].start, 0);
assert_eq!(matches[0].end, 3);
assert_eq!(matches[1].pattern_id, 2); assert_eq!(matches[1].start, 3);
assert_eq!(matches[1].end, 4);
}
#[test]
fn test_leftmost_longest_find_she_he() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.add(b"she", 1);
trie.build(MatchKind::LeftmostLongest);
let matches = trie.find(b"she");
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].pattern_id, 1); assert_eq!(matches[0].start, 0);
assert_eq!(matches[0].end, 3);
}
#[test]
fn test_leftmost_longest_find_ushers() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.add(b"she", 1);
trie.add(b"hers", 2);
trie.build(MatchKind::LeftmostLongest);
let matches = trie.find(b"ushers");
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].pattern_id, 1); }
#[test]
fn test_leftmost_first_prunes_outputs() {
let mut trie = Trie::new();
trie.add(b"a", 0);
trie.add(b"aa", 1);
trie.add(b"aaa", 2);
trie.build(MatchKind::LeftmostFirst);
let a = trie.states[0].edges.get(&b'a').copied().unwrap();
assert!(trie.states[a as usize].outpos.is_some());
let aa = trie.states[a as usize].edges.get(&b'a').copied().unwrap();
assert!(trie.states[aa as usize].outpos.is_none());
let aaa = trie.states[aa as usize].edges.get(&b'a').copied().unwrap();
assert!(trie.states[aaa as usize].outpos.is_none()); }
#[test]
fn test_leftmost_first_short_pattern_wins() {
let mut trie = Trie::new();
trie.add(b"a", 0);
trie.add(b"aa", 1);
trie.add(b"aaa", 2);
trie.build(MatchKind::LeftmostFirst);
let matches = trie.find(b"aaaa");
assert_eq!(matches.len(), 4);
for (i, m) in matches.iter().enumerate() {
assert_eq!(m.pattern_id, 0); assert_eq!(m.start, i);
assert_eq!(m.end, i + 1);
}
}
#[test]
fn test_leftmost_first_long_pattern_first() {
let mut trie = Trie::new();
trie.add(b"aaa", 0);
trie.add(b"aa", 1);
trie.add(b"a", 2);
trie.build(MatchKind::LeftmostFirst);
let a = trie.states[0].edges.get(&b'a').copied().unwrap();
assert!(trie.states[a as usize].outpos.is_some());
let aa = trie.states[a as usize].edges.get(&b'a').copied().unwrap();
assert!(trie.states[aa as usize].outpos.is_some());
let aaa = trie.states[aa as usize].edges.get(&b'a').copied().unwrap();
assert!(trie.states[aaa as usize].outpos.is_some()); }
#[test]
fn test_leftmost_first_long_pattern_first_finds_longest() {
let mut trie = Trie::new();
trie.add(b"aaa", 0);
trie.add(b"aa", 1);
trie.add(b"a", 2);
trie.build(MatchKind::LeftmostFirst);
let matches = trie.find(b"aaaa");
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].pattern_id, 0); assert_eq!(matches[0].start, 0);
assert_eq!(matches[0].end, 3);
assert_eq!(matches[1].pattern_id, 2); assert_eq!(matches[1].start, 3);
assert_eq!(matches[1].end, 4);
}
#[test]
fn test_leftmost_first_non_prefix_patterns() {
let mut trie = Trie::new();
trie.add(b"ab", 0);
trie.add(b"ac", 1);
trie.build(MatchKind::LeftmostFirst);
let a = trie.states[0].edges.get(&b'a').copied().unwrap();
let ab = trie.states[a as usize].edges.get(&b'b').copied().unwrap();
let ac = trie.states[a as usize].edges.get(&b'c').copied().unwrap();
assert!(trie.states[ab as usize].outpos.is_some()); assert!(trie.states[ac as usize].outpos.is_some()); }
#[test]
fn test_leftmost_first_stores_match_kind() {
let mut trie = Trie::new();
trie.add(b"test", 0);
assert!(matches!(trie.match_kind, MatchKind::Overlapping));
trie.build(MatchKind::LeftmostFirst);
assert!(matches!(trie.match_kind, MatchKind::LeftmostFirst));
}
#[test]
fn test_leftmost_longest_stores_match_kind() {
let mut trie = Trie::new();
trie.add(b"test", 0);
trie.build(MatchKind::LeftmostLongest);
assert!(matches!(trie.match_kind, MatchKind::LeftmostLongest));
}
#[test]
fn test_wordpiece_finds_anchor() {
let mut trie = Trie::new();
trie.add(b"un", 0);
trie.add(b"break", 1);
trie.add(b"##break", 2);
trie.add(b"##able", 3);
trie.build_wordpiece(b"##");
assert!(trie.anchor.is_some());
let anchor = trie.anchor.unwrap();
let hash1 = trie.states[0].edges.get(&b'#').copied().unwrap();
let hash2 = trie.states[hash1 as usize].edges.get(&b'#').copied().unwrap();
assert_eq!(anchor, hash2);
}
#[test]
fn test_wordpiece_failure_links_point_to_anchor() {
let mut trie = Trie::new();
trie.add(b"un", 0);
trie.add(b"break", 1);
trie.add(b"##break", 2);
trie.add(b"##able", 3);
trie.build_wordpiece(b"##");
let anchor = trie.anchor.unwrap();
let u = trie.states[0].edges.get(&b'u').copied().unwrap();
let un = trie.states[u as usize].edges.get(&b'n').copied().unwrap();
assert!(trie.states[un as usize].outpos.is_some()); assert_eq!(trie.states[un as usize].fail, anchor);
let b = trie.states[0].edges.get(&b'b').copied().unwrap();
let br = trie.states[b as usize].edges.get(&b'r').copied().unwrap();
let bre = trie.states[br as usize].edges.get(&b'e').copied().unwrap();
let brea = trie.states[bre as usize].edges.get(&b'a').copied().unwrap();
let break_state = trie.states[brea as usize].edges.get(&b'k').copied().unwrap();
assert!(trie.states[break_state as usize].outpos.is_some());
assert_eq!(trie.states[break_state as usize].fail, anchor);
let hash1 = trie.states[0].edges.get(&b'#').copied().unwrap();
let hash2 = trie.states[hash1 as usize].edges.get(&b'#').copied().unwrap();
let a = trie.states[hash2 as usize].edges.get(&b'a').copied().unwrap();
let ab = trie.states[a as usize].edges.get(&b'b').copied().unwrap();
let abl = trie.states[ab as usize].edges.get(&b'l').copied().unwrap();
let able = trie.states[abl as usize].edges.get(&b'e').copied().unwrap();
assert!(trie.states[able as usize].outpos.is_some());
assert_eq!(trie.states[able as usize].fail, anchor);
}
#[test]
fn test_wordpiece_stores_match_kind() {
let mut trie = Trie::new();
trie.add(b"test", 0);
trie.add(b"##ing", 1);
trie.build_wordpiece(b"##");
assert!(matches!(trie.match_kind, MatchKind::WordPiece));
}
#[test]
fn test_wordpiece_no_anchor_prefix() {
let mut trie = Trie::new();
trie.add(b"hello", 0);
trie.add(b"world", 1);
trie.build_wordpiece(b"##");
assert!(trie.anchor.is_none());
}
#[test]
fn test_wordpiece_find_uses_leftmost_longest() {
let mut trie = Trie::new();
trie.add(b"a", 0);
trie.add(b"ab", 1);
trie.add(b"abc", 2);
trie.add(b"##d", 3);
trie.build_wordpiece(b"##");
let matches = trie.find(b"abc");
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].pattern_id, 2); assert_eq!(matches[0].start, 0);
assert_eq!(matches[0].end, 3);
}
}