use super::{DoubleArrayAhoCorasick, ROOT_IDX};
use crate::trie::DEAD_STATE;
use crate::types::{Match, MatchKind, Output};
pub struct FindIter<'a> {
daac: &'a DoubleArrayAhoCorasick,
text: &'a [u8],
pos: usize,
state: u32,
outpos: u32,
}
impl<'a> FindIter<'a> {
pub(super) fn new(daac: &'a DoubleArrayAhoCorasick, text: &'a [u8]) -> Self {
Self {
daac,
text,
pos: 0,
state: ROOT_IDX,
outpos: u32::MAX,
}
}
#[inline(always)]
fn next_state_leftmost(&self, mut state: u32, byte: u8) -> (u32, u32) {
let states = &self.daac.states;
let states_len = states.len();
loop {
let current = &states[state as usize];
let child = current.base ^ (byte as u32);
if (child as usize) < states_len {
let child_state = &states[child as usize];
if child_state.check == state {
return (child, child_state.outpos);
}
}
if state == ROOT_IDX {
return (ROOT_IDX, u32::MAX);
}
let fail = current.fail;
if fail == DEAD_STATE {
return (ROOT_IDX, u32::MAX);
}
state = fail;
}
}
#[inline]
fn next_overlapping(&mut self) -> Option<Match> {
if self.outpos != u32::MAX {
let output = &self.daac.outputs[self.outpos as usize];
self.outpos = output.parent;
return Some(Match {
pattern_id: output.pattern_id,
start: self.pos - output.length as usize,
end: self.pos,
});
}
let states = &self.daac.states;
let states_len = states.len();
while self.pos < self.text.len() {
let byte = self.text[self.pos];
self.pos += 1;
loop {
let current = &states[self.state as usize];
let child = current.base ^ (byte as u32);
if (child as usize) < states_len && states[child as usize].check == self.state {
self.state = child;
break;
}
if self.state == ROOT_IDX {
break;
}
self.state = current.fail;
}
self.outpos = states[self.state as usize].outpos;
if self.outpos != u32::MAX {
let output = &self.daac.outputs[self.outpos as usize];
self.outpos = output.parent;
return Some(Match {
pattern_id: output.pattern_id,
start: self.pos - output.length as usize,
end: self.pos,
});
}
}
None
}
#[inline(always)]
fn next_leftmost(&mut self) -> Option<Match> {
let mut state = ROOT_IDX;
let mut last_outpos: u32 = u32::MAX;
let text = self.text;
let mut pos = self.pos;
while pos < text.len() {
let byte = text[pos];
let (next_state, outpos) = self.next_state_leftmost(state, byte);
pos += 1;
if next_state == ROOT_IDX {
if last_outpos != u32::MAX {
let output = &self.daac.outputs[last_outpos as usize];
return Some(Match {
pattern_id: output.pattern_id,
start: self.pos - output.length as usize,
end: self.pos,
});
}
state = ROOT_IDX;
} else {
state = next_state;
if outpos != u32::MAX {
last_outpos = outpos;
self.pos = pos;
}
}
}
if last_outpos != u32::MAX {
let output = &self.daac.outputs[last_outpos as usize];
return Some(Match {
pattern_id: output.pattern_id,
start: self.pos - output.length as usize,
end: self.pos,
});
}
None
}
}
impl<'a> Iterator for FindIter<'a> {
type Item = Match;
#[inline]
fn next(&mut self) -> Option<Match> {
match self.daac.match_kind {
MatchKind::Overlapping => self.next_overlapping(),
MatchKind::LeftmostFirst
| MatchKind::LeftmostLongest
| MatchKind::WordPiece => self.next_leftmost(),
}
}
}
pub struct OutputIter<'a> {
outputs: &'a [Output],
outpos: u32,
}
impl<'a> OutputIter<'a> {
pub(super) fn new(outputs: &'a [Output], outpos: u32) -> Self {
Self { outputs, outpos }
}
}
impl<'a> Iterator for OutputIter<'a> {
type Item = &'a Output;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.outpos == u32::MAX {
return None;
}
let output = &self.outputs[self.outpos as usize];
self.outpos = output.parent;
Some(output)
}
}
#[cfg(test)]
mod tests {
use super::super::DoubleArrayAhoCorasick;
use crate::trie::Trie;
use crate::types::{Match, MatchKind};
#[test]
fn test_find_single_pattern() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let matches = daac.find(b"she");
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].pattern_id, 0);
assert_eq!(matches[0].start, 1);
assert_eq!(matches[0].end, 3);
}
#[test]
fn test_find_multiple_patterns() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.add(b"she", 1);
trie.add(b"hers", 2);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let matches = daac.find(b"ushers");
assert_eq!(matches.len(), 3);
let match_tuples: Vec<(u32, usize, usize)> = matches
.iter()
.map(|m| (m.pattern_id, m.start, m.end))
.collect();
assert!(match_tuples.contains(&(1, 1, 4))); assert!(match_tuples.contains(&(0, 2, 4))); assert!(match_tuples.contains(&(2, 2, 6))); }
#[test]
fn test_find_no_match() {
let mut trie = Trie::new();
trie.add(b"xyz", 0);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let matches = daac.find(b"abc");
assert_eq!(matches.len(), 0);
}
#[test]
fn test_find_empty_text() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let matches = daac.find(b"");
assert_eq!(matches.len(), 0);
}
#[test]
fn test_find_at_start() {
let mut trie = Trie::new();
trie.add(b"hello", 0);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let matches = daac.find(b"hello world");
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].start, 0);
assert_eq!(matches[0].end, 5);
}
#[test]
fn test_find_at_end() {
let mut trie = Trie::new();
trie.add(b"end", 0);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let matches = daac.find(b"the end");
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].start, 4);
assert_eq!(matches[0].end, 7);
}
#[test]
fn test_find_overlapping() {
let mut trie = Trie::new();
trie.add(b"a", 0);
trie.add(b"aa", 1);
trie.add(b"aaa", 2);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let matches = daac.find(b"aaaa");
assert_eq!(matches.len(), 9);
}
#[test]
fn test_find_iter_count() {
let mut trie = Trie::new();
trie.add(b"a", 0);
trie.add(b"aa", 1);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let count = daac.find_iter(b"aaaa").count();
let matches = daac.find(b"aaaa");
assert_eq!(count, matches.len());
}
#[test]
fn test_find_iter_early_termination() {
let mut trie = Trie::new();
trie.add(b"a", 0);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let first_two: Vec<Match> = daac.find_iter(b"aaaaa").take(2).collect();
assert_eq!(first_two.len(), 2);
assert_eq!(first_two[0].start, 0);
assert_eq!(first_two[1].start, 1);
}
#[test]
fn test_find_iter_matches_find() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.add(b"she", 1);
trie.add(b"his", 2);
trie.add(b"hers", 3);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let text = b"ushers and his";
let iter_matches: Vec<Match> = daac.find_iter(text).collect();
let find_matches = daac.find(text);
assert_eq!(iter_matches.len(), find_matches.len());
for (a, b) in iter_matches.iter().zip(find_matches.iter()) {
assert_eq!(a.pattern_id, b.pattern_id);
assert_eq!(a.start, b.start);
assert_eq!(a.end, b.end);
}
}
#[test]
fn test_find_suffix_outputs() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.add(b"she", 1);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let matches = daac.find(b"she");
assert_eq!(matches.len(), 2);
let match_ids: Vec<u32> = matches.iter().map(|m| m.pattern_id).collect();
assert!(match_ids.contains(&0)); assert!(match_ids.contains(&1)); }
#[test]
fn test_leftmost_longest_find_non_overlapping() {
let mut trie = Trie::new();
trie.add(b"a", 0);
trie.add(b"aa", 1);
trie.add(b"aaa", 2);
trie.build(MatchKind::LeftmostLongest);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let matches = daac.find(b"aaaa");
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].pattern_id, 2); assert_eq!(matches[0].start, 0);
assert_eq!(matches[0].end, 3);
assert_eq!(matches[1].pattern_id, 0); assert_eq!(matches[1].start, 3);
assert_eq!(matches[1].end, 4);
}
#[test]
fn test_leftmost_longest_find_she_he() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.add(b"she", 1);
trie.build(MatchKind::LeftmostLongest);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let matches = daac.find(b"she");
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].pattern_id, 1); assert_eq!(matches[0].start, 0);
assert_eq!(matches[0].end, 3);
}
#[test]
fn test_leftmost_longest_find_ushers() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.add(b"she", 1);
trie.add(b"hers", 2);
trie.build(MatchKind::LeftmostLongest);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let matches = daac.find(b"ushers");
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].pattern_id, 1); }
#[test]
fn test_leftmost_longest_stores_match_kind() {
let mut trie = Trie::new();
trie.add(b"test", 0);
trie.build(MatchKind::LeftmostLongest);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
assert!(matches!(daac.match_kind, MatchKind::LeftmostLongest));
}
#[test]
fn test_leftmost_longest_matches_trie() {
let mut trie1 = Trie::new();
trie1.add(b"a", 0);
trie1.add(b"aa", 1);
trie1.add(b"aaa", 2);
trie1.build(MatchKind::LeftmostLongest);
let mut trie2 = Trie::new();
trie2.add(b"a", 0);
trie2.add(b"aa", 1);
trie2.add(b"aaa", 2);
trie2.build(MatchKind::LeftmostLongest);
let daac = DoubleArrayAhoCorasick::from_trie(trie2);
let text = b"aaaaaaa";
let trie_matches = trie1.find(text);
let daac_matches = daac.find(text);
assert_eq!(trie_matches.len(), daac_matches.len());
for (t, d) in trie_matches.iter().zip(daac_matches.iter()) {
assert_eq!(t.pattern_id, d.pattern_id);
assert_eq!(t.start, d.start);
assert_eq!(t.end, d.end);
}
}
#[test]
fn test_leftmost_first_short_pattern_wins() {
let mut trie = Trie::new();
trie.add(b"a", 0);
trie.add(b"aa", 1);
trie.add(b"aaa", 2);
trie.build(MatchKind::LeftmostFirst);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let matches = daac.find(b"aaaa");
assert_eq!(matches.len(), 4);
for (i, m) in matches.iter().enumerate() {
assert_eq!(m.pattern_id, 0); assert_eq!(m.start, i);
assert_eq!(m.end, i + 1);
}
}
#[test]
fn test_leftmost_first_long_pattern_first() {
let mut trie = Trie::new();
trie.add(b"aaa", 0);
trie.add(b"aa", 1);
trie.add(b"a", 2);
trie.build(MatchKind::LeftmostFirst);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let matches = daac.find(b"aaaa");
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].pattern_id, 0); assert_eq!(matches[0].start, 0);
assert_eq!(matches[0].end, 3);
assert_eq!(matches[1].pattern_id, 2); assert_eq!(matches[1].start, 3);
assert_eq!(matches[1].end, 4);
}
#[test]
fn test_leftmost_first_stores_match_kind() {
let mut trie = Trie::new();
trie.add(b"test", 0);
trie.build(MatchKind::LeftmostFirst);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
assert!(matches!(daac.match_kind, MatchKind::LeftmostFirst));
}
#[test]
fn test_leftmost_first_matches_trie() {
let mut trie1 = Trie::new();
trie1.add(b"a", 0);
trie1.add(b"aa", 1);
trie1.add(b"aaa", 2);
trie1.build(MatchKind::LeftmostFirst);
let mut trie2 = Trie::new();
trie2.add(b"a", 0);
trie2.add(b"aa", 1);
trie2.add(b"aaa", 2);
trie2.build(MatchKind::LeftmostFirst);
let daac = DoubleArrayAhoCorasick::from_trie(trie2);
let text = b"aaaaaaa";
let trie_matches = trie1.find(text);
let daac_matches = daac.find(text);
assert_eq!(trie_matches.len(), daac_matches.len());
for (t, d) in trie_matches.iter().zip(daac_matches.iter()) {
assert_eq!(t.pattern_id, d.pattern_id);
assert_eq!(t.start, d.start);
assert_eq!(t.end, d.end);
}
}
}