use crate::error::{SeqError, SeqResult};
use std::collections::VecDeque;
const NONE: usize = usize::MAX;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct Match {
pub pattern_id: usize,
pub start: usize,
pub end: usize,
}
#[derive(Debug, Clone)]
struct Node {
next: [usize; 256],
fail: usize,
outputs: Vec<usize>,
dict_link: usize,
}
impl Node {
fn new() -> Self {
Self {
next: [NONE; 256],
fail: 0,
outputs: Vec::new(),
dict_link: NONE,
}
}
}
#[derive(Debug, Clone)]
pub struct AhoCorasick {
nodes: Vec<Node>,
pattern_lens: Vec<usize>,
}
impl AhoCorasick {
pub fn new<P: AsRef<[u8]>>(patterns: &[P]) -> SeqResult<Self> {
let mut nodes = vec![Node::new()];
let mut pattern_lens = Vec::with_capacity(patterns.len());
for (pattern_id, pattern) in patterns.iter().enumerate() {
let bytes = pattern.as_ref();
if bytes.is_empty() {
return Err(SeqError::EmptyInput);
}
pattern_lens.push(bytes.len());
let mut state = 0usize;
for &byte in bytes {
let idx = usize::from(byte);
let next = nodes[state].next[idx];
state = if next == NONE {
let new_state = nodes.len();
nodes.push(Node::new());
nodes[state].next[idx] = new_state;
new_state
} else {
next
};
}
nodes[state].outputs.push(pattern_id);
}
for node in &mut nodes {
node.outputs.sort_unstable();
node.outputs.dedup();
}
let mut automaton = Self {
nodes,
pattern_lens,
};
automaton.build_failure_links();
Ok(automaton)
}
fn build_failure_links(&mut self) {
let mut queue: VecDeque<usize> = VecDeque::new();
self.nodes[0].fail = 0;
for c in 0..256usize {
let child = self.nodes[0].next[c];
if child != NONE {
self.nodes[child].fail = 0;
queue.push_back(child);
}
}
while let Some(u) = queue.pop_front() {
let u_fail = self.nodes[u].fail;
for c in 0..256usize {
let child = self.nodes[u].next[c];
if child == NONE {
continue;
}
let mut f = u_fail;
loop {
let edge = self.nodes[f].next[c];
if edge != NONE && edge != child {
self.nodes[child].fail = edge;
break;
}
if f == 0 {
self.nodes[child].fail = 0;
break;
}
f = self.nodes[f].fail;
}
let cf = self.nodes[child].fail;
self.nodes[child].dict_link = if !self.nodes[cf].outputs.is_empty() {
cf
} else {
self.nodes[cf].dict_link
};
queue.push_back(child);
}
}
}
fn goto(&self, mut state: usize, c: usize) -> usize {
loop {
let edge = self.nodes[state].next[c];
if edge != NONE {
return edge;
}
if state == 0 {
return 0;
}
state = self.nodes[state].fail;
}
}
pub fn for_each_match<F: FnMut(Match)>(&self, text: &[u8], mut report: F) {
let mut state = 0usize;
for (pos, &byte) in text.iter().enumerate() {
state = self.goto(state, usize::from(byte));
let end = pos + 1;
let mut node = state;
while node != NONE {
for &pattern_id in &self.nodes[node].outputs {
let len = self.pattern_lens[pattern_id];
report(Match {
pattern_id,
start: end - len,
end,
});
}
node = self.nodes[node].dict_link;
}
}
}
pub fn find_iter(&self, text: &[u8]) -> Vec<Match> {
let mut matches: Vec<Match> = Vec::new();
self.for_each_match(text, |m| matches.push(m));
matches.sort_unstable_by(|a, b| {
a.end
.cmp(&b.end)
.then_with(|| a.pattern_id.cmp(&b.pattern_id))
});
matches
}
pub fn is_match(&self, text: &[u8]) -> bool {
let mut state = 0usize;
for &byte in text {
state = self.goto(state, usize::from(byte));
if !self.nodes[state].outputs.is_empty() || self.nodes[state].dict_link != NONE {
return true;
}
}
false
}
pub fn pattern_count(&self) -> usize {
self.pattern_lens.len()
}
pub fn state_count(&self) -> usize {
self.nodes.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handle::LcgRng;
fn naive_matches(patterns: &[&[u8]], text: &[u8]) -> Vec<Match> {
let mut out = Vec::new();
for (pattern_id, pat) in patterns.iter().enumerate() {
if pat.is_empty() {
continue;
}
if pat.len() > text.len() {
continue;
}
for start in 0..=(text.len() - pat.len()) {
if &text[start..start + pat.len()] == *pat {
out.push(Match {
pattern_id,
start,
end: start + pat.len(),
});
}
}
}
out.sort_unstable_by(|a, b| {
a.end
.cmp(&b.end)
.then_with(|| a.pattern_id.cmp(&b.pattern_id))
});
out
}
fn random_bytes(rng: &mut LcgRng, alphabet: &[u8], len: usize) -> Vec<u8> {
(0..len)
.map(|_| alphabet[rng.next_usize(alphabet.len())])
.collect()
}
#[test]
fn classic_he_she_his_hers() {
let patterns = ["he", "she", "his", "hers"];
let ac = AhoCorasick::new(&patterns).expect("non-empty");
let hits = ac.find_iter(b"ushers");
let expected = vec![
Match {
pattern_id: 0,
start: 2,
end: 4,
}, Match {
pattern_id: 1,
start: 1,
end: 4,
}, Match {
pattern_id: 3,
start: 2,
end: 6,
}, ];
assert_eq!(hits, expected);
assert!(hits.iter().all(|m| m.pattern_id != 2), "his must be absent");
let pat_bytes: Vec<&[u8]> = patterns.iter().map(|p| p.as_bytes()).collect();
assert_eq!(hits, naive_matches(&pat_bytes, b"ushers"));
}
#[test]
fn overlapping_matches_all_reported() {
let patterns = ["aa", "aaa"];
let ac = AhoCorasick::new(&patterns).expect("non-empty");
let hits = ac.find_iter(b"aaaaa");
let pat_bytes: Vec<&[u8]> = patterns.iter().map(|p| p.as_bytes()).collect();
let oracle = naive_matches(&pat_bytes, b"aaaaa");
assert_eq!(hits, oracle);
let aa = hits.iter().filter(|m| m.pattern_id == 0).count();
let aaa = hits.iter().filter(|m| m.pattern_id == 1).count();
assert_eq!(aa, 4, "every aa occurrence");
assert_eq!(aaa, 3, "every aaa occurrence");
}
#[test]
fn absent_pattern_no_matches() {
let ac = AhoCorasick::new(&["xyz", "qqq"]).expect("non-empty");
let hits = ac.find_iter(b"the quick brown fox");
assert!(hits.is_empty(), "no pattern occurs");
assert!(!ac.is_match(b"the quick brown fox"));
let ac2 = AhoCorasick::new(&["fox", "zzz"]).expect("non-empty");
let hits2 = ac2.find_iter(b"the quick brown fox");
assert_eq!(hits2.len(), 1);
assert_eq!(hits2[0].pattern_id, 0);
assert!(hits2.iter().all(|m| m.pattern_id != 1));
}
#[test]
fn single_character_patterns() {
let patterns = ["a", "b", "c"];
let ac = AhoCorasick::new(&patterns).expect("non-empty");
let hits = ac.find_iter(b"abcabc");
let pat_bytes: Vec<&[u8]> = patterns.iter().map(|p| p.as_bytes()).collect();
assert_eq!(hits, naive_matches(&pat_bytes, b"abcabc"));
for id in 0..3 {
assert_eq!(hits.iter().filter(|m| m.pattern_id == id).count(), 2);
}
}
#[test]
fn dictionary_suffix_link_both_reported() {
let patterns = ["hers", "ers"];
let ac = AhoCorasick::new(&patterns).expect("non-empty");
let hits = ac.find_iter(b"hers");
assert!(
hits.iter()
.any(|m| m.pattern_id == 0 && m.start == 0 && m.end == 4),
"hers reported"
);
assert!(
hits.iter()
.any(|m| m.pattern_id == 1 && m.start == 1 && m.end == 4),
"ers reported via dictionary-suffix link"
);
let pat_bytes: Vec<&[u8]> = patterns.iter().map(|p| p.as_bytes()).collect();
assert_eq!(hits, naive_matches(&pat_bytes, b"hers"));
let chain = ["abc", "bc", "c"];
let ac2 = AhoCorasick::new(&chain).expect("non-empty");
let hits2 = ac2.find_iter(b"abc");
assert_eq!(hits2.len(), 3, "three nested suffixes all reported");
let chain_bytes: Vec<&[u8]> = chain.iter().map(|p| p.as_bytes()).collect();
assert_eq!(hits2, naive_matches(&chain_bytes, b"abc"));
}
#[test]
fn repeated_occurrences_all_found() {
let ac = AhoCorasick::new(&["ab"]).expect("non-empty");
let hits = ac.find_iter(b"ababab");
assert_eq!(hits.len(), 3);
let starts: Vec<usize> = hits.iter().map(|m| m.start).collect();
assert_eq!(starts, vec![0, 2, 4]);
let ac_dup = AhoCorasick::new(&["xy", "xy"]).expect("non-empty");
let dup_hits = ac_dup.find_iter(b"xyxy");
assert_eq!(dup_hits.len(), 4);
assert_eq!(dup_hits.iter().filter(|m| m.pattern_id == 0).count(), 2);
assert_eq!(dup_hits.iter().filter(|m| m.pattern_id == 1).count(), 2);
}
#[test]
fn random_cross_check_against_naive() {
let mut rng = LcgRng::new(0xACDC);
let alphabet = b"abc";
for _ in 0..300 {
let num_patterns = 1 + rng.next_usize(6);
let mut owned: Vec<Vec<u8>> = Vec::with_capacity(num_patterns);
for _ in 0..num_patterns {
let plen = 1 + rng.next_usize(4);
owned.push(random_bytes(&mut rng, alphabet, plen));
}
let pat_refs: Vec<&[u8]> = owned.iter().map(|v| v.as_slice()).collect();
let text_len = rng.next_usize(20);
let text = random_bytes(&mut rng, alphabet, text_len);
let ac = AhoCorasick::new(&pat_refs).expect("patterns non-empty");
let got = ac.find_iter(&text);
let oracle = naive_matches(&pat_refs, &text);
assert_eq!(got, oracle, "mismatch: patterns={pat_refs:?} text={text:?}");
assert_eq!(ac.is_match(&text), !oracle.is_empty());
}
}
#[test]
fn empty_pattern_rejected() {
let patterns: [&str; 2] = ["ok", ""];
assert!(matches!(
AhoCorasick::new(&patterns),
Err(SeqError::EmptyInput)
));
}
#[test]
fn empty_pattern_set_never_matches() {
let patterns: [&str; 0] = [];
let ac = AhoCorasick::new(&patterns).expect("empty set is valid");
assert_eq!(ac.pattern_count(), 0);
assert!(ac.find_iter(b"anything at all").is_empty());
assert!(!ac.is_match(b"anything"));
}
#[test]
fn empty_text_no_matches() {
let ac = AhoCorasick::new(&["a", "abc"]).expect("non-empty");
assert!(ac.find_iter(b"").is_empty());
assert!(!ac.is_match(b""));
}
}