extern crate utf16_literal;
use crate::par_quicksort::par_sort_unstable_by_key;
use anyhow::Result;
use rand::distributions::{Distribution, WeightedIndex};
use rand::thread_rng;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::{fmt, ops::Deref, u64};
#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
pub struct SuffixTable<T = Box<[u16]>, U = Box<[u64]>> {
text: T,
table: U,
}
impl SuffixTable<Box<[u16]>, Box<[u64]>> {
pub fn new<S>(src: S, verbose: bool) -> Self
where
S: Into<Box<[u16]>>,
{
let text = src.into();
let mut table: Vec<_> = (0..text.len() as u64).collect();
par_sort_unstable_by_key(&mut table[..], |&i| &text[i as usize..], verbose);
SuffixTable {
text,
table: table.into(),
}
}
}
impl<T, U> SuffixTable<T, U>
where
T: Deref<Target = [u16]> + Sync,
U: Deref<Target = [u64]> + Sync,
{
pub fn from_parts(text: T, table: U) -> Self {
SuffixTable { text, table }
}
pub fn into_parts(self) -> (T, U) {
(self.text, self.table)
}
#[inline]
#[allow(dead_code)]
pub fn len(&self) -> usize {
self.table.len()
}
#[inline]
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
#[allow(dead_code)]
pub fn suffix(&self, i: usize) -> &[u16] {
&self.text[self.table[i] as usize..]
}
#[allow(dead_code)]
pub fn contains(&self, query: &[u16]) -> bool {
!query.is_empty()
&& self
.table
.binary_search_by(|&sufi| {
self.text[sufi as usize..]
.iter()
.take(query.len())
.cmp(query.iter())
})
.is_ok()
}
#[allow(dead_code)]
pub fn positions(&self, query: &[u16]) -> &[u64] {
if self.text.is_empty()
|| query.is_empty()
|| (query < self.suffix(0) && !self.suffix(0).starts_with(query))
|| query > self.suffix(self.len() - 1)
{
return &[];
}
let start = binary_search(&self.table, |&sufi| query <= &self.text[sufi as usize..]);
let end = start
+ binary_search(&self.table[start..], |&sufi| {
!self.text[sufi as usize..].starts_with(query)
});
if start > end {
&[]
} else {
&self.table[start..end]
}
}
fn boundaries(&self, query: &[u16]) -> (usize, usize) {
if self.text.is_empty()
|| query.is_empty()
|| (query < self.suffix(0) && !self.suffix(0).starts_with(query))
|| query > self.suffix(self.len() - 1)
{
return (0, self.table.len());
}
let start = binary_search(&self.table, |&sufi| query <= &self.text[sufi as usize..]);
let end = start
+ binary_search(&self.table[start..], |&sufi| {
!self.text[sufi as usize..].starts_with(query)
});
(start, end)
}
fn range_positions(&self, query: &[u16], range_start: usize, range_end: usize) -> &[u64] {
if self.text.is_empty()
|| query.is_empty()
|| (query < self.suffix(range_start) && !self.suffix(range_start).starts_with(query))
|| query > self.suffix(std::cmp::max(0, range_end - 1))
{
return &[];
}
let start = binary_search(&self.table[range_start..range_end], |&sufi| {
query <= &self.text[sufi as usize..]
});
let end = start
+ binary_search(&self.table[range_start + start..range_end], |&sufi| {
!self.text[sufi as usize..].starts_with(query)
});
if start > end {
&[]
} else {
&self.table[range_start + start..range_start + end]
}
}
pub fn count_next(&self, query: &[u16], vocab: Option<u16>) -> Vec<usize> {
let mut counts: Vec<usize> = vec![0usize; vocab.unwrap_or(u16::MAX) as usize + 1];
let mut suffixed_query = query.to_vec();
let (range_start, range_end) = self.boundaries(query);
for (i, count) in counts.iter_mut().enumerate() {
suffixed_query.push(i as u16);
let positions = self.range_positions(&suffixed_query, range_start, range_end);
*count = positions.len();
suffixed_query.pop();
}
counts
}
pub fn batch_count_next(&self, queries: &[Vec<u16>], vocab: Option<u16>) -> Vec<Vec<usize>> {
queries
.into_par_iter()
.map(|query| self.count_next(query, vocab))
.collect()
}
pub fn sample(&self, query: &[u16], n: usize, k: usize) -> Result<Vec<u16>> {
let mut rng = thread_rng();
let mut sequence = Vec::from(query);
for _ in 0..k {
let start = sequence.len().saturating_sub(n - 1);
let prev = &sequence[start..];
let counts: Vec<usize> = self.count_next(prev, None);
let dist = WeightedIndex::new(&counts)?;
let sampled_index = dist.sample(&mut rng);
sequence.push(sampled_index as u16);
}
Ok(sequence)
}
pub fn batch_sample(
&self,
query: &[u16],
n: usize,
k: usize,
num_samples: usize,
) -> Result<Vec<Vec<u16>>> {
(0..num_samples)
.into_par_iter()
.map(|_| self.sample(query, n, k))
.collect()
}
pub fn is_sorted(&self) -> bool {
self.table
.windows(2)
.all(|pair| self.text[pair[0] as usize..] <= self.text[pair[1] as usize..])
}
}
impl fmt::Debug for SuffixTable {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "\n-----------------------------------------")?;
writeln!(f, "SUFFIX TABLE")?;
for (rank, &sufstart) in self.table.iter().enumerate() {
writeln!(f, "suffix[{}] {}", rank, sufstart,)?;
}
writeln!(f, "-----------------------------------------")
}
}
#[allow(dead_code)]
fn binary_search<T, F>(xs: &[T], mut pred: F) -> usize
where
F: FnMut(&T) -> bool,
{
let (mut left, mut right) = (0, xs.len());
while left < right {
let mid = (left + right) / 2;
if pred(&xs[mid]) {
right = mid;
} else {
left = mid + 1;
}
}
left
}
#[cfg(test)]
mod tests {
use super::*;
use utf16_literal::utf16;
fn sais(text: &str) -> SuffixTable {
SuffixTable::new(text.encode_utf16().collect::<Vec<_>>(), false)
}
#[test]
fn count_next_exists() {
let sa = sais("aaab");
let query = utf16!("a");
let a_index = utf16!("a")[0] as usize;
let b_index = utf16!("b")[0] as usize;
assert_eq!(2, sa.count_next(query, Option::None)[a_index]);
assert_eq!(1, sa.count_next(query, Option::None)[b_index]);
}
#[test]
fn count_next_empty_query() {
let sa = sais("aaab");
let query = utf16!("");
let a_index = utf16!("a")[0] as usize;
let b_index = utf16!("b")[0] as usize;
assert_eq!(3, sa.count_next(query, Option::None)[a_index]);
assert_eq!(1, sa.count_next(query, Option::None)[b_index]);
}
#[test]
fn batch_count_next_exists() {
let sa = sais("aaab");
let queries: Vec<Vec<u16>> = vec![vec![utf16!("a")[0]; 1]; 10_000];
let a_index = utf16!("a")[0] as usize;
let b_index = utf16!("b")[0] as usize;
assert_eq!(2, sa.batch_count_next(&queries, Option::None)[0][a_index]);
assert_eq!(1, sa.batch_count_next(&queries, Option::None)[0][b_index]);
}
}