#![allow(unsafe_code, reason = "SIMD")]
#![allow(unsafe_op_in_unsafe_fn, reason = "SIMD")]
#![allow(clippy::cast_lossless, reason = "SIMD")]
#![allow(clippy::cast_possible_truncation, reason = "SIMD")]
#![allow(clippy::cast_possible_wrap, reason = "SIMD")]
use core::arch::loongarch64::*;
use core::mem;
use core::mem::MaybeUninit;
use crate::backend::generic::{decode_generic_unchecked, encode_generic_unchecked};
use crate::error::InvalidInput;
use crate::util::{digits16, digits32};
#[target_feature(enable = "lsx")]
pub(crate) unsafe fn encode_lsx_unchecked<const UPPER: bool>(
mut src: &[u8],
mut dst: &mut [[MaybeUninit<u8>; 2]],
) {
const BATCH: usize = size_of::<m128i>();
if src.len() >= BATCH {
let m = lsx_vreplgr2vr_b(0b_0000_1111);
let lut = lsx_vld::<0>(digits16::<UPPER>().as_ptr().cast());
while src.len() >= BATCH {
let chunk = lsx_vld::<0>(src.as_ptr().cast());
let hi = lsx_vsrli_b::<4>(chunk);
let lo = lsx_vand_v(chunk, m);
let hi = lsx_vshuf_b(lut, lut, hi);
let lo = lsx_vshuf_b(lut, lut, lo);
let out0 = lsx_vilvl_b(lo, hi);
let out1 = lsx_vilvh_b(lo, hi);
{
let dst = dst.as_mut_ptr().cast();
lsx_vst::<0>(out0, dst);
lsx_vst::<{ size_of::<m128i>() as _ }>(out1, dst);
}
src = &src[BATCH..];
dst = dst.get_unchecked_mut(BATCH..);
}
}
encode_generic_unchecked::<UPPER>(src, dst);
}
#[target_feature(enable = "lasx")]
pub(crate) unsafe fn encode_lasx_unchecked<const UPPER: bool>(
mut src: &[u8],
mut dst: &mut [[MaybeUninit<u8>; 2]],
) {
const BATCH: usize = size_of::<m256i>();
if src.len() >= BATCH {
let m = lasx_xvreplgr2vr_b(0b_0000_1111);
let lut = lasx_xvld::<0>(digits32::<UPPER>().as_ptr().cast());
while src.len() >= BATCH {
let chunk = lasx_xvld::<0>(src.as_ptr().cast());
let hi = lasx_xvsrli_b::<4>(chunk);
let lo = lasx_xvand_v(chunk, m);
let ac = lasx_xvilvl_b(lo, hi);
let bd = lasx_xvilvh_b(lo, hi);
let ab = lasx_xvpermi_q::<0x02>(ac, bd);
let cd = lasx_xvpermi_q::<0x13>(ac, bd);
let out0 = lasx_xvshuf_b(lut, lut, ab);
let out1 = lasx_xvshuf_b(lut, lut, cd);
{
let dst = dst.as_mut_ptr().cast();
lasx_xvst::<0>(out0, dst);
lasx_xvst::<{ size_of::<m256i>() as _ }>(out1, dst);
}
src = &src[BATCH..];
dst = dst.get_unchecked_mut(BATCH..);
}
}
encode_lsx_unchecked::<UPPER>(src, dst);
}
#[target_feature(enable = "lsx")]
pub(crate) unsafe fn decode_lsx_unchecked(
mut src: &[[u8; 2]],
mut dst: &mut [MaybeUninit<u8>],
) -> Result<(), InvalidInput> {
const BATCH: usize = size_of::<m128i>();
const TRAILING_BATCH: usize = BATCH / 2;
if src.len() >= TRAILING_BATCH {
let n_c6 = lsx_vreplgr2vr_b((0xFF_u8 - b'9') as i32);
let n_06 = lsx_vreplgr2vr_b(0x06);
let n_f0 = lsx_vreplgr2vr_b(0xF0);
let n_df = lsx_vreplgr2vr_b(0xDF);
let u_a = lsx_vreplgr2vr_b(b'A' as i32);
let n_0a = lsx_vreplgr2vr_b(0x0A);
let n_0f = lsx_vreplgr2vr_b(0x0F);
while src.len() >= BATCH {
let chunk1 = lsx_vld::<0>(src.as_ptr().cast());
let chunk2 = lsx_vld::<{ size_of::<m128i>() as _ }>(src.as_ptr().cast());
let d1 = lsx_vsub_b(lsx_vssub_bu(lsx_vadd_b(chunk1, n_c6), n_06), n_f0);
let d2 = lsx_vsub_b(lsx_vssub_bu(lsx_vadd_b(chunk2, n_c6), n_06), n_f0);
let a1 = lsx_vsadd_bu(lsx_vsub_b(lsx_vand_v(chunk1, n_df), u_a), n_0a);
let a2 = lsx_vsadd_bu(lsx_vsub_b(lsx_vand_v(chunk2, n_df), u_a), n_0a);
let n1 = lsx_vmin_bu(d1, a1);
let n2 = lsx_vmin_bu(d2, a2);
if lsx_bz_v(lsx_vslt_bu(n_0f, lsx_vmax_bu(n1, n2))) == 0 {
return Err(InvalidInput);
}
let bytes = {
let hi = lsx_vpickev_b(n2, n1);
let lo = lsx_vpickod_b(n2, n1);
lsx_vor_v(lsx_vslli_b::<4>(hi), lo)
};
lsx_vst::<0>(bytes, dst.as_mut_ptr().cast());
src = &src[BATCH..];
dst = dst.get_unchecked_mut(BATCH..);
}
if src.len() >= TRAILING_BATCH {
let chunk = lsx_vld::<0>(src.as_ptr().cast());
let d = lsx_vsub_b(lsx_vssub_bu(lsx_vadd_b(chunk, n_c6), n_06), n_f0);
let a = lsx_vsadd_bu(lsx_vsub_b(lsx_vand_v(chunk, n_df), u_a), n_0a);
let n = lsx_vmin_bu(d, a);
if lsx_bz_v(lsx_vslt_bu(n_0f, n)) == 0 {
return Err(InvalidInput);
}
let bytes = {
let hi = lsx_vpickev_b(n, n);
let lo = lsx_vpickod_b(n, n);
lsx_vor_v(lsx_vslli_b::<4>(hi), lo)
};
lsx_vstelm_d(bytes, dst.as_mut_ptr().cast(), 0, 0);
src = &src[TRAILING_BATCH..];
dst = dst.get_unchecked_mut(TRAILING_BATCH..);
}
}
decode_generic_unchecked::<false>(src, dst)
}
#[target_feature(enable = "lasx")]
pub(crate) unsafe fn decode_lasx_unchecked(
mut src: &[[u8; 2]],
mut dst: &mut [MaybeUninit<u8>],
) -> Result<(), InvalidInput> {
const BATCH: usize = size_of::<m256i>();
const TRAILING_BATCH: usize = BATCH / 2;
if src.len() >= TRAILING_BATCH {
let n_c6 = lasx_xvreplgr2vr_b((0xFF_u8 - b'9') as i32);
let n_06 = lasx_xvreplgr2vr_b(0x06);
let n_f0 = lasx_xvreplgr2vr_b(0xF0);
let n_df = lasx_xvreplgr2vr_b(0xDF);
let u_a = lasx_xvreplgr2vr_b(b'A' as i32);
let n_0a = lasx_xvreplgr2vr_b(0x0A);
let n_0f = lasx_xvreplgr2vr_b(0x0F);
while src.len() >= BATCH {
let chunk1 = lasx_xvld::<0>(src.as_ptr().cast());
let chunk2 = lasx_xvld::<{ size_of::<m256i>() as _ }>(src.as_ptr().cast());
let d1 = lasx_xvsub_b(lasx_xvssub_bu(lasx_xvadd_b(chunk1, n_c6), n_06), n_f0);
let d2 = lasx_xvsub_b(lasx_xvssub_bu(lasx_xvadd_b(chunk2, n_c6), n_06), n_f0);
let a1 = lasx_xvsadd_bu(lasx_xvsub_b(lasx_xvand_v(chunk1, n_df), u_a), n_0a);
let a2 = lasx_xvsadd_bu(lasx_xvsub_b(lasx_xvand_v(chunk2, n_df), u_a), n_0a);
let n1 = lasx_xvmin_bu(d1, a1);
let n2 = lasx_xvmin_bu(d2, a2);
if lasx_xbz_v(lasx_xvslt_bu(n_0f, lasx_xvmax_bu(n1, n2))) == 0 {
return Err(InvalidInput);
}
let bytes = {
let hi = lasx_xvpickev_b(n2, n1);
let lo = lasx_xvpickod_b(n2, n1);
let bytes = lasx_xvor_v(lasx_xvslli_b::<4>(hi), lo);
lasx_xvpermi_d::<0b_11_01_10_00>(bytes)
};
lasx_xvst::<0>(bytes, dst.as_mut_ptr().cast());
src = &src[BATCH..];
dst = dst.get_unchecked_mut(BATCH..);
}
if src.len() >= TRAILING_BATCH {
let chunk = lasx_xvld::<0>(src.as_ptr().cast());
let d = lasx_xvsub_b(lasx_xvssub_bu(lasx_xvadd_b(chunk, n_c6), n_06), n_f0);
let a = lasx_xvsadd_bu(lasx_xvsub_b(lasx_xvand_v(chunk, n_df), u_a), n_0a);
let n = lasx_xvmin_bu(d, a);
if lasx_xbz_v(lasx_xvslt_bu(n_0f, n)) == 0 {
return Err(InvalidInput);
}
let bytes = {
let hi = lasx_xvpickev_b(n, n);
let lo = lasx_xvpickod_b(n, n);
let bytes = lasx_xvor_v(lasx_xvslli_b::<4>(hi), lo);
lasx_xvpermi_d::<0b_11_01_10_00>(bytes)
};
let [bytes, _] = unsafe { mem::transmute::<m256i, [m128i; 2]>(bytes) };
lsx_vst::<0>(bytes, dst.as_mut_ptr().cast());
src = &src[TRAILING_BATCH..];
dst = dst.get_unchecked_mut(TRAILING_BATCH..);
}
}
decode_generic_unchecked::<false>(src, dst)
}
#[cfg(test)]
mod smoking {
use alloc::string::String;
use alloc::vec;
use alloc::vec::Vec;
use core::mem::MaybeUninit;
use core::{slice, str};
use super::{
decode_lasx_unchecked, decode_lsx_unchecked, encode_lasx_unchecked, encode_lsx_unchecked,
};
use crate::util::{DIGITS_LOWER_16, DIGITS_UPPER_16};
macro_rules! test {
(
Encode = $encode_f:ident;
Decode = $($decode_f:ident),*;
Case = $i:expr
) => {{
let input = $i;
let expected_lower = input
.iter()
.flat_map(|b| [
DIGITS_LOWER_16[(*b >> 4) as usize] as char,
DIGITS_LOWER_16[(*b & 0b1111) as usize] as char,
])
.collect::<String>();
let expected_upper = input
.iter()
.flat_map(|b| [
DIGITS_UPPER_16[(*b >> 4) as usize] as char,
DIGITS_UPPER_16[(*b & 0b1111) as usize] as char,
])
.collect::<String>();
let mut output_lower = vec![[MaybeUninit::<u8>::uninit(); 2]; input.len()];
let mut output_upper = vec![[MaybeUninit::<u8>::uninit(); 2]; input.len()];
unsafe {
$encode_f::<false>(input, &mut output_lower);
$encode_f::<true>(input, &mut output_upper);
}
let output_lower = unsafe {
slice::from_raw_parts(
output_lower.as_ptr().cast::<[u8; 2]>(),
output_lower.len(),
)
};
let output_upper = unsafe {
slice::from_raw_parts(
output_upper.as_ptr().cast::<[u8; 2]>(),
output_upper.len(),
)
};
assert_eq!(
output_lower.as_flattened(),
expected_lower.as_bytes(),
"Encode error, expect \"{expected_lower}\", got \"{}\" ({:?})",
str::from_utf8(output_lower.as_flattened()).unwrap_or("<invalid utf-8>"),
output_lower.as_flattened()
);
assert_eq!(
output_upper.as_flattened(),
expected_upper.as_bytes(),
"Encode error, expect \"{expected_upper}\", got \"{}\" ({:?})",
str::from_utf8(output_upper.as_flattened()).unwrap_or("<invalid utf-8>"),
output_upper.as_flattened()
);
$({
let mut decoded_lower = vec![MaybeUninit::<u8>::uninit(); input.len()];
let mut decoded_upper = vec![MaybeUninit::<u8>::uninit(); input.len()];
unsafe {
$decode_f(output_lower, &mut decoded_lower).unwrap();
$decode_f(output_upper, &mut decoded_upper).unwrap();
assert_eq!(
decoded_lower.assume_init_ref(),
input,
"Decode error for {}, expect {:?}, got {:?}",
stringify!($decode_f),
input,
decoded_lower.assume_init_ref()
);
assert_eq!(
decoded_upper.assume_init_ref(),
input,
"Decode error for {}, expect {:?}, got {:?}",
stringify!($decode_f),
input,
decoded_upper.assume_init_ref()
);
}
})*
}};
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_lsx() {
const CASE: &[u8; 33] = &[
0x62, 0xBE, 0x66, 0xE0, 0x1C, 0x1E, 0xFB, 0x43, 0x16, 0xA0, 0x9F, 0x8A, 0xE4, 0x93,
0xE3, 0x7F, 0x23, 0x9F, 0x0D, 0xEF, 0x94, 0x25, 0xE0, 0x60, 0x62, 0xBA, 0x10, 0xB2,
0x7B, 0xB6, 0x2B, 0xFB, 0x44,
];
test! {
Encode = encode_lsx_unchecked;
Decode = decode_lsx_unchecked;
Case = &[]
}
test! {
Encode = encode_lsx_unchecked;
Decode = decode_lsx_unchecked;
Case = &CASE[..15]
}
test! {
Encode = encode_lsx_unchecked;
Decode = decode_lsx_unchecked;
Case = &CASE[..16]
}
test! {
Encode = encode_lsx_unchecked;
Decode = decode_lsx_unchecked;
Case = &CASE[..17]
};
test! {
Encode = encode_lsx_unchecked;
Decode = decode_lsx_unchecked;
Case = &CASE[..31]
}
test! {
Encode = encode_lsx_unchecked;
Decode = decode_lsx_unchecked;
Case = &CASE[..32]
}
test! {
Encode = encode_lsx_unchecked;
Decode = decode_lsx_unchecked;
Case = &CASE[..33]
};
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_lasx() {
const CASE: &[u8; 65] = &[
0xA1, 0xA4, 0xA2, 0x49, 0x4A, 0x43, 0x03, 0x31, 0x5F, 0x60, 0xE7, 0x8F, 0x17, 0x36,
0x31, 0xAD, 0xB3, 0xE4, 0xF2, 0x35, 0x33, 0x6F, 0x05, 0xF0, 0xAA, 0x52, 0xD2, 0x6F,
0x3A, 0xB7, 0x4A, 0xAB, 0x66, 0x32, 0xB0, 0xD6, 0x1C, 0x8C, 0xED, 0x85, 0x9E, 0x03,
0x90, 0x87, 0x16, 0x9C, 0xBA, 0x34, 0xAD, 0x59, 0x35, 0x66, 0xED, 0x80, 0x22, 0x85,
0xDB, 0x54, 0x5E, 0x79, 0xD3, 0x9A, 0x6F, 0x24, 0x43,
];
test! {
Encode = encode_lasx_unchecked;
Decode = decode_lasx_unchecked, decode_lsx_unchecked;
Case = &[]
}
test! {
Encode = encode_lasx_unchecked;
Decode = decode_lasx_unchecked, decode_lsx_unchecked;
Case = &CASE[..31]
}
test! {
Encode = encode_lasx_unchecked;
Decode = decode_lasx_unchecked, decode_lsx_unchecked;
Case = &CASE[..32]
}
test! {
Encode = encode_lasx_unchecked;
Decode = decode_lasx_unchecked, decode_lsx_unchecked;
Case = &CASE[..33]
}
test! {
Encode = encode_lasx_unchecked;
Decode = decode_lasx_unchecked, decode_lsx_unchecked;
Case = &CASE[..63]
}
test! {
Encode = encode_lasx_unchecked;
Decode = decode_lasx_unchecked, decode_lsx_unchecked;
Case = &CASE[..64]
}
test! {
Encode = encode_lasx_unchecked;
Decode = decode_lasx_unchecked, decode_lsx_unchecked;
Case = &CASE[..65]
};
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_validation() {
for l in [15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129] {
for c in 0u8..=255 {
let mut bytes = vec![b'a'; l * 2];
bytes[l] = c;
let bytes = unsafe { bytes.as_chunks_unchecked() };
if c.is_ascii_hexdigit() {
unsafe {
assert!(
decode_lsx_unchecked(bytes, Vec::with_capacity(l).spare_capacity_mut())
.is_ok(),
"lsx validation failed for byte {c} (l={l})",
);
assert!(
decode_lasx_unchecked(
bytes,
Vec::with_capacity(l).spare_capacity_mut()
)
.is_ok(),
"lasx validation failed for byte {c} (l={l})",
);
}
} else {
unsafe {
assert!(
decode_lsx_unchecked(bytes, Vec::with_capacity(l).spare_capacity_mut())
.is_err(),
"lsx validation failed for byte {c} (l={l})"
);
assert!(
decode_lasx_unchecked(
bytes,
Vec::with_capacity(l).spare_capacity_mut()
)
.is_err(),
"lasx validation failed for byte {c} (l={l})"
);
}
}
}
}
}
}