#![allow(clippy::wildcard_imports)] #![allow(clippy::ptr_as_ptr)] #![allow(clippy::implicit_hasher)]
use std::collections::HashSet;
pub type Trigram = [u8; 3];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum TrigramSimdLevel {
#[cfg(target_arch = "x86_64")]
Avx512,
#[cfg(target_arch = "x86_64")]
Avx2,
#[cfg(target_arch = "aarch64")]
Neon,
Scalar,
}
impl TrigramSimdLevel {
#[must_use]
pub fn detect() -> Self {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") {
return Self::Avx512;
}
if is_x86_feature_detected!("avx2") {
return Self::Avx2;
}
}
#[cfg(target_arch = "aarch64")]
{
return Self::Neon;
}
Self::Scalar
}
#[must_use]
#[allow(dead_code)] pub const fn name(self) -> &'static str {
match self {
#[cfg(target_arch = "x86_64")]
Self::Avx512 => "AVX-512",
#[cfg(target_arch = "x86_64")]
Self::Avx2 => "AVX2",
#[cfg(target_arch = "aarch64")]
Self::Neon => "NEON",
Self::Scalar => "Scalar",
}
}
}
#[must_use]
pub fn extract_trigrams_simd(text: &str) -> HashSet<Trigram> {
let level = TrigramSimdLevel::detect();
match level {
#[cfg(target_arch = "x86_64")]
TrigramSimdLevel::Avx512 => extract_trigrams_avx512(text),
#[cfg(target_arch = "x86_64")]
TrigramSimdLevel::Avx2 => extract_trigrams_avx2(text),
#[cfg(target_arch = "aarch64")]
TrigramSimdLevel::Neon => extract_trigrams_neon(text),
TrigramSimdLevel::Scalar => extract_trigrams_scalar(text),
}
}
#[must_use]
pub fn extract_trigrams_scalar(text: &str) -> HashSet<Trigram> {
if text.is_empty() {
return HashSet::new();
}
let text_bytes = text.as_bytes();
let text_len = text_bytes.len();
let total_len = 2 + text_len + 2; let trigram_count = total_len.saturating_sub(2);
let mut trigrams = HashSet::with_capacity(trigram_count);
for i in 0..trigram_count {
let trigram: [u8; 3] = std::array::from_fn(|j| {
let pos = i + j;
if pos < 2 {
b' ' } else if pos < 2 + text_len {
text_bytes[pos - 2]
} else {
b' ' }
});
trigrams.insert(trigram);
}
trigrams
}
#[inline]
fn build_padded_bytes(text: &str) -> Vec<u8> {
let text_bytes = text.as_bytes();
let mut padded = Vec::with_capacity(text_bytes.len() + 4);
padded.extend_from_slice(b" ");
padded.extend_from_slice(text_bytes);
padded.extend_from_slice(b" ");
padded
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[must_use]
unsafe fn extract_trigrams_avx2_inner(bytes: &[u8]) -> HashSet<Trigram> {
use std::arch::x86_64::*;
let mut trigrams = HashSet::with_capacity(bytes.len());
let len = bytes.len();
if len < 3 {
return trigrams;
}
let mut i = 0;
while i + 34 <= len {
_mm_prefetch(bytes.as_ptr().add(i + 64) as *const i8, _MM_HINT_T0);
for j in 0..30 {
let trigram = [bytes[i + j], bytes[i + j + 1], bytes[i + j + 2]];
trigrams.insert(trigram);
}
i += 30; }
while i + 3 <= len {
let trigram = [bytes[i], bytes[i + 1], bytes[i + 2]];
trigrams.insert(trigram);
i += 1;
}
trigrams
}
#[cfg(target_arch = "x86_64")]
#[must_use]
pub fn extract_trigrams_avx2(text: &str) -> HashSet<Trigram> {
if text.is_empty() {
return HashSet::new();
}
if is_x86_feature_detected!("avx2") {
let padded = build_padded_bytes(text);
unsafe { extract_trigrams_avx2_inner(&padded) }
} else {
extract_trigrams_scalar(text)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "avx512bw")]
#[must_use]
unsafe fn extract_trigrams_avx512_inner(bytes: &[u8]) -> HashSet<Trigram> {
use std::arch::x86_64::*;
let mut trigrams = HashSet::with_capacity(bytes.len());
let len = bytes.len();
if len < 3 {
return trigrams;
}
let mut i = 0;
while i + 66 <= len {
_mm_prefetch(bytes.as_ptr().add(i + 128) as *const i8, _MM_HINT_T0);
for j in 0..62 {
let trigram = [bytes[i + j], bytes[i + j + 1], bytes[i + j + 2]];
trigrams.insert(trigram);
}
i += 62; }
while i + 3 <= len {
let trigram = [bytes[i], bytes[i + 1], bytes[i + 2]];
trigrams.insert(trigram);
i += 1;
}
trigrams
}
#[cfg(target_arch = "x86_64")]
#[must_use]
pub fn extract_trigrams_avx512(text: &str) -> HashSet<Trigram> {
if text.is_empty() {
return HashSet::new();
}
if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") {
let padded = build_padded_bytes(text);
unsafe { extract_trigrams_avx512_inner(&padded) }
} else {
extract_trigrams_avx2(text)
}
}
#[cfg(target_arch = "aarch64")]
#[must_use]
pub fn extract_trigrams_neon(text: &str) -> HashSet<Trigram> {
use std::arch::aarch64::*;
if text.is_empty() {
return HashSet::new();
}
let padded = build_padded_bytes(text);
let bytes = padded.as_slice();
let mut trigrams = HashSet::with_capacity(bytes.len());
let len = bytes.len();
if len < 3 {
return trigrams;
}
let mut i = 0;
while i + 18 <= len {
unsafe {
let _chunk = vld1q_u8(bytes.as_ptr().add(i));
}
for j in 0..14 {
let trigram = [bytes[i + j], bytes[i + j + 1], bytes[i + j + 2]];
trigrams.insert(trigram);
}
i += 14;
}
while i + 3 <= len {
let trigram = [bytes[i], bytes[i + 1], bytes[i + 2]];
trigrams.insert(trigram);
i += 1;
}
trigrams
}
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn count_matching_trigrams_simd(
query_trigrams: &[[u8; 3]],
doc_trigrams: &HashSet<[u8; 3]>,
) -> usize {
if query_trigrams.len() < 16 {
return query_trigrams
.iter()
.filter(|t| doc_trigrams.contains(*t))
.count();
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return count_matching_avx2(query_trigrams, doc_trigrams);
}
}
query_trigrams
.iter()
.filter(|t| doc_trigrams.contains(*t))
.count()
}
#[cfg(target_arch = "x86_64")]
fn count_matching_avx2(query_trigrams: &[[u8; 3]], doc_trigrams: &HashSet<[u8; 3]>) -> usize {
let mut count = 0;
for chunk in query_trigrams.chunks(8) {
for trigram in chunk {
if doc_trigrams.contains(trigram) {
count += 1;
}
}
}
count
}