use crate::error::{SeqError, SeqResult};
use std::collections::BTreeMap;
const NO_LINK: isize = -1;
#[derive(Debug, Clone)]
struct State {
len: usize,
link: isize,
next: BTreeMap<u8, usize>,
cnt: usize,
is_clone: bool,
}
impl State {
fn new(len: usize, link: isize) -> Self {
Self {
len,
link,
next: BTreeMap::new(),
cnt: 0,
is_clone: false,
}
}
}
#[derive(Debug, Clone)]
pub struct SuffixAutomaton {
states: Vec<State>,
init: usize,
last: usize,
source_len: usize,
counts_finalised: bool,
}
impl SuffixAutomaton {
pub fn new(s: &[u8]) -> Self {
let mut states: Vec<State> = Vec::with_capacity(2 * s.len().max(1));
states.push(State::new(0, NO_LINK));
let mut sam = Self {
states,
init: 0,
last: 0,
source_len: 0,
counts_finalised: false,
};
for &c in s {
sam.extend(c);
}
sam.source_len = s.len();
sam
}
fn extend(&mut self, c: u8) {
let cur = self.states.len();
{
let cur_len = self.states[self.last].len + 1;
let mut st = State::new(cur_len, NO_LINK);
st.cnt = 1; self.states.push(st);
}
let mut p: isize = self.last as isize;
while p != NO_LINK && !self.states[p as usize].next.contains_key(&c) {
self.states[p as usize].next.insert(c, cur);
p = self.states[p as usize].link;
}
if p == NO_LINK {
self.states[cur].link = self.init as isize;
} else {
let q = self.states[p as usize].next[&c];
if self.states[p as usize].len + 1 == self.states[q].len {
self.states[cur].link = q as isize;
} else {
let clone = self.states.len();
{
let q_state = &self.states[q];
let mut cl = State::new(self.states[p as usize].len + 1, q_state.link);
cl.next = q_state.next.clone();
cl.is_clone = true;
cl.cnt = 0; self.states.push(cl);
}
while p != NO_LINK {
match self.states[p as usize].next.get(&c) {
Some(&target) if target == q => {
self.states[p as usize].next.insert(c, clone);
p = self.states[p as usize].link;
}
_ => break,
}
}
self.states[q].link = clone as isize;
self.states[cur].link = clone as isize;
}
}
self.last = cur;
}
fn finalise_counts(&mut self) {
if self.counts_finalised {
return;
}
let max_len = self.source_len;
let mut bucket = vec![0usize; max_len + 2];
for st in &self.states {
bucket[st.len] += 1;
}
for i in 1..bucket.len() {
bucket[i] += bucket[i - 1];
}
let mut order = vec![0usize; self.states.len()];
for (idx, st) in self.states.iter().enumerate() {
bucket[st.len] -= 1;
order[bucket[st.len]] = idx;
}
for &v in order.iter().rev() {
let link = self.states[v].link;
if link != NO_LINK {
let add = self.states[v].cnt;
self.states[link as usize].cnt += add;
}
}
self.counts_finalised = true;
}
fn walk(&self, pattern: &[u8]) -> Option<usize> {
let mut state = self.init;
for &c in pattern {
match self.states[state].next.get(&c) {
Some(&next) => state = next,
None => return None,
}
}
Some(state)
}
pub fn contains(&self, pattern: &[u8]) -> bool {
self.walk(pattern).is_some()
}
pub fn distinct_substring_count(&self) -> usize {
let mut total = 0usize;
for (idx, st) in self.states.iter().enumerate() {
if idx == self.init {
continue;
}
let link_len = if st.link == NO_LINK {
0
} else {
self.states[st.link as usize].len
};
total += st.len - link_len;
}
total
}
pub fn occurrences(&mut self, pattern: &[u8]) -> usize {
if pattern.is_empty() {
return self.source_len + 1;
}
self.finalise_counts();
match self.walk(pattern) {
Some(state) => self.states[state].cnt,
None => 0,
}
}
pub fn state_count(&self) -> usize {
self.states.len()
}
pub fn clone_count(&self) -> usize {
self.states.iter().filter(|s| s.is_clone).count()
}
pub fn source_len(&self) -> usize {
self.source_len
}
}
pub fn longest_common_substring(a: &[u8], b: &[u8]) -> SeqResult<(usize, usize)> {
if a.is_empty() || b.is_empty() {
return Err(SeqError::EmptyInput);
}
let sam = SuffixAutomaton::new(a);
let mut state = sam.init;
let mut length = 0usize; let mut best_len = 0usize;
let mut best_end = 0usize;
for (i, &c) in b.iter().enumerate() {
loop {
if let Some(&next) = sam.states[state].next.get(&c) {
state = next;
length += 1;
break;
}
if sam.states[state].link == NO_LINK {
length = 0;
break;
}
state = sam.states[state].link as usize;
length = sam.states[state].len;
}
if length > best_len {
best_len = length;
best_end = i + 1;
}
}
let start = best_end - best_len;
Ok((start, best_len))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handle::LcgRng;
use std::collections::HashSet;
fn brute_distinct(s: &[u8]) -> HashSet<Vec<u8>> {
let n = s.len();
let mut set = HashSet::new();
for i in 0..n {
for j in (i + 1)..=n {
set.insert(s[i..j].to_vec());
}
}
set
}
fn brute_occurrences(s: &[u8], pat: &[u8]) -> usize {
if pat.is_empty() {
return s.len() + 1;
}
if pat.len() > s.len() {
return 0;
}
(0..=(s.len() - pat.len()))
.filter(|&start| &s[start..start + pat.len()] == pat)
.count()
}
fn brute_lcs_len(a: &[u8], b: &[u8]) -> usize {
let (m, n) = (a.len(), b.len());
if m == 0 || n == 0 {
return 0;
}
let mut prev = vec![0usize; n + 1];
let mut curr = vec![0usize; n + 1];
let mut best = 0usize;
for i in 1..=m {
for j in 1..=n {
curr[j] = if a[i - 1] == b[j - 1] {
prev[j - 1] + 1
} else {
0
};
best = best.max(curr[j]);
}
std::mem::swap(&mut prev, &mut curr);
for v in curr.iter_mut() {
*v = 0;
}
}
best
}
fn random_bytes(rng: &mut LcgRng, alphabet: &[u8], len: usize) -> Vec<u8> {
(0..len)
.map(|_| alphabet[rng.next_usize(alphabet.len())])
.collect()
}
#[test]
fn distinct_count_matches_brute_force() {
for s in [
b"abcbc".as_slice(),
b"aaaa",
b"abracadabra",
b"banana",
b"mississippi",
b"a",
b"ab",
] {
let sam = SuffixAutomaton::new(s);
let got = sam.distinct_substring_count();
let expect = brute_distinct(s).len();
assert_eq!(got, expect, "distinct count for {s:?}");
}
let mut rng = LcgRng::new(99);
for &alphabet in &[b"ab".as_slice(), b"abc"] {
for _ in 0..300 {
let len = rng.next_usize(18);
let s = random_bytes(&mut rng, alphabet, len);
let sam = SuffixAutomaton::new(&s);
assert_eq!(
sam.distinct_substring_count(),
brute_distinct(&s).len(),
"random distinct count for {s:?}"
);
}
}
}
#[test]
fn membership_positive_and_negative() {
let s = b"abracadabra";
let sam = SuffixAutomaton::new(s);
let all = brute_distinct(s);
for sub in &all {
assert!(sam.contains(sub), "should contain {sub:?}");
}
for bad in [
b"xyz".as_slice(),
b"aa",
b"brc",
b"cadabrra",
b"z",
b"abrad",
] {
let present = all.contains(bad);
assert_eq!(sam.contains(bad), present, "membership of {bad:?}");
}
assert!(sam.contains(b""));
let mut rng = LcgRng::new(7);
for _ in 0..200 {
let len = 1 + rng.next_usize(12);
let text = random_bytes(&mut rng, b"abc", len);
let sam = SuffixAutomaton::new(&text);
let set = brute_distinct(&text);
for _ in 0..8 {
let qlen = 1 + rng.next_usize(5);
let q = random_bytes(&mut rng, b"abcd", qlen); assert_eq!(
sam.contains(&q),
set.contains(&q),
"probe {q:?} against {text:?}"
);
}
}
}
#[test]
fn lcs_matches_dp() {
let cases: &[(&[u8], &[u8])] = &[
(b"abcde", b"cdefg"),
(b"abcbc", b"bcbcd"),
(b"banana", b"ananas"),
(b"xxxx", b"yyyy"),
(b"hello", b"yellow"),
(b"a", b"a"),
(b"a", b"b"),
];
for &(a, b) in cases {
let (start, len) = longest_common_substring(a, b).expect("non-empty");
let expect = brute_lcs_len(a, b);
assert_eq!(len, expect, "lcs length for {a:?},{b:?}");
let sub = &b[start..start + len];
let sam = SuffixAutomaton::new(a);
assert!(sam.contains(sub), "lcs slice {sub:?} must occur in {a:?}");
}
let mut rng = LcgRng::new(2024);
for _ in 0..300 {
let la = 1 + rng.next_usize(14);
let lb = 1 + rng.next_usize(14);
let a = random_bytes(&mut rng, b"abc", la);
let b = random_bytes(&mut rng, b"abc", lb);
let (start, len) = longest_common_substring(&a, &b).expect("non-empty");
assert_eq!(len, brute_lcs_len(&a, &b), "random lcs {a:?},{b:?}");
let sub = &b[start..start + len];
let sam = SuffixAutomaton::new(&a);
assert!(sam.contains(sub), "random lcs slice {sub:?} in {a:?}");
}
}
#[test]
fn clone_split_is_exercised() {
let sam = SuffixAutomaton::new(b"abcbc");
assert!(sam.clone_count() >= 1, "abcbc must force a clone");
assert!(sam.state_count() <= 2 * 5);
assert_eq!(sam.distinct_substring_count(), 12);
assert_eq!(
sam.distinct_substring_count(),
brute_distinct(b"abcbc").len()
);
let sam2 = SuffixAutomaton::new(b"abcbcba");
assert!(sam2.clone_count() >= 1);
assert_eq!(
sam2.distinct_substring_count(),
brute_distinct(b"abcbcba").len()
);
}
#[test]
fn occurrence_counting_matches_brute_force() {
let s = b"abracadabra";
let mut sam = SuffixAutomaton::new(s);
for pat in [
b"a".as_slice(),
b"ab",
b"abra",
b"bra",
b"ra",
b"cad",
b"z",
b"abrad",
] {
assert_eq!(
sam.occurrences(pat),
brute_occurrences(s, pat),
"occurrences of {pat:?}"
);
}
let mut sam_a = SuffixAutomaton::new(b"aaaa");
assert_eq!(sam_a.occurrences(b"aa"), 3);
assert_eq!(sam_a.occurrences(b"aa"), brute_occurrences(b"aaaa", b"aa"));
assert_eq!(sam_a.occurrences(b"aaaa"), 1);
assert_eq!(sam_a.occurrences(b""), 5);
let mut rng = LcgRng::new(555);
for _ in 0..150 {
let len = 1 + rng.next_usize(14);
let text = random_bytes(&mut rng, b"ab", len);
let mut sam = SuffixAutomaton::new(&text);
for _ in 0..6 {
let qlen = 1 + rng.next_usize(4);
let q = random_bytes(&mut rng, b"abc", qlen);
assert_eq!(
sam.occurrences(&q),
brute_occurrences(&text, &q),
"occ {q:?} in {text:?}"
);
}
}
}
#[test]
fn edge_cases_empty_and_single() {
let sam_empty = SuffixAutomaton::new(b"");
assert_eq!(sam_empty.state_count(), 1);
assert_eq!(sam_empty.distinct_substring_count(), 0);
assert!(sam_empty.contains(b"")); assert!(!sam_empty.contains(b"a"));
assert_eq!(sam_empty.source_len(), 0);
let mut sam_one = SuffixAutomaton::new(b"a");
assert_eq!(sam_one.distinct_substring_count(), 1);
assert!(sam_one.contains(b"a"));
assert!(!sam_one.contains(b"b"));
assert_eq!(sam_one.occurrences(b"a"), 1);
assert_eq!(sam_one.occurrences(b"b"), 0);
assert!(matches!(
longest_common_substring(b"", b"abc"),
Err(SeqError::EmptyInput)
));
assert!(matches!(
longest_common_substring(b"abc", b""),
Err(SeqError::EmptyInput)
));
}
}