#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
use std::arch::aarch64::*;
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
use std::arch::x86_64::*;
#[cfg(not(all(feature = "simd_runtime_dispatch", target_arch = "aarch64")))]
use std::simd::Simd;
#[cfg(not(all(feature = "simd_runtime_dispatch", target_arch = "aarch64")))]
use std::simd::cmp::SimdPartialOrd;
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
use std::sync::OnceLock;
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
type CountFn = fn(&[u8]) -> usize;
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
struct SimdDispatch {
count_non_ascii: CountFn,
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
impl SimdDispatch {
fn detect() -> Self {
if std::arch::is_x86_feature_detected!("avx2") {
return Self {
count_non_ascii: count_non_ascii_avx2,
};
}
Self {
count_non_ascii: count_non_ascii_portable,
}
}
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
fn dispatch() -> &'static SimdDispatch {
static DISPATCH: OnceLock<SimdDispatch> = OnceLock::new();
DISPATCH.get_or_init(SimdDispatch::detect)
}
#[cfg(not(all(feature = "simd_runtime_dispatch", target_arch = "aarch64")))]
fn count_non_ascii_portable(bytes: &[u8]) -> usize {
let mut count = 0u32;
let mut offset = 0;
const LANES: usize = 32;
let threshold = Simd::<u8, LANES>::splat(0x80);
while offset + LANES <= bytes.len() {
let chunk = Simd::<u8, LANES>::from_slice(&bytes[offset..]);
count += chunk.simd_ge(threshold).to_bitmask().count_ones();
offset += LANES;
}
for &b in &bytes[offset..] {
count += (b >> 7) as u32;
}
count as usize
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn count_non_ascii_avx2_impl(bytes: &[u8]) -> usize {
let mut count = 0u32;
let mut offset = 0;
while offset + 32 <= bytes.len() {
let chunk = unsafe { _mm256_loadu_si256(bytes.as_ptr().add(offset) as *const __m256i) };
count += (_mm256_movemask_epi8(chunk) as u32).count_ones();
offset += 32;
}
for &b in &bytes[offset..] {
count += (b >> 7) as u32;
}
count as usize
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
fn count_non_ascii_avx2(bytes: &[u8]) -> usize {
unsafe { count_non_ascii_avx2_impl(bytes) }
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
fn count_non_ascii_neon(bytes: &[u8]) -> usize {
let mut count = 0u32;
let mut offset = 0;
unsafe {
while offset + 16 <= bytes.len() {
let chunk = vld1q_u8(bytes.as_ptr().add(offset));
count += vaddvq_u8(vshrq_n_u8(chunk, 7)) as u32;
offset += 16;
}
}
for &b in &bytes[offset..] {
count += (b >> 7) as u32;
}
count as usize
}
#[inline(always)]
pub(super) fn count_non_ascii_simd(bytes: &[u8]) -> usize {
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
return (dispatch().count_non_ascii)(bytes);
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
return count_non_ascii_neon(bytes);
#[cfg(not(all(
feature = "simd_runtime_dispatch",
any(target_arch = "x86_64", target_arch = "aarch64")
)))]
count_non_ascii_portable(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_count_non_ascii() {
let cjk = "你好世界".as_bytes();
let cases: &[(&[u8], usize)] = &[
(b"hello world 123", 0), (cjk, cjk.len()), ("hello世界".as_bytes(), 6), (b"", 0), (b"a", 0), (&[0x80], 1), (&[0xFF], 1), ];
for &(input, expected) in cases {
assert_eq!(count_non_ascii_simd(input), expected, "input: {input:?}");
}
}
#[test]
fn boundary_at_simd_width() {
let ascii_32 = b"abcdefghijklmnopqrstuvwxyz012345";
assert_eq!(count_non_ascii_simd(ascii_32), 0);
let mut buf = vec![0x80u8; 33];
assert_eq!(count_non_ascii_simd(&buf), 33);
buf[32] = b'a';
assert_eq!(count_non_ascii_simd(&buf), 32);
}
}