#[cfg(not(all(feature = "simd_runtime_dispatch", target_arch = "aarch64")))]
use std::simd::Simd;
#[cfg(not(all(feature = "simd_runtime_dispatch", target_arch = "aarch64")))]
use std::simd::cmp::{SimdPartialEq, SimdPartialOrd};
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
use std::arch::aarch64::*;
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
use std::arch::x86_64::*;
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
use std::sync::OnceLock;
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
type SkipFn = fn(&[u8], usize) -> usize;
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
type SkipDeleteFn = fn(&[u8], usize, &[u8; 16]) -> usize;
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
type CountContinuationFn = fn(&[u8]) -> usize;
const SHIFT_TABLE_16: [u8; 16] = [1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128];
#[cfg(not(all(feature = "simd_runtime_dispatch", target_arch = "aarch64")))]
const SHIFT_TABLE_32: [u8; 32] = [
1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4,
8, 16, 32, 64, 128,
];
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
struct SimdDispatch {
skip_ascii: SkipFn,
skip_ascii_non_delete: SkipDeleteFn,
count_continuation: CountContinuationFn,
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
impl SimdDispatch {
fn detect() -> Self {
if std::arch::is_x86_feature_detected!("avx2") {
return Self {
skip_ascii: skip_ascii_avx2,
skip_ascii_non_delete: skip_ascii_non_delete_avx2,
count_continuation: count_continuation_avx2,
};
}
Self {
skip_ascii: skip_ascii_portable,
skip_ascii_non_delete: skip_ascii_non_delete_portable,
count_continuation: count_continuation_portable,
}
}
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
#[inline(always)]
fn dispatch() -> &'static SimdDispatch {
static DISPATCH: OnceLock<SimdDispatch> = OnceLock::new();
DISPATCH.get_or_init(SimdDispatch::detect)
}
#[inline(always)]
fn ascii_delete_contains(byte: u8, ascii_lut: &[u8; 16]) -> bool {
let idx = byte as usize;
(ascii_lut[idx >> 3] & (1 << (idx & 7))) != 0
}
#[inline(always)]
fn find_non_ascii_scalar(bytes: &[u8], offset: usize) -> usize {
let mut offset = offset;
while offset < bytes.len() && bytes[offset] < 0x80 {
offset += 1;
}
offset
}
#[inline(always)]
fn find_ascii_non_delete_scalar(bytes: &[u8], offset: usize, ascii_lut: &[u8; 16]) -> usize {
let mut offset = offset;
while offset < bytes.len() {
let b = bytes[offset];
if b >= 0x80 || ascii_delete_contains(b, ascii_lut) {
break;
}
offset += 1;
}
offset
}
macro_rules! define_avx2_entry {
(
$(#[$meta:meta])*
fn $name:ident ( bytes, offset $(, $extra:ident : $ety:ty)* ),
$impl_fn:ident,
|$b0:ident| $early_check:expr
) => {
$(#[$meta])*
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
fn $name(bytes: &[u8], offset: usize $(, $extra: $ety)*) -> usize {
if offset >= bytes.len() {
return offset;
}
let $b0 = bytes[offset];
if $early_check {
return offset;
}
unsafe { $impl_fn(bytes, offset $(, $extra)*) }
}
};
}
macro_rules! define_skip_dispatch {
(
$(#[$meta:meta])*
pub(crate) fn $name:ident ( bytes, offset $(, $extra:ident : $ety:ty)* ),
$field:ident, $neon:ident, $portable:ident
) => {
$(#[$meta])*
#[inline(always)]
pub(crate) fn $name(bytes: &[u8], offset: usize $(, $extra: $ety)*) -> usize {
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
return (dispatch().$field)(bytes, offset $(, $extra)*);
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
return $neon(bytes, offset $(, $extra)*);
#[cfg(not(all(
feature = "simd_runtime_dispatch",
any(target_arch = "x86_64", target_arch = "aarch64")
)))]
$portable(bytes, offset $(, $extra)*)
}
};
}
#[cfg(not(all(feature = "simd_runtime_dispatch", target_arch = "aarch64")))]
#[inline(always)]
fn portable_ascii_delete_mask_16(chunk: Simd<u8, 16>, ascii_lut: Simd<u8, 16>) -> u64 {
let byte_idx = chunk >> Simd::<u8, 16>::splat(3);
let lut_byte = ascii_lut.swizzle_dyn(byte_idx);
let shift_table = Simd::<u8, 16>::from_array(SHIFT_TABLE_16);
let bit_pos = chunk & Simd::<u8, 16>::splat(7);
let bit_mask = shift_table.swizzle_dyn(bit_pos);
(lut_byte & bit_mask)
.simd_ne(Simd::<u8, 16>::splat(0))
.to_bitmask()
}
#[cfg(not(all(feature = "simd_runtime_dispatch", target_arch = "aarch64")))]
#[inline(always)]
fn portable_ascii_delete_mask_32(chunk: Simd<u8, 32>, ascii_lut: Simd<u8, 32>) -> u64 {
let byte_idx = chunk >> Simd::<u8, 32>::splat(3);
let lut_byte = ascii_lut.swizzle_dyn(byte_idx);
let shift_table = Simd::<u8, 32>::from_array(SHIFT_TABLE_32);
let bit_pos = chunk & Simd::<u8, 32>::splat(7);
let bit_mask = shift_table.swizzle_dyn(bit_pos);
(lut_byte & bit_mask)
.simd_ne(Simd::<u8, 32>::splat(0))
.to_bitmask()
}
#[cfg(not(all(feature = "simd_runtime_dispatch", target_arch = "aarch64")))]
#[inline(always)]
fn skip_ascii_portable(bytes: &[u8], offset: usize) -> usize {
if offset >= bytes.len() || bytes[offset] >= 0x80 {
return offset;
}
let mut offset = offset;
const LANES: usize = 32;
let threshold = Simd::<u8, LANES>::splat(0x80u8);
while offset + LANES <= bytes.len() {
let chunk = Simd::<u8, LANES>::from_slice(&bytes[offset..]);
let mask = chunk.simd_ge(threshold).to_bitmask();
if mask != 0 {
offset += mask.trailing_zeros() as usize;
return offset;
}
offset += LANES;
}
find_non_ascii_scalar(bytes, offset)
}
#[cfg(not(all(feature = "simd_runtime_dispatch", target_arch = "aarch64")))]
#[inline(always)]
fn skip_ascii_non_delete_portable(bytes: &[u8], offset: usize, ascii_lut: &[u8; 16]) -> usize {
if offset >= bytes.len() {
return offset;
}
let b0 = bytes[offset];
if b0 >= 0x80 || ascii_delete_contains(b0, ascii_lut) {
return offset;
}
let mut lut32 = [0u8; 32];
lut32[..16].copy_from_slice(ascii_lut);
lut32[16..].copy_from_slice(ascii_lut);
let ascii_lut_simd32 = Simd::<u8, 32>::from_array(lut32);
let mut offset = offset;
const LANES: usize = 32;
let non_ascii = Simd::<u8, LANES>::splat(0x80u8);
while offset + LANES <= bytes.len() {
let chunk = Simd::<u8, LANES>::from_slice(&bytes[offset..]);
let non_ascii_mask = chunk.simd_ge(non_ascii).to_bitmask();
let delete_mask = portable_ascii_delete_mask_32(chunk, ascii_lut_simd32);
let stop_mask = non_ascii_mask | delete_mask;
if stop_mask != 0 {
offset += stop_mask.trailing_zeros() as usize;
return offset;
}
offset += LANES;
}
while offset + 16 <= bytes.len() {
let chunk = Simd::<u8, 16>::from_slice(&bytes[offset..]);
let non_ascii_mask = chunk.simd_ge(Simd::<u8, 16>::splat(0x80u8)).to_bitmask();
let delete_mask =
portable_ascii_delete_mask_16(chunk, Simd::<u8, 16>::from_array(*ascii_lut));
let stop_mask = non_ascii_mask | delete_mask;
if stop_mask != 0 {
offset += stop_mask.trailing_zeros() as usize;
return offset;
}
offset += 16;
}
find_ascii_non_delete_scalar(bytes, offset, ascii_lut)
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn skip_ascii_avx2_impl(bytes: &[u8], mut offset: usize) -> usize {
while offset + 32 <= bytes.len() {
let chunk = unsafe { _mm256_loadu_si256(bytes.as_ptr().add(offset) as *const __m256i) };
let mask = _mm256_movemask_epi8(chunk) as u32;
if mask != 0 {
return offset + mask.trailing_zeros() as usize;
}
offset += 32;
}
find_non_ascii_scalar(bytes, offset)
}
define_avx2_entry! {
fn skip_ascii_avx2(bytes, offset),
skip_ascii_avx2_impl,
|b0| b0 >= 0x80
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn skip_ascii_non_delete_avx2_impl(
bytes: &[u8],
mut offset: usize,
ascii_lut: &[u8; 16],
) -> usize {
let mut lut32 = [0u8; 32];
lut32[..16].copy_from_slice(ascii_lut);
lut32[16..].copy_from_slice(ascii_lut);
let shuffle_lut = unsafe { _mm256_loadu_si256(lut32.as_ptr() as *const __m256i) };
let shift_table = unsafe { _mm256_loadu_si256(SHIFT_TABLE_32.as_ptr() as *const __m256i) };
let low_nibble_mask = _mm256_set1_epi8(0x0f);
let bit_pos_mask = _mm256_set1_epi8(0x07);
let zero = _mm256_setzero_si256();
while offset + 32 <= bytes.len() {
let chunk = unsafe { _mm256_loadu_si256(bytes.as_ptr().add(offset) as *const __m256i) };
let non_ascii_mask = _mm256_movemask_epi8(chunk) as u32;
let byte_idx = _mm256_and_si256(_mm256_srli_epi16(chunk, 3), low_nibble_mask);
let lut_byte = _mm256_shuffle_epi8(shuffle_lut, byte_idx);
let bit_pos = _mm256_and_si256(chunk, bit_pos_mask);
let bit_mask = _mm256_shuffle_epi8(shift_table, bit_pos);
let deleted = _mm256_and_si256(lut_byte, bit_mask);
let delete_mask = !_mm256_movemask_epi8(_mm256_cmpeq_epi8(deleted, zero)) as u32;
let stop_mask = non_ascii_mask | delete_mask;
if stop_mask != 0 {
return offset + stop_mask.trailing_zeros() as usize;
}
offset += 32;
}
find_ascii_non_delete_scalar(bytes, offset, ascii_lut)
}
define_avx2_entry! {
fn skip_ascii_non_delete_avx2(bytes, offset, ascii_lut: &[u8; 16]),
skip_ascii_non_delete_avx2_impl,
|b0| b0 >= 0x80 || ascii_delete_contains(b0, ascii_lut)
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
#[inline(always)]
unsafe fn first_non_ascii_in_neon(bytes: *const u8, offset: usize) -> usize {
let chunk = unsafe { vld1q_u8(bytes.add(offset)) };
let mut scratch = [0u8; 16];
unsafe { vst1q_u8(scratch.as_mut_ptr(), chunk) };
scratch
.iter()
.position(|&b| b >= 0x80)
.map_or(offset + 16, |idx| offset + idx)
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
#[inline(always)]
fn skip_ascii_neon(bytes: &[u8], offset: usize) -> usize {
if offset >= bytes.len() || bytes[offset] >= 0x80 {
return offset;
}
let mut offset = offset;
unsafe {
while offset + 16 <= bytes.len() {
let chunk = vld1q_u8(bytes.as_ptr().add(offset));
if vmaxvq_u8(chunk) >= 0x80 {
return first_non_ascii_in_neon(bytes.as_ptr(), offset);
}
offset += 16;
}
}
find_non_ascii_scalar(bytes, offset)
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
#[inline(always)]
fn skip_ascii_non_delete_neon(bytes: &[u8], offset: usize, ascii_lut: &[u8; 16]) -> usize {
if offset >= bytes.len() {
return offset;
}
let b0 = bytes[offset];
if b0 >= 0x80 || ascii_delete_contains(b0, ascii_lut) {
return offset;
}
let mut offset = offset;
unsafe {
let lut = vld1q_u8(ascii_lut.as_ptr());
let shift = vld1q_u8(SHIFT_TABLE_16.as_ptr());
let seven = vdupq_n_u8(7);
while offset + 16 <= bytes.len() {
let chunk = vld1q_u8(bytes.as_ptr().add(offset));
let has_non_ascii = vmaxvq_u8(chunk) >= 0x80;
let byte_idx = vshrq_n_u8(chunk, 3);
let lut_byte = vqtbl1q_u8(lut, byte_idx);
let bit_pos = vandq_u8(chunk, seven);
let bit_mask = vqtbl1q_u8(shift, bit_pos);
let deleted = vandq_u8(lut_byte, bit_mask);
if has_non_ascii || vmaxvq_u8(deleted) != 0 {
let mut scratch = [0u8; 16];
vst1q_u8(scratch.as_mut_ptr(), chunk);
return scratch
.iter()
.position(|&b| b >= 0x80 || ascii_delete_contains(b, ascii_lut))
.map_or(offset + 16, |idx| offset + idx);
}
offset += 16;
}
}
find_ascii_non_delete_scalar(bytes, offset, ascii_lut)
}
define_skip_dispatch! {
pub(crate) fn skip_ascii_simd(bytes, offset),
skip_ascii, skip_ascii_neon, skip_ascii_portable
}
define_skip_dispatch! {
pub(crate) fn skip_ascii_non_delete_simd(bytes, offset, ascii_lut: &[u8; 16]),
skip_ascii_non_delete, skip_ascii_non_delete_neon, skip_ascii_non_delete_portable
}
#[inline(always)]
fn count_continuation_scalar(bytes: &[u8], offset: usize) -> usize {
bytes[offset..]
.iter()
.filter(|&&b| (b & 0xC0) == 0x80)
.count()
}
#[cfg(not(all(feature = "simd_runtime_dispatch", target_arch = "aarch64")))]
#[inline(always)]
fn count_continuation_portable(bytes: &[u8]) -> usize {
const LANES: usize = 32;
let mask_val = Simd::<u8, LANES>::splat(0xC0);
let target = Simd::<u8, LANES>::splat(0x80);
let mut count = 0usize;
let mut offset = 0;
while offset + LANES <= bytes.len() {
let chunk = Simd::<u8, LANES>::from_slice(&bytes[offset..]);
count += (chunk & mask_val).simd_eq(target).to_bitmask().count_ones() as usize;
offset += LANES;
}
count + count_continuation_scalar(bytes, offset)
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn count_continuation_avx2_impl(bytes: &[u8]) -> usize {
let mask_val = _mm256_set1_epi8(0xC0u8 as i8);
let target = _mm256_set1_epi8(0x80u8 as i8);
let mut count = 0usize;
let mut offset = 0;
while offset + 32 <= bytes.len() {
let chunk = unsafe { _mm256_loadu_si256(bytes.as_ptr().add(offset) as *const __m256i) };
let masked = _mm256_and_si256(chunk, mask_val);
let eq = _mm256_cmpeq_epi8(masked, target);
count += (_mm256_movemask_epi8(eq) as u32).count_ones() as usize;
offset += 32;
}
count + count_continuation_scalar(bytes, offset)
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
fn count_continuation_avx2(bytes: &[u8]) -> usize {
if bytes.len() < 32 {
return count_continuation_scalar(bytes, 0);
}
unsafe { count_continuation_avx2_impl(bytes) }
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
#[inline(always)]
fn count_continuation_neon(bytes: &[u8]) -> usize {
let mut count = 0usize;
let mut offset = 0;
unsafe {
let mask_val = vdupq_n_u8(0xC0);
let target = vdupq_n_u8(0x80);
while offset + 16 <= bytes.len() {
let chunk = vld1q_u8(bytes.as_ptr().add(offset));
let masked = vandq_u8(chunk, mask_val);
let eq = vceqq_u8(masked, target);
let ones = vshrq_n_u8(eq, 7);
count += vaddvq_u8(ones) as usize;
offset += 16;
}
}
count + count_continuation_scalar(bytes, offset)
}
#[inline(always)]
pub(crate) fn multibyte_density(bytes: &[u8]) -> f32 {
if bytes.is_empty() {
return 0.0;
}
let count = {
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
{
(dispatch().count_continuation)(bytes)
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
{
count_continuation_neon(bytes)
}
#[cfg(not(all(
feature = "simd_runtime_dispatch",
any(target_arch = "x86_64", target_arch = "aarch64")
)))]
{
count_continuation_portable(bytes)
}
};
count as f32 / bytes.len() as f32
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn skip_ascii_matches_scalar_behavior() {
let text = "plain ascii 123".as_bytes();
assert_eq!(skip_ascii_simd(text, 0), text.len());
let mixed = "hello世界".as_bytes();
assert_eq!(skip_ascii_simd(mixed, 0), 5);
assert_eq!(skip_ascii_simd(mixed, 5), 5);
}
#[test]
fn skip_ascii_non_delete_stops_on_delete_and_unicode() {
let mut ascii_lut = [0u8; 16];
ascii_lut[(b'!' as usize) >> 3] |= 1 << (b'!' & 7);
let text = "abc!def".as_bytes();
assert_eq!(skip_ascii_non_delete_simd(text, 0, &ascii_lut), 3);
let unicode = "abcdef你".as_bytes();
assert_eq!(skip_ascii_non_delete_simd(unicode, 0, &ascii_lut), 6);
}
#[test]
fn multibyte_density_pure_ascii() {
assert_eq!(multibyte_density(b"hello world"), 0.0);
assert_eq!(multibyte_density(b""), 0.0);
}
#[test]
fn multibyte_density_pure_cjk() {
let s = "你好";
let d = multibyte_density(s.as_bytes());
assert!((d - 4.0 / 6.0).abs() < 1e-5, "got {d}");
}
#[test]
fn multibyte_density_mixed() {
let s = "hi你";
let d = multibyte_density(s.as_bytes());
assert!((d - 2.0 / 5.0).abs() < 1e-5, "got {d}");
}
#[test]
fn multibyte_density_long_ascii() {
let s: String = "a".repeat(100);
assert_eq!(multibyte_density(s.as_bytes()), 0.0);
}
#[test]
fn multibyte_density_long_cjk() {
let s: String = "你".repeat(100);
let d = multibyte_density(s.as_bytes());
assert!((d - 200.0 / 300.0).abs() < 1e-5, "got {d}");
}
}