use super::SearchResult;
use super::common::{
advance_step, build_skip_table, early_search_result, find_match_in_mask, tail_match,
};
use std::arch::aarch64::*;
use std::collections::HashSet;
const LANE_BYTES: usize = 16;
#[target_feature(enable = "neon")]
pub unsafe fn search(haystack: &[u8], needle: &[u8]) -> SearchResult {
unsafe {
if let Some(result) = early_search_result(haystack, needle, search_single_byte_neon) {
return result;
}
let skip_table = build_skip_table(needle);
let needle_len = needle.len();
let first_vec = vdupq_n_u8(needle[0]);
let mut pos = 0;
while pos <= haystack.len().saturating_sub(needle_len) {
let at_tail = pos + LANE_BYTES > haystack.len();
if let Some(candidate_pos) =
scan_window_neon(haystack, needle, pos, needle_len, first_vec, at_tail)
{
return Some(candidate_pos);
}
pos += advance_step(haystack, needle_len, pos, &skip_table, at_tail);
}
None
}
}
#[target_feature(enable = "neon")]
unsafe fn scan_window_neon(
haystack: &[u8],
needle: &[u8],
pos: usize,
needle_len: usize,
first_vec: uint8x16_t,
at_tail: bool,
) -> Option<usize> {
unsafe {
if at_tail {
return tail_match(haystack, needle, pos, needle_len);
}
scan_candidates_neon(haystack, needle, pos, needle_len, first_vec)
}
}
#[target_feature(enable = "neon")]
unsafe fn scan_candidates_neon(
haystack: &[u8],
needle: &[u8],
pos: usize,
needle_len: usize,
first_vec: uint8x16_t,
) -> Option<usize> {
unsafe {
let hay_vec = vld1q_u8(haystack[pos..].as_ptr());
let first_cmp = vceqq_u8(hay_vec, first_vec);
let first_mask = neon_movemask(first_cmp);
if first_mask == 0 {
return None;
}
find_match_in_mask(
haystack,
needle,
pos,
needle_len,
first_mask as u32,
LANE_BYTES,
verify_match_neon,
)
}
}
#[inline]
unsafe fn neon_movemask(v: uint8x16_t) -> u16 {
unsafe {
let shift = vshrq_n_u8::<7>(v);
let bytes: [u8; 16] = std::mem::transmute(shift);
let mut result = 0u16;
for i in 0..16 {
result |= ((bytes[i] & 1) as u16) << i;
}
result
}
}
#[target_feature(enable = "neon")]
unsafe fn verify_match_neon(haystack: &[u8], needle: &[u8], pos: usize) -> bool {
unsafe {
let mut offset = 0;
while offset + 16 <= needle.len() {
let hay_vec = vld1q_u8(haystack[pos + offset..].as_ptr());
let needle_vec = vld1q_u8(needle[offset..].as_ptr());
let cmp = vceqq_u8(hay_vec, needle_vec);
let cmp64 = vreinterpretq_u64_u8(cmp);
let low = vgetq_lane_u64::<0>(cmp64);
let high = vgetq_lane_u64::<1>(cmp64);
if low != 0xFFFFFFFFFFFFFFFF || high != 0xFFFFFFFFFFFFFFFF {
return false;
}
offset += 16;
}
haystack[pos + offset..pos + needle.len()] == needle[offset..]
}
}
#[target_feature(enable = "neon")]
unsafe fn search_single_byte_neon(haystack: &[u8], byte: u8) -> SearchResult {
unsafe {
let byte_vec = vdupq_n_u8(byte);
let mut pos = 0;
while pos + 16 <= haystack.len() {
let hay_vec = vld1q_u8(haystack[pos..].as_ptr());
let cmp = vceqq_u8(hay_vec, byte_vec);
let mask = neon_movemask(cmp);
if mask != 0 {
return Some(pos + mask.trailing_zeros() as usize);
}
pos += 16;
}
haystack[pos..]
.iter()
.position(|&b| b == byte)
.map(|i| pos + i)
}
}
#[target_feature(enable = "neon")]
unsafe fn is_ascii_neon(bytes: &[u8]) -> bool {
unsafe {
let mut pos = 0;
let threshold = vdupq_n_u8(0x80);
while pos + 16 <= bytes.len() {
let chunk = vld1q_u8(bytes[pos..].as_ptr());
let cmp = vcgeq_u8(chunk, threshold);
if neon_movemask(cmp) != 0 {
return false;
}
pos += 16;
}
bytes[pos..].iter().all(u8::is_ascii)
}
}
#[target_feature(enable = "neon")]
pub unsafe fn extract_trigrams(text: &str) -> Vec<String> {
if text.len() < 3 {
return vec![text.to_string()];
}
if unsafe { is_ascii_neon(text.as_bytes()) } {
return super::common::extract_trigrams_ascii_fast(text);
}
let chars: Vec<char> = text.chars().collect();
if chars.len() < 3 {
return vec![text.to_string()];
}
let mut trigrams = Vec::with_capacity(chars.len().saturating_sub(2));
let mut seen = HashSet::new();
for i in 0..=chars.len().saturating_sub(3) {
let trigram: String = chars[i..i + 3].iter().collect();
if seen.insert(trigram.clone()) {
trigrams.push(trigram);
}
}
trigrams
}
#[target_feature(enable = "neon")]
pub unsafe fn to_lowercase_ascii(text: &str) -> String {
unsafe {
let bytes = text.as_bytes();
let mut result = Vec::with_capacity(bytes.len());
let upper_a = vdupq_n_u8(b'A');
let upper_z = vdupq_n_u8(b'Z');
let to_lower_offset = vdupq_n_u8(32);
let mut pos = 0;
while pos + 16 <= bytes.len() {
let chunk = vld1q_u8(bytes[pos..].as_ptr());
let ge_a = vcgeq_u8(chunk, upper_a);
let le_z = vcleq_u8(chunk, upper_z);
let is_upper = vandq_u8(ge_a, le_z);
let offset_masked = vandq_u8(is_upper, to_lower_offset);
let lowercased = vaddq_u8(chunk, offset_masked);
let mut temp = [0u8; 16];
vst1q_u8(temp.as_mut_ptr(), lowercased);
result.extend_from_slice(&temp);
pos += 16;
}
for &byte in &bytes[pos..] {
result.push(byte.to_ascii_lowercase());
}
String::from_utf8_unchecked(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_neon_search_basic() {
unsafe {
assert_eq!(search(b"hello world", b"world"), Some(6));
assert_eq!(search(b"hello", b"xyz"), None);
}
}
#[test]
fn test_neon_search_single_byte() {
unsafe {
assert_eq!(search(b"hello", b"h"), Some(0));
assert_eq!(search(b"hello", b"o"), Some(4));
assert_eq!(search(b"hello", b"x"), None);
}
}
#[test]
fn test_neon_search_long_haystack() {
unsafe {
let haystack = b"abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
assert_eq!(search(haystack, b"xyz"), Some(23));
assert_eq!(search(haystack, b"XYZ"), Some(59));
}
}
#[test]
fn test_neon_search_repeated_pattern() {
unsafe {
assert_eq!(search(b"aaaaaaaaaa", b"aa"), Some(0));
}
}
#[test]
fn test_neon_lowercase_basic() {
unsafe {
assert_eq!(to_lowercase_ascii("HELLO"), "hello");
assert_eq!(to_lowercase_ascii("HeLLo"), "hello");
assert_eq!(to_lowercase_ascii("hello"), "hello");
}
}
#[test]
fn test_neon_lowercase_long_string() {
unsafe {
let input = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
let expected = "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz";
assert_eq!(to_lowercase_ascii(input), expected);
}
}
#[test]
fn test_neon_lowercase_mixed() {
unsafe {
assert_eq!(to_lowercase_ascii("ABC123xyz"), "abc123xyz");
}
}
#[test]
fn test_neon_trigram_basic() {
unsafe {
let mut trigrams = extract_trigrams("hello");
trigrams.sort();
assert_eq!(trigrams, vec!["ell", "hel", "llo"]);
}
}
#[test]
fn test_neon_trigram_short() {
unsafe {
assert_eq!(extract_trigrams("ab"), vec!["ab"]);
assert_eq!(extract_trigrams("abc"), vec!["abc"]);
}
}
#[test]
fn test_neon_trigram_ascii_fast_path() {
unsafe {
let mut trigrams = extract_trigrams("abcdef");
trigrams.sort();
assert_eq!(trigrams, vec!["abc", "bcd", "cde", "def"]);
}
}
#[test]
fn test_neon_trigram_ascii_long() {
unsafe {
let input = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOP";
let trigrams = extract_trigrams(input);
assert_eq!(trigrams.len(), 50);
assert_eq!(trigrams[0], "abc");
assert_eq!(trigrams[49], "NOP");
}
}
#[test]
fn test_neon_trigram_non_ascii_fallback() {
unsafe {
let mut trigrams = extract_trigrams("héllo");
trigrams.sort();
assert_eq!(trigrams.len(), 3);
assert!(trigrams.contains(&"hél".to_string()));
assert!(trigrams.contains(&"éll".to_string()));
assert!(trigrams.contains(&"llo".to_string()));
}
}
#[test]
fn test_neon_trigram_dedup() {
unsafe {
let trigrams = extract_trigrams("aaaa");
assert_eq!(trigrams, vec!["aaa"]);
}
}
#[test]
fn test_neon_is_ascii() {
unsafe {
assert!(is_ascii_neon(b"hello world"));
assert!(is_ascii_neon(b""));
assert!(is_ascii_neon(
b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
));
assert!(!is_ascii_neon("héllo".as_bytes()));
assert!(!is_ascii_neon("hello 世界".as_bytes()));
assert!(!is_ascii_neon("abcdefghijklmnopé".as_bytes()));
}
}
}