fashex 0.0.3

Conversion from bytes to hexadecimal string.
Documentation
//! Optimized implementations for `aarch64`.

#![allow(unsafe_code, reason = "SIMD")]
#![allow(unsafe_op_in_unsafe_fn, reason = "SIMD")]

use core::arch::aarch64::*;
use core::mem::MaybeUninit;

use crate::backend::generic::{decode_generic_unchecked, encode_generic_unchecked};
use crate::error::InvalidInput;
use crate::util::digits16;

#[target_feature(enable = "neon")]
/// ## Safety
///
/// We assume that:
///
/// 1. The CPU supports `neon`.
/// 2. `src.len() == dst.len()`.
pub(crate) unsafe fn encode_neon_unchecked<const UPPER: bool>(
    mut src: &[u8],
    mut dst: &mut [[MaybeUninit<u8>; 2]],
) {
    /// Process 16 bytes of input, and produce 16 * 2 bytes of output.
    const BATCH: usize = size_of::<uint8x16_t>();

    if src.len() >= BATCH {
        let m = vdupq_n_u8(0b_0000_1111);
        let lut = vld1q_u8(digits16::<UPPER>().as_ptr());

        while src.len() >= BATCH {
            // let [byte @ u8; 16]
            let chunk: uint8x16_t = vld1q_u8(src.as_ptr());

            // let [hi; 16] = [byte >> 4; 16];
            let mut hi = vshrq_n_u8(chunk, 4);
            // let [lo; 16] = [byte & 0b_0000_1111; 16];
            let mut lo = vandq_u8(chunk, m);

            // let [hi; 16] = [lut[hi]; 16];
            lo = vqtbl1q_u8(lut, lo);
            // let [lo; 16] = [lut[lo]; 16];
            hi = vqtbl1q_u8(lut, hi);

            // Interleave the nibbles ([hi[0], lo[0], hi[1], lo[1], ...]).
            let output = vzipq_u8(hi, lo);

            // Store the result.
            vst1q_u8_x2(dst.as_mut_ptr().cast(), output);

            src = &src[BATCH..];
            dst = dst.get_unchecked_mut(BATCH..);
        }
    }

    encode_generic_unchecked::<UPPER>(src, dst);
}

