#![allow(unsafe_code, reason = "XXX")]
#![allow(unsafe_op_in_unsafe_fn, reason = "XXX")]
use core::mem::MaybeUninit;
use core::simd::prelude::*;
use core::slice;
use crate::backend::generic::{
decode_generic_unchecked, encode_generic_unchecked, validate_generic,
};
use crate::error::InvalidInput;
use crate::util::digits16;
pub(crate) unsafe fn encode_simd128_unchecked<const UPPER: bool>(
mut src: &[u8],
mut dst: &mut [[MaybeUninit<u8>; 2]],
) {
const BATCH: usize = size_of::<Simd<u8, 16>>();
if src.len() >= BATCH {
let m = Simd::splat(0b_0000_1111);
let lut = Simd::from_array(*digits16::<UPPER>());
while src.len() >= BATCH {
let chunk = Simd::<u8, 16>::from_slice(src);
let mut hi = chunk >> 4;
let mut lo = chunk & m;
lo = lut.swizzle_dyn(lo);
hi = lut.swizzle_dyn(hi);
let (out0, out1) = Simd::<u8, 16>::interleave(hi, lo);
{
let dst = dst.as_mut_ptr().cast::<Simd<u8, 16>>();
out0.copy_to_slice(slice::from_raw_parts_mut(dst.cast::<u8>(), BATCH));
out1.copy_to_slice(slice::from_raw_parts_mut(dst.add(1).cast::<u8>(), BATCH));
}
src = &src[BATCH..];
dst = dst.get_unchecked_mut(BATCH..);
}
}
encode_generic_unchecked::<UPPER>(src, dst);
}
pub(crate) unsafe fn validate_simd128(mut src: &[[u8; 2]]) -> Result<(), InvalidInput> {
const BATCH: usize = size_of::<Simd<u8, 16>>() / 2;
if src.len() >= BATCH {
let ascii_0 = Simd::splat(b'0');
let ascii_digit_thresh_p1 = Simd::splat(9 + 1);
let ascii_a = Simd::splat(b'a');
let ascii_alphabetic_thresh_p1 = Simd::splat(5 + 1);
let ascii_case_mask = Simd::splat(0b_0010_0000);
while src.len() >= BATCH {
let chunk: Simd<u8, 16> =
Simd::from_slice(slice::from_raw_parts(src.as_ptr().cast::<u8>(), BATCH * 2));
let is_digit = (chunk - ascii_0).simd_lt(ascii_digit_thresh_p1);
let is_alphabetic =
((chunk | ascii_case_mask) - ascii_a).simd_lt(ascii_alphabetic_thresh_p1);
let is_valid = is_digit | is_alphabetic;
if !is_valid.all() {
return Err(InvalidInput);
}
src = &src[BATCH..];
}
}
validate_generic(src)
}
pub(crate) unsafe fn decode_simd128_unchecked(
mut src: &[[u8; 2]],
mut dst: &mut [MaybeUninit<u8>],
) {
const BATCH: usize = size_of::<Simd<u8, 16>>();
#[inline]
fn unhex(value: Simd<u8, 16>) -> Simd<u8, 16> {
let sr6 = value >> 6;
let and15 = value & Simd::splat(0b_0000_1111);
let mul = (sr6 << 3) + sr6;
mul + and15
}
#[inline]
fn nib2byte(hi: Simd<u8, 16>, lo: Simd<u8, 16>) -> Simd<u8, 16> {
(hi << 4) | lo
}
while src.len() >= BATCH {
let chunk0: Simd<u8, 16> = {
#[allow(unsafe_code, reason = "XXX")]
let chunk = unsafe { slice::from_raw_parts(src.as_ptr().cast::<u8>(), BATCH) };
Simd::from_slice(chunk)
};
let chunk1: Simd<u8, 16> = {
#[allow(unsafe_code, reason = "XXX")]
let chunk =
unsafe { slice::from_raw_parts(src.as_ptr().cast::<u8>().add(BATCH), BATCH) };
Simd::from_slice(chunk)
};
let hi = simd_swizzle!(
chunk0,
chunk1,
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]
);
let lo = simd_swizzle!(
chunk0,
chunk1,
[1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31]
);
let hi = unhex(hi);
let lo = unhex(lo);
let out = nib2byte(hi, lo);
out.copy_to_slice(unsafe {
slice::from_raw_parts_mut(dst.as_mut_ptr().cast::<u8>(), BATCH)
});
src = &src[BATCH..];
dst = unsafe { dst.get_unchecked_mut(BATCH..) };
}
let _ = {
#[allow(unsafe_code, reason = "XXX")]
unsafe {
decode_generic_unchecked::<true>(src, dst)
}
};
}
#[cfg(test)]
mod smoking {
use alloc::string::String;
use alloc::vec;
use core::mem::MaybeUninit;
use core::{slice, str};
use super::{decode_simd128_unchecked, encode_simd128_unchecked, validate_simd128};
use crate::util::{DIGITS_LOWER_16, DIGITS_UPPER_16};
macro_rules! test {
(
Encode = $encode_f:ident;
Validate = $($validate_f:ident),*;
Decode = $($decode_f:ident),*;
Case = $i:expr
) => {{
let input = $i;
let expected_lower = input
.iter()
.flat_map(|b| [
DIGITS_LOWER_16[(*b >> 4) as usize] as char,
DIGITS_LOWER_16[(*b & 0b1111) as usize] as char,
])
.collect::<String>();
let expected_upper = input
.iter()
.flat_map(|b| [
DIGITS_UPPER_16[(*b >> 4) as usize] as char,
DIGITS_UPPER_16[(*b & 0b1111) as usize] as char,
])
.collect::<String>();
let mut output_lower = vec![[MaybeUninit::<u8>::uninit(); 2]; input.len()];
let mut output_upper = vec![[MaybeUninit::<u8>::uninit(); 2]; input.len()];
unsafe {
$encode_f::<false>(input, &mut output_lower);
$encode_f::<true>(input, &mut output_upper);
}
let output_lower = unsafe {
slice::from_raw_parts(
output_lower.as_ptr().cast::<[u8; 2]>(),
output_lower.len(),
)
};
let output_upper = unsafe {
slice::from_raw_parts(
output_upper.as_ptr().cast::<[u8; 2]>(),
output_upper.len(),
)
};
assert_eq!(
output_lower.as_flattened(),
expected_lower.as_bytes(),
"Encode error, expect \"{expected_lower}\", got \"{}\" ({:?})",
str::from_utf8(output_lower.as_flattened()).unwrap_or("<invalid utf-8>"),
output_lower.as_flattened()
);
assert_eq!(
output_upper.as_flattened(),
expected_upper.as_bytes(),
"Encode error, expect \"{expected_upper}\", got \"{}\" ({:?})",
str::from_utf8(output_upper.as_flattened()).unwrap_or("<invalid utf-8>"),
output_upper.as_flattened()
);
$(
unsafe {
$validate_f(output_lower)
.unwrap_or_else(|_| panic!("validation failed for {}", stringify!($validate_f)));
$validate_f(output_upper)
.unwrap_or_else(|_| panic!("validation failed for {}", stringify!($validate_f)));
}
)*
$({
let mut decoded_lower = vec![MaybeUninit::<u8>::uninit(); input.len()];
let mut decoded_upper = vec![MaybeUninit::<u8>::uninit(); input.len()];
unsafe {
$decode_f(output_lower, &mut decoded_lower);
$decode_f(output_upper, &mut decoded_upper);
assert_eq!(
decoded_lower.assume_init_ref(),
input,
"Decode error for {}, expect {:?}, got {:?}",
stringify!($decode_f),
input,
decoded_lower.assume_init_ref()
);
assert_eq!(
decoded_upper.assume_init_ref(),
input,
"Decode error for {}, expect {:?}, got {:?}",
stringify!($decode_f),
input,
decoded_upper.assume_init_ref()
);
}
})*
}};
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_simd128() {
const CASE: &[u8; 17] = &[
0x12, 0x77, 0x4C, 0x16, 0x16, 0x2B, 0x99, 0x97, 0x37, 0x62, 0x24, 0x24, 0x36, 0x83,
0xA4, 0xF1, 0xDD,
];
test! {
Encode = encode_simd128_unchecked;
Validate = validate_simd128;
Decode = decode_simd128_unchecked;
Case = &[]
}
test! {
Encode = encode_simd128_unchecked;
Validate = validate_simd128;
Decode = decode_simd128_unchecked;
Case = &CASE[..15]
}
test! {
Encode = encode_simd128_unchecked;
Validate = validate_simd128;
Decode = decode_simd128_unchecked;
Case = &CASE[..16]
}
test! {
Encode = encode_simd128_unchecked;
Validate = validate_simd128;
Decode = decode_simd128_unchecked;
Case = &CASE[..17]
};
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_validation() {
for l in [15, 16, 17, 31, 32, 33, 63, 64, 65] {
for c in 0..=255 {
let mut bytes = vec![b'a'; l * 2];
bytes[l] = c;
let bytes = unsafe { bytes.as_chunks_unchecked() };
if c.is_ascii_hexdigit() {
unsafe {
assert!(
validate_simd128(bytes).is_ok(),
"simd128 validation failed for char `{}` (l={l})",
c as char
);
}
} else {
unsafe {
assert!(
validate_simd128(bytes).is_err(),
"simd128 validation should have failed for byte {c} (l={l})"
);
}
}
}
}
}
}