Skip to main content

oxicuda_seq/matching/
aho_corasick.rs

1//! Aho–Corasick multi-pattern string matching.
2//!
3//! Reference: Alfred V. Aho & Margaret J. Corasick, *"Efficient string
4//! matching: an aid to bibliographic search"*, Communications of the ACM
5//! 18(6), 1975, pp. 333–340.
6//!
7//! # Idea
8//!
9//! Aho–Corasick generalises the Knuth–Morris–Pratt single-pattern automaton to
10//! a *set* of patterns. It first builds the **goto** trie of all patterns, then
11//! augments it with two functions computed by a breadth-first sweep:
12//!
13//! * the **failure (suffix) link** `fail[v]` points to the node spelling the
14//!   longest proper suffix of `v`'s string that is itself a trie node — exactly
15//!   the state to fall back to when the current character cannot extend the
16//!   match, and
17//! * the **output function** which, for every node, lists the pattern ids whose
18//!   strings end at that node, *including* those reachable by following the
19//!   chain of failure links (the **dictionary-suffix** links).
20//!
21//! Scanning a text of length `n` then visits one automaton state per input
22//! character and emits every occurrence of every pattern. The whole scan runs
23//! in `O(n + z)` time where `z` is the number of reported matches, after an
24//! `O(Σ |pᵢ|)` construction.
25//!
26//! # Output convention
27//!
28//! [`AhoCorasick::find_iter`] reports each occurrence as a [`Match`] carrying
29//! the matched `pattern_id`, the **end** index `end` (one past the last matched
30//! byte, i.e. the standard exclusive bound), and the derived `start` index.
31//! Because the dictionary-suffix links are followed, **overlapping** matches are
32//! all reported — if both `he` and `she` end at the same text position, both
33//! appear. Matches are emitted in increasing order of `end`; ties (several
34//! patterns ending at the same position) are emitted in ascending `pattern_id`.
35//!
36//! Patterns are matched over raw bytes (`&[u8]`); for ASCII this coincides with
37//! character matching. A pattern may be added more than once and an empty
38//! pattern is rejected at construction time (it would match at every position
39//! and has no well-defined occurrence semantics here).
40
41use crate::error::{SeqError, SeqResult};
42use std::collections::VecDeque;
43
44/// Sentinel for "no node": the root can never be a failure/goto target via this
45/// value, and using `usize::MAX` makes an accidental dereference panic loudly in
46/// debug builds rather than silently aliasing the root.
47const NONE: usize = usize::MAX;
48
49/// One occurrence reported by the automaton.
50///
51/// The matched substring is `text[start..end]`; `end` is exclusive and equals
52/// `start + len` where `len` is the length of pattern `pattern_id`.
53#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
54pub struct Match {
55    /// Index of the matched pattern in the original `patterns` slice.
56    pub pattern_id: usize,
57    /// Start byte offset of the occurrence within the scanned text.
58    pub start: usize,
59    /// End byte offset (exclusive) of the occurrence within the scanned text.
60    pub end: usize,
61}
62
63/// A single trie node of the Aho–Corasick automaton.
64#[derive(Debug, Clone)]
65struct Node {
66    /// Child transitions of the *goto* trie, indexed by byte. `NONE` marks the
67    /// absence of an explicit trie edge (filled lazily only when querying).
68    next: [usize; 256],
69    /// Failure link: the longest proper suffix of this node's string that is a
70    /// trie node. The root's failure link is itself.
71    fail: usize,
72    /// Pattern ids whose string ends exactly at this node (not following
73    /// failure links). Stored sorted and deduplicated.
74    outputs: Vec<usize>,
75    /// Head of the dictionary-suffix chain: the nearest strict ancestor *via
76    /// failure links* that is itself the end of some pattern, or `NONE`.
77    dict_link: usize,
78}
79
80impl Node {
81    fn new() -> Self {
82        Self {
83            next: [NONE; 256],
84            fail: 0,
85            outputs: Vec::new(),
86            dict_link: NONE,
87        }
88    }
89}
90
91/// A compiled Aho–Corasick automaton over a fixed set of byte patterns.
92///
93/// Build it once with [`AhoCorasick::new`] and reuse it for many texts. The
94/// automaton stores, for each pattern, its length, so that occurrences can be
95/// reported with both `start` and `end` offsets.
96///
97/// # Examples
98///
99/// ```
100/// use oxicuda_seq::matching::AhoCorasick;
101///
102/// let ac = AhoCorasick::new(&["he", "she", "his", "hers"]).expect("non-empty");
103/// let hits: Vec<_> = ac
104///     .find_iter(b"ushers")
105///     .iter()
106///     .map(|m| (m.pattern_id, m.start, m.end))
107///     .collect();
108/// // "she" ends at 3, "he" ends at 3, "hers" ends at 6.
109/// assert!(hits.contains(&(1, 1, 4))); // she
110/// assert!(hits.contains(&(0, 2, 4))); // he
111/// assert!(hits.contains(&(3, 2, 6))); // hers
112/// ```
113#[derive(Debug, Clone)]
114pub struct AhoCorasick {
115    nodes: Vec<Node>,
116    /// Length in bytes of each pattern, indexed by `pattern_id`.
117    pattern_lens: Vec<usize>,
118}
119
120impl AhoCorasick {
121    /// Build the automaton from a slice of patterns.
122    ///
123    /// Patterns may be anything convertible to bytes (`&str`, `String`,
124    /// `&[u8]`, …) via [`AsRef<[u8]>`]. Duplicate patterns are permitted and
125    /// keep distinct ids. An empty pattern slice yields an automaton that never
126    /// matches; an *empty individual pattern* is rejected with
127    /// [`SeqError::EmptyInput`].
128    pub fn new<P: AsRef<[u8]>>(patterns: &[P]) -> SeqResult<Self> {
129        let mut nodes = vec![Node::new()];
130        let mut pattern_lens = Vec::with_capacity(patterns.len());
131
132        // --- Phase 1: build the goto trie. ---
133        for (pattern_id, pattern) in patterns.iter().enumerate() {
134            let bytes = pattern.as_ref();
135            if bytes.is_empty() {
136                return Err(SeqError::EmptyInput);
137            }
138            pattern_lens.push(bytes.len());
139
140            let mut state = 0usize;
141            for &byte in bytes {
142                let idx = usize::from(byte);
143                let next = nodes[state].next[idx];
144                state = if next == NONE {
145                    let new_state = nodes.len();
146                    nodes.push(Node::new());
147                    nodes[state].next[idx] = new_state;
148                    new_state
149                } else {
150                    next
151                };
152            }
153            nodes[state].outputs.push(pattern_id);
154        }
155
156        // A pattern can be supplied twice; keep each node's output list tidy.
157        for node in &mut nodes {
158            node.outputs.sort_unstable();
159            node.outputs.dedup();
160        }
161
162        let mut automaton = Self {
163            nodes,
164            pattern_lens,
165        };
166        automaton.build_failure_links();
167        Ok(automaton)
168    }
169
170    /// Compute failure links and dictionary-suffix links by BFS over the trie.
171    ///
172    /// Depth-1 nodes fail to the root. For a node `v` reached from `u` on byte
173    /// `c`, its failure target is `goto(fail[u], c)`, computed using the
174    /// already-finalised links of the (shallower) BFS frontier. The
175    /// dictionary-suffix link of `v` is `fail[v]` if that node is itself an
176    /// output, else the dictionary-suffix link of `fail[v]` — a classic
177    /// path-compressed chain so that reporting all suffix-matches at a node is
178    /// `O(#matches)` rather than `O(depth)`.
179    fn build_failure_links(&mut self) {
180        let mut queue: VecDeque<usize> = VecDeque::new();
181
182        // Root's children fail to the root itself.
183        self.nodes[0].fail = 0;
184        for c in 0..256usize {
185            let child = self.nodes[0].next[c];
186            if child != NONE {
187                self.nodes[child].fail = 0;
188                queue.push_back(child);
189            }
190        }
191
192        while let Some(u) = queue.pop_front() {
193            // Snapshot the failure target of `u`; it is already final because
194            // `u` was dequeued, hence shallower than its children.
195            let u_fail = self.nodes[u].fail;
196
197            for c in 0..256usize {
198                let child = self.nodes[u].next[c];
199                if child == NONE {
200                    continue;
201                }
202
203                // Failure link of `child`: walk failure links from `u`'s
204                // failure node until a goto edge on `c` exists, defaulting to
205                // the root.
206                let mut f = u_fail;
207                loop {
208                    let edge = self.nodes[f].next[c];
209                    if edge != NONE && edge != child {
210                        self.nodes[child].fail = edge;
211                        break;
212                    }
213                    if f == 0 {
214                        self.nodes[child].fail = 0;
215                        break;
216                    }
217                    f = self.nodes[f].fail;
218                }
219
220                // Dictionary-suffix link: nearest failure-ancestor that ends a
221                // pattern, with path compression.
222                let cf = self.nodes[child].fail;
223                self.nodes[child].dict_link = if !self.nodes[cf].outputs.is_empty() {
224                    cf
225                } else {
226                    self.nodes[cf].dict_link
227                };
228
229                queue.push_back(child);
230            }
231        }
232    }
233
234    /// Follow the goto function of the *automaton* (not merely the trie):
235    /// from `state` on byte `c`, take the explicit edge if present, otherwise
236    /// fall back along failure links until an edge exists or the root is
237    /// reached. Never returns `NONE`.
238    fn goto(&self, mut state: usize, c: usize) -> usize {
239        loop {
240            let edge = self.nodes[state].next[c];
241            if edge != NONE {
242                return edge;
243            }
244            if state == 0 {
245                return 0;
246            }
247            state = self.nodes[state].fail;
248        }
249    }
250
251    /// Scan `text`, invoking `report` for every occurrence of every pattern.
252    ///
253    /// This is the streaming core used by [`find_iter`](Self::find_iter): it
254    /// allocates nothing and lets the caller decide what to do with each
255    /// [`Match`]. The closure is called once per occurrence; at a given text
256    /// position the node's own outputs are reported before those reached via
257    /// the dictionary-suffix chain.
258    pub fn for_each_match<F: FnMut(Match)>(&self, text: &[u8], mut report: F) {
259        let mut state = 0usize;
260        for (pos, &byte) in text.iter().enumerate() {
261            state = self.goto(state, usize::from(byte));
262
263            // Emit outputs of the current node and every dictionary-suffix
264            // ancestor. `end` is one past the current byte.
265            let end = pos + 1;
266            let mut node = state;
267            while node != NONE {
268                for &pattern_id in &self.nodes[node].outputs {
269                    let len = self.pattern_lens[pattern_id];
270                    report(Match {
271                        pattern_id,
272                        start: end - len,
273                        end,
274                    });
275                }
276                node = self.nodes[node].dict_link;
277            }
278        }
279    }
280
281    /// Collect every occurrence of every pattern in `text`.
282    ///
283    /// The returned vector is sorted by `(end, pattern_id)`: occurrences ending
284    /// earlier come first, and several patterns ending at the same position are
285    /// ordered by ascending id. Overlapping matches are all included.
286    pub fn find_iter(&self, text: &[u8]) -> Vec<Match> {
287        // Group by end position so that, for a fixed `end`, the node-local
288        // ordering (own outputs, then dictionary-suffix links) is replaced by a
289        // deterministic ascending-`pattern_id` order independent of trie shape.
290        let mut matches: Vec<Match> = Vec::new();
291        self.for_each_match(text, |m| matches.push(m));
292        matches.sort_unstable_by(|a, b| {
293            a.end
294                .cmp(&b.end)
295                .then_with(|| a.pattern_id.cmp(&b.pattern_id))
296        });
297        matches
298    }
299
300    /// Return `true` if *any* pattern occurs in `text`.
301    ///
302    /// Short-circuits at the first match, so it is cheaper than materialising
303    /// [`find_iter`](Self::find_iter) when only presence is needed.
304    pub fn is_match(&self, text: &[u8]) -> bool {
305        let mut state = 0usize;
306        for &byte in text {
307            state = self.goto(state, usize::from(byte));
308            if !self.nodes[state].outputs.is_empty() || self.nodes[state].dict_link != NONE {
309                return true;
310            }
311        }
312        false
313    }
314
315    /// Number of patterns compiled into the automaton.
316    pub fn pattern_count(&self) -> usize {
317        self.pattern_lens.len()
318    }
319
320    /// Number of states (trie nodes) in the automaton, including the root.
321    pub fn state_count(&self) -> usize {
322        self.nodes.len()
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use crate::handle::LcgRng;
330
331    /// Naive `O(n · Σ|pᵢ|)` multi-substring search used as the cross-check
332    /// oracle: for each pattern, slide it over every text position.
333    fn naive_matches(patterns: &[&[u8]], text: &[u8]) -> Vec<Match> {
334        let mut out = Vec::new();
335        for (pattern_id, pat) in patterns.iter().enumerate() {
336            if pat.is_empty() {
337                continue;
338            }
339            if pat.len() > text.len() {
340                continue;
341            }
342            for start in 0..=(text.len() - pat.len()) {
343                if &text[start..start + pat.len()] == *pat {
344                    out.push(Match {
345                        pattern_id,
346                        start,
347                        end: start + pat.len(),
348                    });
349                }
350            }
351        }
352        out.sort_unstable_by(|a, b| {
353            a.end
354                .cmp(&b.end)
355                .then_with(|| a.pattern_id.cmp(&b.pattern_id))
356        });
357        out
358    }
359
360    fn random_bytes(rng: &mut LcgRng, alphabet: &[u8], len: usize) -> Vec<u8> {
361        (0..len)
362            .map(|_| alphabet[rng.next_usize(alphabet.len())])
363            .collect()
364    }
365
366    /// (a) The textbook {he, she, his, hers} over "ushers".
367    ///
368    /// Verifies both the exact positions *and* which pattern fired:
369    /// `she` ends at 3, `he` ends at 3 (overlapping `she`), `hers` ends at 6.
370    #[test]
371    fn classic_he_she_his_hers() {
372        let patterns = ["he", "she", "his", "hers"];
373        let ac = AhoCorasick::new(&patterns).expect("non-empty");
374        let hits = ac.find_iter(b"ushers");
375
376        // Expected occurrences (pattern_id, start, end), 0=he 1=she 2=his 3=hers.
377        // `find_iter` sorts by (end, pattern_id); both `he` and `she` end at 4,
378        // so `he` (id 0) precedes `she` (id 1) despite `she` starting earlier.
379        let expected = vec![
380            Match {
381                pattern_id: 0,
382                start: 2,
383                end: 4,
384            }, // he
385            Match {
386                pattern_id: 1,
387                start: 1,
388                end: 4,
389            }, // she
390            Match {
391                pattern_id: 3,
392                start: 2,
393                end: 6,
394            }, // hers
395        ];
396        assert_eq!(hits, expected);
397
398        // "his" must NOT appear in "ushers".
399        assert!(hits.iter().all(|m| m.pattern_id != 2), "his must be absent");
400
401        // Cross-check against the naive oracle.
402        let pat_bytes: Vec<&[u8]> = patterns.iter().map(|p| p.as_bytes()).collect();
403        assert_eq!(hits, naive_matches(&pat_bytes, b"ushers"));
404    }
405
406    /// (b) Overlapping matches are *all* reported, not just the leftmost-longest.
407    #[test]
408    fn overlapping_matches_all_reported() {
409        // "aa" and "aaa" over "aaaaa": every starting position of each.
410        let patterns = ["aa", "aaa"];
411        let ac = AhoCorasick::new(&patterns).expect("non-empty");
412        let hits = ac.find_iter(b"aaaaa");
413
414        let pat_bytes: Vec<&[u8]> = patterns.iter().map(|p| p.as_bytes()).collect();
415        let oracle = naive_matches(&pat_bytes, b"aaaaa");
416        assert_eq!(hits, oracle);
417
418        // "aa" occurs at starts 0,1,2,3 (4 times); "aaa" at starts 0,1,2 (3).
419        let aa = hits.iter().filter(|m| m.pattern_id == 0).count();
420        let aaa = hits.iter().filter(|m| m.pattern_id == 1).count();
421        assert_eq!(aa, 4, "every aa occurrence");
422        assert_eq!(aaa, 3, "every aaa occurrence");
423    }
424
425    /// (c) A pattern absent from the text yields no matches for it.
426    #[test]
427    fn absent_pattern_no_matches() {
428        let ac = AhoCorasick::new(&["xyz", "qqq"]).expect("non-empty");
429        let hits = ac.find_iter(b"the quick brown fox");
430        assert!(hits.is_empty(), "no pattern occurs");
431        assert!(!ac.is_match(b"the quick brown fox"));
432
433        // One present, one absent.
434        let ac2 = AhoCorasick::new(&["fox", "zzz"]).expect("non-empty");
435        let hits2 = ac2.find_iter(b"the quick brown fox");
436        assert_eq!(hits2.len(), 1);
437        assert_eq!(hits2[0].pattern_id, 0);
438        assert!(hits2.iter().all(|m| m.pattern_id != 1));
439    }
440
441    /// (d) Single-character patterns.
442    #[test]
443    fn single_character_patterns() {
444        let patterns = ["a", "b", "c"];
445        let ac = AhoCorasick::new(&patterns).expect("non-empty");
446        let hits = ac.find_iter(b"abcabc");
447
448        let pat_bytes: Vec<&[u8]> = patterns.iter().map(|p| p.as_bytes()).collect();
449        assert_eq!(hits, naive_matches(&pat_bytes, b"abcabc"));
450        // Each of a/b/c appears twice.
451        for id in 0..3 {
452            assert_eq!(hits.iter().filter(|m| m.pattern_id == id).count(), 2);
453        }
454    }
455
456    /// (e) A pattern that is a suffix of another forces a dictionary-suffix
457    /// link; both must be reported where the longer one ends.
458    #[test]
459    fn dictionary_suffix_link_both_reported() {
460        // "ers" is a suffix of "hers". Scanning "hers", at end=4 both "hers"
461        // and "ers" complete and both must fire.
462        let patterns = ["hers", "ers"];
463        let ac = AhoCorasick::new(&patterns).expect("non-empty");
464        let hits = ac.find_iter(b"hers");
465
466        assert!(
467            hits.iter()
468                .any(|m| m.pattern_id == 0 && m.start == 0 && m.end == 4),
469            "hers reported"
470        );
471        assert!(
472            hits.iter()
473                .any(|m| m.pattern_id == 1 && m.start == 1 && m.end == 4),
474            "ers reported via dictionary-suffix link"
475        );
476
477        let pat_bytes: Vec<&[u8]> = patterns.iter().map(|p| p.as_bytes()).collect();
478        assert_eq!(hits, naive_matches(&pat_bytes, b"hers"));
479
480        // A deeper chain: "c" ⊂ "bc" ⊂ "abc" all end at the same position.
481        let chain = ["abc", "bc", "c"];
482        let ac2 = AhoCorasick::new(&chain).expect("non-empty");
483        let hits2 = ac2.find_iter(b"abc");
484        assert_eq!(hits2.len(), 3, "three nested suffixes all reported");
485        let chain_bytes: Vec<&[u8]> = chain.iter().map(|p| p.as_bytes()).collect();
486        assert_eq!(hits2, naive_matches(&chain_bytes, b"abc"));
487    }
488
489    /// (f) Repeated occurrences of the same pattern are all found.
490    #[test]
491    fn repeated_occurrences_all_found() {
492        let ac = AhoCorasick::new(&["ab"]).expect("non-empty");
493        let hits = ac.find_iter(b"ababab");
494        assert_eq!(hits.len(), 3);
495        let starts: Vec<usize> = hits.iter().map(|m| m.start).collect();
496        assert_eq!(starts, vec![0, 2, 4]);
497
498        // Same pattern supplied twice keeps both ids and both fire.
499        let ac_dup = AhoCorasick::new(&["xy", "xy"]).expect("non-empty");
500        let dup_hits = ac_dup.find_iter(b"xyxy");
501        // 2 positions × 2 ids = 4 matches.
502        assert_eq!(dup_hits.len(), 4);
503        assert_eq!(dup_hits.iter().filter(|m| m.pattern_id == 0).count(), 2);
504        assert_eq!(dup_hits.iter().filter(|m| m.pattern_id == 1).count(), 2);
505    }
506
507    /// (g) Randomised cross-check against the naive multi-substring search.
508    #[test]
509    fn random_cross_check_against_naive() {
510        let mut rng = LcgRng::new(0xACDC);
511        let alphabet = b"abc";
512
513        for _ in 0..300 {
514            // 1..=6 patterns, each length 1..=4.
515            let num_patterns = 1 + rng.next_usize(6);
516            let mut owned: Vec<Vec<u8>> = Vec::with_capacity(num_patterns);
517            for _ in 0..num_patterns {
518                let plen = 1 + rng.next_usize(4);
519                owned.push(random_bytes(&mut rng, alphabet, plen));
520            }
521            let pat_refs: Vec<&[u8]> = owned.iter().map(|v| v.as_slice()).collect();
522
523            let text_len = rng.next_usize(20);
524            let text = random_bytes(&mut rng, alphabet, text_len);
525
526            let ac = AhoCorasick::new(&pat_refs).expect("patterns non-empty");
527            let got = ac.find_iter(&text);
528            let oracle = naive_matches(&pat_refs, &text);
529            assert_eq!(got, oracle, "mismatch: patterns={pat_refs:?} text={text:?}");
530
531            // `is_match` must agree with whether the oracle found anything.
532            assert_eq!(ac.is_match(&text), !oracle.is_empty());
533        }
534    }
535
536    /// Empty individual patterns are rejected.
537    #[test]
538    fn empty_pattern_rejected() {
539        let patterns: [&str; 2] = ["ok", ""];
540        assert!(matches!(
541            AhoCorasick::new(&patterns),
542            Err(SeqError::EmptyInput)
543        ));
544    }
545
546    /// An empty pattern set builds and never matches.
547    #[test]
548    fn empty_pattern_set_never_matches() {
549        let patterns: [&str; 0] = [];
550        let ac = AhoCorasick::new(&patterns).expect("empty set is valid");
551        assert_eq!(ac.pattern_count(), 0);
552        assert!(ac.find_iter(b"anything at all").is_empty());
553        assert!(!ac.is_match(b"anything"));
554    }
555
556    /// Empty text yields no matches regardless of the patterns.
557    #[test]
558    fn empty_text_no_matches() {
559        let ac = AhoCorasick::new(&["a", "abc"]).expect("non-empty");
560        assert!(ac.find_iter(b"").is_empty());
561        assert!(!ac.is_match(b""));
562    }
563}