contest_algorithms/
string_proc.rs

1//! String processing algorithms.
2use std::cmp::{max, min};
3use std::collections::{hash_map::Entry, HashMap, VecDeque};
4
5/// Prefix trie, easily augmentable by adding more fields and/or methods
6pub struct Trie<C: std::hash::Hash + Eq> {
7    links: Vec<HashMap<C, usize>>,
8}
9
10impl<C: std::hash::Hash + Eq> Default for Trie<C> {
11    /// Creates an empty trie with a root node.
12    fn default() -> Self {
13        Self {
14            links: vec![HashMap::new()],
15        }
16    }
17}
18
19impl<C: std::hash::Hash + Eq> Trie<C> {
20    /// Inserts a word into the trie, and returns the index of its node.
21    pub fn insert(&mut self, word: impl IntoIterator<Item = C>) -> usize {
22        let mut node = 0;
23
24        for ch in word {
25            let len = self.links.len();
26            node = match self.links[node].entry(ch) {
27                Entry::Occupied(entry) => *entry.get(),
28                Entry::Vacant(entry) => {
29                    entry.insert(len);
30                    self.links.push(HashMap::new());
31                    len
32                }
33            }
34        }
35        node
36    }
37
38    /// Finds a word in the trie, and returns the index of its node.
39    pub fn get(&self, word: impl IntoIterator<Item = C>) -> Option<usize> {
40        let mut node = 0;
41        for ch in word {
42            node = *self.links[node].get(&ch)?;
43        }
44        Some(node)
45    }
46}
47
48/// Single-pattern matching with the Knuth-Morris-Pratt algorithm
49pub struct Matcher<'a, C: Eq> {
50    /// The string pattern to search for.
51    pub pattern: &'a [C],
52    /// KMP match failure automaton. fail[i] is the length of the longest
53    /// proper prefix-suffix of pattern[0..=i].
54    pub fail: Vec<usize>,
55}
56
57impl<'a, C: Eq> Matcher<'a, C> {
58    /// Precomputes the automaton that allows linear-time string matching.
59    ///
60    /// # Example
61    ///
62    /// ```
63    /// use contest_algorithms::string_proc::Matcher;
64    /// let byte_string: &[u8] = b"hello";
65    /// let utf8_string: &str = "hello";
66    /// let vec_char: Vec<char> = utf8_string.chars().collect();
67    ///
68    /// let match_from_byte_literal = Matcher::new(byte_string);
69    /// let match_from_utf8 = Matcher::new(utf8_string.as_bytes());
70    /// let match_from_chars = Matcher::new(&vec_char);
71    ///
72    /// let vec_int = vec![4, -3, 1];
73    /// let match_from_ints = Matcher::new(&vec_int);
74    /// ```
75    ///
76    /// # Panics
77    ///
78    /// Panics if pattern is empty.
79    pub fn new(pattern: &'a [C]) -> Self {
80        let mut fail = Vec::with_capacity(pattern.len());
81        fail.push(0);
82        let mut len = 0;
83        for ch in &pattern[1..] {
84            while len > 0 && pattern[len] != *ch {
85                len = fail[len - 1];
86            }
87            if pattern[len] == *ch {
88                len += 1;
89            }
90            fail.push(len);
91        }
92        Self { pattern, fail }
93    }
94
95    /// KMP algorithm, sets @return[i] = length of longest prefix of pattern
96    /// matching a suffix of text[0..=i].
97    pub fn kmp_match(&self, text: impl IntoIterator<Item = C>) -> Vec<usize> {
98        let mut len = 0;
99        text.into_iter()
100            .map(|ch| {
101                if len == self.pattern.len() {
102                    len = self.fail[len - 1];
103                }
104                while len > 0 && self.pattern[len] != ch {
105                    len = self.fail[len - 1];
106                }
107                if self.pattern[len] == ch {
108                    len += 1;
109                }
110                len
111            })
112            .collect()
113    }
114}
115
116/// Multi-pattern matching with the Aho-Corasick algorithm
117pub struct MultiMatcher<C: std::hash::Hash + Eq> {
118    /// A prefix trie storing the string patterns to search for.
119    pub trie: Trie<C>,
120    /// Stores which completed pattern string each node corresponds to.
121    pub pat_id: Vec<Option<usize>>,
122    /// Aho-Corasick failure automaton. fail[i] is the node corresponding to the
123    /// longest prefix-suffix of the node corresponding to i.
124    pub fail: Vec<usize>,
125    /// Shortcut to the next match along the failure chain, or to the root.
126    pub fast: Vec<usize>,
127}
128
129impl<C: std::hash::Hash + Eq> MultiMatcher<C> {
130    fn next(trie: &Trie<C>, fail: &[usize], mut node: usize, ch: &C) -> usize {
131        loop {
132            if let Some(&child) = trie.links[node].get(ch) {
133                return child;
134            } else if node == 0 {
135                return 0;
136            }
137            node = fail[node];
138        }
139    }
140
141    /// Precomputes the automaton that allows linear-time string matching.
142    /// If there are duplicate patterns, all but one copy will be ignored.
143    pub fn new(patterns: impl IntoIterator<Item = impl IntoIterator<Item = C>>) -> Self {
144        let mut trie = Trie::default();
145        let pat_nodes: Vec<usize> = patterns.into_iter().map(|pat| trie.insert(pat)).collect();
146
147        let mut pat_id = vec![None; trie.links.len()];
148        for (i, node) in pat_nodes.into_iter().enumerate() {
149            pat_id[node] = Some(i);
150        }
151
152        let mut fail = vec![0; trie.links.len()];
153        let mut fast = vec![0; trie.links.len()];
154        let mut q: VecDeque<usize> = trie.links[0].values().cloned().collect();
155
156        while let Some(node) = q.pop_front() {
157            for (ch, &child) in &trie.links[node] {
158                let nx = Self::next(&trie, &fail, fail[node], &ch);
159                fail[child] = nx;
160                fast[child] = if pat_id[nx].is_some() { nx } else { fast[nx] };
161                q.push_back(child);
162            }
163        }
164
165        Self {
166            trie,
167            pat_id,
168            fail,
169            fast,
170        }
171    }
172
173    /// Aho-Corasick algorithm, sets @return[i] = node corresponding to
174    /// longest prefix of some pattern matching a suffix of text[0..=i].
175    pub fn ac_match(&self, text: impl IntoIterator<Item = C>) -> Vec<usize> {
176        let mut node = 0;
177        text.into_iter()
178            .map(|ch| {
179                node = Self::next(&self.trie, &self.fail, node, &ch);
180                node
181            })
182            .collect()
183    }
184
185    /// For each non-empty match, returns where in the text it ends, and the index
186    /// of the corresponding pattern.
187    pub fn get_end_pos_and_pat_id(&self, match_nodes: &[usize]) -> Vec<(usize, usize)> {
188        let mut res = vec![];
189        for (text_pos, &(mut node)) in match_nodes.iter().enumerate() {
190            while node != 0 {
191                if let Some(id) = self.pat_id[node] {
192                    res.push((text_pos + 1, id));
193                }
194                node = self.fast[node];
195            }
196        }
197        res
198    }
199}
200
201/// Suffix array data structure, useful for a variety of string queries.
202pub struct SuffixArray {
203    /// The suffix array itself, holding suffix indices in sorted order.
204    pub sfx: Vec<usize>,
205    /// rank[i][j] = rank of the j'th suffix, considering only 2^i chars.
206    /// In other words, rank[i] is a ranking of the substrings text[j..j+2^i].
207    pub rank: Vec<Vec<usize>>,
208}
209
210impl SuffixArray {
211    /// O(n + max_key) stable sort on the items generated by vals.
212    /// Items v in vals are sorted according to val_to_key[v].
213    fn counting_sort(
214        vals: impl Iterator<Item = usize> + Clone,
215        val_to_key: &[usize],
216        max_key: usize,
217    ) -> Vec<usize> {
218        let mut counts = vec![0; max_key];
219        for v in vals.clone() {
220            counts[val_to_key[v]] += 1;
221        }
222        let mut total = 0;
223        for c in counts.iter_mut() {
224            total += *c;
225            *c = total - *c;
226        }
227        let mut result = vec![0; total];
228        for v in vals {
229            let c = &mut counts[val_to_key[v]];
230            result[*c] = v;
231            *c += 1;
232        }
233        result
234    }
235
236    /// Suffix array construction in O(n log n) time.
237    pub fn new(text: impl IntoIterator<Item = u8>) -> Self {
238        let init_rank = text.into_iter().map(|ch| ch as usize).collect::<Vec<_>>();
239        let n = init_rank.len();
240        let mut sfx = Self::counting_sort(0..n, &init_rank, 256);
241        let mut rank = vec![init_rank];
242        // Invariant at the start of every loop iteration:
243        // suffixes are sorted according to the first skip characters.
244        for skip in (0..).map(|i| 1 << i).take_while(|&skip| skip < n) {
245            let prev_rank = rank.last().unwrap();
246            let mut cur_rank = prev_rank.clone();
247
248            let pos = (n - skip..n).chain(sfx.into_iter().filter_map(|p| p.checked_sub(skip)));
249            sfx = Self::counting_sort(pos, &prev_rank, max(n, 256));
250
251            let mut prev = sfx[0];
252            cur_rank[prev] = 0;
253            for &cur in sfx.iter().skip(1) {
254                if max(prev, cur) + skip < n
255                    && prev_rank[prev] == prev_rank[cur]
256                    && prev_rank[prev + skip] == prev_rank[cur + skip]
257                {
258                    cur_rank[cur] = cur_rank[prev];
259                } else {
260                    cur_rank[cur] = cur_rank[prev] + 1;
261                }
262                prev = cur;
263            }
264            rank.push(cur_rank);
265        }
266        Self { sfx, rank }
267    }
268
269    /// Computes the length of longest common prefix of text[i..] and text[j..].
270    pub fn longest_common_prefix(&self, mut i: usize, mut j: usize) -> usize {
271        let mut len = 0;
272        for (k, rank) in self.rank.iter().enumerate().rev() {
273            if rank[i] == rank[j] {
274                i += 1 << k;
275                j += 1 << k;
276                len += 1 << k;
277                if max(i, j) >= self.sfx.len() {
278                    break;
279                }
280            }
281        }
282        len
283    }
284}
285
286/// Manacher's algorithm for computing palindrome substrings in linear time.
287/// pal[2*i] = odd length of palindrome centred at text[i].
288/// pal[2*i+1] = even length of palindrome centred at text[i+0.5].
289///
290/// # Panics
291///
292/// Panics if text is empty.
293pub fn palindromes(text: &[impl Eq]) -> Vec<usize> {
294    let mut pal = Vec::with_capacity(2 * text.len() - 1);
295    pal.push(1);
296    while pal.len() < pal.capacity() {
297        let i = pal.len() - 1;
298        let max_len = min(i + 1, pal.capacity() - i);
299        while pal[i] < max_len && text[(i - pal[i] - 1) / 2] == text[(i + pal[i] + 1) / 2] {
300            pal[i] += 2;
301        }
302        if let Some(a) = 1usize.checked_sub(pal[i]) {
303            pal.push(a);
304        } else {
305            for d in 1.. {
306                let (a, b) = (pal[i - d], pal[i] - d);
307                if a < b {
308                    pal.push(a);
309                } else {
310                    pal.push(b);
311                    break;
312                }
313            }
314        }
315    }
316    pal
317}
318
319#[cfg(test)]
320mod test {
321    use super::*;
322
323    #[test]
324    fn test_trie() {
325        let dict = vec!["banana", "benefit", "banapple", "ban"];
326
327        let trie = dict.into_iter().fold(Trie::default(), |mut trie, word| {
328            trie.insert(word.bytes());
329            trie
330        });
331
332        assert_eq!(trie.get("".bytes()), Some(0));
333        assert_eq!(trie.get("b".bytes()), Some(1));
334        assert_eq!(trie.get("banana".bytes()), Some(6));
335        assert_eq!(trie.get("be".bytes()), Some(7));
336        assert_eq!(trie.get("bane".bytes()), None);
337    }
338
339    #[test]
340    fn test_kmp_matching() {
341        let pattern = "ana";
342        let text = "banana";
343
344        let matches = Matcher::new(pattern.as_bytes()).kmp_match(text.bytes());
345
346        assert_eq!(matches, vec![0, 1, 2, 3, 2, 3]);
347    }
348
349    #[test]
350    fn test_ac_matching() {
351        let dict = vec!["banana", "benefit", "banapple", "ban", "fit"];
352        let text = "banana bans, apple benefits.";
353
354        let matcher = MultiMatcher::new(dict.iter().map(|s| s.bytes()));
355        let match_nodes = matcher.ac_match(text.bytes());
356        let end_pos_and_id = matcher.get_end_pos_and_pat_id(&match_nodes);
357
358        assert_eq!(
359            end_pos_and_id,
360            vec![(3, 3), (6, 0), (10, 3), (26, 1), (26, 4)]
361        );
362    }
363
364    #[test]
365    fn test_suffix_array() {
366        let text1 = "bobocel";
367        let text2 = "banana";
368
369        let sfx1 = SuffixArray::new(text1.bytes());
370        let sfx2 = SuffixArray::new(text2.bytes());
371
372        assert_eq!(sfx1.sfx, vec![0, 2, 4, 5, 6, 1, 3]);
373        assert_eq!(sfx2.sfx, vec![5, 3, 1, 0, 4, 2]);
374
375        assert_eq!(sfx1.longest_common_prefix(0, 2), 2);
376        assert_eq!(sfx2.longest_common_prefix(1, 3), 3);
377
378        // Check that sfx and rank.last() are essentially inverses of each other.
379        for (p, &r) in sfx1.rank.last().unwrap().iter().enumerate() {
380            assert_eq!(sfx1.sfx[r], p);
381        }
382        for (p, &r) in sfx2.rank.last().unwrap().iter().enumerate() {
383            assert_eq!(sfx2.sfx[r], p);
384        }
385    }
386
387    #[test]
388    fn test_palindrome() {
389        let text = "banana";
390
391        let pal_len = palindromes(text.as_bytes());
392
393        assert_eq!(pal_len, vec![1, 0, 1, 0, 3, 0, 5, 0, 3, 0, 1]);
394    }
395}