use crate::error::{CoreError, CoreResult, ErrorContext};
use std::collections::{HashMap, VecDeque};
#[derive(Debug, Clone)]
struct AcState {
transitions: HashMap<u8, usize>,
fail: usize,
output: Vec<usize>,
}
impl AcState {
fn new() -> Self {
AcState {
transitions: HashMap::new(),
fail: 0,
output: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct AhoCorasick {
states: Vec<AcState>,
pattern_lengths: Vec<usize>,
n_patterns: usize,
}
impl AhoCorasick {
pub fn new(patterns: &[&str]) -> Self {
let byte_patterns: Vec<&[u8]> = patterns.iter().map(|s| s.as_bytes()).collect();
Self::from_bytes(&byte_patterns)
}
pub fn from_bytes(patterns: &[&[u8]]) -> Self {
let n_patterns = patterns.len();
let mut pattern_lengths = Vec::with_capacity(n_patterns);
let mut states: Vec<AcState> = vec![AcState::new()];
for (pi, pattern) in patterns.iter().enumerate() {
pattern_lengths.push(pattern.len());
let mut cur = 0usize;
for &byte in pattern.iter() {
if let Some(&next) = states[cur].transitions.get(&byte) {
cur = next;
} else {
let next = states.len();
states.push(AcState::new());
states[cur].transitions.insert(byte, next);
cur = next;
}
}
states[cur].output.push(pi);
}
let mut queue: VecDeque<usize> = VecDeque::new();
let root_children: Vec<(u8, usize)> = states[0]
.transitions
.iter()
.map(|(&b, &s)| (b, s))
.collect();
for (_byte, child) in root_children {
states[child].fail = 0;
queue.push_back(child);
}
while let Some(r) = queue.pop_front() {
let edges: Vec<(u8, usize)> = states[r]
.transitions
.iter()
.map(|(&b, &s)| (b, s))
.collect();
for (byte, s) in edges {
queue.push_back(s);
let mut failure = states[r].fail;
loop {
if let Some(&fs) = states[failure].transitions.get(&byte) {
if fs != s {
states[s].fail = fs;
break;
}
}
if failure == 0 {
states[s].fail = 0;
break;
}
failure = states[failure].fail;
}
let fail_state = states[s].fail;
let extra_output: Vec<usize> = states[fail_state].output.clone();
states[s].output.extend(extra_output);
}
}
AhoCorasick {
states,
pattern_lengths,
n_patterns,
}
}
pub fn find_all(&self, text: &str) -> Vec<(usize, usize, usize)> {
self.find_all_bytes(text.as_bytes())
}
pub fn find_all_bytes(&self, text: &[u8]) -> Vec<(usize, usize, usize)> {
let mut results = Vec::new();
let mut state = 0usize;
for (i, &byte) in text.iter().enumerate() {
loop {
if let Some(&next) = self.states[state].transitions.get(&byte) {
state = next;
break;
} else if state == 0 {
break;
} else {
state = self.states[state].fail;
}
}
for &pi in &self.states[state].output {
let plen = self.pattern_lengths[pi];
let start = i + 1 - plen;
results.push((start, i + 1, pi));
}
}
results
}
pub fn find_first(&self, text: &str) -> Option<(usize, usize, usize)> {
self.find_all_bytes(text.as_bytes()).into_iter().next()
}
pub fn count_matches(&self, text: &str) -> usize {
self.find_all_bytes(text.as_bytes()).len()
}
pub fn is_match(&self, text: &str) -> bool {
let mut state = 0usize;
for &byte in text.as_bytes() {
loop {
if let Some(&next) = self.states[state].transitions.get(&byte) {
state = next;
break;
} else if state == 0 {
break;
} else {
state = self.states[state].fail;
}
}
if !self.states[state].output.is_empty() {
return true;
}
}
false
}
pub fn replace_all(&self, text: &str, replacements: &[&str]) -> CoreResult<String> {
if replacements.len() != self.n_patterns {
return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
"replace_all: expected {} replacements, got {}",
self.n_patterns,
replacements.len()
))));
}
let bytes = text.as_bytes();
let matches = self.find_all_bytes(bytes);
let mut result = String::with_capacity(text.len());
let mut pos = 0usize;
for (start, end, pi) in matches {
if start < pos {
continue;
}
match std::str::from_utf8(&bytes[pos..start]) {
Ok(s) => result.push_str(s),
Err(e) => {
return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
"replace_all: invalid UTF-8 in source text: {e}"
))))
}
}
result.push_str(replacements[pi]);
pos = end;
}
match std::str::from_utf8(&bytes[pos..]) {
Ok(s) => result.push_str(s),
Err(e) => {
return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
"replace_all: invalid UTF-8 in source text tail: {e}"
))))
}
}
Ok(result)
}
#[inline]
pub fn n_patterns(&self) -> usize {
self.n_patterns
}
#[inline]
pub fn n_states(&self) -> usize {
self.states.len()
}
}
pub fn bm_horspool_search(text: &[u8], pattern: &[u8]) -> Vec<usize> {
let n = text.len();
let m = pattern.len();
if m == 0 || m > n {
return Vec::new();
}
let mut shift = [m; 256];
for (i, &b) in pattern[..m - 1].iter().enumerate() {
shift[b as usize] = m - 1 - i;
}
let mut results = Vec::new();
let mut i = m - 1;
while i < n {
let mut k = 0usize;
let mut j = i;
while k < m {
if text[j] != pattern[m - 1 - k] {
break;
}
k += 1;
if j == 0 {
break;
}
j = j.saturating_sub(1);
}
if k == m {
results.push(i + 1 - m);
}
i = i.saturating_add(shift[text[i] as usize]);
if i < m - 1 {
break; }
}
results
}
pub fn kmp_failure_function(pattern: &[u8]) -> Vec<usize> {
let m = pattern.len();
let mut failure = vec![0usize; m];
let mut k = 0usize;
let mut i = 1usize;
while i < m {
while k > 0 && pattern[k] != pattern[i] {
k = failure[k - 1];
}
if pattern[k] == pattern[i] {
k += 1;
}
failure[i] = k;
i += 1;
}
failure
}
pub fn kmp_search(text: &[u8], pattern: &[u8]) -> Vec<usize> {
let n = text.len();
let m = pattern.len();
if m == 0 {
return (0..=n).collect();
}
if m > n {
return Vec::new();
}
let failure = kmp_failure_function(pattern);
let mut results = Vec::new();
let mut q = 0usize;
for (i, &c) in text.iter().enumerate() {
while q > 0 && pattern[q] != c {
q = failure[q - 1];
}
if pattern[q] == c {
q += 1;
}
if q == m {
results.push(i + 1 - m);
q = failure[q - 1];
}
}
results
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ac_basic_find_all() {
let ac = AhoCorasick::new(&["he", "she", "his", "hers"]);
let hits = ac.find_all("ushers");
let patterns: Vec<usize> = hits.iter().map(|&(_, _, p)| p).collect();
assert!(patterns.contains(&0)); assert!(patterns.contains(&1)); assert!(patterns.contains(&3)); }
#[test]
fn test_ac_no_match() {
let ac = AhoCorasick::new(&["xyz", "abc"]);
assert_eq!(ac.find_all("hello world"), vec![]);
}
#[test]
fn test_ac_overlapping() {
let ac = AhoCorasick::new(&["aa"]);
let hits = ac.find_all("aaa");
assert_eq!(hits.len(), 2);
}
#[test]
fn test_ac_single_char_patterns() {
let ac = AhoCorasick::new(&["a", "b"]);
let hits = ac.find_all("abab");
assert_eq!(hits.len(), 4);
}
#[test]
fn test_ac_is_match_true() {
let ac = AhoCorasick::new(&["hello", "world"]);
assert!(ac.is_match("say hello!"));
}
#[test]
fn test_ac_is_match_false() {
let ac = AhoCorasick::new(&["hello", "world"]);
assert!(!ac.is_match("greetings"));
}
#[test]
fn test_ac_count_matches() {
let ac = AhoCorasick::new(&["ab"]);
assert_eq!(ac.count_matches("ababab"), 3);
}
#[test]
fn test_ac_find_first() {
let ac = AhoCorasick::new(&["cd", "ab"]);
let first = ac.find_first("xabcd");
assert!(first.is_some());
let (start, _end, _pi) = first.expect("first match should exist");
assert_eq!(start, 1); }
#[test]
fn test_ac_replace_all() {
let ac = AhoCorasick::new(&["cat", "dog"]);
let out = ac
.replace_all("I have a cat and a dog", &["kitty", "puppy"])
.expect("replace_all should succeed");
assert_eq!(out, "I have a kitty and a puppy");
}
#[test]
fn test_ac_replace_all_wrong_count() {
let ac = AhoCorasick::new(&["cat", "dog"]);
let result = ac.replace_all("text", &["only_one"]);
assert!(result.is_err());
}
#[test]
fn test_ac_empty_text() {
let ac = AhoCorasick::new(&["abc"]);
assert_eq!(ac.find_all(""), vec![]);
assert!(!ac.is_match(""));
}
#[test]
fn test_ac_empty_patterns_slice() {
let ac = AhoCorasick::new(&[]);
assert_eq!(ac.n_patterns(), 0);
assert_eq!(ac.find_all("any text"), vec![]);
}
#[test]
fn test_ac_pattern_longer_than_text() {
let ac = AhoCorasick::new(&["verylongpattern"]);
assert_eq!(ac.find_all("short"), vec![]);
}
#[test]
fn test_ac_positions_correct() {
let ac = AhoCorasick::new(&["bc"]);
let hits = ac.find_all("abcabc");
assert_eq!(hits, vec![(1, 3, 0), (4, 6, 0)]);
}
#[test]
fn test_ac_binary_patterns() {
let patterns: &[&[u8]] = &[b"\x00\x01", b"\xFF\xFE"];
let ac = AhoCorasick::from_bytes(patterns);
let text: &[u8] = &[0x00, 0x01, 0x02, 0xFF, 0xFE];
let hits = ac.find_all_bytes(text);
assert_eq!(hits.len(), 2);
}
#[test]
fn test_bmh_basic() {
let pos = bm_horspool_search(b"AABAAABAAABAA", b"AAB");
assert!(pos.contains(&0));
}
#[test]
fn test_bmh_no_match() {
let pos = bm_horspool_search(b"hello world", b"xyz");
assert!(pos.is_empty());
}
#[test]
fn test_bmh_single_char() {
let pos = bm_horspool_search(b"aaa", b"a");
assert_eq!(pos.len(), 3);
}
#[test]
fn test_bmh_pattern_longer_than_text() {
let pos = bm_horspool_search(b"ab", b"abc");
assert!(pos.is_empty());
}
#[test]
fn test_kmp_basic() {
let pos = kmp_search(b"aababcab", b"ab");
assert_eq!(pos, vec![1, 3, 6]);
}
#[test]
fn test_kmp_no_match() {
let pos = kmp_search(b"hello", b"xyz");
assert!(pos.is_empty());
}
#[test]
fn test_kmp_overlapping() {
let pos = kmp_search(b"aaa", b"aa");
assert_eq!(pos, vec![0, 1]);
}
#[test]
fn test_kmp_empty_pattern() {
let pos = kmp_search(b"abc", b"");
assert_eq!(pos, vec![0, 1, 2, 3]);
}
#[test]
fn test_kmp_failure_function_basic() {
let f = kmp_failure_function(b"abcabd");
assert_eq!(f, vec![0, 0, 0, 1, 2, 0]);
}
#[test]
fn test_kmp_failure_aaaa() {
let f = kmp_failure_function(b"aaaa");
assert_eq!(f, vec![0, 1, 2, 3]);
}
#[test]
fn test_kmp_full_text_match() {
let pos = kmp_search(b"abcabc", b"abcabc");
assert_eq!(pos, vec![0]);
}
}