generalized_suffix_tree/
lib.rs

1mod disjoint_set;
2
3use std::collections::HashMap;
4
5type NodeID = u32;
6type StrID = u32;
7type IndexType = u32;
8type CharType = u8;
9
10// Special nodes.
11const ROOT: NodeID = 0;
12const SINK: NodeID = 1;
13const INVALID: NodeID = std::u32::MAX;
14
15/// This structure represents a slice to a string.
16#[derive(Debug, Clone)]
17struct MappedSubstring {
18    /// Unique ID of the string it's slicing, which can be used to locate the string from the tree's string storage.
19    str_id: StrID,
20
21    /// Index of the first character of the slice.
22    start: IndexType,
23
24    /// One past the index of the last character of the slice.
25    /// e.g. when `end` is equal to `start`, this is an empty slice.
26    /// Note that `end` here always represents a meaningful index, unlike in the original algorithm where a slice could potentially be open-ended.
27    /// Such open-endedness allows for online construction of the tree. Here I chose to not support online construction for convenience. It's possible
28    /// to support it by changing `end`'s type to `Option<IndexType>`.
29    end: IndexType,
30}
31
32impl MappedSubstring {
33    fn new(str_id: StrID, start: IndexType, end: IndexType) -> MappedSubstring {
34        MappedSubstring { str_id, start, end }
35    }
36
37    fn is_empty(&self) -> bool {
38        self.start == self.end
39    }
40
41    fn len(&self) -> IndexType {
42        self.end - self.start
43    }
44}
45
46/// This is a node in the tree. `transitions` represents all the possible
47/// transitions from this node to other nodes, indexed by the first character
48/// of the string slice that transition represents. The character needs to
49/// be encoded to an index between 0..MAX_CHAR_COUNT first.
50/// `suffix_link` contains the suffix link of this node (a term used in the
51/// context of Ukkonen's algorithm).
52/// `substr` stores the slice of the string that the transition from the parent
53/// node represents. By doing so we avoid having an explicit edge data type.
54#[derive(Debug)]
55struct Node {
56    transitions: HashMap<CharType, NodeID>,
57
58    suffix_link: NodeID,
59
60    /// The slice of the string this node represents.
61    substr: MappedSubstring,
62}
63
64impl Node {
65    fn new(str_id: StrID, start: IndexType, end: IndexType) -> Node {
66        Node {
67            transitions: HashMap::new(),
68            suffix_link: INVALID,
69            substr: MappedSubstring::new(str_id, start, end),
70        }
71    }
72
73    fn get_suffix_link(&self) -> NodeID {
74        assert!(self.suffix_link != INVALID, "Invalid suffix link");
75        self.suffix_link
76    }
77}
78
79/// A data structure used to store the current state during the Ukkonen's algorithm.
80struct ReferencePoint {
81    /// The active node.
82    node: NodeID,
83
84    /// The current string we are processing.
85    str_id: StrID,
86
87    /// The active point.
88    index: IndexType,
89}
90
91impl ReferencePoint {
92    fn new(node: NodeID, str_id: StrID, index: IndexType) -> ReferencePoint {
93        ReferencePoint {
94            node,
95            str_id,
96            index,
97        }
98    }
99}
100
101/// This is the generalized suffix tree, implemented using Ukkonen's Algorithm.
102/// One important modification to the algorithm is that this is no longer an online
103/// algorithm, i.e. it only accepts strings fully provided to the suffix tree, instead
104/// of being able to stream processing each string. It is not a fundamental limitation and can be supported.
105///
106/// # Examples
107///
108/// ```
109/// use generalized_suffix_tree::GeneralizedSuffixTree;
110/// let mut tree = GeneralizedSuffixTree::new();
111/// tree.add_string(String::from("ABCDABCE"), '$');
112/// tree.add_string(String::from("CDEFDEFG"), '#');
113/// println!("{}", tree.is_suffix("BCE"));
114/// ```
115#[derive(Debug)]
116pub struct GeneralizedSuffixTree {
117    node_storage: Vec<Node>,
118    str_storage: Vec<String>,
119}
120
121impl GeneralizedSuffixTree {
122    pub fn new() -> GeneralizedSuffixTree {
123        // Set the slice of root to be [0, 1) to allow it consume one character whenever we are transitioning from sink to root.
124        // No other node will ever transition to root so this won't affect anything else.
125        let mut root = Node::new(0, 0, 1);
126        let mut sink = Node::new(0, 0, 0);
127
128        root.suffix_link = SINK;
129        sink.suffix_link = ROOT;
130
131        let node_storage: Vec<Node> = vec![root, sink];
132        GeneralizedSuffixTree {
133            node_storage,
134            str_storage: vec![],
135        }
136    }
137
138    /// Add a new string to the generalized suffix tree.
139    pub fn add_string(&mut self, mut s: String, term: char) {
140        self.validate_string(&s, term);
141
142        let str_id = self.str_storage.len() as StrID;
143
144        // Add a unique terminator character to the end of the string.
145        s.push(term);
146
147        self.str_storage.push(s);
148        self.process_suffixes(str_id);
149    }
150
151    fn validate_string(&self, s: &String, term: char) {
152        assert!(term.is_ascii(), "Only accept ASCII terminator");
153        assert!(
154            !s.contains(term),
155            "String should not contain terminator character"
156        );
157        for existing_str in &self.str_storage {
158            assert!(
159                !existing_str.contains(term),
160                "Any existing string should not contain terminator character"
161            );
162        }
163    }
164
165    /// Find the longest common substring among all strings in the suffix.
166    /// This function can be used when you already have a suffix tree built,
167    /// and would need to know the longest commmon substring.
168    /// It can be trivially extended to support longest common substring among
169    /// `K` strings.
170    pub fn longest_common_substring_all(&self) -> String {
171        let mut disjoint_set = disjoint_set::DisjointSet::new(self.node_storage.len());
172
173        // prev_node stores the most recent occurance of a leaf that belongs to each string.
174        // We use the terminator character (which uniquely represents a string) as the key.
175        let mut prev_node: HashMap<CharType, NodeID> = HashMap::new();
176
177        // lca_cnt[v] means the total number of times that the lca of two nodes is node v.
178        let mut lca_cnt: Vec<usize> = vec![0; self.node_storage.len()];
179
180        let mut longest_str: (Vec<&MappedSubstring>, IndexType) = (vec![], 0);
181        let mut cur_str: (Vec<&MappedSubstring>, IndexType) = (vec![], 0);
182        self.longest_common_substring_all_rec(
183            &mut disjoint_set,
184            &mut prev_node,
185            &mut lca_cnt,
186            ROOT,
187            &mut longest_str,
188            &mut cur_str,
189        );
190
191        let mut result = String::new();
192        for s in longest_str.0 {
193            result.push_str(&self.get_string_slice_short(&s));
194        }
195        result
196    }
197
198    /// A recursive DFS that does a couple of things in one run:
199    /// - Obtain the each pair of leaves that belong to the same string and are
200    ///   consecutive in DFS visits. (stored in `prev_node`)
201    /// - Tarjan's Algorithm to compute the least common ancestor for each
202    ///   of the above pairs. (information stored in `disjoint_set`)
203    /// - Maintain the number of times an LCA lands on each node. (`lca_cnt`)
204    /// This function returns a tuple:
205    /// - Total number of leaves in the subtree.
206    /// - Sum of all LCA counts from each node in the subtree,
207    /// including the node itself.
208    /// These two numbers can be used to compute the number of unique strings
209    /// occured in the subtree, which can be used to check whether we found
210    /// a common substring.
211    /// Details of the algorithm can be found here:
212    /// https://web.cs.ucdavis.edu/~gusfield/cs224f09/commonsubstrings.pdf
213    fn longest_common_substring_all_rec<'a>(
214        &'a self,
215        disjoint_set: &mut disjoint_set::DisjointSet,
216        prev_node: &mut HashMap<CharType, NodeID>,
217        lca_cnt: &mut Vec<usize>,
218        node: NodeID,
219        longest_str: &mut (Vec<&'a MappedSubstring>, IndexType),
220        cur_str: &mut (Vec<&'a MappedSubstring>, IndexType),
221    ) -> (usize, usize) {
222        let mut total_leaf = 0;
223        let mut total_correction = 0;
224        for target_node in self.get_node(node).transitions.values() {
225            if *target_node == INVALID {
226                continue;
227            }
228            let slice = &self.get_node(*target_node).substr;
229            if slice.end as usize == self.get_string(slice.str_id).len() {
230                // target_node is a leaf node.
231                total_leaf += 1;
232                let last_ch = self.get_char(slice.str_id, slice.end - 1);
233                if let Some(prev) = prev_node.get(&last_ch) {
234                    let lca = disjoint_set.find_set(*prev as usize);
235                    lca_cnt[lca as usize] += 1;
236                }
237                prev_node.insert(last_ch, *target_node);
238            } else {
239                cur_str.0.push(slice);
240                cur_str.1 += slice.len();
241                let result = self.longest_common_substring_all_rec(
242                    disjoint_set,
243                    prev_node,
244                    lca_cnt,
245                    *target_node,
246                    longest_str,
247                    cur_str,
248                );
249                total_leaf += result.0;
250                total_correction += result.1;
251
252                cur_str.0.pop();
253                cur_str.1 -= slice.len();
254            }
255
256            disjoint_set.union(node as usize, *target_node as usize);
257        }
258        total_correction += lca_cnt[node as usize];
259        let unique_str_cnt = total_leaf - total_correction;
260        if unique_str_cnt == self.str_storage.len() {
261            // This node represnets a substring that is common among all strings.
262            if cur_str.1 > longest_str.1 {
263                *longest_str = cur_str.clone();
264            }
265        }
266        (total_leaf, total_correction)
267    }
268
269    /// Find the longest common substring between string `s` and the current suffix.
270    /// This function allows us compute this without adding `s` to the suffix.
271    pub fn longest_common_substring_with<'a>(&self, s: &'a String) -> &'a str {
272        let mut longest_start: IndexType = 0;
273        let mut longest_len: IndexType = 0;
274        let mut cur_start: IndexType = 0;
275        let mut cur_len: IndexType = 0;
276        let mut node: NodeID = ROOT;
277
278        let chars = s.as_bytes();
279        let mut index = 0;
280        let mut active_length = 0;
281        while index < chars.len() {
282            let target_node_id = self.transition(node, chars[index - active_length as usize]);
283            if target_node_id != INVALID {
284                let slice = &self.get_node(target_node_id).substr;
285                while index != chars.len()
286                    && active_length < slice.len()
287                    && self.get_char(slice.str_id, active_length + slice.start) == chars[index]
288                {
289                    index += 1;
290                    active_length += 1;
291                }
292
293                let final_len = cur_len + active_length;
294                if final_len > longest_len {
295                    longest_len = final_len;
296                    longest_start = cur_start;
297                }
298
299                if index == chars.len() {
300                    break;
301                }
302
303                if active_length == slice.len() {
304                    // We can keep following this route.
305                    node = target_node_id;
306                    cur_len = final_len;
307                    active_length = 0;
308                    continue;
309                }
310            }
311            // There was a mismatch.
312            cur_start += 1;
313            if cur_start > index as IndexType {
314                index += 1;
315                continue;
316            }
317            // We want to follow a different path with one less character from the start.
318            let suffix_link = self.get_node(node).suffix_link;
319            if suffix_link != INVALID && suffix_link != SINK {
320                assert!(cur_len > 0);
321                node = suffix_link;
322                cur_len -= 1;
323            } else {
324                node = ROOT;
325                active_length = active_length + cur_len - 1;
326                cur_len = 0;
327            }
328            while active_length > 0 {
329                assert!(cur_start + cur_len < chars.len() as IndexType);
330                let target_node_id = self.transition(node, chars[(cur_start + cur_len) as usize]);
331                assert!(target_node_id != INVALID);
332                let slice = &self.get_node(target_node_id).substr;
333                if active_length < slice.len() {
334                    break;
335                }
336                active_length -= slice.len();
337                cur_len += slice.len();
338                node = target_node_id;
339            }
340        }
341        &s[longest_start as usize..(longest_start + longest_len) as usize]
342    }
343
344    /// Checks whether a given string `s` is a suffix in the suffix tree.
345    pub fn is_suffix(&self, s: &str) -> bool {
346        self.is_suffix_or_substr(s, false)
347    }
348
349    /// Checks whether a given string `s` is a substring of any of the strings
350    /// in the suffix tree.
351    pub fn is_substr(&self, s: &str) -> bool {
352        self.is_suffix_or_substr(s, true)
353    }
354
355    fn is_suffix_or_substr(&self, s: &str, check_substr: bool) -> bool {
356        for existing_str in &self.str_storage {
357            assert!(
358                !s.contains(existing_str.chars().last().unwrap()),
359                "Queried string cannot contain terminator char"
360            );
361        }
362        let mut node = ROOT;
363        let mut index = 0;
364        let chars = s.as_bytes();
365        while index < s.len() {
366            let target_node = self.transition(node, chars[index]);
367            if target_node == INVALID {
368                return false;
369            }
370            let slice = &self.get_node(target_node).substr;
371            for i in slice.start..slice.end {
372                if index == s.len() {
373                    let is_suffix = i as usize == self.get_string(slice.str_id).len() - 1;
374                    return check_substr || is_suffix;
375                }
376                if chars[index] != self.get_char(slice.str_id, i) {
377                    return false;
378                }
379                index += 1;
380            }
381            node = target_node;
382        }
383        let mut is_suffix = false;
384        for s in &self.str_storage {
385            // The last character of each string is a terminator. We use that
386            // to look up in the current transitions to determine if we have
387            // reached the end of any string. If needed, we are also able to
388            // return which string the queried string is a suffix of.
389            if self.transition(node, *s.as_bytes().last().unwrap()) != INVALID {
390                is_suffix = true;
391                break;
392            }
393        }
394
395        check_substr || is_suffix
396    }
397
398    pub fn pretty_print(&self) {
399        self.print_recursive(ROOT, 0);
400    }
401
402    fn print_recursive(&self, node: NodeID, space_count: u32) {
403        for target_node in self.get_node(node).transitions.values() {
404            if *target_node == INVALID {
405                continue;
406            }
407            for _ in 0..space_count {
408                print!(" ");
409            }
410            let slice = &self.get_node(*target_node).substr;
411            println!(
412                "{}",
413                self.get_string_slice(slice.str_id, slice.start, slice.end),
414            );
415            self.print_recursive(*target_node, space_count + 4);
416        }
417    }
418
419    fn process_suffixes(&mut self, str_id: StrID) {
420        let mut active_point = ReferencePoint::new(ROOT, str_id, 0);
421        for i in 0..self.get_string(str_id).len() {
422            let mut cur_str =
423                MappedSubstring::new(str_id, active_point.index, (i + 1) as IndexType);
424            active_point = self.update(active_point.node, &cur_str);
425            cur_str.start = active_point.index;
426            active_point = self.canonize(active_point.node, &cur_str);
427        }
428    }
429
430    fn update(&mut self, node: NodeID, cur_str: &MappedSubstring) -> ReferencePoint {
431        assert!(!cur_str.is_empty());
432
433        let mut cur_str = cur_str.clone();
434
435        let mut oldr = ROOT;
436
437        let mut split_str = cur_str.clone();
438        split_str.end -= 1;
439
440        let last_ch = self.get_char(cur_str.str_id, cur_str.end - 1);
441
442        let mut active_point = ReferencePoint::new(node, cur_str.str_id, cur_str.start);
443
444        let mut r = node;
445
446        let mut is_endpoint = self.test_and_split(node, &split_str, last_ch, &mut r);
447        while !is_endpoint {
448            let str_len = self.get_string(active_point.str_id).len() as IndexType;
449            let leaf_node =
450                self.create_node_with_slice(active_point.str_id, cur_str.end - 1, str_len);
451            self.set_transition(r, last_ch, leaf_node);
452            if oldr != ROOT {
453                self.get_node_mut(oldr).suffix_link = r;
454            }
455            oldr = r;
456            let suffix_link = self.get_node(active_point.node).get_suffix_link();
457            active_point = self.canonize(suffix_link, &split_str);
458            split_str.start = active_point.index;
459            cur_str.start = active_point.index;
460            is_endpoint = self.test_and_split(active_point.node, &split_str, last_ch, &mut r);
461        }
462        if oldr != ROOT {
463            self.get_node_mut(oldr).suffix_link = active_point.node;
464        }
465        active_point
466    }
467
468    fn test_and_split(
469        &mut self,
470        node: NodeID,
471        split_str: &MappedSubstring,
472        ch: CharType,
473        r: &mut NodeID,
474    ) -> bool {
475        if split_str.is_empty() {
476            *r = node;
477            return self.transition(node, ch) != INVALID;
478        }
479        let first_ch = self.get_char(split_str.str_id, split_str.start);
480
481        let target_node_id = self.transition(node, first_ch);
482        let target_node_slice = self.get_node(target_node_id).substr.clone();
483
484        let split_index = target_node_slice.start + split_str.len();
485        let ref_ch = self.get_char(target_node_slice.str_id, split_index);
486
487        if ref_ch == ch {
488            *r = node;
489            return true;
490        }
491        // Split target_node into two nodes by inserting r in the middle.
492        *r = self.create_node_with_slice(split_str.str_id, split_str.start, split_str.end);
493        self.set_transition(*r, ref_ch, target_node_id);
494        self.set_transition(node, first_ch, *r);
495        self.get_node_mut(target_node_id).substr.start = split_index;
496
497        false
498    }
499
500    fn canonize(&mut self, mut node: NodeID, cur_str: &MappedSubstring) -> ReferencePoint {
501        let mut cur_str = cur_str.clone();
502        loop {
503            if cur_str.is_empty() {
504                return ReferencePoint::new(node, cur_str.str_id, cur_str.start);
505            }
506
507            let ch = self.get_char(cur_str.str_id, cur_str.start);
508
509            let target_node = self.transition(node, ch);
510            if target_node == INVALID {
511                break;
512            }
513            let slice = &self.get_node(target_node).substr;
514            if slice.len() > cur_str.len() {
515                break;
516            }
517            cur_str.start += slice.len();
518            node = target_node;
519        }
520        ReferencePoint::new(node, cur_str.str_id, cur_str.start)
521    }
522
523    fn create_node_with_slice(
524        &mut self,
525        str_id: StrID,
526        start: IndexType,
527        end: IndexType,
528    ) -> NodeID {
529        let node = Node::new(str_id, start, end);
530        self.node_storage.push(node);
531
532        (self.node_storage.len() - 1) as NodeID
533    }
534
535    fn get_node(&self, node_id: NodeID) -> &Node {
536        &self.node_storage[node_id as usize]
537    }
538
539    fn get_node_mut(&mut self, node_id: NodeID) -> &mut Node {
540        &mut self.node_storage[node_id as usize]
541    }
542
543    fn get_string(&self, str_id: StrID) -> &String {
544        &self.str_storage[str_id as usize]
545    }
546
547    fn get_string_slice(&self, str_id: StrID, start: IndexType, end: IndexType) -> &str {
548        &self.get_string(str_id)[start as usize..end as usize]
549    }
550
551    fn get_string_slice_short(&self, slice: &MappedSubstring) -> &str {
552        &self.get_string_slice(slice.str_id, slice.start, slice.end)
553    }
554
555    fn transition(&self, node: NodeID, ch: CharType) -> NodeID {
556        if node == SINK {
557            // SINK always transition to ROOT.
558            return ROOT;
559        }
560        match self.get_node(node).transitions.get(&ch) {
561            None => INVALID,
562            Some(x) => *x,
563        }
564    }
565
566    fn set_transition(&mut self, node: NodeID, ch: CharType, target_node: NodeID) {
567        self.get_node_mut(node).transitions.insert(ch, target_node);
568    }
569
570    fn get_char(&self, str_id: StrID, index: IndexType) -> u8 {
571        assert!((index as usize) < self.get_string(str_id).len());
572        self.get_string(str_id).as_bytes()[index as usize]
573    }
574}