use fid::BitArray;
use std::fmt::Debug;
use crate::converter::{Converter, IdConverter};
pub fn count_chars<T, C, K>(text: K, converter: &C) -> Vec<u64>
where
T: Copy + Clone + Into<u64>,
K: AsRef<[T]>,
C: Converter<T>,
{
let text = text.as_ref();
let mut occs = vec![0; converter.len() as usize];
for &c in text.iter() {
let c = converter.convert(c).into() as usize;
occs[c] += 1;
}
occs
}
pub fn get_bucket_start_pos(occs: &[u64]) -> Vec<u64> {
let mut sum = 0;
let mut buckets = vec![0; occs.len()];
for (&occ, b) in occs.iter().zip(buckets.iter_mut()) {
*b = sum;
sum += occ
}
buckets
}
pub fn get_bucket_end_pos(occs: &[u64]) -> Vec<u64> {
let mut sum = 0;
let mut buckets = vec![0; occs.len()];
for (&occ, b) in occs.iter().zip(buckets.iter_mut()) {
sum += occ;
*b = sum;
}
buckets
}
fn get_types<T, K>(text: K) -> (BitArray, Vec<usize>)
where
T: Copy + Clone + Ord,
K: AsRef<[T]>,
{
let text = text.as_ref();
let n = text.len();
let mut types = BitArray::new(n);
types.set_bit(n - 1, true);
if n == 1 {
return (types, vec![]);
}
types.set_bit(n - 2, false);
let mut lms = vec![n - 1];
let mut prev_is_s_type = false;
for i in (0..(n - 1)).rev() {
let is_s_type = text[i] < text[i + 1] || (text[i] == text[i + 1] && prev_is_s_type);
if is_s_type {
types.set_bit(i, true);
} else if prev_is_s_type {
lms.push(i + 1);
}
prev_is_s_type = is_s_type;
}
(types, lms)
}
fn is_lms(types: &BitArray, i: u64) -> bool {
i > 0 && i < u64::max_value() && types.get_bit(i as usize) && !types.get_bit(i as usize - 1)
}
fn induced_sort<T, K, C>(text: K, converter: &C, types: &BitArray, occs: &[u64], sa: &mut [u64])
where
T: Into<u64> + Copy + Clone + Ord,
K: AsRef<[T]>,
C: Converter<T>,
{
let text = text.as_ref();
let n = text.len();
let mut bucket_start_pos = get_bucket_start_pos(occs);
for i in 0..n {
let j = sa[i];
if 0 < j && j < u64::max_value() && !types.get_bit(j as usize - 1) {
let c = converter.convert(text[j as usize - 1]).into() as usize;
let p = bucket_start_pos[c] as usize;
sa[p] = j - 1;
bucket_start_pos[c] += 1;
}
}
let mut bucket_end_pos = get_bucket_end_pos(&occs);
for i in (0..n).rev() {
let j = sa[i];
if j != 0 && j != u64::max_value() && types.get_bit(j as usize - 1) {
let c = converter.convert(text[j as usize - 1]).into() as usize;
let p = bucket_end_pos[c] as usize - 1;
sa[p] = j - 1;
bucket_end_pos[c] -= 1;
}
}
}
pub fn sais<T, C, K>(text: K, converter: &C) -> Vec<u64>
where
T: Into<u64> + Copy + Clone + Ord + Debug,
K: AsRef<[T]>,
C: Converter<T>,
{
let n = text.as_ref().len();
match n {
0 => vec![],
1 => vec![0],
_ => {
debug_assert!(
text.as_ref().last().map(|&c| c.into()) == Some(0u64),
"expected: the last char in text should be zero"
);
let mut sa = vec![u64::max_value(); n];
sais_sub(&text, &mut sa, converter);
sa
}
}
}
#[allow(clippy::cognitive_complexity)]
fn sais_sub<T, C, K>(text: K, sa: &mut [u64], converter: &C)
where
T: Into<u64> + Copy + Clone + Ord + Debug,
K: AsRef<[T]>,
C: Converter<T>,
{
let text = text.as_ref();
let n = text.len();
let (types, lms) = get_types(text);
let lms_len = lms.len();
let occs = count_chars(text, converter);
let mut bucket_end_pos = get_bucket_end_pos(&occs);
for &i in lms.iter().rev() {
let c = converter.convert(text[i]).into();
let k = bucket_end_pos[c as usize] as usize - 1;
sa[k] = i as u64;
bucket_end_pos[c as usize] = k as u64;
}
induced_sort(text, converter, &types, &occs, sa);
let mut k = 0;
for i in 0..n {
let p = sa[i];
if is_lms(&types, p) {
sa[k] = p;
k += 1;
if k == lms_len {
break;
}
}
}
let mut name = 1;
{
let (sa_lms, names) = sa.split_at_mut(lms_len);
for n in names.iter_mut() {
*n = u64::max_value();
}
names[sa_lms[0] as usize / 2] = 0; if lms_len <= 1 {
debug_assert!(lms_len != 0);
} else {
names[sa_lms[1] as usize / 2] = 1; for i in 2..lms_len {
let p = sa_lms[i - 1] as usize;
let q = sa_lms[i] as usize;
let mut d = 1;
let mut same = text[p] == text[q] && types.get_bit(p) == types.get_bit(q);
while same {
if text[p + d] != text[q + d] || types.get_bit(p + d) != types.get_bit(q + d) {
same = false;
break;
} else if is_lms(&types, (p + d) as u64) && is_lms(&types, (p + d) as u64) {
break;
}
d += 1;
}
if !same {
name += 1;
}
names[q / 2] = name;
}
}
for s in sa_lms.iter_mut() {
*s = u64::max_value();
}
}
let mut i = sa.len() - 1;
let mut j = 0;
while j < lms_len {
if sa[i] < u64::max_value() {
sa[sa.len() - 1 - j] = sa[i];
j += 1;
}
i -= 1;
}
{
let (sa1, s1) = sa.split_at_mut(sa.len() - lms_len);
if name < lms_len as u64 {
sais_sub(&s1, sa1, &IdConverter::new(name + 1 as u64));
} else {
for (i, &s) in s1.iter().enumerate() {
sa1[s as usize] = i as u64
}
}
for (j, i) in lms.into_iter().rev().enumerate() {
s1[j] = i as u64;
}
for i in 0..lms_len {
sa1[i] = s1[sa1[i] as usize];
}
}
for i in &mut sa[lms_len..] {
*i = u64::max_value();
}
let mut bucket_end_pos = get_bucket_end_pos(&occs);
for i in (0..lms_len).rev() {
let j = sa[i] as usize;
sa[i] = u64::max_value();
let c = if j == n {
0
} else {
converter.convert(text[j]).into()
};
let k = bucket_end_pos[c as usize] as usize - 1;
sa[k] = j as u64;
bucket_end_pos[c as usize] = k as u64;
}
induced_sort(text, converter, &types, &occs, sa);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::converter::RangeConverter;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
fn marks_to_lms(s: &str) -> Vec<usize> {
s.as_bytes()
.iter()
.enumerate()
.filter(|(_, &c)| c == b'*')
.map(|(i, _)| i)
.rev()
.collect::<Vec<_>>()
}
#[test]
fn test_get_types() {
let text = "mmiissiissiippii\0";
let n = text.len();
let types_expected = "LLSSLLSSLLSSLLLLS";
let lms_expected = marks_to_lms(" * * * *");
let (types, lms) = get_types(text);
let types_actual = (0..n)
.map(|i| if types.get_bit(i) { 'S' } else { 'L' })
.collect::<String>();
assert_eq!(types_expected, types_actual);
assert_eq!(lms_expected, lms);
}
#[test]
fn test_get_types_zeros() {
let text = "m\0\0a\0";
let n = text.len();
let types_expected = "LSSLS".to_string();
let lms_expected = marks_to_lms(" * *");
let (types, lms) = get_types(text);
let types_actual = (0..n)
.map(|i| if types.get_bit(i) { 'S' } else { 'L' })
.collect::<String>();
assert_eq!(types_expected, types_actual);
assert_eq!(lms_expected, lms);
}
#[test]
fn test_get_bucket_start_pos() {
let text = "mmiissiissiippii\0";
let converter = RangeConverter::new(b'a', b'z');
let occs = count_chars(text, &converter);
let bucket_start_pos = get_bucket_start_pos(&occs);
let ans = vec![(b'\0', 0), (b'i', 1), (b'm', 9), (b'p', 11), (b's', 13)];
for (c, expected) in ans {
let actual = bucket_start_pos[converter.convert(c) as usize];
assert_eq!(
actual, expected,
"bucket_start_pos['{}'] should be {} but {}",
c as char, expected, actual
);
}
}
#[test]
fn test_get_bucket_end_pos() {
let text = "mmiissiissiippii\0";
let ans = vec![(b'\0', 1), (b'i', 9), (b'm', 11), (b'p', 13), (b's', 17)];
let converter = RangeConverter::new(b'a', b'z');
let occs = count_chars(text, &converter);
let bucket_end_pos = get_bucket_end_pos(&occs);
for (c, expected) in ans {
let actual = bucket_end_pos[converter.convert(c) as usize];
assert_eq!(
actual, expected,
"bucket_end_pos['{}'] should be {} but {}",
c as char, expected, actual
);
}
}
#[test]
#[should_panic(expected = "expected:")]
fn test_sais_no_trailing_zero() {
let text = "nozero".to_string().into_bytes();
let converter = RangeConverter::new(b'a', b'z');
sais(&text, &converter);
}
#[test]
fn test_sais_1() {
let text = &[0u8];
let sa = sais(&text, &IdConverter::new(4));
let expected = get_suffix_array(text);
assert_eq!(sa, expected);
}
#[test]
fn test_sais_2() {
let text = &[3u8, 0];
let sa = sais(&text, &IdConverter::new(4));
let expected = get_suffix_array(text);
assert_eq!(sa, expected);
}
#[test]
fn test_sais_4() {
let text = &[3u8, 2, 1, 0];
let sa = sais(&text, &IdConverter::new(4));
let expected = get_suffix_array(text);
assert_eq!(sa, expected);
}
#[test]
fn test_sais_with_nulls() {
let text = b"mm\0ii\0s\0sii\0ssii\0ppii\0".to_vec();
let sa = sais(&text, &RangeConverter::new(b'a', b'z'));
let expected = get_suffix_array(text);
assert_eq!(sa, expected);
}
#[test]
#[ignore]
fn test_sais_with_consecutive_nulls() {
let text = b"mm\0\0ii\0s\0\0\0sii\0ssii\0ppii\0".to_vec();
let sa = sais(&text, &RangeConverter::new(b'a', b'z'));
let expected = get_suffix_array(text);
assert_eq!(sa, expected);
}
#[test]
fn test_small() {
let mut text = "mmiissiissiippii".to_string().into_bytes();
text.push(0);
let converter = RangeConverter::new(b'a', b'z');
let sa = sais(&text, &converter);
let ans = get_suffix_array(text);
assert_eq!(sa.len(), ans.len());
for (i, (actual, expected)) in sa.into_iter().zip(ans.into_iter()).enumerate() {
assert_eq!(
actual, expected,
"wrong at {}-th pos: expected {}, but actual {}",
i, expected, actual
);
}
}
#[test]
fn test_sais_rand() {
let len = 100_000;
let prob = 1.0 / 4.0;
let mut rng: StdRng = SeedableRng::from_seed([0; 32]);
let mut text = (0..len)
.map(|_| if rng.gen_bool(prob) { b'a' } else { b'b' })
.collect::<Vec<_>>();
text.push(0);
let converter = RangeConverter::new(b'a', b'b');
let sa = sais(&text, &converter);
let ans = get_suffix_array(&text);
assert_eq!(sa.len(), ans.len());
for (i, (actual, expected)) in sa.into_iter().zip(ans.into_iter()).enumerate() {
assert_eq!(
actual, expected,
"wrong at {}-th pos: expected {}, but actual {}",
i, expected, actual
);
}
}
fn get_suffix_array<K: AsRef<[T]>, T: Copy + Clone + Ord>(text: K) -> Vec<u64> {
let text = text.as_ref();
let n = text.len();
let suffixes = (0..n).map(|i| &text[i..n]).collect::<Vec<_>>();
let mut sa = (0..(suffixes.len() as u64)).collect::<Vec<_>>();
sa.sort_by_key(|i| suffixes[*i as usize]);
sa
}
}