use core::mem::size_of;
use crate::memmem::{util::memcmp, vector::Vector, NeedleInfo};
pub(crate) const MIN_NEEDLE_LEN: usize = 2;
pub(crate) const MAX_NEEDLE_LEN: usize = 32;
#[derive(Clone, Copy, Debug)]
pub(crate) struct Forward {
rare1i: u8,
rare2i: u8,
}
impl Forward {
pub(crate) fn new(ninfo: &NeedleInfo, needle: &[u8]) -> Option<Forward> {
let (rare1i, rare2i) = ninfo.rarebytes.as_rare_ordered_u8();
if needle.len() < MIN_NEEDLE_LEN
|| needle.len() > MAX_NEEDLE_LEN
|| rare1i == rare2i
{
return None;
}
Some(Forward { rare1i, rare2i })
}
#[inline(always)]
pub(crate) fn min_haystack_len<V: Vector>(&self) -> usize {
self.rare2i as usize + size_of::<V>()
}
}
#[inline(always)]
pub(crate) unsafe fn fwd_find<V: Vector>(
fwd: &Forward,
haystack: &[u8],
needle: &[u8],
) -> Option<usize> {
if haystack.len() < needle.len() {
return None;
}
let min_haystack_len = fwd.min_haystack_len::<V>();
assert!(haystack.len() >= min_haystack_len, "haystack too small");
debug_assert!(needle.len() <= haystack.len());
debug_assert!(
needle.len() >= MIN_NEEDLE_LEN,
"needle must be at least {} bytes",
MIN_NEEDLE_LEN,
);
debug_assert!(
needle.len() <= MAX_NEEDLE_LEN,
"needle must be at most {} bytes",
MAX_NEEDLE_LEN,
);
let (rare1i, rare2i) = (fwd.rare1i as usize, fwd.rare2i as usize);
let rare1chunk = V::splat(needle[rare1i]);
let rare2chunk = V::splat(needle[rare2i]);
let start_ptr = haystack.as_ptr();
let end_ptr = start_ptr.add(haystack.len());
let max_ptr = end_ptr.sub(min_haystack_len);
let mut ptr = start_ptr;
while ptr <= max_ptr {
let m = fwd_find_in_chunk(
fwd, needle, ptr, end_ptr, rare1chunk, rare2chunk, !0,
);
if let Some(chunki) = m {
return Some(matched(start_ptr, ptr, chunki));
}
ptr = ptr.add(size_of::<V>());
}
if ptr < end_ptr {
let remaining = diff(end_ptr, ptr);
debug_assert!(
remaining < min_haystack_len,
"remaining bytes should be smaller than the minimum haystack \
length of {}, but there are {} bytes remaining",
min_haystack_len,
remaining,
);
if remaining < needle.len() {
return None;
}
debug_assert!(
max_ptr < ptr,
"after main loop, ptr should have exceeded max_ptr",
);
let overlap = diff(ptr, max_ptr);
debug_assert!(
overlap > 0,
"overlap ({}) must always be non-zero",
overlap,
);
debug_assert!(
overlap < size_of::<V>(),
"overlap ({}) cannot possibly be >= than a vector ({})",
overlap,
size_of::<V>(),
);
let mask = !((1 << overlap) - 1);
ptr = max_ptr;
let m = fwd_find_in_chunk(
fwd, needle, ptr, end_ptr, rare1chunk, rare2chunk, mask,
);
if let Some(chunki) = m {
return Some(matched(start_ptr, ptr, chunki));
}
}
None
}
#[inline(always)]
unsafe fn fwd_find_in_chunk<V: Vector>(
fwd: &Forward,
needle: &[u8],
ptr: *const u8,
end_ptr: *const u8,
rare1chunk: V,
rare2chunk: V,
mask: u32,
) -> Option<usize> {
let chunk0 = V::load_unaligned(ptr.add(fwd.rare1i as usize));
let chunk1 = V::load_unaligned(ptr.add(fwd.rare2i as usize));
let eq0 = chunk0.cmpeq(rare1chunk);
let eq1 = chunk1.cmpeq(rare2chunk);
let mut match_offsets = eq0.and(eq1).movemask() & mask;
while match_offsets != 0 {
let offset = match_offsets.trailing_zeros() as usize;
let ptr = ptr.add(offset);
if end_ptr.sub(needle.len()) < ptr {
return None;
}
let chunk = core::slice::from_raw_parts(ptr, needle.len());
if memcmp(needle, chunk) {
return Some(offset);
}
match_offsets &= match_offsets - 1;
}
None
}
#[cold]
#[inline(never)]
fn matched(start_ptr: *const u8, ptr: *const u8, chunki: usize) -> usize {
diff(ptr, start_ptr) + chunki
}
fn diff(a: *const u8, b: *const u8) -> usize {
debug_assert!(a >= b);
(a as usize) - (b as usize)
}