use std::cmp;
use aho_corasick::{AhoCorasick, AhoCorasickBuilder};
use syntax::hir::literal::Literals;
use vector::avx2::{AVX2VectorBuilder, u8x32};
const BLOCK_SIZE: usize = 32;
#[derive(Debug, Clone)]
pub struct Match {
pub pat: usize,
pub start: usize,
pub end: usize,
}
#[derive(Debug, Clone)]
pub struct Teddy {
vb: AVX2VectorBuilder,
pats: Vec<Vec<u8>>,
ac: AhoCorasick,
buckets: Vec<Vec<usize>>,
masks: Masks,
}
impl Teddy {
pub fn available() -> bool {
AVX2VectorBuilder::new().is_some()
}
pub fn new(pats: &Literals) -> Option<Teddy> {
let vb = match AVX2VectorBuilder::new() {
None => return None,
Some(vb) => vb,
};
if !Teddy::available() {
return None;
}
let pats: Vec<_> = pats.literals().iter().map(|p|p.to_vec()).collect();
let min_len = pats.iter().map(|p| p.len()).min().unwrap_or(0);
if min_len < 1 {
return None;
}
let nmasks = cmp::min(3, min_len);
let mut masks = Masks::new(vb, nmasks);
let mut buckets = vec![vec![]; 8];
for (pati, pat) in pats.iter().enumerate() {
let bucket = pati % 8;
buckets[bucket].push(pati);
masks.add(bucket as u8, pat);
}
let ac = AhoCorasickBuilder::new()
.dfa(true)
.prefilter(false)
.build(&pats);
Some(Teddy {
vb: vb,
pats: pats.to_vec(),
ac: ac,
buckets: buckets,
masks: masks,
})
}
pub fn patterns(&self) -> &[Vec<u8>] {
&self.pats
}
pub fn len(&self) -> usize {
self.pats.len()
}
pub fn approximate_size(&self) -> usize {
self.pats.iter().fold(0, |a, b| a + b.len())
}
pub fn find(&self, haystack: &[u8]) -> Option<Match> {
unsafe { self.find_impl(haystack) }
}
#[allow(unused_attributes)]
#[target_feature(enable = "avx2")]
unsafe fn find_impl(&self, haystack: &[u8]) -> Option<Match> {
if haystack.is_empty() || haystack.len() < (BLOCK_SIZE + 2) {
return self.slow(haystack, 0);
}
match self.masks.len() {
0 => None,
1 => self.find1(haystack),
2 => self.find2(haystack),
3 => self.find3(haystack),
_ => unreachable!(),
}
}
#[inline(always)]
fn find1(&self, haystack: &[u8]) -> Option<Match> {
let mut pos = 0;
let zero = self.vb.u8x32_splat(0);
let len = haystack.len();
debug_assert!(len >= BLOCK_SIZE);
while pos <= len - BLOCK_SIZE {
let h = unsafe {
let p = haystack.get_unchecked(pos..);
self.vb.u8x32_load_unchecked_unaligned(p)
};
let res0 = self.masks.members1(h);
let bitfield = res0.ne(zero).movemask();
if bitfield != 0 {
if let Some(m) = self.verify(haystack, pos, res0, bitfield) {
return Some(m);
}
}
pos += BLOCK_SIZE;
}
self.slow(haystack, pos)
}
#[inline(always)]
fn find2(&self, haystack: &[u8]) -> Option<Match> {
let zero = self.vb.u8x32_splat(0);
let len = haystack.len();
let mut prev0 = self.vb.u8x32_splat(0xFF);
let mut pos = 1;
debug_assert!(len >= BLOCK_SIZE);
while pos <= len - BLOCK_SIZE {
let h = unsafe {
let p = haystack.get_unchecked(pos..);
self.vb.u8x32_load_unchecked_unaligned(p)
};
let (res0, res1) = self.masks.members2(h);
let res0prev0 = res0.alignr_15(prev0);
let res = res0prev0.and(res1);
prev0 = res0;
let bitfield = res.ne(zero).movemask();
if bitfield != 0 {
let pos = pos.checked_sub(1).unwrap();
if let Some(m) = self.verify(haystack, pos, res, bitfield) {
return Some(m);
}
}
pos += BLOCK_SIZE;
}
self.slow(haystack, pos.checked_sub(1).unwrap())
}
#[inline(always)]
fn find3(&self, haystack: &[u8]) -> Option<Match> {
let zero = self.vb.u8x32_splat(0);
let len = haystack.len();
let mut prev0 = self.vb.u8x32_splat(0xFF);
let mut prev1 = self.vb.u8x32_splat(0xFF);
let mut pos = 2;
while pos <= len - BLOCK_SIZE {
let h = unsafe {
let p = haystack.get_unchecked(pos..);
self.vb.u8x32_load_unchecked_unaligned(p)
};
let (res0, res1, res2) = self.masks.members3(h);
let res0prev0 = res0.alignr_14(prev0);
let res1prev1 = res1.alignr_15(prev1);
let res = res0prev0.and(res1prev1).and(res2);
prev0 = res0;
prev1 = res1;
let bitfield = res.ne(zero).movemask();
if bitfield != 0 {
let pos = pos.checked_sub(2).unwrap();
if let Some(m) = self.verify(haystack, pos, res, bitfield) {
return Some(m);
}
}
pos += BLOCK_SIZE;
}
self.slow(haystack, pos.checked_sub(2).unwrap())
}
#[inline(always)]
fn verify(
&self,
haystack: &[u8],
pos: usize,
res: u8x32,
mut bitfield: u32,
) -> Option<Match> {
while bitfield != 0 {
let byte_pos = bitfield.trailing_zeros() as usize;
bitfield &= !(1 << byte_pos);
let start = pos + byte_pos;
let mut patterns = res.extract(byte_pos);
while patterns != 0 {
let bucket = patterns.trailing_zeros() as usize;
patterns &= !(1 << bucket);
if let Some(m) = self.verify_bucket(haystack, bucket, start) {
return Some(m);
}
}
}
None
}
#[inline(always)]
fn verify_bucket(
&self,
haystack: &[u8],
bucket: usize,
start: usize,
) -> Option<Match> {
for &pati in &self.buckets[bucket] {
let pat = &*self.pats[pati];
if start + pat.len() > haystack.len() {
continue;
}
if pat == &haystack[start..start + pat.len()] {
return Some(Match {
pat: pati,
start: start,
end: start + pat.len(),
});
}
}
None
}
#[inline(never)]
fn slow(&self, haystack: &[u8], pos: usize) -> Option<Match> {
self.ac.find(&haystack[pos..]).map(|m| {
Match {
pat: m.pattern(),
start: pos + m.start(),
end: pos + m.end(),
}
})
}
}
#[derive(Debug, Clone)]
struct Masks {
vb: AVX2VectorBuilder,
masks: [Mask; 3],
size: usize,
}
impl Masks {
fn new(vb: AVX2VectorBuilder, n: usize) -> Masks {
Masks {
vb: vb,
masks: [Mask::new(vb), Mask::new(vb), Mask::new(vb)],
size: n,
}
}
fn len(&self) -> usize {
self.size
}
fn add(&mut self, bucket: u8, pat: &[u8]) {
for i in 0..self.len() {
self.masks[i].add(bucket, pat[i]);
}
}
#[inline(always)]
fn members1(&self, haystack_block: u8x32) -> u8x32 {
let masklo = self.vb.u8x32_splat(0xF);
let hlo = haystack_block.and(masklo);
let hhi = haystack_block.bit_shift_right_4().and(masklo);
self.masks[0].lo.shuffle(hlo).and(self.masks[0].hi.shuffle(hhi))
}
#[inline(always)]
fn members2(&self, haystack_block: u8x32) -> (u8x32, u8x32) {
let masklo = self.vb.u8x32_splat(0xF);
let hlo = haystack_block.and(masklo);
let hhi = haystack_block.bit_shift_right_4().and(masklo);
let res0 =
self.masks[0].lo.shuffle(hlo).and(self.masks[0].hi.shuffle(hhi));
let res1 =
self.masks[1].lo.shuffle(hlo).and(self.masks[1].hi.shuffle(hhi));
(res0, res1)
}
#[inline(always)]
fn members3(&self, haystack_block: u8x32) -> (u8x32, u8x32, u8x32) {
let masklo = self.vb.u8x32_splat(0xF);
let hlo = haystack_block.and(masklo);
let hhi = haystack_block.bit_shift_right_4().and(masklo);
let res0 =
self.masks[0].lo.shuffle(hlo).and(self.masks[0].hi.shuffle(hhi));
let res1 =
self.masks[1].lo.shuffle(hlo).and(self.masks[1].hi.shuffle(hhi));
let res2 =
self.masks[2].lo.shuffle(hlo).and(self.masks[2].hi.shuffle(hhi));
(res0, res1, res2)
}
}
#[derive(Debug, Clone, Copy)]
struct Mask {
lo: u8x32,
hi: u8x32,
}
impl Mask {
fn new(vb: AVX2VectorBuilder) -> Mask {
Mask {
lo: vb.u8x32_splat(0),
hi: vb.u8x32_splat(0),
}
}
fn add(&mut self, bucket: u8, byte: u8) {
let byte_lo = (byte & 0xF) as usize;
let byte_hi = (byte >> 4) as usize;
let lo = self.lo.extract(byte_lo) | ((1 << bucket) as u8);
self.lo.replace(byte_lo, lo);
self.lo.replace(byte_lo + 16, lo);
let hi = self.hi.extract(byte_hi) | ((1 << bucket) as u8);
self.hi.replace(byte_hi, hi);
self.hi.replace(byte_hi + 16, hi);
}
}