use crate::error::{SeqError, SeqResult};
#[derive(Debug, Clone)]
pub struct SuffixArray {
sa: Vec<usize>,
text: Vec<u8>,
}
impl SuffixArray {
pub fn new(s: &[u8]) -> SeqResult<Self> {
if s.is_empty() {
return Err(SeqError::EmptyInput);
}
let sa = build_sais(s);
Ok(Self {
sa,
text: s.to_vec(),
})
}
pub fn sa(&self) -> &[usize] {
&self.sa
}
pub fn text(&self) -> &[u8] {
&self.text
}
pub fn rank(&self) -> Vec<usize> {
let n = self.sa.len();
let mut rank = vec![0usize; n];
for (i, &p) in self.sa.iter().enumerate() {
rank[p] = i;
}
rank
}
pub fn lcp(&self) -> Vec<usize> {
let n = self.sa.len();
let mut lcp = vec![0usize; n];
if n == 0 {
return lcp;
}
let rank = self.rank();
let mut h = 0usize; for i in 0..n {
if rank[i] == 0 {
h = 0;
continue;
}
let j = self.sa[rank[i] - 1]; while i + h < n && j + h < n && self.text[i + h] == self.text[j + h] {
h += 1;
}
lcp[rank[i]] = h;
h = h.saturating_sub(1);
}
lcp
}
pub fn distinct_substring_count(&self) -> usize {
let n = self.sa.len();
let total = n * (n + 1) / 2;
let lcp_sum: usize = self.lcp().iter().sum();
total - lcp_sum
}
pub fn search(&self, pattern: &[u8]) -> Vec<usize> {
if pattern.is_empty() {
return Vec::new();
}
let n = self.sa.len();
let lo = self.lower_bound(pattern);
if lo == n {
return Vec::new();
}
let hi = self.upper_bound(pattern);
let mut out: Vec<usize> = self.sa[lo..hi].to_vec();
out.sort_unstable();
out
}
fn lower_bound(&self, pattern: &[u8]) -> usize {
let n = self.sa.len();
let (mut lo, mut hi) = (0usize, n);
while lo < hi {
let mid = lo + (hi - lo) / 2;
if self.suffix_lt_pattern(self.sa[mid], pattern) {
lo = mid + 1;
} else {
hi = mid;
}
}
lo
}
fn upper_bound(&self, pattern: &[u8]) -> usize {
let n = self.sa.len();
let (mut lo, mut hi) = (0usize, n);
while lo < hi {
let mid = lo + (hi - lo) / 2;
if self.suffix_le_pattern_prefix(self.sa[mid], pattern) {
lo = mid + 1;
} else {
hi = mid;
}
}
lo
}
fn suffix_lt_pattern(&self, start: usize, pattern: &[u8]) -> bool {
let suf = &self.text[start..];
let m = pattern.len().min(suf.len());
for k in 0..m {
if suf[k] != pattern[k] {
return suf[k] < pattern[k];
}
}
suf.len() < pattern.len()
}
fn suffix_le_pattern_prefix(&self, start: usize, pattern: &[u8]) -> bool {
let suf = &self.text[start..];
let m = pattern.len().min(suf.len());
for k in 0..m {
if suf[k] != pattern[k] {
return suf[k] < pattern[k];
}
}
true
}
}
fn build_sais(s: &[u8]) -> Vec<usize> {
let n = s.len();
let mut work: Vec<usize> = Vec::with_capacity(n + 1);
for &b in s {
work.push(b as usize + 1);
}
work.push(0);
let sa_full = sais_core(&work, 257);
sa_full.into_iter().skip(1).collect()
}
type SType = bool;
fn sais_core(s: &[usize], alphabet: usize) -> Vec<usize> {
let n = s.len();
let mut sa = vec![usize::MAX; n];
if n == 1 {
sa[0] = 0;
return sa;
}
if n == 2 {
sa[0] = 1;
sa[1] = 0;
return sa;
}
let t = classify_types(s);
let bucket_sizes = bucket_sizes(s, alphabet);
let lms_positions: Vec<usize> = (1..n).filter(|&i| is_lms(&t, i)).collect();
place_lms_at_bucket_ends(&mut sa, s, &bucket_sizes, &lms_positions);
induce_l(&mut sa, s, &t, &bucket_sizes);
induce_s(&mut sa, s, &t, &bucket_sizes);
let (reduced, names_count, lms_order) = name_lms_substrings(&sa, s, &t);
let lms_sorted: Vec<usize> = if names_count == lms_order.len() {
let mut sorted = vec![0usize; lms_order.len()];
for (k, &pos) in lms_order.iter().enumerate() {
sorted[reduced[k]] = pos;
}
sorted
} else {
let sub_sa = sais_core(&reduced, names_count + 1);
sub_sa.into_iter().map(|r| lms_order[r]).collect()
};
for slot in sa.iter_mut() {
*slot = usize::MAX;
}
place_lms_sorted_at_bucket_ends(&mut sa, s, &bucket_sizes, &lms_sorted);
induce_l(&mut sa, s, &t, &bucket_sizes);
induce_s(&mut sa, s, &t, &bucket_sizes);
sa
}
fn classify_types(s: &[usize]) -> Vec<SType> {
let n = s.len();
let mut t = vec![false; n];
t[n - 1] = true; for i in (0..n - 1).rev() {
t[i] = match s[i].cmp(&s[i + 1]) {
std::cmp::Ordering::Less => true,
std::cmp::Ordering::Greater => false,
std::cmp::Ordering::Equal => t[i + 1],
};
}
t
}
fn is_lms(t: &[SType], i: usize) -> bool {
i > 0 && t[i] && !t[i - 1]
}
fn bucket_sizes(s: &[usize], alphabet: usize) -> Vec<usize> {
let mut sizes = vec![0usize; alphabet];
for &c in s {
sizes[c] += 1;
}
sizes
}
fn bucket_heads(sizes: &[usize]) -> Vec<usize> {
let mut heads = vec![0usize; sizes.len()];
let mut sum = 0usize;
for (i, &sz) in sizes.iter().enumerate() {
heads[i] = sum;
sum += sz;
}
heads
}
fn bucket_tails(sizes: &[usize]) -> Vec<usize> {
let mut tails = vec![0usize; sizes.len()];
let mut sum = 0usize;
for (i, &sz) in sizes.iter().enumerate() {
sum += sz;
tails[i] = sum.wrapping_sub(1);
}
tails
}
fn place_lms_at_bucket_ends(
sa: &mut [usize],
s: &[usize],
sizes: &[usize],
lms_positions: &[usize],
) {
let mut tails = bucket_tails(sizes);
for &p in lms_positions {
let c = s[p];
sa[tails[c]] = p;
tails[c] = tails[c].wrapping_sub(1);
}
}
fn place_lms_sorted_at_bucket_ends(
sa: &mut [usize],
s: &[usize],
sizes: &[usize],
lms_sorted: &[usize],
) {
let mut tails = bucket_tails(sizes);
for &p in lms_sorted.iter().rev() {
let c = s[p];
sa[tails[c]] = p;
tails[c] = tails[c].wrapping_sub(1);
}
}
fn induce_l(sa: &mut [usize], s: &[usize], t: &[SType], sizes: &[usize]) {
let n = s.len();
let mut heads = bucket_heads(sizes);
for i in 0..n {
let p = sa[i];
if p == usize::MAX || p == 0 {
continue;
}
let j = p - 1;
if !t[j] {
let c = s[j];
sa[heads[c]] = j;
heads[c] += 1;
}
}
}
fn induce_s(sa: &mut [usize], s: &[usize], t: &[SType], sizes: &[usize]) {
let n = s.len();
let mut tails = bucket_tails(sizes);
for i in (0..n).rev() {
let p = sa[i];
if p == usize::MAX || p == 0 {
continue;
}
let j = p - 1;
if t[j] {
let c = s[j];
sa[tails[c]] = j;
tails[c] = tails[c].wrapping_sub(1);
}
}
}
fn name_lms_substrings(sa: &[usize], s: &[usize], t: &[SType]) -> (Vec<usize>, usize, Vec<usize>) {
let n = s.len();
let mut lms_in_sa: Vec<usize> = Vec::new();
for &p in sa {
if p != usize::MAX && is_lms(t, p) {
lms_in_sa.push(p);
}
}
let mut names = vec![usize::MAX; n];
let mut current_name = 0usize;
names[lms_in_sa[0]] = current_name;
let mut prev = lms_in_sa[0];
for &cur in lms_in_sa.iter().skip(1) {
if !lms_substrings_equal(s, t, prev, cur) {
current_name += 1;
}
names[cur] = current_name;
prev = cur;
}
let names_count = current_name + 1;
let mut lms_order: Vec<usize> = Vec::new();
let mut reduced: Vec<usize> = Vec::new();
for i in 0..n {
if is_lms(t, i) {
lms_order.push(i);
reduced.push(names[i]);
}
}
(reduced, names_count, lms_order)
}
fn lms_substrings_equal(s: &[usize], t: &[SType], a: usize, b: usize) -> bool {
let n = s.len();
if a == b {
return true;
}
let mut i = 0usize;
loop {
let ai = a + i;
let bi = b + i;
if ai >= n || bi >= n {
return false;
}
let a_is_lms = is_lms(t, ai);
let b_is_lms = is_lms(t, bi);
if i > 0 && a_is_lms && b_is_lms {
return true;
}
if a_is_lms != b_is_lms {
return false;
}
if s[ai] != s[bi] || t[ai] != t[bi] {
return false;
}
i += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn brute_force_sa(s: &[u8]) -> Vec<usize> {
let n = s.len();
let mut idx: Vec<usize> = (0..n).collect();
idx.sort_by(|&a, &b| s[a..].cmp(&s[b..]));
idx
}
fn brute_force_lcp(s: &[u8], sa: &[usize]) -> Vec<usize> {
let n = sa.len();
let mut lcp = vec![0usize; n];
for i in 1..n {
let (a, b) = (sa[i - 1], sa[i]);
let mut k = 0;
while a + k < s.len() && b + k < s.len() && s[a + k] == s[b + k] {
k += 1;
}
lcp[i] = k;
}
lcp
}
fn brute_force_distinct(s: &[u8]) -> usize {
let n = s.len();
let mut set = std::collections::BTreeSet::new();
for i in 0..n {
for j in i + 1..=n {
set.insert(s[i..j].to_vec());
}
}
set.len()
}
fn naive_search(p: &[u8], t: &[u8]) -> Vec<usize> {
let (m, n) = (p.len(), t.len());
if m == 0 || m > n {
return Vec::new();
}
(0..=(n - m)).filter(|&i| &t[i..i + m] == p).collect()
}
fn random_bytes(rng: &mut crate::handle::LcgRng, alphabet: &[u8], len: usize) -> Vec<u8> {
(0..len)
.map(|_| alphabet[rng.next_usize(alphabet.len())])
.collect()
}
#[test]
fn sa_matches_brute_force() {
for s in [b"banana".as_slice(), b"mississippi", b"abracadabra"] {
let sa = SuffixArray::new(s).expect("non-empty");
assert_eq!(sa.sa(), brute_force_sa(s).as_slice(), "SA for {s:?}");
}
let mut rng = crate::handle::LcgRng::new(11);
for &alphabet in &[b"a".as_slice(), b"ab", b"abc", b"abcd"] {
for _ in 0..400 {
let len = 1 + rng.next_usize(40);
let s = random_bytes(&mut rng, alphabet, len);
let got = SuffixArray::new(&s).expect("non-empty");
assert_eq!(got.sa(), brute_force_sa(&s).as_slice(), "SA for {s:?}");
}
}
}
#[test]
fn sa_is_permutation() {
let mut rng = crate::handle::LcgRng::new(22);
for _ in 0..300 {
let len = 1 + rng.next_usize(40);
let s = random_bytes(&mut rng, b"abc", len);
let sa = SuffixArray::new(&s).expect("non-empty");
let n = s.len();
let mut seen = vec![false; n];
assert_eq!(sa.sa().len(), n);
for &p in sa.sa() {
assert!(p < n, "index out of range");
assert!(!seen[p], "duplicate index {p}");
seen[p] = true;
}
assert!(seen.iter().all(|&b| b), "not all indices present");
}
}
#[test]
fn lcp_matches_brute_force() {
for s in [b"banana".as_slice(), b"mississippi", b"aaaa"] {
let sa = SuffixArray::new(s).expect("non-empty");
let got = sa.lcp();
let want = brute_force_lcp(s, sa.sa());
assert_eq!(got, want, "LCP for {s:?}");
}
let mut rng = crate::handle::LcgRng::new(33);
for &alphabet in &[b"ab".as_slice(), b"abc"] {
for _ in 0..400 {
let len = 1 + rng.next_usize(40);
let s = random_bytes(&mut rng, alphabet, len);
let sa = SuffixArray::new(&s).expect("non-empty");
assert_eq!(sa.lcp(), brute_force_lcp(&s, sa.sa()), "LCP {s:?}");
}
}
}
#[test]
fn repeated_characters() {
let sa = SuffixArray::new(b"aaaa").expect("non-empty");
assert_eq!(sa.sa(), &[3, 2, 1, 0]);
assert_eq!(sa.lcp(), vec![0, 1, 2, 3]);
let sa = SuffixArray::new(b"aaaaaaaa").expect("non-empty");
assert_eq!(sa.sa(), brute_force_sa(b"aaaaaaaa").as_slice());
}
#[test]
fn search_matches_naive() {
let sa = SuffixArray::new(b"banana").expect("non-empty");
assert_eq!(sa.search(b"ana"), vec![1, 3]);
assert_eq!(sa.search(b"a"), vec![1, 3, 5]);
assert_eq!(sa.search(b"banana"), vec![0]);
assert!(sa.search(b"xyz").is_empty());
assert!(sa.search(b"").is_empty());
let mut rng = crate::handle::LcgRng::new(44);
for &alphabet in &[b"ab".as_slice(), b"abc"] {
for _ in 0..400 {
let tlen = 1 + rng.next_usize(40);
let plen = 1 + rng.next_usize(5);
let t = random_bytes(&mut rng, alphabet, tlen);
let p = random_bytes(&mut rng, alphabet, plen);
let sa = SuffixArray::new(&t).expect("non-empty");
let mut want = naive_search(&p, &t);
want.sort_unstable();
assert_eq!(sa.search(&p), want, "search p={p:?} t={t:?}");
}
}
}
#[test]
fn distinct_substring_count_matches_brute_force() {
for s in [b"banana".as_slice(), b"mississippi", b"aaaa", b"abcabc"] {
let sa = SuffixArray::new(s).expect("non-empty");
assert_eq!(
sa.distinct_substring_count(),
brute_force_distinct(s),
"distinct for {s:?}"
);
}
let mut rng = crate::handle::LcgRng::new(55);
for &alphabet in &[b"ab".as_slice(), b"abc"] {
for _ in 0..200 {
let len = 1 + rng.next_usize(24);
let s = random_bytes(&mut rng, alphabet, len);
let sa = SuffixArray::new(&s).expect("non-empty");
assert_eq!(
sa.distinct_substring_count(),
brute_force_distinct(&s),
"distinct for {s:?}"
);
}
}
}
#[test]
fn empty_input_errors() {
assert!(matches!(SuffixArray::new(b""), Err(SeqError::EmptyInput)));
}
#[test]
fn single_char() {
let sa = SuffixArray::new(b"x").expect("non-empty");
assert_eq!(sa.sa(), &[0]);
assert_eq!(sa.lcp(), vec![0]);
assert_eq!(sa.search(b"x"), vec![0]);
assert!(sa.search(b"y").is_empty());
assert_eq!(sa.distinct_substring_count(), 1);
}
#[test]
fn rank_is_inverse() {
let mut rng = crate::handle::LcgRng::new(66);
for _ in 0..200 {
let len = 1 + rng.next_usize(30);
let s = random_bytes(&mut rng, b"abc", len);
let sa = SuffixArray::new(&s).expect("non-empty");
let rank = sa.rank();
for (i, &p) in sa.sa().iter().enumerate() {
assert_eq!(rank[p], i);
}
}
}
}