mod disjoint_set;
use std::collections::HashMap;
type NodeID = u32;
type StrID = u32;
type IndexType = u32;
type CharType = u8;
const ROOT: NodeID = 0;
const SINK: NodeID = 1;
const INVALID: NodeID = std::u32::MAX;
#[derive(Debug, Clone)]
struct MappedSubstring {
str_id: StrID,
start: IndexType,
end: IndexType,
}
impl MappedSubstring {
fn new(str_id: StrID, start: IndexType, end: IndexType) -> MappedSubstring {
MappedSubstring { str_id, start, end }
}
fn is_empty(&self) -> bool {
self.start == self.end
}
fn len(&self) -> IndexType {
self.end - self.start
}
}
#[derive(Debug)]
struct Node {
transitions: HashMap<CharType, NodeID>,
suffix_link: NodeID,
substr: MappedSubstring,
}
impl Node {
fn new(str_id: StrID, start: IndexType, end: IndexType) -> Node {
Node {
transitions: HashMap::new(),
suffix_link: INVALID,
substr: MappedSubstring::new(str_id, start, end),
}
}
fn get_suffix_link(&self) -> NodeID {
assert!(self.suffix_link != INVALID, "Invalid suffix link");
self.suffix_link
}
}
struct ReferencePoint {
node: NodeID,
str_id: StrID,
index: IndexType,
}
impl ReferencePoint {
fn new(node: NodeID, str_id: StrID, index: IndexType) -> ReferencePoint {
ReferencePoint {
node,
str_id,
index,
}
}
}
#[derive(Debug)]
pub struct GeneralizedSuffixTree {
node_storage: Vec<Node>,
str_storage: Vec<String>,
}
impl GeneralizedSuffixTree {
pub fn new() -> GeneralizedSuffixTree {
let mut root = Node::new(0, 0, 1);
let mut sink = Node::new(0, 0, 0);
root.suffix_link = SINK;
sink.suffix_link = ROOT;
let node_storage: Vec<Node> = vec![root, sink];
GeneralizedSuffixTree {
node_storage,
str_storage: vec![],
}
}
pub fn add_string(&mut self, mut s: String, term: char) {
self.validate_string(&s, term);
let str_id = self.str_storage.len() as StrID;
s.push(term);
self.str_storage.push(s);
self.process_suffixes(str_id);
}
fn validate_string(&self, s: &String, term: char) {
assert!(term.is_ascii(), "Only accept ASCII terminator");
assert!(
!s.contains(term),
"String should not contain terminator character"
);
for existing_str in &self.str_storage {
assert!(
!existing_str.contains(term),
"Any existing string should not contain terminator character"
);
}
}
pub fn longest_common_substring_all(&self) -> String {
let mut disjoint_set = disjoint_set::DisjointSet::new(self.node_storage.len());
let mut prev_node: HashMap<CharType, NodeID> = HashMap::new();
let mut lca_cnt: Vec<usize> = vec![0; self.node_storage.len()];
let mut longest_str: (Vec<&MappedSubstring>, IndexType) = (vec![], 0);
let mut cur_str: (Vec<&MappedSubstring>, IndexType) = (vec![], 0);
self.longest_common_substring_all_rec(
&mut disjoint_set,
&mut prev_node,
&mut lca_cnt,
ROOT,
&mut longest_str,
&mut cur_str,
);
let mut result = String::new();
for s in longest_str.0 {
result.push_str(&self.get_string_slice_short(&s));
}
result
}
fn longest_common_substring_all_rec<'a>(
&'a self,
disjoint_set: &mut disjoint_set::DisjointSet,
prev_node: &mut HashMap<CharType, NodeID>,
lca_cnt: &mut Vec<usize>,
node: NodeID,
longest_str: &mut (Vec<&'a MappedSubstring>, IndexType),
cur_str: &mut (Vec<&'a MappedSubstring>, IndexType),
) -> (usize, usize) {
let mut total_leaf = 0;
let mut total_correction = 0;
for target_node in self.get_node(node).transitions.values() {
if *target_node == INVALID {
continue;
}
let slice = &self.get_node(*target_node).substr;
if slice.end as usize == self.get_string(slice.str_id).len() {
total_leaf += 1;
let last_ch = self.get_char(slice.str_id, slice.end - 1);
if let Some(prev) = prev_node.get(&last_ch) {
let lca = disjoint_set.find_set(*prev as usize);
lca_cnt[lca as usize] += 1;
}
prev_node.insert(last_ch, *target_node);
} else {
cur_str.0.push(slice);
cur_str.1 += slice.len();
let result = self.longest_common_substring_all_rec(
disjoint_set,
prev_node,
lca_cnt,
*target_node,
longest_str,
cur_str,
);
total_leaf += result.0;
total_correction += result.1;
cur_str.0.pop();
cur_str.1 -= slice.len();
}
disjoint_set.union(node as usize, *target_node as usize);
}
total_correction += lca_cnt[node as usize];
let unique_str_cnt = total_leaf - total_correction;
if unique_str_cnt == self.str_storage.len() {
if cur_str.1 > longest_str.1 {
*longest_str = cur_str.clone();
}
}
(total_leaf, total_correction)
}
pub fn longest_common_substring_with<'a>(&self, s: &'a String) -> &'a str {
let mut longest_start: IndexType = 0;
let mut longest_len: IndexType = 0;
let mut cur_start: IndexType = 0;
let mut cur_len: IndexType = 0;
let mut node: NodeID = ROOT;
let chars = s.as_bytes();
let mut index = 0;
let mut active_length = 0;
while index < chars.len() {
let target_node_id = self.transition(node, chars[index - active_length as usize]);
if target_node_id != INVALID {
let slice = &self.get_node(target_node_id).substr;
while index != chars.len()
&& active_length < slice.len()
&& self.get_char(slice.str_id, active_length + slice.start) == chars[index]
{
index += 1;
active_length += 1;
}
let final_len = cur_len + active_length;
if final_len > longest_len {
longest_len = final_len;
longest_start = cur_start;
}
if index == chars.len() {
break;
}
if active_length == slice.len() {
node = target_node_id;
cur_len = final_len;
active_length = 0;
continue;
}
}
cur_start += 1;
if cur_start > index as IndexType {
index += 1;
continue;
}
let suffix_link = self.get_node(node).suffix_link;
if suffix_link != INVALID && suffix_link != SINK {
assert!(cur_len > 0);
node = suffix_link;
cur_len -= 1;
} else {
node = ROOT;
active_length = active_length + cur_len - 1;
cur_len = 0;
}
while active_length > 0 {
assert!(cur_start + cur_len < chars.len() as IndexType);
let target_node_id = self.transition(node, chars[(cur_start + cur_len) as usize]);
assert!(target_node_id != INVALID);
let slice = &self.get_node(target_node_id).substr;
if active_length < slice.len() {
break;
}
active_length -= slice.len();
cur_len += slice.len();
node = target_node_id;
}
}
&s[longest_start as usize..(longest_start + longest_len) as usize]
}
pub fn is_suffix(&self, s: &str) -> bool {
self.is_suffix_or_substr(s, false)
}
pub fn is_substr(&self, s: &str) -> bool {
self.is_suffix_or_substr(s, true)
}
fn is_suffix_or_substr(&self, s: &str, check_substr: bool) -> bool {
for existing_str in &self.str_storage {
assert!(
!s.contains(existing_str.chars().last().unwrap()),
"Queried string cannot contain terminator char"
);
}
let mut node = ROOT;
let mut index = 0;
let chars = s.as_bytes();
while index < s.len() {
let target_node = self.transition(node, chars[index]);
if target_node == INVALID {
return false;
}
let slice = &self.get_node(target_node).substr;
for i in slice.start..slice.end {
if index == s.len() {
let is_suffix = i as usize == self.get_string(slice.str_id).len() - 1;
return check_substr || is_suffix;
}
if chars[index] != self.get_char(slice.str_id, i) {
return false;
}
index += 1;
}
node = target_node;
}
let mut is_suffix = false;
for s in &self.str_storage {
if self.transition(node, *s.as_bytes().last().unwrap()) != INVALID {
is_suffix = true;
break;
}
}
check_substr || is_suffix
}
pub fn pretty_print(&self) {
self.print_recursive(ROOT, 0);
}
fn print_recursive(&self, node: NodeID, space_count: u32) {
for target_node in self.get_node(node).transitions.values() {
if *target_node == INVALID {
continue;
}
for _ in 0..space_count {
print!(" ");
}
let slice = &self.get_node(*target_node).substr;
println!(
"{}",
self.get_string_slice(slice.str_id, slice.start, slice.end),
);
self.print_recursive(*target_node, space_count + 4);
}
}
fn process_suffixes(&mut self, str_id: StrID) {
let mut active_point = ReferencePoint::new(ROOT, str_id, 0);
for i in 0..self.get_string(str_id).len() {
let mut cur_str =
MappedSubstring::new(str_id, active_point.index, (i + 1) as IndexType);
active_point = self.update(active_point.node, &cur_str);
cur_str.start = active_point.index;
active_point = self.canonize(active_point.node, &cur_str);
}
}
fn update(&mut self, node: NodeID, cur_str: &MappedSubstring) -> ReferencePoint {
assert!(!cur_str.is_empty());
let mut cur_str = cur_str.clone();
let mut oldr = ROOT;
let mut split_str = cur_str.clone();
split_str.end -= 1;
let last_ch = self.get_char(cur_str.str_id, cur_str.end - 1);
let mut active_point = ReferencePoint::new(node, cur_str.str_id, cur_str.start);
let mut r = node;
let mut is_endpoint = self.test_and_split(node, &split_str, last_ch, &mut r);
while !is_endpoint {
let str_len = self.get_string(active_point.str_id).len() as IndexType;
let leaf_node =
self.create_node_with_slice(active_point.str_id, cur_str.end - 1, str_len);
self.set_transition(r, last_ch, leaf_node);
if oldr != ROOT {
self.get_node_mut(oldr).suffix_link = r;
}
oldr = r;
let suffix_link = self.get_node(active_point.node).get_suffix_link();
active_point = self.canonize(suffix_link, &split_str);
split_str.start = active_point.index;
cur_str.start = active_point.index;
is_endpoint = self.test_and_split(active_point.node, &split_str, last_ch, &mut r);
}
if oldr != ROOT {
self.get_node_mut(oldr).suffix_link = active_point.node;
}
active_point
}
fn test_and_split(
&mut self,
node: NodeID,
split_str: &MappedSubstring,
ch: CharType,
r: &mut NodeID,
) -> bool {
if split_str.is_empty() {
*r = node;
return self.transition(node, ch) != INVALID;
}
let first_ch = self.get_char(split_str.str_id, split_str.start);
let target_node_id = self.transition(node, first_ch);
let target_node_slice = self.get_node(target_node_id).substr.clone();
let split_index = target_node_slice.start + split_str.len();
let ref_ch = self.get_char(target_node_slice.str_id, split_index);
if ref_ch == ch {
*r = node;
return true;
}
*r = self.create_node_with_slice(split_str.str_id, split_str.start, split_str.end);
self.set_transition(*r, ref_ch, target_node_id);
self.set_transition(node, first_ch, *r);
self.get_node_mut(target_node_id).substr.start = split_index;
false
}
fn canonize(&mut self, mut node: NodeID, cur_str: &MappedSubstring) -> ReferencePoint {
let mut cur_str = cur_str.clone();
loop {
if cur_str.is_empty() {
return ReferencePoint::new(node, cur_str.str_id, cur_str.start);
}
let ch = self.get_char(cur_str.str_id, cur_str.start);
let target_node = self.transition(node, ch);
if target_node == INVALID {
break;
}
let slice = &self.get_node(target_node).substr;
if slice.len() > cur_str.len() {
break;
}
cur_str.start += slice.len();
node = target_node;
}
ReferencePoint::new(node, cur_str.str_id, cur_str.start)
}
fn create_node_with_slice(
&mut self,
str_id: StrID,
start: IndexType,
end: IndexType,
) -> NodeID {
let node = Node::new(str_id, start, end);
self.node_storage.push(node);
(self.node_storage.len() - 1) as NodeID
}
fn get_node(&self, node_id: NodeID) -> &Node {
&self.node_storage[node_id as usize]
}
fn get_node_mut(&mut self, node_id: NodeID) -> &mut Node {
&mut self.node_storage[node_id as usize]
}
fn get_string(&self, str_id: StrID) -> &String {
&self.str_storage[str_id as usize]
}
fn get_string_slice(&self, str_id: StrID, start: IndexType, end: IndexType) -> &str {
&self.get_string(str_id)[start as usize..end as usize]
}
fn get_string_slice_short(&self, slice: &MappedSubstring) -> &str {
&self.get_string_slice(slice.str_id, slice.start, slice.end)
}
fn transition(&self, node: NodeID, ch: CharType) -> NodeID {
if node == SINK {
return ROOT;
}
match self.get_node(node).transitions.get(&ch) {
None => INVALID,
Some(x) => *x,
}
}
fn set_transition(&mut self, node: NodeID, ch: CharType, target_node: NodeID) {
self.get_node_mut(node).transitions.insert(ch, target_node);
}
fn get_char(&self, str_id: StrID, index: IndexType) -> u8 {
assert!((index as usize) < self.get_string(str_id).len());
self.get_string(str_id).as_bytes()[index as usize]
}
}