pub type SearchResult = Result<Vec<usize>, String>;
#[derive(Debug)]
pub struct SuffixArray {
text: String,
array: Vec<usize>,
rank: Vec<usize>,
lcp: Vec<usize>,
}
impl SuffixArray {
pub fn new(text: &str) -> Self {
let text = text.to_string();
let chars: Vec<char> = text.chars().collect();
let n = chars.len();
let mut array: Vec<usize> = (0..n).collect();
let mut rank = vec![0; n];
let mut tmp_rank = vec![0; n];
for (i, ch) in chars.iter().enumerate() {
rank[i] = *ch as usize;
}
let mut k = 1;
while k < n {
array.sort_by(|&i, &j| {
let ri = rank[i];
let rj = rank[j];
let ri1 = if i + k < n { rank[i + k] } else { 0 };
let rj1 = if j + k < n { rank[j + k] } else { 0 };
(ri, ri1).cmp(&(rj, rj1))
});
tmp_rank[array[0]] = 0;
for i in 1..n {
let curr = array[i];
let prev = array[i - 1];
let curr_pair = (rank[curr], if curr + k < n { rank[curr + k] } else { 0 });
let prev_pair = (rank[prev], if prev + k < n { rank[prev + k] } else { 0 });
tmp_rank[curr] = if curr_pair == prev_pair {
tmp_rank[prev]
} else {
i
};
}
rank.copy_from_slice(&tmp_rank);
if rank[array[n - 1]] == n - 1 {
break; }
k *= 2;
}
let lcp = Self::compute_lcp_array(&chars, &array, &rank);
Self {
text,
array,
rank,
lcp,
}
}
fn compute_lcp_array(chars: &[char], suffix_array: &[usize], rank: &[usize]) -> Vec<usize> {
let n = chars.len();
let mut lcp = vec![0; n];
let mut h = 0;
for i in 0..n {
if rank[i] > 0 {
let j = suffix_array[rank[i] - 1];
while i + h < n && j + h < n && chars[i + h] == chars[j + h] {
h += 1;
}
lcp[rank[i]] = h;
if h > 0 {
h = h.saturating_sub(1);
}
}
}
lcp
}
pub fn get_array(&self) -> &[usize] {
&self.array
}
pub fn get_rank(&self) -> &[usize] {
&self.rank
}
pub fn get_lcp(&self) -> &[usize] {
&self.lcp
}
pub fn find_all(&self, pattern: &str) -> SearchResult {
if pattern.is_empty() {
return Err("Pattern cannot be empty".to_string());
}
if pattern.len() > self.text.len() {
return Ok(vec![]);
}
self.find_bounds(pattern)
}
pub fn find_first(&self, pattern: &str) -> Result<Option<usize>, String> {
self.find_all(pattern)
.map(|positions| positions.first().copied())
}
fn find_bounds(&self, pattern: &str) -> Result<Vec<usize>, String> {
let n = self.text.chars().count();
let pattern_chars: Vec<char> = pattern.chars().collect();
let text_chars: Vec<char> = self.text.chars().collect();
let mut positions: Vec<usize> = Vec::new();
for i in 0..n {
let pos = self.array[i];
let suffix: Vec<char> = text_chars[pos..].to_vec();
if self.is_pattern_prefix(&pattern_chars, &suffix) {
positions.push(pos);
}
}
positions.sort_unstable();
Ok(positions)
}
fn is_pattern_prefix(&self, pattern: &[char], suffix: &[char]) -> bool {
if suffix.len() < pattern.len() {
return false;
}
pattern.iter().zip(suffix.iter()).all(|(p, s)| p == s)
}
}
pub fn find_all(text: &str, pattern: &str) -> SearchResult {
let sa = SuffixArray::new(text);
sa.find_all(pattern)
}
pub fn find_first(text: &str, pattern: &str) -> Result<Option<usize>, String> {
let sa = SuffixArray::new(text);
sa.find_first(pattern)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_suffix_array() {
let text = "banana";
let sa = SuffixArray::new(text);
let array = sa.get_array();
assert_eq!(array, &[5, 3, 1, 0, 4, 2]);
}
#[test]
fn test_lcp_array() {
let text = "banana";
let sa = SuffixArray::new(text);
let lcp = sa.get_lcp();
assert_eq!(lcp, &[0, 1, 3, 0, 0, 2]);
}
#[test]
fn test_find_all() {
let text = "banana";
let sa = SuffixArray::new(text);
assert_eq!(sa.find_all("ana").unwrap(), vec![1, 3]);
assert_eq!(sa.find_all("na").unwrap(), vec![2, 4]);
assert_eq!(sa.find_all("a").unwrap(), vec![1, 3, 5]);
assert_eq!(sa.find_all("ban").unwrap(), vec![0]);
assert_eq!(sa.find_all("xyz").unwrap(), Vec::<usize>::new());
}
#[test]
fn test_find_first() {
let text = "banana";
let sa = SuffixArray::new(text);
assert_eq!(sa.find_first("ana").unwrap(), Some(1));
assert_eq!(sa.find_first("na").unwrap(), Some(2));
assert_eq!(sa.find_first("a").unwrap(), Some(1));
assert_eq!(sa.find_first("ban").unwrap(), Some(0));
assert_eq!(sa.find_first("xyz").unwrap(), None);
}
#[test]
fn test_empty_pattern() {
let text = "banana";
let sa = SuffixArray::new(text);
assert!(sa.find_all("").is_err());
assert!(sa.find_first("").is_err());
}
#[test]
fn test_pattern_longer_than_text() {
let text = "abc";
let sa = SuffixArray::new(text);
assert_eq!(sa.find_all("abcd").unwrap(), Vec::<usize>::new());
assert_eq!(sa.find_first("abcd").unwrap(), None);
}
#[test]
fn test_unicode_text() {
let text = "こんにちは世界";
let sa = SuffixArray::new(text);
assert_eq!(sa.find_all("にち").unwrap(), vec![2]);
assert_eq!(sa.find_all("世界").unwrap(), vec![5]);
assert_eq!(sa.find_all("ちは").unwrap(), vec![3]);
}
#[test]
fn test_overlapping_patterns() {
let text = "aaaaa";
let sa = SuffixArray::new(text);
assert_eq!(sa.find_all("aa").unwrap(), vec![0, 1, 2, 3]);
assert_eq!(sa.find_all("aaa").unwrap(), vec![0, 1, 2]);
}
#[test]
fn test_long_text() {
let text = "a".repeat(10000) + "b";
let sa = SuffixArray::new(&text);
assert_eq!(sa.find_first("b").unwrap(), Some(10000));
assert_eq!(sa.find_all("aa").unwrap().len(), 9999);
}
#[test]
fn test_module_level_functions() {
let text = "banana";
assert_eq!(find_all(text, "ana").unwrap(), vec![1, 3]);
assert_eq!(find_first(text, "ana").unwrap(), Some(1));
assert!(find_all(text, "").is_err());
assert!(find_first(text, "").is_err());
}
#[test]
fn test_repeated_patterns() {
let text = "abababab";
let sa = SuffixArray::new(text);
assert_eq!(sa.find_all("ab").unwrap(), vec![0, 2, 4, 6]);
assert_eq!(sa.find_all("aba").unwrap(), vec![0, 2, 4]);
assert_eq!(sa.find_all("abab").unwrap(), vec![0, 2, 4]);
}
#[test]
fn test_case_sensitivity() {
let text = "bAnAnA";
let sa = SuffixArray::new(text);
assert_eq!(sa.find_all("ana").unwrap(), Vec::<usize>::new());
assert_eq!(sa.find_all("AnA").unwrap(), vec![1, 3]);
}
}