#![allow(unsafe_code, reason = "SIMD")]
#![allow(unsafe_op_in_unsafe_fn, reason = "SIMD")]
use core::arch::aarch64::*;
use core::mem::MaybeUninit;
use crate::backend::generic::{decode_generic_unchecked, encode_generic_unchecked};
use crate::error::InvalidInput;
use crate::util::digits16;
#[target_feature(enable = "neon")]
pub(crate) unsafe fn encode_neon_unchecked<const UPPER: bool>(
mut src: &[u8],
mut dst: &mut [[MaybeUninit<u8>; 2]],
) {
const BATCH: usize = size_of::<uint8x16_t>();
if src.len() >= BATCH {
let m = vdupq_n_u8(0b_0000_1111);
let lut = vld1q_u8(digits16::<UPPER>().as_ptr());
while src.len() >= BATCH {
let chunk: uint8x16_t = vld1q_u8(src.as_ptr());
let mut hi = vshrq_n_u8(chunk, 4);
let mut lo = vandq_u8(chunk, m);
lo = vqtbl1q_u8(lut, lo);
hi = vqtbl1q_u8(lut, hi);
let output = vzipq_u8(hi, lo);
vst1q_u8_x2(dst.as_mut_ptr().cast(), output);
src = &src[BATCH..];
dst = dst.get_unchecked_mut(BATCH..);
}
}
encode_generic_unchecked::<UPPER>(src, dst);
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn decode_neon_unchecked(
mut src: &[[u8; 2]],
mut dst: &mut [MaybeUninit<u8>],
) -> Result<(), InvalidInput> {
const BATCH: usize = size_of::<uint8x16_t>();
const TRAILING_BATCH: usize = BATCH / 2;
if src.len() >= TRAILING_BATCH {
let n_c6 = vdupq_n_u8(0xFF_u8 - b'9');
let n_06 = vdupq_n_u8(0x06);
let n_f0 = vdupq_n_u8(0xF0);
let n_df = vdupq_n_u8(0xDF);
let u_a = vdupq_n_u8(b'A');
let n_0a = vdupq_n_u8(0x0A);
while src.len() >= BATCH {
let uint8x16x2_t(chunk1, chunk2) = vld1q_u8_x2(src.as_ptr().cast::<u8>());
let d1 = vsubq_u8(vqsubq_u8(vaddq_u8(chunk1, n_c6), n_06), n_f0);
let d2 = vsubq_u8(vqsubq_u8(vaddq_u8(chunk2, n_c6), n_06), n_f0);
let a1 = vqaddq_u8(vsubq_u8(vandq_u8(chunk1, n_df), u_a), n_0a);
let a2 = vqaddq_u8(vsubq_u8(vandq_u8(chunk2, n_df), u_a), n_0a);
let n1 = vminq_u8(d1, a1);
let n2 = vminq_u8(d2, a2);
if vmaxvq_u8(n1) > 0x0F || vmaxvq_u8(n2) > 0x0F {
return Err(InvalidInput);
}
let bytes = {
let uint8x16x2_t(hi, lo) = vuzpq_u8(n1, n2);
vorrq_u8(vshlq_n_u8(hi, 4), lo)
};
vst1q_u8(dst.as_mut_ptr().cast::<u8>(), bytes);
src = &src[BATCH..];
dst = dst.get_unchecked_mut(BATCH..);
}
if src.len() >= TRAILING_BATCH {
let chunk = vld1q_u8(src.as_ptr().cast::<u8>());
let d = vsubq_u8(vqsubq_u8(vaddq_u8(chunk, n_c6), n_06), n_f0);
let a = vqaddq_u8(vsubq_u8(vandq_u8(chunk, n_df), u_a), n_0a);
let n = vminq_u8(d, a);
if vmaxvq_u8(n) > 0x0F {
return Err(InvalidInput);
}
let bytes = {
let uint8x16x2_t(hi, lo) = vuzpq_u8(n, n);
vorr_u8(vshl_n_u8(vget_low_u8(hi), 4), vget_low_u8(lo))
};
vst1_u8(dst.as_mut_ptr().cast::<u8>(), bytes);
src = &src[TRAILING_BATCH..];
dst = dst.get_unchecked_mut(TRAILING_BATCH..);
}
}
decode_generic_unchecked::<false>(src, dst)
}
#[cfg(test)]
mod smoking {
use alloc::string::String;
use alloc::vec;
use alloc::vec::Vec;
use core::mem::MaybeUninit;
use core::{slice, str};
use super::*;
use crate::util::{DIGITS_LOWER_16, DIGITS_UPPER_16};
macro_rules! test {
(
Encode = $encode_f:ident;
Decode = $($decode_f:ident),*;
Case = $i:expr
) => {{
let input = $i;
let expected_lower = input
.iter()
.flat_map(|b| [
DIGITS_LOWER_16[(*b >> 4) as usize] as char,
DIGITS_LOWER_16[(*b & 0b1111) as usize] as char,
])
.collect::<String>();
let expected_upper = input
.iter()
.flat_map(|b| [
DIGITS_UPPER_16[(*b >> 4) as usize] as char,
DIGITS_UPPER_16[(*b & 0b1111) as usize] as char,
])
.collect::<String>();
let mut output_lower = vec![[MaybeUninit::<u8>::uninit(); 2]; input.len()];
let mut output_upper = vec![[MaybeUninit::<u8>::uninit(); 2]; input.len()];
unsafe {
$encode_f::<false>(input, &mut output_lower);
$encode_f::<true>(input, &mut output_upper);
}
let output_lower = unsafe {
slice::from_raw_parts(
output_lower.as_ptr().cast::<[u8; 2]>(),
output_lower.len(),
)
};
let output_upper = unsafe {
slice::from_raw_parts(
output_upper.as_ptr().cast::<[u8; 2]>(),
output_upper.len(),
)
};
assert_eq!(
output_lower.as_flattened(),
expected_lower.as_bytes(),
"Encode error, expect \"{expected_lower}\", got \"{}\" ({:?})",
str::from_utf8(output_lower.as_flattened()).unwrap_or("<invalid utf-8>"),
output_lower.as_flattened()
);
assert_eq!(
output_upper.as_flattened(),
expected_upper.as_bytes(),
"Encode error, expect \"{expected_upper}\", got \"{}\" ({:?})",
str::from_utf8(output_upper.as_flattened()).unwrap_or("<invalid utf-8>"),
output_upper.as_flattened()
);
$({
let mut decoded_lower = vec![MaybeUninit::<u8>::uninit(); input.len()];
let mut decoded_upper = vec![MaybeUninit::<u8>::uninit(); input.len()];
unsafe {
$decode_f(output_lower, &mut decoded_lower).unwrap();
$decode_f(output_upper, &mut decoded_upper).unwrap();
assert_eq!(
decoded_lower.assume_init_ref(),
input,
"Decode error for {}, expect {:?}, got {:?}",
stringify!($decode_f),
input,
decoded_lower.assume_init_ref()
);
assert_eq!(
decoded_upper.assume_init_ref(),
input,
"Decode error for {}, expect {:?}, got {:?}",
stringify!($decode_f),
input,
decoded_upper.assume_init_ref()
);
}
})*
}};
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_neon() {
const CASE: &[u8; 33] = &[
0x62, 0xBE, 0x66, 0xE0, 0x1C, 0x1E, 0xFB, 0x43, 0x16, 0xA0, 0x9F, 0x8A, 0xE4, 0x93,
0xE3, 0x7F, 0x23, 0x9F, 0x0D, 0xEF, 0x94, 0x25, 0xE0, 0x60, 0x62, 0xBA, 0x10, 0xB2,
0x7B, 0xB6, 0x2B, 0xFB, 0x44,
];
test! {
Encode = encode_neon_unchecked;
Decode = decode_neon_unchecked;
Case = &[]
}
test! {
Encode = encode_neon_unchecked;
Decode = decode_neon_unchecked;
Case = &CASE[..15]
}
test! {
Encode = encode_neon_unchecked;
Decode = decode_neon_unchecked;
Case = &CASE[..16]
}
test! {
Encode = encode_neon_unchecked;
Decode = decode_neon_unchecked;
Case = &CASE[..17]
};
test! {
Encode = encode_neon_unchecked;
Decode = decode_neon_unchecked;
Case = &CASE[..31]
}
test! {
Encode = encode_neon_unchecked;
Decode = decode_neon_unchecked;
Case = &CASE[..32]
}
test! {
Encode = encode_neon_unchecked;
Decode = decode_neon_unchecked;
Case = &CASE[..33]
};
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_validation() {
for l in [15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129] {
for c in 0u8..=255 {
let mut bytes = vec![b'a'; l * 2];
bytes[l] = c;
let bytes = unsafe { bytes.as_chunks_unchecked() };
if c.is_ascii_hexdigit() {
unsafe {
assert!(
decode_neon_unchecked(
bytes,
Vec::with_capacity(l).spare_capacity_mut()
)
.is_ok(),
"neon validation failed for byte {c} (l={l})",
);
}
} else {
unsafe {
assert!(
decode_neon_unchecked(
bytes,
Vec::with_capacity(l).spare_capacity_mut()
)
.is_err(),
"neon validation failed for byte {c} (l={l})"
);
}
}
}
}
}
}