Skip to main content

aho_corasick/
lib.rs

1//! Aho-Corasick multi-pattern string matching automaton.
2//!
3//! Constructs a trie over the given pattern set, then adds failure links
4//! (BFS) to enable O(n + m + z) searching where n is the text length,
5//! m is the total pattern length, and z is the number of matches.
6//!
7//! # Examples
8//!
9//! ```
10//! use aho_corasick::{AhoCorasick, Match};
11//!
12//! let ac = AhoCorasick::new(&["he", "she", "his", "hers"]);
13//! let matches = ac.find_all(b"ahishers");
14//! assert_eq!(matches.len(), 4);
15//! ```
16
17use std::collections::VecDeque;
18
19// ── node ──────────────────────────────────────────────────────────────────────
20
21const ALPHA: usize = 256;
22
23struct Node {
24    children: Vec<usize>,
25    fail: usize,
26    output: Vec<usize>,
27}
28
29impl Node {
30    fn new() -> Self {
31        Self { children: vec![usize::MAX; ALPHA], fail: 0, output: vec![] }
32    }
33}
34
35// ── Match ─────────────────────────────────────────────────────────────────────
36
37/// A pattern occurrence found in the text.
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct Match {
40    /// Index of the matched pattern (order it was inserted).
41    pub pattern_id: usize,
42    /// Start position in the text (inclusive).
43    pub start: usize,
44    /// End position in the text (exclusive).
45    pub end: usize,
46}
47
48// ── AhoCorasick ───────────────────────────────────────────────────────────────
49
50/// Aho-Corasick automaton for multi-pattern searching.
51pub struct AhoCorasick {
52    nodes: Vec<Node>,
53    patterns: Vec<Vec<u8>>,
54    goto: Vec<[usize; ALPHA]>,
55}
56
57impl AhoCorasick {
58    /// Build the automaton from a list of patterns.
59    pub fn new(patterns: &[&str]) -> Self {
60        let pats: Vec<Vec<u8>> = patterns.iter().map(|p| p.as_bytes().to_vec()).collect();
61        Self::from_bytes(&pats.iter().map(|v| v.as_slice()).collect::<Vec<_>>())
62    }
63
64    /// Build the automaton from byte-string patterns.
65    pub fn from_bytes(patterns: &[&[u8]]) -> Self {
66        let pats: Vec<Vec<u8>> = patterns.iter().map(|p| p.to_vec()).collect();
67        let mut nodes: Vec<Node> = vec![Node::new()];
68
69        // Phase 1: build trie
70        for (pid, pat) in pats.iter().enumerate() {
71            let mut cur = 0usize;
72            for &b in pat {
73                let b = b as usize;
74                if nodes[cur].children[b] == usize::MAX {
75                    nodes[cur].children[b] = nodes.len();
76                    nodes.push(Node::new());
77                }
78                cur = nodes[cur].children[b];
79            }
80            nodes[cur].output.push(pid);
81        }
82
83        // Phase 2: build failure links and goto table with BFS
84        let n = nodes.len();
85        let mut goto = vec![[0usize; ALPHA]; n];
86        let mut queue = VecDeque::new();
87
88        // Root's children: collect first to avoid split borrows.
89        let root_children: Vec<usize> = nodes[0].children.clone();
90        for (b, child) in root_children.into_iter().enumerate() {
91            if child == usize::MAX {
92                goto[0][b] = 0; // self-loop at root
93            } else {
94                goto[0][b] = child;
95                nodes[child].fail = 0;
96                queue.push_back(child);
97            }
98        }
99
100        while let Some(u) = queue.pop_front() {
101            // Merge output of fail link
102            let fail_u = nodes[u].fail;
103            let extra: Vec<usize> = nodes[fail_u].output.clone();
104            nodes[u].output.extend(extra);
105
106            let u_children: Vec<usize> = nodes[u].children.clone();
107            for (b, child) in u_children.into_iter().enumerate() {
108                if child == usize::MAX {
109                    // No child: follow fail link's goto
110                    goto[u][b] = goto[nodes[u].fail][b];
111                } else {
112                    nodes[child].fail = goto[nodes[u].fail][b];
113                    goto[u][b] = child;
114                    queue.push_back(child);
115                }
116            }
117        }
118
119        Self { nodes, patterns: pats, goto }
120    }
121
122    /// Find all pattern occurrences in `text`.
123    pub fn find_all(&self, text: &[u8]) -> Vec<Match> {
124        let mut state = 0usize;
125        let mut matches = Vec::new();
126        for (i, &b) in text.iter().enumerate() {
127            state = self.goto[state][b as usize];
128            for &pid in &self.nodes[state].output {
129                let pat_len = self.patterns[pid].len();
130                matches.push(Match {
131                    pattern_id: pid,
132                    start: i + 1 - pat_len,
133                    end: i + 1,
134                });
135            }
136        }
137        matches
138    }
139
140    /// Returns true if `text` contains any of the patterns.
141    pub fn contains(&self, text: &[u8]) -> bool {
142        let mut state = 0usize;
143        for &b in text {
144            state = self.goto[state][b as usize];
145            if !self.nodes[state].output.is_empty() {
146                return true;
147            }
148        }
149        false
150    }
151
152    /// Count total occurrences of all patterns in `text`.
153    pub fn count(&self, text: &[u8]) -> usize {
154        self.find_all(text).len()
155    }
156
157    /// Returns the number of patterns.
158    pub fn num_patterns(&self) -> usize {
159        self.patterns.len()
160    }
161
162    /// Returns the number of trie nodes (automaton states).
163    pub fn num_states(&self) -> usize {
164        self.nodes.len()
165    }
166}
167
168// ── tests ─────────────────────────────────────────────────────────────────────
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    fn find_sorted(ac: &AhoCorasick, text: &[u8]) -> Vec<(usize, usize, usize)> {
175        let mut v: Vec<_> =
176            ac.find_all(text).into_iter().map(|m| (m.start, m.end, m.pattern_id)).collect();
177        v.sort();
178        v
179    }
180
181    // ── basic search ─────────────────────────────────────────────────────────
182
183    #[test]
184    fn classic_ahishers() {
185        let ac = AhoCorasick::new(&["he", "she", "his", "hers"]);
186        let matches = ac.find_all(b"ahishers");
187        assert_eq!(matches.len(), 4);
188    }
189
190    #[test]
191    fn ahishers_positions() {
192        let ac = AhoCorasick::new(&["he", "she", "his", "hers"]);
193        let m = find_sorted(&ac, b"ahishers");
194        // "his" at 1, "he" at 4, "she" at 3, "hers" at 4
195        assert!(m.contains(&(1, 4, 2))); // "his"
196        assert!(m.contains(&(3, 6, 1))); // "she"
197        assert!(m.contains(&(4, 6, 0))); // "he"
198        assert!(m.contains(&(4, 8, 3))); // "hers"
199    }
200
201    #[test]
202    fn no_match() {
203        let ac = AhoCorasick::new(&["xyz", "foo"]);
204        assert_eq!(ac.find_all(b"hello world"), vec![]);
205    }
206
207    #[test]
208    fn single_pattern() {
209        let ac = AhoCorasick::new(&["ab"]);
210        let m = ac.find_all(b"ababab");
211        assert_eq!(m.len(), 3);
212        assert_eq!(m[0].start, 0);
213        assert_eq!(m[1].start, 2);
214        assert_eq!(m[2].start, 4);
215    }
216
217    #[test]
218    fn overlapping_patterns() {
219        let ac = AhoCorasick::new(&["aa", "aaa"]);
220        let m = ac.find_all(b"aaaa");
221        assert_eq!(ac.count(b"aaaa"), m.len());
222        assert!(m.len() >= 3);
223    }
224
225    #[test]
226    fn pattern_is_prefix_of_another() {
227        let ac = AhoCorasick::new(&["a", "ab", "abc"]);
228        let m = ac.find_all(b"abc");
229        assert_eq!(m.len(), 3);
230    }
231
232    #[test]
233    fn pattern_is_suffix_of_another() {
234        let ac = AhoCorasick::new(&["abc", "bc", "c"]);
235        let m = ac.find_all(b"abc");
236        assert_eq!(m.len(), 3);
237    }
238
239    // ── empty and edge cases ─────────────────────────────────────────────────
240
241    #[test]
242    fn empty_text() {
243        let ac = AhoCorasick::new(&["hello"]);
244        assert_eq!(ac.find_all(b""), vec![]);
245    }
246
247    #[test]
248    fn empty_patterns_list() {
249        let ac = AhoCorasick::new(&[]);
250        assert_eq!(ac.find_all(b"hello"), vec![]);
251        assert_eq!(ac.num_patterns(), 0);
252    }
253
254    #[test]
255    fn single_char_patterns() {
256        let ac = AhoCorasick::new(&["a", "b", "c"]);
257        let m = ac.find_all(b"abc");
258        assert_eq!(m.len(), 3);
259    }
260
261    #[test]
262    fn repeated_pattern() {
263        let ac = AhoCorasick::new(&["aa"]);
264        let m = ac.find_all(b"aaaa");
265        assert_eq!(m.len(), 3);
266    }
267
268    // ── contains / count ─────────────────────────────────────────────────────
269
270    #[test]
271    fn contains_true() {
272        let ac = AhoCorasick::new(&["world"]);
273        assert!(ac.contains(b"hello world"));
274    }
275
276    #[test]
277    fn contains_false() {
278        let ac = AhoCorasick::new(&["xyz"]);
279        assert!(!ac.contains(b"hello world"));
280    }
281
282    #[test]
283    fn count_multiple() {
284        let ac = AhoCorasick::new(&["an", "ban"]);
285        // "banana" → "an"@1, "ban"@0, "an"@3
286        assert_eq!(ac.count(b"banana"), 3);
287    }
288
289    // ── metadata ─────────────────────────────────────────────────────────────
290
291    #[test]
292    fn num_patterns() {
293        let ac = AhoCorasick::new(&["a", "bb", "ccc"]);
294        assert_eq!(ac.num_patterns(), 3);
295    }
296
297    #[test]
298    fn num_states_single() {
299        let ac = AhoCorasick::new(&["abc"]);
300        assert_eq!(ac.num_states(), 4); // root + 3 nodes
301    }
302
303    // ── byte patterns ────────────────────────────────────────────────────────
304
305    #[test]
306    fn from_bytes() {
307        let ac = AhoCorasick::from_bytes(&[b"foo", b"bar"]);
308        assert_eq!(ac.count(b"foobar"), 2);
309    }
310
311    #[test]
312    fn find_all_returns_correct_span() {
313        let ac = AhoCorasick::new(&["cat"]);
314        let m = ac.find_all(b"concatenate");
315        assert!(!m.is_empty());
316        let first = &m[0];
317        assert_eq!(&b"concatenate"[first.start..first.end], b"cat");
318    }
319
320    #[test]
321    fn long_text_single_pattern() {
322        let ac = AhoCorasick::new(&["ab"]);
323        let text = b"abababababababab";
324        assert_eq!(ac.count(text), 8);
325    }
326
327    #[test]
328    fn pattern_id_ordering() {
329        let ac = AhoCorasick::new(&["z", "y", "x"]);
330        let m = ac.find_all(b"xyz");
331        // x→pid2, y→pid1, z→pid0
332        let ids: Vec<usize> = m.iter().map(|m| m.pattern_id).collect();
333        assert!(ids.contains(&0));
334        assert!(ids.contains(&1));
335        assert!(ids.contains(&2));
336    }
337}