1mod disjoint_set;
2
3use std::collections::HashMap;
4
5type NodeID = u32;
6type StrID = u32;
7type IndexType = u32;
8type CharType = u8;
9
10const ROOT: NodeID = 0;
12const SINK: NodeID = 1;
13const INVALID: NodeID = std::u32::MAX;
14
15#[derive(Debug, Clone)]
17struct MappedSubstring {
18 str_id: StrID,
20
21 start: IndexType,
23
24 end: IndexType,
30}
31
32impl MappedSubstring {
33 fn new(str_id: StrID, start: IndexType, end: IndexType) -> MappedSubstring {
34 MappedSubstring { str_id, start, end }
35 }
36
37 fn is_empty(&self) -> bool {
38 self.start == self.end
39 }
40
41 fn len(&self) -> IndexType {
42 self.end - self.start
43 }
44}
45
46#[derive(Debug)]
55struct Node {
56 transitions: HashMap<CharType, NodeID>,
57
58 suffix_link: NodeID,
59
60 substr: MappedSubstring,
62}
63
64impl Node {
65 fn new(str_id: StrID, start: IndexType, end: IndexType) -> Node {
66 Node {
67 transitions: HashMap::new(),
68 suffix_link: INVALID,
69 substr: MappedSubstring::new(str_id, start, end),
70 }
71 }
72
73 fn get_suffix_link(&self) -> NodeID {
74 assert!(self.suffix_link != INVALID, "Invalid suffix link");
75 self.suffix_link
76 }
77}
78
79struct ReferencePoint {
81 node: NodeID,
83
84 str_id: StrID,
86
87 index: IndexType,
89}
90
91impl ReferencePoint {
92 fn new(node: NodeID, str_id: StrID, index: IndexType) -> ReferencePoint {
93 ReferencePoint {
94 node,
95 str_id,
96 index,
97 }
98 }
99}
100
101#[derive(Debug)]
116pub struct GeneralizedSuffixTree {
117 node_storage: Vec<Node>,
118 str_storage: Vec<String>,
119}
120
121impl GeneralizedSuffixTree {
122 pub fn new() -> GeneralizedSuffixTree {
123 let mut root = Node::new(0, 0, 1);
126 let mut sink = Node::new(0, 0, 0);
127
128 root.suffix_link = SINK;
129 sink.suffix_link = ROOT;
130
131 let node_storage: Vec<Node> = vec![root, sink];
132 GeneralizedSuffixTree {
133 node_storage,
134 str_storage: vec![],
135 }
136 }
137
138 pub fn add_string(&mut self, mut s: String, term: char) {
140 self.validate_string(&s, term);
141
142 let str_id = self.str_storage.len() as StrID;
143
144 s.push(term);
146
147 self.str_storage.push(s);
148 self.process_suffixes(str_id);
149 }
150
151 fn validate_string(&self, s: &String, term: char) {
152 assert!(term.is_ascii(), "Only accept ASCII terminator");
153 assert!(
154 !s.contains(term),
155 "String should not contain terminator character"
156 );
157 for existing_str in &self.str_storage {
158 assert!(
159 !existing_str.contains(term),
160 "Any existing string should not contain terminator character"
161 );
162 }
163 }
164
165 pub fn longest_common_substring_all(&self) -> String {
171 let mut disjoint_set = disjoint_set::DisjointSet::new(self.node_storage.len());
172
173 let mut prev_node: HashMap<CharType, NodeID> = HashMap::new();
176
177 let mut lca_cnt: Vec<usize> = vec![0; self.node_storage.len()];
179
180 let mut longest_str: (Vec<&MappedSubstring>, IndexType) = (vec![], 0);
181 let mut cur_str: (Vec<&MappedSubstring>, IndexType) = (vec![], 0);
182 self.longest_common_substring_all_rec(
183 &mut disjoint_set,
184 &mut prev_node,
185 &mut lca_cnt,
186 ROOT,
187 &mut longest_str,
188 &mut cur_str,
189 );
190
191 let mut result = String::new();
192 for s in longest_str.0 {
193 result.push_str(&self.get_string_slice_short(&s));
194 }
195 result
196 }
197
198 fn longest_common_substring_all_rec<'a>(
214 &'a self,
215 disjoint_set: &mut disjoint_set::DisjointSet,
216 prev_node: &mut HashMap<CharType, NodeID>,
217 lca_cnt: &mut Vec<usize>,
218 node: NodeID,
219 longest_str: &mut (Vec<&'a MappedSubstring>, IndexType),
220 cur_str: &mut (Vec<&'a MappedSubstring>, IndexType),
221 ) -> (usize, usize) {
222 let mut total_leaf = 0;
223 let mut total_correction = 0;
224 for target_node in self.get_node(node).transitions.values() {
225 if *target_node == INVALID {
226 continue;
227 }
228 let slice = &self.get_node(*target_node).substr;
229 if slice.end as usize == self.get_string(slice.str_id).len() {
230 total_leaf += 1;
232 let last_ch = self.get_char(slice.str_id, slice.end - 1);
233 if let Some(prev) = prev_node.get(&last_ch) {
234 let lca = disjoint_set.find_set(*prev as usize);
235 lca_cnt[lca as usize] += 1;
236 }
237 prev_node.insert(last_ch, *target_node);
238 } else {
239 cur_str.0.push(slice);
240 cur_str.1 += slice.len();
241 let result = self.longest_common_substring_all_rec(
242 disjoint_set,
243 prev_node,
244 lca_cnt,
245 *target_node,
246 longest_str,
247 cur_str,
248 );
249 total_leaf += result.0;
250 total_correction += result.1;
251
252 cur_str.0.pop();
253 cur_str.1 -= slice.len();
254 }
255
256 disjoint_set.union(node as usize, *target_node as usize);
257 }
258 total_correction += lca_cnt[node as usize];
259 let unique_str_cnt = total_leaf - total_correction;
260 if unique_str_cnt == self.str_storage.len() {
261 if cur_str.1 > longest_str.1 {
263 *longest_str = cur_str.clone();
264 }
265 }
266 (total_leaf, total_correction)
267 }
268
269 pub fn longest_common_substring_with<'a>(&self, s: &'a String) -> &'a str {
272 let mut longest_start: IndexType = 0;
273 let mut longest_len: IndexType = 0;
274 let mut cur_start: IndexType = 0;
275 let mut cur_len: IndexType = 0;
276 let mut node: NodeID = ROOT;
277
278 let chars = s.as_bytes();
279 let mut index = 0;
280 let mut active_length = 0;
281 while index < chars.len() {
282 let target_node_id = self.transition(node, chars[index - active_length as usize]);
283 if target_node_id != INVALID {
284 let slice = &self.get_node(target_node_id).substr;
285 while index != chars.len()
286 && active_length < slice.len()
287 && self.get_char(slice.str_id, active_length + slice.start) == chars[index]
288 {
289 index += 1;
290 active_length += 1;
291 }
292
293 let final_len = cur_len + active_length;
294 if final_len > longest_len {
295 longest_len = final_len;
296 longest_start = cur_start;
297 }
298
299 if index == chars.len() {
300 break;
301 }
302
303 if active_length == slice.len() {
304 node = target_node_id;
306 cur_len = final_len;
307 active_length = 0;
308 continue;
309 }
310 }
311 cur_start += 1;
313 if cur_start > index as IndexType {
314 index += 1;
315 continue;
316 }
317 let suffix_link = self.get_node(node).suffix_link;
319 if suffix_link != INVALID && suffix_link != SINK {
320 assert!(cur_len > 0);
321 node = suffix_link;
322 cur_len -= 1;
323 } else {
324 node = ROOT;
325 active_length = active_length + cur_len - 1;
326 cur_len = 0;
327 }
328 while active_length > 0 {
329 assert!(cur_start + cur_len < chars.len() as IndexType);
330 let target_node_id = self.transition(node, chars[(cur_start + cur_len) as usize]);
331 assert!(target_node_id != INVALID);
332 let slice = &self.get_node(target_node_id).substr;
333 if active_length < slice.len() {
334 break;
335 }
336 active_length -= slice.len();
337 cur_len += slice.len();
338 node = target_node_id;
339 }
340 }
341 &s[longest_start as usize..(longest_start + longest_len) as usize]
342 }
343
344 pub fn is_suffix(&self, s: &str) -> bool {
346 self.is_suffix_or_substr(s, false)
347 }
348
349 pub fn is_substr(&self, s: &str) -> bool {
352 self.is_suffix_or_substr(s, true)
353 }
354
355 fn is_suffix_or_substr(&self, s: &str, check_substr: bool) -> bool {
356 for existing_str in &self.str_storage {
357 assert!(
358 !s.contains(existing_str.chars().last().unwrap()),
359 "Queried string cannot contain terminator char"
360 );
361 }
362 let mut node = ROOT;
363 let mut index = 0;
364 let chars = s.as_bytes();
365 while index < s.len() {
366 let target_node = self.transition(node, chars[index]);
367 if target_node == INVALID {
368 return false;
369 }
370 let slice = &self.get_node(target_node).substr;
371 for i in slice.start..slice.end {
372 if index == s.len() {
373 let is_suffix = i as usize == self.get_string(slice.str_id).len() - 1;
374 return check_substr || is_suffix;
375 }
376 if chars[index] != self.get_char(slice.str_id, i) {
377 return false;
378 }
379 index += 1;
380 }
381 node = target_node;
382 }
383 let mut is_suffix = false;
384 for s in &self.str_storage {
385 if self.transition(node, *s.as_bytes().last().unwrap()) != INVALID {
390 is_suffix = true;
391 break;
392 }
393 }
394
395 check_substr || is_suffix
396 }
397
398 pub fn pretty_print(&self) {
399 self.print_recursive(ROOT, 0);
400 }
401
402 fn print_recursive(&self, node: NodeID, space_count: u32) {
403 for target_node in self.get_node(node).transitions.values() {
404 if *target_node == INVALID {
405 continue;
406 }
407 for _ in 0..space_count {
408 print!(" ");
409 }
410 let slice = &self.get_node(*target_node).substr;
411 println!(
412 "{}",
413 self.get_string_slice(slice.str_id, slice.start, slice.end),
414 );
415 self.print_recursive(*target_node, space_count + 4);
416 }
417 }
418
419 fn process_suffixes(&mut self, str_id: StrID) {
420 let mut active_point = ReferencePoint::new(ROOT, str_id, 0);
421 for i in 0..self.get_string(str_id).len() {
422 let mut cur_str =
423 MappedSubstring::new(str_id, active_point.index, (i + 1) as IndexType);
424 active_point = self.update(active_point.node, &cur_str);
425 cur_str.start = active_point.index;
426 active_point = self.canonize(active_point.node, &cur_str);
427 }
428 }
429
430 fn update(&mut self, node: NodeID, cur_str: &MappedSubstring) -> ReferencePoint {
431 assert!(!cur_str.is_empty());
432
433 let mut cur_str = cur_str.clone();
434
435 let mut oldr = ROOT;
436
437 let mut split_str = cur_str.clone();
438 split_str.end -= 1;
439
440 let last_ch = self.get_char(cur_str.str_id, cur_str.end - 1);
441
442 let mut active_point = ReferencePoint::new(node, cur_str.str_id, cur_str.start);
443
444 let mut r = node;
445
446 let mut is_endpoint = self.test_and_split(node, &split_str, last_ch, &mut r);
447 while !is_endpoint {
448 let str_len = self.get_string(active_point.str_id).len() as IndexType;
449 let leaf_node =
450 self.create_node_with_slice(active_point.str_id, cur_str.end - 1, str_len);
451 self.set_transition(r, last_ch, leaf_node);
452 if oldr != ROOT {
453 self.get_node_mut(oldr).suffix_link = r;
454 }
455 oldr = r;
456 let suffix_link = self.get_node(active_point.node).get_suffix_link();
457 active_point = self.canonize(suffix_link, &split_str);
458 split_str.start = active_point.index;
459 cur_str.start = active_point.index;
460 is_endpoint = self.test_and_split(active_point.node, &split_str, last_ch, &mut r);
461 }
462 if oldr != ROOT {
463 self.get_node_mut(oldr).suffix_link = active_point.node;
464 }
465 active_point
466 }
467
468 fn test_and_split(
469 &mut self,
470 node: NodeID,
471 split_str: &MappedSubstring,
472 ch: CharType,
473 r: &mut NodeID,
474 ) -> bool {
475 if split_str.is_empty() {
476 *r = node;
477 return self.transition(node, ch) != INVALID;
478 }
479 let first_ch = self.get_char(split_str.str_id, split_str.start);
480
481 let target_node_id = self.transition(node, first_ch);
482 let target_node_slice = self.get_node(target_node_id).substr.clone();
483
484 let split_index = target_node_slice.start + split_str.len();
485 let ref_ch = self.get_char(target_node_slice.str_id, split_index);
486
487 if ref_ch == ch {
488 *r = node;
489 return true;
490 }
491 *r = self.create_node_with_slice(split_str.str_id, split_str.start, split_str.end);
493 self.set_transition(*r, ref_ch, target_node_id);
494 self.set_transition(node, first_ch, *r);
495 self.get_node_mut(target_node_id).substr.start = split_index;
496
497 false
498 }
499
500 fn canonize(&mut self, mut node: NodeID, cur_str: &MappedSubstring) -> ReferencePoint {
501 let mut cur_str = cur_str.clone();
502 loop {
503 if cur_str.is_empty() {
504 return ReferencePoint::new(node, cur_str.str_id, cur_str.start);
505 }
506
507 let ch = self.get_char(cur_str.str_id, cur_str.start);
508
509 let target_node = self.transition(node, ch);
510 if target_node == INVALID {
511 break;
512 }
513 let slice = &self.get_node(target_node).substr;
514 if slice.len() > cur_str.len() {
515 break;
516 }
517 cur_str.start += slice.len();
518 node = target_node;
519 }
520 ReferencePoint::new(node, cur_str.str_id, cur_str.start)
521 }
522
523 fn create_node_with_slice(
524 &mut self,
525 str_id: StrID,
526 start: IndexType,
527 end: IndexType,
528 ) -> NodeID {
529 let node = Node::new(str_id, start, end);
530 self.node_storage.push(node);
531
532 (self.node_storage.len() - 1) as NodeID
533 }
534
535 fn get_node(&self, node_id: NodeID) -> &Node {
536 &self.node_storage[node_id as usize]
537 }
538
539 fn get_node_mut(&mut self, node_id: NodeID) -> &mut Node {
540 &mut self.node_storage[node_id as usize]
541 }
542
543 fn get_string(&self, str_id: StrID) -> &String {
544 &self.str_storage[str_id as usize]
545 }
546
547 fn get_string_slice(&self, str_id: StrID, start: IndexType, end: IndexType) -> &str {
548 &self.get_string(str_id)[start as usize..end as usize]
549 }
550
551 fn get_string_slice_short(&self, slice: &MappedSubstring) -> &str {
552 &self.get_string_slice(slice.str_id, slice.start, slice.end)
553 }
554
555 fn transition(&self, node: NodeID, ch: CharType) -> NodeID {
556 if node == SINK {
557 return ROOT;
559 }
560 match self.get_node(node).transitions.get(&ch) {
561 None => INVALID,
562 Some(x) => *x,
563 }
564 }
565
566 fn set_transition(&mut self, node: NodeID, ch: CharType, target_node: NodeID) {
567 self.get_node_mut(node).transitions.insert(ch, target_node);
568 }
569
570 fn get_char(&self, str_id: StrID, index: IndexType) -> u8 {
571 assert!((index as usize) < self.get_string(str_id).len());
572 self.get_string(str_id).as_bytes()[index as usize]
573 }
574}