use std::arch::aarch64::*;
use crate::prefilter::{case_needle, scalar};
#[inline(always)]
unsafe fn overlapping_load(haystack: &[u8], start: usize, len: usize) -> uint8x16_t {
unsafe {
match len {
0..=7 => unreachable!(),
8 => {
let low = vld1_u8(haystack.as_ptr());
vcombine_u8(low, vdup_n_u8(0))
}
9..=15 => {
let low = vld1_u8(haystack.as_ptr());
let high_start = len - 8;
let high = vld1_u8(haystack[high_start..].as_ptr());
vcombine_u8(low, high)
}
16 => vld1q_u8(haystack.as_ptr()),
_ => vld1q_u8(haystack[start.min(len - 16)..].as_ptr()),
}
}
}
#[derive(Debug, Clone)]
pub struct PrefilterNEON {
needle: Vec<(u8, u8)>,
}
impl PrefilterNEON {
#[inline]
pub fn new(needle: &[u8]) -> Self {
Self {
needle: case_needle(needle),
}
}
#[inline]
#[target_feature(enable = "neon")]
pub unsafe fn match_haystack(&self, haystack: &[u8]) -> (bool, usize) {
let len = haystack.len();
match len {
0 => return (true, 0),
1..=7 => {
return (scalar::match_haystack(&self.needle, haystack), 0);
}
_ => {}
};
let mut can_skip_chunks = true;
let mut skipped_chunks = 0;
let mut needle_iter = self
.needle
.iter()
.map(|&(c1, c2)| (vdupq_n_u8(c1), vdupq_n_u8(c2)));
let mut needle_char = needle_iter.next().unwrap();
for start in (0..len).step_by(16) {
let haystack_chunk = unsafe { overlapping_load(haystack, start, len) };
loop {
let mask = vmaxvq_u8(vorrq_u8(
vceqq_u8(needle_char.1, haystack_chunk),
vceqq_u8(needle_char.0, haystack_chunk),
));
if mask == 0 {
break;
}
if let Some(next_needle_char) = needle_iter.next() {
if can_skip_chunks {
skipped_chunks = start / 16;
}
can_skip_chunks = false;
needle_char = next_needle_char;
} else {
return (true, skipped_chunks);
}
}
}
(false, skipped_chunks)
}
#[inline]
#[target_feature(enable = "neon")]
pub unsafe fn match_haystack_typos(&self, haystack: &[u8], max_typos: u16) -> (bool, usize) {
let len = haystack.len();
match len {
0 => return (true, 0),
1..=7 => {
return (
scalar::match_haystack_typos(&self.needle, haystack, max_typos),
0,
);
}
_ => {}
};
if max_typos >= 3 {
return (true, 0);
}
let mut needle_iter = self
.needle
.iter()
.map(|&(c1, c2)| (vdupq_n_u8(c1), vdupq_n_u8(c2)));
let mut needle_char = needle_iter.next().unwrap();
let mut typos = 0;
loop {
for start in (0..len).step_by(16) {
let haystack_chunk = unsafe { overlapping_load(haystack, start, len) };
loop {
let mask = vmaxvq_u8(vorrq_u8(
vceqq_u8(needle_char.1, haystack_chunk),
vceqq_u8(needle_char.0, haystack_chunk),
));
if mask == 0 {
break;
}
if let Some(next_needle_char) = needle_iter.next() {
needle_char = next_needle_char;
} else {
return (true, 0);
}
}
}
typos += 1;
if typos > max_typos as usize {
return (false, 0);
}
if let Some(next_needle_char) = needle_iter.next() {
needle_char = next_needle_char;
} else {
return (true, 0);
}
}
}
}