use std::borrow::Borrow;
use std::iter::Enumerate;
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
pub struct ShiftAnd {
m: usize,
masks: [u64; 256],
accept: u64,
}
impl ShiftAnd {
pub fn new<C, P>(pattern: P) -> Self
where
P::IntoIter: ExactSizeIterator,
C: Borrow<u8>,
P: IntoIterator<Item = C>,
{
let pattern = pattern.into_iter();
let m = pattern.len();
assert!(m <= 64, "Expecting a pattern of at most 64 symbols.");
let (masks, accept) = masks(pattern);
ShiftAnd { m, masks, accept }
}
pub fn find_all<C, T>(&self, text: T) -> Matches<'_, C, T::IntoIter>
where
C: Borrow<u8>,
T: IntoIterator<Item = C>,
{
Matches {
shiftand: self,
active: 0,
text: text.into_iter().enumerate(),
}
}
}
pub fn masks<C, P>(pattern: P) -> ([u64; 256], u64)
where
C: Borrow<u8>,
P: IntoIterator<Item = C>,
{
let mut masks = [0; 256];
let mut bit = 1;
for c in pattern {
masks[*c.borrow() as usize] |= bit;
bit *= 2;
}
(masks, bit / 2)
}
#[derive(Clone, Debug)]
pub struct Matches<'a, C, T>
where
C: Borrow<u8>,
T: Iterator<Item = C>,
{
shiftand: &'a ShiftAnd,
active: u64,
text: Enumerate<T>,
}
impl<'a, C, T> Iterator for Matches<'a, C, T>
where
C: Borrow<u8>,
T: Iterator<Item = C>,
{
type Item = usize;
fn next(&mut self) -> Option<usize> {
for (i, c) in self.text.by_ref() {
self.active = ((self.active << 1) | 1) & self.shiftand.masks[*c.borrow() as usize];
if self.active & self.shiftand.accept > 0 {
return Some(i + 1 - self.shiftand.m);
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use itertools::Itertools;
#[test]
fn test_find_all() {
let text = b"dhjalkjwqnnnannanaflkjdklfj";
let pattern = b"qnnnannan";
let shiftand = ShiftAnd::new(pattern);
assert_eq!(shiftand.find_all(text).collect_vec(), [8]);
}
#[test]
fn test_issue_416() {
let text_pos_0 = b"CCTTTTTTTTTTTTTTT";
let pattern = b"CC";
let shiftand = ShiftAnd::new(pattern);
assert_eq!(shiftand.find_all(text_pos_0).collect_vec(), [0]);
}
#[test]
fn test_multiple_finds() {
let text = b"CCTCCTCC";
let pattern = b"CC";
let shiftand = ShiftAnd::new(pattern);
assert_eq!(shiftand.find_all(text).collect_vec(), [0, 3, 6]);
}
}