#[target_feature(enable = "neon")]
/// ## Safety
///
/// We assume that:
///
/// 1. The CPU supports `neon`.
/// 2. `src.len() == dst.len()`.
pub(crate) unsafe fn decode_neon_unchecked(
    mut src: &[[u8; 2]],
    mut dst: &mut [MaybeUninit<u8>],
) -> Result<(), InvalidInput> {
    /// Process 16 * 2 bytes of input, and produce 16 bytes of output.
    const BATCH: usize = size_of::<uint8x16_t>();

    /// Process 8 * 2 bytes of input, and produce 8 bytes of output.
    const TRAILING_BATCH: usize = BATCH / 2;

    if src.len() >= TRAILING_BATCH {
        let n_c6 = vdupq_n_u8(0xFF_u8 - b'9');
        let n_06 = vdupq_n_u8(0x06);
        let n_f0 = vdupq_n_u8(0xF0);

        let n_df = vdupq_n_u8(0xDF);
        let u_a = vdupq_n_u8(b'A');
        let n_0a = vdupq_n_u8(0x0A);

        while src.len() >= BATCH {
            let uint8x16x2_t(chunk1, chunk2) = vld1q_u8_x2(src.as_ptr().cast::<u8>());

            // Digits '0'..'9' → 0..9, others > 15.
            let d1 = vsubq_u8(vqsubq_u8(vaddq_u8(chunk1, n_c6), n_06), n_f0);
            let d2 = vsubq_u8(vqsubq_u8(vaddq_u8(chunk2, n_c6), n_06), n_f0);

            // Letters 'A'..'F'/'a'..'f' → 10..15, others > 15.
            let a1 = vqaddq_u8(vsubq_u8(vandq_u8(chunk1, n_df), u_a), n_0a);
            let a2 = vqaddq_u8(vsubq_u8(vandq_u8(chunk2, n_df), u_a), n_0a);

            // Valid nibble wins (0..15), invalid stays > 15.
            let n1 = vminq_u8(d1, a1);
            let n2 = vminq_u8(d2, a2);

            // Validate: invalid stays > 15.
            if vmaxvq_u8(n1) > 0x0F || vmaxvq_u8(n2) > 0x0F {
                return Err(InvalidInput);
            }

            let bytes = {
                let uint8x16x2_t(hi, lo) = vuzpq_u8(n1, n2);
                vorrq_u8(vshlq_n_u8(hi, 4), lo)
            };

            vst1q_u8(dst.as_mut_ptr().cast::<u8>(), bytes);

            src = &src[BATCH..];
            dst = dst.get_unchecked_mut(BATCH..);
        }

        if src.len() >= TRAILING_BATCH {
            let chunk = vld1q_u8(src.as_ptr().cast::<u8>());

            // Digits '0'..'9' → 0..9, others > 15.
            let d = vsubq_u8(vqsubq_u8(vaddq_u8(chunk, n_c6), n_06), n_f0);

            // Letters 'A'..'F'/'a'..'f' → 10..15, others > 15.
            let a = vqaddq_u8(vsubq_u8(vandq_u8(chunk, n_df), u_a), n_0a);

            // Valid nibble wins (0..15), invalid stays > 15.
            let n = vminq_u8(d, a);

            // Validate: invalid stays > 15.
            if vmaxvq_u8(n) > 0x0F {
                return Err(InvalidInput);
            }

            let bytes = {
                let uint8x16x2_t(hi, lo) = vuzpq_u8(n, n);
                vorr_u8(vshl_n_u8(vget_low_u8(hi), 4), vget_low_u8(lo))
            };

            vst1_u8(dst.as_mut_ptr().cast::<u8>(), bytes);

            src = &src[TRAILING_BATCH..];
            dst = dst.get_unchecked_mut(TRAILING_BATCH..);
        }
    }

    decode_generic_unchecked::<false>(src, dst)
}

#[cfg(test)]
mod smoking {
    use alloc::string::String;
    use alloc::vec;
    use alloc::vec::Vec;
    use core::mem::MaybeUninit;
    use core::{slice, str};

    use super::*;
    use crate::util::{DIGITS_LOWER_16, DIGITS_UPPER_16};

    macro_rules! test {
        (
            Encode = $encode_f:ident;
            Decode = $($decode_f:ident),*;
            Case = $i:expr
        ) => {{
            let input = $i;

            let expected_lower = input
                .iter()
                .flat_map(|b| [
                    DIGITS_LOWER_16[(*b >> 4) as usize] as char,
                    DIGITS_LOWER_16[(*b & 0b1111) as usize] as char,
                ])
                .collect::<String>();
            let expected_upper = input
                .iter()
                .flat_map(|b| [
                    DIGITS_UPPER_16[(*b >> 4) as usize] as char,
                    DIGITS_UPPER_16[(*b & 0b1111) as usize] as char,
                ])
                .collect::<String>();

            let mut output_lower = vec![[MaybeUninit::<u8>::uninit(); 2]; input.len()];
            let mut output_upper = vec![[MaybeUninit::<u8>::uninit(); 2]; input.len()];

            unsafe {
                $encode_f::<false>(input, &mut output_lower);
                $encode_f::<true>(input, &mut output_upper);
            }

            let output_lower = unsafe {
                slice::from_raw_parts(
                    output_lower.as_ptr().cast::<[u8; 2]>(),
                    output_lower.len(),
                )
            };
            let output_upper = unsafe {
                slice::from_raw_parts(
                    output_upper.as_ptr().cast::<[u8; 2]>(),
                    output_upper.len(),
                )
            };

            assert_eq!(
                output_lower.as_flattened(),
                expected_lower.as_bytes(),
                "Encode error, expect \"{expected_lower}\", got \"{}\" ({:?})",
                str::from_utf8(output_lower.as_flattened()).unwrap_or("<invalid utf-8>"),
                output_lower.as_flattened()
            );
            assert_eq!(
                output_upper.as_flattened(),
                expected_upper.as_bytes(),
                "Encode error, expect \"{expected_upper}\", got \"{}\" ({:?})",
                str::from_utf8(output_upper.as_flattened()).unwrap_or("<invalid utf-8>"),
                output_upper.as_flattened()
            );

            $({
                let mut decoded_lower = vec![MaybeUninit::<u8>::uninit(); input.len()];
                let mut decoded_upper = vec![MaybeUninit::<u8>::uninit(); input.len()];

                unsafe {
                    $decode_f(output_lower, &mut decoded_lower).unwrap();
                    $decode_f(output_upper, &mut decoded_upper).unwrap();

                    assert_eq!(
                        decoded_lower.assume_init_ref(),
                        input,
                        "Decode error for {}, expect {:?}, got {:?}",
                        stringify!($decode_f),
                        input,
                        decoded_lower.assume_init_ref()
                    );
                    assert_eq!(
                        decoded_upper.assume_init_ref(),
                        input,
                        "Decode error for {}, expect {:?}, got {:?}",
                        stringify!($decode_f),
                        input,
                        decoded_upper.assume_init_ref()
                    );
                }
            })*
        }};
    }

