use crate::{decode::decode_scalar, Error};
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::{
__m128i, _mm_add_epi8, _mm_and_si128, _mm_cmpgt_epi8, _mm_cmplt_epi8, _mm_loadu_si128,
_mm_movemask_epi8, _mm_or_si128, _mm_packus_epi16, _mm_set1_epi16, _mm_set1_epi8,
_mm_setzero_si128, _mm_slli_epi16, _mm_srli_epi16, _mm_storel_epi64,
};
#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::{
uint16x8_t, uint8x16_t, uint8x8_t, vaddq_u8, vandq_u16, vandq_u8, vbslq_u8, vcgeq_u8, vcleq_u8,
vdupq_n_u16, vdupq_n_u8, vld1q_u8, vminvq_u8, vmovn_u16, vorrq_u16, vorrq_u8,
vreinterpretq_u16_u8, vshlq_n_u16, vshrq_n_u16, vst1_u8, vst1q_u8, vsubq_u8,
};
const CHUNK_HEX: usize = 16;
const CHUNK_OUT: usize = 8;
pub(crate) fn decode_to_slice_simd(src_hex: &[u8], dst: &mut [u8]) -> Result<usize, Error> {
let out_len = dst.len();
debug_assert_eq!(src_hex.len(), out_len * 2);
let iters = out_len / CHUNK_OUT;
let tail_hex_start = iters * CHUNK_HEX;
#[cfg(target_arch = "x86_64")]
unsafe {
for i in 0..iters {
let hex_off = i * CHUNK_HEX;
validate_chunk16_sse2(&src_hex[hex_off..hex_off + CHUNK_HEX], hex_off)?;
}
}
#[cfg(target_arch = "aarch64")]
unsafe {
for i in 0..iters {
let hex_off = i * CHUNK_HEX;
validate_chunk16_neon(&src_hex[hex_off..hex_off + CHUNK_HEX], hex_off)?;
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
return decode_scalar(src_hex, dst);
}
let tail_hex = &src_hex[tail_hex_start..];
if !tail_hex.is_empty() {
validate_hex_scalar(tail_hex, tail_hex_start)?;
}
#[cfg(target_arch = "x86_64")]
{
for i in 0..iters {
let hex_off = i * CHUNK_HEX;
let out_off = i * CHUNK_OUT;
unsafe {
decode_chunk16_sse2(
&src_hex[hex_off..hex_off + CHUNK_HEX],
&mut dst[out_off..out_off + CHUNK_OUT],
);
}
}
}
#[cfg(target_arch = "aarch64")]
{
for i in 0..iters {
let hex_off = i * CHUNK_HEX;
let out_off = i * CHUNK_OUT;
unsafe {
decode_chunk16_neon(
&src_hex[hex_off..hex_off + CHUNK_HEX],
&mut dst[out_off..out_off + CHUNK_OUT],
);
}
}
}
if !tail_hex.is_empty() {
let tail_dst = &mut dst[iters * CHUNK_OUT..];
decode_scalar(tail_hex, tail_dst)?;
}
Ok(out_len)
}
#[inline]
fn is_hex_ascii(b: u8) -> bool {
b.is_ascii_hexdigit()
}
fn validate_hex_scalar(src_hex: &[u8], hex_base: usize) -> Result<(), Error> {
for (i, &b) in src_hex.iter().enumerate() {
if !is_hex_ascii(b) {
return Err(Error::InvalidByte {
index: hex_base + i,
byte: b,
});
}
}
Ok(())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn validate_chunk16_sse2(src16: &[u8], hex_base: usize) -> Result<(), Error> {
debug_assert_eq!(src16.len(), 16);
#[allow(clippy::cast_ptr_alignment)]
let v = _mm_loadu_si128(src16.as_ptr().cast::<__m128i>());
let ge_0 = _mm_cmpgt_epi8(v, _mm_set1_epi8((b'0' - 1).cast_signed()));
let le_9 = _mm_cmplt_epi8(v, _mm_set1_epi8((b'9' + 1).cast_signed()));
let is_digit = _mm_and_si128(ge_0, le_9);
let lower = _mm_or_si128(v, _mm_set1_epi8(0x20u8.cast_signed()));
let ge_a = _mm_cmpgt_epi8(lower, _mm_set1_epi8((b'a' - 1).cast_signed()));
let le_f = _mm_cmplt_epi8(lower, _mm_set1_epi8((b'f' + 1).cast_signed()));
let is_alpha = _mm_and_si128(ge_a, le_f);
let valid = _mm_or_si128(is_digit, is_alpha);
let mask = _mm_movemask_epi8(valid);
if mask != -1 {
let bad_lane = (!mask.cast_unsigned()).trailing_zeros() as usize;
return Err(Error::InvalidByte {
index: hex_base + bad_lane,
byte: src16[bad_lane],
});
}
Ok(())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn decode_chunk16_sse2(src16: &[u8], dst8: &mut [u8]) {
debug_assert_eq!(src16.len(), 16);
debug_assert_eq!(dst8.len(), 8);
#[allow(clippy::cast_ptr_alignment)]
let v = _mm_loadu_si128(src16.as_ptr().cast::<__m128i>());
let lower = _mm_or_si128(v, _mm_set1_epi8(0x20u8.cast_signed()));
let ge_a = _mm_cmpgt_epi8(lower, _mm_set1_epi8((b'a' - 1).cast_signed()));
let le_f = _mm_cmplt_epi8(lower, _mm_set1_epi8((b'f' + 1).cast_signed()));
let is_alpha = _mm_and_si128(ge_a, le_f);
let low_nibble = _mm_and_si128(v, _mm_set1_epi8(0x0Fu8.cast_signed()));
let add = _mm_and_si128(is_alpha, _mm_set1_epi8(9i8));
let nibbles = _mm_add_epi8(low_nibble, add);
let w = nibbles;
let low = _mm_and_si128(w, _mm_set1_epi16(0x00FFu16.cast_signed()));
let high = _mm_srli_epi16(w, 8);
let packed_words = _mm_or_si128(_mm_slli_epi16(low, 4), high);
let packed_bytes = _mm_packus_epi16(packed_words, _mm_setzero_si128());
#[allow(clippy::cast_ptr_alignment)]
_mm_storel_epi64(dst8.as_mut_ptr().cast::<__m128i>(), packed_bytes);
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn validate_chunk16_neon(src16: &[u8], hex_base: usize) -> Result<(), Error> {
debug_assert_eq!(src16.len(), 16);
let v: uint8x16_t = vld1q_u8(src16.as_ptr());
let ge_0: uint8x16_t = vcgeq_u8(v, vdupq_n_u8(b'0'));
let le_9: uint8x16_t = vcleq_u8(v, vdupq_n_u8(b'9'));
let is_digit: uint8x16_t = vandq_u8(ge_0, le_9);
let lower: uint8x16_t = vorrq_u8(v, vdupq_n_u8(0x20));
let ge_a: uint8x16_t = vcgeq_u8(lower, vdupq_n_u8(b'a'));
let le_f: uint8x16_t = vcleq_u8(lower, vdupq_n_u8(b'f'));
let is_alpha: uint8x16_t = vandq_u8(ge_a, le_f);
let valid: uint8x16_t = vorrq_u8(is_digit, is_alpha);
let min_lane: u8 = vminvq_u8(valid);
if min_lane != 0xFF {
let mut valid_bytes = [0u8; 16];
vst1q_u8(valid_bytes.as_mut_ptr(), valid);
for lane in 0..16 {
if valid_bytes[lane] != 0xFF {
return Err(Error::InvalidByte {
index: hex_base + lane,
byte: src16[lane],
});
}
}
unreachable!("NEON validate invariant violated: min_lane != 0xFF but all lanes == 0xFF");
}
Ok(())
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn decode_chunk16_neon(src16: &[u8], dst8: &mut [u8]) {
debug_assert_eq!(src16.len(), 16);
debug_assert_eq!(dst8.len(), 8);
let v: uint8x16_t = vld1q_u8(src16.as_ptr());
let ge_0: uint8x16_t = vcgeq_u8(v, vdupq_n_u8(b'0'));
let le_9: uint8x16_t = vcleq_u8(v, vdupq_n_u8(b'9'));
let is_digit: uint8x16_t = vandq_u8(ge_0, le_9);
let lower: uint8x16_t = vorrq_u8(v, vdupq_n_u8(0x20));
let digit_val: uint8x16_t = vsubq_u8(v, vdupq_n_u8(b'0'));
let alpha_val: uint8x16_t = vaddq_u8(vsubq_u8(lower, vdupq_n_u8(b'a')), vdupq_n_u8(10));
let nibbles: uint8x16_t = vbslq_u8(is_digit, digit_val, alpha_val);
let w: uint16x8_t = vreinterpretq_u16_u8(nibbles);
let low: uint16x8_t = vandq_u16(w, vdupq_n_u16(0x00FF));
let high: uint16x8_t = vshrq_n_u16(w, 8);
let packed_words: uint16x8_t = vorrq_u16(vshlq_n_u16(low, 4), high);
let packed_bytes: uint8x8_t = vmovn_u16(packed_words);
vst1_u8(dst8.as_mut_ptr(), packed_bytes);
}
#[cfg(all(test, feature = "simd"))]
#[path = "simd/tests.rs"]
mod tests;