competitive_programming_rs/data_structure/
suffix_array.rs1pub mod suffix_array {
2 use std::cmp::Ordering;
3
4 pub struct SuffixArray {
5 pub n: usize,
6 pub s: Vec<u8>,
7 pub array: Vec<usize>,
8 }
9
10 fn compare_node(i: usize, j: usize, k: usize, rank: &[i32]) -> Ordering {
11 if rank[i] != rank[j] {
12 rank[i].cmp(&rank[j])
13 } else {
14 let ri = if i + k < rank.len() { rank[i + k] } else { -1 };
15 let rj = if j + k < rank.len() { rank[j + k] } else { -1 };
16 ri.cmp(&rj)
17 }
18 }
19
20 impl SuffixArray {
21 pub fn new(s: &[u8]) -> SuffixArray {
22 let n = s.len();
23 let mut rank = vec![0; n + 1];
24 let mut array = vec![0; n + 1];
25
26 for i in 0..=n {
27 array[i] = i;
28 rank[i] = if i < n { s[i] as i32 } else { -1 };
29 }
30
31 let mut tmp = vec![0; n + 1];
32 let mut k = 1;
33 while k <= n {
34 array.sort_by(|a, b| compare_node(*a, *b, k, &rank));
35
36 tmp[array[0]] = 0;
37 for i in 1..=n {
38 let d = if compare_node(array[i - 1], array[i], k, &rank) == Ordering::Less {
39 1
40 } else {
41 0
42 };
43 tmp[array[i]] = tmp[array[i - 1]] + d;
44 }
45 std::mem::swap(&mut rank, &mut tmp);
46 k *= 2;
47 }
48
49 SuffixArray {
50 n,
51 array,
52 s: Vec::from(s),
53 }
54 }
55
56 pub fn contains(&self, t: &[u8]) -> bool {
57 let b = self.lower_bound(t);
58 if b >= self.array.len() {
59 false
60 } else {
61 let start = self.array[b];
62 let end = (t.len() + start).min(self.s.len());
63 let sub = &self.s[start..end];
64 sub == t
65 }
66 }
67
68 fn binary_search<F>(&self, string: &[u8], f: F) -> usize
69 where
70 F: Fn(&[u8], &[u8]) -> bool,
71 {
72 let (mut ng, mut ok) = (-1, self.n as i32 + 1);
73 while ok - ng > 1 {
74 let pos = (ng + ok) / 2;
75 let start = self.array[pos as usize];
76 let end = (start + string.len()).min(self.s.len());
77 let substring = &self.s[start..end];
78 if f(substring, string) {
79 ng = pos;
80 } else {
81 ok = pos;
82 }
83 }
84 ok as usize
85 }
86
87 pub fn lower_bound(&self, t: &[u8]) -> usize {
88 let check_function = |sub: &[u8], s: &[u8]| sub.cmp(s) == Ordering::Less;
89 self.binary_search(t, check_function)
90 }
91
92 pub fn upper_bound(&self, t: &[u8]) -> usize {
93 let check_function = |sub: &[u8], s: &[u8]| sub.cmp(s) != Ordering::Greater;
94 self.binary_search(t, check_function)
95 }
96 }
97
98 pub fn construct_lcp<T: Ord>(string: &[T], suffix_array: &[usize]) -> Vec<usize> {
99 assert_eq!(string.len() + 1, suffix_array.len());
100 let n = string.len();
101 let mut lcp = vec![0; n];
102 let mut rank = vec![0; n + 1];
103 for i in 0..=n {
104 rank[suffix_array[i]] = i;
105 }
106
107 let mut height = 0;
108 lcp[0] = 0;
109 for i in 0..n {
110 let j = suffix_array[rank[i] - 1];
111
112 if height > 0 {
113 height -= 1;
114 }
115 while j + height < n && i + height < n {
116 if string[j + height] != string[i + height] {
117 break;
118 }
119 height += 1;
120 }
121
122 lcp[rank[i] - 1] = height;
123 }
124
125 lcp
126 }
127}
128#[cfg(test)]
129mod test {
130 use super::suffix_array::*;
131 use crate::data_structure::segment_tree::SegmentTree;
132 use crate::utils::test_helper::Tester;
133 use rand::prelude::*;
134
135 #[test]
136 fn small_test() {
137 let string = "abcdeabcde".to_owned().bytes().collect::<Vec<_>>();
138 let sa = SuffixArray::new(&string);
139 assert_eq!(
140 sa.lower_bound(&"a".to_owned().bytes().collect::<Vec<_>>()),
141 1
142 );
143 assert_eq!(
144 sa.upper_bound(&"a".to_owned().bytes().collect::<Vec<_>>()),
145 3
146 );
147
148 assert!(sa.contains(&"abcde".to_owned().bytes().collect::<Vec<_>>()));
149 assert!(!sa.contains(&"abce".to_owned().bytes().collect::<Vec<_>>()));
150 }
151
152 #[test]
153 fn corner_case() {
154 let string = "cba".to_owned().bytes().collect::<Vec<_>>();
155 let sa = SuffixArray::new(&string);
156 assert_eq!(
157 sa.lower_bound(&"c".to_owned().bytes().collect::<Vec<_>>()),
158 3
159 );
160 assert_eq!(
161 sa.upper_bound(&"c".to_owned().bytes().collect::<Vec<_>>()),
162 4
163 );
164 }
165
166 #[test]
167 fn test_suffix_array() {
168 let mut rng = thread_rng();
169 let n = 100;
170 for _ in 0..100 {
171 let string = (0..n).map(|_| rng.gen_range(0, 30)).collect::<Vec<_>>();
172 let sa = SuffixArray::new(&string);
173
174 let mut naive = vec![];
175 for i in 0..=n {
176 let substring = string[i..].to_vec();
177 naive.push((substring, i));
178 }
179 naive.sort();
180
181 for i in 0..=n {
182 assert_eq!(sa.array[i], naive[i].1);
183 }
184
185 let lcp_array = construct_lcp(&string, &sa.array);
186 for i in 0..n {
187 let lcp = lcp_array[i];
188
189 let prev = sa.array[i];
190 let next = sa.array[i + 1];
191
192 let prev_substring = &string[prev..(prev + lcp)];
193 let next_substring = &string[next..(next + lcp)];
194 assert_eq!(prev_substring, next_substring);
195 assert_ne!(string.get(prev + lcp), string.get(next + lcp));
196 }
197 }
198 }
199
200 #[test]
201 fn jag2014summer_day4_f() {
202 let tester = Tester::new(
203 "./assets/jag2014summer-day4/F/in/",
204 "./assets/jag2014summer-day4/F/out/",
205 );
206 tester.test_solution(|sc| {
207 let s: Vec<u8> = sc.read::<String>().bytes().collect();
208 let n = s.len();
209 let reverse_s = {
210 let mut r = s.clone();
211 r.reverse();
212 r
213 };
214 let sa = SuffixArray::new(&s);
215 let reverse_sa = SuffixArray::new(&reverse_s);
216
217 let op = |a: i64, b: i64| a.min(b);
218
219 let mut rmq = SegmentTree::new(n + 1, op);
220 let mut reverse_rmq = SegmentTree::new(n + 1, op);
221 for i in 0..=n {
222 rmq.update(i, sa.array[i] as i64);
223 reverse_rmq.update(i, reverse_sa.array[i] as i64);
224 }
225
226 let m: usize = sc.read();
227 for _ in 0..m {
228 let x = sc.read::<String>().bytes().collect::<Vec<_>>();
229 let y = {
230 let mut y: Vec<u8> = sc.read::<String>().bytes().collect::<Vec<_>>();
231 y.reverse();
232 y
233 };
234
235 if !sa.contains(&x) {
236 sc.write("0\n");
237 continue;
238 }
239 let low = sa.lower_bound(&x);
240 let up = sa.upper_bound(&x);
241
242 if !reverse_sa.contains(&y) {
243 sc.write("0\n");
244 continue;
245 }
246 let reverse_low = reverse_sa.lower_bound(&y);
247 let reverse_up = reverse_sa.upper_bound(&y);
248
249 if low >= up || reverse_low >= reverse_up {
250 sc.write("0\n");
251 }
252
253 let s = rmq.query(low..up).unwrap() as usize;
254 let t = n - reverse_rmq.query(reverse_low..reverse_up).unwrap() as usize;
255 if s + x.len() <= t && s <= t - y.len() {
256 sc.write(format!("{}\n", t - s));
257 } else {
258 sc.write("0\n");
259 }
260 }
261 });
262 }
263}