    #[test]
    #[cfg_attr(miri, ignore)]
    fn test_neon() {
        const CASE: &[u8; 33] = &[
            0x62, 0xBE, 0x66, 0xE0, 0x1C, 0x1E, 0xFB, 0x43, 0x16, 0xA0, 0x9F, 0x8A, 0xE4, 0x93,
            0xE3, 0x7F, 0x23, 0x9F, 0x0D, 0xEF, 0x94, 0x25, 0xE0, 0x60, 0x62, 0xBA, 0x10, 0xB2,
            0x7B, 0xB6, 0x2B, 0xFB, 0x44,
        ];

        test! {
            Encode = encode_neon_unchecked;
            Decode = decode_neon_unchecked;
            Case = &[]
        }

        test! {
            Encode = encode_neon_unchecked;
            Decode = decode_neon_unchecked;
            Case = &CASE[..15]
        }

        test! {
            Encode = encode_neon_unchecked;
            Decode = decode_neon_unchecked;
            Case = &CASE[..16]
        }

        test! {
            Encode = encode_neon_unchecked;
            Decode = decode_neon_unchecked;
            Case = &CASE[..17]
        };

        test! {
            Encode = encode_neon_unchecked;
            Decode = decode_neon_unchecked;
            Case = &CASE[..31]
        }

        test! {
            Encode = encode_neon_unchecked;
            Decode = decode_neon_unchecked;
            Case = &CASE[..32]
        }

        test! {
            Encode = encode_neon_unchecked;
            Decode = decode_neon_unchecked;
            Case = &CASE[..33]
        };
    }

    #[test]
    #[cfg_attr(miri, ignore)]
    fn test_validation() {
        for l in [15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129] {
            for c in 0u8..=255 {
                let mut bytes = vec![b'a'; l * 2];

                bytes[l] = c;

                let bytes = unsafe { bytes.as_chunks_unchecked() };

                if c.is_ascii_hexdigit() {
                    unsafe {
                        assert!(
                            decode_neon_unchecked(
                                bytes,
                                Vec::with_capacity(l).spare_capacity_mut()
                            )
                            .is_ok(),
                            "neon validation failed for byte {c} (l={l})",
                        );
                    }
                } else {
                    unsafe {
                        assert!(
                            decode_neon_unchecked(
                                bytes,
                                Vec::with_capacity(l).spare_capacity_mut()
                            )
                            .is_err(),
                            "neon validation failed for byte {c} (l={l})"
                        );
                    }
                }
            }
        }
    }
}