use std::borrow::Borrow;
use std::iter::{repeat, Enumerate};
use crate::utils::TextSlice;
type Lps = Vec<usize>;
#[derive(Default, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Serialize)]
pub struct KMP<'a> {
m: usize,
lps: Lps,
pattern: TextSlice<'a>,
}
impl<'a> KMP<'a> {
pub fn new(pattern: TextSlice<'a>) -> Self {
let m = pattern.len();
let lps = lps(pattern);
KMP { lps, m, pattern }
}
fn delta(&self, mut q: usize, a: u8) -> usize {
while q == self.m || (self.pattern[q] != a && q > 0) {
q = self.lps[q - 1];
}
if self.pattern[q] == a {
q += 1;
}
q
}
pub fn find_all<C, T>(&self, text: T) -> Matches<C, T::IntoIter>
where
C: Borrow<u8>,
T: IntoIterator<Item = C>,
{
Matches {
kmp: self,
q: 0,
text: text.into_iter().enumerate(),
}
}
}
fn lps(pattern: &[u8]) -> Lps {
let (m, mut q) = (pattern.len(), 0);
let mut lps: Lps = repeat(0).take(m).collect();
for i in 1..m {
while q > 0 && pattern[q] != pattern[i] {
q = lps[q - 1];
}
if pattern[q] == pattern[i] {
q += 1;
}
lps[i] = q;
}
lps
}
#[derive(Clone, Debug)]
pub struct Matches<'a, C, T>
where
C: Borrow<u8>,
T: Iterator<Item = C>,
{
kmp: &'a KMP<'a>,
q: usize,
text: Enumerate<T>,
}
impl<'a, C, T> Iterator for Matches<'a, C, T>
where
C: Borrow<u8>,
T: Iterator<Item = C>,
{
type Item = usize;
fn next(&mut self) -> Option<usize> {
for (i, c) in self.text.by_ref() {
self.q = self.kmp.delta(self.q, *c.borrow());
if self.q == self.kmp.m {
return Some(1 + i - self.kmp.m);
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::{lps, KMP};
use itertools::Itertools;
#[test]
fn test_find_all() {
let text = b"dhjalkjwqnnnannanaflkjdklfj";
let pattern = b"qnnnannan";
let kmp = KMP::new(pattern);
assert_eq!(kmp.find_all(text).collect_vec(), [8]);
}
#[test]
fn test_find_all_at_start() {
let text = b"dhjalkjwqnnnannanaflkjdklfj";
let pattern = b"dhjalk";
let kmp = KMP::new(pattern);
assert_eq!(kmp.find_all(text).collect_vec(), [0]);
}
#[test]
fn test_lps() {
let pattern = b"ababaca";
let lps = lps(pattern);
assert_eq!(lps, [0, 0, 1, 2, 3, 0, 1]);
}
#[test]
fn test_delta() {
let pattern = b"abbab";
let kmp = KMP::new(pattern);
assert_eq!(kmp.delta(0, b'a'), 1);
assert_eq!(kmp.delta(0, b'b'), 0);
assert_eq!(kmp.delta(1, b'a'), 1);
assert_eq!(kmp.delta(1, b'b'), 2);
assert_eq!(kmp.delta(2, b'a'), 1);
assert_eq!(kmp.delta(2, b'b'), 3);
assert_eq!(kmp.delta(3, b'a'), 4);
assert_eq!(kmp.delta(3, b'b'), 0);
assert_eq!(kmp.delta(4, b'a'), 1);
assert_eq!(kmp.delta(4, b'b'), 5);
assert_eq!(kmp.delta(5, b'a'), 1);
assert_eq!(kmp.delta(5, b'b'), 3);
}
}