1use std::collections::VecDeque;
18
19const ALPHA: usize = 256;
22
23struct Node {
24 children: Vec<usize>,
25 fail: usize,
26 output: Vec<usize>,
27}
28
29impl Node {
30 fn new() -> Self {
31 Self { children: vec![usize::MAX; ALPHA], fail: 0, output: vec![] }
32 }
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct Match {
40 pub pattern_id: usize,
42 pub start: usize,
44 pub end: usize,
46}
47
48pub struct AhoCorasick {
52 nodes: Vec<Node>,
53 patterns: Vec<Vec<u8>>,
54 goto: Vec<[usize; ALPHA]>,
55}
56
57impl AhoCorasick {
58 pub fn new(patterns: &[&str]) -> Self {
60 let pats: Vec<Vec<u8>> = patterns.iter().map(|p| p.as_bytes().to_vec()).collect();
61 Self::from_bytes(&pats.iter().map(|v| v.as_slice()).collect::<Vec<_>>())
62 }
63
64 pub fn from_bytes(patterns: &[&[u8]]) -> Self {
66 let pats: Vec<Vec<u8>> = patterns.iter().map(|p| p.to_vec()).collect();
67 let mut nodes: Vec<Node> = vec![Node::new()];
68
69 for (pid, pat) in pats.iter().enumerate() {
71 let mut cur = 0usize;
72 for &b in pat {
73 let b = b as usize;
74 if nodes[cur].children[b] == usize::MAX {
75 nodes[cur].children[b] = nodes.len();
76 nodes.push(Node::new());
77 }
78 cur = nodes[cur].children[b];
79 }
80 nodes[cur].output.push(pid);
81 }
82
83 let n = nodes.len();
85 let mut goto = vec![[0usize; ALPHA]; n];
86 let mut queue = VecDeque::new();
87
88 let root_children: Vec<usize> = nodes[0].children.clone();
90 for (b, child) in root_children.into_iter().enumerate() {
91 if child == usize::MAX {
92 goto[0][b] = 0; } else {
94 goto[0][b] = child;
95 nodes[child].fail = 0;
96 queue.push_back(child);
97 }
98 }
99
100 while let Some(u) = queue.pop_front() {
101 let fail_u = nodes[u].fail;
103 let extra: Vec<usize> = nodes[fail_u].output.clone();
104 nodes[u].output.extend(extra);
105
106 let u_children: Vec<usize> = nodes[u].children.clone();
107 for (b, child) in u_children.into_iter().enumerate() {
108 if child == usize::MAX {
109 goto[u][b] = goto[nodes[u].fail][b];
111 } else {
112 nodes[child].fail = goto[nodes[u].fail][b];
113 goto[u][b] = child;
114 queue.push_back(child);
115 }
116 }
117 }
118
119 Self { nodes, patterns: pats, goto }
120 }
121
122 pub fn find_all(&self, text: &[u8]) -> Vec<Match> {
124 let mut state = 0usize;
125 let mut matches = Vec::new();
126 for (i, &b) in text.iter().enumerate() {
127 state = self.goto[state][b as usize];
128 for &pid in &self.nodes[state].output {
129 let pat_len = self.patterns[pid].len();
130 matches.push(Match {
131 pattern_id: pid,
132 start: i + 1 - pat_len,
133 end: i + 1,
134 });
135 }
136 }
137 matches
138 }
139
140 pub fn contains(&self, text: &[u8]) -> bool {
142 let mut state = 0usize;
143 for &b in text {
144 state = self.goto[state][b as usize];
145 if !self.nodes[state].output.is_empty() {
146 return true;
147 }
148 }
149 false
150 }
151
152 pub fn count(&self, text: &[u8]) -> usize {
154 self.find_all(text).len()
155 }
156
157 pub fn num_patterns(&self) -> usize {
159 self.patterns.len()
160 }
161
162 pub fn num_states(&self) -> usize {
164 self.nodes.len()
165 }
166}
167
168#[cfg(test)]
171mod tests {
172 use super::*;
173
174 fn find_sorted(ac: &AhoCorasick, text: &[u8]) -> Vec<(usize, usize, usize)> {
175 let mut v: Vec<_> =
176 ac.find_all(text).into_iter().map(|m| (m.start, m.end, m.pattern_id)).collect();
177 v.sort();
178 v
179 }
180
181 #[test]
184 fn classic_ahishers() {
185 let ac = AhoCorasick::new(&["he", "she", "his", "hers"]);
186 let matches = ac.find_all(b"ahishers");
187 assert_eq!(matches.len(), 4);
188 }
189
190 #[test]
191 fn ahishers_positions() {
192 let ac = AhoCorasick::new(&["he", "she", "his", "hers"]);
193 let m = find_sorted(&ac, b"ahishers");
194 assert!(m.contains(&(1, 4, 2))); assert!(m.contains(&(3, 6, 1))); assert!(m.contains(&(4, 6, 0))); assert!(m.contains(&(4, 8, 3))); }
200
201 #[test]
202 fn no_match() {
203 let ac = AhoCorasick::new(&["xyz", "foo"]);
204 assert_eq!(ac.find_all(b"hello world"), vec![]);
205 }
206
207 #[test]
208 fn single_pattern() {
209 let ac = AhoCorasick::new(&["ab"]);
210 let m = ac.find_all(b"ababab");
211 assert_eq!(m.len(), 3);
212 assert_eq!(m[0].start, 0);
213 assert_eq!(m[1].start, 2);
214 assert_eq!(m[2].start, 4);
215 }
216
217 #[test]
218 fn overlapping_patterns() {
219 let ac = AhoCorasick::new(&["aa", "aaa"]);
220 let m = ac.find_all(b"aaaa");
221 assert_eq!(ac.count(b"aaaa"), m.len());
222 assert!(m.len() >= 3);
223 }
224
225 #[test]
226 fn pattern_is_prefix_of_another() {
227 let ac = AhoCorasick::new(&["a", "ab", "abc"]);
228 let m = ac.find_all(b"abc");
229 assert_eq!(m.len(), 3);
230 }
231
232 #[test]
233 fn pattern_is_suffix_of_another() {
234 let ac = AhoCorasick::new(&["abc", "bc", "c"]);
235 let m = ac.find_all(b"abc");
236 assert_eq!(m.len(), 3);
237 }
238
239 #[test]
242 fn empty_text() {
243 let ac = AhoCorasick::new(&["hello"]);
244 assert_eq!(ac.find_all(b""), vec![]);
245 }
246
247 #[test]
248 fn empty_patterns_list() {
249 let ac = AhoCorasick::new(&[]);
250 assert_eq!(ac.find_all(b"hello"), vec![]);
251 assert_eq!(ac.num_patterns(), 0);
252 }
253
254 #[test]
255 fn single_char_patterns() {
256 let ac = AhoCorasick::new(&["a", "b", "c"]);
257 let m = ac.find_all(b"abc");
258 assert_eq!(m.len(), 3);
259 }
260
261 #[test]
262 fn repeated_pattern() {
263 let ac = AhoCorasick::new(&["aa"]);
264 let m = ac.find_all(b"aaaa");
265 assert_eq!(m.len(), 3);
266 }
267
268 #[test]
271 fn contains_true() {
272 let ac = AhoCorasick::new(&["world"]);
273 assert!(ac.contains(b"hello world"));
274 }
275
276 #[test]
277 fn contains_false() {
278 let ac = AhoCorasick::new(&["xyz"]);
279 assert!(!ac.contains(b"hello world"));
280 }
281
282 #[test]
283 fn count_multiple() {
284 let ac = AhoCorasick::new(&["an", "ban"]);
285 assert_eq!(ac.count(b"banana"), 3);
287 }
288
289 #[test]
292 fn num_patterns() {
293 let ac = AhoCorasick::new(&["a", "bb", "ccc"]);
294 assert_eq!(ac.num_patterns(), 3);
295 }
296
297 #[test]
298 fn num_states_single() {
299 let ac = AhoCorasick::new(&["abc"]);
300 assert_eq!(ac.num_states(), 4); }
302
303 #[test]
306 fn from_bytes() {
307 let ac = AhoCorasick::from_bytes(&[b"foo", b"bar"]);
308 assert_eq!(ac.count(b"foobar"), 2);
309 }
310
311 #[test]
312 fn find_all_returns_correct_span() {
313 let ac = AhoCorasick::new(&["cat"]);
314 let m = ac.find_all(b"concatenate");
315 assert!(!m.is_empty());
316 let first = &m[0];
317 assert_eq!(&b"concatenate"[first.start..first.end], b"cat");
318 }
319
320 #[test]
321 fn long_text_single_pattern() {
322 let ac = AhoCorasick::new(&["ab"]);
323 let text = b"abababababababab";
324 assert_eq!(ac.count(text), 8);
325 }
326
327 #[test]
328 fn pattern_id_ordering() {
329 let ac = AhoCorasick::new(&["z", "y", "x"]);
330 let m = ac.find_all(b"xyz");
331 let ids: Vec<usize> = m.iter().map(|m| m.pattern_id).collect();
333 assert!(ids.contains(&0));
334 assert!(ids.contains(&1));
335 assert!(ids.contains(&2));
336 }
337}