1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
use crate::cs::error::{Error, Result};
use std::{
collections::{HashMap, VecDeque},
fmt,
sync::Arc,
};
/// Configuration options for pattern matching behavior.
#[derive(Clone)]
pub struct MatchConfig {
/// Optional custom boundary checker: returns true if the character is considered a
/// boundary.
///
/// If `None`, no special boundary logic is applied.
pub boundary_checker: Option<Arc<dyn Fn(char) -> bool + Send + Sync>>,
/// Only report the longest match at each position.
pub longest_match_only: bool,
}
// Manually implement Debug since `Arc<dyn Fn(...)>` doesn't implement Debug by default.
impl fmt::Debug for MatchConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MatchConfig")
// We won't try to debug the actual closure. We just indicate its presence.
.field(
"boundary_checker",
&match self.boundary_checker {
Some(_) => "Some(<fn>)",
None => "None",
},
)
.field("longest_match_only", &self.longest_match_only)
.finish()
}
}
// Manual implementation needed because Arc<dyn Fn> doesn't implement Default
#[allow(clippy::derivable_impls)]
impl Default for MatchConfig {
fn default() -> Self {
Self {
boundary_checker: None,
longest_match_only: false,
}
}
}
/// Represents a match found by the Aho-Corasick algorithm.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Match {
/// Index of the matched pattern in the original patterns vector.
pub pattern_index: usize,
/// Start position of the match in the text (byte index).
pub start: usize,
/// End position of the match in the text (byte index, exclusive).
pub end: usize,
}
/// A node in the trie structure.
#[derive(Debug)]
struct TrieNode {
/// Children nodes indexed by character.
children: HashMap<char, usize>,
/// Failure link to the longest proper suffix node.
failure: Option<usize>,
/// All pattern indices that end at this node.
output: Vec<usize>,
/// Depth in the trie (for match position calculation).
depth: usize,
}
impl TrieNode {
fn new(depth: usize) -> Self {
Self {
children: HashMap::new(),
failure: None,
output: Vec::new(),
depth,
}
}
}
/// An implementation of the Aho-Corasick string matching algorithm.
#[derive(Debug)]
pub struct AhoCorasick {
/// All nodes in the automaton.
nodes: Vec<TrieNode>,
/// Original patterns for reporting matches.
patterns: Vec<String>,
/// Root node index (always 0).
root: usize,
/// Configuration for pattern matching behavior.
config: MatchConfig,
}
impl AhoCorasick {
/// Creates a new Aho-Corasick automaton from the given patterns with default
/// configuration.
pub fn new(patterns: Vec<String>) -> Result<Self> {
Self::with_config(patterns, MatchConfig::default())
}
/// Creates a new Aho-Corasick automaton with the specified configuration.
pub fn with_config(patterns: Vec<String>, config: MatchConfig) -> Result<Self> {
// Validate patterns.
if patterns.is_empty() {
return Err(Error::invalid_input("At least one pattern is required"));
}
if patterns.iter().any(|p| p.is_empty()) {
return Err(Error::empty_pattern());
}
let mut ac = Self {
nodes: vec![TrieNode::new(0)],
patterns,
root: 0,
config,
};
// Build trie and failure links.
ac.build_trie()?;
ac.build_failure_links();
Ok(ac)
}
/// Builds the initial trie structure from the patterns.
fn build_trie(&mut self) -> Result<()> {
for (pattern_idx, pattern) in self.patterns.iter().enumerate() {
let mut current = self.root;
// Follow/create path for each character.
for ch in pattern.chars() {
// Instead of using or_insert_with (which causes E0500),
// we explicitly check if the child exists or not.
if let Some(&next) = self.nodes[current].children.get(&ch) {
current = next;
} else {
let new_idx = self.nodes.len();
self.nodes
.push(TrieNode::new(self.nodes[current].depth + 1));
self.nodes[current].children.insert(ch, new_idx);
current = new_idx;
}
}
// Store the index of this pattern in the output list.
self.nodes[current].output.push(pattern_idx);
}
Ok(())
}
/// Builds failure links using a breadth-first traversal of the trie.
fn build_failure_links(&mut self) {
let mut queue = VecDeque::new();
// Initialize root's children.
let root_children: Vec<_> = self.nodes[self.root].children.values().copied().collect();
for child in root_children {
self.nodes[child].failure = Some(self.root);
queue.push_back(child);
}
// Process remaining nodes.
while let Some(current) = queue.pop_front() {
let current_failure = self.nodes[current].failure.unwrap_or(self.root);
let children: Vec<(char, usize)> = self.nodes[current]
.children
.iter()
.map(|(ch, &node)| (*ch, node))
.collect();
for (ch, child) in children {
queue.push_back(child);
// Find the failure link by following parent's failure.
let mut fail_state = current_failure;
let mut next_failure = self.root;
while fail_state != self.root {
if let Some(&next) = self.nodes[fail_state].children.get(&ch) {
next_failure = next;
break;
}
fail_state = self.nodes[fail_state].failure.unwrap_or(self.root);
}
// Check root's children if needed.
if fail_state == self.root {
if let Some(&next) = self.nodes[self.root].children.get(&ch) {
next_failure = next;
}
}
// Set failure link.
self.nodes[child].failure = Some(next_failure);
// Merge outputs from the failure link.
let output_clone = self.nodes[next_failure].output.clone();
self.nodes[child].output.extend_from_slice(&output_clone);
}
}
}
/// Finds the next trie state given the current state and an input character.
fn find_next_state(&self, mut current: usize, ch: char) -> usize {
while !self.nodes[current].children.contains_key(&ch) && current != self.root {
current = self.nodes[current].failure.unwrap_or(self.root);
}
self.nodes[current]
.children
.get(&ch)
.copied()
.unwrap_or(self.root)
}
/// Helper function to check if a match is at a word boundary.
///
/// If `boundary_checker` is `None`, we do no special check (always return true).
fn is_word_boundary(&self, text: &str, start: usize, end: usize) -> bool {
// If no boundary checker is provided, don't filter anything out.
let Some(check_fn) = &self.config.boundary_checker else {
return true;
};
let is_boundary_char = |c: char| check_fn(c);
let before_is_boundary = start == 0
|| text[..start]
.chars()
.next_back()
.is_none_or(is_boundary_char);
let after_is_boundary =
end >= text.len() || text[end..].chars().next().is_none_or(is_boundary_char);
before_is_boundary && after_is_boundary
}
/// Finds all occurrences of any pattern in the given text.
pub fn find_all<'a>(&'a self, text: &'a str) -> impl Iterator<Item = Match> + 'a {
let mut matches = Vec::new();
let mut current = self.root;
// Convert text to (byte_offset, char).
let chars: Vec<(usize, char)> = text.char_indices().collect();
// If longest_match_only is set, we collect matches per position.
let mut matches_at_pos = if self.config.longest_match_only {
vec![Vec::new(); chars.len()]
} else {
Vec::new()
};
for (pos, (byte_offset, ch)) in chars.iter().enumerate() {
current = self.find_next_state(current, *ch);
// Check outputs for the current node.
for &pattern_idx in &self.nodes[current].output {
let pat_len = self.patterns[pattern_idx].chars().count();
if pos + 1 >= pat_len {
let start_pos = pos + 1 - pat_len;
let start_byte = chars[start_pos].0;
let end_byte = byte_offset + ch.len_utf8();
// Check word boundaries if needed.
if self.is_word_boundary(text, start_byte, end_byte) {
let m = Match {
pattern_index: pattern_idx,
start: start_byte,
end: end_byte,
};
if self.config.longest_match_only {
matches_at_pos[start_pos].push(m);
} else {
matches.push(m);
}
}
}
}
}
// If we only want the longest match per start position.
if self.config.longest_match_only {
for pos_matches in matches_at_pos.into_iter().filter(|v| !v.is_empty()) {
// Sort by (longest match first, then pattern index).
let mut pos_matches = pos_matches;
pos_matches
.sort_by_key(|m| (-(m.end as isize - m.start as isize), m.pattern_index));
matches.push(pos_matches[0].clone());
}
}
matches.into_iter()
}
/// Finds the first occurrence of any pattern in the given text.
pub fn find_first(&self, text: &str) -> Option<Match> {
self.find_all(text).next()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_single_pattern() {
let ac = AhoCorasick::new(vec!["test".to_string()]).unwrap();
let matches: Vec<_> = ac.find_all("this is a test case").collect();
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].pattern_index, 0);
assert_eq!(matches[0].start, 10);
assert_eq!(matches[0].end, 14);
}
#[test]
fn test_multiple_patterns() {
let patterns: Vec<String> = vec!["he", "she", "his", "hers"]
.into_iter()
.map(String::from)
.collect();
let ac = AhoCorasick::new(patterns.clone()).unwrap();
let matches: Vec<_> = ac.find_all("she sells seashells").collect();
// "she" at index 0 => "he" is found in "she", "sells", "seashells".
assert_eq!(matches.len(), 4);
// With boundary checker => only "she" at start is valid.
let mut config = MatchConfig::default();
config.boundary_checker = Some(Arc::new(|c: char| !c.is_alphanumeric()));
let ac = AhoCorasick::with_config(patterns.clone(), config).unwrap();
let matches: Vec<_> = ac.find_all("she sells seashells").collect();
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].pattern_index, 1);
assert_eq!(matches[0].start, 0);
assert_eq!(matches[0].end, 3);
}
#[test]
fn test_overlapping_patterns() {
let patterns: Vec<String> = vec!["ant", "ant colony", "colony"]
.into_iter()
.map(String::from)
.collect();
// Default config => all matches.
let ac = AhoCorasick::new(patterns.clone()).unwrap();
let matches: Vec<_> = ac.find_all("ant colony").collect();
assert_eq!(matches.len(), 3);
// Longest match only => "ant colony" and "colony".
let mut config = MatchConfig::default();
config.longest_match_only = true;
let ac = AhoCorasick::with_config(patterns.clone(), config).unwrap();
let matches: Vec<_> = ac.find_all("ant colony").collect();
assert_eq!(matches.len(), 2);
}
#[test]
fn test_unicode() {
// Provide explicit type to avoid E0282
let patterns: Vec<String> = vec!["🦀", "🦀🔧", "🔧"]
.into_iter()
.map(String::from)
.collect();
// Default config => all matches.
let ac = AhoCorasick::new(patterns.clone()).unwrap();
let matches: Vec<_> = ac.find_all("🦀🔧").collect();
assert_eq!(matches.len(), 3);
// Longest match only => "🦀🔧" and "🔧".
let mut config = MatchConfig::default();
config.longest_match_only = true;
let ac = AhoCorasick::with_config(patterns, config).unwrap();
let matches: Vec<_> = ac.find_all("🦀🔧").collect();
assert_eq!(matches.len(), 2);
}
}