competitive_programming_rs/data_structure/
suffix_array.rs

1pub mod suffix_array {
2    use std::cmp::Ordering;
3
4    pub struct SuffixArray {
5        pub n: usize,
6        pub s: Vec<u8>,
7        pub array: Vec<usize>,
8    }
9
10    fn compare_node(i: usize, j: usize, k: usize, rank: &[i32]) -> Ordering {
11        if rank[i] != rank[j] {
12            rank[i].cmp(&rank[j])
13        } else {
14            let ri = if i + k < rank.len() { rank[i + k] } else { -1 };
15            let rj = if j + k < rank.len() { rank[j + k] } else { -1 };
16            ri.cmp(&rj)
17        }
18    }
19
20    impl SuffixArray {
21        pub fn new(s: &[u8]) -> SuffixArray {
22            let n = s.len();
23            let mut rank = vec![0; n + 1];
24            let mut array = vec![0; n + 1];
25
26            for i in 0..=n {
27                array[i] = i;
28                rank[i] = if i < n { s[i] as i32 } else { -1 };
29            }
30
31            let mut tmp = vec![0; n + 1];
32            let mut k = 1;
33            while k <= n {
34                array.sort_by(|a, b| compare_node(*a, *b, k, &rank));
35
36                tmp[array[0]] = 0;
37                for i in 1..=n {
38                    let d = if compare_node(array[i - 1], array[i], k, &rank) == Ordering::Less {
39                        1
40                    } else {
41                        0
42                    };
43                    tmp[array[i]] = tmp[array[i - 1]] + d;
44                }
45                std::mem::swap(&mut rank, &mut tmp);
46                k *= 2;
47            }
48
49            SuffixArray {
50                n,
51                array,
52                s: Vec::from(s),
53            }
54        }
55
56        pub fn contains(&self, t: &[u8]) -> bool {
57            let b = self.lower_bound(t);
58            if b >= self.array.len() {
59                false
60            } else {
61                let start = self.array[b];
62                let end = (t.len() + start).min(self.s.len());
63                let sub = &self.s[start..end];
64                sub == t
65            }
66        }
67
68        fn binary_search<F>(&self, string: &[u8], f: F) -> usize
69        where
70            F: Fn(&[u8], &[u8]) -> bool,
71        {
72            let (mut ng, mut ok) = (-1, self.n as i32 + 1);
73            while ok - ng > 1 {
74                let pos = (ng + ok) / 2;
75                let start = self.array[pos as usize];
76                let end = (start + string.len()).min(self.s.len());
77                let substring = &self.s[start..end];
78                if f(substring, string) {
79                    ng = pos;
80                } else {
81                    ok = pos;
82                }
83            }
84            ok as usize
85        }
86
87        pub fn lower_bound(&self, t: &[u8]) -> usize {
88            let check_function = |sub: &[u8], s: &[u8]| sub.cmp(s) == Ordering::Less;
89            self.binary_search(t, check_function)
90        }
91
92        pub fn upper_bound(&self, t: &[u8]) -> usize {
93            let check_function = |sub: &[u8], s: &[u8]| sub.cmp(s) != Ordering::Greater;
94            self.binary_search(t, check_function)
95        }
96    }
97
98    pub fn construct_lcp<T: Ord>(string: &[T], suffix_array: &[usize]) -> Vec<usize> {
99        assert_eq!(string.len() + 1, suffix_array.len());
100        let n = string.len();
101        let mut lcp = vec![0; n];
102        let mut rank = vec![0; n + 1];
103        for i in 0..=n {
104            rank[suffix_array[i]] = i;
105        }
106
107        let mut height = 0;
108        lcp[0] = 0;
109        for i in 0..n {
110            let j = suffix_array[rank[i] - 1];
111
112            if height > 0 {
113                height -= 1;
114            }
115            while j + height < n && i + height < n {
116                if string[j + height] != string[i + height] {
117                    break;
118                }
119                height += 1;
120            }
121
122            lcp[rank[i] - 1] = height;
123        }
124
125        lcp
126    }
127}
128#[cfg(test)]
129mod test {
130    use super::suffix_array::*;
131    use crate::data_structure::segment_tree::SegmentTree;
132    use crate::utils::test_helper::Tester;
133    use rand::prelude::*;
134
135    #[test]
136    fn small_test() {
137        let string = "abcdeabcde".to_owned().bytes().collect::<Vec<_>>();
138        let sa = SuffixArray::new(&string);
139        assert_eq!(
140            sa.lower_bound(&"a".to_owned().bytes().collect::<Vec<_>>()),
141            1
142        );
143        assert_eq!(
144            sa.upper_bound(&"a".to_owned().bytes().collect::<Vec<_>>()),
145            3
146        );
147
148        assert!(sa.contains(&"abcde".to_owned().bytes().collect::<Vec<_>>()));
149        assert!(!sa.contains(&"abce".to_owned().bytes().collect::<Vec<_>>()));
150    }
151
152    #[test]
153    fn corner_case() {
154        let string = "cba".to_owned().bytes().collect::<Vec<_>>();
155        let sa = SuffixArray::new(&string);
156        assert_eq!(
157            sa.lower_bound(&"c".to_owned().bytes().collect::<Vec<_>>()),
158            3
159        );
160        assert_eq!(
161            sa.upper_bound(&"c".to_owned().bytes().collect::<Vec<_>>()),
162            4
163        );
164    }
165
166    #[test]
167    fn test_suffix_array() {
168        let mut rng = thread_rng();
169        let n = 100;
170        for _ in 0..100 {
171            let string = (0..n).map(|_| rng.gen_range(0, 30)).collect::<Vec<_>>();
172            let sa = SuffixArray::new(&string);
173
174            let mut naive = vec![];
175            for i in 0..=n {
176                let substring = string[i..].to_vec();
177                naive.push((substring, i));
178            }
179            naive.sort();
180
181            for i in 0..=n {
182                assert_eq!(sa.array[i], naive[i].1);
183            }
184
185            let lcp_array = construct_lcp(&string, &sa.array);
186            for i in 0..n {
187                let lcp = lcp_array[i];
188
189                let prev = sa.array[i];
190                let next = sa.array[i + 1];
191
192                let prev_substring = &string[prev..(prev + lcp)];
193                let next_substring = &string[next..(next + lcp)];
194                assert_eq!(prev_substring, next_substring);
195                assert_ne!(string.get(prev + lcp), string.get(next + lcp));
196            }
197        }
198    }
199
200    #[test]
201    fn jag2014summer_day4_f() {
202        let tester = Tester::new(
203            "./assets/jag2014summer-day4/F/in/",
204            "./assets/jag2014summer-day4/F/out/",
205        );
206        tester.test_solution(|sc| {
207            let s: Vec<u8> = sc.read::<String>().bytes().collect();
208            let n = s.len();
209            let reverse_s = {
210                let mut r = s.clone();
211                r.reverse();
212                r
213            };
214            let sa = SuffixArray::new(&s);
215            let reverse_sa = SuffixArray::new(&reverse_s);
216
217            let op = |a: i64, b: i64| a.min(b);
218
219            let mut rmq = SegmentTree::new(n + 1, op);
220            let mut reverse_rmq = SegmentTree::new(n + 1, op);
221            for i in 0..=n {
222                rmq.update(i, sa.array[i] as i64);
223                reverse_rmq.update(i, reverse_sa.array[i] as i64);
224            }
225
226            let m: usize = sc.read();
227            for _ in 0..m {
228                let x = sc.read::<String>().bytes().collect::<Vec<_>>();
229                let y = {
230                    let mut y: Vec<u8> = sc.read::<String>().bytes().collect::<Vec<_>>();
231                    y.reverse();
232                    y
233                };
234
235                if !sa.contains(&x) {
236                    sc.write("0\n");
237                    continue;
238                }
239                let low = sa.lower_bound(&x);
240                let up = sa.upper_bound(&x);
241
242                if !reverse_sa.contains(&y) {
243                    sc.write("0\n");
244                    continue;
245                }
246                let reverse_low = reverse_sa.lower_bound(&y);
247                let reverse_up = reverse_sa.upper_bound(&y);
248
249                if low >= up || reverse_low >= reverse_up {
250                    sc.write("0\n");
251                }
252
253                let s = rmq.query(low..up).unwrap() as usize;
254                let t = n - reverse_rmq.query(reverse_low..reverse_up).unwrap() as usize;
255                if s + x.len() <= t && s <= t - y.len() {
256                    sc.write(format!("{}\n", t - s));
257                } else {
258                    sc.write("0\n");
259                }
260            }
261        });
262    }
263}