use std::cmp::{max, min};
use std::collections::{hash_map::Entry, HashMap, VecDeque};
pub struct Trie<C: std::hash::Hash + Eq> {
links: Vec<HashMap<C, usize>>,
}
impl<C: std::hash::Hash + Eq> Default for Trie<C> {
fn default() -> Self {
Self {
links: vec![HashMap::new()],
}
}
}
impl<C: std::hash::Hash + Eq> Trie<C> {
pub fn insert(&mut self, word: impl IntoIterator<Item = C>) -> usize {
let mut node = 0;
for ch in word {
let len = self.links.len();
node = match self.links[node].entry(ch) {
Entry::Occupied(entry) => *entry.get(),
Entry::Vacant(entry) => {
entry.insert(len);
self.links.push(HashMap::new());
len
}
}
}
node
}
pub fn get(&self, word: impl IntoIterator<Item = C>) -> Option<usize> {
let mut node = 0;
for ch in word {
node = *self.links[node].get(&ch)?;
}
Some(node)
}
}
pub struct Matcher<'a, C: Eq> {
pub pattern: &'a [C],
pub fail: Vec<usize>,
}
impl<'a, C: Eq> Matcher<'a, C> {
pub fn new(pattern: &'a [C]) -> Self {
let mut fail = Vec::with_capacity(pattern.len());
fail.push(0);
let mut len = 0;
for ch in &pattern[1..] {
while len > 0 && pattern[len] != *ch {
len = fail[len - 1];
}
if pattern[len] == *ch {
len += 1;
}
fail.push(len);
}
Self { pattern, fail }
}
pub fn kmp_match(&self, text: &[C]) -> Vec<usize> {
let mut match_lens = Vec::with_capacity(text.len());
let mut len = 0;
for ch in text {
if len == self.pattern.len() {
len = self.fail[len - 1];
}
while len > 0 && self.pattern[len] != *ch {
len = self.fail[len - 1];
}
if self.pattern[len] == *ch {
len += 1;
}
match_lens.push(len);
}
match_lens
}
}
pub struct MultiMatcher<C: std::hash::Hash + Eq> {
pub trie: Trie<C>,
pub pat_id: Vec<Option<usize>>,
pub fail: Vec<usize>,
pub fast: Vec<usize>,
}
impl<C: std::hash::Hash + Eq> MultiMatcher<C> {
fn next(trie: &Trie<C>, fail: &[usize], mut node: usize, ch: &C) -> usize {
loop {
if let Some(&child) = trie.links[node].get(ch) {
return child;
} else if node == 0 {
return 0;
}
node = fail[node];
}
}
pub fn new(patterns: Vec<impl IntoIterator<Item = C>>) -> Self {
let mut trie = Trie::default();
let pat_nodes: Vec<usize> = patterns.into_iter().map(|pat| trie.insert(pat)).collect();
let mut pat_id = vec![None; trie.links.len()];
for (i, node) in pat_nodes.into_iter().enumerate() {
pat_id[node] = Some(i);
}
let mut fail = vec![0; trie.links.len()];
let mut fast = vec![0; trie.links.len()];
let mut q: VecDeque<usize> = trie.links[0].values().cloned().collect();
while let Some(node) = q.pop_front() {
for (ch, &child) in &trie.links[node] {
let nx = Self::next(&trie, &fail, fail[node], &ch);
fail[child] = nx;
fast[child] = if pat_id[nx].is_some() { nx } else { fast[nx] };
q.push_back(child);
}
}
Self {
trie,
pat_id,
fail,
fast,
}
}
pub fn ac_match(&self, text: &[C]) -> Vec<usize> {
let mut match_nodes = Vec::with_capacity(text.len());
let mut node = 0;
for ch in text {
node = Self::next(&self.trie, &self.fail, node, &ch);
match_nodes.push(node);
}
match_nodes
}
pub fn get_end_pos_and_pat_id(&self, match_nodes: &[usize]) -> Vec<(usize, usize)> {
let mut res = vec![];
for (text_pos, &(mut node)) in match_nodes.iter().enumerate() {
while node != 0 {
if let Some(id) = self.pat_id[node] {
res.push((text_pos + 1, id));
}
node = self.fast[node];
}
}
res
}
}
pub struct SuffixArray {
pub sfx: Vec<usize>,
pub rank: Vec<Vec<usize>>,
}
impl SuffixArray {
fn counting_sort(
vals: impl Iterator<Item = usize> + Clone,
val_to_key: &[usize],
max_key: usize,
) -> Vec<usize> {
let mut counts = vec![0; max_key];
for v in vals.clone() {
counts[val_to_key[v]] += 1;
}
let mut total = 0;
for c in counts.iter_mut() {
total += *c;
*c = total - *c;
}
let mut result = vec![0; total];
for v in vals {
let c = &mut counts[val_to_key[v]];
result[*c] = v;
*c += 1;
}
result
}
pub fn new(text: &[u8]) -> Self {
let n = text.len();
let init_rank = text.iter().map(|&ch| ch as usize).collect::<Vec<_>>();
let mut sfx = Self::counting_sort(0..n, &init_rank, 256);
let mut rank = vec![init_rank];
for skip in (0..).map(|i| 1 << i).take_while(|&skip| skip < n) {
let prev_rank = rank.last().unwrap();
let mut cur_rank = prev_rank.clone();
let pos = (n - skip..n).chain(sfx.into_iter().filter_map(|p| p.checked_sub(skip)));
sfx = Self::counting_sort(pos, &prev_rank, max(n, 256));
let mut prev = sfx[0];
cur_rank[prev] = 0;
for &cur in sfx.iter().skip(1) {
if max(prev, cur) + skip < n
&& prev_rank[prev] == prev_rank[cur]
&& prev_rank[prev + skip] == prev_rank[cur + skip]
{
cur_rank[cur] = cur_rank[prev];
} else {
cur_rank[cur] = cur_rank[prev] + 1;
}
prev = cur;
}
rank.push(cur_rank);
}
Self { sfx, rank }
}
pub fn longest_common_prefix(&self, mut i: usize, mut j: usize) -> usize {
let mut len = 0;
for (k, rank) in self.rank.iter().enumerate().rev() {
if rank[i] == rank[j] {
i += 1 << k;
j += 1 << k;
len += 1 << k;
if max(i, j) >= self.sfx.len() {
break;
}
}
}
len
}
}
pub fn palindromes<T: Eq>(text: &[T]) -> Vec<usize> {
let mut pal = Vec::with_capacity(2 * text.len() - 1);
pal.push(1);
while pal.len() < pal.capacity() {
let i = pal.len() - 1;
let max_len = min(i + 1, pal.capacity() - i);
while pal[i] < max_len && text[(i - pal[i] - 1) / 2] == text[(i + pal[i] + 1) / 2] {
pal[i] += 2;
}
if let Some(a) = 1usize.checked_sub(pal[i]) {
pal.push(a);
} else {
for d in 1.. {
let (a, b) = (pal[i - d], pal[i] - d);
if a < b {
pal.push(a);
} else {
pal.push(b);
break;
}
}
}
}
pal
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_trie() {
let dict = vec!["banana", "benefit", "banapple", "ban"];
let trie = dict.into_iter().fold(Trie::default(), |mut trie, word| {
trie.insert(word.bytes());
trie
});
assert_eq!(trie.get("".bytes()), Some(0));
assert_eq!(trie.get("b".bytes()), Some(1));
assert_eq!(trie.get("banana".bytes()), Some(6));
assert_eq!(trie.get("be".bytes()), Some(7));
assert_eq!(trie.get("bane".bytes()), None);
}
#[test]
fn test_kmp_matching() {
let text = b"banana";
let pattern = b"ana";
let matches = Matcher::new(pattern).kmp_match(text);
assert_eq!(matches, vec![0, 1, 2, 3, 2, 3]);
}
#[test]
fn test_ac_matching() {
let text = b"banana bans, apple benefits.";
let dict = vec![
"banana".bytes(),
"benefit".bytes(),
"banapple".bytes(),
"ban".bytes(),
"fit".bytes(),
];
let matcher = MultiMatcher::new(dict);
let match_nodes = matcher.ac_match(text);
let end_pos_and_id = matcher.get_end_pos_and_pat_id(&match_nodes);
assert_eq!(
end_pos_and_id,
vec![(3, 3), (6, 0), (10, 3), (26, 1), (26, 4)]
);
}
#[test]
fn test_suffix_array() {
let text1 = b"bobocel";
let text2 = b"banana";
let sfx1 = SuffixArray::new(text1);
let sfx2 = SuffixArray::new(text2);
assert_eq!(sfx1.sfx, vec![0, 2, 4, 5, 6, 1, 3]);
assert_eq!(sfx2.sfx, vec![5, 3, 1, 0, 4, 2]);
assert_eq!(sfx1.longest_common_prefix(0, 2), 2);
assert_eq!(sfx2.longest_common_prefix(1, 3), 3);
for (p, &r) in sfx1.rank.last().unwrap().iter().enumerate() {
assert_eq!(sfx1.sfx[r], p);
}
for (p, &r) in sfx2.rank.last().unwrap().iter().enumerate() {
assert_eq!(sfx2.sfx[r], p);
}
}
#[test]
fn test_palindrome() {
let text = b"banana";
let pal_len = palindromes(text);
assert_eq!(pal_len, vec![1, 0, 1, 0, 3, 0, 5, 0, 3, 0, 1]);
}
}