1use std::cmp::{max, min};
3use std::collections::{hash_map::Entry, HashMap, VecDeque};
4
5pub struct Trie<C: std::hash::Hash + Eq> {
7 links: Vec<HashMap<C, usize>>,
8}
9
10impl<C: std::hash::Hash + Eq> Default for Trie<C> {
11 fn default() -> Self {
13 Self {
14 links: vec![HashMap::new()],
15 }
16 }
17}
18
19impl<C: std::hash::Hash + Eq> Trie<C> {
20 pub fn insert(&mut self, word: impl IntoIterator<Item = C>) -> usize {
22 let mut node = 0;
23
24 for ch in word {
25 let len = self.links.len();
26 node = match self.links[node].entry(ch) {
27 Entry::Occupied(entry) => *entry.get(),
28 Entry::Vacant(entry) => {
29 entry.insert(len);
30 self.links.push(HashMap::new());
31 len
32 }
33 }
34 }
35 node
36 }
37
38 pub fn get(&self, word: impl IntoIterator<Item = C>) -> Option<usize> {
40 let mut node = 0;
41 for ch in word {
42 node = *self.links[node].get(&ch)?;
43 }
44 Some(node)
45 }
46}
47
48pub struct Matcher<'a, C: Eq> {
50 pub pattern: &'a [C],
52 pub fail: Vec<usize>,
55}
56
57impl<'a, C: Eq> Matcher<'a, C> {
58 pub fn new(pattern: &'a [C]) -> Self {
80 let mut fail = Vec::with_capacity(pattern.len());
81 fail.push(0);
82 let mut len = 0;
83 for ch in &pattern[1..] {
84 while len > 0 && pattern[len] != *ch {
85 len = fail[len - 1];
86 }
87 if pattern[len] == *ch {
88 len += 1;
89 }
90 fail.push(len);
91 }
92 Self { pattern, fail }
93 }
94
95 pub fn kmp_match(&self, text: impl IntoIterator<Item = C>) -> Vec<usize> {
98 let mut len = 0;
99 text.into_iter()
100 .map(|ch| {
101 if len == self.pattern.len() {
102 len = self.fail[len - 1];
103 }
104 while len > 0 && self.pattern[len] != ch {
105 len = self.fail[len - 1];
106 }
107 if self.pattern[len] == ch {
108 len += 1;
109 }
110 len
111 })
112 .collect()
113 }
114}
115
116pub struct MultiMatcher<C: std::hash::Hash + Eq> {
118 pub trie: Trie<C>,
120 pub pat_id: Vec<Option<usize>>,
122 pub fail: Vec<usize>,
125 pub fast: Vec<usize>,
127}
128
129impl<C: std::hash::Hash + Eq> MultiMatcher<C> {
130 fn next(trie: &Trie<C>, fail: &[usize], mut node: usize, ch: &C) -> usize {
131 loop {
132 if let Some(&child) = trie.links[node].get(ch) {
133 return child;
134 } else if node == 0 {
135 return 0;
136 }
137 node = fail[node];
138 }
139 }
140
141 pub fn new(patterns: impl IntoIterator<Item = impl IntoIterator<Item = C>>) -> Self {
144 let mut trie = Trie::default();
145 let pat_nodes: Vec<usize> = patterns.into_iter().map(|pat| trie.insert(pat)).collect();
146
147 let mut pat_id = vec![None; trie.links.len()];
148 for (i, node) in pat_nodes.into_iter().enumerate() {
149 pat_id[node] = Some(i);
150 }
151
152 let mut fail = vec![0; trie.links.len()];
153 let mut fast = vec![0; trie.links.len()];
154 let mut q: VecDeque<usize> = trie.links[0].values().cloned().collect();
155
156 while let Some(node) = q.pop_front() {
157 for (ch, &child) in &trie.links[node] {
158 let nx = Self::next(&trie, &fail, fail[node], &ch);
159 fail[child] = nx;
160 fast[child] = if pat_id[nx].is_some() { nx } else { fast[nx] };
161 q.push_back(child);
162 }
163 }
164
165 Self {
166 trie,
167 pat_id,
168 fail,
169 fast,
170 }
171 }
172
173 pub fn ac_match(&self, text: impl IntoIterator<Item = C>) -> Vec<usize> {
176 let mut node = 0;
177 text.into_iter()
178 .map(|ch| {
179 node = Self::next(&self.trie, &self.fail, node, &ch);
180 node
181 })
182 .collect()
183 }
184
185 pub fn get_end_pos_and_pat_id(&self, match_nodes: &[usize]) -> Vec<(usize, usize)> {
188 let mut res = vec![];
189 for (text_pos, &(mut node)) in match_nodes.iter().enumerate() {
190 while node != 0 {
191 if let Some(id) = self.pat_id[node] {
192 res.push((text_pos + 1, id));
193 }
194 node = self.fast[node];
195 }
196 }
197 res
198 }
199}
200
201pub struct SuffixArray {
203 pub sfx: Vec<usize>,
205 pub rank: Vec<Vec<usize>>,
208}
209
210impl SuffixArray {
211 fn counting_sort(
214 vals: impl Iterator<Item = usize> + Clone,
215 val_to_key: &[usize],
216 max_key: usize,
217 ) -> Vec<usize> {
218 let mut counts = vec![0; max_key];
219 for v in vals.clone() {
220 counts[val_to_key[v]] += 1;
221 }
222 let mut total = 0;
223 for c in counts.iter_mut() {
224 total += *c;
225 *c = total - *c;
226 }
227 let mut result = vec![0; total];
228 for v in vals {
229 let c = &mut counts[val_to_key[v]];
230 result[*c] = v;
231 *c += 1;
232 }
233 result
234 }
235
236 pub fn new(text: impl IntoIterator<Item = u8>) -> Self {
238 let init_rank = text.into_iter().map(|ch| ch as usize).collect::<Vec<_>>();
239 let n = init_rank.len();
240 let mut sfx = Self::counting_sort(0..n, &init_rank, 256);
241 let mut rank = vec![init_rank];
242 for skip in (0..).map(|i| 1 << i).take_while(|&skip| skip < n) {
245 let prev_rank = rank.last().unwrap();
246 let mut cur_rank = prev_rank.clone();
247
248 let pos = (n - skip..n).chain(sfx.into_iter().filter_map(|p| p.checked_sub(skip)));
249 sfx = Self::counting_sort(pos, &prev_rank, max(n, 256));
250
251 let mut prev = sfx[0];
252 cur_rank[prev] = 0;
253 for &cur in sfx.iter().skip(1) {
254 if max(prev, cur) + skip < n
255 && prev_rank[prev] == prev_rank[cur]
256 && prev_rank[prev + skip] == prev_rank[cur + skip]
257 {
258 cur_rank[cur] = cur_rank[prev];
259 } else {
260 cur_rank[cur] = cur_rank[prev] + 1;
261 }
262 prev = cur;
263 }
264 rank.push(cur_rank);
265 }
266 Self { sfx, rank }
267 }
268
269 pub fn longest_common_prefix(&self, mut i: usize, mut j: usize) -> usize {
271 let mut len = 0;
272 for (k, rank) in self.rank.iter().enumerate().rev() {
273 if rank[i] == rank[j] {
274 i += 1 << k;
275 j += 1 << k;
276 len += 1 << k;
277 if max(i, j) >= self.sfx.len() {
278 break;
279 }
280 }
281 }
282 len
283 }
284}
285
286pub fn palindromes(text: &[impl Eq]) -> Vec<usize> {
294 let mut pal = Vec::with_capacity(2 * text.len() - 1);
295 pal.push(1);
296 while pal.len() < pal.capacity() {
297 let i = pal.len() - 1;
298 let max_len = min(i + 1, pal.capacity() - i);
299 while pal[i] < max_len && text[(i - pal[i] - 1) / 2] == text[(i + pal[i] + 1) / 2] {
300 pal[i] += 2;
301 }
302 if let Some(a) = 1usize.checked_sub(pal[i]) {
303 pal.push(a);
304 } else {
305 for d in 1.. {
306 let (a, b) = (pal[i - d], pal[i] - d);
307 if a < b {
308 pal.push(a);
309 } else {
310 pal.push(b);
311 break;
312 }
313 }
314 }
315 }
316 pal
317}
318
319#[cfg(test)]
320mod test {
321 use super::*;
322
323 #[test]
324 fn test_trie() {
325 let dict = vec!["banana", "benefit", "banapple", "ban"];
326
327 let trie = dict.into_iter().fold(Trie::default(), |mut trie, word| {
328 trie.insert(word.bytes());
329 trie
330 });
331
332 assert_eq!(trie.get("".bytes()), Some(0));
333 assert_eq!(trie.get("b".bytes()), Some(1));
334 assert_eq!(trie.get("banana".bytes()), Some(6));
335 assert_eq!(trie.get("be".bytes()), Some(7));
336 assert_eq!(trie.get("bane".bytes()), None);
337 }
338
339 #[test]
340 fn test_kmp_matching() {
341 let pattern = "ana";
342 let text = "banana";
343
344 let matches = Matcher::new(pattern.as_bytes()).kmp_match(text.bytes());
345
346 assert_eq!(matches, vec![0, 1, 2, 3, 2, 3]);
347 }
348
349 #[test]
350 fn test_ac_matching() {
351 let dict = vec!["banana", "benefit", "banapple", "ban", "fit"];
352 let text = "banana bans, apple benefits.";
353
354 let matcher = MultiMatcher::new(dict.iter().map(|s| s.bytes()));
355 let match_nodes = matcher.ac_match(text.bytes());
356 let end_pos_and_id = matcher.get_end_pos_and_pat_id(&match_nodes);
357
358 assert_eq!(
359 end_pos_and_id,
360 vec![(3, 3), (6, 0), (10, 3), (26, 1), (26, 4)]
361 );
362 }
363
364 #[test]
365 fn test_suffix_array() {
366 let text1 = "bobocel";
367 let text2 = "banana";
368
369 let sfx1 = SuffixArray::new(text1.bytes());
370 let sfx2 = SuffixArray::new(text2.bytes());
371
372 assert_eq!(sfx1.sfx, vec![0, 2, 4, 5, 6, 1, 3]);
373 assert_eq!(sfx2.sfx, vec![5, 3, 1, 0, 4, 2]);
374
375 assert_eq!(sfx1.longest_common_prefix(0, 2), 2);
376 assert_eq!(sfx2.longest_common_prefix(1, 3), 3);
377
378 for (p, &r) in sfx1.rank.last().unwrap().iter().enumerate() {
380 assert_eq!(sfx1.sfx[r], p);
381 }
382 for (p, &r) in sfx2.rank.last().unwrap().iter().enumerate() {
383 assert_eq!(sfx2.sfx[r], p);
384 }
385 }
386
387 #[test]
388 fn test_palindrome() {
389 let text = "banana";
390
391 let pal_len = palindromes(text.as_bytes());
392
393 assert_eq!(pal_len, vec![1, 0, 1, 0, 3, 0, 5, 0, 3, 0, 1]);
394 }
395}